github.com/david-imola/snapd@v0.0.0-20210611180407-2de8ddeece6d/testutil/lowlevel.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2017-2018 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package testutil
    21  
    22  import (
    23  	"fmt"
    24  	"os"
    25  	"strings"
    26  	"syscall"
    27  	"time"
    28  
    29  	"gopkg.in/check.v1"
    30  
    31  	"github.com/snapcore/snapd/osutil/mount"
    32  	"github.com/snapcore/snapd/osutil/sys"
    33  )
    34  
    35  const umountNoFollow = 8
    36  
    37  // fakeFileInfo implements os.FileInfo for testing.
    38  //
    39  // Some of the functions panic as we don't expect them to be called.
    40  // Feel free to expand them as necessary.
    41  type fakeFileInfo struct {
    42  	name string
    43  	mode os.FileMode
    44  }
    45  
    46  func (fi *fakeFileInfo) Name() string      { return fi.name }
    47  func (*fakeFileInfo) Size() int64          { panic("unexpected call") }
    48  func (fi *fakeFileInfo) Mode() os.FileMode { return fi.mode }
    49  func (*fakeFileInfo) ModTime() time.Time   { panic("unexpected call") }
    50  func (fi *fakeFileInfo) IsDir() bool       { return fi.Mode().IsDir() }
    51  func (*fakeFileInfo) Sys() interface{}     { panic("unexpected call") }
    52  
    53  // FakeFileInfo returns a fake object implementing os.FileInfo
    54  func FakeFileInfo(name string, mode os.FileMode) os.FileInfo {
    55  	return &fakeFileInfo{name: name, mode: mode}
    56  }
    57  
    58  // Convenient FakeFileInfo objects for InsertLstatResult
    59  var (
    60  	FileInfoFile    = &fakeFileInfo{}
    61  	FileInfoDir     = &fakeFileInfo{mode: os.ModeDir}
    62  	FileInfoSymlink = &fakeFileInfo{mode: os.ModeSymlink}
    63  )
    64  
    65  // Formatter for flags passed to open syscall.
    66  //
    67  // Not all flags are handled. Unknown flags cause a panic.
    68  // Please expand the set of recognized flags as tests require.
    69  func formatOpenFlags(flags int) string {
    70  	var fl []string
    71  	if flags&syscall.O_NOFOLLOW != 0 {
    72  		flags ^= syscall.O_NOFOLLOW
    73  		fl = append(fl, "O_NOFOLLOW")
    74  	}
    75  	if flags&syscall.O_CLOEXEC != 0 {
    76  		flags ^= syscall.O_CLOEXEC
    77  		fl = append(fl, "O_CLOEXEC")
    78  	}
    79  	if flags&syscall.O_DIRECTORY != 0 {
    80  		flags ^= syscall.O_DIRECTORY
    81  		fl = append(fl, "O_DIRECTORY")
    82  	}
    83  	if flags&syscall.O_RDWR != 0 {
    84  		flags ^= syscall.O_RDWR
    85  		fl = append(fl, "O_RDWR")
    86  	}
    87  	if flags&syscall.O_CREAT != 0 {
    88  		flags ^= syscall.O_CREAT
    89  		fl = append(fl, "O_CREAT")
    90  	}
    91  	if flags&syscall.O_EXCL != 0 {
    92  		flags ^= syscall.O_EXCL
    93  		fl = append(fl, "O_EXCL")
    94  	}
    95  	if flags&sys.O_PATH != 0 {
    96  		flags ^= sys.O_PATH
    97  		fl = append(fl, "O_PATH")
    98  	}
    99  	if flags != 0 {
   100  		panic(fmt.Errorf("unrecognized open flags %d", flags))
   101  	}
   102  	if len(fl) == 0 {
   103  		return "0"
   104  	}
   105  	return strings.Join(fl, "|")
   106  }
   107  
   108  // Formatter for flags passed to mount syscall.
   109  //
   110  // Not all flags are handled. Unknown flags cause a panic.
   111  // Please expand the set of recognized flags as tests require.
   112  func formatMountFlags(flags int) string {
   113  	fl, unknown := mount.MountFlagsToOpts(flags)
   114  	if unknown != 0 {
   115  		panic(fmt.Errorf("unrecognized mount flags %d", unknown))
   116  	}
   117  	if len(fl) == 0 {
   118  		return "0"
   119  	}
   120  	return strings.Join(fl, "|")
   121  }
   122  
   123  // Formatter for flags passed to unmount syscall.
   124  //
   125  // Not all flags are handled. Unknown flags cause a panic.
   126  // Please expand the set of recognized flags as tests require.
   127  func formatUnmountFlags(flags int) string {
   128  	fl, unknown := mount.UnmountFlagsToOpts(flags)
   129  	if unknown != 0 {
   130  		panic(fmt.Errorf("unrecognized unmount flags %d", unknown))
   131  	}
   132  	if len(fl) == 0 {
   133  		return "0"
   134  	}
   135  	return strings.Join(fl, "|")
   136  }
   137  
   138  // CallResultError describes a system call and the corresponding result or error.
   139  //
   140  // The field names stand for Call, Result and Error respectively. They are
   141  // abbreviated due to the nature of their use (in large quantity).
   142  type CallResultError struct {
   143  	C string
   144  	R interface{}
   145  	E error
   146  }
   147  
   148  // SyscallRecorder stores which system calls were invoked.
   149  //
   150  // The recorder supports a small set of features useful for testing: injecting
   151  // failures, returning pre-arranged test data, allocation, tracking and
   152  // verification of file descriptors.
   153  type SyscallRecorder struct {
   154  	// History of all the system calls made.
   155  	rcalls []CallResultError
   156  	// Error function for a given system call.
   157  	errors map[string]func() error
   158  	// pre-arranged result of lstat, fstat and readdir calls.
   159  	osLstats    map[string]os.FileInfo
   160  	sysLstats   map[string]syscall.Stat_t
   161  	fstats      map[string]syscall.Stat_t
   162  	fstatfses   map[string]func() syscall.Statfs_t
   163  	readdirs    map[string][]os.FileInfo
   164  	readlinkats map[string]string
   165  	// allocated file descriptors
   166  	fds map[int]string
   167  }
   168  
   169  // InsertFault makes given subsequent call to return the specified error.
   170  //
   171  // If one error is provided then the call will reliably fail that way.
   172  // If multiple errors are given then they will be used on subsequent calls
   173  // until the errors finally run out and the call succeeds.
   174  func (sys *SyscallRecorder) InsertFault(call string, errors ...error) {
   175  	if sys.errors == nil {
   176  		sys.errors = make(map[string]func() error)
   177  	}
   178  	if len(errors) == 1 {
   179  		// deterministic error
   180  		sys.errors[call] = func() error {
   181  			return errors[0]
   182  		}
   183  	} else {
   184  		// error sequence
   185  		sys.errors[call] = func() error {
   186  			if len(errors) > 0 {
   187  				err := errors[0]
   188  				errors = errors[1:]
   189  				return err
   190  			}
   191  			return nil
   192  		}
   193  	}
   194  }
   195  
   196  // InsertFaultFunc arranges given function to be called whenever given call is made.
   197  //
   198  // The main purpose is to allow to vary the behavior of a given system call over time.
   199  // The provided function can return an error or nil to indicate success.
   200  func (sys *SyscallRecorder) InsertFaultFunc(call string, fn func() error) {
   201  	if sys.errors == nil {
   202  		sys.errors = make(map[string]func() error)
   203  	}
   204  	sys.errors[call] = fn
   205  }
   206  
   207  // Calls returns the sequence of mocked calls that have been made.
   208  func (sys *SyscallRecorder) Calls() []string {
   209  	if len(sys.rcalls) == 0 {
   210  		return nil
   211  	}
   212  	calls := make([]string, 0, len(sys.rcalls))
   213  	for _, rc := range sys.rcalls {
   214  		calls = append(calls, rc.C)
   215  	}
   216  	return calls
   217  }
   218  
   219  // RCalls returns the sequence of mocked calls that have been made along with their results.
   220  func (sys *SyscallRecorder) RCalls() []CallResultError {
   221  	return sys.rcalls
   222  }
   223  
   224  // rcall remembers that a given call has occurred and returns a pre-arranged error or value, if any
   225  func (sys *SyscallRecorder) rcall(call string, resultFn func(call string) (interface{}, error)) (val interface{}, err error) {
   226  	if errorFn := sys.errors[call]; errorFn != nil {
   227  		err = errorFn()
   228  	}
   229  	if err == nil && resultFn != nil {
   230  		val, err = resultFn(call)
   231  	}
   232  	if err != nil {
   233  		sys.rcalls = append(sys.rcalls, CallResultError{C: call, E: err})
   234  	} else {
   235  		sys.rcalls = append(sys.rcalls, CallResultError{C: call, R: val})
   236  	}
   237  	return val, err
   238  }
   239  
   240  // allocFd assigns a file descriptor to a given operation.
   241  func (sys *SyscallRecorder) allocFd(name string) int {
   242  	if sys.fds == nil {
   243  		sys.fds = make(map[int]string)
   244  	}
   245  
   246  	// Use 3 as the lowest number for tests to look more plausible.
   247  	for i := 3; i < 100; i++ {
   248  		if _, ok := sys.fds[i]; !ok {
   249  			sys.fds[i] = name
   250  			return i
   251  		}
   252  	}
   253  	panic("cannot find unused file descriptor")
   254  }
   255  
   256  // freeFd closes an open file descriptor.
   257  func (sys *SyscallRecorder) freeFd(fd int) error {
   258  	if _, ok := sys.fds[fd]; !ok {
   259  		return fmt.Errorf("attempting to close a closed file descriptor %d", fd)
   260  	}
   261  	delete(sys.fds, fd)
   262  	return nil
   263  }
   264  
   265  // StrayDescriptorsError returns an error if any descriptor is left unclosed.
   266  func (sys *SyscallRecorder) StrayDescriptorsError() error {
   267  	for fd, name := range sys.fds {
   268  		return fmt.Errorf("unclosed file descriptor %d (%s)", fd, name)
   269  	}
   270  	return nil
   271  }
   272  
   273  // CheckForStrayDescriptors ensures that all fake file descriptors are closed.
   274  func (sys *SyscallRecorder) CheckForStrayDescriptors(c *check.C) {
   275  	c.Assert(sys.StrayDescriptorsError(), check.IsNil)
   276  }
   277  
   278  // Open is a fake implementation of syscall.Open
   279  func (sys *SyscallRecorder) Open(path string, flags int, mode uint32) (int, error) {
   280  	call := fmt.Sprintf("open %q %s %#o", path, formatOpenFlags(flags), mode)
   281  	fd, err := sys.rcall(call, func(call string) (interface{}, error) {
   282  		return sys.allocFd(call), nil
   283  	})
   284  	if err != nil {
   285  		return -1, err
   286  	}
   287  	return fd.(int), nil
   288  }
   289  
   290  // Openat is a fake implementation of syscall.Openat
   291  func (sys *SyscallRecorder) Openat(dirfd int, path string, flags int, mode uint32) (int, error) {
   292  	call := fmt.Sprintf("openat %d %q %s %#o", dirfd, path, formatOpenFlags(flags), mode)
   293  	fd, err := sys.rcall(call, func(call string) (interface{}, error) {
   294  		if _, ok := sys.fds[dirfd]; !ok {
   295  			return -1, fmt.Errorf("attempting to openat with an invalid file descriptor %d", dirfd)
   296  		}
   297  		return sys.allocFd(call), nil
   298  	})
   299  	if err != nil {
   300  		return -1, err
   301  	}
   302  	return fd.(int), nil
   303  }
   304  
   305  // Close is a fake implementation of syscall.Close
   306  func (sys *SyscallRecorder) Close(fd int) error {
   307  	call := fmt.Sprintf("close %d", fd)
   308  	_, err := sys.rcall(call, func(call string) (interface{}, error) {
   309  		return nil, sys.freeFd(fd)
   310  	})
   311  	return err
   312  }
   313  
   314  // Fchown is a fake implementation of syscall.Fchown
   315  func (sys *SyscallRecorder) Fchown(fd int, uid sys.UserID, gid sys.GroupID) error {
   316  	call := fmt.Sprintf("fchown %d %d %d", fd, uid, gid)
   317  	_, err := sys.rcall(call, func(call string) (interface{}, error) {
   318  		if _, ok := sys.fds[fd]; !ok {
   319  			return nil, fmt.Errorf("attempting to fchown an invalid file descriptor %d", fd)
   320  		}
   321  		return nil, nil
   322  	})
   323  	return err
   324  }
   325  
   326  // Mkdirat is a fake implementation of syscall.Mkdirat
   327  func (sys *SyscallRecorder) Mkdirat(dirfd int, path string, mode uint32) error {
   328  	call := fmt.Sprintf("mkdirat %d %q %#o", dirfd, path, mode)
   329  	_, err := sys.rcall(call, func(call string) (interface{}, error) {
   330  		if _, ok := sys.fds[dirfd]; !ok {
   331  			return nil, fmt.Errorf("attempting to mkdirat with an invalid file descriptor %d", dirfd)
   332  		}
   333  		return nil, nil
   334  	})
   335  	return err
   336  }
   337  
   338  // Mount is a fake implementation of syscall.Mount
   339  func (sys *SyscallRecorder) Mount(source string, target string, fstype string, flags uintptr, data string) error {
   340  	call := fmt.Sprintf("mount %q %q %q %s %q", source, target, fstype, formatMountFlags(int(flags)), data)
   341  	_, err := sys.rcall(call, nil)
   342  	return err
   343  }
   344  
   345  // Unmount is a fake implementation of syscall.Unmount
   346  func (sys *SyscallRecorder) Unmount(target string, flags int) error {
   347  	call := fmt.Sprintf("unmount %q %s", target, formatUnmountFlags(flags))
   348  	_, err := sys.rcall(call, nil)
   349  	return err
   350  }
   351  
   352  // InsertOsLstatResult makes given subsequent call to OsLstat return the specified fake file info.
   353  func (sys *SyscallRecorder) InsertOsLstatResult(call string, fi os.FileInfo) {
   354  	if sys.osLstats == nil {
   355  		sys.osLstats = make(map[string]os.FileInfo)
   356  	}
   357  	sys.osLstats[call] = fi
   358  }
   359  
   360  // InsertSysLstatResult makes given subsequent call to SysLstat return the specified fake file info.
   361  func (sys *SyscallRecorder) InsertSysLstatResult(call string, sb syscall.Stat_t) {
   362  	if sys.sysLstats == nil {
   363  		sys.sysLstats = make(map[string]syscall.Stat_t)
   364  	}
   365  	sys.sysLstats[call] = sb
   366  }
   367  
   368  // OsLstat is a fake implementation of os.Lstat
   369  func (sys *SyscallRecorder) OsLstat(name string) (os.FileInfo, error) {
   370  	// NOTE the syscall.Lstat uses a different signature `lstat %q <ptr>`.
   371  	call := fmt.Sprintf("lstat %q", name)
   372  	val, err := sys.rcall(call, func(call string) (interface{}, error) {
   373  		if fi, ok := sys.osLstats[call]; ok {
   374  			return fi, nil
   375  		}
   376  		panic(fmt.Sprintf("one of InsertOsLstatResult() or InsertFault() for %s must be used", call))
   377  	})
   378  	if err != nil {
   379  		return nil, err
   380  	}
   381  	return val.(os.FileInfo), err
   382  }
   383  
   384  // SysLstat is a fake implementation of syscall.Lstat
   385  func (sys *SyscallRecorder) SysLstat(name string, sb *syscall.Stat_t) error {
   386  	// NOTE the os.Lstat uses a different signature `lstat %q`.
   387  	call := fmt.Sprintf("lstat %q <ptr>", name)
   388  	val, err := sys.rcall(call, func(call string) (interface{}, error) {
   389  		if buf, ok := sys.sysLstats[call]; ok {
   390  			return buf, nil
   391  		}
   392  		panic(fmt.Sprintf("one of InsertSysLstatResult() or InsertFault() for %s must be used", call))
   393  	})
   394  	if err == nil && sb != nil {
   395  		*sb = val.(syscall.Stat_t)
   396  	}
   397  	return err
   398  }
   399  
   400  // InsertFstatResult makes given subsequent call fstat return the specified stat buffer.
   401  func (sys *SyscallRecorder) InsertFstatResult(call string, buf syscall.Stat_t) {
   402  	if sys.fstats == nil {
   403  		sys.fstats = make(map[string]syscall.Stat_t)
   404  	}
   405  	sys.fstats[call] = buf
   406  }
   407  
   408  // Fstat is a fake implementation of syscall.Fstat
   409  func (sys *SyscallRecorder) Fstat(fd int, buf *syscall.Stat_t) error {
   410  	call := fmt.Sprintf("fstat %d <ptr>", fd)
   411  	val, err := sys.rcall(call, func(call string) (interface{}, error) {
   412  		if _, ok := sys.fds[fd]; !ok {
   413  			return nil, fmt.Errorf("attempting to fstat with an invalid file descriptor %d", fd)
   414  		}
   415  		if buf, ok := sys.fstats[call]; ok {
   416  			return buf, nil
   417  		}
   418  		panic(fmt.Sprintf("one of InsertFstatResult() or InsertFault() for %s must be used", call))
   419  	})
   420  	if err == nil && buf != nil {
   421  		*buf = val.(syscall.Stat_t)
   422  	}
   423  	return err
   424  }
   425  
   426  // InsertFstatfsResult makes given subsequent call fstatfs return the specified stat buffer.
   427  func (sys *SyscallRecorder) InsertFstatfsResult(call string, bufs ...syscall.Statfs_t) {
   428  	if sys.fstatfses == nil {
   429  		sys.fstatfses = make(map[string]func() syscall.Statfs_t)
   430  	}
   431  	if len(bufs) == 0 {
   432  		panic("cannot provide zero results to InsertFstatfsResult")
   433  	}
   434  	if len(bufs) == 1 {
   435  		// deterministic behavior
   436  		sys.fstatfses[call] = func() syscall.Statfs_t {
   437  			return bufs[0]
   438  		}
   439  	} else {
   440  		// sequential results with the last element repeated forever.
   441  		sys.fstatfses[call] = func() syscall.Statfs_t {
   442  			buf := bufs[0]
   443  			if len(bufs) > 1 {
   444  				bufs = bufs[1:]
   445  			}
   446  			return buf
   447  		}
   448  	}
   449  }
   450  
   451  // Fstatfs is a fake implementation of syscall.Fstatfs
   452  func (sys *SyscallRecorder) Fstatfs(fd int, buf *syscall.Statfs_t) error {
   453  	call := fmt.Sprintf("fstatfs %d <ptr>", fd)
   454  	val, err := sys.rcall(call, func(call string) (interface{}, error) {
   455  		if _, ok := sys.fds[fd]; !ok {
   456  			return nil, fmt.Errorf("attempting to fstatfs with an invalid file descriptor %d", fd)
   457  		}
   458  		if bufFn, ok := sys.fstatfses[call]; ok {
   459  			return bufFn(), nil
   460  		}
   461  		panic(fmt.Sprintf("one of InsertFstatfsResult() or InsertFault() for %s must be used", call))
   462  	})
   463  	if err == nil && buf != nil {
   464  		*buf = val.(syscall.Statfs_t)
   465  	}
   466  	return err
   467  }
   468  
   469  // InsertReadDirResult makes given subsequent call readdir return the specified fake file infos.
   470  func (sys *SyscallRecorder) InsertReadDirResult(call string, infos []os.FileInfo) {
   471  	if sys.readdirs == nil {
   472  		sys.readdirs = make(map[string][]os.FileInfo)
   473  	}
   474  	sys.readdirs[call] = infos
   475  }
   476  
   477  // ReadDir is a fake implementation of os.ReadDir
   478  func (sys *SyscallRecorder) ReadDir(dirname string) ([]os.FileInfo, error) {
   479  	call := fmt.Sprintf("readdir %q", dirname)
   480  	val, err := sys.rcall(call, func(call string) (interface{}, error) {
   481  		if fi, ok := sys.readdirs[call]; ok {
   482  			return fi, nil
   483  		}
   484  		panic(fmt.Sprintf("one of InsertReadDirResult() or InsertFault() for %s must be used", call))
   485  	})
   486  	if err == nil {
   487  		return val.([]os.FileInfo), nil
   488  	}
   489  	return nil, err
   490  }
   491  
   492  // Symlink is a fake implementation of syscall.Symlink
   493  func (sys *SyscallRecorder) Symlink(oldname, newname string) error {
   494  	call := fmt.Sprintf("symlink %q -> %q", newname, oldname)
   495  	_, err := sys.rcall(call, nil)
   496  	return err
   497  }
   498  
   499  // Symlinkat is a fake implementation of osutil.Symlinkat (syscall.Symlinkat is not exposed)
   500  func (sys *SyscallRecorder) Symlinkat(oldname string, dirfd int, newname string) error {
   501  	call := fmt.Sprintf("symlinkat %q %d %q", oldname, dirfd, newname)
   502  	_, err := sys.rcall(call, func(call string) (interface{}, error) {
   503  		if _, ok := sys.fds[dirfd]; !ok {
   504  			return nil, fmt.Errorf("attempting to symlinkat with an invalid file descriptor %d", dirfd)
   505  		}
   506  		return nil, nil
   507  	})
   508  	return err
   509  }
   510  
   511  // InsertReadlinkatResult makes given subsequent call to readlinkat return the specified oldname.
   512  func (sys *SyscallRecorder) InsertReadlinkatResult(call, oldname string) {
   513  	if sys.readlinkats == nil {
   514  		sys.readlinkats = make(map[string]string)
   515  	}
   516  	sys.readlinkats[call] = oldname
   517  }
   518  
   519  // Readlinkat is a fake implementation of osutil.Readlinkat (syscall.Readlinkat is not exposed)
   520  func (sys *SyscallRecorder) Readlinkat(dirfd int, path string, buf []byte) (int, error) {
   521  	call := fmt.Sprintf("readlinkat %d %q <ptr>", dirfd, path)
   522  	val, err := sys.rcall(call, func(call string) (interface{}, error) {
   523  		if _, ok := sys.fds[dirfd]; !ok {
   524  			return nil, fmt.Errorf("attempting to readlinkat with an invalid file descriptor %d", dirfd)
   525  		}
   526  		if oldname, ok := sys.readlinkats[call]; ok {
   527  			return oldname, nil
   528  		}
   529  		panic(fmt.Sprintf("one of InsertReadlinkatResult() or InsertFault() for %s must be used", call))
   530  	})
   531  	if err == nil {
   532  		n := copy(buf, val.(string))
   533  		return n, nil
   534  	}
   535  	return 0, err
   536  }
   537  
   538  // Remove is a fake implementation of os.Remove
   539  func (sys *SyscallRecorder) Remove(name string) error {
   540  	call := fmt.Sprintf("remove %q", name)
   541  	_, err := sys.rcall(call, nil)
   542  	return err
   543  }
   544  
   545  // Fchdir is a fake implementation of syscall.Fchdir
   546  func (sys *SyscallRecorder) Fchdir(fd int) error {
   547  	call := fmt.Sprintf("fchdir %d", fd)
   548  	_, err := sys.rcall(call, func(call string) (interface{}, error) {
   549  		if _, ok := sys.fds[fd]; !ok {
   550  			return nil, fmt.Errorf("attempting to fchdir with an invalid file descriptor %d", fd)
   551  		}
   552  		return nil, nil
   553  	})
   554  	return err
   555  }