github.com/canonical/ubuntu-image@v0.0.0-20240430122802-2202fe98b290/internal/helper/helper.go (about) 1 package helper 2 3 import ( 4 "bytes" 5 "crypto/sha256" 6 "fmt" 7 "io" 8 "os" 9 "os/exec" 10 "path/filepath" 11 "reflect" 12 "strings" 13 14 "github.com/invopop/jsonschema" 15 "github.com/snapcore/snapd/gadget/quantity" 16 "github.com/snapcore/snapd/osutil" 17 "github.com/xeipuuv/gojsonschema" 18 19 "github.com/canonical/ubuntu-image/internal/commands" 20 ) 21 22 // define some functions that can be mocked by test cases 23 var ( 24 osRename = os.Rename 25 osRemove = os.Remove 26 osWriteFile = os.WriteFile 27 osutilFileExists = osutil.FileExists 28 ) 29 30 func BoolPtr(b bool) *bool { 31 return &b 32 } 33 34 // CaptureStd returns an io.Reader to read what was printed, and teardown 35 func CaptureStd(toCap **os.File) (io.Reader, func(), error) { 36 stdCap, stdCapW, err := os.Pipe() 37 if err != nil { 38 return nil, nil, err 39 } 40 oldStdCap := *toCap 41 *toCap = stdCapW 42 closed := false 43 return stdCap, func() { 44 // only teardown once 45 if closed { 46 return 47 } 48 *toCap = oldStdCap 49 stdCapW.Close() 50 closed = true 51 }, nil 52 } 53 54 // InitCommonOpts initializes default common options for state machines. 55 // This is used for test scenarios to avoid nil pointer dereferences 56 func InitCommonOpts() (*commands.CommonOpts, *commands.StateMachineOpts) { 57 commonOpts := new(commands.CommonOpts) 58 // This is a workaround to set the default value for test cases. Normally 59 // go-flags makes sure that the option has a sane value at all times, but 60 // for tests we'd have to set it manually all the time. 61 commonOpts.SectorSize = "512" 62 return commonOpts, new(commands.StateMachineOpts) 63 } 64 65 // RunScript runs scripts from disk. Currently only used for hooks 66 func RunScript(hookScript string) error { 67 hookScriptCmd := exec.Command(hookScript) 68 hookScriptCmd.Env = os.Environ() 69 hookScriptCmd.Stdout = os.Stdout 70 hookScriptCmd.Stderr = os.Stderr 71 if err := hookScriptCmd.Run(); err != nil { 72 return fmt.Errorf("Error running hook script %s: %s", hookScript, err.Error()) 73 } 74 return nil 75 } 76 77 // Du recurses through a directory similar to du and adds all the sizes of files together 78 func Du(path string) (quantity.Size, error) { 79 duCommand := *exec.Command("du", "-s", "-B1") 80 duCommand.Args = append(duCommand.Args, path) 81 82 duBytes, err := duCommand.Output() 83 if err != nil { 84 return quantity.Size(0), err 85 } 86 sizeString := strings.Split(string(duBytes), "\t")[0] 87 size, err := quantity.ParseSize(sizeString) 88 return size, err 89 } 90 91 // CopyBlob runs `dd` to copy a blob to an image file 92 func CopyBlob(ddArgs []string) error { 93 ddCommand := *exec.Command("dd") 94 ddCommand.Args = append(ddCommand.Args, ddArgs...) 95 96 if err := ddCommand.Run(); err != nil { 97 return fmt.Errorf("Command \"%s\" returned with %s", ddCommand.String(), err.Error()) 98 } 99 return nil 100 } 101 102 // SetDefaults iterates through the keys in a struct and sets 103 // default values if one is specified with a struct tag of "default". 104 // Currently only default values of strings, slice of strings, and 105 // bools are supported 106 func SetDefaults(needsDefaults interface{}) error { 107 value := reflect.ValueOf(needsDefaults) 108 if value.Kind() != reflect.Ptr { 109 return fmt.Errorf("The argument to SetDefaults must be a pointer") 110 } 111 elem := value.Elem() 112 for i := 0; i < elem.NumField(); i++ { 113 field := elem.Field(i) 114 // if we're dealing with a slice of pointers to structs, 115 // iterate through it and set the defaults for each struct pointer 116 if isSliceOfPtrToStructs(field) { 117 err := setDefaultsToSlice(field) 118 if err != nil { 119 return err 120 } 121 } else if field.Type().Kind() == reflect.Ptr { 122 err := setDefaultToPtr(field, elem, i) 123 if err != nil { 124 return err 125 } 126 } else { 127 err := setDefaultToBasicType(field, elem, i) 128 if err != nil { 129 return err 130 } 131 } 132 } 133 return nil 134 } 135 136 // setDefaultsToSlice sets default values to elements of a slice. 137 // It assumes it was already checked field is a non empty slice. Otherwise this 138 // function will probably panic. 139 func setDefaultsToSlice(field reflect.Value) error { 140 for i := 0; i < field.Cap(); i++ { 141 err := SetDefaults(field.Index(i).Interface()) 142 if err != nil { 143 return err 144 } 145 } 146 return nil 147 } 148 149 // setDefaultToPtr sets a default value to a field being a pointer. 150 // It assumes it was already checked that field is a ptr. Otherwise this 151 // function will probably panic. 152 func setDefaultToPtr(field reflect.Value, elem reflect.Value, fieldIndex int) error { 153 // if it's a pointer to a struct, look for default types 154 if field.Elem().Kind() == reflect.Struct { 155 err := SetDefaults(field.Interface()) 156 if err != nil { 157 return err 158 } 159 // special case for pointer to bools 160 } else if field.Type().Elem() == reflect.TypeOf(true) { 161 // if a value is set, do nothing 162 if !field.IsNil() { 163 return nil 164 } 165 tags := elem.Type().Field(fieldIndex).Tag 166 defaultValue, hasDefault := tags.Lookup("default") 167 if !hasDefault { 168 // If no default and no value is set, make sure we have a valid 169 // value consistent with the "zero" value for a bool (false) 170 field.Set(reflect.ValueOf(BoolPtr(false))) 171 return nil 172 } 173 if defaultValue == "true" { 174 field.Set(reflect.ValueOf(BoolPtr(true))) 175 } else { 176 field.Set(reflect.ValueOf(BoolPtr(false))) 177 } 178 } 179 return nil 180 } 181 182 // setDefaultToBasicType sets the default value to a basic type (and slice) based on the 183 // "default" tag. 184 func setDefaultToBasicType(field reflect.Value, elem reflect.Value, fieldIndex int) error { 185 tags := elem.Type().Field(fieldIndex).Tag 186 defaultValue, hasDefault := tags.Lookup("default") 187 if !hasDefault { 188 return nil 189 } 190 indirectedField := reflect.Indirect(field) 191 if indirectedField.CanSet() && field.IsZero() { 192 varType := field.Type().Kind() 193 switch varType { 194 case reflect.String: 195 field.SetString(defaultValue) 196 case reflect.Slice: 197 defaultValues := strings.Split(defaultValue, ",") 198 field.Set(reflect.ValueOf(defaultValues)) 199 case reflect.Bool: 200 return fmt.Errorf("Setting default value of a boolean not supported. Use a pointer to boolean instead.") 201 default: 202 return fmt.Errorf("Setting default value of type %s not supported", 203 varType) 204 } 205 } 206 return nil 207 } 208 209 // CheckEmptyFields iterates through the image definition struct and 210 // checks for fields that are present but return IsZero == true. 211 // TODO: I've created a PR upstream in xeipuuv/gojsonschema 212 // https://github.com/xeipuuv/gojsonschema/pull/352 213 // if it gets merged this can be deleted 214 func CheckEmptyFields(Interface interface{}, result *gojsonschema.Result, schema *jsonschema.Schema) error { 215 value := reflect.ValueOf(Interface) 216 if value.Kind() != reflect.Ptr { 217 return fmt.Errorf("The argument to CheckEmptyFields must be a pointer") 218 } 219 elem := value.Elem() 220 for i := 0; i < elem.NumField(); i++ { 221 field := elem.Field(i) 222 // if we're dealing with a slice, iterate through 223 // it and search for missing required fields in each 224 // element of the slice 225 if field.Type().Kind() == reflect.Slice { 226 err := checkEmptyFieldsInSlice(field, result, schema) 227 if err != nil { 228 return err 229 } 230 } else if field.Type().Kind() == reflect.Ptr { 231 // otherwise if it's just a pointer to a nested struct 232 // search it for empty required fields 233 err := checkEmptyFieldsInPtr(field, result, schema) 234 if err != nil { 235 return err 236 } 237 } else { 238 tags := elem.Type().Field(i).Tag 239 if !isRequiredFromTags(tags) && !isRequiredFromSchema(elem, i, schema) { 240 continue 241 } 242 // this is a required field, check for zero values 243 if !reflect.Indirect(field).IsZero() { 244 continue 245 } 246 jsonContext := gojsonschema.NewJsonContext("image_definition", nil) 247 errDetail := gojsonschema.ErrorDetails{ 248 "property": tags.Get("yaml"), 249 "parent": elem.Type().Name(), 250 } 251 result.AddError( 252 newMissingFieldError( 253 gojsonschema.NewJsonContext("missing_field", jsonContext), 254 52, 255 errDetail, 256 ), 257 errDetail, 258 ) 259 } 260 } 261 return nil 262 } 263 264 func checkEmptyFieldsInSlice(field reflect.Value, result *gojsonschema.Result, schema *jsonschema.Schema) error { 265 for i := 0; i < field.Cap(); i++ { 266 sliceElem := field.Index(i) 267 if sliceElem.Kind() == reflect.Ptr && sliceElem.Elem().Kind() == reflect.Struct { 268 err := CheckEmptyFields(sliceElem.Interface(), result, schema) 269 if err != nil { 270 return err 271 } 272 } 273 } 274 return nil 275 } 276 277 func checkEmptyFieldsInPtr(field reflect.Value, result *gojsonschema.Result, schema *jsonschema.Schema) error { 278 if field.Elem().Kind() == reflect.Struct { 279 err := CheckEmptyFields(field.Interface(), result, schema) 280 if err != nil { 281 return err 282 } 283 } 284 return nil 285 } 286 287 // isRequiredFromTags checks if the field is required from the JSON tags 288 func isRequiredFromTags(tags reflect.StructTag) bool { 289 jsonTag, hasJSON := tags.Lookup("json") 290 if hasJSON { 291 if !strings.Contains(jsonTag, "omitempty") { 292 return true 293 } 294 } 295 return false 296 } 297 298 // isRequiredFromSchema checks if the field is required from the schema 299 func isRequiredFromSchema(elem reflect.Value, i int, schema *jsonschema.Schema) bool { 300 for _, requiredField := range schema.Required { 301 if elem.Type().Field(i).Name == requiredField { 302 return true 303 } 304 } 305 return false 306 } 307 308 func newMissingFieldError(context *gojsonschema.JsonContext, value interface{}, details gojsonschema.ErrorDetails) *MissingFieldError { 309 err := MissingFieldError{} 310 err.SetContext(context) 311 err.SetType("missing_field_error") 312 err.SetValue(value) 313 err.SetDescriptionFormat("Key \"{{.property}}\" is required in struct \"{{.parent}}\", but is not in the YAML file!") 314 err.SetDetails(details) 315 316 return &err 317 } 318 319 // MissingFieldError is used when the fields exist but are the zero value for their type 320 type MissingFieldError struct { 321 gojsonschema.ResultErrorFields 322 } 323 324 // SliceHasElement searches for a string in a slice of strings and returns whether it 325 // is found 326 func SliceHasElement(haystack []string, needle string) bool { 327 found := false 328 for _, element := range haystack { 329 if element == needle { 330 found = true 331 } 332 } 333 return found 334 } 335 336 // SetCommandOutput sets the output of a command to either use a multiwriter 337 // or behave as a normal command and store the output in a buffer 338 func SetCommandOutput(cmd *exec.Cmd, liveOutput bool) (cmdOutput *bytes.Buffer) { 339 var cmdOutputBuffer bytes.Buffer 340 cmdOutput = &cmdOutputBuffer 341 cmd.Stdout = cmdOutput 342 cmd.Stderr = cmdOutput 343 if liveOutput { 344 mwriter := io.MultiWriter(os.Stdout, cmdOutput) 345 cmd.Stdout = mwriter 346 cmd.Stderr = mwriter 347 } 348 return cmdOutput 349 } 350 351 func RunCmd(cmd *exec.Cmd, debug bool) error { 352 output := SetCommandOutput(cmd, debug) 353 err := cmd.Run() 354 if err != nil { 355 return fmt.Errorf("Error running command \"%s\". Error: %s. Output:\n%s", 356 cmd.String(), err.Error(), output.String()) 357 } 358 return nil 359 } 360 361 // RunCmds runs a list of commands and returns the error 362 // It stops at the first error 363 func RunCmds(cmds []*exec.Cmd, debug bool) error { 364 for _, cmd := range cmds { 365 err := RunCmd(cmd, debug) 366 if err != nil { 367 return err 368 } 369 } 370 371 return nil 372 } 373 374 // SafeQuantitySubtraction subtracts quantities while checking for integer underflow 375 func SafeQuantitySubtraction(orig, subtract quantity.Size) quantity.Size { 376 if subtract > orig { 377 return 0 378 } 379 return orig - subtract 380 } 381 382 // CreateTarArchive places all of the files from a source directory into a tar. 383 // Currently supported are uncompressed tar archives and the following 384 // compression types: zip, gzip, xz bzip2, zstd 385 func CreateTarArchive(src, dest, compression string, verbose, debug bool) error { 386 tarCommand := *exec.Command( 387 "tar", 388 "--directory", 389 src, 390 "--xattrs", 391 "--xattrs-include=*", 392 "--create", 393 "--file", 394 dest, 395 ".", 396 ) 397 if debug { 398 tarCommand.Args = append(tarCommand.Args, "--verbose") 399 } 400 // set up any compression arguments 401 switch compression { 402 case "uncompressed": 403 break 404 case "bzip2": 405 tarCommand.Args = append(tarCommand.Args, "--bzip2") 406 case "gzip": 407 tarCommand.Args = append(tarCommand.Args, "--gzip") 408 case "xz": 409 tarCommand.Args = append(tarCommand.Args, "--xz") 410 case "zstd": 411 tarCommand.Args = append(tarCommand.Args, "--zstd") 412 default: 413 return fmt.Errorf("Unknown compression type: \"%s\"", compression) 414 } 415 416 tarOutput := SetCommandOutput(&tarCommand, debug) 417 if err := tarCommand.Run(); err != nil { 418 return fmt.Errorf("Error running \"tar\" command \"%s\". "+ 419 "Error is \"%s\". Full output below:\n%s", 420 tarCommand.String(), err.Error(), tarOutput.String()) 421 } 422 return nil 423 } 424 425 // ExtractTarArchive extracts all the files from a tar. Currently supported are 426 // uncompressed tar archives and the following compression types: zip, gzip, xz 427 // bzip2, zstd 428 func ExtractTarArchive(src, dest string, verbose, debug bool) error { 429 tarCommand := *exec.Command( 430 "tar", 431 "--xattrs", 432 "--xattrs-include=*", 433 "--extract", 434 "--file", 435 src, 436 "--directory", 437 dest, 438 ) 439 if debug { 440 tarCommand.Args = append(tarCommand.Args, "--verbose") 441 } 442 tarOutput := SetCommandOutput(&tarCommand, debug) 443 if err := tarCommand.Run(); err != nil { 444 return fmt.Errorf("Error running \"tar\" command \"%s\". "+ 445 "Error is \"%s\". Full output below:\n%s", 446 tarCommand.String(), err.Error(), tarOutput.String()) 447 } 448 return nil 449 } 450 451 // CalculateSHA256 calculates the SHA256 sum of the file provided as an argument 452 func CalculateSHA256(fileName string) (string, error) { 453 f, err := os.Open(fileName) 454 if err != nil { 455 return "", fmt.Errorf("Error opening file \"%s\" to calculate SHA256 sum: \"%s\"", fileName, err.Error()) 456 } 457 defer f.Close() 458 459 hasher := sha256.New() 460 _, err = io.Copy(hasher, f) 461 if err != nil { 462 return "", fmt.Errorf("Error calculating SHA256 sum of file \"%s\": \"%s\"", fileName, err.Error()) 463 } 464 465 return string(hasher.Sum(nil)), nil 466 } 467 468 // CheckTags iterates through the keys in a struct and looks for 469 // a value passed in as a parameter. It returns the yaml name of 470 // the key and an error. Currently only boolean values for the tags 471 // are supported 472 func CheckTags(searchStruct interface{}, tag string) (string, error) { 473 value := reflect.ValueOf(searchStruct) 474 if value.Kind() != reflect.Ptr { 475 return "", fmt.Errorf("The argument to CheckTags must be a pointer") 476 } 477 elem := value.Elem() 478 479 return checkTagsOnField(elem, tag) 480 } 481 482 func isSliceOfPtrToStructs(field reflect.Value) bool { 483 return field.Type().Kind() == reflect.Slice && 484 field.Cap() > 0 && 485 field.Index(0).Kind() == reflect.Pointer 486 } 487 488 func checkTagsOnField(elem reflect.Value, tag string) (string, error) { 489 for i := 0; i < elem.NumField(); i++ { 490 field := elem.Field(i) 491 // if we're dealing with a slice of pointers to structs, 492 // iterate through it and check the tags for each struct pointer 493 if isSliceOfPtrToStructs(field) { 494 for i := 0; i < field.Cap(); i++ { 495 tagUsed, err := CheckTags(field.Index(i).Interface(), tag) 496 if err != nil { 497 return "", err 498 } 499 if tagUsed != "" { 500 // just return on the first one found. 501 // user can iteratively work through error 502 // messages if there are more than one 503 return tagUsed, nil 504 } 505 } 506 } else if !field.IsNil() { 507 tags := elem.Type().Field(i).Tag 508 tagValue, hasTag := tags.Lookup(tag) 509 if hasTag && tagValue == "true" { 510 yamlName, _ := tags.Lookup("yaml") 511 return yamlName, nil 512 } 513 } 514 } 515 // no true value found 516 return "", nil 517 } 518 519 // BackupAndCopyResolvConf creates a backup of /etc/resolv.conf in a chroot 520 // and copies the contents from the host system into the chroot 521 func BackupAndCopyResolvConf(chroot string) error { 522 if osutil.FileExists(filepath.Join(chroot, "etc", "resolv.conf.tmp")) { 523 // already backed up/copied so do nothing 524 return nil 525 } 526 src := filepath.Join(chroot, "etc", "resolv.conf") 527 dest := filepath.Join(chroot, "etc", "resolv.conf.tmp") 528 if err := os.Rename(src, dest); err != nil { 529 return fmt.Errorf("Error moving file \"%s\" to \"%s\": %s", src, dest, err.Error()) 530 } 531 dest = src 532 src = filepath.Join("/etc", "resolv.conf") 533 if err := osutil.CopyFile(src, dest, osutil.CopyFlagDefault); err != nil { 534 return fmt.Errorf("Error copying file \"%s\" to \"%s\": %s", src, dest, err.Error()) 535 } 536 return nil 537 } 538 539 // RestoreResolvConf restores the resolv.conf in the chroot from the 540 // version that was backed up by BackupAndCopyResolvConf 541 func RestoreResolvConf(chroot string) error { 542 if !osutil.FileExists(filepath.Join(chroot, "etc", "resolv.conf.tmp")) { 543 return nil 544 } 545 if osutil.IsSymlink(filepath.Join(chroot, "etc", "resolv.conf")) { 546 // As per what live-build does, handle the case where some package 547 // in the install_packages phase converts resolv.conf into a 548 // symlink. In such case we don't restore our backup but instead 549 // remove it, leaving the symlink around. 550 backup := filepath.Join(chroot, "etc", "resolv.conf.tmp") 551 if err := osRemove(backup); err != nil { 552 return fmt.Errorf("Error removing file \"%s\": %s", backup, err.Error()) 553 } 554 } else { 555 src := filepath.Join(chroot, "etc", "resolv.conf.tmp") 556 dest := filepath.Join(chroot, "etc", "resolv.conf") 557 if err := osRename(src, dest); err != nil { 558 return fmt.Errorf("Error moving file \"%s\" to \"%s\": %s", src, dest, err.Error()) 559 } 560 } 561 return nil 562 } 563 564 const backupExt = ".REAL" 565 566 // BackupReplace backup the target file and replace it with the given content 567 // Returns the restore function. 568 func BackupReplace(target string, content string) (func(error) error, error) { 569 backup := target + backupExt 570 if osutilFileExists(backup) { 571 // already backed up so do nothing 572 return nil, nil 573 } 574 575 if err := osRename(target, backup); err != nil { 576 return nil, fmt.Errorf("Error moving file \"%s\" to \"%s\": %s", target, backup, err.Error()) 577 } 578 579 if err := osWriteFile(target, []byte(content), 0755); err != nil { 580 return nil, fmt.Errorf("Error writing to %s : %s", target, err.Error()) 581 } 582 583 return genRestoreFile(target), nil 584 } 585 586 // genRestoreFile returns the function to be called to restore the backuped file 587 func genRestoreFile(target string) func(err error) error { 588 return func(err error) error { 589 src := target + backupExt 590 if !osutilFileExists(src) { 591 return err 592 } 593 594 if tmpErr := osRename(src, target); tmpErr != nil { 595 tmpErr = fmt.Errorf("Error moving file \"%s\" to \"%s\": %s", src, target, tmpErr.Error()) 596 return fmt.Errorf("%s\n%s", err, tmpErr) 597 } 598 599 return err 600 } 601 }