github.com/meulengracht/snapd@v0.0.0-20210719210640-8bde69bcc84e/overlord/devicestate/handlers_serial.go (about) 1 // -*- Mode: Go; indent-tabs-mode: t -*- 2 /* 3 * Copyright (C) 2016-2020 Canonical Ltd 4 * 5 * This program is free software: you can redistribute it and/or modify 6 * it under the terms of the GNU General Public License version 3 as 7 * published by the Free Software Foundation. 8 * 9 * This program is distributed in the hope that it will be useful, 10 * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 * GNU General Public License for more details. 13 * 14 * You should have received a copy of the GNU General Public License 15 * along with this program. If not, see <http://www.gnu.org/licenses/>. 16 * 17 */ 18 19 package devicestate 20 21 import ( 22 "bytes" 23 "crypto/rsa" 24 "encoding/json" 25 "errors" 26 "fmt" 27 "io" 28 "net/http" 29 "net/url" 30 "strconv" 31 "strings" 32 "time" 33 34 "gopkg.in/tomb.v2" 35 36 "github.com/snapcore/snapd/asserts" 37 "github.com/snapcore/snapd/httputil" 38 "github.com/snapcore/snapd/logger" 39 "github.com/snapcore/snapd/overlord/assertstate" 40 "github.com/snapcore/snapd/overlord/auth" 41 "github.com/snapcore/snapd/overlord/configstate/config" 42 "github.com/snapcore/snapd/overlord/configstate/proxyconf" 43 "github.com/snapcore/snapd/overlord/snapstate" 44 "github.com/snapcore/snapd/overlord/state" 45 "github.com/snapcore/snapd/snapdenv" 46 "github.com/snapcore/snapd/strutil" 47 "github.com/snapcore/snapd/timings" 48 ) 49 50 func baseURL() *url.URL { 51 if snapdenv.UseStagingStore() { 52 return mustParse("https://api.staging.snapcraft.io/") 53 } 54 return mustParse("https://api.snapcraft.io/") 55 } 56 57 func mustParse(s string) *url.URL { 58 u, err := url.Parse(s) 59 if err != nil { 60 panic(err) 61 } 62 return u 63 } 64 65 var ( 66 keyLength = 4096 67 retryInterval = 60 * time.Second 68 maxTentatives = 15 69 baseStoreURL = baseURL().ResolveReference(authRef) 70 71 authRef = mustParse("api/v1/snaps/auth/") // authRef must end in / for the following refs to work 72 reqIdRef = mustParse("request-id") 73 serialRef = mustParse("serial") 74 devicesRef = mustParse("devices") 75 76 // we accept a stream with the serial assertion as well 77 registrationCapabilities = []string{"serial-stream"} 78 ) 79 80 func (m *DeviceManager) doGenerateDeviceKey(t *state.Task, _ *tomb.Tomb) error { 81 st := t.State() 82 st.Lock() 83 defer st.Unlock() 84 85 perfTimings := state.TimingsForTask(t) 86 defer perfTimings.Save(st) 87 88 device, err := m.device() 89 if err != nil { 90 return err 91 } 92 93 if device.KeyID != "" { 94 // nothing to do 95 return nil 96 } 97 98 st.Unlock() 99 var keyPair *rsa.PrivateKey 100 timings.Run(perfTimings, "generate-rsa-key", "generating device key pair", func(tm timings.Measurer) { 101 keyPair, err = generateRSAKey(keyLength) 102 }) 103 st.Lock() 104 if err != nil { 105 return fmt.Errorf("cannot generate device key pair: %v", err) 106 } 107 108 privKey := asserts.RSAPrivateKey(keyPair) 109 err = m.withKeypairMgr(func(keypairMgr asserts.KeypairManager) error { 110 return keypairMgr.Put(privKey) 111 }) 112 if err != nil { 113 return fmt.Errorf("cannot store device key pair: %v", err) 114 } 115 116 device.KeyID = privKey.PublicKey().ID() 117 err = m.setDevice(device) 118 if err != nil { 119 return err 120 } 121 t.SetStatus(state.DoneStatus) 122 return nil 123 } 124 125 func newEnoughProxy(st *state.State, proxyURL *url.URL, client *http.Client) bool { 126 st.Unlock() 127 defer st.Lock() 128 129 const prefix = "Cannot check whether proxy store supports a custom serial vault" 130 131 req, err := http.NewRequest("HEAD", proxyURL.String(), nil) 132 if err != nil { 133 // can't really happen unless proxyURL is somehow broken 134 logger.Debugf(prefix+": %v", err) 135 return false 136 } 137 req.Header.Set("User-Agent", snapdenv.UserAgent()) 138 resp, err := client.Do(req) 139 if err != nil { 140 // some sort of network or protocol error 141 logger.Debugf(prefix+": %v", err) 142 return false 143 } 144 resp.Body.Close() 145 if resp.StatusCode != 200 { 146 logger.Debugf(prefix+": Head request returned %s.", resp.Status) 147 return false 148 } 149 verstr := resp.Header.Get("Snap-Store-Version") 150 ver, err := strconv.Atoi(verstr) 151 if err != nil { 152 logger.Debugf(prefix+": Bogus Snap-Store-Version header %q.", verstr) 153 return false 154 } 155 return ver >= 6 156 } 157 158 func (cfg *serialRequestConfig) setURLs(proxyURL, svcURL *url.URL) { 159 base := baseStoreURL 160 if proxyURL != nil { 161 if svcURL != nil { 162 if cfg.headers == nil { 163 cfg.headers = make(map[string]string, 1) 164 } 165 cfg.headers["X-Snap-Device-Service-URL"] = svcURL.String() 166 } 167 base = proxyURL.ResolveReference(authRef) 168 } else if svcURL != nil { 169 base = svcURL 170 } 171 172 cfg.requestIDURL = base.ResolveReference(reqIdRef).String() 173 if svcURL != nil && proxyURL == nil { 174 // talking directly to the custom device service 175 cfg.serialRequestURL = base.ResolveReference(serialRef).String() 176 } else { 177 cfg.serialRequestURL = base.ResolveReference(devicesRef).String() 178 } 179 } 180 181 // A registrationContext handles the contextual information needed 182 // for the initial registration or a re-registration. 183 type registrationContext interface { 184 Device() (*auth.DeviceState, error) 185 186 Model() *asserts.Model 187 188 GadgetForSerialRequestConfig() string 189 SerialRequestExtraHeaders() map[string]interface{} 190 SerialRequestAncillaryAssertions() []asserts.Assertion 191 192 FinishRegistration(serial *asserts.Serial) error 193 194 ForRemodeling() bool 195 } 196 197 // initialRegistrationContext is a thin wrapper around DeviceManager 198 // implementing registrationContext for initial regitration 199 type initialRegistrationContext struct { 200 deviceMgr *DeviceManager 201 202 model *asserts.Model 203 } 204 205 func (rc *initialRegistrationContext) ForRemodeling() bool { 206 return false 207 } 208 209 func (rc *initialRegistrationContext) Device() (*auth.DeviceState, error) { 210 return rc.deviceMgr.device() 211 } 212 213 func (rc *initialRegistrationContext) Model() *asserts.Model { 214 return rc.model 215 } 216 217 func (rc *initialRegistrationContext) GadgetForSerialRequestConfig() string { 218 return rc.model.Gadget() 219 } 220 221 func (rc *initialRegistrationContext) SerialRequestExtraHeaders() map[string]interface{} { 222 return nil 223 } 224 225 func (rc *initialRegistrationContext) SerialRequestAncillaryAssertions() []asserts.Assertion { 226 return []asserts.Assertion{rc.model} 227 } 228 229 func (rc *initialRegistrationContext) FinishRegistration(serial *asserts.Serial) error { 230 device, err := rc.deviceMgr.device() 231 if err != nil { 232 return err 233 } 234 235 device.Serial = serial.Serial() 236 if err := rc.deviceMgr.setDevice(device); err != nil { 237 return err 238 } 239 rc.deviceMgr.markRegistered() 240 241 // make sure we timely consider anything that was blocked on 242 // registration 243 rc.deviceMgr.state.EnsureBefore(0) 244 245 return nil 246 } 247 248 // registrationCtx returns a registrationContext appropriate for the task and its change. 249 func (m *DeviceManager) registrationCtx(t *state.Task) (registrationContext, error) { 250 remodCtx, err := remodelCtxFromTask(t) 251 if err != nil && err != state.ErrNoState { 252 return nil, err 253 } 254 if regCtx, ok := remodCtx.(registrationContext); ok { 255 return regCtx, nil 256 } 257 model, err := m.Model() 258 if err != nil { 259 return nil, err 260 } 261 262 return &initialRegistrationContext{ 263 deviceMgr: m, 264 model: model, 265 }, nil 266 } 267 268 type serialSetup struct { 269 SerialRequest string `json:"serial-request"` 270 Serial string `json:"serial"` 271 } 272 273 type requestIDResp struct { 274 RequestID string `json:"request-id"` 275 } 276 277 func retryErr(t *state.Task, nTentatives int, reason string, a ...interface{}) error { 278 t.State().Lock() 279 defer t.State().Unlock() 280 if nTentatives >= maxTentatives { 281 return fmt.Errorf(reason, a...) 282 } 283 t.Errorf(reason, a...) 284 return &state.Retry{After: retryInterval} 285 } 286 287 type serverError struct { 288 Message string `json:"message"` 289 Errors []*serverError `json:"error_list"` 290 } 291 292 func retryBadStatus(t *state.Task, nTentatives int, reason string, resp *http.Response) error { 293 if resp.StatusCode > 500 { 294 // likely temporary 295 return retryErr(t, nTentatives, "%s: unexpected status %d", reason, resp.StatusCode) 296 } 297 if resp.Header.Get("Content-Type") == "application/json" { 298 var srvErr serverError 299 dec := json.NewDecoder(resp.Body) 300 err := dec.Decode(&srvErr) 301 if err == nil { 302 msg := srvErr.Message 303 if msg == "" && len(srvErr.Errors) > 0 { 304 msg = srvErr.Errors[0].Message 305 } 306 if msg != "" { 307 return fmt.Errorf("%s: %s", reason, msg) 308 } 309 } 310 } 311 return fmt.Errorf("%s: unexpected status %d", reason, resp.StatusCode) 312 } 313 314 func prepareSerialRequest(t *state.Task, regCtx registrationContext, privKey asserts.PrivateKey, device *auth.DeviceState, client *http.Client, cfg *serialRequestConfig) (string, error) { 315 // limit tentatives starting from scratch before going to 316 // slower full retries 317 var nTentatives int 318 err := t.Get("pre-poll-tentatives", &nTentatives) 319 if err != nil && err != state.ErrNoState { 320 return "", err 321 } 322 nTentatives++ 323 t.Set("pre-poll-tentatives", nTentatives) 324 325 st := t.State() 326 st.Unlock() 327 defer st.Lock() 328 329 req, err := http.NewRequest("POST", cfg.requestIDURL, nil) 330 if err != nil { 331 return "", fmt.Errorf("internal error: cannot create request-id request %q", cfg.requestIDURL) 332 } 333 req.Header.Set("User-Agent", snapdenv.UserAgent()) 334 cfg.applyHeaders(req) 335 336 resp, err := client.Do(req) 337 if err != nil { 338 if !httputil.ShouldRetryError(err) { 339 // a non temporary net error fully errors out and triggers a retry 340 // retries 341 return "", fmt.Errorf("cannot retrieve request-id for making a request for a serial: %v", err) 342 } 343 if httputil.NoNetwork(err) { 344 // If there is no network there is no need to count 345 // this as a tentatives attempt. If we do it this 346 // way the risk is that we tried a bunch of times 347 // with no network and if we hit the server for real 348 // and it replies with something we need to retry 349 // we will not because nTentatives is way over the 350 // limit. 351 st.Lock() 352 t.Set("pre-poll-tentatives", 0) 353 st.Unlock() 354 // Retry quickly if there is no network 355 // (yet). This ensures that we try to get a serial 356 // as soon as the user configured the network of the 357 // device 358 noNetworkRetryInterval := retryInterval / 2 359 return "", &state.Retry{After: noNetworkRetryInterval} 360 } 361 362 return "", retryErr(t, nTentatives, "cannot retrieve request-id for making a request for a serial: %v", err) 363 } 364 defer resp.Body.Close() 365 if resp.StatusCode != 200 { 366 return "", retryBadStatus(t, nTentatives, "cannot retrieve request-id for making a request for a serial", resp) 367 } 368 369 dec := json.NewDecoder(resp.Body) 370 var requestID requestIDResp 371 err = dec.Decode(&requestID) 372 if err != nil { // assume broken i/o 373 return "", retryErr(t, nTentatives, "cannot read response with request-id for making a request for a serial: %v", err) 374 } 375 376 encodedPubKey, err := asserts.EncodePublicKey(privKey.PublicKey()) 377 if err != nil { 378 return "", fmt.Errorf("internal error: cannot encode device public key: %v", err) 379 380 } 381 382 headers := map[string]interface{}{ 383 "brand-id": device.Brand, 384 "model": device.Model, 385 "request-id": requestID.RequestID, 386 "device-key": string(encodedPubKey), 387 } 388 if cfg.proposedSerial != "" { 389 headers["serial"] = cfg.proposedSerial 390 } 391 392 for k, v := range regCtx.SerialRequestExtraHeaders() { 393 headers[k] = v 394 } 395 396 serialReq, err := asserts.SignWithoutAuthority(asserts.SerialRequestType, headers, cfg.body, privKey) 397 if err != nil { 398 return "", err 399 } 400 401 buf := new(bytes.Buffer) 402 encoder := asserts.NewEncoder(buf) 403 if err := encoder.Encode(serialReq); err != nil { 404 return "", fmt.Errorf("cannot encode serial-request: %v", err) 405 } 406 407 for _, ancillaryAs := range regCtx.SerialRequestAncillaryAssertions() { 408 if err := encoder.Encode(ancillaryAs); err != nil { 409 return "", fmt.Errorf("cannot encode ancillary assertion: %v", err) 410 } 411 412 } 413 414 return buf.String(), nil 415 } 416 417 var errPoll = errors.New("serial-request accepted, poll later") 418 419 func submitSerialRequest(t *state.Task, serialRequest string, client *http.Client, cfg *serialRequestConfig) (*asserts.Serial, *asserts.Batch, error) { 420 st := t.State() 421 st.Unlock() 422 defer st.Lock() 423 424 req, err := http.NewRequest("POST", cfg.serialRequestURL, bytes.NewBufferString(serialRequest)) 425 if err != nil { 426 return nil, nil, fmt.Errorf("internal error: cannot create serial-request request %q", cfg.serialRequestURL) 427 } 428 req.Header.Set("User-Agent", snapdenv.UserAgent()) 429 req.Header.Set("Snap-Device-Capabilities", strings.Join(registrationCapabilities, " ")) 430 cfg.applyHeaders(req) 431 req.Header.Set("Content-Type", asserts.MediaType) 432 433 resp, err := client.Do(req) 434 if err != nil { 435 return nil, nil, retryErr(t, 0, "cannot deliver device serial request: %v", err) 436 } 437 defer resp.Body.Close() 438 439 switch resp.StatusCode { 440 case 200, 201: 441 case 202: 442 return nil, nil, errPoll 443 default: 444 return nil, nil, retryBadStatus(t, 0, "cannot deliver device serial request", resp) 445 } 446 447 var serial *asserts.Serial 448 var batch *asserts.Batch 449 // decode body with stream of assertions, of which one is the serial 450 dec := asserts.NewDecoder(resp.Body) 451 for { 452 got, err := dec.Decode() 453 if err == io.EOF { 454 break 455 } 456 if err != nil { // assume broken i/o 457 return nil, nil, retryErr(t, 0, "cannot read response to request for a serial: %v", err) 458 } 459 if got.Type() == asserts.SerialType { 460 if serial != nil { 461 return nil, nil, fmt.Errorf("cannot accept more than a single device serial assertion from the device service") 462 } 463 serial = got.(*asserts.Serial) 464 } else { 465 if batch == nil { 466 batch = asserts.NewBatch(nil) 467 } 468 if err := batch.Add(got); err != nil { 469 return nil, nil, err 470 } 471 } 472 // TODO: consider a size limit? 473 } 474 475 if serial == nil { 476 return nil, nil, fmt.Errorf("cannot proceed, received assertion stream from the device service missing device serial assertion") 477 } 478 479 return serial, batch, nil 480 } 481 482 var httputilNewHTTPClient = httputil.NewHTTPClient 483 484 func getSerial(t *state.Task, regCtx registrationContext, privKey asserts.PrivateKey, device *auth.DeviceState, tm timings.Measurer) (serial *asserts.Serial, ancillaryBatch *asserts.Batch, err error) { 485 var serialSup serialSetup 486 err = t.Get("serial-setup", &serialSup) 487 if err != nil && err != state.ErrNoState { 488 return nil, nil, err 489 } 490 491 if serialSup.Serial != "" { 492 // we got a serial, just haven't managed to save its info yet 493 a, err := asserts.Decode([]byte(serialSup.Serial)) 494 if err != nil { 495 return nil, nil, fmt.Errorf("internal error: cannot decode previously saved serial: %v", err) 496 } 497 return a.(*asserts.Serial), nil, nil 498 } 499 500 st := t.State() 501 proxyConf := proxyconf.New(st) 502 client := httputilNewHTTPClient(&httputil.ClientOptions{ 503 Timeout: 30 * time.Second, 504 MayLogBody: true, 505 Proxy: proxyConf.Conf, 506 ProxyConnectHeader: http.Header{"User-Agent": []string{snapdenv.UserAgent()}}, 507 }) 508 509 cfg, err := getSerialRequestConfig(t, regCtx, client) 510 if err != nil { 511 return nil, nil, err 512 } 513 514 // NB: until we get at least an Accepted (202) we need to 515 // retry from scratch creating a new request-id because the 516 // previous one used could have expired 517 518 if serialSup.SerialRequest == "" { 519 var serialRequest string 520 var err error 521 timings.Run(tm, "prepare-serial-request", "prepare device serial request", func(timings.Measurer) { 522 serialRequest, err = prepareSerialRequest(t, regCtx, privKey, device, client, cfg) 523 }) 524 if err != nil { // errors & retries 525 return nil, nil, err 526 } 527 528 serialSup.SerialRequest = serialRequest 529 } 530 531 timings.Run(tm, "submit-serial-request", "submit device serial request", func(timings.Measurer) { 532 serial, ancillaryBatch, err = submitSerialRequest(t, serialSup.SerialRequest, client, cfg) 533 }) 534 if err == errPoll { 535 // we can/should reuse the serial-request 536 t.Set("serial-setup", serialSup) 537 return nil, nil, errPoll 538 } 539 if err != nil { // errors & retries 540 return nil, nil, err 541 } 542 543 keyID := privKey.PublicKey().ID() 544 if serial.BrandID() != device.Brand || serial.Model() != device.Model || serial.DeviceKey().ID() != keyID { 545 return nil, nil, fmt.Errorf("obtained serial assertion does not match provided device identity information (brand, model, key id): %s / %s / %s != %s / %s / %s", serial.BrandID(), serial.Model(), serial.DeviceKey().ID(), device.Brand, device.Model, keyID) 546 } 547 548 // cross check authority if different from brand-id 549 if serial.BrandID() != serial.AuthorityID() { 550 model := regCtx.Model() 551 if !strutil.ListContains(model.SerialAuthority(), serial.AuthorityID()) { 552 return nil, nil, fmt.Errorf("obtained serial assertion is signed by authority %q different from brand %q without model assertion with serial-authority set to to allow for them", serial.AuthorityID(), serial.BrandID()) 553 } 554 } 555 556 if ancillaryBatch == nil { 557 serialSup.Serial = string(asserts.Encode(serial)) 558 t.Set("serial-setup", serialSup) 559 } 560 561 if repeatRequestSerial == "after-got-serial" { 562 // For testing purposes, ensure a crash in this state works. 563 return nil, nil, &state.Retry{} 564 } 565 566 return serial, ancillaryBatch, nil 567 } 568 569 type serialRequestConfig struct { 570 requestIDURL string 571 serialRequestURL string 572 headers map[string]string 573 proposedSerial string 574 body []byte 575 } 576 577 func (cfg *serialRequestConfig) applyHeaders(req *http.Request) { 578 for k, v := range cfg.headers { 579 req.Header.Set(k, v) 580 } 581 } 582 583 func getSerialRequestConfig(t *state.Task, regCtx registrationContext, client *http.Client) (*serialRequestConfig, error) { 584 var svcURL, proxyURL *url.URL 585 586 st := t.State() 587 tr := config.NewTransaction(st) 588 if proxyStore, err := proxyStore(st, tr); err != nil && err != state.ErrNoState { 589 return nil, err 590 } else if proxyStore != nil { 591 proxyURL = proxyStore.URL() 592 } 593 594 cfg := serialRequestConfig{} 595 596 gadgetName := regCtx.GadgetForSerialRequestConfig() 597 // gadget is optional on classic 598 if gadgetName != "" { 599 var gadgetSt snapstate.SnapState 600 if err := snapstate.Get(st, gadgetName, &gadgetSt); err != nil { 601 return nil, fmt.Errorf("cannot find gadget snap %q: %v", gadgetName, err) 602 } 603 604 var svcURI string 605 err := tr.GetMaybe(gadgetName, "device-service.url", &svcURI) 606 if err != nil { 607 return nil, err 608 } 609 610 if svcURI != "" { 611 svcURL, err = url.Parse(svcURI) 612 if err != nil { 613 return nil, fmt.Errorf("cannot parse device registration base URL %q: %v", svcURI, err) 614 } 615 if !strings.HasSuffix(svcURL.Path, "/") { 616 svcURL.Path += "/" 617 } 618 } 619 620 err = tr.GetMaybe(gadgetName, "device-service.headers", &cfg.headers) 621 if err != nil { 622 return nil, err 623 } 624 625 var bodyStr string 626 err = tr.GetMaybe(gadgetName, "registration.body", &bodyStr) 627 if err != nil { 628 return nil, err 629 } 630 631 cfg.body = []byte(bodyStr) 632 633 err = tr.GetMaybe(gadgetName, "registration.proposed-serial", &cfg.proposedSerial) 634 if err != nil { 635 return nil, err 636 } 637 } 638 639 if proxyURL != nil && svcURL != nil && !newEnoughProxy(st, proxyURL, client) { 640 logger.Noticef("Proxy store does not support custom serial vault; ignoring the proxy") 641 proxyURL = nil 642 } 643 644 cfg.setURLs(proxyURL, svcURL) 645 646 return &cfg, nil 647 } 648 649 func (m *DeviceManager) doRequestSerial(t *state.Task, _ *tomb.Tomb) error { 650 st := t.State() 651 st.Lock() 652 defer st.Unlock() 653 654 perfTimings := state.TimingsForTask(t) 655 defer perfTimings.Save(st) 656 657 regCtx, err := m.registrationCtx(t) 658 if err != nil { 659 return err 660 } 661 662 device, err := regCtx.Device() 663 if err != nil { 664 return err 665 } 666 667 // NB: the keyPair is fixed for now 668 privKey, err := m.keyPair() 669 if err == state.ErrNoState { 670 return fmt.Errorf("internal error: cannot find device key pair") 671 } 672 if err != nil { 673 return err 674 } 675 676 // make this idempotent, look if we have already a serial assertion 677 // for privKey 678 serials, err := assertstate.DB(st).FindMany(asserts.SerialType, map[string]string{ 679 "brand-id": device.Brand, 680 "model": device.Model, 681 "device-key-sha3-384": privKey.PublicKey().ID(), 682 }) 683 if err != nil && !asserts.IsNotFound(err) { 684 return err 685 } 686 687 finish := func(serial *asserts.Serial) error { 688 // save serial if appropriate into the device save 689 // assertion database 690 err := m.withSaveAssertDB(func(savedb *asserts.Database) error { 691 db := assertstate.DB(st) 692 retrieve := func(ref *asserts.Ref) (asserts.Assertion, error) { 693 return ref.Resolve(db.Find) 694 } 695 b := asserts.NewBatch(nil) 696 err := b.Fetch(savedb, retrieve, func(f asserts.Fetcher) error { 697 // save the associated model as well 698 // as it might be required for cross-checks 699 // of the serial 700 if err := f.Save(regCtx.Model()); err != nil { 701 return err 702 } 703 return f.Save(serial) 704 }) 705 if err != nil { 706 return err 707 } 708 return b.CommitTo(savedb, nil) 709 }) 710 if err != nil && err != errNoSaveSupport { 711 return fmt.Errorf("cannot save serial to device save assertion database: %v", err) 712 } 713 714 if err := regCtx.FinishRegistration(serial); err != nil { 715 return err 716 } 717 t.SetStatus(state.DoneStatus) 718 return nil 719 } 720 721 if len(serials) == 1 { 722 // means we saved the assertion but didn't get to the end of the task 723 return finish(serials[0].(*asserts.Serial)) 724 } 725 if len(serials) > 1 { 726 return fmt.Errorf("internal error: multiple serial assertions for the same device key") 727 } 728 729 var serial *asserts.Serial 730 var ancillaryBatch *asserts.Batch 731 timings.Run(perfTimings, "get-serial", "get device serial", func(tm timings.Measurer) { 732 serial, ancillaryBatch, err = getSerial(t, regCtx, privKey, device, tm) 733 }) 734 if err == errPoll { 735 t.Logf("Will poll for device serial assertion in 60 seconds") 736 return &state.Retry{After: retryInterval} 737 } 738 if err != nil { // errors & retries 739 return err 740 741 } 742 743 // TODO: the accept* helpers put the serial directly in the 744 // system assertion database, that will not work 745 // for 3rd-party signed serials in the case of a remodel 746 // because the model is added only later. If needed, the best way 747 // to fix this requires rethinking how remodel and new assertions 748 // interact 749 if ancillaryBatch == nil { 750 // the device service returned only the serial 751 if err := acceptSerialOnly(t, serial, perfTimings); err != nil { 752 return err 753 } 754 } else { 755 // the device service returned a stream of assertions 756 timings.Run(perfTimings, "fetch-keys", "fetch signing key chain", func(timings.Measurer) { 757 err = acceptSerialPlusBatch(t, serial, ancillaryBatch) 758 }) 759 if err != nil { 760 t.Errorf("cannot accept stream of assertions from device service: %v", err) 761 return err 762 } 763 } 764 765 if repeatRequestSerial == "after-add-serial" { 766 // For testing purposes, ensure a crash in this state works. 767 return &state.Retry{} 768 } 769 770 return finish(serial) 771 } 772 773 func acceptSerialOnly(t *state.Task, serial *asserts.Serial, perfTimings *timings.Timings) error { 774 st := t.State() 775 var err error 776 var errAcctKey error 777 // try to fetch the signing key chain of the serial 778 timings.Run(perfTimings, "fetch-keys", "fetch signing key chain", func(timings.Measurer) { 779 errAcctKey, err = fetchKeys(st, serial.SignKeyID()) 780 }) 781 if err != nil { 782 return err 783 } 784 785 // add the serial assertion to the system assertion db 786 err = assertstate.Add(st, serial) 787 if err != nil { 788 // if we had failed to fetch the signing key, retry in a bit 789 if errAcctKey != nil { 790 t.Errorf("cannot fetch signing key for the serial: %v", errAcctKey) 791 return &state.Retry{After: retryInterval} 792 } 793 return err 794 } 795 796 return nil 797 } 798 799 func acceptSerialPlusBatch(t *state.Task, serial *asserts.Serial, batch *asserts.Batch) error { 800 st := t.State() 801 err := batch.Add(serial) 802 if err != nil { 803 return err 804 } 805 return assertstate.AddBatch(st, batch, &asserts.CommitOptions{Precheck: true}) 806 } 807 808 var repeatRequestSerial string // for tests 809 810 func fetchKeys(st *state.State, keyID string) (errAcctKey error, err error) { 811 // TODO: right now any store should be good enough here but 812 // that might change. As an alternative we do support 813 // receiving a stream with any relevant assertions. 814 sto := snapstate.Store(st, nil) 815 db := assertstate.DB(st) 816 817 retrieveError := false 818 retrieve := func(ref *asserts.Ref) (asserts.Assertion, error) { 819 st.Unlock() 820 defer st.Lock() 821 a, err := sto.Assertion(ref.Type, ref.PrimaryKey, nil) 822 retrieveError = err != nil 823 return a, err 824 } 825 826 save := func(a asserts.Assertion) error { 827 err = assertstate.Add(st, a) 828 if err != nil && !asserts.IsUnaccceptedUpdate(err) { 829 return err 830 } 831 return nil 832 } 833 834 f := asserts.NewFetcher(db, retrieve, save) 835 836 keyRef := &asserts.Ref{ 837 Type: asserts.AccountKeyType, 838 PrimaryKey: []string{keyID}, 839 } 840 if err := f.Fetch(keyRef); err != nil { 841 if retrieveError { 842 return err, nil 843 } else { 844 return nil, err 845 } 846 } 847 return nil, nil 848 }