github.com/rigado/snapd@v2.42.5-go-mod+incompatible/cmd/snap-repair/runner.go (about) 1 // -*- Mode: Go; indent-tabs-mode: t -*- 2 3 /* 4 * Copyright (C) 2017 Canonical Ltd 5 * 6 * This program is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU General Public License version 3 as 8 * published by the Free Software Foundation. 9 * 10 * This program is distributed in the hope that it will be useful, 11 * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 * GNU General Public License for more details. 14 * 15 * You should have received a copy of the GNU General Public License 16 * along with this program. If not, see <http://www.gnu.org/licenses/>. 17 * 18 */ 19 20 package main 21 22 import ( 23 "bufio" 24 "bytes" 25 "crypto/tls" 26 "encoding/json" 27 "errors" 28 "fmt" 29 "io" 30 "io/ioutil" 31 "net/http" 32 "net/url" 33 "os" 34 "os/exec" 35 "path/filepath" 36 "strconv" 37 "strings" 38 "syscall" 39 "time" 40 41 "gopkg.in/retry.v1" 42 43 "github.com/snapcore/snapd/arch" 44 "github.com/snapcore/snapd/asserts" 45 "github.com/snapcore/snapd/asserts/sysdb" 46 "github.com/snapcore/snapd/dirs" 47 "github.com/snapcore/snapd/errtracker" 48 "github.com/snapcore/snapd/httputil" 49 "github.com/snapcore/snapd/logger" 50 "github.com/snapcore/snapd/osutil" 51 "github.com/snapcore/snapd/release" 52 "github.com/snapcore/snapd/strutil" 53 ) 54 55 var ( 56 // TODO: move inside the repairs themselves? 57 defaultRepairTimeout = 30 * time.Minute 58 ) 59 60 var errtrackerReportRepair = errtracker.ReportRepair 61 62 // Repair is a runnable repair. 63 type Repair struct { 64 *asserts.Repair 65 66 run *Runner 67 sequence int 68 } 69 70 func (r *Repair) RunDir() string { 71 return filepath.Join(dirs.SnapRepairRunDir, r.BrandID(), strconv.Itoa(r.RepairID())) 72 } 73 74 func (r *Repair) String() string { 75 return fmt.Sprintf("%s-%v", r.BrandID(), r.RepairID()) 76 } 77 78 // SetStatus sets the status of the repair in the state and saves the latter. 79 func (r *Repair) SetStatus(status RepairStatus) { 80 brandID := r.BrandID() 81 cur := *r.run.state.Sequences[brandID][r.sequence-1] 82 cur.Status = status 83 r.run.setRepairState(brandID, cur) 84 r.run.SaveState() 85 } 86 87 // makeRepairSymlink ensures $dir/repair exists and is a symlink to 88 // /usr/lib/snapd/snap-repair 89 func makeRepairSymlink(dir string) (err error) { 90 // make "repair" binary available to the repair scripts via symlink 91 // to the real snap-repair 92 if err = os.MkdirAll(dir, 0755); err != nil { 93 return err 94 } 95 96 old := filepath.Join(dirs.CoreLibExecDir, "snap-repair") 97 new := filepath.Join(dir, "repair") 98 if err := os.Symlink(old, new); err != nil && !os.IsExist(err) { 99 return err 100 } 101 102 return nil 103 } 104 105 // Run executes the repair script leaving execution trail files on disk. 106 func (r *Repair) Run() error { 107 // write the script to disk 108 rundir := r.RunDir() 109 err := os.MkdirAll(rundir, 0775) 110 if err != nil { 111 return err 112 } 113 114 // ensure the script can use "repair done" 115 repairToolsDir := filepath.Join(dirs.SnapRunRepairDir, "tools") 116 if err := makeRepairSymlink(repairToolsDir); err != nil { 117 return err 118 } 119 120 baseName := fmt.Sprintf("r%d", r.Revision()) 121 script := filepath.Join(rundir, baseName+".script") 122 err = osutil.AtomicWriteFile(script, r.Body(), 0700, 0) 123 if err != nil { 124 return err 125 } 126 127 logPath := filepath.Join(rundir, baseName+".running") 128 logf, err := os.OpenFile(logPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) 129 if err != nil { 130 return err 131 } 132 defer logf.Close() 133 134 fmt.Fprintf(logf, "repair: %s\n", r) 135 fmt.Fprintf(logf, "revision: %d\n", r.Revision()) 136 fmt.Fprintf(logf, "summary: %s\n", r.Summary()) 137 fmt.Fprintf(logf, "output:\n") 138 139 statusR, statusW, err := os.Pipe() 140 if err != nil { 141 return err 142 } 143 defer statusR.Close() 144 defer statusW.Close() 145 146 logger.Debugf("executing %s", script) 147 148 // run the script 149 env := os.Environ() 150 // we need to hardcode FD=3 because this is the FD after 151 // exec.Command() forked. there is no way in go currently 152 // to run something right after fork() in the child to 153 // know the fd. However because go will close all fds 154 // except the ones in "cmd.ExtraFiles" we are safe to set "3" 155 env = append(env, "SNAP_REPAIR_STATUS_FD=3") 156 env = append(env, "SNAP_REPAIR_RUN_DIR="+rundir) 157 // inject repairToolDir into PATH so that the script can use 158 // `repair {done,skip,retry}` 159 var havePath bool 160 for i, envStr := range env { 161 if strings.HasPrefix(envStr, "PATH=") { 162 newEnv := fmt.Sprintf("%s:%s", strings.TrimSuffix(envStr, ":"), repairToolsDir) 163 env[i] = newEnv 164 havePath = true 165 } 166 } 167 if !havePath { 168 env = append(env, "PATH=/usr/sbin:/usr/bin:/sbin:/bin:"+repairToolsDir) 169 } 170 171 workdir := filepath.Join(rundir, "work") 172 if err := os.MkdirAll(workdir, 0700); err != nil { 173 return err 174 } 175 176 cmd := exec.Command(script) 177 cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} 178 cmd.Env = env 179 cmd.Dir = workdir 180 cmd.ExtraFiles = []*os.File{statusW} 181 cmd.Stdout = logf 182 cmd.Stderr = logf 183 if err = cmd.Start(); err != nil { 184 return err 185 } 186 statusW.Close() 187 188 // wait for repair to finish or timeout 189 var scriptErr error 190 killTimerCh := time.After(defaultRepairTimeout) 191 doneCh := make(chan error, 1) 192 go func() { 193 doneCh <- cmd.Wait() 194 close(doneCh) 195 }() 196 select { 197 case scriptErr = <-doneCh: 198 // done 199 case <-killTimerCh: 200 if err := osutil.KillProcessGroup(cmd); err != nil { 201 logger.Noticef("cannot kill timed out repair %s: %s", r, err) 202 } 203 scriptErr = fmt.Errorf("repair did not finish within %s", defaultRepairTimeout) 204 } 205 // read repair status pipe, use the last value 206 status := readStatus(statusR) 207 statusPath := filepath.Join(rundir, baseName+"."+status.String()) 208 209 // if the script had an error exit status still honor what we 210 // read from the status-pipe, however report the error 211 if scriptErr != nil { 212 scriptErr = fmt.Errorf("repair %s revision %d failed: %s", r, r.Revision(), scriptErr) 213 if err := r.errtrackerReport(scriptErr, status, logPath); err != nil { 214 logger.Noticef("cannot report error to errtracker: %s", err) 215 } 216 // ensure the error is present in the output log 217 fmt.Fprintf(logf, "\n%s", scriptErr) 218 } 219 if err := os.Rename(logPath, statusPath); err != nil { 220 return err 221 } 222 r.SetStatus(status) 223 224 return nil 225 } 226 227 func readStatus(r io.Reader) RepairStatus { 228 var status RepairStatus 229 scanner := bufio.NewScanner(r) 230 for scanner.Scan() { 231 switch strings.TrimSpace(scanner.Text()) { 232 case "done": 233 status = DoneStatus 234 // TODO: support having a script skip over many and up to a given repair-id # 235 case "skip": 236 status = SkipStatus 237 } 238 } 239 if scanner.Err() != nil { 240 return RetryStatus 241 } 242 return status 243 } 244 245 // errtrackerReport reports an repairErr with the given logPath to the 246 // snap error tracker. 247 func (r *Repair) errtrackerReport(repairErr error, status RepairStatus, logPath string) error { 248 errMsg := repairErr.Error() 249 250 scriptOutput, err := ioutil.ReadFile(logPath) 251 if err != nil { 252 logger.Noticef("cannot read %s", logPath) 253 } 254 s := fmt.Sprintf("%s/%d", r.BrandID(), r.RepairID()) 255 256 dupSig := fmt.Sprintf("%s\n%s\noutput:\n%s", s, errMsg, scriptOutput) 257 extra := map[string]string{ 258 "Revision": strconv.Itoa(r.Revision()), 259 "BrandID": r.BrandID(), 260 "RepairID": strconv.Itoa(r.RepairID()), 261 "Status": status.String(), 262 } 263 _, err = errtrackerReportRepair(s, errMsg, dupSig, extra) 264 return err 265 } 266 267 // Runner implements fetching, tracking and running repairs. 268 type Runner struct { 269 BaseURL *url.URL 270 cli *http.Client 271 272 state state 273 stateModified bool 274 275 // sequenceNext keeps track of the next integer id in a brand sequence to considered in this run, see Next. 276 sequenceNext map[string]int 277 } 278 279 // NewRunner returns a Runner. 280 func NewRunner() *Runner { 281 run := &Runner{ 282 sequenceNext: make(map[string]int), 283 } 284 opts := httputil.ClientOptions{ 285 MayLogBody: false, 286 TLSConfig: &tls.Config{ 287 Time: run.now, 288 }, 289 } 290 run.cli = httputil.NewHTTPClient(&opts) 291 return run 292 } 293 294 var ( 295 fetchRetryStrategy = retry.LimitCount(7, retry.LimitTime(90*time.Second, 296 retry.Exponential{ 297 Initial: 500 * time.Millisecond, 298 Factor: 2.5, 299 }, 300 )) 301 302 peekRetryStrategy = retry.LimitCount(5, retry.LimitTime(44*time.Second, 303 retry.Exponential{ 304 Initial: 300 * time.Millisecond, 305 Factor: 2.5, 306 }, 307 )) 308 ) 309 310 var ( 311 ErrRepairNotFound = errors.New("repair not found") 312 ErrRepairNotModified = errors.New("repair was not modified") 313 ) 314 315 var ( 316 maxRepairScriptSize = 24 * 1024 * 1024 317 ) 318 319 // Fetch retrieves a stream with the repair with the given ids and any 320 // auxiliary assertions. If revision>=0 the request will include an 321 // If-None-Match header with an ETag for the revision, and 322 // ErrRepairNotModified is returned if the revision is still current. 323 func (run *Runner) Fetch(brandID string, repairID int, revision int) (*asserts.Repair, []asserts.Assertion, error) { 324 u, err := run.BaseURL.Parse(fmt.Sprintf("repairs/%s/%d", brandID, repairID)) 325 if err != nil { 326 return nil, nil, err 327 } 328 329 var r []asserts.Assertion 330 resp, err := httputil.RetryRequest(u.String(), func() (*http.Response, error) { 331 req, err := http.NewRequest("GET", u.String(), nil) 332 if err != nil { 333 return nil, err 334 } 335 req.Header.Set("User-Agent", httputil.UserAgent()) 336 req.Header.Set("Accept", "application/x.ubuntu.assertion") 337 if revision >= 0 { 338 req.Header.Set("If-None-Match", fmt.Sprintf(`"%d"`, revision)) 339 } 340 return run.cli.Do(req) 341 }, func(resp *http.Response) error { 342 if resp.StatusCode == 200 { 343 logger.Debugf("fetching repair %s-%d", brandID, repairID) 344 // decode assertions 345 dec := asserts.NewDecoderWithTypeMaxBodySize(resp.Body, map[*asserts.AssertionType]int{ 346 asserts.RepairType: maxRepairScriptSize, 347 }) 348 for { 349 a, err := dec.Decode() 350 if err == io.EOF { 351 break 352 } 353 if err != nil { 354 return err 355 } 356 r = append(r, a) 357 } 358 if len(r) == 0 { 359 return io.ErrUnexpectedEOF 360 } 361 } 362 return nil 363 }, fetchRetryStrategy) 364 365 if err != nil { 366 return nil, nil, err 367 } 368 369 moveTimeLowerBound := true 370 defer func() { 371 if moveTimeLowerBound { 372 t, _ := http.ParseTime(resp.Header.Get("Date")) 373 run.moveTimeLowerBound(t) 374 } 375 }() 376 377 switch resp.StatusCode { 378 case 200: 379 // ok 380 case 304: 381 // not modified 382 return nil, nil, ErrRepairNotModified 383 case 404: 384 return nil, nil, ErrRepairNotFound 385 default: 386 moveTimeLowerBound = false 387 return nil, nil, fmt.Errorf("cannot fetch repair, unexpected status %d", resp.StatusCode) 388 } 389 390 repair, aux, err := checkStream(brandID, repairID, r) 391 if err != nil { 392 return nil, nil, fmt.Errorf("cannot fetch repair, %v", err) 393 } 394 395 if repair.Revision() <= revision { 396 // this shouldn't happen but if it does we behave like 397 // all the rest of assertion infrastructure and ignore 398 // the now superseded revision 399 return nil, nil, ErrRepairNotModified 400 } 401 402 return repair, aux, err 403 } 404 405 func checkStream(brandID string, repairID int, r []asserts.Assertion) (repair *asserts.Repair, aux []asserts.Assertion, err error) { 406 if len(r) == 0 { 407 return nil, nil, fmt.Errorf("empty repair assertions stream") 408 } 409 var ok bool 410 repair, ok = r[0].(*asserts.Repair) 411 if !ok { 412 return nil, nil, fmt.Errorf("unexpected first assertion %q", r[0].Type().Name) 413 } 414 415 if repair.BrandID() != brandID || repair.RepairID() != repairID { 416 return nil, nil, fmt.Errorf("repair id mismatch %s/%d != %s/%d", repair.BrandID(), repair.RepairID(), brandID, repairID) 417 } 418 419 return repair, r[1:], nil 420 } 421 422 type peekResp struct { 423 Headers map[string]interface{} `json:"headers"` 424 } 425 426 // Peek retrieves the headers for the repair with the given ids. 427 func (run *Runner) Peek(brandID string, repairID int) (headers map[string]interface{}, err error) { 428 u, err := run.BaseURL.Parse(fmt.Sprintf("repairs/%s/%d", brandID, repairID)) 429 if err != nil { 430 return nil, err 431 } 432 433 var rsp peekResp 434 435 resp, err := httputil.RetryRequest(u.String(), func() (*http.Response, error) { 436 req, err := http.NewRequest("GET", u.String(), nil) 437 if err != nil { 438 return nil, err 439 } 440 req.Header.Set("User-Agent", httputil.UserAgent()) 441 req.Header.Set("Accept", "application/json") 442 return run.cli.Do(req) 443 }, func(resp *http.Response) error { 444 rsp.Headers = nil 445 if resp.StatusCode == 200 { 446 dec := json.NewDecoder(resp.Body) 447 return dec.Decode(&rsp) 448 } 449 return nil 450 }, peekRetryStrategy) 451 452 if err != nil { 453 return nil, err 454 } 455 456 moveTimeLowerBound := true 457 defer func() { 458 if moveTimeLowerBound { 459 t, _ := http.ParseTime(resp.Header.Get("Date")) 460 run.moveTimeLowerBound(t) 461 } 462 }() 463 464 switch resp.StatusCode { 465 case 200: 466 // ok 467 case 404: 468 return nil, ErrRepairNotFound 469 default: 470 moveTimeLowerBound = false 471 return nil, fmt.Errorf("cannot peek repair headers, unexpected status %d", resp.StatusCode) 472 } 473 474 headers = rsp.Headers 475 if headers["brand-id"] != brandID || headers["repair-id"] != strconv.Itoa(repairID) { 476 return nil, fmt.Errorf("cannot peek repair headers, repair id mismatch %s/%s != %s/%d", headers["brand-id"], headers["repair-id"], brandID, repairID) 477 } 478 479 return headers, nil 480 } 481 482 // deviceInfo captures information about the device. 483 type deviceInfo struct { 484 Brand string `json:"brand"` 485 Model string `json:"model"` 486 } 487 488 // RepairStatus represents the possible statuses of a repair. 489 type RepairStatus int 490 491 const ( 492 RetryStatus RepairStatus = iota 493 SkipStatus 494 DoneStatus 495 ) 496 497 func (rs RepairStatus) String() string { 498 switch rs { 499 case RetryStatus: 500 return "retry" 501 case SkipStatus: 502 return "skip" 503 case DoneStatus: 504 return "done" 505 default: 506 return "unknown" 507 } 508 } 509 510 // RepairState holds the current revision and status of a repair in a sequence of repairs. 511 type RepairState struct { 512 Sequence int `json:"sequence"` 513 Revision int `json:"revision"` 514 Status RepairStatus `json:"status"` 515 } 516 517 // state holds the atomically updated control state of the runner with sequences of repairs and their states. 518 type state struct { 519 Device deviceInfo `json:"device"` 520 Sequences map[string][]*RepairState `json:"sequences,omitempty"` 521 TimeLowerBound time.Time `json:"time-lower-bound"` 522 } 523 524 func (run *Runner) setRepairState(brandID string, state RepairState) { 525 if run.state.Sequences == nil { 526 run.state.Sequences = make(map[string][]*RepairState) 527 } 528 sequence := run.state.Sequences[brandID] 529 if state.Sequence > len(sequence) { 530 run.stateModified = true 531 run.state.Sequences[brandID] = append(sequence, &state) 532 } else if *sequence[state.Sequence-1] != state { 533 run.stateModified = true 534 sequence[state.Sequence-1] = &state 535 } 536 } 537 538 func (run *Runner) readState() error { 539 r, err := os.Open(dirs.SnapRepairStateFile) 540 if err != nil { 541 return err 542 } 543 defer r.Close() 544 dec := json.NewDecoder(r) 545 return dec.Decode(&run.state) 546 } 547 548 func (run *Runner) moveTimeLowerBound(t time.Time) { 549 if t.After(run.state.TimeLowerBound) { 550 run.stateModified = true 551 run.state.TimeLowerBound = t.UTC() 552 } 553 } 554 555 var timeNow = time.Now 556 557 func (run *Runner) now() time.Time { 558 now := timeNow().UTC() 559 if now.Before(run.state.TimeLowerBound) { 560 return run.state.TimeLowerBound 561 } 562 return now 563 } 564 565 func (run *Runner) initState() error { 566 if err := os.MkdirAll(dirs.SnapRepairDir, 0775); err != nil { 567 return fmt.Errorf("cannot create repair state directory: %v", err) 568 } 569 // best-effort remove old 570 os.Remove(dirs.SnapRepairStateFile) 571 run.state = state{} 572 // initialize time lower bound with image built time/seed.yaml time 573 info, err := os.Stat(filepath.Join(dirs.SnapSeedDir, "seed.yaml")) 574 if err != nil { 575 return err 576 } 577 run.moveTimeLowerBound(info.ModTime()) 578 // initialize device info 579 if err := run.initDeviceInfo(); err != nil { 580 return err 581 } 582 run.stateModified = true 583 return run.SaveState() 584 } 585 586 func trustedBackstore(trusted []asserts.Assertion) asserts.Backstore { 587 trustedBS := asserts.NewMemoryBackstore() 588 for _, t := range trusted { 589 trustedBS.Put(t.Type(), t) 590 } 591 return trustedBS 592 } 593 594 func checkAuthorityID(a asserts.Assertion, trusted asserts.Backstore) error { 595 assertType := a.Type() 596 if assertType != asserts.AccountKeyType && assertType != asserts.AccountType { 597 return nil 598 } 599 // check that account and account-key assertions are signed by 600 // a trusted authority 601 acctID := a.AuthorityID() 602 _, err := trusted.Get(asserts.AccountType, []string{acctID}, asserts.AccountType.MaxSupportedFormat()) 603 if err != nil && !asserts.IsNotFound(err) { 604 return err 605 } 606 if asserts.IsNotFound(err) { 607 return fmt.Errorf("%v not signed by trusted authority: %s", a.Ref(), acctID) 608 } 609 return nil 610 } 611 612 func verifySignatures(a asserts.Assertion, workBS asserts.Backstore, trusted asserts.Backstore) error { 613 if err := checkAuthorityID(a, trusted); err != nil { 614 return err 615 } 616 acctKeyMaxSuppFormat := asserts.AccountKeyType.MaxSupportedFormat() 617 618 seen := make(map[string]bool) 619 bottom := false 620 for !bottom { 621 u := a.Ref().Unique() 622 if seen[u] { 623 return fmt.Errorf("circular assertions") 624 } 625 seen[u] = true 626 signKey := []string{a.SignKeyID()} 627 key, err := trusted.Get(asserts.AccountKeyType, signKey, acctKeyMaxSuppFormat) 628 if err != nil && !asserts.IsNotFound(err) { 629 return err 630 } 631 if err == nil { 632 bottom = true 633 } else { 634 key, err = workBS.Get(asserts.AccountKeyType, signKey, acctKeyMaxSuppFormat) 635 if err != nil && !asserts.IsNotFound(err) { 636 return err 637 } 638 if asserts.IsNotFound(err) { 639 return fmt.Errorf("cannot find public key %q", signKey[0]) 640 } 641 if err := checkAuthorityID(key, trusted); err != nil { 642 return err 643 } 644 } 645 if err := asserts.CheckSignature(a, key.(*asserts.AccountKey), nil, time.Time{}); err != nil { 646 return err 647 } 648 a = key 649 } 650 return nil 651 } 652 653 func (run *Runner) initDeviceInfo() error { 654 const errPrefix = "cannot set device information: " 655 656 workBS := asserts.NewMemoryBackstore() 657 assertSeedDir := filepath.Join(dirs.SnapSeedDir, "assertions") 658 dc, err := ioutil.ReadDir(assertSeedDir) 659 if err != nil { 660 return err 661 } 662 var model *asserts.Model 663 for _, fi := range dc { 664 fn := filepath.Join(assertSeedDir, fi.Name()) 665 f, err := os.Open(fn) 666 if err != nil { 667 // best effort 668 continue 669 } 670 dec := asserts.NewDecoder(f) 671 for { 672 a, err := dec.Decode() 673 if err != nil { 674 // best effort 675 break 676 } 677 switch a.Type() { 678 case asserts.ModelType: 679 if model != nil { 680 return fmt.Errorf(errPrefix + "multiple models in seed assertions") 681 } 682 model = a.(*asserts.Model) 683 case asserts.AccountType, asserts.AccountKeyType: 684 workBS.Put(a.Type(), a) 685 } 686 } 687 } 688 if model == nil { 689 return fmt.Errorf(errPrefix + "no model assertion in seed data") 690 } 691 trustedBS := trustedBackstore(sysdb.Trusted()) 692 if err := verifySignatures(model, workBS, trustedBS); err != nil { 693 return fmt.Errorf(errPrefix+"%v", err) 694 } 695 acctPK := []string{model.BrandID()} 696 acctMaxSupFormat := asserts.AccountType.MaxSupportedFormat() 697 acct, err := trustedBS.Get(asserts.AccountType, acctPK, acctMaxSupFormat) 698 if err != nil { 699 var err error 700 acct, err = workBS.Get(asserts.AccountType, acctPK, acctMaxSupFormat) 701 if err != nil { 702 return fmt.Errorf(errPrefix + "no brand account assertion in seed data") 703 } 704 } 705 if err := verifySignatures(acct, workBS, trustedBS); err != nil { 706 return fmt.Errorf(errPrefix+"%v", err) 707 } 708 run.state.Device.Brand = model.BrandID() 709 run.state.Device.Model = model.Model() 710 return nil 711 } 712 713 // LoadState loads the repairs' state from disk, and (re)initializes it if it's missing or corrupted. 714 func (run *Runner) LoadState() error { 715 err := run.readState() 716 if err == nil { 717 return nil 718 } 719 // error => initialize from scratch 720 if !os.IsNotExist(err) { 721 logger.Noticef("cannor read repair state: %v", err) 722 } 723 return run.initState() 724 } 725 726 // SaveState saves the repairs' state to disk. 727 func (run *Runner) SaveState() error { 728 if !run.stateModified { 729 return nil 730 } 731 m, err := json.Marshal(&run.state) 732 if err != nil { 733 return fmt.Errorf("cannot marshal repair state: %v", err) 734 } 735 err = osutil.AtomicWriteFile(dirs.SnapRepairStateFile, m, 0600, 0) 736 if err != nil { 737 return fmt.Errorf("cannot save repair state: %v", err) 738 } 739 run.stateModified = false 740 return nil 741 } 742 743 func stringList(headers map[string]interface{}, name string) ([]string, error) { 744 v, ok := headers[name] 745 if !ok { 746 return nil, nil 747 } 748 l, ok := v.([]interface{}) 749 if !ok { 750 return nil, fmt.Errorf("header %q is not a list", name) 751 } 752 r := make([]string, len(l)) 753 for i, v := range l { 754 s, ok := v.(string) 755 if !ok { 756 return nil, fmt.Errorf("header %q contains non-string elements", name) 757 } 758 r[i] = s 759 } 760 return r, nil 761 } 762 763 // Applicable returns whether a repair with the given headers is applicable to the device. 764 func (run *Runner) Applicable(headers map[string]interface{}) bool { 765 if headers["disabled"] == "true" { 766 return false 767 } 768 series, err := stringList(headers, "series") 769 if err != nil { 770 return false 771 } 772 if len(series) != 0 && !strutil.ListContains(series, release.Series) { 773 return false 774 } 775 archs, err := stringList(headers, "architectures") 776 if err != nil { 777 return false 778 } 779 if len(archs) != 0 && !strutil.ListContains(archs, arch.DpkgArchitecture()) { 780 return false 781 } 782 brandModel := fmt.Sprintf("%s/%s", run.state.Device.Brand, run.state.Device.Model) 783 models, err := stringList(headers, "models") 784 if err != nil { 785 return false 786 } 787 if len(models) != 0 && !strutil.ListContains(models, brandModel) { 788 // model prefix matching: brand/prefix* 789 hit := false 790 for _, patt := range models { 791 if strings.HasSuffix(patt, "*") && strings.ContainsRune(patt, '/') { 792 if strings.HasPrefix(brandModel, strings.TrimSuffix(patt, "*")) { 793 hit = true 794 break 795 } 796 } 797 } 798 if !hit { 799 return false 800 } 801 } 802 return true 803 } 804 805 var errSkip = errors.New("repair unnecessary on this system") 806 807 func (run *Runner) fetch(brandID string, repairID int) (repair *asserts.Repair, aux []asserts.Assertion, err error) { 808 headers, err := run.Peek(brandID, repairID) 809 if err != nil { 810 return nil, nil, err 811 } 812 if !run.Applicable(headers) { 813 return nil, nil, errSkip 814 } 815 return run.Fetch(brandID, repairID, -1) 816 } 817 818 func (run *Runner) refetch(brandID string, repairID, revision int) (repair *asserts.Repair, aux []asserts.Assertion, err error) { 819 return run.Fetch(brandID, repairID, revision) 820 } 821 822 func (run *Runner) saveStream(brandID string, repairID int, repair *asserts.Repair, aux []asserts.Assertion) error { 823 d := filepath.Join(dirs.SnapRepairAssertsDir, brandID, strconv.Itoa(repairID)) 824 err := os.MkdirAll(d, 0775) 825 if err != nil { 826 return err 827 } 828 buf := &bytes.Buffer{} 829 enc := asserts.NewEncoder(buf) 830 r := append([]asserts.Assertion{repair}, aux...) 831 for _, a := range r { 832 if err := enc.Encode(a); err != nil { 833 return fmt.Errorf("cannot encode repair assertions %s-%d for saving: %v", brandID, repairID, err) 834 } 835 } 836 p := filepath.Join(d, fmt.Sprintf("r%d.repair", r[0].Revision())) 837 return osutil.AtomicWriteFile(p, buf.Bytes(), 0600, 0) 838 } 839 840 func (run *Runner) readSavedStream(brandID string, repairID, revision int) (repair *asserts.Repair, aux []asserts.Assertion, err error) { 841 d := filepath.Join(dirs.SnapRepairAssertsDir, brandID, strconv.Itoa(repairID)) 842 p := filepath.Join(d, fmt.Sprintf("r%d.repair", revision)) 843 f, err := os.Open(p) 844 if err != nil { 845 return nil, nil, err 846 } 847 defer f.Close() 848 849 dec := asserts.NewDecoder(f) 850 var r []asserts.Assertion 851 for { 852 a, err := dec.Decode() 853 if err == io.EOF { 854 break 855 } 856 if err != nil { 857 return nil, nil, fmt.Errorf("cannot decode repair assertions %s-%d from disk: %v", brandID, repairID, err) 858 } 859 r = append(r, a) 860 } 861 return checkStream(brandID, repairID, r) 862 } 863 864 func (run *Runner) makeReady(brandID string, sequenceNext int) (repair *asserts.Repair, err error) { 865 sequence := run.state.Sequences[brandID] 866 var aux []asserts.Assertion 867 var state RepairState 868 if sequenceNext <= len(sequence) { 869 // consider retries 870 state = *sequence[sequenceNext-1] 871 if state.Status != RetryStatus { 872 return nil, errSkip 873 } 874 var err error 875 repair, aux, err = run.refetch(brandID, state.Sequence, state.Revision) 876 if err != nil { 877 if err != ErrRepairNotModified { 878 logger.Noticef("cannot refetch repair %s-%d, will retry what is on disk: %v", brandID, sequenceNext, err) 879 } 880 // try to use what we have already on disk 881 repair, aux, err = run.readSavedStream(brandID, state.Sequence, state.Revision) 882 if err != nil { 883 return nil, err 884 } 885 } 886 } else { 887 // fetch the next repair in the sequence 888 // assumes no gaps, each repair id is present so far, 889 // possibly skipped 890 var err error 891 repair, aux, err = run.fetch(brandID, sequenceNext) 892 if err != nil && err != errSkip { 893 return nil, err 894 } 895 state = RepairState{ 896 Sequence: sequenceNext, 897 } 898 if err == errSkip { 899 // TODO: store headers to justify decision 900 state.Status = SkipStatus 901 run.setRepairState(brandID, state) 902 return nil, errSkip 903 } 904 } 905 // verify with signatures 906 if err := run.Verify(repair, aux); err != nil { 907 return nil, fmt.Errorf("cannot verify repair %s-%d: %v", brandID, state.Sequence, err) 908 } 909 if err := run.saveStream(brandID, state.Sequence, repair, aux); err != nil { 910 return nil, err 911 } 912 state.Revision = repair.Revision() 913 if !run.Applicable(repair.Headers()) { 914 state.Status = SkipStatus 915 run.setRepairState(brandID, state) 916 return nil, errSkip 917 } 918 run.setRepairState(brandID, state) 919 return repair, nil 920 } 921 922 // Next returns the next repair for the brand id sequence to run/retry or ErrRepairNotFound if there is none atm. It updates the state as required. 923 func (run *Runner) Next(brandID string) (*Repair, error) { 924 sequenceNext := run.sequenceNext[brandID] 925 if sequenceNext == 0 { 926 sequenceNext = 1 927 } 928 for { 929 repair, err := run.makeReady(brandID, sequenceNext) 930 // SaveState is a no-op unless makeReady modified the state 931 stateErr := run.SaveState() 932 if err != nil && err != errSkip && err != ErrRepairNotFound { 933 // err is a non trivial error, just log the SaveState error and report err 934 if stateErr != nil { 935 logger.Noticef("%v", stateErr) 936 } 937 return nil, err 938 } 939 if stateErr != nil { 940 return nil, stateErr 941 } 942 if err == ErrRepairNotFound { 943 return nil, ErrRepairNotFound 944 } 945 946 sequenceNext += 1 947 run.sequenceNext[brandID] = sequenceNext 948 if err == errSkip { 949 continue 950 } 951 952 return &Repair{ 953 Repair: repair, 954 run: run, 955 sequence: sequenceNext - 1, 956 }, nil 957 } 958 } 959 960 // Limit trust to specific keys while there's no delegation or limited 961 // keys support. The obtained assertion stream may also include 962 // account keys that are directly or indirectly signed by a trusted 963 // key. 964 var ( 965 trustedRepairRootKeys []*asserts.AccountKey 966 ) 967 968 // Verify verifies that the repair is properly signed by the specific 969 // trusted root keys or by account keys in the stream (passed via aux) 970 // directly or indirectly signed by a trusted key. 971 func (run *Runner) Verify(repair *asserts.Repair, aux []asserts.Assertion) error { 972 workBS := asserts.NewMemoryBackstore() 973 for _, a := range aux { 974 if a.Type() != asserts.AccountKeyType { 975 continue 976 } 977 err := workBS.Put(asserts.AccountKeyType, a) 978 if err != nil { 979 return err 980 } 981 } 982 trustedBS := asserts.NewMemoryBackstore() 983 for _, t := range trustedRepairRootKeys { 984 trustedBS.Put(asserts.AccountKeyType, t) 985 } 986 for _, t := range sysdb.Trusted() { 987 if t.Type() == asserts.AccountType { 988 trustedBS.Put(asserts.AccountType, t) 989 } 990 } 991 992 return verifySignatures(repair, workBS, trustedBS) 993 }