github.com/canonical/ubuntu-image@v0.0.0-20240430122802-2202fe98b290/internal/testhelper/testhelper.go (about) 1 // Package testhelper provides helpers to ease mocking functions and methods 2 // provided by packages such as os or http. 3 package testhelper 4 5 import ( 6 "fmt" 7 "io" 8 "io/fs" 9 "net/http" 10 "os" 11 ) 12 13 // SaveCWD gets the current working directory and returns a function to go back to it 14 // nolint: errcheck 15 func SaveCWD() func() { 16 wd, _ := os.Getwd() 17 return func() { 18 _ = os.Chdir(wd) 19 } 20 } 21 22 // OSMockConf enables setting thresholds to indicate how many calls a the mocked 23 // functions should accept before returning an error. 24 // See osMock methods for specific behaviors. 25 type OSMockConf struct { 26 OsutilCopySpecialFileThreshold uint 27 ReadDirThreshold uint 28 RemoveThreshold uint 29 TruncateThreshold uint 30 OpenFileThreshold uint 31 MkdirAllThreshold uint 32 HttpGetThreshold uint 33 ReadAllThreshold uint 34 } 35 36 // osMock holds methods to easily mock functions from os and snapd/osutil packages 37 // Each method can be configured to fail after a given number of calls 38 // This could be improved by letting the mock functions calls the real 39 // functions before failing. 40 type osMock struct { 41 conf *OSMockConf 42 beforeOsutilCopySpecialFileFail uint 43 beforeReadDirFail uint 44 beforeRemoveFail uint 45 beforeTruncateFail uint 46 beforeOpenFileFail uint 47 beforeMkdirAllFail uint 48 beforeHttpGetFail uint 49 beforeReadAllFail uint 50 } 51 52 // CopySpecialFile mocks CopySpecialFile github.com/snapcore/snapd/osutil 53 func (o *osMock) CopySpecialFile(path, dest string) error { 54 if o.beforeOsutilCopySpecialFileFail >= o.conf.OsutilCopySpecialFileThreshold { 55 return fmt.Errorf("CopySpecialFile fail") 56 } 57 o.beforeOsutilCopySpecialFileFail++ 58 59 return nil 60 } 61 62 // ReadDir mocks os.ReadDir 63 func (o *osMock) ReadDir(name string) ([]fs.DirEntry, error) { 64 if o.beforeReadDirFail >= o.conf.ReadDirThreshold { 65 return nil, fmt.Errorf("ReadDir fail") 66 } 67 o.beforeReadDirFail++ 68 69 return []fs.DirEntry{}, nil 70 } 71 72 // Remove mocks os.Remove 73 func (o *osMock) Remove(name string) error { 74 if o.beforeRemoveFail >= o.conf.RemoveThreshold { 75 return fmt.Errorf("Remove fail") 76 } 77 o.beforeRemoveFail++ 78 79 return nil 80 } 81 82 // Truncate mocks osTruncate 83 func (o *osMock) Truncate(name string, size int64) error { 84 if o.beforeTruncateFail >= o.conf.TruncateThreshold { 85 return fmt.Errorf("Truncate fail") 86 } 87 o.beforeTruncateFail++ 88 89 return nil 90 } 91 92 // OpenFile mocks os.OpenFile 93 func (o *osMock) OpenFile(name string, flag int, perm os.FileMode) (*os.File, error) { 94 if o.beforeOpenFileFail >= o.conf.OpenFileThreshold { 95 return nil, fmt.Errorf("OpenFile fail") 96 } 97 o.beforeOpenFileFail++ 98 99 return &os.File{}, nil 100 } 101 102 // MkdirAll mocks os.MkdirAll 103 func (o *osMock) MkdirAll(path string, perm os.FileMode) error { 104 if o.beforeOpenFileFail >= o.conf.OpenFileThreshold { 105 return fmt.Errorf("OpenFile fail") 106 } 107 o.beforeMkdirAllFail++ 108 109 return nil 110 } 111 112 // HttpGet mocks http.Get 113 func (o *osMock) HttpGet(path string) (*http.Response, error) { 114 if o.beforeHttpGetFail >= o.conf.HttpGetThreshold { 115 return nil, fmt.Errorf("HttpGet fail") 116 } 117 o.beforeHttpGetFail++ 118 119 return &http.Response{}, nil 120 } 121 122 // ReadAll mocks os.ReadAll 123 func (o *osMock) ReadAll(io.Reader) ([]byte, error) { 124 if o.beforeReadAllFail >= o.conf.ReadAllThreshold { 125 return nil, fmt.Errorf("ReadAll fail") 126 } 127 o.beforeReadAllFail++ 128 129 return []byte{}, nil 130 } 131 132 func NewOSMock(conf *OSMockConf) *osMock { 133 return &osMock{conf: conf} 134 }