github.com/canonical/ubuntu-image@v0.0.0-20240430122802-2202fe98b290/internal/testhelper/testhelper.go (about)

     1  // Package testhelper provides helpers to ease mocking functions and methods
     2  // provided by packages such as os or http.
     3  package testhelper
     4  
     5  import (
     6  	"fmt"
     7  	"io"
     8  	"io/fs"
     9  	"net/http"
    10  	"os"
    11  )
    12  
    13  // SaveCWD gets the current working directory and returns a function to go back to it
    14  // nolint: errcheck
    15  func SaveCWD() func() {
    16  	wd, _ := os.Getwd()
    17  	return func() {
    18  		_ = os.Chdir(wd)
    19  	}
    20  }
    21  
    22  // OSMockConf enables setting thresholds to indicate how many calls a the mocked
    23  // functions should accept before returning an error.
    24  // See osMock methods for specific behaviors.
    25  type OSMockConf struct {
    26  	OsutilCopySpecialFileThreshold uint
    27  	ReadDirThreshold               uint
    28  	RemoveThreshold                uint
    29  	TruncateThreshold              uint
    30  	OpenFileThreshold              uint
    31  	MkdirAllThreshold              uint
    32  	HttpGetThreshold               uint
    33  	ReadAllThreshold               uint
    34  }
    35  
    36  // osMock holds methods to easily mock functions from os and snapd/osutil packages
    37  // Each method can be configured to fail after a given number of calls
    38  // This could be improved by letting the mock functions calls the real
    39  // functions before failing.
    40  type osMock struct {
    41  	conf                            *OSMockConf
    42  	beforeOsutilCopySpecialFileFail uint
    43  	beforeReadDirFail               uint
    44  	beforeRemoveFail                uint
    45  	beforeTruncateFail              uint
    46  	beforeOpenFileFail              uint
    47  	beforeMkdirAllFail              uint
    48  	beforeHttpGetFail               uint
    49  	beforeReadAllFail               uint
    50  }
    51  
    52  // CopySpecialFile mocks CopySpecialFile github.com/snapcore/snapd/osutil
    53  func (o *osMock) CopySpecialFile(path, dest string) error {
    54  	if o.beforeOsutilCopySpecialFileFail >= o.conf.OsutilCopySpecialFileThreshold {
    55  		return fmt.Errorf("CopySpecialFile fail")
    56  	}
    57  	o.beforeOsutilCopySpecialFileFail++
    58  
    59  	return nil
    60  }
    61  
    62  // ReadDir mocks os.ReadDir
    63  func (o *osMock) ReadDir(name string) ([]fs.DirEntry, error) {
    64  	if o.beforeReadDirFail >= o.conf.ReadDirThreshold {
    65  		return nil, fmt.Errorf("ReadDir fail")
    66  	}
    67  	o.beforeReadDirFail++
    68  
    69  	return []fs.DirEntry{}, nil
    70  }
    71  
    72  // Remove mocks os.Remove
    73  func (o *osMock) Remove(name string) error {
    74  	if o.beforeRemoveFail >= o.conf.RemoveThreshold {
    75  		return fmt.Errorf("Remove fail")
    76  	}
    77  	o.beforeRemoveFail++
    78  
    79  	return nil
    80  }
    81  
    82  // Truncate mocks osTruncate
    83  func (o *osMock) Truncate(name string, size int64) error {
    84  	if o.beforeTruncateFail >= o.conf.TruncateThreshold {
    85  		return fmt.Errorf("Truncate fail")
    86  	}
    87  	o.beforeTruncateFail++
    88  
    89  	return nil
    90  }
    91  
    92  // OpenFile mocks os.OpenFile
    93  func (o *osMock) OpenFile(name string, flag int, perm os.FileMode) (*os.File, error) {
    94  	if o.beforeOpenFileFail >= o.conf.OpenFileThreshold {
    95  		return nil, fmt.Errorf("OpenFile fail")
    96  	}
    97  	o.beforeOpenFileFail++
    98  
    99  	return &os.File{}, nil
   100  }
   101  
   102  // MkdirAll mocks os.MkdirAll
   103  func (o *osMock) MkdirAll(path string, perm os.FileMode) error {
   104  	if o.beforeOpenFileFail >= o.conf.OpenFileThreshold {
   105  		return fmt.Errorf("OpenFile fail")
   106  	}
   107  	o.beforeMkdirAllFail++
   108  
   109  	return nil
   110  }
   111  
   112  // HttpGet mocks http.Get
   113  func (o *osMock) HttpGet(path string) (*http.Response, error) {
   114  	if o.beforeHttpGetFail >= o.conf.HttpGetThreshold {
   115  		return nil, fmt.Errorf("HttpGet fail")
   116  	}
   117  	o.beforeHttpGetFail++
   118  
   119  	return &http.Response{}, nil
   120  }
   121  
   122  // ReadAll mocks os.ReadAll
   123  func (o *osMock) ReadAll(io.Reader) ([]byte, error) {
   124  	if o.beforeReadAllFail >= o.conf.ReadAllThreshold {
   125  		return nil, fmt.Errorf("ReadAll fail")
   126  	}
   127  	o.beforeReadAllFail++
   128  
   129  	return []byte{}, nil
   130  }
   131  
   132  func NewOSMock(conf *OSMockConf) *osMock {
   133  	return &osMock{conf: conf}
   134  }