github.com/hugh712/snapd@v0.0.0-20200910133618-1a99902bd583/overlord/devicestate/handlers_serial.go (about) 1 // -*- Mode: Go; indent-tabs-mode: t -*- 2 /* 3 * Copyright (C) 2016-2020 Canonical Ltd 4 * 5 * This program is free software: you can redistribute it and/or modify 6 * it under the terms of the GNU General Public License version 3 as 7 * published by the Free Software Foundation. 8 * 9 * This program is distributed in the hope that it will be useful, 10 * but WITHOUT ANY WARRANTY; without even the implied warranty of 11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 * GNU General Public License for more details. 13 * 14 * You should have received a copy of the GNU General Public License 15 * along with this program. If not, see <http://www.gnu.org/licenses/>. 16 * 17 */ 18 19 package devicestate 20 21 import ( 22 "bytes" 23 "crypto/rsa" 24 "encoding/json" 25 "errors" 26 "fmt" 27 "io" 28 "net/http" 29 "net/url" 30 "strconv" 31 "strings" 32 "time" 33 34 "gopkg.in/tomb.v2" 35 36 "github.com/snapcore/snapd/asserts" 37 "github.com/snapcore/snapd/httputil" 38 "github.com/snapcore/snapd/logger" 39 "github.com/snapcore/snapd/overlord/assertstate" 40 "github.com/snapcore/snapd/overlord/auth" 41 "github.com/snapcore/snapd/overlord/configstate/config" 42 "github.com/snapcore/snapd/overlord/configstate/proxyconf" 43 "github.com/snapcore/snapd/overlord/snapstate" 44 "github.com/snapcore/snapd/overlord/state" 45 "github.com/snapcore/snapd/snapdenv" 46 "github.com/snapcore/snapd/strutil" 47 "github.com/snapcore/snapd/timings" 48 ) 49 50 func baseURL() *url.URL { 51 if snapdenv.UseStagingStore() { 52 return mustParse("https://api.staging.snapcraft.io/") 53 } 54 return mustParse("https://api.snapcraft.io/") 55 } 56 57 func mustParse(s string) *url.URL { 58 u, err := url.Parse(s) 59 if err != nil { 60 panic(err) 61 } 62 return u 63 } 64 65 var ( 66 keyLength = 4096 67 retryInterval = 60 * time.Second 68 maxTentatives = 15 69 baseStoreURL = baseURL().ResolveReference(authRef) 70 71 authRef = mustParse("api/v1/snaps/auth/") // authRef must end in / for the following refs to work 72 reqIdRef = mustParse("request-id") 73 serialRef = mustParse("serial") 74 devicesRef = mustParse("devices") 75 76 // we accept a stream with the serial assertion as well 77 registrationCapabilities = []string{"serial-stream"} 78 ) 79 80 func (m *DeviceManager) doGenerateDeviceKey(t *state.Task, _ *tomb.Tomb) error { 81 st := t.State() 82 st.Lock() 83 defer st.Unlock() 84 85 perfTimings := state.TimingsForTask(t) 86 defer perfTimings.Save(st) 87 88 device, err := m.device() 89 if err != nil { 90 return err 91 } 92 93 if device.KeyID != "" { 94 // nothing to do 95 return nil 96 } 97 98 st.Unlock() 99 var keyPair *rsa.PrivateKey 100 timings.Run(perfTimings, "generate-rsa-key", "generating device key pair", func(tm timings.Measurer) { 101 keyPair, err = generateRSAKey(keyLength) 102 }) 103 st.Lock() 104 if err != nil { 105 return fmt.Errorf("cannot generate device key pair: %v", err) 106 } 107 108 privKey := asserts.RSAPrivateKey(keyPair) 109 err = m.keypairMgr.Put(privKey) 110 if err != nil { 111 return fmt.Errorf("cannot store device key pair: %v", err) 112 } 113 114 device.KeyID = privKey.PublicKey().ID() 115 err = m.setDevice(device) 116 if err != nil { 117 return err 118 } 119 t.SetStatus(state.DoneStatus) 120 return nil 121 } 122 123 func newEnoughProxy(st *state.State, proxyURL *url.URL, client *http.Client) bool { 124 st.Unlock() 125 defer st.Lock() 126 127 const prefix = "Cannot check whether proxy store supports a custom serial vault" 128 129 req, err := http.NewRequest("HEAD", proxyURL.String(), nil) 130 if err != nil { 131 // can't really happen unless proxyURL is somehow broken 132 logger.Debugf(prefix+": %v", err) 133 return false 134 } 135 req.Header.Set("User-Agent", snapdenv.UserAgent()) 136 resp, err := client.Do(req) 137 if err != nil { 138 // some sort of network or protocol error 139 logger.Debugf(prefix+": %v", err) 140 return false 141 } 142 resp.Body.Close() 143 if resp.StatusCode != 200 { 144 logger.Debugf(prefix+": Head request returned %s.", resp.Status) 145 return false 146 } 147 verstr := resp.Header.Get("Snap-Store-Version") 148 ver, err := strconv.Atoi(verstr) 149 if err != nil { 150 logger.Debugf(prefix+": Bogus Snap-Store-Version header %q.", verstr) 151 return false 152 } 153 return ver >= 6 154 } 155 156 func (cfg *serialRequestConfig) setURLs(proxyURL, svcURL *url.URL) { 157 base := baseStoreURL 158 if proxyURL != nil { 159 if svcURL != nil { 160 if cfg.headers == nil { 161 cfg.headers = make(map[string]string, 1) 162 } 163 cfg.headers["X-Snap-Device-Service-URL"] = svcURL.String() 164 } 165 base = proxyURL.ResolveReference(authRef) 166 } else if svcURL != nil { 167 base = svcURL 168 } 169 170 cfg.requestIDURL = base.ResolveReference(reqIdRef).String() 171 if svcURL != nil && proxyURL == nil { 172 // talking directly to the custom device service 173 cfg.serialRequestURL = base.ResolveReference(serialRef).String() 174 } else { 175 cfg.serialRequestURL = base.ResolveReference(devicesRef).String() 176 } 177 } 178 179 // A registrationContext handles the contextual information needed 180 // for the initial registration or a re-registration. 181 type registrationContext interface { 182 Device() (*auth.DeviceState, error) 183 184 Model() *asserts.Model 185 186 GadgetForSerialRequestConfig() string 187 SerialRequestExtraHeaders() map[string]interface{} 188 SerialRequestAncillaryAssertions() []asserts.Assertion 189 190 FinishRegistration(serial *asserts.Serial) error 191 192 ForRemodeling() bool 193 } 194 195 // initialRegistrationContext is a thin wrapper around DeviceManager 196 // implementing registrationContext for initial regitration 197 type initialRegistrationContext struct { 198 deviceMgr *DeviceManager 199 200 model *asserts.Model 201 } 202 203 func (rc *initialRegistrationContext) ForRemodeling() bool { 204 return false 205 } 206 207 func (rc *initialRegistrationContext) Device() (*auth.DeviceState, error) { 208 return rc.deviceMgr.device() 209 } 210 211 func (rc *initialRegistrationContext) Model() *asserts.Model { 212 return rc.model 213 } 214 215 func (rc *initialRegistrationContext) GadgetForSerialRequestConfig() string { 216 return rc.model.Gadget() 217 } 218 219 func (rc *initialRegistrationContext) SerialRequestExtraHeaders() map[string]interface{} { 220 return nil 221 } 222 223 func (rc *initialRegistrationContext) SerialRequestAncillaryAssertions() []asserts.Assertion { 224 return []asserts.Assertion{rc.model} 225 } 226 227 func (rc *initialRegistrationContext) FinishRegistration(serial *asserts.Serial) error { 228 device, err := rc.deviceMgr.device() 229 if err != nil { 230 return err 231 } 232 233 device.Serial = serial.Serial() 234 if err := rc.deviceMgr.setDevice(device); err != nil { 235 return err 236 } 237 rc.deviceMgr.markRegistered() 238 239 // make sure we timely consider anything that was blocked on 240 // registration 241 rc.deviceMgr.state.EnsureBefore(0) 242 243 return nil 244 } 245 246 // registrationCtx returns a registrationContext appropriate for the task and its change. 247 func (m *DeviceManager) registrationCtx(t *state.Task) (registrationContext, error) { 248 remodCtx, err := remodelCtxFromTask(t) 249 if err != nil && err != state.ErrNoState { 250 return nil, err 251 } 252 if regCtx, ok := remodCtx.(registrationContext); ok { 253 return regCtx, nil 254 } 255 model, err := m.Model() 256 if err != nil { 257 return nil, err 258 } 259 260 return &initialRegistrationContext{ 261 deviceMgr: m, 262 model: model, 263 }, nil 264 } 265 266 type serialSetup struct { 267 SerialRequest string `json:"serial-request"` 268 Serial string `json:"serial"` 269 } 270 271 type requestIDResp struct { 272 RequestID string `json:"request-id"` 273 } 274 275 func retryErr(t *state.Task, nTentatives int, reason string, a ...interface{}) error { 276 t.State().Lock() 277 defer t.State().Unlock() 278 if nTentatives >= maxTentatives { 279 return fmt.Errorf(reason, a...) 280 } 281 t.Errorf(reason, a...) 282 return &state.Retry{After: retryInterval} 283 } 284 285 type serverError struct { 286 Message string `json:"message"` 287 Errors []*serverError `json:"error_list"` 288 } 289 290 func retryBadStatus(t *state.Task, nTentatives int, reason string, resp *http.Response) error { 291 if resp.StatusCode > 500 { 292 // likely temporary 293 return retryErr(t, nTentatives, "%s: unexpected status %d", reason, resp.StatusCode) 294 } 295 if resp.Header.Get("Content-Type") == "application/json" { 296 var srvErr serverError 297 dec := json.NewDecoder(resp.Body) 298 err := dec.Decode(&srvErr) 299 if err == nil { 300 msg := srvErr.Message 301 if msg == "" && len(srvErr.Errors) > 0 { 302 msg = srvErr.Errors[0].Message 303 } 304 if msg != "" { 305 return fmt.Errorf("%s: %s", reason, msg) 306 } 307 } 308 } 309 return fmt.Errorf("%s: unexpected status %d", reason, resp.StatusCode) 310 } 311 312 func prepareSerialRequest(t *state.Task, regCtx registrationContext, privKey asserts.PrivateKey, device *auth.DeviceState, client *http.Client, cfg *serialRequestConfig) (string, error) { 313 // limit tentatives starting from scratch before going to 314 // slower full retries 315 var nTentatives int 316 err := t.Get("pre-poll-tentatives", &nTentatives) 317 if err != nil && err != state.ErrNoState { 318 return "", err 319 } 320 nTentatives++ 321 t.Set("pre-poll-tentatives", nTentatives) 322 323 st := t.State() 324 st.Unlock() 325 defer st.Lock() 326 327 req, err := http.NewRequest("POST", cfg.requestIDURL, nil) 328 if err != nil { 329 return "", fmt.Errorf("internal error: cannot create request-id request %q", cfg.requestIDURL) 330 } 331 req.Header.Set("User-Agent", snapdenv.UserAgent()) 332 cfg.applyHeaders(req) 333 334 resp, err := client.Do(req) 335 if err != nil { 336 if !httputil.ShouldRetryError(err) { 337 // a non temporary net error fully errors out and triggers a retry 338 // retries 339 return "", fmt.Errorf("cannot retrieve request-id for making a request for a serial: %v", err) 340 } 341 if httputil.NoNetwork(err) { 342 // If there is no network there is no need to count 343 // this as a tentatives attempt. If we do it this 344 // way the risk is that we tried a bunch of times 345 // with no network and if we hit the server for real 346 // and it replies with something we need to retry 347 // we will not because nTentatives is way over the 348 // limit. 349 st.Lock() 350 t.Set("pre-poll-tentatives", 0) 351 st.Unlock() 352 // Retry quickly if there is no network 353 // (yet). This ensures that we try to get a serial 354 // as soon as the user configured the network of the 355 // device 356 noNetworkRetryInterval := retryInterval / 2 357 return "", &state.Retry{After: noNetworkRetryInterval} 358 } 359 360 return "", retryErr(t, nTentatives, "cannot retrieve request-id for making a request for a serial: %v", err) 361 } 362 defer resp.Body.Close() 363 if resp.StatusCode != 200 { 364 return "", retryBadStatus(t, nTentatives, "cannot retrieve request-id for making a request for a serial", resp) 365 } 366 367 dec := json.NewDecoder(resp.Body) 368 var requestID requestIDResp 369 err = dec.Decode(&requestID) 370 if err != nil { // assume broken i/o 371 return "", retryErr(t, nTentatives, "cannot read response with request-id for making a request for a serial: %v", err) 372 } 373 374 encodedPubKey, err := asserts.EncodePublicKey(privKey.PublicKey()) 375 if err != nil { 376 return "", fmt.Errorf("internal error: cannot encode device public key: %v", err) 377 378 } 379 380 headers := map[string]interface{}{ 381 "brand-id": device.Brand, 382 "model": device.Model, 383 "request-id": requestID.RequestID, 384 "device-key": string(encodedPubKey), 385 } 386 if cfg.proposedSerial != "" { 387 headers["serial"] = cfg.proposedSerial 388 } 389 390 for k, v := range regCtx.SerialRequestExtraHeaders() { 391 headers[k] = v 392 } 393 394 serialReq, err := asserts.SignWithoutAuthority(asserts.SerialRequestType, headers, cfg.body, privKey) 395 if err != nil { 396 return "", err 397 } 398 399 buf := new(bytes.Buffer) 400 encoder := asserts.NewEncoder(buf) 401 if err := encoder.Encode(serialReq); err != nil { 402 return "", fmt.Errorf("cannot encode serial-request: %v", err) 403 } 404 405 for _, ancillaryAs := range regCtx.SerialRequestAncillaryAssertions() { 406 if err := encoder.Encode(ancillaryAs); err != nil { 407 return "", fmt.Errorf("cannot encode ancillary assertion: %v", err) 408 } 409 410 } 411 412 return buf.String(), nil 413 } 414 415 var errPoll = errors.New("serial-request accepted, poll later") 416 417 func submitSerialRequest(t *state.Task, serialRequest string, client *http.Client, cfg *serialRequestConfig) (*asserts.Serial, *asserts.Batch, error) { 418 st := t.State() 419 st.Unlock() 420 defer st.Lock() 421 422 req, err := http.NewRequest("POST", cfg.serialRequestURL, bytes.NewBufferString(serialRequest)) 423 if err != nil { 424 return nil, nil, fmt.Errorf("internal error: cannot create serial-request request %q", cfg.serialRequestURL) 425 } 426 req.Header.Set("User-Agent", snapdenv.UserAgent()) 427 req.Header.Set("Snap-Device-Capabilities", strings.Join(registrationCapabilities, " ")) 428 cfg.applyHeaders(req) 429 req.Header.Set("Content-Type", asserts.MediaType) 430 431 resp, err := client.Do(req) 432 if err != nil { 433 return nil, nil, retryErr(t, 0, "cannot deliver device serial request: %v", err) 434 } 435 defer resp.Body.Close() 436 437 switch resp.StatusCode { 438 case 200, 201: 439 case 202: 440 return nil, nil, errPoll 441 default: 442 return nil, nil, retryBadStatus(t, 0, "cannot deliver device serial request", resp) 443 } 444 445 var serial *asserts.Serial 446 var batch *asserts.Batch 447 // decode body with stream of assertions, of which one is the serial 448 dec := asserts.NewDecoder(resp.Body) 449 for { 450 got, err := dec.Decode() 451 if err == io.EOF { 452 break 453 } 454 if err != nil { // assume broken i/o 455 return nil, nil, retryErr(t, 0, "cannot read response to request for a serial: %v", err) 456 } 457 if got.Type() == asserts.SerialType { 458 if serial != nil { 459 return nil, nil, fmt.Errorf("cannot accept more than a single device serial assertion from the device service") 460 } 461 serial = got.(*asserts.Serial) 462 } else { 463 if batch == nil { 464 batch = asserts.NewBatch(nil) 465 } 466 if err := batch.Add(got); err != nil { 467 return nil, nil, err 468 } 469 } 470 // TODO: consider a size limit? 471 } 472 473 if serial == nil { 474 return nil, nil, fmt.Errorf("cannot proceed, received assertion stream from the device service missing device serial assertion") 475 } 476 477 return serial, batch, nil 478 } 479 480 var httputilNewHTTPClient = httputil.NewHTTPClient 481 482 func getSerial(t *state.Task, regCtx registrationContext, privKey asserts.PrivateKey, device *auth.DeviceState, tm timings.Measurer) (serial *asserts.Serial, ancillaryBatch *asserts.Batch, err error) { 483 var serialSup serialSetup 484 err = t.Get("serial-setup", &serialSup) 485 if err != nil && err != state.ErrNoState { 486 return nil, nil, err 487 } 488 489 if serialSup.Serial != "" { 490 // we got a serial, just haven't managed to save its info yet 491 a, err := asserts.Decode([]byte(serialSup.Serial)) 492 if err != nil { 493 return nil, nil, fmt.Errorf("internal error: cannot decode previously saved serial: %v", err) 494 } 495 return a.(*asserts.Serial), nil, nil 496 } 497 498 st := t.State() 499 proxyConf := proxyconf.New(st) 500 client := httputilNewHTTPClient(&httputil.ClientOptions{ 501 Timeout: 30 * time.Second, 502 MayLogBody: true, 503 Proxy: proxyConf.Conf, 504 ProxyConnectHeader: http.Header{"User-Agent": []string{snapdenv.UserAgent()}}, 505 }) 506 507 cfg, err := getSerialRequestConfig(t, regCtx, client) 508 if err != nil { 509 return nil, nil, err 510 } 511 512 // NB: until we get at least an Accepted (202) we need to 513 // retry from scratch creating a new request-id because the 514 // previous one used could have expired 515 516 if serialSup.SerialRequest == "" { 517 var serialRequest string 518 var err error 519 timings.Run(tm, "prepare-serial-request", "prepare device serial request", func(timings.Measurer) { 520 serialRequest, err = prepareSerialRequest(t, regCtx, privKey, device, client, cfg) 521 }) 522 if err != nil { // errors & retries 523 return nil, nil, err 524 } 525 526 serialSup.SerialRequest = serialRequest 527 } 528 529 timings.Run(tm, "submit-serial-request", "submit device serial request", func(timings.Measurer) { 530 serial, ancillaryBatch, err = submitSerialRequest(t, serialSup.SerialRequest, client, cfg) 531 }) 532 if err == errPoll { 533 // we can/should reuse the serial-request 534 t.Set("serial-setup", serialSup) 535 return nil, nil, errPoll 536 } 537 if err != nil { // errors & retries 538 return nil, nil, err 539 } 540 541 keyID := privKey.PublicKey().ID() 542 if serial.BrandID() != device.Brand || serial.Model() != device.Model || serial.DeviceKey().ID() != keyID { 543 return nil, nil, fmt.Errorf("obtained serial assertion does not match provided device identity information (brand, model, key id): %s / %s / %s != %s / %s / %s", serial.BrandID(), serial.Model(), serial.DeviceKey().ID(), device.Brand, device.Model, keyID) 544 } 545 546 // cross check authority if different from brand-id 547 if serial.BrandID() != serial.AuthorityID() { 548 model := regCtx.Model() 549 if !strutil.ListContains(model.SerialAuthority(), serial.AuthorityID()) { 550 return nil, nil, fmt.Errorf("obtained serial assertion is signed by authority %q different from brand %q without model assertion with serial-authority set to to allow for them", serial.AuthorityID(), serial.BrandID()) 551 } 552 } 553 554 if ancillaryBatch == nil { 555 serialSup.Serial = string(asserts.Encode(serial)) 556 t.Set("serial-setup", serialSup) 557 } 558 559 if repeatRequestSerial == "after-got-serial" { 560 // For testing purposes, ensure a crash in this state works. 561 return nil, nil, &state.Retry{} 562 } 563 564 return serial, ancillaryBatch, nil 565 } 566 567 type serialRequestConfig struct { 568 requestIDURL string 569 serialRequestURL string 570 headers map[string]string 571 proposedSerial string 572 body []byte 573 } 574 575 func (cfg *serialRequestConfig) applyHeaders(req *http.Request) { 576 for k, v := range cfg.headers { 577 req.Header.Set(k, v) 578 } 579 } 580 581 func getSerialRequestConfig(t *state.Task, regCtx registrationContext, client *http.Client) (*serialRequestConfig, error) { 582 var svcURL, proxyURL *url.URL 583 584 st := t.State() 585 tr := config.NewTransaction(st) 586 if proxyStore, err := proxyStore(st, tr); err != nil && err != state.ErrNoState { 587 return nil, err 588 } else if proxyStore != nil { 589 proxyURL = proxyStore.URL() 590 } 591 592 cfg := serialRequestConfig{} 593 594 gadgetName := regCtx.GadgetForSerialRequestConfig() 595 // gadget is optional on classic 596 if gadgetName != "" { 597 var gadgetSt snapstate.SnapState 598 if err := snapstate.Get(st, gadgetName, &gadgetSt); err != nil { 599 return nil, fmt.Errorf("cannot find gadget snap %q: %v", gadgetName, err) 600 } 601 602 var svcURI string 603 err := tr.GetMaybe(gadgetName, "device-service.url", &svcURI) 604 if err != nil { 605 return nil, err 606 } 607 608 if svcURI != "" { 609 svcURL, err = url.Parse(svcURI) 610 if err != nil { 611 return nil, fmt.Errorf("cannot parse device registration base URL %q: %v", svcURI, err) 612 } 613 if !strings.HasSuffix(svcURL.Path, "/") { 614 svcURL.Path += "/" 615 } 616 } 617 618 err = tr.GetMaybe(gadgetName, "device-service.headers", &cfg.headers) 619 if err != nil { 620 return nil, err 621 } 622 623 var bodyStr string 624 err = tr.GetMaybe(gadgetName, "registration.body", &bodyStr) 625 if err != nil { 626 return nil, err 627 } 628 629 cfg.body = []byte(bodyStr) 630 631 err = tr.GetMaybe(gadgetName, "registration.proposed-serial", &cfg.proposedSerial) 632 if err != nil { 633 return nil, err 634 } 635 } 636 637 if proxyURL != nil && svcURL != nil && !newEnoughProxy(st, proxyURL, client) { 638 logger.Noticef("Proxy store does not support custom serial vault; ignoring the proxy") 639 proxyURL = nil 640 } 641 642 cfg.setURLs(proxyURL, svcURL) 643 644 return &cfg, nil 645 } 646 647 func (m *DeviceManager) doRequestSerial(t *state.Task, _ *tomb.Tomb) error { 648 st := t.State() 649 st.Lock() 650 defer st.Unlock() 651 652 perfTimings := state.TimingsForTask(t) 653 defer perfTimings.Save(st) 654 655 regCtx, err := m.registrationCtx(t) 656 if err != nil { 657 return err 658 } 659 660 device, err := regCtx.Device() 661 if err != nil { 662 return err 663 } 664 665 // NB: the keyPair is fixed for now 666 privKey, err := m.keyPair() 667 if err == state.ErrNoState { 668 return fmt.Errorf("internal error: cannot find device key pair") 669 } 670 if err != nil { 671 return err 672 } 673 674 // make this idempotent, look if we have already a serial assertion 675 // for privKey 676 serials, err := assertstate.DB(st).FindMany(asserts.SerialType, map[string]string{ 677 "brand-id": device.Brand, 678 "model": device.Model, 679 "device-key-sha3-384": privKey.PublicKey().ID(), 680 }) 681 if err != nil && !asserts.IsNotFound(err) { 682 return err 683 } 684 685 finish := func(serial *asserts.Serial) error { 686 if regCtx.FinishRegistration(serial); err != nil { 687 return err 688 } 689 t.SetStatus(state.DoneStatus) 690 return nil 691 } 692 693 if len(serials) == 1 { 694 // means we saved the assertion but didn't get to the end of the task 695 return finish(serials[0].(*asserts.Serial)) 696 } 697 if len(serials) > 1 { 698 return fmt.Errorf("internal error: multiple serial assertions for the same device key") 699 } 700 701 var serial *asserts.Serial 702 var ancillaryBatch *asserts.Batch 703 timings.Run(perfTimings, "get-serial", "get device serial", func(tm timings.Measurer) { 704 serial, ancillaryBatch, err = getSerial(t, regCtx, privKey, device, tm) 705 }) 706 if err == errPoll { 707 t.Logf("Will poll for device serial assertion in 60 seconds") 708 return &state.Retry{After: retryInterval} 709 } 710 if err != nil { // errors & retries 711 return err 712 713 } 714 715 // TODO: the accept* helpers put the serial directly in the 716 // system assertion database, that will not work 717 // for 3rd-party signed serials in the case of a remodel 718 // because the model is added only later. If needed, the best way 719 // to fix this requires rethinking how remodel and new assertions 720 // interact 721 if ancillaryBatch == nil { 722 // the device service returned only the serial 723 if err := acceptSerialOnly(t, serial, perfTimings); err != nil { 724 return err 725 } 726 } else { 727 // the device service returned a stream of assertions 728 timings.Run(perfTimings, "fetch-keys", "fetch signing key chain", func(timings.Measurer) { 729 err = acceptSerialPlusBatch(t, serial, ancillaryBatch) 730 }) 731 if err != nil { 732 t.Errorf("cannot accept stream of assertions from device service: %v", err) 733 return err 734 } 735 } 736 737 if repeatRequestSerial == "after-add-serial" { 738 // For testing purposes, ensure a crash in this state works. 739 return &state.Retry{} 740 } 741 742 return finish(serial) 743 } 744 745 func acceptSerialOnly(t *state.Task, serial *asserts.Serial, perfTimings *timings.Timings) error { 746 st := t.State() 747 var err error 748 var errAcctKey error 749 // try to fetch the signing key chain of the serial 750 timings.Run(perfTimings, "fetch-keys", "fetch signing key chain", func(timings.Measurer) { 751 errAcctKey, err = fetchKeys(st, serial.SignKeyID()) 752 }) 753 if err != nil { 754 return err 755 } 756 757 // add the serial assertion to the system assertion db 758 err = assertstate.Add(st, serial) 759 if err != nil { 760 // if we had failed to fetch the signing key, retry in a bit 761 if errAcctKey != nil { 762 t.Errorf("cannot fetch signing key for the serial: %v", errAcctKey) 763 return &state.Retry{After: retryInterval} 764 } 765 return err 766 } 767 768 return nil 769 } 770 771 func acceptSerialPlusBatch(t *state.Task, serial *asserts.Serial, batch *asserts.Batch) error { 772 st := t.State() 773 err := batch.Add(serial) 774 if err != nil { 775 return err 776 } 777 return assertstate.AddBatch(st, batch, &asserts.CommitOptions{Precheck: true}) 778 } 779 780 var repeatRequestSerial string // for tests 781 782 func fetchKeys(st *state.State, keyID string) (errAcctKey error, err error) { 783 // TODO: right now any store should be good enough here but 784 // that might change. As an alternative we do support 785 // receiving a stream with any relevant assertions. 786 sto := snapstate.Store(st, nil) 787 db := assertstate.DB(st) 788 789 retrieveError := false 790 retrieve := func(ref *asserts.Ref) (asserts.Assertion, error) { 791 st.Unlock() 792 defer st.Lock() 793 a, err := sto.Assertion(ref.Type, ref.PrimaryKey, nil) 794 retrieveError = err != nil 795 return a, err 796 } 797 798 save := func(a asserts.Assertion) error { 799 err = assertstate.Add(st, a) 800 if err != nil && !asserts.IsUnaccceptedUpdate(err) { 801 return err 802 } 803 return nil 804 } 805 806 f := asserts.NewFetcher(db, retrieve, save) 807 808 keyRef := &asserts.Ref{ 809 Type: asserts.AccountKeyType, 810 PrimaryKey: []string{keyID}, 811 } 812 if err := f.Fetch(keyRef); err != nil { 813 if retrieveError { 814 return err, nil 815 } else { 816 return nil, err 817 } 818 } 819 return nil, nil 820 }