github.com/hugh712/snapd@v0.0.0-20200910133618-1a99902bd583/cmd/snap-repair/runner.go (about) 1 // -*- Mode: Go; indent-tabs-mode: t -*- 2 3 /* 4 * Copyright (C) 2017-2020 Canonical Ltd 5 * 6 * This program is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU General Public License version 3 as 8 * published by the Free Software Foundation. 9 * 10 * This program is distributed in the hope that it will be useful, 11 * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 * GNU General Public License for more details. 14 * 15 * You should have received a copy of the GNU General Public License 16 * along with this program. If not, see <http://www.gnu.org/licenses/>. 17 * 18 */ 19 20 package main 21 22 import ( 23 "bufio" 24 "bytes" 25 "crypto/tls" 26 "encoding/json" 27 "errors" 28 "fmt" 29 "io" 30 "io/ioutil" 31 "net/http" 32 "net/url" 33 "os" 34 "os/exec" 35 "path/filepath" 36 "strconv" 37 "strings" 38 "syscall" 39 "time" 40 41 "gopkg.in/retry.v1" 42 43 "github.com/snapcore/snapd/arch" 44 "github.com/snapcore/snapd/asserts" 45 "github.com/snapcore/snapd/asserts/sysdb" 46 "github.com/snapcore/snapd/dirs" 47 "github.com/snapcore/snapd/errtracker" 48 "github.com/snapcore/snapd/httputil" 49 "github.com/snapcore/snapd/logger" 50 "github.com/snapcore/snapd/osutil" 51 "github.com/snapcore/snapd/release" 52 "github.com/snapcore/snapd/snapdenv" 53 "github.com/snapcore/snapd/strutil" 54 ) 55 56 var ( 57 // TODO: move inside the repairs themselves? 58 defaultRepairTimeout = 30 * time.Minute 59 ) 60 61 var errtrackerReportRepair = errtracker.ReportRepair 62 63 // Repair is a runnable repair. 64 type Repair struct { 65 *asserts.Repair 66 67 run *Runner 68 sequence int 69 } 70 71 func (r *Repair) RunDir() string { 72 return filepath.Join(dirs.SnapRepairRunDir, r.BrandID(), strconv.Itoa(r.RepairID())) 73 } 74 75 func (r *Repair) String() string { 76 return fmt.Sprintf("%s-%v", r.BrandID(), r.RepairID()) 77 } 78 79 // SetStatus sets the status of the repair in the state and saves the latter. 80 func (r *Repair) SetStatus(status RepairStatus) { 81 brandID := r.BrandID() 82 cur := *r.run.state.Sequences[brandID][r.sequence-1] 83 cur.Status = status 84 r.run.setRepairState(brandID, cur) 85 r.run.SaveState() 86 } 87 88 // makeRepairSymlink ensures $dir/repair exists and is a symlink to 89 // /usr/lib/snapd/snap-repair 90 func makeRepairSymlink(dir string) (err error) { 91 // make "repair" binary available to the repair scripts via symlink 92 // to the real snap-repair 93 if err = os.MkdirAll(dir, 0755); err != nil { 94 return err 95 } 96 97 old := filepath.Join(dirs.CoreLibExecDir, "snap-repair") 98 new := filepath.Join(dir, "repair") 99 if err := os.Symlink(old, new); err != nil && !os.IsExist(err) { 100 return err 101 } 102 103 return nil 104 } 105 106 // Run executes the repair script leaving execution trail files on disk. 107 func (r *Repair) Run() error { 108 // write the script to disk 109 rundir := r.RunDir() 110 err := os.MkdirAll(rundir, 0775) 111 if err != nil { 112 return err 113 } 114 115 // ensure the script can use "repair done" 116 repairToolsDir := filepath.Join(dirs.SnapRunRepairDir, "tools") 117 if err := makeRepairSymlink(repairToolsDir); err != nil { 118 return err 119 } 120 121 baseName := fmt.Sprintf("r%d", r.Revision()) 122 script := filepath.Join(rundir, baseName+".script") 123 err = osutil.AtomicWriteFile(script, r.Body(), 0700, 0) 124 if err != nil { 125 return err 126 } 127 128 logPath := filepath.Join(rundir, baseName+".running") 129 logf, err := os.OpenFile(logPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600) 130 if err != nil { 131 return err 132 } 133 defer logf.Close() 134 135 fmt.Fprintf(logf, "repair: %s\n", r) 136 fmt.Fprintf(logf, "revision: %d\n", r.Revision()) 137 fmt.Fprintf(logf, "summary: %s\n", r.Summary()) 138 fmt.Fprintf(logf, "output:\n") 139 140 statusR, statusW, err := os.Pipe() 141 if err != nil { 142 return err 143 } 144 defer statusR.Close() 145 defer statusW.Close() 146 147 logger.Debugf("executing %s", script) 148 149 // run the script 150 env := os.Environ() 151 // we need to hardcode FD=3 because this is the FD after 152 // exec.Command() forked. there is no way in go currently 153 // to run something right after fork() in the child to 154 // know the fd. However because go will close all fds 155 // except the ones in "cmd.ExtraFiles" we are safe to set "3" 156 env = append(env, "SNAP_REPAIR_STATUS_FD=3") 157 env = append(env, "SNAP_REPAIR_RUN_DIR="+rundir) 158 // inject repairToolDir into PATH so that the script can use 159 // `repair {done,skip,retry}` 160 var havePath bool 161 for i, envStr := range env { 162 if strings.HasPrefix(envStr, "PATH=") { 163 newEnv := fmt.Sprintf("%s:%s", strings.TrimSuffix(envStr, ":"), repairToolsDir) 164 env[i] = newEnv 165 havePath = true 166 } 167 } 168 if !havePath { 169 env = append(env, "PATH=/usr/sbin:/usr/bin:/sbin:/bin:"+repairToolsDir) 170 } 171 172 workdir := filepath.Join(rundir, "work") 173 if err := os.MkdirAll(workdir, 0700); err != nil { 174 return err 175 } 176 177 cmd := exec.Command(script) 178 cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} 179 cmd.Env = env 180 cmd.Dir = workdir 181 cmd.ExtraFiles = []*os.File{statusW} 182 cmd.Stdout = logf 183 cmd.Stderr = logf 184 if err = cmd.Start(); err != nil { 185 return err 186 } 187 statusW.Close() 188 189 // wait for repair to finish or timeout 190 var scriptErr error 191 killTimerCh := time.After(defaultRepairTimeout) 192 doneCh := make(chan error, 1) 193 go func() { 194 doneCh <- cmd.Wait() 195 close(doneCh) 196 }() 197 select { 198 case scriptErr = <-doneCh: 199 // done 200 case <-killTimerCh: 201 if err := osutil.KillProcessGroup(cmd); err != nil { 202 logger.Noticef("cannot kill timed out repair %s: %s", r, err) 203 } 204 scriptErr = fmt.Errorf("repair did not finish within %s", defaultRepairTimeout) 205 } 206 // read repair status pipe, use the last value 207 status := readStatus(statusR) 208 statusPath := filepath.Join(rundir, baseName+"."+status.String()) 209 210 // if the script had an error exit status still honor what we 211 // read from the status-pipe, however report the error 212 if scriptErr != nil { 213 scriptErr = fmt.Errorf("repair %s revision %d failed: %s", r, r.Revision(), scriptErr) 214 if err := r.errtrackerReport(scriptErr, status, logPath); err != nil { 215 logger.Noticef("cannot report error to errtracker: %s", err) 216 } 217 // ensure the error is present in the output log 218 fmt.Fprintf(logf, "\n%s", scriptErr) 219 } 220 if err := os.Rename(logPath, statusPath); err != nil { 221 return err 222 } 223 r.SetStatus(status) 224 225 return nil 226 } 227 228 func readStatus(r io.Reader) RepairStatus { 229 var status RepairStatus 230 scanner := bufio.NewScanner(r) 231 for scanner.Scan() { 232 switch strings.TrimSpace(scanner.Text()) { 233 case "done": 234 status = DoneStatus 235 // TODO: support having a script skip over many and up to a given repair-id # 236 case "skip": 237 status = SkipStatus 238 } 239 } 240 if scanner.Err() != nil { 241 return RetryStatus 242 } 243 return status 244 } 245 246 // errtrackerReport reports an repairErr with the given logPath to the 247 // snap error tracker. 248 func (r *Repair) errtrackerReport(repairErr error, status RepairStatus, logPath string) error { 249 errMsg := repairErr.Error() 250 251 scriptOutput, err := ioutil.ReadFile(logPath) 252 if err != nil { 253 logger.Noticef("cannot read %s", logPath) 254 } 255 s := fmt.Sprintf("%s/%d", r.BrandID(), r.RepairID()) 256 257 dupSig := fmt.Sprintf("%s\n%s\noutput:\n%s", s, errMsg, scriptOutput) 258 extra := map[string]string{ 259 "Revision": strconv.Itoa(r.Revision()), 260 "BrandID": r.BrandID(), 261 "RepairID": strconv.Itoa(r.RepairID()), 262 "Status": status.String(), 263 } 264 _, err = errtrackerReportRepair(s, errMsg, dupSig, extra) 265 return err 266 } 267 268 // Runner implements fetching, tracking and running repairs. 269 type Runner struct { 270 BaseURL *url.URL 271 cli *http.Client 272 273 state state 274 stateModified bool 275 276 // sequenceNext keeps track of the next integer id in a brand sequence to considered in this run, see Next. 277 sequenceNext map[string]int 278 } 279 280 // NewRunner returns a Runner. 281 func NewRunner() *Runner { 282 run := &Runner{ 283 sequenceNext: make(map[string]int), 284 } 285 opts := httputil.ClientOptions{ 286 MayLogBody: false, 287 ProxyConnectHeader: http.Header{"User-Agent": []string{snapdenv.UserAgent()}}, 288 TLSConfig: &tls.Config{ 289 Time: run.now, 290 }, 291 ExtraSSLCerts: &httputil.ExtraSSLCertsFromDir{ 292 Dir: dirs.SnapdStoreSSLCertsDir, 293 }, 294 } 295 run.cli = httputil.NewHTTPClient(&opts) 296 return run 297 } 298 299 var ( 300 fetchRetryStrategy = retry.LimitCount(7, retry.LimitTime(90*time.Second, 301 retry.Exponential{ 302 Initial: 500 * time.Millisecond, 303 Factor: 2.5, 304 }, 305 )) 306 307 peekRetryStrategy = retry.LimitCount(5, retry.LimitTime(44*time.Second, 308 retry.Exponential{ 309 Initial: 300 * time.Millisecond, 310 Factor: 2.5, 311 }, 312 )) 313 ) 314 315 var ( 316 ErrRepairNotFound = errors.New("repair not found") 317 ErrRepairNotModified = errors.New("repair was not modified") 318 ) 319 320 var ( 321 maxRepairScriptSize = 24 * 1024 * 1024 322 ) 323 324 // Fetch retrieves a stream with the repair with the given ids and any 325 // auxiliary assertions. If revision>=0 the request will include an 326 // If-None-Match header with an ETag for the revision, and 327 // ErrRepairNotModified is returned if the revision is still current. 328 func (run *Runner) Fetch(brandID string, repairID int, revision int) (*asserts.Repair, []asserts.Assertion, error) { 329 u, err := run.BaseURL.Parse(fmt.Sprintf("repairs/%s/%d", brandID, repairID)) 330 if err != nil { 331 return nil, nil, err 332 } 333 334 var r []asserts.Assertion 335 resp, err := httputil.RetryRequest(u.String(), func() (*http.Response, error) { 336 req, err := http.NewRequest("GET", u.String(), nil) 337 if err != nil { 338 return nil, err 339 } 340 req.Header.Set("User-Agent", snapdenv.UserAgent()) 341 req.Header.Set("Accept", "application/x.ubuntu.assertion") 342 if revision >= 0 { 343 req.Header.Set("If-None-Match", fmt.Sprintf(`"%d"`, revision)) 344 } 345 return run.cli.Do(req) 346 }, func(resp *http.Response) error { 347 if resp.StatusCode == 200 { 348 logger.Debugf("fetching repair %s-%d", brandID, repairID) 349 // decode assertions 350 dec := asserts.NewDecoderWithTypeMaxBodySize(resp.Body, map[*asserts.AssertionType]int{ 351 asserts.RepairType: maxRepairScriptSize, 352 }) 353 for { 354 a, err := dec.Decode() 355 if err == io.EOF { 356 break 357 } 358 if err != nil { 359 return err 360 } 361 r = append(r, a) 362 } 363 if len(r) == 0 { 364 return io.ErrUnexpectedEOF 365 } 366 } 367 return nil 368 }, fetchRetryStrategy) 369 370 if err != nil { 371 return nil, nil, err 372 } 373 374 moveTimeLowerBound := true 375 defer func() { 376 if moveTimeLowerBound { 377 t, _ := http.ParseTime(resp.Header.Get("Date")) 378 run.moveTimeLowerBound(t) 379 } 380 }() 381 382 switch resp.StatusCode { 383 case 200: 384 // ok 385 case 304: 386 // not modified 387 return nil, nil, ErrRepairNotModified 388 case 404: 389 return nil, nil, ErrRepairNotFound 390 default: 391 moveTimeLowerBound = false 392 return nil, nil, fmt.Errorf("cannot fetch repair, unexpected status %d", resp.StatusCode) 393 } 394 395 repair, aux, err := checkStream(brandID, repairID, r) 396 if err != nil { 397 return nil, nil, fmt.Errorf("cannot fetch repair, %v", err) 398 } 399 400 if repair.Revision() <= revision { 401 // this shouldn't happen but if it does we behave like 402 // all the rest of assertion infrastructure and ignore 403 // the now superseded revision 404 return nil, nil, ErrRepairNotModified 405 } 406 407 return repair, aux, err 408 } 409 410 func checkStream(brandID string, repairID int, r []asserts.Assertion) (repair *asserts.Repair, aux []asserts.Assertion, err error) { 411 if len(r) == 0 { 412 return nil, nil, fmt.Errorf("empty repair assertions stream") 413 } 414 var ok bool 415 repair, ok = r[0].(*asserts.Repair) 416 if !ok { 417 return nil, nil, fmt.Errorf("unexpected first assertion %q", r[0].Type().Name) 418 } 419 420 if repair.BrandID() != brandID || repair.RepairID() != repairID { 421 return nil, nil, fmt.Errorf("repair id mismatch %s/%d != %s/%d", repair.BrandID(), repair.RepairID(), brandID, repairID) 422 } 423 424 return repair, r[1:], nil 425 } 426 427 type peekResp struct { 428 Headers map[string]interface{} `json:"headers"` 429 } 430 431 // Peek retrieves the headers for the repair with the given ids. 432 func (run *Runner) Peek(brandID string, repairID int) (headers map[string]interface{}, err error) { 433 u, err := run.BaseURL.Parse(fmt.Sprintf("repairs/%s/%d", brandID, repairID)) 434 if err != nil { 435 return nil, err 436 } 437 438 var rsp peekResp 439 440 resp, err := httputil.RetryRequest(u.String(), func() (*http.Response, error) { 441 req, err := http.NewRequest("GET", u.String(), nil) 442 if err != nil { 443 return nil, err 444 } 445 req.Header.Set("User-Agent", snapdenv.UserAgent()) 446 req.Header.Set("Accept", "application/json") 447 return run.cli.Do(req) 448 }, func(resp *http.Response) error { 449 rsp.Headers = nil 450 if resp.StatusCode == 200 { 451 dec := json.NewDecoder(resp.Body) 452 return dec.Decode(&rsp) 453 } 454 return nil 455 }, peekRetryStrategy) 456 457 if err != nil { 458 return nil, err 459 } 460 461 moveTimeLowerBound := true 462 defer func() { 463 if moveTimeLowerBound { 464 t, _ := http.ParseTime(resp.Header.Get("Date")) 465 run.moveTimeLowerBound(t) 466 } 467 }() 468 469 switch resp.StatusCode { 470 case 200: 471 // ok 472 case 404: 473 return nil, ErrRepairNotFound 474 default: 475 moveTimeLowerBound = false 476 return nil, fmt.Errorf("cannot peek repair headers, unexpected status %d", resp.StatusCode) 477 } 478 479 headers = rsp.Headers 480 if headers["brand-id"] != brandID || headers["repair-id"] != strconv.Itoa(repairID) { 481 return nil, fmt.Errorf("cannot peek repair headers, repair id mismatch %s/%s != %s/%d", headers["brand-id"], headers["repair-id"], brandID, repairID) 482 } 483 484 return headers, nil 485 } 486 487 // deviceInfo captures information about the device. 488 type deviceInfo struct { 489 Brand string `json:"brand"` 490 Model string `json:"model"` 491 } 492 493 // RepairStatus represents the possible statuses of a repair. 494 type RepairStatus int 495 496 const ( 497 RetryStatus RepairStatus = iota 498 SkipStatus 499 DoneStatus 500 ) 501 502 func (rs RepairStatus) String() string { 503 switch rs { 504 case RetryStatus: 505 return "retry" 506 case SkipStatus: 507 return "skip" 508 case DoneStatus: 509 return "done" 510 default: 511 return "unknown" 512 } 513 } 514 515 // RepairState holds the current revision and status of a repair in a sequence of repairs. 516 type RepairState struct { 517 Sequence int `json:"sequence"` 518 Revision int `json:"revision"` 519 Status RepairStatus `json:"status"` 520 } 521 522 // state holds the atomically updated control state of the runner with sequences of repairs and their states. 523 type state struct { 524 Device deviceInfo `json:"device"` 525 Sequences map[string][]*RepairState `json:"sequences,omitempty"` 526 TimeLowerBound time.Time `json:"time-lower-bound"` 527 } 528 529 func (run *Runner) setRepairState(brandID string, state RepairState) { 530 if run.state.Sequences == nil { 531 run.state.Sequences = make(map[string][]*RepairState) 532 } 533 sequence := run.state.Sequences[brandID] 534 if state.Sequence > len(sequence) { 535 run.stateModified = true 536 run.state.Sequences[brandID] = append(sequence, &state) 537 } else if *sequence[state.Sequence-1] != state { 538 run.stateModified = true 539 sequence[state.Sequence-1] = &state 540 } 541 } 542 543 func (run *Runner) readState() error { 544 r, err := os.Open(dirs.SnapRepairStateFile) 545 if err != nil { 546 return err 547 } 548 defer r.Close() 549 dec := json.NewDecoder(r) 550 return dec.Decode(&run.state) 551 } 552 553 func (run *Runner) moveTimeLowerBound(t time.Time) { 554 if t.After(run.state.TimeLowerBound) { 555 run.stateModified = true 556 run.state.TimeLowerBound = t.UTC() 557 } 558 } 559 560 var timeNow = time.Now 561 562 func (run *Runner) now() time.Time { 563 now := timeNow().UTC() 564 if now.Before(run.state.TimeLowerBound) { 565 return run.state.TimeLowerBound 566 } 567 return now 568 } 569 570 func (run *Runner) initState() error { 571 if err := os.MkdirAll(dirs.SnapRepairDir, 0775); err != nil { 572 return fmt.Errorf("cannot create repair state directory: %v", err) 573 } 574 // best-effort remove old 575 os.Remove(dirs.SnapRepairStateFile) 576 run.state = state{} 577 // initialize time lower bound with image built time/seed.yaml time 578 info, err := os.Stat(filepath.Join(dirs.SnapSeedDir, "seed.yaml")) 579 if err != nil { 580 return err 581 } 582 run.moveTimeLowerBound(info.ModTime()) 583 // initialize device info 584 if err := run.initDeviceInfo(); err != nil { 585 return err 586 } 587 run.stateModified = true 588 return run.SaveState() 589 } 590 591 func trustedBackstore(trusted []asserts.Assertion) asserts.Backstore { 592 trustedBS := asserts.NewMemoryBackstore() 593 for _, t := range trusted { 594 trustedBS.Put(t.Type(), t) 595 } 596 return trustedBS 597 } 598 599 func checkAuthorityID(a asserts.Assertion, trusted asserts.Backstore) error { 600 assertType := a.Type() 601 if assertType != asserts.AccountKeyType && assertType != asserts.AccountType { 602 return nil 603 } 604 // check that account and account-key assertions are signed by 605 // a trusted authority 606 acctID := a.AuthorityID() 607 _, err := trusted.Get(asserts.AccountType, []string{acctID}, asserts.AccountType.MaxSupportedFormat()) 608 if err != nil && !asserts.IsNotFound(err) { 609 return err 610 } 611 if asserts.IsNotFound(err) { 612 return fmt.Errorf("%v not signed by trusted authority: %s", a.Ref(), acctID) 613 } 614 return nil 615 } 616 617 func verifySignatures(a asserts.Assertion, workBS asserts.Backstore, trusted asserts.Backstore) error { 618 if err := checkAuthorityID(a, trusted); err != nil { 619 return err 620 } 621 acctKeyMaxSuppFormat := asserts.AccountKeyType.MaxSupportedFormat() 622 623 seen := make(map[string]bool) 624 bottom := false 625 for !bottom { 626 u := a.Ref().Unique() 627 if seen[u] { 628 return fmt.Errorf("circular assertions") 629 } 630 seen[u] = true 631 signKey := []string{a.SignKeyID()} 632 key, err := trusted.Get(asserts.AccountKeyType, signKey, acctKeyMaxSuppFormat) 633 if err != nil && !asserts.IsNotFound(err) { 634 return err 635 } 636 if err == nil { 637 bottom = true 638 } else { 639 key, err = workBS.Get(asserts.AccountKeyType, signKey, acctKeyMaxSuppFormat) 640 if err != nil && !asserts.IsNotFound(err) { 641 return err 642 } 643 if asserts.IsNotFound(err) { 644 return fmt.Errorf("cannot find public key %q", signKey[0]) 645 } 646 if err := checkAuthorityID(key, trusted); err != nil { 647 return err 648 } 649 } 650 if err := asserts.CheckSignature(a, key.(*asserts.AccountKey), nil, time.Time{}); err != nil { 651 return err 652 } 653 a = key 654 } 655 return nil 656 } 657 658 func (run *Runner) initDeviceInfo() error { 659 const errPrefix = "cannot set device information: " 660 661 workBS := asserts.NewMemoryBackstore() 662 assertSeedDir := filepath.Join(dirs.SnapSeedDir, "assertions") 663 dc, err := ioutil.ReadDir(assertSeedDir) 664 if err != nil { 665 return err 666 } 667 var model *asserts.Model 668 for _, fi := range dc { 669 fn := filepath.Join(assertSeedDir, fi.Name()) 670 f, err := os.Open(fn) 671 if err != nil { 672 // best effort 673 continue 674 } 675 dec := asserts.NewDecoder(f) 676 for { 677 a, err := dec.Decode() 678 if err != nil { 679 // best effort 680 break 681 } 682 switch a.Type() { 683 case asserts.ModelType: 684 if model != nil { 685 return fmt.Errorf(errPrefix + "multiple models in seed assertions") 686 } 687 model = a.(*asserts.Model) 688 case asserts.AccountType, asserts.AccountKeyType: 689 workBS.Put(a.Type(), a) 690 } 691 } 692 } 693 if model == nil { 694 return fmt.Errorf(errPrefix + "no model assertion in seed data") 695 } 696 trustedBS := trustedBackstore(sysdb.Trusted()) 697 if err := verifySignatures(model, workBS, trustedBS); err != nil { 698 return fmt.Errorf(errPrefix+"%v", err) 699 } 700 acctPK := []string{model.BrandID()} 701 acctMaxSupFormat := asserts.AccountType.MaxSupportedFormat() 702 acct, err := trustedBS.Get(asserts.AccountType, acctPK, acctMaxSupFormat) 703 if err != nil { 704 var err error 705 acct, err = workBS.Get(asserts.AccountType, acctPK, acctMaxSupFormat) 706 if err != nil { 707 return fmt.Errorf(errPrefix + "no brand account assertion in seed data") 708 } 709 } 710 if err := verifySignatures(acct, workBS, trustedBS); err != nil { 711 return fmt.Errorf(errPrefix+"%v", err) 712 } 713 run.state.Device.Brand = model.BrandID() 714 run.state.Device.Model = model.Model() 715 return nil 716 } 717 718 // LoadState loads the repairs' state from disk, and (re)initializes it if it's missing or corrupted. 719 func (run *Runner) LoadState() error { 720 err := run.readState() 721 if err == nil { 722 return nil 723 } 724 // error => initialize from scratch 725 if !os.IsNotExist(err) { 726 logger.Noticef("cannor read repair state: %v", err) 727 } 728 return run.initState() 729 } 730 731 // SaveState saves the repairs' state to disk. 732 func (run *Runner) SaveState() error { 733 if !run.stateModified { 734 return nil 735 } 736 m, err := json.Marshal(&run.state) 737 if err != nil { 738 return fmt.Errorf("cannot marshal repair state: %v", err) 739 } 740 err = osutil.AtomicWriteFile(dirs.SnapRepairStateFile, m, 0600, 0) 741 if err != nil { 742 return fmt.Errorf("cannot save repair state: %v", err) 743 } 744 run.stateModified = false 745 return nil 746 } 747 748 func stringList(headers map[string]interface{}, name string) ([]string, error) { 749 v, ok := headers[name] 750 if !ok { 751 return nil, nil 752 } 753 l, ok := v.([]interface{}) 754 if !ok { 755 return nil, fmt.Errorf("header %q is not a list", name) 756 } 757 r := make([]string, len(l)) 758 for i, v := range l { 759 s, ok := v.(string) 760 if !ok { 761 return nil, fmt.Errorf("header %q contains non-string elements", name) 762 } 763 r[i] = s 764 } 765 return r, nil 766 } 767 768 // Applicable returns whether a repair with the given headers is applicable to the device. 769 func (run *Runner) Applicable(headers map[string]interface{}) bool { 770 if headers["disabled"] == "true" { 771 return false 772 } 773 series, err := stringList(headers, "series") 774 if err != nil { 775 return false 776 } 777 if len(series) != 0 && !strutil.ListContains(series, release.Series) { 778 return false 779 } 780 archs, err := stringList(headers, "architectures") 781 if err != nil { 782 return false 783 } 784 if len(archs) != 0 && !strutil.ListContains(archs, arch.DpkgArchitecture()) { 785 return false 786 } 787 brandModel := fmt.Sprintf("%s/%s", run.state.Device.Brand, run.state.Device.Model) 788 models, err := stringList(headers, "models") 789 if err != nil { 790 return false 791 } 792 if len(models) != 0 && !strutil.ListContains(models, brandModel) { 793 // model prefix matching: brand/prefix* 794 hit := false 795 for _, patt := range models { 796 if strings.HasSuffix(patt, "*") && strings.ContainsRune(patt, '/') { 797 if strings.HasPrefix(brandModel, strings.TrimSuffix(patt, "*")) { 798 hit = true 799 break 800 } 801 } 802 } 803 if !hit { 804 return false 805 } 806 } 807 return true 808 } 809 810 var errSkip = errors.New("repair unnecessary on this system") 811 812 func (run *Runner) fetch(brandID string, repairID int) (repair *asserts.Repair, aux []asserts.Assertion, err error) { 813 headers, err := run.Peek(brandID, repairID) 814 if err != nil { 815 return nil, nil, err 816 } 817 if !run.Applicable(headers) { 818 return nil, nil, errSkip 819 } 820 return run.Fetch(brandID, repairID, -1) 821 } 822 823 func (run *Runner) refetch(brandID string, repairID, revision int) (repair *asserts.Repair, aux []asserts.Assertion, err error) { 824 return run.Fetch(brandID, repairID, revision) 825 } 826 827 func (run *Runner) saveStream(brandID string, repairID int, repair *asserts.Repair, aux []asserts.Assertion) error { 828 d := filepath.Join(dirs.SnapRepairAssertsDir, brandID, strconv.Itoa(repairID)) 829 err := os.MkdirAll(d, 0775) 830 if err != nil { 831 return err 832 } 833 buf := &bytes.Buffer{} 834 enc := asserts.NewEncoder(buf) 835 r := append([]asserts.Assertion{repair}, aux...) 836 for _, a := range r { 837 if err := enc.Encode(a); err != nil { 838 return fmt.Errorf("cannot encode repair assertions %s-%d for saving: %v", brandID, repairID, err) 839 } 840 } 841 p := filepath.Join(d, fmt.Sprintf("r%d.repair", r[0].Revision())) 842 return osutil.AtomicWriteFile(p, buf.Bytes(), 0600, 0) 843 } 844 845 func (run *Runner) readSavedStream(brandID string, repairID, revision int) (repair *asserts.Repair, aux []asserts.Assertion, err error) { 846 d := filepath.Join(dirs.SnapRepairAssertsDir, brandID, strconv.Itoa(repairID)) 847 p := filepath.Join(d, fmt.Sprintf("r%d.repair", revision)) 848 f, err := os.Open(p) 849 if err != nil { 850 return nil, nil, err 851 } 852 defer f.Close() 853 854 dec := asserts.NewDecoder(f) 855 var r []asserts.Assertion 856 for { 857 a, err := dec.Decode() 858 if err == io.EOF { 859 break 860 } 861 if err != nil { 862 return nil, nil, fmt.Errorf("cannot decode repair assertions %s-%d from disk: %v", brandID, repairID, err) 863 } 864 r = append(r, a) 865 } 866 return checkStream(brandID, repairID, r) 867 } 868 869 func (run *Runner) makeReady(brandID string, sequenceNext int) (repair *asserts.Repair, err error) { 870 sequence := run.state.Sequences[brandID] 871 var aux []asserts.Assertion 872 var state RepairState 873 if sequenceNext <= len(sequence) { 874 // consider retries 875 state = *sequence[sequenceNext-1] 876 if state.Status != RetryStatus { 877 return nil, errSkip 878 } 879 var err error 880 repair, aux, err = run.refetch(brandID, state.Sequence, state.Revision) 881 if err != nil { 882 if err != ErrRepairNotModified { 883 logger.Noticef("cannot refetch repair %s-%d, will retry what is on disk: %v", brandID, sequenceNext, err) 884 } 885 // try to use what we have already on disk 886 repair, aux, err = run.readSavedStream(brandID, state.Sequence, state.Revision) 887 if err != nil { 888 return nil, err 889 } 890 } 891 } else { 892 // fetch the next repair in the sequence 893 // assumes no gaps, each repair id is present so far, 894 // possibly skipped 895 var err error 896 repair, aux, err = run.fetch(brandID, sequenceNext) 897 if err != nil && err != errSkip { 898 return nil, err 899 } 900 state = RepairState{ 901 Sequence: sequenceNext, 902 } 903 if err == errSkip { 904 // TODO: store headers to justify decision 905 state.Status = SkipStatus 906 run.setRepairState(brandID, state) 907 return nil, errSkip 908 } 909 } 910 // verify with signatures 911 if err := run.Verify(repair, aux); err != nil { 912 return nil, fmt.Errorf("cannot verify repair %s-%d: %v", brandID, state.Sequence, err) 913 } 914 if err := run.saveStream(brandID, state.Sequence, repair, aux); err != nil { 915 return nil, err 916 } 917 state.Revision = repair.Revision() 918 if !run.Applicable(repair.Headers()) { 919 state.Status = SkipStatus 920 run.setRepairState(brandID, state) 921 return nil, errSkip 922 } 923 run.setRepairState(brandID, state) 924 return repair, nil 925 } 926 927 // Next returns the next repair for the brand id sequence to run/retry or ErrRepairNotFound if there is none atm. It updates the state as required. 928 func (run *Runner) Next(brandID string) (*Repair, error) { 929 sequenceNext := run.sequenceNext[brandID] 930 if sequenceNext == 0 { 931 sequenceNext = 1 932 } 933 for { 934 repair, err := run.makeReady(brandID, sequenceNext) 935 // SaveState is a no-op unless makeReady modified the state 936 stateErr := run.SaveState() 937 if err != nil && err != errSkip && err != ErrRepairNotFound { 938 // err is a non trivial error, just log the SaveState error and report err 939 if stateErr != nil { 940 logger.Noticef("%v", stateErr) 941 } 942 return nil, err 943 } 944 if stateErr != nil { 945 return nil, stateErr 946 } 947 if err == ErrRepairNotFound { 948 return nil, ErrRepairNotFound 949 } 950 951 sequenceNext += 1 952 run.sequenceNext[brandID] = sequenceNext 953 if err == errSkip { 954 continue 955 } 956 957 return &Repair{ 958 Repair: repair, 959 run: run, 960 sequence: sequenceNext - 1, 961 }, nil 962 } 963 } 964 965 // Limit trust to specific keys while there's no delegation or limited 966 // keys support. The obtained assertion stream may also include 967 // account keys that are directly or indirectly signed by a trusted 968 // key. 969 var ( 970 trustedRepairRootKeys []*asserts.AccountKey 971 ) 972 973 // Verify verifies that the repair is properly signed by the specific 974 // trusted root keys or by account keys in the stream (passed via aux) 975 // directly or indirectly signed by a trusted key. 976 func (run *Runner) Verify(repair *asserts.Repair, aux []asserts.Assertion) error { 977 workBS := asserts.NewMemoryBackstore() 978 for _, a := range aux { 979 if a.Type() != asserts.AccountKeyType { 980 continue 981 } 982 err := workBS.Put(asserts.AccountKeyType, a) 983 if err != nil { 984 return err 985 } 986 } 987 trustedBS := asserts.NewMemoryBackstore() 988 for _, t := range trustedRepairRootKeys { 989 trustedBS.Put(asserts.AccountKeyType, t) 990 } 991 for _, t := range sysdb.Trusted() { 992 // we do *not* add the defalt sysdb trusted account 993 // keys here because the repair assertions have their 994 // own *dedicated* root of trust 995 if t.Type() == asserts.AccountType { 996 trustedBS.Put(asserts.AccountType, t) 997 } 998 } 999 1000 return verifySignatures(repair, workBS, trustedBS) 1001 }