Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pkg/detectors/jdbc/jdbc.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,9 @@ func tryRedactRegex(conn string) (string, bool) {
}

var supportedSubprotocols = map[string]func(logContext.Context, string) (jdbc, error){
"mysql": parseMySQL,
"postgresql": parsePostgres,
"sqlserver": parseSqlServer,
"mysql": ParseMySQL,
"postgresql": ParsePostgres,
"sqlserver": ParseSqlServer,
}

type pingResult struct {
Expand Down
54 changes: 34 additions & 20 deletions pkg/detectors/jdbc/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,26 @@ import (
"github.com/go-sql-driver/mysql"
)

type mysqlJDBC struct {
conn string
userPass string
host string
params string
type MysqlJDBC struct {
Conn string
User string
Password string
Host string
Params string
Database string
}

func (s *mysqlJDBC) ping(ctx context.Context) pingResult {
func (s *MysqlJDBC) ping(ctx context.Context) pingResult {
return ping(ctx, "mysql", isMySQLErrorDeterminate,
buildMySQLConnectionString(s.host, "", s.userPass, s.params))
BuildMySQLConnectionString(s.Host, "", s.User, s.Password, s.Params))
}

func buildMySQLConnectionString(host, database, userPass, params string) string {
func BuildMySQLConnectionString(host, database, user, password, params string) string {
conn := host + "/" + database
userPass := user
if password != "" {
userPass = userPass + ":" + password
}
if userPass != "" {
conn = userPass + "@" + conn
}
Expand All @@ -51,7 +57,7 @@ func isMySQLErrorDeterminate(err error) bool {
return false
}

func parseMySQL(ctx logContext.Context, subname string) (jdbc, error) {
func ParseMySQL(ctx logContext.Context, subname string) (jdbc, error) {
// expected form: [subprotocol:]//[user:password@]HOST[/DB][?key=val[&key=val]]
if !strings.HasPrefix(subname, "//") {
return nil, errors.New("expected host to start with //")
Expand All @@ -70,11 +76,13 @@ func parseMySQL(ctx logContext.Context, subname string) (jdbc, error) {
Info("Skipping invalid MySQL URL - no password or host found")
return nil, fmt.Errorf("missing host or password in connection string")
}
return &mysqlJDBC{
conn: subname[2:],
userPass: cfg.User + ":" + cfg.Passwd,
host: fmt.Sprintf("tcp(%s)", cfg.Addr),
params: "timeout=5s",
return &MysqlJDBC{
Conn: subname[2:],
User: cfg.User,
Password: cfg.Passwd,
Host: fmt.Sprintf("tcp(%s)", cfg.Addr),
Params: "timeout=5s",
Database: cfg.DBName,
}, nil
}

Expand Down Expand Up @@ -107,13 +115,19 @@ func parseMySQLURI(ctx logContext.Context, subname string) (jdbc, error) {
return nil, fmt.Errorf("missing host or password in connection string")
}

userAndPass := user + ":" + pass
// Parse database name
dbName := strings.TrimPrefix(u.Path, "/")
if dbName == "" {
dbName = "mysql" // default DB
}

return &mysqlJDBC{
conn: subname[2:],
userPass: userAndPass,
host: fmt.Sprintf("tcp(%s)", u.Host),
params: "timeout=5s",
return &MysqlJDBC{
Conn: subname[2:],
User: user,
Password: pass,
Host: fmt.Sprintf("tcp(%s)", u.Host),
Params: "timeout=5s",
Database: dbName,
}, nil

}
4 changes: 3 additions & 1 deletion pkg/detectors/jdbc/mysql_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"github.com/testcontainers/testcontainers-go/modules/mysql"

logContext "github.com/trufflesecurity/trufflehog/v3/pkg/context"
)

func TestMySQL(t *testing.T) {
Expand Down Expand Up @@ -89,7 +91,7 @@ func TestMySQL(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
j, err := parseMySQL(tt.input)
j, err := ParseMySQL(logContext.Background(), tt.input)

if err != nil {
got := result{ParseErr: true}
Expand Down
11 changes: 5 additions & 6 deletions pkg/detectors/jdbc/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package jdbc

import (
"context"
"strings"
"testing"

logContext "github.com/trufflesecurity/trufflehog/v3/pkg/context"
Expand Down Expand Up @@ -59,7 +58,7 @@ func TestParseMySQLMissingCredentials(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := logContext.AddLogger(context.Background())
j, err := parseMySQL(ctx, tt.subname)
j, err := ParseMySQL(ctx, tt.subname)

if tt.shouldBeNil {
if j != nil {
Expand Down Expand Up @@ -103,15 +102,15 @@ func TestParseMySQLUsernameRecognition(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := logContext.AddLogger(context.Background())
j, err := parseMySQL(ctx, tt.subname)
j, err := ParseMySQL(ctx, tt.subname)
if err != nil {
t.Fatalf("parseMySQL() error = %v", err)
}

mysqlConn := j.(*mysqlJDBC)
if !strings.Contains(mysqlConn.userPass, tt.wantUsername) {
mysqlConn := j.(*MysqlJDBC)
if mysqlConn.User != tt.wantUsername {
t.Errorf("Connection string does not contain expected username '%s'\nGot: %s\nExpected: %s",
tt.wantUsername, mysqlConn.userPass, tt.wantUsername)
tt.wantUsername, mysqlConn.User, tt.wantUsername)
}
})
}
Expand Down
64 changes: 38 additions & 26 deletions pkg/detectors/jdbc/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,23 @@ import (
"github.com/lib/pq"
)

type postgresJDBC struct {
conn string
params map[string]string
type PostgresJDBC struct {
Host string
User string
Password string
Database string
Params map[string]string
}

func (s *postgresJDBC) ping(ctx context.Context) pingResult {
func (s *PostgresJDBC) ping(ctx context.Context) pingResult {
// It is crucial that we try to build a connection string ourselves before using the one we found. This is because
// if the found connection string doesn't include a username, the driver will attempt to connect using the current
// user's name, which will fail in a way that looks like a determinate failure, thus terminating the waterfall. In
// contrast, when we build a connection string ourselves, if there's no username, we try 'postgres' instead, which
// actually has a chance of working.
return ping(ctx, "postgres", isPostgresErrorDeterminate,
buildPostgresConnectionString(s.params, true),
buildPostgresConnectionString(s.params, false),
BuildPostgresConnectionString(s.Host, s.User, s.Password, "postgres", s.Params, true),
BuildPostgresConnectionString(s.Host, s.User, s.Password, "postgres", s.Params, false),
)
}

Expand Down Expand Up @@ -59,7 +62,7 @@ func joinKeyValues(m map[string]string, sep string) string {
return strings.Join(data, sep)
}

func parsePostgres(ctx logContext.Context, subname string) (jdbc, error) {
func ParsePostgres(ctx logContext.Context, subname string) (jdbc, error) {
// expected form: [subprotocol:]//[user:password@]HOST[/DB][?key=val[&key=val]]

if !strings.HasPrefix(subname, "//") {
Expand All @@ -77,16 +80,20 @@ func parsePostgres(ctx logContext.Context, subname string) (jdbc, error) {
}

params := map[string]string{
"host": u.Host,
"dbname": dbName,
"connect_timeout": "5",
}

postgresJDBC := &PostgresJDBC{
Host: u.Host,
Database: dbName,
Params: params,
}
Comment on lines 82 to 88
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can use a single ConnectionInfo struct across all handlers, and later clean it up from jdbc analyzer models.go.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is a great idea. I'll unify these as well.


if u.User != nil {
params["user"] = u.User.Username()
postgresJDBC.User = u.User.Username()
pass, set := u.User.Password()
if set {
params["password"] = pass
postgresJDBC.Password = pass
}
}

Expand All @@ -95,46 +102,51 @@ func parsePostgres(ctx logContext.Context, subname string) (jdbc, error) {
// https://www.postgresql.org/docs/current/libpq-ssl.html#LIBPQ-SSL-PROTECTION
case "disable", "allow", "prefer",
"require", "verify-ca", "verify-full":
params["sslmode"] = v[0]
postgresJDBC.Params["sslmode"] = v[0]
}
}

if v := u.Query().Get("user"); v != "" {
params["user"] = v
postgresJDBC.User = v
}

if v := u.Query().Get("password"); v != "" {
params["password"] = v
postgresJDBC.Password = v
}

if params["host"] == "" || params["password"] == "" {
if postgresJDBC.Host == "" || postgresJDBC.Password == "" {
ctx.Logger().WithName("jdbc").
V(2).
Info("Skipping invalid Postgres URL - no password or host found")
return nil, fmt.Errorf("missing host or password in connection string")
}

return &postgresJDBC{subname[2:], params}, nil
return postgresJDBC, nil
}

func buildPostgresConnectionString(params map[string]string, includeDbName bool) string {
func BuildPostgresConnectionString(host string, user string, password string, dbName string, params map[string]string, includeDbName bool) string {
data := map[string]string{
// default user
"user": "postgres",
"user": "postgres",
"password": password,
"host": host,
}
if user != "" {
data["user"] = user
}
if h, p, found := strings.Cut(host, ":"); found {
data["host"] = h
data["port"] = p
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As per the coding guidelines

Suggested change
if h, p, found := strings.Cut(host, ":"); found {
data["host"] = h
data["port"] = p
}
if h, p, ok := strings.Cut(host, ":"); ok {
data["host"] = h
data["port"] = p
}

for key, val := range params {
if key == "host" {
if h, p, found := strings.Cut(val, ":"); found {
data["host"] = h
data["port"] = p
continue
}
}
data[key] = val
}

if !includeDbName {
if includeDbName {
data["dbname"] = "postgres"
if dbName != "" {
data["dbname"] = dbName
}
}

connStr := joinKeyValues(data, " ")
Expand Down
4 changes: 3 additions & 1 deletion pkg/detectors/jdbc/postgres_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"

logContext "github.com/trufflesecurity/trufflehog/v3/pkg/context"
)

func TestPostgres(t *testing.T) {
Expand Down Expand Up @@ -119,7 +121,7 @@ func TestPostgres(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
j, err := parsePostgres(tt.input)
j, err := ParsePostgres(logContext.Background(), tt.input)
if err != nil {
got := result{ParseErr: true}

Expand Down
12 changes: 6 additions & 6 deletions pkg/detectors/jdbc/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func TestParsePostgresMissingCredentials(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := logContext.AddLogger(context.Background())
j, err := parsePostgres(ctx, tt.subname)
j, err := ParsePostgres(ctx, tt.subname)

if tt.shouldBeNil {
if j != nil {
Expand Down Expand Up @@ -86,14 +86,14 @@ func TestParsePostgresUsernameRecognition(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := logContext.AddLogger(context.Background())
j, err := parsePostgres(ctx, tt.subname)
j, err := ParsePostgres(ctx, tt.subname)
if err != nil {
t.Fatalf("parsePostgres() error = %v", err)
t.Fatalf("ParsePostgres() error = %v", err)
}

pgConn := j.(*postgresJDBC)
if pgConn.params["user"] != tt.wantUsername {
t.Errorf("expected username '%s', got '%s'", tt.wantUsername, pgConn.params["user"])
pgConn := j.(*PostgresJDBC)
if pgConn.User != tt.wantUsername {
t.Errorf("expected username '%s', got '%s'", tt.wantUsername, pgConn.User)
}
})
}
Expand Down
Loading
Loading