github.com/canonical/ubuntu-image@v0.0.0-20240430122802-2202fe98b290/internal/helper/helper.go (about)

     1  package helper
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/sha256"
     6  	"fmt"
     7  	"io"
     8  	"os"
     9  	"os/exec"
    10  	"path/filepath"
    11  	"reflect"
    12  	"strings"
    13  
    14  	"github.com/invopop/jsonschema"
    15  	"github.com/snapcore/snapd/gadget/quantity"
    16  	"github.com/snapcore/snapd/osutil"
    17  	"github.com/xeipuuv/gojsonschema"
    18  
    19  	"github.com/canonical/ubuntu-image/internal/commands"
    20  )
    21  
    22  // define some functions that can be mocked by test cases
    23  var (
    24  	osRename         = os.Rename
    25  	osRemove         = os.Remove
    26  	osWriteFile      = os.WriteFile
    27  	osutilFileExists = osutil.FileExists
    28  )
    29  
    30  func BoolPtr(b bool) *bool {
    31  	return &b
    32  }
    33  
    34  // CaptureStd returns an io.Reader to read what was printed, and teardown
    35  func CaptureStd(toCap **os.File) (io.Reader, func(), error) {
    36  	stdCap, stdCapW, err := os.Pipe()
    37  	if err != nil {
    38  		return nil, nil, err
    39  	}
    40  	oldStdCap := *toCap
    41  	*toCap = stdCapW
    42  	closed := false
    43  	return stdCap, func() {
    44  		// only teardown once
    45  		if closed {
    46  			return
    47  		}
    48  		*toCap = oldStdCap
    49  		stdCapW.Close()
    50  		closed = true
    51  	}, nil
    52  }
    53  
    54  // InitCommonOpts initializes default common options for state machines.
    55  // This is used for test scenarios to avoid nil pointer dereferences
    56  func InitCommonOpts() (*commands.CommonOpts, *commands.StateMachineOpts) {
    57  	commonOpts := new(commands.CommonOpts)
    58  	// This is a workaround to set the default value for test cases. Normally
    59  	// go-flags makes sure that the option has a sane value at all times, but
    60  	// for tests we'd have to set it manually all the time.
    61  	commonOpts.SectorSize = "512"
    62  	return commonOpts, new(commands.StateMachineOpts)
    63  }
    64  
    65  // RunScript runs scripts from disk. Currently only used for hooks
    66  func RunScript(hookScript string) error {
    67  	hookScriptCmd := exec.Command(hookScript)
    68  	hookScriptCmd.Env = os.Environ()
    69  	hookScriptCmd.Stdout = os.Stdout
    70  	hookScriptCmd.Stderr = os.Stderr
    71  	if err := hookScriptCmd.Run(); err != nil {
    72  		return fmt.Errorf("Error running hook script %s: %s", hookScript, err.Error())
    73  	}
    74  	return nil
    75  }
    76  
    77  // Du recurses through a directory similar to du and adds all the sizes of files together
    78  func Du(path string) (quantity.Size, error) {
    79  	duCommand := *exec.Command("du", "-s", "-B1")
    80  	duCommand.Args = append(duCommand.Args, path)
    81  
    82  	duBytes, err := duCommand.Output()
    83  	if err != nil {
    84  		return quantity.Size(0), err
    85  	}
    86  	sizeString := strings.Split(string(duBytes), "\t")[0]
    87  	size, err := quantity.ParseSize(sizeString)
    88  	return size, err
    89  }
    90  
    91  // CopyBlob runs `dd` to copy a blob to an image file
    92  func CopyBlob(ddArgs []string) error {
    93  	ddCommand := *exec.Command("dd")
    94  	ddCommand.Args = append(ddCommand.Args, ddArgs...)
    95  
    96  	if err := ddCommand.Run(); err != nil {
    97  		return fmt.Errorf("Command \"%s\" returned with %s", ddCommand.String(), err.Error())
    98  	}
    99  	return nil
   100  }
   101  
   102  // SetDefaults iterates through the keys in a struct and sets
   103  // default values if one is specified with a struct tag of "default".
   104  // Currently only default values of strings, slice of strings, and
   105  // bools are supported
   106  func SetDefaults(needsDefaults interface{}) error {
   107  	value := reflect.ValueOf(needsDefaults)
   108  	if value.Kind() != reflect.Ptr {
   109  		return fmt.Errorf("The argument to SetDefaults must be a pointer")
   110  	}
   111  	elem := value.Elem()
   112  	for i := 0; i < elem.NumField(); i++ {
   113  		field := elem.Field(i)
   114  		// if we're dealing with a slice of pointers to structs,
   115  		// iterate through it and set the defaults for each struct pointer
   116  		if isSliceOfPtrToStructs(field) {
   117  			err := setDefaultsToSlice(field)
   118  			if err != nil {
   119  				return err
   120  			}
   121  		} else if field.Type().Kind() == reflect.Ptr {
   122  			err := setDefaultToPtr(field, elem, i)
   123  			if err != nil {
   124  				return err
   125  			}
   126  		} else {
   127  			err := setDefaultToBasicType(field, elem, i)
   128  			if err != nil {
   129  				return err
   130  			}
   131  		}
   132  	}
   133  	return nil
   134  }
   135  
   136  // setDefaultsToSlice sets default values to elements of a slice.
   137  // It assumes it was already checked field is a non empty slice. Otherwise this
   138  // function will probably panic.
   139  func setDefaultsToSlice(field reflect.Value) error {
   140  	for i := 0; i < field.Cap(); i++ {
   141  		err := SetDefaults(field.Index(i).Interface())
   142  		if err != nil {
   143  			return err
   144  		}
   145  	}
   146  	return nil
   147  }
   148  
   149  // setDefaultToPtr sets a default value to a field being a pointer.
   150  // It assumes it was already checked that field is a ptr. Otherwise this
   151  // function will probably panic.
   152  func setDefaultToPtr(field reflect.Value, elem reflect.Value, fieldIndex int) error {
   153  	// if it's a pointer to a struct, look for default types
   154  	if field.Elem().Kind() == reflect.Struct {
   155  		err := SetDefaults(field.Interface())
   156  		if err != nil {
   157  			return err
   158  		}
   159  		// special case for pointer to bools
   160  	} else if field.Type().Elem() == reflect.TypeOf(true) {
   161  		// if a value is set, do nothing
   162  		if !field.IsNil() {
   163  			return nil
   164  		}
   165  		tags := elem.Type().Field(fieldIndex).Tag
   166  		defaultValue, hasDefault := tags.Lookup("default")
   167  		if !hasDefault {
   168  			// If no default and no value is set, make sure we have a valid
   169  			// value consistent with the "zero" value for a bool (false)
   170  			field.Set(reflect.ValueOf(BoolPtr(false)))
   171  			return nil
   172  		}
   173  		if defaultValue == "true" {
   174  			field.Set(reflect.ValueOf(BoolPtr(true)))
   175  		} else {
   176  			field.Set(reflect.ValueOf(BoolPtr(false)))
   177  		}
   178  	}
   179  	return nil
   180  }
   181  
   182  // setDefaultToBasicType sets the default value to a basic type (and slice) based on the
   183  // "default" tag.
   184  func setDefaultToBasicType(field reflect.Value, elem reflect.Value, fieldIndex int) error {
   185  	tags := elem.Type().Field(fieldIndex).Tag
   186  	defaultValue, hasDefault := tags.Lookup("default")
   187  	if !hasDefault {
   188  		return nil
   189  	}
   190  	indirectedField := reflect.Indirect(field)
   191  	if indirectedField.CanSet() && field.IsZero() {
   192  		varType := field.Type().Kind()
   193  		switch varType {
   194  		case reflect.String:
   195  			field.SetString(defaultValue)
   196  		case reflect.Slice:
   197  			defaultValues := strings.Split(defaultValue, ",")
   198  			field.Set(reflect.ValueOf(defaultValues))
   199  		case reflect.Bool:
   200  			return fmt.Errorf("Setting default value of a boolean not supported. Use a pointer to boolean instead.")
   201  		default:
   202  			return fmt.Errorf("Setting default value of type %s not supported",
   203  				varType)
   204  		}
   205  	}
   206  	return nil
   207  }
   208  
   209  // CheckEmptyFields iterates through the image definition struct and
   210  // checks for fields that are present but return IsZero == true.
   211  // TODO: I've created a PR upstream in xeipuuv/gojsonschema
   212  // https://github.com/xeipuuv/gojsonschema/pull/352
   213  // if it gets merged this can be deleted
   214  func CheckEmptyFields(Interface interface{}, result *gojsonschema.Result, schema *jsonschema.Schema) error {
   215  	value := reflect.ValueOf(Interface)
   216  	if value.Kind() != reflect.Ptr {
   217  		return fmt.Errorf("The argument to CheckEmptyFields must be a pointer")
   218  	}
   219  	elem := value.Elem()
   220  	for i := 0; i < elem.NumField(); i++ {
   221  		field := elem.Field(i)
   222  		// if we're dealing with a slice, iterate through
   223  		// it and search for missing required fields in each
   224  		// element of the slice
   225  		if field.Type().Kind() == reflect.Slice {
   226  			err := checkEmptyFieldsInSlice(field, result, schema)
   227  			if err != nil {
   228  				return err
   229  			}
   230  		} else if field.Type().Kind() == reflect.Ptr {
   231  			// otherwise if it's just a pointer to a nested struct
   232  			// search it for empty required fields
   233  			err := checkEmptyFieldsInPtr(field, result, schema)
   234  			if err != nil {
   235  				return err
   236  			}
   237  		} else {
   238  			tags := elem.Type().Field(i).Tag
   239  			if !isRequiredFromTags(tags) && !isRequiredFromSchema(elem, i, schema) {
   240  				continue
   241  			}
   242  			// this is a required field, check for zero values
   243  			if !reflect.Indirect(field).IsZero() {
   244  				continue
   245  			}
   246  			jsonContext := gojsonschema.NewJsonContext("image_definition", nil)
   247  			errDetail := gojsonschema.ErrorDetails{
   248  				"property": tags.Get("yaml"),
   249  				"parent":   elem.Type().Name(),
   250  			}
   251  			result.AddError(
   252  				newMissingFieldError(
   253  					gojsonschema.NewJsonContext("missing_field", jsonContext),
   254  					52,
   255  					errDetail,
   256  				),
   257  				errDetail,
   258  			)
   259  		}
   260  	}
   261  	return nil
   262  }
   263  
   264  func checkEmptyFieldsInSlice(field reflect.Value, result *gojsonschema.Result, schema *jsonschema.Schema) error {
   265  	for i := 0; i < field.Cap(); i++ {
   266  		sliceElem := field.Index(i)
   267  		if sliceElem.Kind() == reflect.Ptr && sliceElem.Elem().Kind() == reflect.Struct {
   268  			err := CheckEmptyFields(sliceElem.Interface(), result, schema)
   269  			if err != nil {
   270  				return err
   271  			}
   272  		}
   273  	}
   274  	return nil
   275  }
   276  
   277  func checkEmptyFieldsInPtr(field reflect.Value, result *gojsonschema.Result, schema *jsonschema.Schema) error {
   278  	if field.Elem().Kind() == reflect.Struct {
   279  		err := CheckEmptyFields(field.Interface(), result, schema)
   280  		if err != nil {
   281  			return err
   282  		}
   283  	}
   284  	return nil
   285  }
   286  
   287  // isRequiredFromTags checks if the field is required from the JSON tags
   288  func isRequiredFromTags(tags reflect.StructTag) bool {
   289  	jsonTag, hasJSON := tags.Lookup("json")
   290  	if hasJSON {
   291  		if !strings.Contains(jsonTag, "omitempty") {
   292  			return true
   293  		}
   294  	}
   295  	return false
   296  }
   297  
   298  // isRequiredFromSchema checks if the field is required from the schema
   299  func isRequiredFromSchema(elem reflect.Value, i int, schema *jsonschema.Schema) bool {
   300  	for _, requiredField := range schema.Required {
   301  		if elem.Type().Field(i).Name == requiredField {
   302  			return true
   303  		}
   304  	}
   305  	return false
   306  }
   307  
   308  func newMissingFieldError(context *gojsonschema.JsonContext, value interface{}, details gojsonschema.ErrorDetails) *MissingFieldError {
   309  	err := MissingFieldError{}
   310  	err.SetContext(context)
   311  	err.SetType("missing_field_error")
   312  	err.SetValue(value)
   313  	err.SetDescriptionFormat("Key \"{{.property}}\" is required in struct \"{{.parent}}\", but is not in the YAML file!")
   314  	err.SetDetails(details)
   315  
   316  	return &err
   317  }
   318  
   319  // MissingFieldError is used when the fields exist but are the zero value for their type
   320  type MissingFieldError struct {
   321  	gojsonschema.ResultErrorFields
   322  }
   323  
   324  // SliceHasElement searches for a string in a slice of strings and returns whether it
   325  // is found
   326  func SliceHasElement(haystack []string, needle string) bool {
   327  	found := false
   328  	for _, element := range haystack {
   329  		if element == needle {
   330  			found = true
   331  		}
   332  	}
   333  	return found
   334  }
   335  
   336  // SetCommandOutput sets the output of a command to either use a multiwriter
   337  // or behave as a normal command and store the output in a buffer
   338  func SetCommandOutput(cmd *exec.Cmd, liveOutput bool) (cmdOutput *bytes.Buffer) {
   339  	var cmdOutputBuffer bytes.Buffer
   340  	cmdOutput = &cmdOutputBuffer
   341  	cmd.Stdout = cmdOutput
   342  	cmd.Stderr = cmdOutput
   343  	if liveOutput {
   344  		mwriter := io.MultiWriter(os.Stdout, cmdOutput)
   345  		cmd.Stdout = mwriter
   346  		cmd.Stderr = mwriter
   347  	}
   348  	return cmdOutput
   349  }
   350  
   351  func RunCmd(cmd *exec.Cmd, debug bool) error {
   352  	output := SetCommandOutput(cmd, debug)
   353  	err := cmd.Run()
   354  	if err != nil {
   355  		return fmt.Errorf("Error running command \"%s\". Error: %s. Output:\n%s",
   356  			cmd.String(), err.Error(), output.String())
   357  	}
   358  	return nil
   359  }
   360  
   361  // RunCmds runs a list of commands and returns the error
   362  // It stops at the first error
   363  func RunCmds(cmds []*exec.Cmd, debug bool) error {
   364  	for _, cmd := range cmds {
   365  		err := RunCmd(cmd, debug)
   366  		if err != nil {
   367  			return err
   368  		}
   369  	}
   370  
   371  	return nil
   372  }
   373  
   374  // SafeQuantitySubtraction subtracts quantities while checking for integer underflow
   375  func SafeQuantitySubtraction(orig, subtract quantity.Size) quantity.Size {
   376  	if subtract > orig {
   377  		return 0
   378  	}
   379  	return orig - subtract
   380  }
   381  
   382  // CreateTarArchive places all of the files from a source directory into a tar.
   383  // Currently supported are uncompressed tar archives and the following
   384  // compression types: zip, gzip, xz bzip2, zstd
   385  func CreateTarArchive(src, dest, compression string, verbose, debug bool) error {
   386  	tarCommand := *exec.Command(
   387  		"tar",
   388  		"--directory",
   389  		src,
   390  		"--xattrs",
   391  		"--xattrs-include=*",
   392  		"--create",
   393  		"--file",
   394  		dest,
   395  		".",
   396  	)
   397  	if debug {
   398  		tarCommand.Args = append(tarCommand.Args, "--verbose")
   399  	}
   400  	// set up any compression arguments
   401  	switch compression {
   402  	case "uncompressed":
   403  		break
   404  	case "bzip2":
   405  		tarCommand.Args = append(tarCommand.Args, "--bzip2")
   406  	case "gzip":
   407  		tarCommand.Args = append(tarCommand.Args, "--gzip")
   408  	case "xz":
   409  		tarCommand.Args = append(tarCommand.Args, "--xz")
   410  	case "zstd":
   411  		tarCommand.Args = append(tarCommand.Args, "--zstd")
   412  	default:
   413  		return fmt.Errorf("Unknown compression type: \"%s\"", compression)
   414  	}
   415  
   416  	tarOutput := SetCommandOutput(&tarCommand, debug)
   417  	if err := tarCommand.Run(); err != nil {
   418  		return fmt.Errorf("Error running \"tar\" command \"%s\". "+
   419  			"Error is \"%s\". Full output below:\n%s",
   420  			tarCommand.String(), err.Error(), tarOutput.String())
   421  	}
   422  	return nil
   423  }
   424  
   425  // ExtractTarArchive extracts all the files from a tar. Currently supported are
   426  // uncompressed tar archives and the following compression types: zip, gzip, xz
   427  // bzip2, zstd
   428  func ExtractTarArchive(src, dest string, verbose, debug bool) error {
   429  	tarCommand := *exec.Command(
   430  		"tar",
   431  		"--xattrs",
   432  		"--xattrs-include=*",
   433  		"--extract",
   434  		"--file",
   435  		src,
   436  		"--directory",
   437  		dest,
   438  	)
   439  	if debug {
   440  		tarCommand.Args = append(tarCommand.Args, "--verbose")
   441  	}
   442  	tarOutput := SetCommandOutput(&tarCommand, debug)
   443  	if err := tarCommand.Run(); err != nil {
   444  		return fmt.Errorf("Error running \"tar\" command \"%s\". "+
   445  			"Error is \"%s\". Full output below:\n%s",
   446  			tarCommand.String(), err.Error(), tarOutput.String())
   447  	}
   448  	return nil
   449  }
   450  
   451  // CalculateSHA256 calculates the SHA256 sum of the file provided as an argument
   452  func CalculateSHA256(fileName string) (string, error) {
   453  	f, err := os.Open(fileName)
   454  	if err != nil {
   455  		return "", fmt.Errorf("Error opening file \"%s\" to calculate SHA256 sum: \"%s\"", fileName, err.Error())
   456  	}
   457  	defer f.Close()
   458  
   459  	hasher := sha256.New()
   460  	_, err = io.Copy(hasher, f)
   461  	if err != nil {
   462  		return "", fmt.Errorf("Error calculating SHA256 sum of file \"%s\": \"%s\"", fileName, err.Error())
   463  	}
   464  
   465  	return string(hasher.Sum(nil)), nil
   466  }
   467  
   468  // CheckTags iterates through the keys in a struct and looks for
   469  // a value passed in as a parameter. It returns the yaml name of
   470  // the key and an error. Currently only boolean values for the tags
   471  // are supported
   472  func CheckTags(searchStruct interface{}, tag string) (string, error) {
   473  	value := reflect.ValueOf(searchStruct)
   474  	if value.Kind() != reflect.Ptr {
   475  		return "", fmt.Errorf("The argument to CheckTags must be a pointer")
   476  	}
   477  	elem := value.Elem()
   478  
   479  	return checkTagsOnField(elem, tag)
   480  }
   481  
   482  func isSliceOfPtrToStructs(field reflect.Value) bool {
   483  	return field.Type().Kind() == reflect.Slice &&
   484  		field.Cap() > 0 &&
   485  		field.Index(0).Kind() == reflect.Pointer
   486  }
   487  
   488  func checkTagsOnField(elem reflect.Value, tag string) (string, error) {
   489  	for i := 0; i < elem.NumField(); i++ {
   490  		field := elem.Field(i)
   491  		// if we're dealing with a slice of pointers to structs,
   492  		// iterate through it and check the tags for each struct pointer
   493  		if isSliceOfPtrToStructs(field) {
   494  			for i := 0; i < field.Cap(); i++ {
   495  				tagUsed, err := CheckTags(field.Index(i).Interface(), tag)
   496  				if err != nil {
   497  					return "", err
   498  				}
   499  				if tagUsed != "" {
   500  					// just return on the first one found.
   501  					// user can iteratively work through error
   502  					// messages if there are more than one
   503  					return tagUsed, nil
   504  				}
   505  			}
   506  		} else if !field.IsNil() {
   507  			tags := elem.Type().Field(i).Tag
   508  			tagValue, hasTag := tags.Lookup(tag)
   509  			if hasTag && tagValue == "true" {
   510  				yamlName, _ := tags.Lookup("yaml")
   511  				return yamlName, nil
   512  			}
   513  		}
   514  	}
   515  	// no true value found
   516  	return "", nil
   517  }
   518  
   519  // BackupAndCopyResolvConf creates a backup of /etc/resolv.conf in a chroot
   520  // and copies the contents from the host system into the chroot
   521  func BackupAndCopyResolvConf(chroot string) error {
   522  	if osutil.FileExists(filepath.Join(chroot, "etc", "resolv.conf.tmp")) {
   523  		// already backed up/copied so do nothing
   524  		return nil
   525  	}
   526  	src := filepath.Join(chroot, "etc", "resolv.conf")
   527  	dest := filepath.Join(chroot, "etc", "resolv.conf.tmp")
   528  	if err := os.Rename(src, dest); err != nil {
   529  		return fmt.Errorf("Error moving file \"%s\" to \"%s\": %s", src, dest, err.Error())
   530  	}
   531  	dest = src
   532  	src = filepath.Join("/etc", "resolv.conf")
   533  	if err := osutil.CopyFile(src, dest, osutil.CopyFlagDefault); err != nil {
   534  		return fmt.Errorf("Error copying file \"%s\" to \"%s\": %s", src, dest, err.Error())
   535  	}
   536  	return nil
   537  }
   538  
   539  // RestoreResolvConf restores the resolv.conf in the chroot from the
   540  // version that was backed up by BackupAndCopyResolvConf
   541  func RestoreResolvConf(chroot string) error {
   542  	if !osutil.FileExists(filepath.Join(chroot, "etc", "resolv.conf.tmp")) {
   543  		return nil
   544  	}
   545  	if osutil.IsSymlink(filepath.Join(chroot, "etc", "resolv.conf")) {
   546  		// As per what live-build does, handle the case where some package
   547  		// in the install_packages phase converts resolv.conf into a
   548  		// symlink. In such case we don't restore our backup but instead
   549  		// remove it, leaving the symlink around.
   550  		backup := filepath.Join(chroot, "etc", "resolv.conf.tmp")
   551  		if err := osRemove(backup); err != nil {
   552  			return fmt.Errorf("Error removing file \"%s\": %s", backup, err.Error())
   553  		}
   554  	} else {
   555  		src := filepath.Join(chroot, "etc", "resolv.conf.tmp")
   556  		dest := filepath.Join(chroot, "etc", "resolv.conf")
   557  		if err := osRename(src, dest); err != nil {
   558  			return fmt.Errorf("Error moving file \"%s\" to \"%s\": %s", src, dest, err.Error())
   559  		}
   560  	}
   561  	return nil
   562  }
   563  
   564  const backupExt = ".REAL"
   565  
   566  // BackupReplace backup the target file and replace it with the given content
   567  // Returns the restore function.
   568  func BackupReplace(target string, content string) (func(error) error, error) {
   569  	backup := target + backupExt
   570  	if osutilFileExists(backup) {
   571  		// already backed up so do nothing
   572  		return nil, nil
   573  	}
   574  
   575  	if err := osRename(target, backup); err != nil {
   576  		return nil, fmt.Errorf("Error moving file \"%s\" to \"%s\": %s", target, backup, err.Error())
   577  	}
   578  
   579  	if err := osWriteFile(target, []byte(content), 0755); err != nil {
   580  		return nil, fmt.Errorf("Error writing to %s : %s", target, err.Error())
   581  	}
   582  
   583  	return genRestoreFile(target), nil
   584  }
   585  
   586  // genRestoreFile returns the function to be called to restore the backuped file
   587  func genRestoreFile(target string) func(err error) error {
   588  	return func(err error) error {
   589  		src := target + backupExt
   590  		if !osutilFileExists(src) {
   591  			return err
   592  		}
   593  
   594  		if tmpErr := osRename(src, target); tmpErr != nil {
   595  			tmpErr = fmt.Errorf("Error moving file \"%s\" to \"%s\": %s", src, target, tmpErr.Error())
   596  			return fmt.Errorf("%s\n%s", err, tmpErr)
   597  		}
   598  
   599  		return err
   600  	}
   601  }