Skip to content

Commit 0fab92f

Browse files
authored
[INS-170] Unify JDBC URL parsing across detectors and analyzers (#4574)
* publicize jdbc url parsing methods for mysql * publicize jdbc url parsing methods for postgresql * publicize jdbc url parsing methods for sqlserver * make postgres consistent with others * keep default user as postgres * use same connectioninfo struct for all handlers
1 parent 05cccb5 commit 0fab92f

File tree

9 files changed

+138
-93
lines changed

9 files changed

+138
-93
lines changed

pkg/detectors/jdbc/jdbc.go

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,16 +207,25 @@ func tryRedactRegex(conn string) (string, bool) {
207207
}
208208

209209
var supportedSubprotocols = map[string]func(logContext.Context, string) (jdbc, error){
210-
"mysql": parseMySQL,
211-
"postgresql": parsePostgres,
212-
"sqlserver": parseSqlServer,
210+
"mysql": ParseMySQL,
211+
"postgresql": ParsePostgres,
212+
"sqlserver": ParseSqlServer,
213213
}
214214

215215
type pingResult struct {
216216
err error
217217
determinate bool
218218
}
219219

220+
// ConnectionInfo holds parsed connection information
221+
type ConnectionInfo struct {
222+
Host string // includes port if specified, e.g., "host:port"
223+
Database string
224+
User string
225+
Password string
226+
Params map[string]string
227+
}
228+
220229
type jdbc interface {
221230
ping(context.Context) pingResult
222231
}

pkg/detectors/jdbc/mysql.go

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,30 @@ import (
1212
"github.com/go-sql-driver/mysql"
1313
)
1414

15-
type mysqlJDBC struct {
16-
conn string
17-
userPass string
18-
host string
19-
params string
15+
type MysqlJDBC struct {
16+
ConnectionInfo
2017
}
2118

22-
func (s *mysqlJDBC) ping(ctx context.Context) pingResult {
19+
func (s *MysqlJDBC) ping(ctx context.Context) pingResult {
2320
return ping(ctx, "mysql", isMySQLErrorDeterminate,
24-
buildMySQLConnectionString(s.host, "", s.userPass, s.params))
21+
BuildMySQLConnectionString(s.Host, "", s.User, s.Password, s.Params))
2522
}
2623

27-
func buildMySQLConnectionString(host, database, userPass, params string) string {
24+
func BuildMySQLConnectionString(host, database, user, password string, params map[string]string) string {
2825
conn := host + "/" + database
26+
userPass := user
27+
if password != "" {
28+
userPass = userPass + ":" + password
29+
}
2930
if userPass != "" {
3031
conn = userPass + "@" + conn
3132
}
32-
if params != "" {
33-
conn = conn + "?" + params
33+
if len(params) > 0 {
34+
var paramList []string
35+
for k, v := range params {
36+
paramList = append(paramList, fmt.Sprintf("%s=%s", k, v))
37+
}
38+
conn = conn + "?" + strings.Join(paramList, "&")
3439
}
3540
return conn
3641
}
@@ -51,7 +56,7 @@ func isMySQLErrorDeterminate(err error) bool {
5156
return false
5257
}
5358

54-
func parseMySQL(ctx logContext.Context, subname string) (jdbc, error) {
59+
func ParseMySQL(ctx logContext.Context, subname string) (jdbc, error) {
5560
// expected form: [subprotocol:]//[user:password@]HOST[/DB][?key=val[&key=val]]
5661
if !strings.HasPrefix(subname, "//") {
5762
return nil, errors.New("expected host to start with //")
@@ -70,11 +75,14 @@ func parseMySQL(ctx logContext.Context, subname string) (jdbc, error) {
7075
Info("Skipping invalid MySQL URL - no password or host found")
7176
return nil, fmt.Errorf("missing host or password in connection string")
7277
}
73-
return &mysqlJDBC{
74-
conn: subname[2:],
75-
userPass: cfg.User + ":" + cfg.Passwd,
76-
host: fmt.Sprintf("tcp(%s)", cfg.Addr),
77-
params: "timeout=5s",
78+
return &MysqlJDBC{
79+
ConnectionInfo: ConnectionInfo{
80+
User: cfg.User,
81+
Password: cfg.Passwd,
82+
Host: fmt.Sprintf("tcp(%s)", cfg.Addr),
83+
Params: map[string]string{"timeout": "5s"},
84+
Database: cfg.DBName,
85+
},
7886
}, nil
7987
}
8088

@@ -107,13 +115,20 @@ func parseMySQLURI(ctx logContext.Context, subname string) (jdbc, error) {
107115
return nil, fmt.Errorf("missing host or password in connection string")
108116
}
109117

110-
userAndPass := user + ":" + pass
118+
// Parse database name
119+
dbName := strings.TrimPrefix(u.Path, "/")
120+
if dbName == "" {
121+
dbName = "mysql" // default DB
122+
}
111123

112-
return &mysqlJDBC{
113-
conn: subname[2:],
114-
userPass: userAndPass,
115-
host: fmt.Sprintf("tcp(%s)", u.Host),
116-
params: "timeout=5s",
124+
return &MysqlJDBC{
125+
ConnectionInfo: ConnectionInfo{
126+
User: user,
127+
Password: pass,
128+
Host: fmt.Sprintf("tcp(%s)", u.Host),
129+
Params: map[string]string{"timeout": "5s"},
130+
Database: dbName,
131+
},
117132
}, nil
118133

119134
}

pkg/detectors/jdbc/mysql_integration_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ import (
1212
"github.com/google/go-cmp/cmp"
1313
"github.com/stretchr/testify/assert"
1414
"github.com/testcontainers/testcontainers-go/modules/mysql"
15+
16+
logContext "github.com/trufflesecurity/trufflehog/v3/pkg/context"
1517
)
1618

1719
func TestMySQL(t *testing.T) {
@@ -89,7 +91,7 @@ func TestMySQL(t *testing.T) {
8991
}
9092
for _, tt := range tests {
9193
t.Run(tt.input, func(t *testing.T) {
92-
j, err := parseMySQL(tt.input)
94+
j, err := ParseMySQL(logContext.Background(), tt.input)
9395

9496
if err != nil {
9597
got := result{ParseErr: true}

pkg/detectors/jdbc/mysql_test.go

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package jdbc
22

33
import (
44
"context"
5-
"strings"
65
"testing"
76

87
logContext "github.com/trufflesecurity/trufflehog/v3/pkg/context"
@@ -59,7 +58,7 @@ func TestParseMySQLMissingCredentials(t *testing.T) {
5958
for _, tt := range tests {
6059
t.Run(tt.name, func(t *testing.T) {
6160
ctx := logContext.AddLogger(context.Background())
62-
j, err := parseMySQL(ctx, tt.subname)
61+
j, err := ParseMySQL(ctx, tt.subname)
6362

6463
if tt.shouldBeNil {
6564
if j != nil {
@@ -103,15 +102,15 @@ func TestParseMySQLUsernameRecognition(t *testing.T) {
103102
for _, tt := range tests {
104103
t.Run(tt.name, func(t *testing.T) {
105104
ctx := logContext.AddLogger(context.Background())
106-
j, err := parseMySQL(ctx, tt.subname)
105+
j, err := ParseMySQL(ctx, tt.subname)
107106
if err != nil {
108107
t.Fatalf("parseMySQL() error = %v", err)
109108
}
110109

111-
mysqlConn := j.(*mysqlJDBC)
112-
if !strings.Contains(mysqlConn.userPass, tt.wantUsername) {
110+
mysqlConn := j.(*MysqlJDBC)
111+
if mysqlConn.User != tt.wantUsername {
113112
t.Errorf("Connection string does not contain expected username '%s'\nGot: %s\nExpected: %s",
114-
tt.wantUsername, mysqlConn.userPass, tt.wantUsername)
113+
tt.wantUsername, mysqlConn.User, tt.wantUsername)
115114
}
116115
})
117116
}

pkg/detectors/jdbc/postgres.go

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,19 @@ import (
1212
"github.com/lib/pq"
1313
)
1414

15-
type postgresJDBC struct {
16-
conn string
17-
params map[string]string
15+
type PostgresJDBC struct {
16+
ConnectionInfo
1817
}
1918

20-
func (s *postgresJDBC) ping(ctx context.Context) pingResult {
19+
func (s *PostgresJDBC) ping(ctx context.Context) pingResult {
2120
// It is crucial that we try to build a connection string ourselves before using the one we found. This is because
2221
// if the found connection string doesn't include a username, the driver will attempt to connect using the current
2322
// user's name, which will fail in a way that looks like a determinate failure, thus terminating the waterfall. In
2423
// contrast, when we build a connection string ourselves, if there's no username, we try 'postgres' instead, which
2524
// actually has a chance of working.
2625
return ping(ctx, "postgres", isPostgresErrorDeterminate,
27-
buildPostgresConnectionString(s.params, true),
28-
buildPostgresConnectionString(s.params, false),
26+
BuildPostgresConnectionString(s.Host, s.User, s.Password, "postgres", s.Params, true),
27+
BuildPostgresConnectionString(s.Host, s.User, s.Password, "postgres", s.Params, false),
2928
)
3029
}
3130

@@ -59,7 +58,7 @@ func joinKeyValues(m map[string]string, sep string) string {
5958
return strings.Join(data, sep)
6059
}
6160

62-
func parsePostgres(ctx logContext.Context, subname string) (jdbc, error) {
61+
func ParsePostgres(ctx logContext.Context, subname string) (jdbc, error) {
6362
// expected form: [subprotocol:]//[user:password@]HOST[/DB][?key=val[&key=val]]
6463

6564
if !strings.HasPrefix(subname, "//") {
@@ -77,16 +76,22 @@ func parsePostgres(ctx logContext.Context, subname string) (jdbc, error) {
7776
}
7877

7978
params := map[string]string{
80-
"host": u.Host,
81-
"dbname": dbName,
8279
"connect_timeout": "5",
8380
}
8481

82+
postgresJDBC := &PostgresJDBC{
83+
ConnectionInfo: ConnectionInfo{
84+
Host: u.Host,
85+
Database: dbName,
86+
Params: params,
87+
},
88+
}
89+
8590
if u.User != nil {
86-
params["user"] = u.User.Username()
91+
postgresJDBC.User = u.User.Username()
8792
pass, set := u.User.Password()
8893
if set {
89-
params["password"] = pass
94+
postgresJDBC.Password = pass
9095
}
9196
}
9297

@@ -95,46 +100,51 @@ func parsePostgres(ctx logContext.Context, subname string) (jdbc, error) {
95100
// https://www.postgresql.org/docs/current/libpq-ssl.html#LIBPQ-SSL-PROTECTION
96101
case "disable", "allow", "prefer",
97102
"require", "verify-ca", "verify-full":
98-
params["sslmode"] = v[0]
103+
postgresJDBC.Params["sslmode"] = v[0]
99104
}
100105
}
101106

102107
if v := u.Query().Get("user"); v != "" {
103-
params["user"] = v
108+
postgresJDBC.User = v
104109
}
105110

106111
if v := u.Query().Get("password"); v != "" {
107-
params["password"] = v
112+
postgresJDBC.Password = v
108113
}
109114

110-
if params["host"] == "" || params["password"] == "" {
115+
if postgresJDBC.Host == "" || postgresJDBC.Password == "" {
111116
ctx.Logger().WithName("jdbc").
112117
V(2).
113118
Info("Skipping invalid Postgres URL - no password or host found")
114119
return nil, fmt.Errorf("missing host or password in connection string")
115120
}
116121

117-
return &postgresJDBC{subname[2:], params}, nil
122+
return postgresJDBC, nil
118123
}
119124

120-
func buildPostgresConnectionString(params map[string]string, includeDbName bool) string {
125+
func BuildPostgresConnectionString(host string, user string, password string, dbName string, params map[string]string, includeDbName bool) string {
121126
data := map[string]string{
122127
// default user
123-
"user": "postgres",
128+
"user": "postgres",
129+
"password": password,
130+
"host": host,
131+
}
132+
if user != "" {
133+
data["user"] = user
134+
}
135+
if h, p, ok := strings.Cut(host, ":"); ok {
136+
data["host"] = h
137+
data["port"] = p
124138
}
125139
for key, val := range params {
126-
if key == "host" {
127-
if h, p, found := strings.Cut(val, ":"); found {
128-
data["host"] = h
129-
data["port"] = p
130-
continue
131-
}
132-
}
133140
data[key] = val
134141
}
135142

136-
if !includeDbName {
143+
if includeDbName {
137144
data["dbname"] = "postgres"
145+
if dbName != "" {
146+
data["dbname"] = dbName
147+
}
138148
}
139149

140150
connStr := joinKeyValues(data, " ")

pkg/detectors/jdbc/postgres_integration_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ import (
1515
"github.com/testcontainers/testcontainers-go"
1616
"github.com/testcontainers/testcontainers-go/modules/postgres"
1717
"github.com/testcontainers/testcontainers-go/wait"
18+
19+
logContext "github.com/trufflesecurity/trufflehog/v3/pkg/context"
1820
)
1921

2022
func TestPostgres(t *testing.T) {
@@ -119,7 +121,7 @@ func TestPostgres(t *testing.T) {
119121

120122
for _, tt := range tests {
121123
t.Run(tt.name, func(t *testing.T) {
122-
j, err := parsePostgres(tt.input)
124+
j, err := ParsePostgres(logContext.Background(), tt.input)
123125
if err != nil {
124126
got := result{ParseErr: true}
125127

pkg/detectors/jdbc/postgres_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func TestParsePostgresMissingCredentials(t *testing.T) {
4747
for _, tt := range tests {
4848
t.Run(tt.name, func(t *testing.T) {
4949
ctx := logContext.AddLogger(context.Background())
50-
j, err := parsePostgres(ctx, tt.subname)
50+
j, err := ParsePostgres(ctx, tt.subname)
5151

5252
if tt.shouldBeNil {
5353
if j != nil {
@@ -86,14 +86,14 @@ func TestParsePostgresUsernameRecognition(t *testing.T) {
8686
for _, tt := range tests {
8787
t.Run(tt.name, func(t *testing.T) {
8888
ctx := logContext.AddLogger(context.Background())
89-
j, err := parsePostgres(ctx, tt.subname)
89+
j, err := ParsePostgres(ctx, tt.subname)
9090
if err != nil {
91-
t.Fatalf("parsePostgres() error = %v", err)
91+
t.Fatalf("ParsePostgres() error = %v", err)
9292
}
9393

94-
pgConn := j.(*postgresJDBC)
95-
if pgConn.params["user"] != tt.wantUsername {
96-
t.Errorf("expected username '%s', got '%s'", tt.wantUsername, pgConn.params["user"])
94+
pgConn := j.(*PostgresJDBC)
95+
if pgConn.User != tt.wantUsername {
96+
t.Errorf("expected username '%s', got '%s'", tt.wantUsername, pgConn.User)
9797
}
9898
})
9999
}

0 commit comments

Comments
 (0)