Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions pkg/detectors/jdbc/jdbc.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,16 +207,25 @@ func tryRedactRegex(conn string) (string, bool) {
}

var supportedSubprotocols = map[string]func(logContext.Context, string) (jdbc, error){
"mysql": parseMySQL,
"postgresql": parsePostgres,
"sqlserver": parseSqlServer,
"mysql": ParseMySQL,
"postgresql": ParsePostgres,
"sqlserver": ParseSqlServer,
}

type pingResult struct {
err error
determinate bool
}

// ConnectionInfo holds parsed connection information
type ConnectionInfo struct {
Host string // includes port if specified, e.g., "host:port"
Database string
User string
Password string
Params map[string]string
}

type jdbc interface {
ping(context.Context) pingResult
}
Expand Down
59 changes: 37 additions & 22 deletions pkg/detectors/jdbc/mysql.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,30 @@ import (
"github.com/go-sql-driver/mysql"
)

type mysqlJDBC struct {
conn string
userPass string
host string
params string
type MysqlJDBC struct {
ConnectionInfo
}

func (s *mysqlJDBC) ping(ctx context.Context) pingResult {
func (s *MysqlJDBC) ping(ctx context.Context) pingResult {
return ping(ctx, "mysql", isMySQLErrorDeterminate,
buildMySQLConnectionString(s.host, "", s.userPass, s.params))
BuildMySQLConnectionString(s.Host, "", s.User, s.Password, s.Params))
}

func buildMySQLConnectionString(host, database, userPass, params string) string {
func BuildMySQLConnectionString(host, database, user, password string, params map[string]string) string {
conn := host + "/" + database
userPass := user
if password != "" {
userPass = userPass + ":" + password
}
if userPass != "" {
conn = userPass + "@" + conn
}
if params != "" {
conn = conn + "?" + params
if len(params) > 0 {
var paramList []string
for k, v := range params {
paramList = append(paramList, fmt.Sprintf("%s=%s", k, v))
}
conn = conn + "?" + strings.Join(paramList, "&")
}
return conn
}
Expand All @@ -51,7 +56,7 @@ func isMySQLErrorDeterminate(err error) bool {
return false
}

func parseMySQL(ctx logContext.Context, subname string) (jdbc, error) {
func ParseMySQL(ctx logContext.Context, subname string) (jdbc, error) {
// expected form: [subprotocol:]//[user:password@]HOST[/DB][?key=val[&key=val]]
if !strings.HasPrefix(subname, "//") {
return nil, errors.New("expected host to start with //")
Expand All @@ -70,11 +75,14 @@ func parseMySQL(ctx logContext.Context, subname string) (jdbc, error) {
Info("Skipping invalid MySQL URL - no password or host found")
return nil, fmt.Errorf("missing host or password in connection string")
}
return &mysqlJDBC{
conn: subname[2:],
userPass: cfg.User + ":" + cfg.Passwd,
host: fmt.Sprintf("tcp(%s)", cfg.Addr),
params: "timeout=5s",
return &MysqlJDBC{
ConnectionInfo: ConnectionInfo{
User: cfg.User,
Password: cfg.Passwd,
Host: fmt.Sprintf("tcp(%s)", cfg.Addr),
Params: map[string]string{"timeout": "5s"},
Database: cfg.DBName,
},
}, nil
}

Expand Down Expand Up @@ -107,13 +115,20 @@ func parseMySQLURI(ctx logContext.Context, subname string) (jdbc, error) {
return nil, fmt.Errorf("missing host or password in connection string")
}

userAndPass := user + ":" + pass
// Parse database name
dbName := strings.TrimPrefix(u.Path, "/")
if dbName == "" {
dbName = "mysql" // default DB
}

return &mysqlJDBC{
conn: subname[2:],
userPass: userAndPass,
host: fmt.Sprintf("tcp(%s)", u.Host),
params: "timeout=5s",
return &MysqlJDBC{
ConnectionInfo: ConnectionInfo{
User: user,
Password: pass,
Host: fmt.Sprintf("tcp(%s)", u.Host),
Params: map[string]string{"timeout": "5s"},
Database: dbName,
},
}, nil

}
4 changes: 3 additions & 1 deletion pkg/detectors/jdbc/mysql_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"github.com/testcontainers/testcontainers-go/modules/mysql"

logContext "github.com/trufflesecurity/trufflehog/v3/pkg/context"
)

func TestMySQL(t *testing.T) {
Expand Down Expand Up @@ -89,7 +91,7 @@ func TestMySQL(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.input, func(t *testing.T) {
j, err := parseMySQL(tt.input)
j, err := ParseMySQL(logContext.Background(), tt.input)

if err != nil {
got := result{ParseErr: true}
Expand Down
11 changes: 5 additions & 6 deletions pkg/detectors/jdbc/mysql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package jdbc

import (
"context"
"strings"
"testing"

logContext "github.com/trufflesecurity/trufflehog/v3/pkg/context"
Expand Down Expand Up @@ -59,7 +58,7 @@ func TestParseMySQLMissingCredentials(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := logContext.AddLogger(context.Background())
j, err := parseMySQL(ctx, tt.subname)
j, err := ParseMySQL(ctx, tt.subname)

if tt.shouldBeNil {
if j != nil {
Expand Down Expand Up @@ -103,15 +102,15 @@ func TestParseMySQLUsernameRecognition(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := logContext.AddLogger(context.Background())
j, err := parseMySQL(ctx, tt.subname)
j, err := ParseMySQL(ctx, tt.subname)
if err != nil {
t.Fatalf("parseMySQL() error = %v", err)
}

mysqlConn := j.(*mysqlJDBC)
if !strings.Contains(mysqlConn.userPass, tt.wantUsername) {
mysqlConn := j.(*MysqlJDBC)
if mysqlConn.User != tt.wantUsername {
t.Errorf("Connection string does not contain expected username '%s'\nGot: %s\nExpected: %s",
tt.wantUsername, mysqlConn.userPass, tt.wantUsername)
tt.wantUsername, mysqlConn.User, tt.wantUsername)
}
})
}
Expand Down
62 changes: 36 additions & 26 deletions pkg/detectors/jdbc/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,19 @@ import (
"github.com/lib/pq"
)

type postgresJDBC struct {
conn string
params map[string]string
type PostgresJDBC struct {
ConnectionInfo
}

func (s *postgresJDBC) ping(ctx context.Context) pingResult {
func (s *PostgresJDBC) ping(ctx context.Context) pingResult {
// It is crucial that we try to build a connection string ourselves before using the one we found. This is because
// if the found connection string doesn't include a username, the driver will attempt to connect using the current
// user's name, which will fail in a way that looks like a determinate failure, thus terminating the waterfall. In
// contrast, when we build a connection string ourselves, if there's no username, we try 'postgres' instead, which
// actually has a chance of working.
return ping(ctx, "postgres", isPostgresErrorDeterminate,
buildPostgresConnectionString(s.params, true),
buildPostgresConnectionString(s.params, false),
BuildPostgresConnectionString(s.Host, s.User, s.Password, "postgres", s.Params, true),
BuildPostgresConnectionString(s.Host, s.User, s.Password, "postgres", s.Params, false),
)
}

Expand Down Expand Up @@ -59,7 +58,7 @@ func joinKeyValues(m map[string]string, sep string) string {
return strings.Join(data, sep)
}

func parsePostgres(ctx logContext.Context, subname string) (jdbc, error) {
func ParsePostgres(ctx logContext.Context, subname string) (jdbc, error) {
// expected form: [subprotocol:]//[user:password@]HOST[/DB][?key=val[&key=val]]

if !strings.HasPrefix(subname, "//") {
Expand All @@ -77,16 +76,22 @@ func parsePostgres(ctx logContext.Context, subname string) (jdbc, error) {
}

params := map[string]string{
"host": u.Host,
"dbname": dbName,
"connect_timeout": "5",
}

postgresJDBC := &PostgresJDBC{
ConnectionInfo: ConnectionInfo{
Host: u.Host,
Database: dbName,
Params: params,
},
}

if u.User != nil {
params["user"] = u.User.Username()
postgresJDBC.User = u.User.Username()
pass, set := u.User.Password()
if set {
params["password"] = pass
postgresJDBC.Password = pass
}
}

Expand All @@ -95,46 +100,51 @@ func parsePostgres(ctx logContext.Context, subname string) (jdbc, error) {
// https://www.postgresql.org/docs/current/libpq-ssl.html#LIBPQ-SSL-PROTECTION
case "disable", "allow", "prefer",
"require", "verify-ca", "verify-full":
params["sslmode"] = v[0]
postgresJDBC.Params["sslmode"] = v[0]
}
}

if v := u.Query().Get("user"); v != "" {
params["user"] = v
postgresJDBC.User = v
}

if v := u.Query().Get("password"); v != "" {
params["password"] = v
postgresJDBC.Password = v
}

if params["host"] == "" || params["password"] == "" {
if postgresJDBC.Host == "" || postgresJDBC.Password == "" {
ctx.Logger().WithName("jdbc").
V(2).
Info("Skipping invalid Postgres URL - no password or host found")
return nil, fmt.Errorf("missing host or password in connection string")
}

return &postgresJDBC{subname[2:], params}, nil
return postgresJDBC, nil
}

func buildPostgresConnectionString(params map[string]string, includeDbName bool) string {
func BuildPostgresConnectionString(host string, user string, password string, dbName string, params map[string]string, includeDbName bool) string {
data := map[string]string{
// default user
"user": "postgres",
"user": "postgres",
"password": password,
"host": host,
}
if user != "" {
data["user"] = user
}
if h, p, ok := strings.Cut(host, ":"); ok {
data["host"] = h
data["port"] = p
}
for key, val := range params {
if key == "host" {
if h, p, found := strings.Cut(val, ":"); found {
data["host"] = h
data["port"] = p
continue
}
}
data[key] = val
}

if !includeDbName {
if includeDbName {
data["dbname"] = "postgres"
if dbName != "" {
data["dbname"] = dbName
}
}

connStr := joinKeyValues(data, " ")
Expand Down
4 changes: 3 additions & 1 deletion pkg/detectors/jdbc/postgres_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import (
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/modules/postgres"
"github.com/testcontainers/testcontainers-go/wait"

logContext "github.com/trufflesecurity/trufflehog/v3/pkg/context"
)

func TestPostgres(t *testing.T) {
Expand Down Expand Up @@ -119,7 +121,7 @@ func TestPostgres(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
j, err := parsePostgres(tt.input)
j, err := ParsePostgres(logContext.Background(), tt.input)
if err != nil {
got := result{ParseErr: true}

Expand Down
12 changes: 6 additions & 6 deletions pkg/detectors/jdbc/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func TestParsePostgresMissingCredentials(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := logContext.AddLogger(context.Background())
j, err := parsePostgres(ctx, tt.subname)
j, err := ParsePostgres(ctx, tt.subname)

if tt.shouldBeNil {
if j != nil {
Expand Down Expand Up @@ -86,14 +86,14 @@ func TestParsePostgresUsernameRecognition(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := logContext.AddLogger(context.Background())
j, err := parsePostgres(ctx, tt.subname)
j, err := ParsePostgres(ctx, tt.subname)
if err != nil {
t.Fatalf("parsePostgres() error = %v", err)
t.Fatalf("ParsePostgres() error = %v", err)
}

pgConn := j.(*postgresJDBC)
if pgConn.params["user"] != tt.wantUsername {
t.Errorf("expected username '%s', got '%s'", tt.wantUsername, pgConn.params["user"])
pgConn := j.(*PostgresJDBC)
if pgConn.User != tt.wantUsername {
t.Errorf("expected username '%s', got '%s'", tt.wantUsername, pgConn.User)
}
})
}
Expand Down
Loading
Loading