Skip to content

Commit b8dddcd

Browse files
committed
make approach more robust
1 parent 48945e6 commit b8dddcd

File tree

4 files changed

+403
-8
lines changed

4 files changed

+403
-8
lines changed

pkg/pb/sourcespb/sources.pb.go

Lines changed: 8 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pkg/sources/filesystem/filesystem.go

Lines changed: 152 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ import (
1313
"google.golang.org/protobuf/proto"
1414
"google.golang.org/protobuf/types/known/anypb"
1515

16+
"github.com/trufflesecurity/trufflehog/v3/pkg/cache"
17+
"github.com/trufflesecurity/trufflehog/v3/pkg/cache/lru"
1618
"github.com/trufflesecurity/trufflehog/v3/pkg/common"
1719
"github.com/trufflesecurity/trufflehog/v3/pkg/context"
1820
"github.com/trufflesecurity/trufflehog/v3/pkg/feature"
@@ -36,6 +38,22 @@ type Source struct {
3638
filter *common.Filter
3739
skipBinaries bool
3840
followSymlinks bool
41+
// scanRootPaths tracks the top-level directories/files being scanned.
42+
// Used to enforce depth-1 symlink following: only symlinks that are direct children
43+
// of these paths will be followed, preventing deep symlink chains.
44+
scanRootPaths map[string]struct{}
45+
// visitedPaths is an LRU cache tracking canonical paths of followed symlinks.
46+
// Only created when followSymlinks=true to avoid memory overhead.
47+
//
48+
// Why LRU cache instead of a map:
49+
// - Bounded memory: Limits to 10k paths (~1MB) even for massive directory trees
50+
// - Per-path reset: Cache is recreated for each scan path to prevent accumulation
51+
// - Loop detection: Prevents scanning the same file multiple times via different symlinks
52+
//
53+
// Why depth-1 limiting:
54+
// - Prevents infinite loops: Symlink chains (A->B->C->...) are limited
55+
// - Predictable behavior: Users know exactly which symlinks will be followed
56+
visitedPaths cache.Cache[struct{}]
3957
sources.Progress
4058
sources.CommonSourceUnitUnmarshaller
4159
}
@@ -95,7 +113,30 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ .
95113
}
96114
s.SetProgressComplete(i, len(s.paths), fmt.Sprintf("Path: %s", path), "")
97115

116+
// Initialize per-path tracking - critically important for memory management.
117+
// scanRootPaths is reset for each top-level path to track depth-1 symlinks.
118+
s.scanRootPaths = make(map[string]struct{})
119+
120+
// Create LRU cache only if following symlinks to avoid unnecessary memory allocation.
121+
// The cache is recreated for each scan path to prevent memory accumulation across
122+
// multiple scans. This ensures O(paths_per_scan) memory instead of O(total_paths).
123+
if s.followSymlinks {
124+
// Maximum of 10k paths limits memory to ~1MB even for very large directory trees.
125+
// If a directory has >10k symlinks, oldest entries are evicted (LRU behavior).
126+
const maxCacheSize = 10000
127+
cache, err := lru.NewCache[struct{}]("filesystem_visited", lru.WithCapacity[struct{}](maxCacheSize))
128+
if err != nil {
129+
logger.Error(err, "failed to create LRU cache for symlink tracking")
130+
continue
131+
}
132+
s.visitedPaths = cache
133+
}
134+
98135
cleanPath := filepath.Clean(path)
136+
137+
// Store the scan root path for depth tracking
138+
s.scanRootPaths[cleanPath] = struct{}{}
139+
99140
var fileInfo fs.FileInfo
100141
var err error
101142
if s.followSymlinks {
@@ -113,6 +154,29 @@ func (s *Source) Chunks(ctx context.Context, chunksChan chan *sources.Chunk, _ .
113154
continue
114155
}
115156

157+
// If followSymlinks is enabled and this is a symlink, check for loops
158+
if s.followSymlinks && fileInfo.Mode()&os.ModeSymlink != 0 {
159+
canonicalPath, err := filepath.EvalSymlinks(cleanPath)
160+
if err != nil {
161+
logger.V(5).Info("unable to resolve symlink", "path", cleanPath, "error", err)
162+
continue
163+
}
164+
165+
// Check for loops using the LRU cache
166+
if s.visitedPaths.Exists(canonicalPath) {
167+
logger.Info("skipping symlink loop detected", "path", cleanPath, "target", canonicalPath)
168+
continue
169+
}
170+
s.visitedPaths.Set(canonicalPath, struct{}{})
171+
172+
// Re-stat the canonical path to determine if it's a file or directory
173+
fileInfo, err = os.Stat(canonicalPath)
174+
if err != nil {
175+
logger.Error(err, "unable to stat symlink target")
176+
continue
177+
}
178+
}
179+
116180
if fileInfo.IsDir() {
117181
err = s.scanDir(ctx, cleanPath, chunksChan)
118182
} else {
@@ -156,8 +220,42 @@ func (s *Source) scanDir(ctx context.Context, path string, chunksChan chan *sour
156220
return nil // skip the file
157221
}
158222

159-
// Handle symlinks when followSymlinks is enabled
223+
// Handle symlinks when followSymlinks is enabled.
224+
// DEPTH-1 ENFORCEMENT: Only follow symlinks that are direct children of scan root paths.
225+
// This prevents:
226+
// 1. Infinite symlink chains (A->B->C->...)
227+
// 2. Deep directory traversal through symlinks
228+
//
229+
// Example: If scanning /path/to/dir:
230+
// - /path/to/dir/link.txt (direct child) -> WILL be followed
231+
// - /path/to/dir/subdir/link.txt (not direct child) -> will NOT be followed
160232
if s.followSymlinks && d.Type()&fs.ModeSymlink != 0 {
233+
// Only follow symlinks that are direct children of the scan root
234+
if !s.isDirectChild(fullPath) {
235+
ctx.Logger().V(5).Info("skipping symlink (not a direct child of scan root)", "path", fullPath)
236+
return nil
237+
}
238+
239+
// Resolve the symlink to its canonical path for loop detection.
240+
// This handles cases where multiple symlinks point to the same file.
241+
canonicalPath, err := filepath.EvalSymlinks(fullPath)
242+
if err != nil {
243+
// Broken symlink or permission issue, skip it
244+
ctx.Logger().V(5).Info("unable to resolve symlink", "path", fullPath, "error", err)
245+
return nil
246+
}
247+
248+
// Check for loops using LRU cache.
249+
// Prevents scanning the same file multiple times if reachable via different symlinks.
250+
// Also prevents infinite loops where symlinks form cycles.
251+
if s.followSymlinks && s.visitedPaths != nil && s.visitedPaths.Exists(canonicalPath) {
252+
ctx.Logger().Info("skipping symlink loop detected", "path", fullPath, "target", canonicalPath)
253+
return nil
254+
}
255+
if s.followSymlinks && s.visitedPaths != nil {
256+
s.visitedPaths.Set(canonicalPath, struct{}{})
257+
}
258+
161259
// Follow the symlink to see what it points to
162260
targetInfo, err := os.Stat(fullPath)
163261
if err != nil {
@@ -209,6 +307,18 @@ func (s *Source) scanDir(ctx context.Context, path string, chunksChan chan *sour
209307

210308
var skipSymlinkErr = errors.New("skipping symlink")
211309

310+
// isDirectChild checks if a path is a direct child of any scan root path.
311+
// This enforces depth-1 symlink following to prevent:
312+
// - Infinite symlink loops
313+
// - Deep directory traversal through symlinks
314+
//
315+
// Returns true only if the symlink's parent directory matches a scan root path.
316+
func (s *Source) isDirectChild(path string) bool {
317+
dir := filepath.Clean(filepath.Dir(path))
318+
_, isRoot := s.scanRootPaths[dir]
319+
return isRoot
320+
}
321+
212322
func (s *Source) scanFile(ctx context.Context, path string, chunksChan chan *sources.Chunk) error {
213323
fileCtx := context.WithValues(ctx, "path", path)
214324
var fileStat fs.FileInfo
@@ -282,7 +392,26 @@ func (s *Source) ChunkUnit(ctx context.Context, unit sources.SourceUnit, reporte
282392
path, _ := unit.SourceUnitID()
283393
logger := ctx.Logger().WithValues("path", path)
284394

395+
// Initialize per-unit tracking - same rationale as Chunks() method.
396+
// Each ChunkUnit call gets fresh tracking to prevent memory accumulation.
397+
s.scanRootPaths = make(map[string]struct{})
398+
399+
// Create LRU cache only if following symlinks.
400+
// Memory is bounded to 10k paths (~1MB) per unit scan.
401+
if s.followSymlinks {
402+
const maxCacheSize = 10000
403+
cache, err := lru.NewCache[struct{}]("filesystem_visited", lru.WithCapacity[struct{}](maxCacheSize))
404+
if err != nil {
405+
return reporter.ChunkErr(ctx, fmt.Errorf("failed to create LRU cache: %w", err))
406+
}
407+
s.visitedPaths = cache
408+
}
409+
285410
cleanPath := filepath.Clean(path)
411+
412+
// Store the scan root path for depth tracking
413+
s.scanRootPaths[cleanPath] = struct{}{}
414+
286415
var fileInfo fs.FileInfo
287416
var err error
288417
if s.followSymlinks {
@@ -294,6 +423,28 @@ func (s *Source) ChunkUnit(ctx context.Context, unit sources.SourceUnit, reporte
294423
return reporter.ChunkErr(ctx, fmt.Errorf("unable to get file info: %w", err))
295424
}
296425

426+
// If followSymlinks is enabled and this is a symlink, check for loops
427+
if s.followSymlinks && fileInfo.Mode()&os.ModeSymlink != 0 {
428+
canonicalPath, err := filepath.EvalSymlinks(cleanPath)
429+
if err != nil {
430+
logger.V(5).Info("unable to resolve symlink", "path", cleanPath, "error", err)
431+
return reporter.ChunkErr(ctx, fmt.Errorf("unable to resolve symlink: %w", err))
432+
}
433+
434+
// Check for loops
435+
if s.visitedPaths.Exists(canonicalPath) {
436+
logger.Info("skipping symlink loop detected", "path", cleanPath, "target", canonicalPath)
437+
return nil
438+
}
439+
s.visitedPaths.Set(canonicalPath, struct{}{})
440+
441+
// Re-stat the canonical path to determine if it's a file or directory
442+
fileInfo, err = os.Stat(canonicalPath)
443+
if err != nil {
444+
return reporter.ChunkErr(ctx, fmt.Errorf("unable to stat symlink target: %w", err))
445+
}
446+
}
447+
297448
ch := make(chan *sources.Chunk)
298449
var scanErr error
299450
go func() {

0 commit comments

Comments
 (0)