github.com/stolowski/snapd@v0.0.0-20210407085831-115137ce5a22/osutil/export_test.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2016 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package osutil
    21  
    22  import (
    23  	"fmt"
    24  	"io"
    25  	"io/ioutil"
    26  	"os"
    27  	"os/exec"
    28  	"os/user"
    29  	"syscall"
    30  	"time"
    31  
    32  	"github.com/snapcore/snapd/osutil/sys"
    33  	"github.com/snapcore/snapd/strutil"
    34  )
    35  
    36  var (
    37  	StreamsEqualChunked  = streamsEqualChunked
    38  	FilesAreEqualChunked = filesAreEqualChunked
    39  	SudoersFile          = sudoersFile
    40  	DoCopyFile           = doCopyFile
    41  )
    42  
    43  type Fileish = fileish
    44  
    45  func MockMaxCp(new int64) (restore func()) {
    46  	old := maxcp
    47  	maxcp = new
    48  	return func() {
    49  		maxcp = old
    50  	}
    51  }
    52  
    53  func MockCopyFile(new func(fileish, fileish, os.FileInfo) error) (restore func()) {
    54  	old := copyfile
    55  	copyfile = new
    56  	return func() {
    57  		copyfile = old
    58  	}
    59  }
    60  
    61  func MockOpenFile(new func(string, int, os.FileMode) (fileish, error)) (restore func()) {
    62  	old := openfile
    63  	openfile = new
    64  	return func() {
    65  		openfile = old
    66  	}
    67  }
    68  
    69  func MockSyscallSettimeofday(f func(*syscall.Timeval) error) (restore func()) {
    70  	old := syscallSettimeofday
    71  	syscallSettimeofday = f
    72  	return func() {
    73  		syscallSettimeofday = old
    74  	}
    75  }
    76  
    77  func MockUserLookup(mock func(name string) (*user.User, error)) func() {
    78  	realUserLookup := userLookup
    79  	userLookup = mock
    80  
    81  	return func() { userLookup = realUserLookup }
    82  }
    83  
    84  func MockUserCurrent(mock func() (*user.User, error)) func() {
    85  	realUserCurrent := userCurrent
    86  	userCurrent = mock
    87  
    88  	return func() { userCurrent = realUserCurrent }
    89  }
    90  
    91  func MockSudoersDotD(mockDir string) func() {
    92  	realSudoersD := sudoersDotD
    93  	sudoersDotD = mockDir
    94  
    95  	return func() { sudoersDotD = realSudoersD }
    96  }
    97  
    98  func MockSyscallKill(f func(int, syscall.Signal) error) func() {
    99  	oldSyscallKill := syscallKill
   100  	syscallKill = f
   101  	return func() {
   102  		syscallKill = oldSyscallKill
   103  	}
   104  }
   105  
   106  func MockSyscallStatfs(f func(string, *syscall.Statfs_t) error) func() {
   107  	oldSyscallStatfs := syscallStatfs
   108  	syscallStatfs = f
   109  	return func() {
   110  		syscallStatfs = oldSyscallStatfs
   111  	}
   112  }
   113  
   114  func MockSyscallGetpgid(f func(int) (int, error)) func() {
   115  	oldSyscallGetpgid := syscallGetpgid
   116  	syscallGetpgid = f
   117  	return func() {
   118  		syscallGetpgid = oldSyscallGetpgid
   119  	}
   120  }
   121  
   122  func MockCmdWaitTimeout(timeout time.Duration) func() {
   123  	oldCmdWaitTimeout := cmdWaitTimeout
   124  	cmdWaitTimeout = timeout
   125  	return func() {
   126  		cmdWaitTimeout = oldCmdWaitTimeout
   127  	}
   128  }
   129  
   130  func WaitingReaderGuts(r io.Reader) (io.Reader, *exec.Cmd) {
   131  	wr := r.(*waitingReader)
   132  	return wr.reader, wr.cmd
   133  }
   134  
   135  func MockChown(f func(*os.File, sys.UserID, sys.GroupID) error) func() {
   136  	oldChown := chown
   137  	chown = f
   138  	return func() {
   139  		chown = oldChown
   140  	}
   141  }
   142  
   143  func MockLookPath(new func(string) (string, error)) (restore func()) {
   144  	old := lookPath
   145  	lookPath = new
   146  	return func() {
   147  		lookPath = old
   148  	}
   149  }
   150  
   151  func SetAtomicFileRenamed(aw *AtomicFile, renamed bool) {
   152  	aw.renamed = renamed
   153  }
   154  
   155  func SetUnsafeIO(b bool) func() {
   156  	oldSnapdUnsafeIO := snapdUnsafeIO
   157  	snapdUnsafeIO = b
   158  	return func() {
   159  		snapdUnsafeIO = oldSnapdUnsafeIO
   160  	}
   161  }
   162  
   163  func GetUnsafeIO() bool {
   164  	// a getter so that tests do not attempt to modify that directly
   165  	return snapdUnsafeIO
   166  }
   167  
   168  func MockOsReadlink(f func(string) (string, error)) func() {
   169  	realOsReadlink := osReadlink
   170  	osReadlink = f
   171  
   172  	return func() { osReadlink = realOsReadlink }
   173  }
   174  
   175  // MockEtcFstab mocks content of /etc/fstab read by IsHomeUsingNFS
   176  func MockEtcFstab(text string) (restore func()) {
   177  	old := etcFstab
   178  	f, err := ioutil.TempFile("", "fstab")
   179  	if err != nil {
   180  		panic(fmt.Errorf("cannot open temporary file: %s", err))
   181  	}
   182  	if err := ioutil.WriteFile(f.Name(), []byte(text), 0644); err != nil {
   183  		panic(fmt.Errorf("cannot write mock fstab file: %s", err))
   184  	}
   185  	etcFstab = f.Name()
   186  	return func() {
   187  		if etcFstab == "/etc/fstab" {
   188  			panic("respectfully refusing to remove /etc/fstab")
   189  		}
   190  		os.Remove(etcFstab)
   191  		etcFstab = old
   192  	}
   193  }
   194  
   195  // MockUname mocks syscall.Uname as used by MachineName and KernelVersion
   196  func MockUname(f func(*syscall.Utsname) error) (restore func()) {
   197  	old := syscallUname
   198  	syscallUname = f
   199  
   200  	return func() {
   201  		syscallUname = old
   202  	}
   203  }
   204  
   205  var (
   206  	FindUidNoGetentFallback = findUidNoGetentFallback
   207  	FindGidNoGetentFallback = findGidNoGetentFallback
   208  
   209  	FindUidWithGetentFallback = findUidWithGetentFallback
   210  	FindGidWithGetentFallback = findGidWithGetentFallback
   211  )
   212  
   213  func MockFindUidNoFallback(mock func(name string) (uint64, error)) (restore func()) {
   214  	old := findUidNoGetentFallback
   215  	findUidNoGetentFallback = mock
   216  	return func() { findUidNoGetentFallback = old }
   217  }
   218  
   219  func MockFindGidNoFallback(mock func(name string) (uint64, error)) (restore func()) {
   220  	old := findGidNoGetentFallback
   221  	findGidNoGetentFallback = mock
   222  	return func() { findGidNoGetentFallback = old }
   223  }
   224  
   225  const MaxSymlinkTries = maxSymlinkTries
   226  
   227  var ParseRawEnvironment = parseRawEnvironment
   228  
   229  // ParseRawExpandableEnv returns a new expandable environment parsed from key=value strings.
   230  func ParseRawExpandableEnv(entries []string) (ExpandableEnv, error) {
   231  	om := strutil.NewOrderedMap()
   232  	for _, entry := range entries {
   233  		key, value, err := parseEnvEntry(entry)
   234  		if err != nil {
   235  			return ExpandableEnv{}, err
   236  		}
   237  		if om.Get(key) != "" {
   238  			return ExpandableEnv{}, fmt.Errorf("cannot overwrite earlier value of %q", key)
   239  		}
   240  		om.Set(key, value)
   241  	}
   242  	return ExpandableEnv{OrderedMap: om}, nil
   243  }
   244  
   245  // this is weird to use in a test, but it is so that we can test the actual
   246  // implementation of LoadMountInfo, which normally panics during tests if not
   247  // properly mocked
   248  func MountInfoMustMock(new bool) (restore func()) {
   249  	old := mountInfoMustMockInTests
   250  	mountInfoMustMockInTests = new
   251  	return func() {
   252  		mountInfoMustMockInTests = old
   253  	}
   254  }
   255  
   256  // this should not be used except to test the actual implementation logic of
   257  // LoadMountInfo, if you are trying to mock /proc/self/mountinfo in a test,
   258  // use MockMountInfo(), which is exported and the right way to do that.
   259  func MockProcSelfMountInfoLocation(new string) (restore func()) {
   260  	old := procSelfMountInfo
   261  	procSelfMountInfo = new
   262  	return func() {
   263  		procSelfMountInfo = old
   264  	}
   265  }