github.com/rigado/snapd@v2.42.5-go-mod+incompatible/osutil/export_test.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2016 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package osutil
    21  
    22  import (
    23  	"fmt"
    24  	"io"
    25  	"io/ioutil"
    26  	"os"
    27  	"os/exec"
    28  	"os/user"
    29  	"syscall"
    30  	"time"
    31  
    32  	"github.com/snapcore/snapd/osutil/sys"
    33  )
    34  
    35  func MockUserLookup(mock func(name string) (*user.User, error)) func() {
    36  	realUserLookup := userLookup
    37  	userLookup = mock
    38  
    39  	return func() { userLookup = realUserLookup }
    40  }
    41  
    42  func MockUserCurrent(mock func() (*user.User, error)) func() {
    43  	realUserCurrent := userCurrent
    44  	userCurrent = mock
    45  
    46  	return func() { userCurrent = realUserCurrent }
    47  }
    48  
    49  func MockSudoersDotD(mockDir string) func() {
    50  	realSudoersD := sudoersDotD
    51  	sudoersDotD = mockDir
    52  
    53  	return func() { sudoersDotD = realSudoersD }
    54  }
    55  
    56  func MockSyscallKill(f func(int, syscall.Signal) error) func() {
    57  	oldSyscallKill := syscallKill
    58  	syscallKill = f
    59  	return func() {
    60  		syscallKill = oldSyscallKill
    61  	}
    62  }
    63  
    64  func MockSyscallGetpgid(f func(int) (int, error)) func() {
    65  	oldSyscallGetpgid := syscallGetpgid
    66  	syscallGetpgid = f
    67  	return func() {
    68  		syscallGetpgid = oldSyscallGetpgid
    69  	}
    70  }
    71  
    72  func MockCmdWaitTimeout(timeout time.Duration) func() {
    73  	oldCmdWaitTimeout := cmdWaitTimeout
    74  	cmdWaitTimeout = timeout
    75  	return func() {
    76  		cmdWaitTimeout = oldCmdWaitTimeout
    77  	}
    78  }
    79  
    80  func WaitingReaderGuts(r io.Reader) (io.Reader, *exec.Cmd) {
    81  	wr := r.(*waitingReader)
    82  	return wr.reader, wr.cmd
    83  }
    84  
    85  func MockChown(f func(*os.File, sys.UserID, sys.GroupID) error) func() {
    86  	oldChown := chown
    87  	chown = f
    88  	return func() {
    89  		chown = oldChown
    90  	}
    91  }
    92  
    93  func SetAtomicFileRenamed(aw *AtomicFile, renamed bool) {
    94  	aw.renamed = renamed
    95  }
    96  
    97  func SetUnsafeIO(b bool) func() {
    98  	oldSnapdUnsafeIO := snapdUnsafeIO
    99  	snapdUnsafeIO = b
   100  	return func() {
   101  		snapdUnsafeIO = oldSnapdUnsafeIO
   102  	}
   103  }
   104  
   105  func MockOsReadlink(f func(string) (string, error)) func() {
   106  	realOsReadlink := osReadlink
   107  	osReadlink = f
   108  
   109  	return func() { osReadlink = realOsReadlink }
   110  }
   111  
   112  //MockMountInfo mocks content of /proc/self/mountinfo read by IsHomeUsingNFS
   113  func MockMountInfo(text string) (restore func()) {
   114  	old := procSelfMountInfo
   115  	f, err := ioutil.TempFile("", "mountinfo")
   116  	if err != nil {
   117  		panic(fmt.Errorf("cannot open temporary file: %s", err))
   118  	}
   119  	if err := ioutil.WriteFile(f.Name(), []byte(text), 0644); err != nil {
   120  		panic(fmt.Errorf("cannot write mock mountinfo file: %s", err))
   121  	}
   122  	procSelfMountInfo = f.Name()
   123  	return func() {
   124  		os.Remove(procSelfMountInfo)
   125  		procSelfMountInfo = old
   126  	}
   127  }
   128  
   129  // MockEtcFstab mocks content of /etc/fstab read by IsHomeUsingNFS
   130  func MockEtcFstab(text string) (restore func()) {
   131  	old := etcFstab
   132  	f, err := ioutil.TempFile("", "fstab")
   133  	if err != nil {
   134  		panic(fmt.Errorf("cannot open temporary file: %s", err))
   135  	}
   136  	if err := ioutil.WriteFile(f.Name(), []byte(text), 0644); err != nil {
   137  		panic(fmt.Errorf("cannot write mock fstab file: %s", err))
   138  	}
   139  	etcFstab = f.Name()
   140  	return func() {
   141  		if etcFstab == "/etc/fstab" {
   142  			panic("respectfully refusing to remove /etc/fstab")
   143  		}
   144  		os.Remove(etcFstab)
   145  		etcFstab = old
   146  	}
   147  }
   148  
   149  // MockUname mocks syscall.Uname as used by MachineName and KernelVersion
   150  func MockUname(f func(*syscall.Utsname) error) (restore func()) {
   151  	old := syscallUname
   152  	syscallUname = f
   153  
   154  	return func() {
   155  		syscallUname = old
   156  	}
   157  }
   158  
   159  var (
   160  	FindUidNoGetentFallback = findUidNoGetentFallback
   161  	FindGidNoGetentFallback = findGidNoGetentFallback
   162  
   163  	FindUidWithGetentFallback = findUidWithGetentFallback
   164  	FindGidWithGetentFallback = findGidWithGetentFallback
   165  )
   166  
   167  func MockFindUidNoFallback(mock func(name string) (uint64, error)) (restore func()) {
   168  	old := findUidNoGetentFallback
   169  	findUidNoGetentFallback = mock
   170  	return func() { findUidNoGetentFallback = old }
   171  }
   172  
   173  func MockFindGidNoFallback(mock func(name string) (uint64, error)) (restore func()) {
   174  	old := findGidNoGetentFallback
   175  	findGidNoGetentFallback = mock
   176  	return func() { findGidNoGetentFallback = old }
   177  }
   178  
   179  func MockFindUid(mock func(name string) (uint64, error)) (restore func()) {
   180  	old := findUid
   181  	findUid = mock
   182  	return func() { findUid = old }
   183  }
   184  
   185  func MockFindGid(mock func(name string) (uint64, error)) (restore func()) {
   186  	old := findGid
   187  	findGid = mock
   188  	return func() { findGid = old }
   189  }