github.com/rigado/snapd@v2.42.5-go-mod+incompatible/overlord/devicestate/handlers.go (about) 1 // -*- Mode: Go; indent-tabs-mode: t -*- 2 /* 3 * Copyright (C) 2016-2017 Canonical Ltd 4 * 5 * This program is free software: you can redistribute it and/or modify 6 * it under the terms of the GNU General Public License version 3 as 7 * published by the Free Software Foundation. 8 * 9 * This program is distributed in the hope that it will be useful, 10 * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 * GNU General Public License for more details. 13 * 14 * You should have received a copy of the GNU General Public License 15 * along with this program. If not, see <http://www.gnu.org/licenses/>. 16 * 17 */ 18 19 package devicestate 20 21 import ( 22 "bytes" 23 "crypto/rsa" 24 "encoding/json" 25 "errors" 26 "fmt" 27 "io" 28 "net" 29 "net/http" 30 "net/url" 31 "os" 32 "path/filepath" 33 "strconv" 34 "strings" 35 "time" 36 37 "gopkg.in/tomb.v2" 38 39 "github.com/snapcore/snapd/asserts" 40 "github.com/snapcore/snapd/dirs" 41 "github.com/snapcore/snapd/gadget" 42 "github.com/snapcore/snapd/httputil" 43 "github.com/snapcore/snapd/logger" 44 "github.com/snapcore/snapd/osutil" 45 "github.com/snapcore/snapd/overlord/assertstate" 46 "github.com/snapcore/snapd/overlord/auth" 47 "github.com/snapcore/snapd/overlord/configstate/config" 48 "github.com/snapcore/snapd/overlord/configstate/proxyconf" 49 "github.com/snapcore/snapd/overlord/snapstate" 50 "github.com/snapcore/snapd/overlord/state" 51 "github.com/snapcore/snapd/release" 52 "github.com/snapcore/snapd/snap" 53 "github.com/snapcore/snapd/snap/naming" 54 "github.com/snapcore/snapd/timings" 55 ) 56 57 func (m *DeviceManager) doMarkSeeded(t *state.Task, _ *tomb.Tomb) error { 58 st := t.State() 59 st.Lock() 60 defer st.Unlock() 61 62 st.Set("seed-time", time.Now()) 63 st.Set("seeded", true) 64 // make sure we setup a fallback model/consider the next phase 65 // (registration) timely 66 st.EnsureBefore(0) 67 return nil 68 } 69 70 func isSameAssertsRevision(err error) bool { 71 if e, ok := err.(*asserts.RevisionError); ok { 72 if e.Used == e.Current { 73 return true 74 } 75 } 76 return false 77 } 78 79 func (m *DeviceManager) doSetModel(t *state.Task, _ *tomb.Tomb) error { 80 st := t.State() 81 st.Lock() 82 defer st.Unlock() 83 84 remodCtx, err := remodelCtxFromTask(t) 85 if err != nil { 86 return err 87 } 88 new := remodCtx.Model() 89 90 err = assertstate.Add(st, new) 91 if err != nil && !isSameAssertsRevision(err) { 92 return err 93 } 94 95 // unmark no-longer required snaps 96 requiredSnaps := getAllRequiredSnapsForModel(new) 97 // TODO:XXX: have AllByRef 98 snapStates, err := snapstate.All(st) 99 if err != nil { 100 return err 101 } 102 for snapName, snapst := range snapStates { 103 // TODO: remove this type restriction once we remodel 104 // gadgets and add tests that ensure 105 // that the required flag is properly set/unset 106 typ, err := snapst.Type() 107 if err != nil { 108 return err 109 } 110 if typ != snap.TypeApp && typ != snap.TypeBase && typ != snap.TypeKernel { 111 continue 112 } 113 // clean required flag if no-longer needed 114 if snapst.Flags.Required && !requiredSnaps.Contains(naming.Snap(snapName)) { 115 snapst.Flags.Required = false 116 snapstate.Set(st, snapName, snapst) 117 } 118 // TODO: clean "required" flag of "core" if a remodel 119 // moves from the "core" snap to a different 120 // bootable base snap. 121 } 122 123 return remodCtx.Finish() 124 } 125 126 func (m *DeviceManager) cleanupRemodel(t *state.Task, _ *tomb.Tomb) error { 127 st := t.State() 128 st.Lock() 129 defer st.Unlock() 130 // cleanup the cached remodel context 131 cleanupRemodelCtx(t.Change()) 132 return nil 133 } 134 135 func useStaging() bool { 136 return osutil.GetenvBool("SNAPPY_USE_STAGING_STORE") 137 } 138 139 func baseURL() *url.URL { 140 if useStaging() { 141 return mustParse("https://api.staging.snapcraft.io/") 142 } 143 return mustParse("https://api.snapcraft.io/") 144 } 145 146 func mustParse(s string) *url.URL { 147 u, err := url.Parse(s) 148 if err != nil { 149 panic(err) 150 } 151 return u 152 } 153 154 var ( 155 keyLength = 4096 156 retryInterval = 60 * time.Second 157 maxTentatives = 15 158 baseStoreURL = baseURL().ResolveReference(authRef) 159 160 authRef = mustParse("api/v1/snaps/auth/") // authRef must end in / for the following refs to work 161 reqIdRef = mustParse("request-id") 162 serialRef = mustParse("serial") 163 devicesRef = mustParse("devices") 164 165 // we accept a stream with the serial assertion as well 166 registrationCapabilities = []string{"serial-stream"} 167 ) 168 169 func newEnoughProxy(st *state.State, proxyURL *url.URL, client *http.Client) bool { 170 st.Unlock() 171 defer st.Lock() 172 173 const prefix = "Cannot check whether proxy store supports a custom serial vault" 174 175 req, err := http.NewRequest("HEAD", proxyURL.String(), nil) 176 if err != nil { 177 // can't really happen unless proxyURL is somehow broken 178 logger.Debugf(prefix+": %v", err) 179 return false 180 } 181 req.Header.Set("User-Agent", httputil.UserAgent()) 182 resp, err := client.Do(req) 183 if err != nil { 184 // some sort of network or protocol error 185 logger.Debugf(prefix+": %v", err) 186 return false 187 } 188 resp.Body.Close() 189 if resp.StatusCode != 200 { 190 logger.Debugf(prefix+": Head request returned %s.", resp.Status) 191 return false 192 } 193 verstr := resp.Header.Get("Snap-Store-Version") 194 ver, err := strconv.Atoi(verstr) 195 if err != nil { 196 logger.Debugf(prefix+": Bogus Snap-Store-Version header %q.", verstr) 197 return false 198 } 199 return ver >= 6 200 } 201 202 func (cfg *serialRequestConfig) setURLs(proxyURL, svcURL *url.URL) { 203 base := baseStoreURL 204 if proxyURL != nil { 205 if svcURL != nil { 206 if cfg.headers == nil { 207 cfg.headers = make(map[string]string, 1) 208 } 209 cfg.headers["X-Snap-Device-Service-URL"] = svcURL.String() 210 } 211 base = proxyURL.ResolveReference(authRef) 212 } else if svcURL != nil { 213 base = svcURL 214 } 215 216 cfg.requestIDURL = base.ResolveReference(reqIdRef).String() 217 if svcURL != nil && proxyURL == nil { 218 // talking directly to the custom device service 219 cfg.serialRequestURL = base.ResolveReference(serialRef).String() 220 } else { 221 cfg.serialRequestURL = base.ResolveReference(devicesRef).String() 222 } 223 } 224 225 func (m *DeviceManager) doGenerateDeviceKey(t *state.Task, _ *tomb.Tomb) error { 226 st := t.State() 227 st.Lock() 228 defer st.Unlock() 229 230 perfTimings := timings.NewForTask(t) 231 defer perfTimings.Save(st) 232 233 device, err := m.device() 234 if err != nil { 235 return err 236 } 237 238 if device.KeyID != "" { 239 // nothing to do 240 return nil 241 } 242 243 st.Unlock() 244 var keyPair *rsa.PrivateKey 245 timings.Run(perfTimings, "generate-rsa-key", "generating device key pair", func(tm timings.Measurer) { 246 keyPair, err = generateRSAKey(keyLength) 247 }) 248 st.Lock() 249 if err != nil { 250 return fmt.Errorf("cannot generate device key pair: %v", err) 251 } 252 253 privKey := asserts.RSAPrivateKey(keyPair) 254 err = m.keypairMgr.Put(privKey) 255 if err != nil { 256 return fmt.Errorf("cannot store device key pair: %v", err) 257 } 258 259 device.KeyID = privKey.PublicKey().ID() 260 err = m.setDevice(device) 261 if err != nil { 262 return err 263 } 264 t.SetStatus(state.DoneStatus) 265 return nil 266 } 267 268 // A registrationContext handles the contextual information needed 269 // for the initial registration or a re-registration. 270 type registrationContext interface { 271 Device() (*auth.DeviceState, error) 272 273 GadgetForSerialRequestConfig() string 274 SerialRequestExtraHeaders() map[string]interface{} 275 SerialRequestAncillaryAssertions() []asserts.Assertion 276 277 FinishRegistration(serial *asserts.Serial) error 278 279 ForRemodeling() bool 280 } 281 282 // initialRegistrationContext is a thin wrapper around DeviceManager 283 // implementing registrationContext for initial regitration 284 type initialRegistrationContext struct { 285 deviceMgr *DeviceManager 286 287 gadget string 288 } 289 290 func (rc *initialRegistrationContext) ForRemodeling() bool { 291 return false 292 } 293 294 func (rc *initialRegistrationContext) Device() (*auth.DeviceState, error) { 295 return rc.deviceMgr.device() 296 } 297 298 func (rc *initialRegistrationContext) GadgetForSerialRequestConfig() string { 299 return rc.gadget 300 } 301 302 func (rc *initialRegistrationContext) SerialRequestExtraHeaders() map[string]interface{} { 303 return nil 304 } 305 306 func (rc *initialRegistrationContext) SerialRequestAncillaryAssertions() []asserts.Assertion { 307 return nil 308 } 309 310 func (rc *initialRegistrationContext) FinishRegistration(serial *asserts.Serial) error { 311 device, err := rc.deviceMgr.device() 312 if err != nil { 313 return err 314 } 315 316 device.Serial = serial.Serial() 317 if err := rc.deviceMgr.setDevice(device); err != nil { 318 return err 319 } 320 rc.deviceMgr.markRegistered() 321 322 // make sure we timely consider anything that was blocked on 323 // registration 324 rc.deviceMgr.state.EnsureBefore(0) 325 326 return nil 327 } 328 329 // registrationCtx returns a registrationContext appropriate for the task and its change. 330 func (m *DeviceManager) registrationCtx(t *state.Task) (registrationContext, error) { 331 remodCtx, err := remodelCtxFromTask(t) 332 if err != nil && err != state.ErrNoState { 333 return nil, err 334 } 335 if regCtx, ok := remodCtx.(registrationContext); ok { 336 return regCtx, nil 337 } 338 model, err := m.Model() 339 if err != nil { 340 return nil, err 341 } 342 343 return &initialRegistrationContext{ 344 deviceMgr: m, 345 gadget: model.Gadget(), 346 }, nil 347 } 348 349 type serialSetup struct { 350 SerialRequest string `json:"serial-request"` 351 Serial string `json:"serial"` 352 } 353 354 type requestIDResp struct { 355 RequestID string `json:"request-id"` 356 } 357 358 func retryErr(t *state.Task, nTentatives int, reason string, a ...interface{}) error { 359 t.State().Lock() 360 defer t.State().Unlock() 361 if nTentatives >= maxTentatives { 362 return fmt.Errorf(reason, a...) 363 } 364 t.Errorf(reason, a...) 365 return &state.Retry{After: retryInterval} 366 } 367 368 type serverError struct { 369 Message string `json:"message"` 370 Errors []*serverError `json:"error_list"` 371 } 372 373 func retryBadStatus(t *state.Task, nTentatives int, reason string, resp *http.Response) error { 374 if resp.StatusCode > 500 { 375 // likely temporary 376 return retryErr(t, nTentatives, "%s: unexpected status %d", reason, resp.StatusCode) 377 } 378 if resp.Header.Get("Content-Type") == "application/json" { 379 var srvErr serverError 380 dec := json.NewDecoder(resp.Body) 381 err := dec.Decode(&srvErr) 382 if err == nil { 383 msg := srvErr.Message 384 if msg == "" && len(srvErr.Errors) > 0 { 385 msg = srvErr.Errors[0].Message 386 } 387 if msg != "" { 388 return fmt.Errorf("%s: %s", reason, msg) 389 } 390 } 391 } 392 return fmt.Errorf("%s: unexpected status %d", reason, resp.StatusCode) 393 } 394 395 func prepareSerialRequest(t *state.Task, regCtx registrationContext, privKey asserts.PrivateKey, device *auth.DeviceState, client *http.Client, cfg *serialRequestConfig) (string, error) { 396 // limit tentatives starting from scratch before going to 397 // slower full retries 398 var nTentatives int 399 err := t.Get("pre-poll-tentatives", &nTentatives) 400 if err != nil && err != state.ErrNoState { 401 return "", err 402 } 403 nTentatives++ 404 t.Set("pre-poll-tentatives", nTentatives) 405 406 st := t.State() 407 st.Unlock() 408 defer st.Lock() 409 410 req, err := http.NewRequest("POST", cfg.requestIDURL, nil) 411 if err != nil { 412 return "", fmt.Errorf("internal error: cannot create request-id request %q", cfg.requestIDURL) 413 } 414 req.Header.Set("User-Agent", httputil.UserAgent()) 415 cfg.applyHeaders(req) 416 417 resp, err := client.Do(req) 418 if err != nil { 419 if netErr, ok := err.(net.Error); ok && !netErr.Temporary() { 420 // a non temporary net error, like a DNS no 421 // host, error out and do full retries 422 return "", fmt.Errorf("cannot retrieve request-id for making a request for a serial: %v", err) 423 } 424 return "", retryErr(t, nTentatives, "cannot retrieve request-id for making a request for a serial: %v", err) 425 } 426 defer resp.Body.Close() 427 if resp.StatusCode != 200 { 428 return "", retryBadStatus(t, nTentatives, "cannot retrieve request-id for making a request for a serial", resp) 429 } 430 431 dec := json.NewDecoder(resp.Body) 432 var requestID requestIDResp 433 err = dec.Decode(&requestID) 434 if err != nil { // assume broken i/o 435 return "", retryErr(t, nTentatives, "cannot read response with request-id for making a request for a serial: %v", err) 436 } 437 438 encodedPubKey, err := asserts.EncodePublicKey(privKey.PublicKey()) 439 if err != nil { 440 return "", fmt.Errorf("internal error: cannot encode device public key: %v", err) 441 442 } 443 444 headers := map[string]interface{}{ 445 "brand-id": device.Brand, 446 "model": device.Model, 447 "request-id": requestID.RequestID, 448 "device-key": string(encodedPubKey), 449 } 450 if cfg.proposedSerial != "" { 451 headers["serial"] = cfg.proposedSerial 452 } 453 454 for k, v := range regCtx.SerialRequestExtraHeaders() { 455 headers[k] = v 456 } 457 458 serialReq, err := asserts.SignWithoutAuthority(asserts.SerialRequestType, headers, cfg.body, privKey) 459 if err != nil { 460 return "", err 461 } 462 463 buf := new(bytes.Buffer) 464 encoder := asserts.NewEncoder(buf) 465 if err := encoder.Encode(serialReq); err != nil { 466 return "", fmt.Errorf("cannot encode serial-request: %v", err) 467 } 468 469 for _, ancillaryAs := range regCtx.SerialRequestAncillaryAssertions() { 470 if err := encoder.Encode(ancillaryAs); err != nil { 471 return "", fmt.Errorf("cannot encode ancillary assertion: %v", err) 472 } 473 474 } 475 476 return buf.String(), nil 477 } 478 479 var errPoll = errors.New("serial-request accepted, poll later") 480 481 func submitSerialRequest(t *state.Task, serialRequest string, client *http.Client, cfg *serialRequestConfig) (*asserts.Serial, *asserts.Batch, error) { 482 st := t.State() 483 st.Unlock() 484 defer st.Lock() 485 486 req, err := http.NewRequest("POST", cfg.serialRequestURL, bytes.NewBufferString(serialRequest)) 487 if err != nil { 488 return nil, nil, fmt.Errorf("internal error: cannot create serial-request request %q", cfg.serialRequestURL) 489 } 490 req.Header.Set("User-Agent", httputil.UserAgent()) 491 req.Header.Set("Snap-Device-Capabilities", strings.Join(registrationCapabilities, " ")) 492 cfg.applyHeaders(req) 493 req.Header.Set("Content-Type", asserts.MediaType) 494 495 resp, err := client.Do(req) 496 if err != nil { 497 return nil, nil, retryErr(t, 0, "cannot deliver device serial request: %v", err) 498 } 499 defer resp.Body.Close() 500 501 switch resp.StatusCode { 502 case 200, 201: 503 case 202: 504 return nil, nil, errPoll 505 default: 506 return nil, nil, retryBadStatus(t, 0, "cannot deliver device serial request", resp) 507 } 508 509 var serial *asserts.Serial 510 var batch *asserts.Batch 511 // decode body with stream of assertions, of which one is the serial 512 dec := asserts.NewDecoder(resp.Body) 513 for { 514 got, err := dec.Decode() 515 if err == io.EOF { 516 break 517 } 518 if err != nil { // assume broken i/o 519 return nil, nil, retryErr(t, 0, "cannot read response to request for a serial: %v", err) 520 } 521 if got.Type() == asserts.SerialType { 522 if serial != nil { 523 return nil, nil, fmt.Errorf("cannot accept more than a single device serial assertion from the device service") 524 } 525 serial = got.(*asserts.Serial) 526 } else { 527 if batch == nil { 528 batch = asserts.NewBatch(nil) 529 } 530 if err := batch.Add(got); err != nil { 531 return nil, nil, err 532 } 533 } 534 // TODO: consider a size limit? 535 } 536 537 if serial == nil { 538 return nil, nil, fmt.Errorf("cannot proceed, received assertion stream from the device service missing device serial assertion") 539 } 540 541 return serial, batch, nil 542 } 543 544 func getSerial(t *state.Task, regCtx registrationContext, privKey asserts.PrivateKey, device *auth.DeviceState, tm timings.Measurer) (serial *asserts.Serial, ancillaryBatch *asserts.Batch, err error) { 545 var serialSup serialSetup 546 err = t.Get("serial-setup", &serialSup) 547 if err != nil && err != state.ErrNoState { 548 return nil, nil, err 549 } 550 551 if serialSup.Serial != "" { 552 // we got a serial, just haven't managed to save its info yet 553 a, err := asserts.Decode([]byte(serialSup.Serial)) 554 if err != nil { 555 return nil, nil, fmt.Errorf("internal error: cannot decode previously saved serial: %v", err) 556 } 557 return a.(*asserts.Serial), nil, nil 558 } 559 560 st := t.State() 561 proxyConf := proxyconf.New(st) 562 client := httputil.NewHTTPClient(&httputil.ClientOptions{ 563 Timeout: 30 * time.Second, 564 MayLogBody: true, 565 Proxy: proxyConf.Conf, 566 }) 567 568 cfg, err := getSerialRequestConfig(t, regCtx, client) 569 if err != nil { 570 return nil, nil, err 571 } 572 573 // NB: until we get at least an Accepted (202) we need to 574 // retry from scratch creating a new request-id because the 575 // previous one used could have expired 576 577 if serialSup.SerialRequest == "" { 578 var serialRequest string 579 var err error 580 timings.Run(tm, "prepare-serial-request", "prepare device serial request", func(timings.Measurer) { 581 serialRequest, err = prepareSerialRequest(t, regCtx, privKey, device, client, cfg) 582 }) 583 if err != nil { // errors & retries 584 return nil, nil, err 585 } 586 587 serialSup.SerialRequest = serialRequest 588 } 589 590 timings.Run(tm, "submit-serial-request", "submit device serial request", func(timings.Measurer) { 591 serial, ancillaryBatch, err = submitSerialRequest(t, serialSup.SerialRequest, client, cfg) 592 }) 593 if err == errPoll { 594 // we can/should reuse the serial-request 595 t.Set("serial-setup", serialSup) 596 return nil, nil, errPoll 597 } 598 if err != nil { // errors & retries 599 return nil, nil, err 600 } 601 602 keyID := privKey.PublicKey().ID() 603 if serial.BrandID() != device.Brand || serial.Model() != device.Model || serial.DeviceKey().ID() != keyID { 604 return nil, nil, fmt.Errorf("obtained serial assertion does not match provided device identity information (brand, model, key id): %s / %s / %s != %s / %s / %s", serial.BrandID(), serial.Model(), serial.DeviceKey().ID(), device.Brand, device.Model, keyID) 605 } 606 607 if ancillaryBatch == nil { 608 serialSup.Serial = string(asserts.Encode(serial)) 609 t.Set("serial-setup", serialSup) 610 } 611 612 if repeatRequestSerial == "after-got-serial" { 613 // For testing purposes, ensure a crash in this state works. 614 return nil, nil, &state.Retry{} 615 } 616 617 return serial, ancillaryBatch, nil 618 } 619 620 type serialRequestConfig struct { 621 requestIDURL string 622 serialRequestURL string 623 headers map[string]string 624 proposedSerial string 625 body []byte 626 } 627 628 func (cfg *serialRequestConfig) applyHeaders(req *http.Request) { 629 for k, v := range cfg.headers { 630 req.Header.Set(k, v) 631 } 632 } 633 634 func getSerialRequestConfig(t *state.Task, regCtx registrationContext, client *http.Client) (*serialRequestConfig, error) { 635 var svcURL, proxyURL *url.URL 636 637 st := t.State() 638 tr := config.NewTransaction(st) 639 if proxyStore, err := proxyStore(st, tr); err != nil && err != state.ErrNoState { 640 return nil, err 641 } else if proxyStore != nil { 642 proxyURL = proxyStore.URL() 643 } 644 645 cfg := serialRequestConfig{} 646 647 gadgetName := regCtx.GadgetForSerialRequestConfig() 648 // gadget is optional on classic 649 if gadgetName != "" { 650 var gadgetSt snapstate.SnapState 651 if err := snapstate.Get(st, gadgetName, &gadgetSt); err != nil { 652 return nil, fmt.Errorf("cannot find gadget snap %q: %v", gadgetName, err) 653 } 654 655 var svcURI string 656 err := tr.GetMaybe(gadgetName, "device-service.url", &svcURI) 657 if err != nil { 658 return nil, err 659 } 660 661 if svcURI != "" { 662 svcURL, err = url.Parse(svcURI) 663 if err != nil { 664 return nil, fmt.Errorf("cannot parse device registration base URL %q: %v", svcURI, err) 665 } 666 if !strings.HasSuffix(svcURL.Path, "/") { 667 svcURL.Path += "/" 668 } 669 } 670 671 err = tr.GetMaybe(gadgetName, "device-service.headers", &cfg.headers) 672 if err != nil { 673 return nil, err 674 } 675 676 var bodyStr string 677 err = tr.GetMaybe(gadgetName, "registration.body", &bodyStr) 678 if err != nil { 679 return nil, err 680 } 681 682 cfg.body = []byte(bodyStr) 683 684 err = tr.GetMaybe(gadgetName, "registration.proposed-serial", &cfg.proposedSerial) 685 if err != nil { 686 return nil, err 687 } 688 } 689 690 if proxyURL != nil && svcURL != nil && !newEnoughProxy(st, proxyURL, client) { 691 logger.Noticef("Proxy store does not support custom serial vault; ignoring the proxy") 692 proxyURL = nil 693 } 694 695 cfg.setURLs(proxyURL, svcURL) 696 697 return &cfg, nil 698 } 699 700 func (m *DeviceManager) doRequestSerial(t *state.Task, _ *tomb.Tomb) error { 701 st := t.State() 702 st.Lock() 703 defer st.Unlock() 704 705 perfTimings := timings.NewForTask(t) 706 defer perfTimings.Save(st) 707 708 regCtx, err := m.registrationCtx(t) 709 if err != nil { 710 return err 711 } 712 713 device, err := regCtx.Device() 714 if err != nil { 715 return err 716 } 717 718 // NB: the keyPair is fixed for now 719 privKey, err := m.keyPair() 720 if err == state.ErrNoState { 721 return fmt.Errorf("internal error: cannot find device key pair") 722 } 723 if err != nil { 724 return err 725 } 726 727 // make this idempotent, look if we have already a serial assertion 728 // for privKey 729 serials, err := assertstate.DB(st).FindMany(asserts.SerialType, map[string]string{ 730 "brand-id": device.Brand, 731 "model": device.Model, 732 "device-key-sha3-384": privKey.PublicKey().ID(), 733 }) 734 if err != nil && !asserts.IsNotFound(err) { 735 return err 736 } 737 738 finish := func(serial *asserts.Serial) error { 739 if regCtx.FinishRegistration(serial); err != nil { 740 return err 741 } 742 t.SetStatus(state.DoneStatus) 743 return nil 744 } 745 746 if len(serials) == 1 { 747 // means we saved the assertion but didn't get to the end of the task 748 return finish(serials[0].(*asserts.Serial)) 749 } 750 if len(serials) > 1 { 751 return fmt.Errorf("internal error: multiple serial assertions for the same device key") 752 } 753 754 var serial *asserts.Serial 755 var ancillaryBatch *asserts.Batch 756 timings.Run(perfTimings, "get-serial", "get device serial", func(tm timings.Measurer) { 757 serial, ancillaryBatch, err = getSerial(t, regCtx, privKey, device, tm) 758 }) 759 if err == errPoll { 760 t.Logf("Will poll for device serial assertion in 60 seconds") 761 return &state.Retry{After: retryInterval} 762 } 763 if err != nil { // errors & retries 764 return err 765 766 } 767 768 if ancillaryBatch == nil { 769 // the device service returned only the serial 770 if err := acceptSerialOnly(t, serial, perfTimings); err != nil { 771 return err 772 } 773 } else { 774 // the device service returned a stream of assertions 775 timings.Run(perfTimings, "fetch-keys", "fetch signing key chain", func(timings.Measurer) { 776 err = acceptSerialPlusBatch(t, serial, ancillaryBatch) 777 }) 778 if err != nil { 779 t.Errorf("cannot accept stream of assertions from device service: %v", err) 780 return err 781 } 782 } 783 784 if repeatRequestSerial == "after-add-serial" { 785 // For testing purposes, ensure a crash in this state works. 786 return &state.Retry{} 787 } 788 789 return finish(serial) 790 } 791 792 func acceptSerialOnly(t *state.Task, serial *asserts.Serial, perfTimings *timings.Timings) error { 793 st := t.State() 794 var err error 795 var errAcctKey error 796 // try to fetch the signing key chain of the serial 797 timings.Run(perfTimings, "fetch-keys", "fetch signing key chain", func(timings.Measurer) { 798 errAcctKey, err = fetchKeys(st, serial.SignKeyID()) 799 }) 800 if err != nil { 801 return err 802 } 803 804 // add the serial assertion to the system assertion db 805 err = assertstate.Add(st, serial) 806 if err != nil { 807 // if we had failed to fetch the signing key, retry in a bit 808 if errAcctKey != nil { 809 t.Errorf("cannot fetch signing key for the serial: %v", errAcctKey) 810 return &state.Retry{After: retryInterval} 811 } 812 return err 813 } 814 815 return nil 816 } 817 818 func acceptSerialPlusBatch(t *state.Task, serial *asserts.Serial, batch *asserts.Batch) error { 819 st := t.State() 820 err := batch.Add(serial) 821 if err != nil { 822 return err 823 } 824 return assertstate.AddBatch(st, batch, &asserts.CommitOptions{Precheck: true}) 825 } 826 827 var repeatRequestSerial string // for tests 828 829 func fetchKeys(st *state.State, keyID string) (errAcctKey error, err error) { 830 // TODO: right now any store should be good enough here but 831 // that might change. As an alternative we do support 832 // receiving a stream with any relevant assertions. 833 sto := snapstate.Store(st, nil) 834 db := assertstate.DB(st) 835 836 retrieveError := false 837 retrieve := func(ref *asserts.Ref) (asserts.Assertion, error) { 838 st.Unlock() 839 defer st.Lock() 840 a, err := sto.Assertion(ref.Type, ref.PrimaryKey, nil) 841 retrieveError = err != nil 842 return a, err 843 } 844 845 save := func(a asserts.Assertion) error { 846 err = assertstate.Add(st, a) 847 if err != nil && !asserts.IsUnaccceptedUpdate(err) { 848 return err 849 } 850 return nil 851 } 852 853 f := asserts.NewFetcher(db, retrieve, save) 854 855 keyRef := &asserts.Ref{ 856 Type: asserts.AccountKeyType, 857 PrimaryKey: []string{keyID}, 858 } 859 if err := f.Fetch(keyRef); err != nil { 860 if retrieveError { 861 return err, nil 862 } else { 863 return nil, err 864 } 865 } 866 return nil, nil 867 } 868 869 func (m *DeviceManager) doPrepareRemodeling(t *state.Task, tmb *tomb.Tomb) error { 870 st := t.State() 871 st.Lock() 872 defer st.Unlock() 873 874 remodCtx, err := remodelCtxFromTask(t) 875 if err != nil { 876 return err 877 } 878 current, err := findModel(st) 879 if err != nil { 880 return err 881 } 882 883 sto := remodCtx.Store() 884 if sto == nil { 885 return fmt.Errorf("internal error: re-registration remodeling should have built a store") 886 } 887 // ensure a new session accounting for the new brand/model 888 st.Unlock() 889 _, err = sto.EnsureDeviceSession() 890 st.Lock() 891 if err != nil { 892 return fmt.Errorf("cannot get a store session based on the new model assertion: %v", err) 893 } 894 895 chgID := t.Change().ID() 896 897 tss, err := remodelTasks(tmb.Context(nil), st, current, remodCtx.Model(), remodCtx, chgID) 898 if err != nil { 899 return err 900 } 901 902 allTs := state.NewTaskSet() 903 for _, ts := range tss { 904 allTs.AddAll(ts) 905 } 906 snapstate.InjectTasks(t, allTs) 907 908 st.EnsureBefore(0) 909 t.SetStatus(state.DoneStatus) 910 911 return nil 912 } 913 914 func snapState(st *state.State, name string) (*snapstate.SnapState, error) { 915 var snapst snapstate.SnapState 916 err := snapstate.Get(st, name, &snapst) 917 if err != nil && err != state.ErrNoState { 918 return nil, err 919 } 920 return &snapst, nil 921 } 922 923 func makeRollbackDir(name string) (string, error) { 924 rollbackDir := filepath.Join(dirs.SnapRollbackDir, name) 925 926 if err := os.MkdirAll(rollbackDir, 0750); err != nil { 927 return "", err 928 } 929 930 return rollbackDir, nil 931 } 932 933 func currentGadgetInfo(snapst *snapstate.SnapState) (*gadget.GadgetData, error) { 934 currentInfo, err := snapst.CurrentInfo() 935 if err != nil && err != snapstate.ErrNoCurrent { 936 return nil, err 937 } 938 if currentInfo == nil { 939 // no current yet 940 return nil, nil 941 } 942 const onClassic = false 943 gi, err := gadget.ReadInfo(currentInfo.MountDir(), onClassic) 944 if err != nil { 945 return nil, err 946 } 947 return &gadget.GadgetData{Info: gi, RootDir: currentInfo.MountDir()}, nil 948 } 949 950 func pendingGadgetInfo(snapsup *snapstate.SnapSetup) (*gadget.GadgetData, error) { 951 info, err := snap.ReadInfo(snapsup.InstanceName(), snapsup.SideInfo) 952 if err != nil { 953 return nil, err 954 } 955 const onClassic = false 956 update, err := gadget.ReadInfo(info.MountDir(), onClassic) 957 if err != nil { 958 return nil, err 959 } 960 return &gadget.GadgetData{Info: update, RootDir: info.MountDir()}, nil 961 } 962 963 func gadgetCurrentAndUpdate(st *state.State, snapsup *snapstate.SnapSetup) (current *gadget.GadgetData, update *gadget.GadgetData, err error) { 964 snapst, err := snapState(st, snapsup.InstanceName()) 965 if err != nil { 966 return nil, nil, err 967 } 968 969 currentData, err := currentGadgetInfo(snapst) 970 if err != nil { 971 return nil, nil, fmt.Errorf("cannot read current gadget snap details: %v", err) 972 } 973 if currentData == nil { 974 // don't bother reading update if there is no current 975 return nil, nil, nil 976 } 977 978 newData, err := pendingGadgetInfo(snapsup) 979 if err != nil { 980 return nil, nil, fmt.Errorf("cannot read candidate gadget snap details: %v", err) 981 } 982 983 return currentData, newData, nil 984 } 985 986 var ( 987 gadgetUpdate = gadget.Update 988 ) 989 990 func (m *DeviceManager) doUpdateGadgetAssets(t *state.Task, _ *tomb.Tomb) error { 991 if release.OnClassic { 992 return fmt.Errorf("cannot run update gadget assets task on a classic system") 993 } 994 995 st := t.State() 996 st.Lock() 997 defer st.Unlock() 998 999 snapsup, err := snapstate.TaskSnapSetup(t) 1000 if err != nil { 1001 return err 1002 } 1003 1004 currentData, updateData, err := gadgetCurrentAndUpdate(t.State(), snapsup) 1005 if err != nil { 1006 return err 1007 } 1008 if currentData == nil { 1009 // no updates during first boot & seeding 1010 return nil 1011 } 1012 1013 snapRollbackDir, err := makeRollbackDir(fmt.Sprintf("%v_%v", snapsup.InstanceName(), snapsup.SideInfo.Revision)) 1014 if err != nil { 1015 return fmt.Errorf("cannot prepare update rollback directory: %v", err) 1016 } 1017 1018 st.Unlock() 1019 err = gadgetUpdate(*currentData, *updateData, snapRollbackDir) 1020 st.Lock() 1021 if err != nil { 1022 if err == gadget.ErrNoUpdate { 1023 // no update needed 1024 t.Logf("No gadget assets update needed") 1025 return nil 1026 } 1027 return err 1028 } 1029 1030 t.SetStatus(state.DoneStatus) 1031 1032 if err := os.RemoveAll(snapRollbackDir); err != nil && !os.IsNotExist(err) { 1033 logger.Noticef("failed to remove gadget update rollback directory %q: %v", snapRollbackDir, err) 1034 } 1035 1036 // TODO: consider having the option to do this early via recovery in 1037 // core20, have fallback code as well there 1038 st.RequestRestart(state.RestartSystem) 1039 1040 return nil 1041 }