github.com/devseccon/trivy@v0.47.1-0.20231123133102-bd902a0bd996/pkg/mapfs/file.go (about)

     1  package mapfs
     2  
     3  import (
     4  	"io"
     5  	"io/fs"
     6  	"os"
     7  	"path/filepath"
     8  	"sort"
     9  	"strings"
    10  	"time"
    11  
    12  	"golang.org/x/xerrors"
    13  
    14  	xsync "github.com/devseccon/trivy/pkg/x/sync"
    15  )
    16  
    17  var separator = "/"
    18  
    19  // file represents one of them:
    20  // - an actual file
    21  // - a virtual file
    22  // - a virtual dir
    23  type file struct {
    24  	underlyingPath string // underlying file path
    25  	data           []byte // virtual file, only either of 'path' or 'data' has a value.
    26  	stat           fileStat
    27  	files          xsync.Map[string, *file]
    28  }
    29  
    30  func (f *file) isVirtual() bool {
    31  	return len(f.data) != 0 || f.stat.IsDir()
    32  }
    33  
    34  func (f *file) Open(name string) (fs.File, error) {
    35  	if name == "" || name == "." {
    36  		return f.open()
    37  	}
    38  
    39  	if sub, err := f.getFile(name); err == nil {
    40  		return sub.open()
    41  	}
    42  
    43  	return nil, &fs.PathError{
    44  		Op:   "open",
    45  		Path: name,
    46  		Err:  fs.ErrNotExist,
    47  	}
    48  }
    49  
    50  func (f *file) open() (fs.File, error) {
    51  	switch {
    52  	case f.stat.IsDir(): // Directory
    53  		entries, err := f.ReadDir(".")
    54  		if err != nil {
    55  			return nil, xerrors.Errorf("read dir error: %w", err)
    56  		}
    57  		return &mapDir{
    58  			path:     f.underlyingPath,
    59  			fileStat: f.stat,
    60  			entry:    entries,
    61  		}, nil
    62  	case len(f.data) != 0: // Virtual file
    63  		return &openMapFile{
    64  			path:   f.stat.name,
    65  			file:   f,
    66  			offset: 0,
    67  		}, nil
    68  	default: // Real file
    69  		return os.Open(f.underlyingPath)
    70  	}
    71  }
    72  
    73  func (f *file) Remove(name string) error {
    74  	if name == "" || name == "." {
    75  		return nil
    76  	}
    77  
    78  	return f.removePath(name, false)
    79  }
    80  
    81  func (f *file) RemoveAll(name string) error {
    82  	if name == "" || name == "." {
    83  		return nil
    84  	}
    85  
    86  	return f.removePath(name, true)
    87  }
    88  
    89  func (f *file) removePath(name string, recursive bool) error {
    90  	parts := strings.Split(name, separator)
    91  	if len(parts) == 1 {
    92  		sub, ok := f.files.Load(name)
    93  		if !ok {
    94  			return fs.ErrNotExist
    95  		}
    96  		if sub.files.Len() != 0 && !recursive {
    97  			return fs.ErrInvalid
    98  		}
    99  		f.files.Delete(name)
   100  		return nil
   101  	}
   102  
   103  	sub, err := f.getFile(parts[0])
   104  	if err != nil {
   105  		return err
   106  	} else if !sub.stat.IsDir() {
   107  		return fs.ErrNotExist
   108  	}
   109  
   110  	return sub.removePath(strings.Join(parts[1:], separator), recursive)
   111  }
   112  
   113  func (f *file) getFile(name string) (*file, error) {
   114  	if name == "" || name == "." {
   115  		return f, nil
   116  	}
   117  	parts := strings.Split(name, separator)
   118  	if len(parts) == 1 {
   119  		f, ok := f.files.Load(name)
   120  		if ok {
   121  			return f, nil
   122  		}
   123  		return nil, fs.ErrNotExist
   124  	}
   125  
   126  	sub, ok := f.files.Load(parts[0])
   127  	if !ok || !sub.stat.IsDir() {
   128  		return nil, fs.ErrNotExist
   129  	}
   130  
   131  	return sub.getFile(strings.Join(parts[1:], separator))
   132  }
   133  
   134  func (f *file) ReadDir(name string) ([]fs.DirEntry, error) {
   135  	if name == "" || name == "." {
   136  		var entries []fs.DirEntry
   137  		var err error
   138  		f.files.Range(func(name string, value *file) bool {
   139  			if value.isVirtual() {
   140  				entries = append(entries, &value.stat)
   141  			} else {
   142  				var fi os.FileInfo
   143  				fi, err = os.Stat(value.underlyingPath)
   144  				if err != nil {
   145  					return false
   146  				}
   147  				entries = append(entries, &fileStat{
   148  					name:    name,
   149  					size:    fi.Size(),
   150  					mode:    fi.Mode(),
   151  					modTime: fi.ModTime(),
   152  					sys:     fi.Sys(),
   153  				})
   154  			}
   155  			return true
   156  		})
   157  		if err != nil {
   158  			return nil, xerrors.Errorf("range error: %w", err)
   159  		}
   160  		sort.Slice(entries, func(i, j int) bool { return entries[i].Name() < entries[j].Name() })
   161  		return entries, nil
   162  	}
   163  
   164  	parts := strings.Split(name, separator)
   165  	dir, ok := f.files.Load(parts[0])
   166  	if !ok || !dir.stat.IsDir() {
   167  		return nil, fs.ErrNotExist
   168  	}
   169  	return dir.ReadDir(strings.Join(parts[1:], separator))
   170  }
   171  
   172  func (f *file) MkdirAll(path string, perm fs.FileMode) error {
   173  	parts := strings.Split(path, separator)
   174  
   175  	if path == "" || path == "." {
   176  		return nil
   177  	}
   178  
   179  	if perm&fs.ModeDir == 0 {
   180  		perm |= fs.ModeDir
   181  	}
   182  
   183  	sub := &file{
   184  		stat: fileStat{
   185  			name:    parts[0],
   186  			size:    0x100,
   187  			modTime: time.Now(),
   188  			mode:    perm,
   189  		},
   190  		files: xsync.Map[string, *file]{},
   191  	}
   192  
   193  	// Create the directory when the key is not present
   194  	sub, loaded := f.files.LoadOrStore(parts[0], sub)
   195  	if loaded && !sub.stat.IsDir() {
   196  		return fs.ErrExist
   197  	}
   198  
   199  	if len(parts) == 1 {
   200  		return nil
   201  	}
   202  
   203  	return sub.MkdirAll(strings.Join(parts[1:], separator), perm)
   204  }
   205  
   206  func (f *file) WriteFile(path, underlyingPath string) error {
   207  	parts := strings.Split(path, separator)
   208  
   209  	if len(parts) == 1 {
   210  		f.files.Store(parts[0], &file{
   211  			underlyingPath: underlyingPath,
   212  		})
   213  		return nil
   214  	}
   215  
   216  	dir, ok := f.files.Load(parts[0])
   217  	if !ok || !dir.stat.IsDir() {
   218  		return fs.ErrNotExist
   219  	}
   220  
   221  	return dir.WriteFile(strings.Join(parts[1:], separator), underlyingPath)
   222  }
   223  
   224  func (f *file) WriteVirtualFile(path string, data []byte, mode fs.FileMode) error {
   225  	if mode&fs.ModeDir != 0 {
   226  		return xerrors.Errorf("invalid perm: %v", mode)
   227  	}
   228  	parts := strings.Split(path, separator)
   229  
   230  	if len(parts) == 1 {
   231  		f.files.Store(parts[0], &file{
   232  			data: data,
   233  			stat: fileStat{
   234  				name:    parts[0],
   235  				size:    int64(len(data)),
   236  				mode:    mode,
   237  				modTime: time.Now(),
   238  			},
   239  		})
   240  		return nil
   241  	}
   242  
   243  	dir, ok := f.files.Load(parts[0])
   244  	if !ok || !dir.stat.IsDir() {
   245  		return fs.ErrNotExist
   246  	}
   247  
   248  	return dir.WriteVirtualFile(strings.Join(parts[1:], separator), data, mode)
   249  }
   250  
   251  func (f *file) glob(pattern string) ([]string, error) {
   252  	var entries []string
   253  	parts := strings.Split(pattern, separator)
   254  
   255  	var err error
   256  	f.files.Range(func(name string, sub *file) bool {
   257  		if ok, err := filepath.Match(parts[0], name); err != nil {
   258  			return false
   259  		} else if ok {
   260  			if len(parts) == 1 {
   261  				entries = append(entries, name)
   262  			} else {
   263  				subEntries, err := sub.glob(strings.Join(parts[1:], separator))
   264  				if err != nil {
   265  					return false
   266  				}
   267  				for _, sub := range subEntries {
   268  					entries = append(entries, name+separator+sub)
   269  				}
   270  			}
   271  		}
   272  		return true
   273  	})
   274  	if err != nil {
   275  		return nil, xerrors.Errorf("range error: %w", err)
   276  	}
   277  
   278  	sort.Strings(entries)
   279  	return entries, nil
   280  }
   281  
   282  // An openMapFile is a regular (non-directory) fs.File open for reading.
   283  // ported from https://github.com/golang/go/blob/99bc53f5e819c2d2d49f2a56c488898085be3982/src/testing/fstest/mapfs.go
   284  type openMapFile struct {
   285  	path string
   286  	*file
   287  	offset int64
   288  }
   289  
   290  func (f *openMapFile) Stat() (fs.FileInfo, error) { return &f.file.stat, nil }
   291  
   292  func (f *openMapFile) Close() error { return nil }
   293  
   294  func (f *openMapFile) Read(b []byte) (int, error) {
   295  	if f.offset >= int64(len(f.file.data)) {
   296  		return 0, io.EOF
   297  	}
   298  	if f.offset < 0 {
   299  		return 0, &fs.PathError{
   300  			Op:   "read",
   301  			Path: f.path,
   302  			Err:  fs.ErrInvalid,
   303  		}
   304  	}
   305  	n := copy(b, f.file.data[f.offset:])
   306  	f.offset += int64(n)
   307  	return n, nil
   308  }
   309  
   310  func (f *openMapFile) Seek(offset int64, whence int) (int64, error) {
   311  	switch whence {
   312  	case 0:
   313  		// offset += 0
   314  	case 1:
   315  		offset += f.offset
   316  	case 2:
   317  		offset += int64(len(f.file.data))
   318  	}
   319  	if offset < 0 || offset > int64(len(f.file.data)) {
   320  		return 0, &fs.PathError{
   321  			Op:   "seek",
   322  			Path: f.path,
   323  			Err:  fs.ErrInvalid,
   324  		}
   325  	}
   326  	f.offset = offset
   327  	return offset, nil
   328  }
   329  
   330  func (f *openMapFile) ReadAt(b []byte, offset int64) (int, error) {
   331  	if offset < 0 || offset > int64(len(f.file.data)) {
   332  		return 0, &fs.PathError{
   333  			Op:   "read",
   334  			Path: f.path,
   335  			Err:  fs.ErrInvalid,
   336  		}
   337  	}
   338  	n := copy(b, f.file.data[offset:])
   339  	if n < len(b) {
   340  		return n, io.EOF
   341  	}
   342  	return n, nil
   343  }
   344  
   345  // A mapDir is a directory fs.File (so also fs.ReadDirFile) open for reading.
   346  type mapDir struct {
   347  	path string
   348  	fileStat
   349  	entry  []fs.DirEntry
   350  	offset int
   351  }
   352  
   353  func (d *mapDir) Stat() (fs.FileInfo, error) { return &d.fileStat, nil }
   354  func (d *mapDir) Close() error               { return nil }
   355  func (d *mapDir) Read(_ []byte) (int, error) {
   356  	return 0, &fs.PathError{
   357  		Op:   "read",
   358  		Path: d.path,
   359  		Err:  fs.ErrInvalid,
   360  	}
   361  }
   362  
   363  func (d *mapDir) ReadDir(count int) ([]fs.DirEntry, error) {
   364  	n := len(d.entry) - d.offset
   365  	if n == 0 && count > 0 {
   366  		return nil, io.EOF
   367  	}
   368  	if count > 0 && n > count {
   369  		n = count
   370  	}
   371  	list := make([]fs.DirEntry, n)
   372  	for i := range list {
   373  		list[i] = d.entry[d.offset+i]
   374  	}
   375  	d.offset += n
   376  	return list, nil
   377  }