github.com/kubiko/snapd@v0.0.0-20201013125620-d4f3094d9ddf/cmd/snap-repair/runner.go (about) 1 // -*- Mode: Go; indent-tabs-mode: t -*- 2 3 /* 4 * Copyright (C) 2017-2020 Canonical Ltd 5 * 6 * This program is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU General Public License version 3 as 8 * published by the Free Software Foundation. 9 * 10 * This program is distributed in the hope that it will be useful, 11 * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 * GNU General Public License for more details. 14 * 15 * You should have received a copy of the GNU General Public License 16 * along with this program. If not, see <http://www.gnu.org/licenses/>. 17 * 18 */ 19 20 package main 21 22 import ( 23 "bufio" 24 "bytes" 25 "crypto/tls" 26 "encoding/json" 27 "errors" 28 "fmt" 29 "io" 30 "io/ioutil" 31 "net/http" 32 "net/url" 33 "os" 34 "os/exec" 35 "path/filepath" 36 "strconv" 37 "strings" 38 "syscall" 39 "time" 40 41 "github.com/mvo5/goconfigparser" 42 "gopkg.in/retry.v1" 43 44 "github.com/snapcore/snapd/arch" 45 "github.com/snapcore/snapd/asserts" 46 "github.com/snapcore/snapd/asserts/sysdb" 47 "github.com/snapcore/snapd/dirs" 48 "github.com/snapcore/snapd/errtracker" 49 "github.com/snapcore/snapd/httputil" 50 "github.com/snapcore/snapd/logger" 51 "github.com/snapcore/snapd/osutil" 52 "github.com/snapcore/snapd/release" 53 "github.com/snapcore/snapd/snapdenv" 54 "github.com/snapcore/snapd/strutil" 55 ) 56 57 var ( 58 // TODO: move inside the repairs themselves? 59 defaultRepairTimeout = 30 * time.Minute 60 ) 61 62 var errtrackerReportRepair = errtracker.ReportRepair 63 64 // Repair is a runnable repair. 65 type Repair struct { 66 *asserts.Repair 67 68 run *Runner 69 sequence int 70 } 71 72 func (r *Repair) RunDir() string { 73 return filepath.Join(dirs.SnapRepairRunDir, r.BrandID(), strconv.Itoa(r.RepairID())) 74 } 75 76 func (r *Repair) String() string { 77 return fmt.Sprintf("%s-%v", r.BrandID(), r.RepairID()) 78 } 79 80 // SetStatus sets the status of the repair in the state and saves the latter. 81 func (r *Repair) SetStatus(status RepairStatus) { 82 brandID := r.BrandID() 83 cur := *r.run.state.Sequences[brandID][r.sequence-1] 84 cur.Status = status 85 r.run.setRepairState(brandID, cur) 86 r.run.SaveState() 87 } 88 89 // makeRepairSymlink ensures $dir/repair exists and is a symlink to 90 // /usr/lib/snapd/snap-repair 91 func makeRepairSymlink(dir string) (err error) { 92 // make "repair" binary available to the repair scripts via symlink 93 // to the real snap-repair 94 if err = os.MkdirAll(dir, 0755); err != nil { 95 return err 96 } 97 98 old := filepath.Join(dirs.CoreLibExecDir, "snap-repair") 99 new := filepath.Join(dir, "repair") 100 if err := os.Symlink(old, new); err != nil && !os.IsExist(err) { 101 return err 102 } 103 104 return nil 105 } 106 107 // Run executes the repair script leaving execution trail files on disk. 108 func (r *Repair) Run() error { 109 // write the script to disk 110 rundir := r.RunDir() 111 err := os.MkdirAll(rundir, 0775) 112 if err != nil { 113 return err 114 } 115 116 // ensure the script can use "repair done" 117 repairToolsDir := filepath.Join(dirs.SnapRunRepairDir, "tools") 118 if err := makeRepairSymlink(repairToolsDir); err != nil { 119 return err 120 } 121 122 baseName := fmt.Sprintf("r%d", r.Revision()) 123 script := filepath.Join(rundir, baseName+".script") 124 err = osutil.AtomicWriteFile(script, r.Body(), 0700, 0) 125 if err != nil { 126 return err 127 } 128 129 logPath := filepath.Join(rundir, baseName+".running") 130 logf, err := os.OpenFile(logPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) 131 if err != nil { 132 return err 133 } 134 defer logf.Close() 135 136 fmt.Fprintf(logf, "repair: %s\n", r) 137 fmt.Fprintf(logf, "revision: %d\n", r.Revision()) 138 fmt.Fprintf(logf, "summary: %s\n", r.Summary()) 139 fmt.Fprintf(logf, "output:\n") 140 141 statusR, statusW, err := os.Pipe() 142 if err != nil { 143 return err 144 } 145 defer statusR.Close() 146 defer statusW.Close() 147 148 logger.Debugf("executing %s", script) 149 150 // run the script 151 env := os.Environ() 152 // we need to hardcode FD=3 because this is the FD after 153 // exec.Command() forked. there is no way in go currently 154 // to run something right after fork() in the child to 155 // know the fd. However because go will close all fds 156 // except the ones in "cmd.ExtraFiles" we are safe to set "3" 157 env = append(env, "SNAP_REPAIR_STATUS_FD=3") 158 env = append(env, "SNAP_REPAIR_RUN_DIR="+rundir) 159 // inject repairToolDir into PATH so that the script can use 160 // `repair {done,skip,retry}` 161 var havePath bool 162 for i, envStr := range env { 163 if strings.HasPrefix(envStr, "PATH=") { 164 newEnv := fmt.Sprintf("%s:%s", strings.TrimSuffix(envStr, ":"), repairToolsDir) 165 env[i] = newEnv 166 havePath = true 167 } 168 } 169 if !havePath { 170 env = append(env, "PATH=/usr/sbin:/usr/bin:/sbin:/bin:"+repairToolsDir) 171 } 172 173 workdir := filepath.Join(rundir, "work") 174 if err := os.MkdirAll(workdir, 0700); err != nil { 175 return err 176 } 177 178 cmd := exec.Command(script) 179 cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} 180 cmd.Env = env 181 cmd.Dir = workdir 182 cmd.ExtraFiles = []*os.File{statusW} 183 cmd.Stdout = logf 184 cmd.Stderr = logf 185 if err = cmd.Start(); err != nil { 186 return err 187 } 188 statusW.Close() 189 190 // wait for repair to finish or timeout 191 var scriptErr error 192 killTimerCh := time.After(defaultRepairTimeout) 193 doneCh := make(chan error, 1) 194 go func() { 195 doneCh <- cmd.Wait() 196 close(doneCh) 197 }() 198 select { 199 case scriptErr = <-doneCh: 200 // done 201 case <-killTimerCh: 202 if err := osutil.KillProcessGroup(cmd); err != nil { 203 logger.Noticef("cannot kill timed out repair %s: %s", r, err) 204 } 205 scriptErr = fmt.Errorf("repair did not finish within %s", defaultRepairTimeout) 206 } 207 // read repair status pipe, use the last value 208 status := readStatus(statusR) 209 statusPath := filepath.Join(rundir, baseName+"."+status.String()) 210 211 // if the script had an error exit status still honor what we 212 // read from the status-pipe, however report the error 213 if scriptErr != nil { 214 scriptErr = fmt.Errorf("repair %s revision %d failed: %s", r, r.Revision(), scriptErr) 215 if err := r.errtrackerReport(scriptErr, status, logPath); err != nil { 216 logger.Noticef("cannot report error to errtracker: %s", err) 217 } 218 // ensure the error is present in the output log 219 fmt.Fprintf(logf, "\n%s", scriptErr) 220 } 221 if err := os.Rename(logPath, statusPath); err != nil { 222 return err 223 } 224 r.SetStatus(status) 225 226 return nil 227 } 228 229 func readStatus(r io.Reader) RepairStatus { 230 var status RepairStatus 231 scanner := bufio.NewScanner(r) 232 for scanner.Scan() { 233 switch strings.TrimSpace(scanner.Text()) { 234 case "done": 235 status = DoneStatus 236 // TODO: support having a script skip over many and up to a given repair-id # 237 case "skip": 238 status = SkipStatus 239 } 240 } 241 if scanner.Err() != nil { 242 return RetryStatus 243 } 244 return status 245 } 246 247 // errtrackerReport reports an repairErr with the given logPath to the 248 // snap error tracker. 249 func (r *Repair) errtrackerReport(repairErr error, status RepairStatus, logPath string) error { 250 errMsg := repairErr.Error() 251 252 scriptOutput, err := ioutil.ReadFile(logPath) 253 if err != nil { 254 logger.Noticef("cannot read %s", logPath) 255 } 256 s := fmt.Sprintf("%s/%d", r.BrandID(), r.RepairID()) 257 258 dupSig := fmt.Sprintf("%s\n%s\noutput:\n%s", s, errMsg, scriptOutput) 259 extra := map[string]string{ 260 "Revision": strconv.Itoa(r.Revision()), 261 "BrandID": r.BrandID(), 262 "RepairID": strconv.Itoa(r.RepairID()), 263 "Status": status.String(), 264 } 265 _, err = errtrackerReportRepair(s, errMsg, dupSig, extra) 266 return err 267 } 268 269 // Runner implements fetching, tracking and running repairs. 270 type Runner struct { 271 BaseURL *url.URL 272 cli *http.Client 273 274 state state 275 stateModified bool 276 277 // sequenceNext keeps track of the next integer id in a brand sequence to considered in this run, see Next. 278 sequenceNext map[string]int 279 } 280 281 // NewRunner returns a Runner. 282 func NewRunner() *Runner { 283 run := &Runner{ 284 sequenceNext: make(map[string]int), 285 } 286 opts := httputil.ClientOptions{ 287 MayLogBody: false, 288 ProxyConnectHeader: http.Header{"User-Agent": []string{snapdenv.UserAgent()}}, 289 TLSConfig: &tls.Config{ 290 Time: run.now, 291 }, 292 ExtraSSLCerts: &httputil.ExtraSSLCertsFromDir{ 293 Dir: dirs.SnapdStoreSSLCertsDir, 294 }, 295 } 296 run.cli = httputil.NewHTTPClient(&opts) 297 return run 298 } 299 300 var ( 301 fetchRetryStrategy = retry.LimitCount(7, retry.LimitTime(90*time.Second, 302 retry.Exponential{ 303 Initial: 500 * time.Millisecond, 304 Factor: 2.5, 305 }, 306 )) 307 308 peekRetryStrategy = retry.LimitCount(5, retry.LimitTime(44*time.Second, 309 retry.Exponential{ 310 Initial: 300 * time.Millisecond, 311 Factor: 2.5, 312 }, 313 )) 314 ) 315 316 var ( 317 ErrRepairNotFound = errors.New("repair not found") 318 ErrRepairNotModified = errors.New("repair was not modified") 319 ) 320 321 var ( 322 maxRepairScriptSize = 24 * 1024 * 1024 323 ) 324 325 // Fetch retrieves a stream with the repair with the given ids and any 326 // auxiliary assertions. If revision>=0 the request will include an 327 // If-None-Match header with an ETag for the revision, and 328 // ErrRepairNotModified is returned if the revision is still current. 329 func (run *Runner) Fetch(brandID string, repairID int, revision int) (*asserts.Repair, []asserts.Assertion, error) { 330 u, err := run.BaseURL.Parse(fmt.Sprintf("repairs/%s/%d", brandID, repairID)) 331 if err != nil { 332 return nil, nil, err 333 } 334 335 var r []asserts.Assertion 336 resp, err := httputil.RetryRequest(u.String(), func() (*http.Response, error) { 337 req, err := http.NewRequest("GET", u.String(), nil) 338 if err != nil { 339 return nil, err 340 } 341 req.Header.Set("User-Agent", snapdenv.UserAgent()) 342 req.Header.Set("Accept", "application/x.ubuntu.assertion") 343 if revision >= 0 { 344 req.Header.Set("If-None-Match", fmt.Sprintf(`"%d"`, revision)) 345 } 346 return run.cli.Do(req) 347 }, func(resp *http.Response) error { 348 if resp.StatusCode == 200 { 349 logger.Debugf("fetching repair %s-%d", brandID, repairID) 350 // decode assertions 351 dec := asserts.NewDecoderWithTypeMaxBodySize(resp.Body, map[*asserts.AssertionType]int{ 352 asserts.RepairType: maxRepairScriptSize, 353 }) 354 for { 355 a, err := dec.Decode() 356 if err == io.EOF { 357 break 358 } 359 if err != nil { 360 return err 361 } 362 r = append(r, a) 363 } 364 if len(r) == 0 { 365 return io.ErrUnexpectedEOF 366 } 367 } 368 return nil 369 }, fetchRetryStrategy) 370 371 if err != nil { 372 return nil, nil, err 373 } 374 375 moveTimeLowerBound := true 376 defer func() { 377 if moveTimeLowerBound { 378 t, _ := http.ParseTime(resp.Header.Get("Date")) 379 run.moveTimeLowerBound(t) 380 } 381 }() 382 383 switch resp.StatusCode { 384 case 200: 385 // ok 386 case 304: 387 // not modified 388 return nil, nil, ErrRepairNotModified 389 case 404: 390 return nil, nil, ErrRepairNotFound 391 default: 392 moveTimeLowerBound = false 393 return nil, nil, fmt.Errorf("cannot fetch repair, unexpected status %d", resp.StatusCode) 394 } 395 396 repair, aux, err := checkStream(brandID, repairID, r) 397 if err != nil { 398 return nil, nil, fmt.Errorf("cannot fetch repair, %v", err) 399 } 400 401 if repair.Revision() <= revision { 402 // this shouldn't happen but if it does we behave like 403 // all the rest of assertion infrastructure and ignore 404 // the now superseded revision 405 return nil, nil, ErrRepairNotModified 406 } 407 408 return repair, aux, err 409 } 410 411 func checkStream(brandID string, repairID int, r []asserts.Assertion) (repair *asserts.Repair, aux []asserts.Assertion, err error) { 412 if len(r) == 0 { 413 return nil, nil, fmt.Errorf("empty repair assertions stream") 414 } 415 var ok bool 416 repair, ok = r[0].(*asserts.Repair) 417 if !ok { 418 return nil, nil, fmt.Errorf("unexpected first assertion %q", r[0].Type().Name) 419 } 420 421 if repair.BrandID() != brandID || repair.RepairID() != repairID { 422 return nil, nil, fmt.Errorf("repair id mismatch %s/%d != %s/%d", repair.BrandID(), repair.RepairID(), brandID, repairID) 423 } 424 425 return repair, r[1:], nil 426 } 427 428 type peekResp struct { 429 Headers map[string]interface{} `json:"headers"` 430 } 431 432 // Peek retrieves the headers for the repair with the given ids. 433 func (run *Runner) Peek(brandID string, repairID int) (headers map[string]interface{}, err error) { 434 u, err := run.BaseURL.Parse(fmt.Sprintf("repairs/%s/%d", brandID, repairID)) 435 if err != nil { 436 return nil, err 437 } 438 439 var rsp peekResp 440 441 resp, err := httputil.RetryRequest(u.String(), func() (*http.Response, error) { 442 req, err := http.NewRequest("GET", u.String(), nil) 443 if err != nil { 444 return nil, err 445 } 446 req.Header.Set("User-Agent", snapdenv.UserAgent()) 447 req.Header.Set("Accept", "application/json") 448 return run.cli.Do(req) 449 }, func(resp *http.Response) error { 450 rsp.Headers = nil 451 if resp.StatusCode == 200 { 452 dec := json.NewDecoder(resp.Body) 453 return dec.Decode(&rsp) 454 } 455 return nil 456 }, peekRetryStrategy) 457 458 if err != nil { 459 return nil, err 460 } 461 462 moveTimeLowerBound := true 463 defer func() { 464 if moveTimeLowerBound { 465 t, _ := http.ParseTime(resp.Header.Get("Date")) 466 run.moveTimeLowerBound(t) 467 } 468 }() 469 470 switch resp.StatusCode { 471 case 200: 472 // ok 473 case 404: 474 return nil, ErrRepairNotFound 475 default: 476 moveTimeLowerBound = false 477 return nil, fmt.Errorf("cannot peek repair headers, unexpected status %d", resp.StatusCode) 478 } 479 480 headers = rsp.Headers 481 if headers["brand-id"] != brandID || headers["repair-id"] != strconv.Itoa(repairID) { 482 return nil, fmt.Errorf("cannot peek repair headers, repair id mismatch %s/%s != %s/%d", headers["brand-id"], headers["repair-id"], brandID, repairID) 483 } 484 485 return headers, nil 486 } 487 488 // deviceInfo captures information about the device. 489 type deviceInfo struct { 490 Brand string `json:"brand"` 491 Model string `json:"model"` 492 } 493 494 // RepairStatus represents the possible statuses of a repair. 495 type RepairStatus int 496 497 const ( 498 RetryStatus RepairStatus = iota 499 SkipStatus 500 DoneStatus 501 ) 502 503 func (rs RepairStatus) String() string { 504 switch rs { 505 case RetryStatus: 506 return "retry" 507 case SkipStatus: 508 return "skip" 509 case DoneStatus: 510 return "done" 511 default: 512 return "unknown" 513 } 514 } 515 516 // RepairState holds the current revision and status of a repair in a sequence of repairs. 517 type RepairState struct { 518 Sequence int `json:"sequence"` 519 Revision int `json:"revision"` 520 Status RepairStatus `json:"status"` 521 } 522 523 // state holds the atomically updated control state of the runner with sequences of repairs and their states. 524 type state struct { 525 Device deviceInfo `json:"device"` 526 Sequences map[string][]*RepairState `json:"sequences,omitempty"` 527 TimeLowerBound time.Time `json:"time-lower-bound"` 528 } 529 530 func (run *Runner) setRepairState(brandID string, state RepairState) { 531 if run.state.Sequences == nil { 532 run.state.Sequences = make(map[string][]*RepairState) 533 } 534 sequence := run.state.Sequences[brandID] 535 if state.Sequence > len(sequence) { 536 run.stateModified = true 537 run.state.Sequences[brandID] = append(sequence, &state) 538 } else if *sequence[state.Sequence-1] != state { 539 run.stateModified = true 540 sequence[state.Sequence-1] = &state 541 } 542 } 543 544 func (run *Runner) readState() error { 545 r, err := os.Open(dirs.SnapRepairStateFile) 546 if err != nil { 547 return err 548 } 549 defer r.Close() 550 dec := json.NewDecoder(r) 551 return dec.Decode(&run.state) 552 } 553 554 func (run *Runner) moveTimeLowerBound(t time.Time) { 555 if t.After(run.state.TimeLowerBound) { 556 run.stateModified = true 557 run.state.TimeLowerBound = t.UTC() 558 } 559 } 560 561 var timeNow = time.Now 562 563 func (run *Runner) now() time.Time { 564 now := timeNow().UTC() 565 if now.Before(run.state.TimeLowerBound) { 566 return run.state.TimeLowerBound 567 } 568 return now 569 } 570 571 func (run *Runner) initState() error { 572 if err := os.MkdirAll(dirs.SnapRepairDir, 0775); err != nil { 573 return fmt.Errorf("cannot create repair state directory: %v", err) 574 } 575 // best-effort remove old 576 os.Remove(dirs.SnapRepairStateFile) 577 run.state = state{} 578 // initialize time lower bound with image built time/seed.yaml time 579 if err := run.findTimeLowerBound(); err != nil { 580 return err 581 } 582 // initialize device info 583 if err := run.initDeviceInfo(); err != nil { 584 return err 585 } 586 run.stateModified = true 587 return run.SaveState() 588 } 589 590 func trustedBackstore(trusted []asserts.Assertion) asserts.Backstore { 591 trustedBS := asserts.NewMemoryBackstore() 592 for _, t := range trusted { 593 trustedBS.Put(t.Type(), t) 594 } 595 return trustedBS 596 } 597 598 func checkAuthorityID(a asserts.Assertion, trusted asserts.Backstore) error { 599 assertType := a.Type() 600 if assertType != asserts.AccountKeyType && assertType != asserts.AccountType { 601 return nil 602 } 603 // check that account and account-key assertions are signed by 604 // a trusted authority 605 acctID := a.AuthorityID() 606 _, err := trusted.Get(asserts.AccountType, []string{acctID}, asserts.AccountType.MaxSupportedFormat()) 607 if err != nil && !asserts.IsNotFound(err) { 608 return err 609 } 610 if asserts.IsNotFound(err) { 611 return fmt.Errorf("%v not signed by trusted authority: %s", a.Ref(), acctID) 612 } 613 return nil 614 } 615 616 func verifySignatures(a asserts.Assertion, workBS asserts.Backstore, trusted asserts.Backstore) error { 617 if err := checkAuthorityID(a, trusted); err != nil { 618 return err 619 } 620 acctKeyMaxSuppFormat := asserts.AccountKeyType.MaxSupportedFormat() 621 622 seen := make(map[string]bool) 623 bottom := false 624 for !bottom { 625 u := a.Ref().Unique() 626 if seen[u] { 627 return fmt.Errorf("circular assertions") 628 } 629 seen[u] = true 630 signKey := []string{a.SignKeyID()} 631 key, err := trusted.Get(asserts.AccountKeyType, signKey, acctKeyMaxSuppFormat) 632 if err != nil && !asserts.IsNotFound(err) { 633 return err 634 } 635 if err == nil { 636 bottom = true 637 } else { 638 key, err = workBS.Get(asserts.AccountKeyType, signKey, acctKeyMaxSuppFormat) 639 if err != nil && !asserts.IsNotFound(err) { 640 return err 641 } 642 if asserts.IsNotFound(err) { 643 return fmt.Errorf("cannot find public key %q", signKey[0]) 644 } 645 if err := checkAuthorityID(key, trusted); err != nil { 646 return err 647 } 648 } 649 if err := asserts.CheckSignature(a, key.(*asserts.AccountKey), nil, time.Time{}); err != nil { 650 return err 651 } 652 a = key 653 } 654 return nil 655 } 656 657 func (run *Runner) findTimeLowerBound() error { 658 timeLowerBoundSources := []string{ 659 // uc16 660 filepath.Join(dirs.SnapSeedDir, "seed.yaml"), 661 // uc20+ 662 dirs.SnapModeenvFile, 663 } 664 // add all model files from uc20 seeds 665 allModels, err := filepath.Glob(filepath.Join(dirs.SnapSeedDir, "systems/*/model")) 666 if err != nil { 667 return err 668 } 669 timeLowerBoundSources = append(timeLowerBoundSources, allModels...) 670 671 // use all files as potential time inputs 672 for _, p := range timeLowerBoundSources { 673 info, err := os.Stat(p) 674 if os.IsNotExist(err) { 675 continue 676 } 677 if err != nil { 678 return err 679 } 680 run.moveTimeLowerBound(info.ModTime()) 681 } 682 return nil 683 } 684 685 func findBrandAndModel() (string, string, error) { 686 if osutil.FileExists(dirs.SnapModeenvFile) { 687 return findBrandAndModel20() 688 } 689 return findBrandAndModel16() 690 } 691 692 func findBrandAndModel20() (brand, model string, err error) { 693 cfg := goconfigparser.New() 694 cfg.AllowNoSectionHeader = true 695 if err := cfg.ReadFile(dirs.SnapModeenvFile); err != nil { 696 return "", "", err 697 } 698 brandAndModel, err := cfg.Get("", "model") 699 if err != nil { 700 return "", "", err 701 } 702 l := strings.SplitN(brandAndModel, "/", 2) 703 if len(l) != 2 { 704 return "", "", fmt.Errorf("cannot find brand/model in modeenv model string %q", brandAndModel) 705 } 706 707 return l[0], l[1], nil 708 } 709 710 func findBrandAndModel16() (brand, model string, err error) { 711 workBS := asserts.NewMemoryBackstore() 712 assertSeedDir := filepath.Join(dirs.SnapSeedDir, "assertions") 713 dc, err := ioutil.ReadDir(assertSeedDir) 714 if err != nil { 715 return "", "", err 716 } 717 var modelAs *asserts.Model 718 for _, fi := range dc { 719 fn := filepath.Join(assertSeedDir, fi.Name()) 720 f, err := os.Open(fn) 721 if err != nil { 722 // best effort 723 continue 724 } 725 dec := asserts.NewDecoder(f) 726 for { 727 a, err := dec.Decode() 728 if err != nil { 729 // best effort 730 break 731 } 732 switch a.Type() { 733 case asserts.ModelType: 734 if modelAs != nil { 735 return "", "", fmt.Errorf("multiple models in seed assertions") 736 } 737 modelAs = a.(*asserts.Model) 738 case asserts.AccountType, asserts.AccountKeyType: 739 workBS.Put(a.Type(), a) 740 } 741 } 742 } 743 if modelAs == nil { 744 return "", "", fmt.Errorf("no model assertion in seed data") 745 } 746 trustedBS := trustedBackstore(sysdb.Trusted()) 747 if err := verifySignatures(modelAs, workBS, trustedBS); err != nil { 748 return "", "", err 749 } 750 acctPK := []string{modelAs.BrandID()} 751 acctMaxSupFormat := asserts.AccountType.MaxSupportedFormat() 752 acct, err := trustedBS.Get(asserts.AccountType, acctPK, acctMaxSupFormat) 753 if err != nil { 754 var err error 755 acct, err = workBS.Get(asserts.AccountType, acctPK, acctMaxSupFormat) 756 if err != nil { 757 return "", "", fmt.Errorf("no brand account assertion in seed data") 758 } 759 } 760 if err := verifySignatures(acct, workBS, trustedBS); err != nil { 761 return "", "", err 762 } 763 return modelAs.BrandID(), modelAs.Model(), nil 764 } 765 766 func (run *Runner) initDeviceInfo() error { 767 brandID, model, err := findBrandAndModel() 768 if err != nil { 769 return fmt.Errorf("cannot set device information: %v", err) 770 } 771 run.state.Device.Brand = brandID 772 run.state.Device.Model = model 773 return nil 774 } 775 776 // LoadState loads the repairs' state from disk, and (re)initializes it if it's missing or corrupted. 777 func (run *Runner) LoadState() error { 778 err := run.readState() 779 if err == nil { 780 return nil 781 } 782 // error => initialize from scratch 783 if !os.IsNotExist(err) { 784 logger.Noticef("cannor read repair state: %v", err) 785 } 786 return run.initState() 787 } 788 789 // SaveState saves the repairs' state to disk. 790 func (run *Runner) SaveState() error { 791 if !run.stateModified { 792 return nil 793 } 794 m, err := json.Marshal(&run.state) 795 if err != nil { 796 return fmt.Errorf("cannot marshal repair state: %v", err) 797 } 798 err = osutil.AtomicWriteFile(dirs.SnapRepairStateFile, m, 0600, 0) 799 if err != nil { 800 return fmt.Errorf("cannot save repair state: %v", err) 801 } 802 run.stateModified = false 803 return nil 804 } 805 806 func stringList(headers map[string]interface{}, name string) ([]string, error) { 807 v, ok := headers[name] 808 if !ok { 809 return nil, nil 810 } 811 l, ok := v.([]interface{}) 812 if !ok { 813 return nil, fmt.Errorf("header %q is not a list", name) 814 } 815 r := make([]string, len(l)) 816 for i, v := range l { 817 s, ok := v.(string) 818 if !ok { 819 return nil, fmt.Errorf("header %q contains non-string elements", name) 820 } 821 r[i] = s 822 } 823 return r, nil 824 } 825 826 // Applicable returns whether a repair with the given headers is applicable to the device. 827 func (run *Runner) Applicable(headers map[string]interface{}) bool { 828 if headers["disabled"] == "true" { 829 return false 830 } 831 series, err := stringList(headers, "series") 832 if err != nil { 833 return false 834 } 835 if len(series) != 0 && !strutil.ListContains(series, release.Series) { 836 return false 837 } 838 archs, err := stringList(headers, "architectures") 839 if err != nil { 840 return false 841 } 842 if len(archs) != 0 && !strutil.ListContains(archs, arch.DpkgArchitecture()) { 843 return false 844 } 845 brandModel := fmt.Sprintf("%s/%s", run.state.Device.Brand, run.state.Device.Model) 846 models, err := stringList(headers, "models") 847 if err != nil { 848 return false 849 } 850 if len(models) != 0 && !strutil.ListContains(models, brandModel) { 851 // model prefix matching: brand/prefix* 852 hit := false 853 for _, patt := range models { 854 if strings.HasSuffix(patt, "*") && strings.ContainsRune(patt, '/') { 855 if strings.HasPrefix(brandModel, strings.TrimSuffix(patt, "*")) { 856 hit = true 857 break 858 } 859 } 860 } 861 if !hit { 862 return false 863 } 864 } 865 return true 866 } 867 868 var errSkip = errors.New("repair unnecessary on this system") 869 870 func (run *Runner) fetch(brandID string, repairID int) (repair *asserts.Repair, aux []asserts.Assertion, err error) { 871 headers, err := run.Peek(brandID, repairID) 872 if err != nil { 873 return nil, nil, err 874 } 875 if !run.Applicable(headers) { 876 return nil, nil, errSkip 877 } 878 return run.Fetch(brandID, repairID, -1) 879 } 880 881 func (run *Runner) refetch(brandID string, repairID, revision int) (repair *asserts.Repair, aux []asserts.Assertion, err error) { 882 return run.Fetch(brandID, repairID, revision) 883 } 884 885 func (run *Runner) saveStream(brandID string, repairID int, repair *asserts.Repair, aux []asserts.Assertion) error { 886 d := filepath.Join(dirs.SnapRepairAssertsDir, brandID, strconv.Itoa(repairID)) 887 err := os.MkdirAll(d, 0775) 888 if err != nil { 889 return err 890 } 891 buf := &bytes.Buffer{} 892 enc := asserts.NewEncoder(buf) 893 r := append([]asserts.Assertion{repair}, aux...) 894 for _, a := range r { 895 if err := enc.Encode(a); err != nil { 896 return fmt.Errorf("cannot encode repair assertions %s-%d for saving: %v", brandID, repairID, err) 897 } 898 } 899 p := filepath.Join(d, fmt.Sprintf("r%d.repair", r[0].Revision())) 900 return osutil.AtomicWriteFile(p, buf.Bytes(), 0600, 0) 901 } 902 903 func (run *Runner) readSavedStream(brandID string, repairID, revision int) (repair *asserts.Repair, aux []asserts.Assertion, err error) { 904 d := filepath.Join(dirs.SnapRepairAssertsDir, brandID, strconv.Itoa(repairID)) 905 p := filepath.Join(d, fmt.Sprintf("r%d.repair", revision)) 906 f, err := os.Open(p) 907 if err != nil { 908 return nil, nil, err 909 } 910 defer f.Close() 911 912 dec := asserts.NewDecoder(f) 913 var r []asserts.Assertion 914 for { 915 a, err := dec.Decode() 916 if err == io.EOF { 917 break 918 } 919 if err != nil { 920 return nil, nil, fmt.Errorf("cannot decode repair assertions %s-%d from disk: %v", brandID, repairID, err) 921 } 922 r = append(r, a) 923 } 924 return checkStream(brandID, repairID, r) 925 } 926 927 func (run *Runner) makeReady(brandID string, sequenceNext int) (repair *asserts.Repair, err error) { 928 sequence := run.state.Sequences[brandID] 929 var aux []asserts.Assertion 930 var state RepairState 931 if sequenceNext <= len(sequence) { 932 // consider retries 933 state = *sequence[sequenceNext-1] 934 if state.Status != RetryStatus { 935 return nil, errSkip 936 } 937 var err error 938 repair, aux, err = run.refetch(brandID, state.Sequence, state.Revision) 939 if err != nil { 940 if err != ErrRepairNotModified { 941 logger.Noticef("cannot refetch repair %s-%d, will retry what is on disk: %v", brandID, sequenceNext, err) 942 } 943 // try to use what we have already on disk 944 repair, aux, err = run.readSavedStream(brandID, state.Sequence, state.Revision) 945 if err != nil { 946 return nil, err 947 } 948 } 949 } else { 950 // fetch the next repair in the sequence 951 // assumes no gaps, each repair id is present so far, 952 // possibly skipped 953 var err error 954 repair, aux, err = run.fetch(brandID, sequenceNext) 955 if err != nil && err != errSkip { 956 return nil, err 957 } 958 state = RepairState{ 959 Sequence: sequenceNext, 960 } 961 if err == errSkip { 962 // TODO: store headers to justify decision 963 state.Status = SkipStatus 964 run.setRepairState(brandID, state) 965 return nil, errSkip 966 } 967 } 968 // verify with signatures 969 if err := run.Verify(repair, aux); err != nil { 970 return nil, fmt.Errorf("cannot verify repair %s-%d: %v", brandID, state.Sequence, err) 971 } 972 if err := run.saveStream(brandID, state.Sequence, repair, aux); err != nil { 973 return nil, err 974 } 975 state.Revision = repair.Revision() 976 if !run.Applicable(repair.Headers()) { 977 state.Status = SkipStatus 978 run.setRepairState(brandID, state) 979 return nil, errSkip 980 } 981 run.setRepairState(brandID, state) 982 return repair, nil 983 } 984 985 // Next returns the next repair for the brand id sequence to run/retry or ErrRepairNotFound if there is none atm. It updates the state as required. 986 func (run *Runner) Next(brandID string) (*Repair, error) { 987 sequenceNext := run.sequenceNext[brandID] 988 if sequenceNext == 0 { 989 sequenceNext = 1 990 } 991 for { 992 repair, err := run.makeReady(brandID, sequenceNext) 993 // SaveState is a no-op unless makeReady modified the state 994 stateErr := run.SaveState() 995 if err != nil && err != errSkip && err != ErrRepairNotFound { 996 // err is a non trivial error, just log the SaveState error and report err 997 if stateErr != nil { 998 logger.Noticef("%v", stateErr) 999 } 1000 return nil, err 1001 } 1002 if stateErr != nil { 1003 return nil, stateErr 1004 } 1005 if err == ErrRepairNotFound { 1006 return nil, ErrRepairNotFound 1007 } 1008 1009 sequenceNext += 1 1010 run.sequenceNext[brandID] = sequenceNext 1011 if err == errSkip { 1012 continue 1013 } 1014 1015 return &Repair{ 1016 Repair: repair, 1017 run: run, 1018 sequence: sequenceNext - 1, 1019 }, nil 1020 } 1021 } 1022 1023 // Limit trust to specific keys while there's no delegation or limited 1024 // keys support. The obtained assertion stream may also include 1025 // account keys that are directly or indirectly signed by a trusted 1026 // key. 1027 var ( 1028 trustedRepairRootKeys []*asserts.AccountKey 1029 ) 1030 1031 // Verify verifies that the repair is properly signed by the specific 1032 // trusted root keys or by account keys in the stream (passed via aux) 1033 // directly or indirectly signed by a trusted key. 1034 func (run *Runner) Verify(repair *asserts.Repair, aux []asserts.Assertion) error { 1035 workBS := asserts.NewMemoryBackstore() 1036 for _, a := range aux { 1037 if a.Type() != asserts.AccountKeyType { 1038 continue 1039 } 1040 err := workBS.Put(asserts.AccountKeyType, a) 1041 if err != nil { 1042 return err 1043 } 1044 } 1045 trustedBS := asserts.NewMemoryBackstore() 1046 for _, t := range trustedRepairRootKeys { 1047 trustedBS.Put(asserts.AccountKeyType, t) 1048 } 1049 for _, t := range sysdb.Trusted() { 1050 // we do *not* add the defalt sysdb trusted account 1051 // keys here because the repair assertions have their 1052 // own *dedicated* root of trust 1053 if t.Type() == asserts.AccountType { 1054 trustedBS.Put(asserts.AccountType, t) 1055 } 1056 } 1057 1058 return verifySignatures(repair, workBS, trustedBS) 1059 }