github.com/meulengracht/snapd@v0.0.0-20210719210640-8bde69bcc84e/overlord/devicestate/handlers_serial.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  /*
     3   * Copyright (C) 2016-2020 Canonical Ltd
     4   *
     5   * This program is free software: you can redistribute it and/or modify
     6   * it under the terms of the GNU General Public License version 3 as
     7   * published by the Free Software Foundation.
     8   *
     9   * This program is distributed in the hope that it will be useful,
    10   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    11   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    12   * GNU General Public License for more details.
    13   *
    14   * You should have received a copy of the GNU General Public License
    15   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    16   *
    17   */
    18  
    19  package devicestate
    20  
    21  import (
    22  	"bytes"
    23  	"crypto/rsa"
    24  	"encoding/json"
    25  	"errors"
    26  	"fmt"
    27  	"io"
    28  	"net/http"
    29  	"net/url"
    30  	"strconv"
    31  	"strings"
    32  	"time"
    33  
    34  	"gopkg.in/tomb.v2"
    35  
    36  	"github.com/snapcore/snapd/asserts"
    37  	"github.com/snapcore/snapd/httputil"
    38  	"github.com/snapcore/snapd/logger"
    39  	"github.com/snapcore/snapd/overlord/assertstate"
    40  	"github.com/snapcore/snapd/overlord/auth"
    41  	"github.com/snapcore/snapd/overlord/configstate/config"
    42  	"github.com/snapcore/snapd/overlord/configstate/proxyconf"
    43  	"github.com/snapcore/snapd/overlord/snapstate"
    44  	"github.com/snapcore/snapd/overlord/state"
    45  	"github.com/snapcore/snapd/snapdenv"
    46  	"github.com/snapcore/snapd/strutil"
    47  	"github.com/snapcore/snapd/timings"
    48  )
    49  
    50  func baseURL() *url.URL {
    51  	if snapdenv.UseStagingStore() {
    52  		return mustParse("https://api.staging.snapcraft.io/")
    53  	}
    54  	return mustParse("https://api.snapcraft.io/")
    55  }
    56  
    57  func mustParse(s string) *url.URL {
    58  	u, err := url.Parse(s)
    59  	if err != nil {
    60  		panic(err)
    61  	}
    62  	return u
    63  }
    64  
    65  var (
    66  	keyLength     = 4096
    67  	retryInterval = 60 * time.Second
    68  	maxTentatives = 15
    69  	baseStoreURL  = baseURL().ResolveReference(authRef)
    70  
    71  	authRef    = mustParse("api/v1/snaps/auth/") // authRef must end in / for the following refs to work
    72  	reqIdRef   = mustParse("request-id")
    73  	serialRef  = mustParse("serial")
    74  	devicesRef = mustParse("devices")
    75  
    76  	// we accept a stream with the serial assertion as well
    77  	registrationCapabilities = []string{"serial-stream"}
    78  )
    79  
    80  func (m *DeviceManager) doGenerateDeviceKey(t *state.Task, _ *tomb.Tomb) error {
    81  	st := t.State()
    82  	st.Lock()
    83  	defer st.Unlock()
    84  
    85  	perfTimings := state.TimingsForTask(t)
    86  	defer perfTimings.Save(st)
    87  
    88  	device, err := m.device()
    89  	if err != nil {
    90  		return err
    91  	}
    92  
    93  	if device.KeyID != "" {
    94  		// nothing to do
    95  		return nil
    96  	}
    97  
    98  	st.Unlock()
    99  	var keyPair *rsa.PrivateKey
   100  	timings.Run(perfTimings, "generate-rsa-key", "generating device key pair", func(tm timings.Measurer) {
   101  		keyPair, err = generateRSAKey(keyLength)
   102  	})
   103  	st.Lock()
   104  	if err != nil {
   105  		return fmt.Errorf("cannot generate device key pair: %v", err)
   106  	}
   107  
   108  	privKey := asserts.RSAPrivateKey(keyPair)
   109  	err = m.withKeypairMgr(func(keypairMgr asserts.KeypairManager) error {
   110  		return keypairMgr.Put(privKey)
   111  	})
   112  	if err != nil {
   113  		return fmt.Errorf("cannot store device key pair: %v", err)
   114  	}
   115  
   116  	device.KeyID = privKey.PublicKey().ID()
   117  	err = m.setDevice(device)
   118  	if err != nil {
   119  		return err
   120  	}
   121  	t.SetStatus(state.DoneStatus)
   122  	return nil
   123  }
   124  
   125  func newEnoughProxy(st *state.State, proxyURL *url.URL, client *http.Client) bool {
   126  	st.Unlock()
   127  	defer st.Lock()
   128  
   129  	const prefix = "Cannot check whether proxy store supports a custom serial vault"
   130  
   131  	req, err := http.NewRequest("HEAD", proxyURL.String(), nil)
   132  	if err != nil {
   133  		// can't really happen unless proxyURL is somehow broken
   134  		logger.Debugf(prefix+": %v", err)
   135  		return false
   136  	}
   137  	req.Header.Set("User-Agent", snapdenv.UserAgent())
   138  	resp, err := client.Do(req)
   139  	if err != nil {
   140  		// some sort of network or protocol error
   141  		logger.Debugf(prefix+": %v", err)
   142  		return false
   143  	}
   144  	resp.Body.Close()
   145  	if resp.StatusCode != 200 {
   146  		logger.Debugf(prefix+": Head request returned %s.", resp.Status)
   147  		return false
   148  	}
   149  	verstr := resp.Header.Get("Snap-Store-Version")
   150  	ver, err := strconv.Atoi(verstr)
   151  	if err != nil {
   152  		logger.Debugf(prefix+": Bogus Snap-Store-Version header %q.", verstr)
   153  		return false
   154  	}
   155  	return ver >= 6
   156  }
   157  
   158  func (cfg *serialRequestConfig) setURLs(proxyURL, svcURL *url.URL) {
   159  	base := baseStoreURL
   160  	if proxyURL != nil {
   161  		if svcURL != nil {
   162  			if cfg.headers == nil {
   163  				cfg.headers = make(map[string]string, 1)
   164  			}
   165  			cfg.headers["X-Snap-Device-Service-URL"] = svcURL.String()
   166  		}
   167  		base = proxyURL.ResolveReference(authRef)
   168  	} else if svcURL != nil {
   169  		base = svcURL
   170  	}
   171  
   172  	cfg.requestIDURL = base.ResolveReference(reqIdRef).String()
   173  	if svcURL != nil && proxyURL == nil {
   174  		// talking directly to the custom device service
   175  		cfg.serialRequestURL = base.ResolveReference(serialRef).String()
   176  	} else {
   177  		cfg.serialRequestURL = base.ResolveReference(devicesRef).String()
   178  	}
   179  }
   180  
   181  // A registrationContext handles the contextual information needed
   182  // for the initial registration or a re-registration.
   183  type registrationContext interface {
   184  	Device() (*auth.DeviceState, error)
   185  
   186  	Model() *asserts.Model
   187  
   188  	GadgetForSerialRequestConfig() string
   189  	SerialRequestExtraHeaders() map[string]interface{}
   190  	SerialRequestAncillaryAssertions() []asserts.Assertion
   191  
   192  	FinishRegistration(serial *asserts.Serial) error
   193  
   194  	ForRemodeling() bool
   195  }
   196  
   197  // initialRegistrationContext is a thin wrapper around DeviceManager
   198  // implementing registrationContext for initial regitration
   199  type initialRegistrationContext struct {
   200  	deviceMgr *DeviceManager
   201  
   202  	model *asserts.Model
   203  }
   204  
   205  func (rc *initialRegistrationContext) ForRemodeling() bool {
   206  	return false
   207  }
   208  
   209  func (rc *initialRegistrationContext) Device() (*auth.DeviceState, error) {
   210  	return rc.deviceMgr.device()
   211  }
   212  
   213  func (rc *initialRegistrationContext) Model() *asserts.Model {
   214  	return rc.model
   215  }
   216  
   217  func (rc *initialRegistrationContext) GadgetForSerialRequestConfig() string {
   218  	return rc.model.Gadget()
   219  }
   220  
   221  func (rc *initialRegistrationContext) SerialRequestExtraHeaders() map[string]interface{} {
   222  	return nil
   223  }
   224  
   225  func (rc *initialRegistrationContext) SerialRequestAncillaryAssertions() []asserts.Assertion {
   226  	return []asserts.Assertion{rc.model}
   227  }
   228  
   229  func (rc *initialRegistrationContext) FinishRegistration(serial *asserts.Serial) error {
   230  	device, err := rc.deviceMgr.device()
   231  	if err != nil {
   232  		return err
   233  	}
   234  
   235  	device.Serial = serial.Serial()
   236  	if err := rc.deviceMgr.setDevice(device); err != nil {
   237  		return err
   238  	}
   239  	rc.deviceMgr.markRegistered()
   240  
   241  	// make sure we timely consider anything that was blocked on
   242  	// registration
   243  	rc.deviceMgr.state.EnsureBefore(0)
   244  
   245  	return nil
   246  }
   247  
   248  // registrationCtx returns a registrationContext appropriate for the task and its change.
   249  func (m *DeviceManager) registrationCtx(t *state.Task) (registrationContext, error) {
   250  	remodCtx, err := remodelCtxFromTask(t)
   251  	if err != nil && err != state.ErrNoState {
   252  		return nil, err
   253  	}
   254  	if regCtx, ok := remodCtx.(registrationContext); ok {
   255  		return regCtx, nil
   256  	}
   257  	model, err := m.Model()
   258  	if err != nil {
   259  		return nil, err
   260  	}
   261  
   262  	return &initialRegistrationContext{
   263  		deviceMgr: m,
   264  		model:     model,
   265  	}, nil
   266  }
   267  
   268  type serialSetup struct {
   269  	SerialRequest string `json:"serial-request"`
   270  	Serial        string `json:"serial"`
   271  }
   272  
   273  type requestIDResp struct {
   274  	RequestID string `json:"request-id"`
   275  }
   276  
   277  func retryErr(t *state.Task, nTentatives int, reason string, a ...interface{}) error {
   278  	t.State().Lock()
   279  	defer t.State().Unlock()
   280  	if nTentatives >= maxTentatives {
   281  		return fmt.Errorf(reason, a...)
   282  	}
   283  	t.Errorf(reason, a...)
   284  	return &state.Retry{After: retryInterval}
   285  }
   286  
   287  type serverError struct {
   288  	Message string         `json:"message"`
   289  	Errors  []*serverError `json:"error_list"`
   290  }
   291  
   292  func retryBadStatus(t *state.Task, nTentatives int, reason string, resp *http.Response) error {
   293  	if resp.StatusCode > 500 {
   294  		// likely temporary
   295  		return retryErr(t, nTentatives, "%s: unexpected status %d", reason, resp.StatusCode)
   296  	}
   297  	if resp.Header.Get("Content-Type") == "application/json" {
   298  		var srvErr serverError
   299  		dec := json.NewDecoder(resp.Body)
   300  		err := dec.Decode(&srvErr)
   301  		if err == nil {
   302  			msg := srvErr.Message
   303  			if msg == "" && len(srvErr.Errors) > 0 {
   304  				msg = srvErr.Errors[0].Message
   305  			}
   306  			if msg != "" {
   307  				return fmt.Errorf("%s: %s", reason, msg)
   308  			}
   309  		}
   310  	}
   311  	return fmt.Errorf("%s: unexpected status %d", reason, resp.StatusCode)
   312  }
   313  
   314  func prepareSerialRequest(t *state.Task, regCtx registrationContext, privKey asserts.PrivateKey, device *auth.DeviceState, client *http.Client, cfg *serialRequestConfig) (string, error) {
   315  	// limit tentatives starting from scratch before going to
   316  	// slower full retries
   317  	var nTentatives int
   318  	err := t.Get("pre-poll-tentatives", &nTentatives)
   319  	if err != nil && err != state.ErrNoState {
   320  		return "", err
   321  	}
   322  	nTentatives++
   323  	t.Set("pre-poll-tentatives", nTentatives)
   324  
   325  	st := t.State()
   326  	st.Unlock()
   327  	defer st.Lock()
   328  
   329  	req, err := http.NewRequest("POST", cfg.requestIDURL, nil)
   330  	if err != nil {
   331  		return "", fmt.Errorf("internal error: cannot create request-id request %q", cfg.requestIDURL)
   332  	}
   333  	req.Header.Set("User-Agent", snapdenv.UserAgent())
   334  	cfg.applyHeaders(req)
   335  
   336  	resp, err := client.Do(req)
   337  	if err != nil {
   338  		if !httputil.ShouldRetryError(err) {
   339  			// a non temporary net error fully errors out and triggers a retry
   340  			// retries
   341  			return "", fmt.Errorf("cannot retrieve request-id for making a request for a serial: %v", err)
   342  		}
   343  		if httputil.NoNetwork(err) {
   344  			// If there is no network there is no need to count
   345  			// this as a tentatives attempt. If we do it this
   346  			// way the risk is that we tried a bunch of times
   347  			// with no network and if we hit the server for real
   348  			// and it replies with something we need to retry
   349  			// we will not because nTentatives is way over the
   350  			// limit.
   351  			st.Lock()
   352  			t.Set("pre-poll-tentatives", 0)
   353  			st.Unlock()
   354  			// Retry quickly if there is no network
   355  			// (yet). This ensures that we try to get a serial
   356  			// as soon as the user configured the network of the
   357  			// device
   358  			noNetworkRetryInterval := retryInterval / 2
   359  			return "", &state.Retry{After: noNetworkRetryInterval}
   360  		}
   361  
   362  		return "", retryErr(t, nTentatives, "cannot retrieve request-id for making a request for a serial: %v", err)
   363  	}
   364  	defer resp.Body.Close()
   365  	if resp.StatusCode != 200 {
   366  		return "", retryBadStatus(t, nTentatives, "cannot retrieve request-id for making a request for a serial", resp)
   367  	}
   368  
   369  	dec := json.NewDecoder(resp.Body)
   370  	var requestID requestIDResp
   371  	err = dec.Decode(&requestID)
   372  	if err != nil { // assume broken i/o
   373  		return "", retryErr(t, nTentatives, "cannot read response with request-id for making a request for a serial: %v", err)
   374  	}
   375  
   376  	encodedPubKey, err := asserts.EncodePublicKey(privKey.PublicKey())
   377  	if err != nil {
   378  		return "", fmt.Errorf("internal error: cannot encode device public key: %v", err)
   379  
   380  	}
   381  
   382  	headers := map[string]interface{}{
   383  		"brand-id":   device.Brand,
   384  		"model":      device.Model,
   385  		"request-id": requestID.RequestID,
   386  		"device-key": string(encodedPubKey),
   387  	}
   388  	if cfg.proposedSerial != "" {
   389  		headers["serial"] = cfg.proposedSerial
   390  	}
   391  
   392  	for k, v := range regCtx.SerialRequestExtraHeaders() {
   393  		headers[k] = v
   394  	}
   395  
   396  	serialReq, err := asserts.SignWithoutAuthority(asserts.SerialRequestType, headers, cfg.body, privKey)
   397  	if err != nil {
   398  		return "", err
   399  	}
   400  
   401  	buf := new(bytes.Buffer)
   402  	encoder := asserts.NewEncoder(buf)
   403  	if err := encoder.Encode(serialReq); err != nil {
   404  		return "", fmt.Errorf("cannot encode serial-request: %v", err)
   405  	}
   406  
   407  	for _, ancillaryAs := range regCtx.SerialRequestAncillaryAssertions() {
   408  		if err := encoder.Encode(ancillaryAs); err != nil {
   409  			return "", fmt.Errorf("cannot encode ancillary assertion: %v", err)
   410  		}
   411  
   412  	}
   413  
   414  	return buf.String(), nil
   415  }
   416  
   417  var errPoll = errors.New("serial-request accepted, poll later")
   418  
   419  func submitSerialRequest(t *state.Task, serialRequest string, client *http.Client, cfg *serialRequestConfig) (*asserts.Serial, *asserts.Batch, error) {
   420  	st := t.State()
   421  	st.Unlock()
   422  	defer st.Lock()
   423  
   424  	req, err := http.NewRequest("POST", cfg.serialRequestURL, bytes.NewBufferString(serialRequest))
   425  	if err != nil {
   426  		return nil, nil, fmt.Errorf("internal error: cannot create serial-request request %q", cfg.serialRequestURL)
   427  	}
   428  	req.Header.Set("User-Agent", snapdenv.UserAgent())
   429  	req.Header.Set("Snap-Device-Capabilities", strings.Join(registrationCapabilities, " "))
   430  	cfg.applyHeaders(req)
   431  	req.Header.Set("Content-Type", asserts.MediaType)
   432  
   433  	resp, err := client.Do(req)
   434  	if err != nil {
   435  		return nil, nil, retryErr(t, 0, "cannot deliver device serial request: %v", err)
   436  	}
   437  	defer resp.Body.Close()
   438  
   439  	switch resp.StatusCode {
   440  	case 200, 201:
   441  	case 202:
   442  		return nil, nil, errPoll
   443  	default:
   444  		return nil, nil, retryBadStatus(t, 0, "cannot deliver device serial request", resp)
   445  	}
   446  
   447  	var serial *asserts.Serial
   448  	var batch *asserts.Batch
   449  	// decode body with stream of assertions, of which one is the serial
   450  	dec := asserts.NewDecoder(resp.Body)
   451  	for {
   452  		got, err := dec.Decode()
   453  		if err == io.EOF {
   454  			break
   455  		}
   456  		if err != nil { // assume broken i/o
   457  			return nil, nil, retryErr(t, 0, "cannot read response to request for a serial: %v", err)
   458  		}
   459  		if got.Type() == asserts.SerialType {
   460  			if serial != nil {
   461  				return nil, nil, fmt.Errorf("cannot accept more than a single device serial assertion from the device service")
   462  			}
   463  			serial = got.(*asserts.Serial)
   464  		} else {
   465  			if batch == nil {
   466  				batch = asserts.NewBatch(nil)
   467  			}
   468  			if err := batch.Add(got); err != nil {
   469  				return nil, nil, err
   470  			}
   471  		}
   472  		// TODO: consider a size limit?
   473  	}
   474  
   475  	if serial == nil {
   476  		return nil, nil, fmt.Errorf("cannot proceed, received assertion stream from the device service missing device serial assertion")
   477  	}
   478  
   479  	return serial, batch, nil
   480  }
   481  
   482  var httputilNewHTTPClient = httputil.NewHTTPClient
   483  
   484  func getSerial(t *state.Task, regCtx registrationContext, privKey asserts.PrivateKey, device *auth.DeviceState, tm timings.Measurer) (serial *asserts.Serial, ancillaryBatch *asserts.Batch, err error) {
   485  	var serialSup serialSetup
   486  	err = t.Get("serial-setup", &serialSup)
   487  	if err != nil && err != state.ErrNoState {
   488  		return nil, nil, err
   489  	}
   490  
   491  	if serialSup.Serial != "" {
   492  		// we got a serial, just haven't managed to save its info yet
   493  		a, err := asserts.Decode([]byte(serialSup.Serial))
   494  		if err != nil {
   495  			return nil, nil, fmt.Errorf("internal error: cannot decode previously saved serial: %v", err)
   496  		}
   497  		return a.(*asserts.Serial), nil, nil
   498  	}
   499  
   500  	st := t.State()
   501  	proxyConf := proxyconf.New(st)
   502  	client := httputilNewHTTPClient(&httputil.ClientOptions{
   503  		Timeout:            30 * time.Second,
   504  		MayLogBody:         true,
   505  		Proxy:              proxyConf.Conf,
   506  		ProxyConnectHeader: http.Header{"User-Agent": []string{snapdenv.UserAgent()}},
   507  	})
   508  
   509  	cfg, err := getSerialRequestConfig(t, regCtx, client)
   510  	if err != nil {
   511  		return nil, nil, err
   512  	}
   513  
   514  	// NB: until we get at least an Accepted (202) we need to
   515  	// retry from scratch creating a new request-id because the
   516  	// previous one used could have expired
   517  
   518  	if serialSup.SerialRequest == "" {
   519  		var serialRequest string
   520  		var err error
   521  		timings.Run(tm, "prepare-serial-request", "prepare device serial request", func(timings.Measurer) {
   522  			serialRequest, err = prepareSerialRequest(t, regCtx, privKey, device, client, cfg)
   523  		})
   524  		if err != nil { // errors & retries
   525  			return nil, nil, err
   526  		}
   527  
   528  		serialSup.SerialRequest = serialRequest
   529  	}
   530  
   531  	timings.Run(tm, "submit-serial-request", "submit device serial request", func(timings.Measurer) {
   532  		serial, ancillaryBatch, err = submitSerialRequest(t, serialSup.SerialRequest, client, cfg)
   533  	})
   534  	if err == errPoll {
   535  		// we can/should reuse the serial-request
   536  		t.Set("serial-setup", serialSup)
   537  		return nil, nil, errPoll
   538  	}
   539  	if err != nil { // errors & retries
   540  		return nil, nil, err
   541  	}
   542  
   543  	keyID := privKey.PublicKey().ID()
   544  	if serial.BrandID() != device.Brand || serial.Model() != device.Model || serial.DeviceKey().ID() != keyID {
   545  		return nil, nil, fmt.Errorf("obtained serial assertion does not match provided device identity information (brand, model, key id): %s / %s / %s != %s / %s / %s", serial.BrandID(), serial.Model(), serial.DeviceKey().ID(), device.Brand, device.Model, keyID)
   546  	}
   547  
   548  	// cross check authority if different from brand-id
   549  	if serial.BrandID() != serial.AuthorityID() {
   550  		model := regCtx.Model()
   551  		if !strutil.ListContains(model.SerialAuthority(), serial.AuthorityID()) {
   552  			return nil, nil, fmt.Errorf("obtained serial assertion is signed by authority %q different from brand %q without model assertion with serial-authority set to to allow for them", serial.AuthorityID(), serial.BrandID())
   553  		}
   554  	}
   555  
   556  	if ancillaryBatch == nil {
   557  		serialSup.Serial = string(asserts.Encode(serial))
   558  		t.Set("serial-setup", serialSup)
   559  	}
   560  
   561  	if repeatRequestSerial == "after-got-serial" {
   562  		// For testing purposes, ensure a crash in this state works.
   563  		return nil, nil, &state.Retry{}
   564  	}
   565  
   566  	return serial, ancillaryBatch, nil
   567  }
   568  
   569  type serialRequestConfig struct {
   570  	requestIDURL     string
   571  	serialRequestURL string
   572  	headers          map[string]string
   573  	proposedSerial   string
   574  	body             []byte
   575  }
   576  
   577  func (cfg *serialRequestConfig) applyHeaders(req *http.Request) {
   578  	for k, v := range cfg.headers {
   579  		req.Header.Set(k, v)
   580  	}
   581  }
   582  
   583  func getSerialRequestConfig(t *state.Task, regCtx registrationContext, client *http.Client) (*serialRequestConfig, error) {
   584  	var svcURL, proxyURL *url.URL
   585  
   586  	st := t.State()
   587  	tr := config.NewTransaction(st)
   588  	if proxyStore, err := proxyStore(st, tr); err != nil && err != state.ErrNoState {
   589  		return nil, err
   590  	} else if proxyStore != nil {
   591  		proxyURL = proxyStore.URL()
   592  	}
   593  
   594  	cfg := serialRequestConfig{}
   595  
   596  	gadgetName := regCtx.GadgetForSerialRequestConfig()
   597  	// gadget is optional on classic
   598  	if gadgetName != "" {
   599  		var gadgetSt snapstate.SnapState
   600  		if err := snapstate.Get(st, gadgetName, &gadgetSt); err != nil {
   601  			return nil, fmt.Errorf("cannot find gadget snap %q: %v", gadgetName, err)
   602  		}
   603  
   604  		var svcURI string
   605  		err := tr.GetMaybe(gadgetName, "device-service.url", &svcURI)
   606  		if err != nil {
   607  			return nil, err
   608  		}
   609  
   610  		if svcURI != "" {
   611  			svcURL, err = url.Parse(svcURI)
   612  			if err != nil {
   613  				return nil, fmt.Errorf("cannot parse device registration base URL %q: %v", svcURI, err)
   614  			}
   615  			if !strings.HasSuffix(svcURL.Path, "/") {
   616  				svcURL.Path += "/"
   617  			}
   618  		}
   619  
   620  		err = tr.GetMaybe(gadgetName, "device-service.headers", &cfg.headers)
   621  		if err != nil {
   622  			return nil, err
   623  		}
   624  
   625  		var bodyStr string
   626  		err = tr.GetMaybe(gadgetName, "registration.body", &bodyStr)
   627  		if err != nil {
   628  			return nil, err
   629  		}
   630  
   631  		cfg.body = []byte(bodyStr)
   632  
   633  		err = tr.GetMaybe(gadgetName, "registration.proposed-serial", &cfg.proposedSerial)
   634  		if err != nil {
   635  			return nil, err
   636  		}
   637  	}
   638  
   639  	if proxyURL != nil && svcURL != nil && !newEnoughProxy(st, proxyURL, client) {
   640  		logger.Noticef("Proxy store does not support custom serial vault; ignoring the proxy")
   641  		proxyURL = nil
   642  	}
   643  
   644  	cfg.setURLs(proxyURL, svcURL)
   645  
   646  	return &cfg, nil
   647  }
   648  
   649  func (m *DeviceManager) doRequestSerial(t *state.Task, _ *tomb.Tomb) error {
   650  	st := t.State()
   651  	st.Lock()
   652  	defer st.Unlock()
   653  
   654  	perfTimings := state.TimingsForTask(t)
   655  	defer perfTimings.Save(st)
   656  
   657  	regCtx, err := m.registrationCtx(t)
   658  	if err != nil {
   659  		return err
   660  	}
   661  
   662  	device, err := regCtx.Device()
   663  	if err != nil {
   664  		return err
   665  	}
   666  
   667  	// NB: the keyPair is fixed for now
   668  	privKey, err := m.keyPair()
   669  	if err == state.ErrNoState {
   670  		return fmt.Errorf("internal error: cannot find device key pair")
   671  	}
   672  	if err != nil {
   673  		return err
   674  	}
   675  
   676  	// make this idempotent, look if we have already a serial assertion
   677  	// for privKey
   678  	serials, err := assertstate.DB(st).FindMany(asserts.SerialType, map[string]string{
   679  		"brand-id":            device.Brand,
   680  		"model":               device.Model,
   681  		"device-key-sha3-384": privKey.PublicKey().ID(),
   682  	})
   683  	if err != nil && !asserts.IsNotFound(err) {
   684  		return err
   685  	}
   686  
   687  	finish := func(serial *asserts.Serial) error {
   688  		// save serial if appropriate into the device save
   689  		// assertion database
   690  		err := m.withSaveAssertDB(func(savedb *asserts.Database) error {
   691  			db := assertstate.DB(st)
   692  			retrieve := func(ref *asserts.Ref) (asserts.Assertion, error) {
   693  				return ref.Resolve(db.Find)
   694  			}
   695  			b := asserts.NewBatch(nil)
   696  			err := b.Fetch(savedb, retrieve, func(f asserts.Fetcher) error {
   697  				// save the associated model as well
   698  				// as it might be required for cross-checks
   699  				// of the serial
   700  				if err := f.Save(regCtx.Model()); err != nil {
   701  					return err
   702  				}
   703  				return f.Save(serial)
   704  			})
   705  			if err != nil {
   706  				return err
   707  			}
   708  			return b.CommitTo(savedb, nil)
   709  		})
   710  		if err != nil && err != errNoSaveSupport {
   711  			return fmt.Errorf("cannot save serial to device save assertion database: %v", err)
   712  		}
   713  
   714  		if err := regCtx.FinishRegistration(serial); err != nil {
   715  			return err
   716  		}
   717  		t.SetStatus(state.DoneStatus)
   718  		return nil
   719  	}
   720  
   721  	if len(serials) == 1 {
   722  		// means we saved the assertion but didn't get to the end of the task
   723  		return finish(serials[0].(*asserts.Serial))
   724  	}
   725  	if len(serials) > 1 {
   726  		return fmt.Errorf("internal error: multiple serial assertions for the same device key")
   727  	}
   728  
   729  	var serial *asserts.Serial
   730  	var ancillaryBatch *asserts.Batch
   731  	timings.Run(perfTimings, "get-serial", "get device serial", func(tm timings.Measurer) {
   732  		serial, ancillaryBatch, err = getSerial(t, regCtx, privKey, device, tm)
   733  	})
   734  	if err == errPoll {
   735  		t.Logf("Will poll for device serial assertion in 60 seconds")
   736  		return &state.Retry{After: retryInterval}
   737  	}
   738  	if err != nil { // errors & retries
   739  		return err
   740  
   741  	}
   742  
   743  	// TODO: the accept* helpers put the serial directly in the
   744  	// system assertion database, that will not work
   745  	// for 3rd-party signed serials in the case of a remodel
   746  	// because the model is added only later. If needed, the best way
   747  	// to fix this requires rethinking how remodel and new assertions
   748  	// interact
   749  	if ancillaryBatch == nil {
   750  		// the device service returned only the serial
   751  		if err := acceptSerialOnly(t, serial, perfTimings); err != nil {
   752  			return err
   753  		}
   754  	} else {
   755  		// the device service returned a stream of assertions
   756  		timings.Run(perfTimings, "fetch-keys", "fetch signing key chain", func(timings.Measurer) {
   757  			err = acceptSerialPlusBatch(t, serial, ancillaryBatch)
   758  		})
   759  		if err != nil {
   760  			t.Errorf("cannot accept stream of assertions from device service: %v", err)
   761  			return err
   762  		}
   763  	}
   764  
   765  	if repeatRequestSerial == "after-add-serial" {
   766  		// For testing purposes, ensure a crash in this state works.
   767  		return &state.Retry{}
   768  	}
   769  
   770  	return finish(serial)
   771  }
   772  
   773  func acceptSerialOnly(t *state.Task, serial *asserts.Serial, perfTimings *timings.Timings) error {
   774  	st := t.State()
   775  	var err error
   776  	var errAcctKey error
   777  	// try to fetch the signing key chain of the serial
   778  	timings.Run(perfTimings, "fetch-keys", "fetch signing key chain", func(timings.Measurer) {
   779  		errAcctKey, err = fetchKeys(st, serial.SignKeyID())
   780  	})
   781  	if err != nil {
   782  		return err
   783  	}
   784  
   785  	// add the serial assertion to the system assertion db
   786  	err = assertstate.Add(st, serial)
   787  	if err != nil {
   788  		// if we had failed to fetch the signing key, retry in a bit
   789  		if errAcctKey != nil {
   790  			t.Errorf("cannot fetch signing key for the serial: %v", errAcctKey)
   791  			return &state.Retry{After: retryInterval}
   792  		}
   793  		return err
   794  	}
   795  
   796  	return nil
   797  }
   798  
   799  func acceptSerialPlusBatch(t *state.Task, serial *asserts.Serial, batch *asserts.Batch) error {
   800  	st := t.State()
   801  	err := batch.Add(serial)
   802  	if err != nil {
   803  		return err
   804  	}
   805  	return assertstate.AddBatch(st, batch, &asserts.CommitOptions{Precheck: true})
   806  }
   807  
   808  var repeatRequestSerial string // for tests
   809  
   810  func fetchKeys(st *state.State, keyID string) (errAcctKey error, err error) {
   811  	// TODO: right now any store should be good enough here but
   812  	// that might change. As an alternative we do support
   813  	// receiving a stream with any relevant assertions.
   814  	sto := snapstate.Store(st, nil)
   815  	db := assertstate.DB(st)
   816  
   817  	retrieveError := false
   818  	retrieve := func(ref *asserts.Ref) (asserts.Assertion, error) {
   819  		st.Unlock()
   820  		defer st.Lock()
   821  		a, err := sto.Assertion(ref.Type, ref.PrimaryKey, nil)
   822  		retrieveError = err != nil
   823  		return a, err
   824  	}
   825  
   826  	save := func(a asserts.Assertion) error {
   827  		err = assertstate.Add(st, a)
   828  		if err != nil && !asserts.IsUnaccceptedUpdate(err) {
   829  			return err
   830  		}
   831  		return nil
   832  	}
   833  
   834  	f := asserts.NewFetcher(db, retrieve, save)
   835  
   836  	keyRef := &asserts.Ref{
   837  		Type:       asserts.AccountKeyType,
   838  		PrimaryKey: []string{keyID},
   839  	}
   840  	if err := f.Fetch(keyRef); err != nil {
   841  		if retrieveError {
   842  			return err, nil
   843  		} else {
   844  			return nil, err
   845  		}
   846  	}
   847  	return nil, nil
   848  }