github.com/hugh712/snapd@v0.0.0-20200910133618-1a99902bd583/overlord/devicestate/handlers_serial.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  /*
     3   * Copyright (C) 2016-2020 Canonical Ltd
     4   *
     5   * This program is free software: you can redistribute it and/or modify
     6   * it under the terms of the GNU General Public License version 3 as
     7   * published by the Free Software Foundation.
     8   *
     9   * This program is distributed in the hope that it will be useful,
    10   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    11   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    12   * GNU General Public License for more details.
    13   *
    14   * You should have received a copy of the GNU General Public License
    15   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    16   *
    17   */
    18  
    19  package devicestate
    20  
    21  import (
    22  	"bytes"
    23  	"crypto/rsa"
    24  	"encoding/json"
    25  	"errors"
    26  	"fmt"
    27  	"io"
    28  	"net/http"
    29  	"net/url"
    30  	"strconv"
    31  	"strings"
    32  	"time"
    33  
    34  	"gopkg.in/tomb.v2"
    35  
    36  	"github.com/snapcore/snapd/asserts"
    37  	"github.com/snapcore/snapd/httputil"
    38  	"github.com/snapcore/snapd/logger"
    39  	"github.com/snapcore/snapd/overlord/assertstate"
    40  	"github.com/snapcore/snapd/overlord/auth"
    41  	"github.com/snapcore/snapd/overlord/configstate/config"
    42  	"github.com/snapcore/snapd/overlord/configstate/proxyconf"
    43  	"github.com/snapcore/snapd/overlord/snapstate"
    44  	"github.com/snapcore/snapd/overlord/state"
    45  	"github.com/snapcore/snapd/snapdenv"
    46  	"github.com/snapcore/snapd/strutil"
    47  	"github.com/snapcore/snapd/timings"
    48  )
    49  
    50  func baseURL() *url.URL {
    51  	if snapdenv.UseStagingStore() {
    52  		return mustParse("https://api.staging.snapcraft.io/")
    53  	}
    54  	return mustParse("https://api.snapcraft.io/")
    55  }
    56  
    57  func mustParse(s string) *url.URL {
    58  	u, err := url.Parse(s)
    59  	if err != nil {
    60  		panic(err)
    61  	}
    62  	return u
    63  }
    64  
    65  var (
    66  	keyLength     = 4096
    67  	retryInterval = 60 * time.Second
    68  	maxTentatives = 15
    69  	baseStoreURL  = baseURL().ResolveReference(authRef)
    70  
    71  	authRef    = mustParse("api/v1/snaps/auth/") // authRef must end in / for the following refs to work
    72  	reqIdRef   = mustParse("request-id")
    73  	serialRef  = mustParse("serial")
    74  	devicesRef = mustParse("devices")
    75  
    76  	// we accept a stream with the serial assertion as well
    77  	registrationCapabilities = []string{"serial-stream"}
    78  )
    79  
    80  func (m *DeviceManager) doGenerateDeviceKey(t *state.Task, _ *tomb.Tomb) error {
    81  	st := t.State()
    82  	st.Lock()
    83  	defer st.Unlock()
    84  
    85  	perfTimings := state.TimingsForTask(t)
    86  	defer perfTimings.Save(st)
    87  
    88  	device, err := m.device()
    89  	if err != nil {
    90  		return err
    91  	}
    92  
    93  	if device.KeyID != "" {
    94  		// nothing to do
    95  		return nil
    96  	}
    97  
    98  	st.Unlock()
    99  	var keyPair *rsa.PrivateKey
   100  	timings.Run(perfTimings, "generate-rsa-key", "generating device key pair", func(tm timings.Measurer) {
   101  		keyPair, err = generateRSAKey(keyLength)
   102  	})
   103  	st.Lock()
   104  	if err != nil {
   105  		return fmt.Errorf("cannot generate device key pair: %v", err)
   106  	}
   107  
   108  	privKey := asserts.RSAPrivateKey(keyPair)
   109  	err = m.keypairMgr.Put(privKey)
   110  	if err != nil {
   111  		return fmt.Errorf("cannot store device key pair: %v", err)
   112  	}
   113  
   114  	device.KeyID = privKey.PublicKey().ID()
   115  	err = m.setDevice(device)
   116  	if err != nil {
   117  		return err
   118  	}
   119  	t.SetStatus(state.DoneStatus)
   120  	return nil
   121  }
   122  
   123  func newEnoughProxy(st *state.State, proxyURL *url.URL, client *http.Client) bool {
   124  	st.Unlock()
   125  	defer st.Lock()
   126  
   127  	const prefix = "Cannot check whether proxy store supports a custom serial vault"
   128  
   129  	req, err := http.NewRequest("HEAD", proxyURL.String(), nil)
   130  	if err != nil {
   131  		// can't really happen unless proxyURL is somehow broken
   132  		logger.Debugf(prefix+": %v", err)
   133  		return false
   134  	}
   135  	req.Header.Set("User-Agent", snapdenv.UserAgent())
   136  	resp, err := client.Do(req)
   137  	if err != nil {
   138  		// some sort of network or protocol error
   139  		logger.Debugf(prefix+": %v", err)
   140  		return false
   141  	}
   142  	resp.Body.Close()
   143  	if resp.StatusCode != 200 {
   144  		logger.Debugf(prefix+": Head request returned %s.", resp.Status)
   145  		return false
   146  	}
   147  	verstr := resp.Header.Get("Snap-Store-Version")
   148  	ver, err := strconv.Atoi(verstr)
   149  	if err != nil {
   150  		logger.Debugf(prefix+": Bogus Snap-Store-Version header %q.", verstr)
   151  		return false
   152  	}
   153  	return ver >= 6
   154  }
   155  
   156  func (cfg *serialRequestConfig) setURLs(proxyURL, svcURL *url.URL) {
   157  	base := baseStoreURL
   158  	if proxyURL != nil {
   159  		if svcURL != nil {
   160  			if cfg.headers == nil {
   161  				cfg.headers = make(map[string]string, 1)
   162  			}
   163  			cfg.headers["X-Snap-Device-Service-URL"] = svcURL.String()
   164  		}
   165  		base = proxyURL.ResolveReference(authRef)
   166  	} else if svcURL != nil {
   167  		base = svcURL
   168  	}
   169  
   170  	cfg.requestIDURL = base.ResolveReference(reqIdRef).String()
   171  	if svcURL != nil && proxyURL == nil {
   172  		// talking directly to the custom device service
   173  		cfg.serialRequestURL = base.ResolveReference(serialRef).String()
   174  	} else {
   175  		cfg.serialRequestURL = base.ResolveReference(devicesRef).String()
   176  	}
   177  }
   178  
   179  // A registrationContext handles the contextual information needed
   180  // for the initial registration or a re-registration.
   181  type registrationContext interface {
   182  	Device() (*auth.DeviceState, error)
   183  
   184  	Model() *asserts.Model
   185  
   186  	GadgetForSerialRequestConfig() string
   187  	SerialRequestExtraHeaders() map[string]interface{}
   188  	SerialRequestAncillaryAssertions() []asserts.Assertion
   189  
   190  	FinishRegistration(serial *asserts.Serial) error
   191  
   192  	ForRemodeling() bool
   193  }
   194  
   195  // initialRegistrationContext is a thin wrapper around DeviceManager
   196  // implementing registrationContext for initial regitration
   197  type initialRegistrationContext struct {
   198  	deviceMgr *DeviceManager
   199  
   200  	model *asserts.Model
   201  }
   202  
   203  func (rc *initialRegistrationContext) ForRemodeling() bool {
   204  	return false
   205  }
   206  
   207  func (rc *initialRegistrationContext) Device() (*auth.DeviceState, error) {
   208  	return rc.deviceMgr.device()
   209  }
   210  
   211  func (rc *initialRegistrationContext) Model() *asserts.Model {
   212  	return rc.model
   213  }
   214  
   215  func (rc *initialRegistrationContext) GadgetForSerialRequestConfig() string {
   216  	return rc.model.Gadget()
   217  }
   218  
   219  func (rc *initialRegistrationContext) SerialRequestExtraHeaders() map[string]interface{} {
   220  	return nil
   221  }
   222  
   223  func (rc *initialRegistrationContext) SerialRequestAncillaryAssertions() []asserts.Assertion {
   224  	return []asserts.Assertion{rc.model}
   225  }
   226  
   227  func (rc *initialRegistrationContext) FinishRegistration(serial *asserts.Serial) error {
   228  	device, err := rc.deviceMgr.device()
   229  	if err != nil {
   230  		return err
   231  	}
   232  
   233  	device.Serial = serial.Serial()
   234  	if err := rc.deviceMgr.setDevice(device); err != nil {
   235  		return err
   236  	}
   237  	rc.deviceMgr.markRegistered()
   238  
   239  	// make sure we timely consider anything that was blocked on
   240  	// registration
   241  	rc.deviceMgr.state.EnsureBefore(0)
   242  
   243  	return nil
   244  }
   245  
   246  // registrationCtx returns a registrationContext appropriate for the task and its change.
   247  func (m *DeviceManager) registrationCtx(t *state.Task) (registrationContext, error) {
   248  	remodCtx, err := remodelCtxFromTask(t)
   249  	if err != nil && err != state.ErrNoState {
   250  		return nil, err
   251  	}
   252  	if regCtx, ok := remodCtx.(registrationContext); ok {
   253  		return regCtx, nil
   254  	}
   255  	model, err := m.Model()
   256  	if err != nil {
   257  		return nil, err
   258  	}
   259  
   260  	return &initialRegistrationContext{
   261  		deviceMgr: m,
   262  		model:     model,
   263  	}, nil
   264  }
   265  
   266  type serialSetup struct {
   267  	SerialRequest string `json:"serial-request"`
   268  	Serial        string `json:"serial"`
   269  }
   270  
   271  type requestIDResp struct {
   272  	RequestID string `json:"request-id"`
   273  }
   274  
   275  func retryErr(t *state.Task, nTentatives int, reason string, a ...interface{}) error {
   276  	t.State().Lock()
   277  	defer t.State().Unlock()
   278  	if nTentatives >= maxTentatives {
   279  		return fmt.Errorf(reason, a...)
   280  	}
   281  	t.Errorf(reason, a...)
   282  	return &state.Retry{After: retryInterval}
   283  }
   284  
   285  type serverError struct {
   286  	Message string         `json:"message"`
   287  	Errors  []*serverError `json:"error_list"`
   288  }
   289  
   290  func retryBadStatus(t *state.Task, nTentatives int, reason string, resp *http.Response) error {
   291  	if resp.StatusCode > 500 {
   292  		// likely temporary
   293  		return retryErr(t, nTentatives, "%s: unexpected status %d", reason, resp.StatusCode)
   294  	}
   295  	if resp.Header.Get("Content-Type") == "application/json" {
   296  		var srvErr serverError
   297  		dec := json.NewDecoder(resp.Body)
   298  		err := dec.Decode(&srvErr)
   299  		if err == nil {
   300  			msg := srvErr.Message
   301  			if msg == "" && len(srvErr.Errors) > 0 {
   302  				msg = srvErr.Errors[0].Message
   303  			}
   304  			if msg != "" {
   305  				return fmt.Errorf("%s: %s", reason, msg)
   306  			}
   307  		}
   308  	}
   309  	return fmt.Errorf("%s: unexpected status %d", reason, resp.StatusCode)
   310  }
   311  
   312  func prepareSerialRequest(t *state.Task, regCtx registrationContext, privKey asserts.PrivateKey, device *auth.DeviceState, client *http.Client, cfg *serialRequestConfig) (string, error) {
   313  	// limit tentatives starting from scratch before going to
   314  	// slower full retries
   315  	var nTentatives int
   316  	err := t.Get("pre-poll-tentatives", &nTentatives)
   317  	if err != nil && err != state.ErrNoState {
   318  		return "", err
   319  	}
   320  	nTentatives++
   321  	t.Set("pre-poll-tentatives", nTentatives)
   322  
   323  	st := t.State()
   324  	st.Unlock()
   325  	defer st.Lock()
   326  
   327  	req, err := http.NewRequest("POST", cfg.requestIDURL, nil)
   328  	if err != nil {
   329  		return "", fmt.Errorf("internal error: cannot create request-id request %q", cfg.requestIDURL)
   330  	}
   331  	req.Header.Set("User-Agent", snapdenv.UserAgent())
   332  	cfg.applyHeaders(req)
   333  
   334  	resp, err := client.Do(req)
   335  	if err != nil {
   336  		if !httputil.ShouldRetryError(err) {
   337  			// a non temporary net error fully errors out and triggers a retry
   338  			// retries
   339  			return "", fmt.Errorf("cannot retrieve request-id for making a request for a serial: %v", err)
   340  		}
   341  		if httputil.NoNetwork(err) {
   342  			// If there is no network there is no need to count
   343  			// this as a tentatives attempt. If we do it this
   344  			// way the risk is that we tried a bunch of times
   345  			// with no network and if we hit the server for real
   346  			// and it replies with something we need to retry
   347  			// we will not because nTentatives is way over the
   348  			// limit.
   349  			st.Lock()
   350  			t.Set("pre-poll-tentatives", 0)
   351  			st.Unlock()
   352  			// Retry quickly if there is no network
   353  			// (yet). This ensures that we try to get a serial
   354  			// as soon as the user configured the network of the
   355  			// device
   356  			noNetworkRetryInterval := retryInterval / 2
   357  			return "", &state.Retry{After: noNetworkRetryInterval}
   358  		}
   359  
   360  		return "", retryErr(t, nTentatives, "cannot retrieve request-id for making a request for a serial: %v", err)
   361  	}
   362  	defer resp.Body.Close()
   363  	if resp.StatusCode != 200 {
   364  		return "", retryBadStatus(t, nTentatives, "cannot retrieve request-id for making a request for a serial", resp)
   365  	}
   366  
   367  	dec := json.NewDecoder(resp.Body)
   368  	var requestID requestIDResp
   369  	err = dec.Decode(&requestID)
   370  	if err != nil { // assume broken i/o
   371  		return "", retryErr(t, nTentatives, "cannot read response with request-id for making a request for a serial: %v", err)
   372  	}
   373  
   374  	encodedPubKey, err := asserts.EncodePublicKey(privKey.PublicKey())
   375  	if err != nil {
   376  		return "", fmt.Errorf("internal error: cannot encode device public key: %v", err)
   377  
   378  	}
   379  
   380  	headers := map[string]interface{}{
   381  		"brand-id":   device.Brand,
   382  		"model":      device.Model,
   383  		"request-id": requestID.RequestID,
   384  		"device-key": string(encodedPubKey),
   385  	}
   386  	if cfg.proposedSerial != "" {
   387  		headers["serial"] = cfg.proposedSerial
   388  	}
   389  
   390  	for k, v := range regCtx.SerialRequestExtraHeaders() {
   391  		headers[k] = v
   392  	}
   393  
   394  	serialReq, err := asserts.SignWithoutAuthority(asserts.SerialRequestType, headers, cfg.body, privKey)
   395  	if err != nil {
   396  		return "", err
   397  	}
   398  
   399  	buf := new(bytes.Buffer)
   400  	encoder := asserts.NewEncoder(buf)
   401  	if err := encoder.Encode(serialReq); err != nil {
   402  		return "", fmt.Errorf("cannot encode serial-request: %v", err)
   403  	}
   404  
   405  	for _, ancillaryAs := range regCtx.SerialRequestAncillaryAssertions() {
   406  		if err := encoder.Encode(ancillaryAs); err != nil {
   407  			return "", fmt.Errorf("cannot encode ancillary assertion: %v", err)
   408  		}
   409  
   410  	}
   411  
   412  	return buf.String(), nil
   413  }
   414  
   415  var errPoll = errors.New("serial-request accepted, poll later")
   416  
   417  func submitSerialRequest(t *state.Task, serialRequest string, client *http.Client, cfg *serialRequestConfig) (*asserts.Serial, *asserts.Batch, error) {
   418  	st := t.State()
   419  	st.Unlock()
   420  	defer st.Lock()
   421  
   422  	req, err := http.NewRequest("POST", cfg.serialRequestURL, bytes.NewBufferString(serialRequest))
   423  	if err != nil {
   424  		return nil, nil, fmt.Errorf("internal error: cannot create serial-request request %q", cfg.serialRequestURL)
   425  	}
   426  	req.Header.Set("User-Agent", snapdenv.UserAgent())
   427  	req.Header.Set("Snap-Device-Capabilities", strings.Join(registrationCapabilities, " "))
   428  	cfg.applyHeaders(req)
   429  	req.Header.Set("Content-Type", asserts.MediaType)
   430  
   431  	resp, err := client.Do(req)
   432  	if err != nil {
   433  		return nil, nil, retryErr(t, 0, "cannot deliver device serial request: %v", err)
   434  	}
   435  	defer resp.Body.Close()
   436  
   437  	switch resp.StatusCode {
   438  	case 200, 201:
   439  	case 202:
   440  		return nil, nil, errPoll
   441  	default:
   442  		return nil, nil, retryBadStatus(t, 0, "cannot deliver device serial request", resp)
   443  	}
   444  
   445  	var serial *asserts.Serial
   446  	var batch *asserts.Batch
   447  	// decode body with stream of assertions, of which one is the serial
   448  	dec := asserts.NewDecoder(resp.Body)
   449  	for {
   450  		got, err := dec.Decode()
   451  		if err == io.EOF {
   452  			break
   453  		}
   454  		if err != nil { // assume broken i/o
   455  			return nil, nil, retryErr(t, 0, "cannot read response to request for a serial: %v", err)
   456  		}
   457  		if got.Type() == asserts.SerialType {
   458  			if serial != nil {
   459  				return nil, nil, fmt.Errorf("cannot accept more than a single device serial assertion from the device service")
   460  			}
   461  			serial = got.(*asserts.Serial)
   462  		} else {
   463  			if batch == nil {
   464  				batch = asserts.NewBatch(nil)
   465  			}
   466  			if err := batch.Add(got); err != nil {
   467  				return nil, nil, err
   468  			}
   469  		}
   470  		// TODO: consider a size limit?
   471  	}
   472  
   473  	if serial == nil {
   474  		return nil, nil, fmt.Errorf("cannot proceed, received assertion stream from the device service missing device serial assertion")
   475  	}
   476  
   477  	return serial, batch, nil
   478  }
   479  
   480  var httputilNewHTTPClient = httputil.NewHTTPClient
   481  
   482  func getSerial(t *state.Task, regCtx registrationContext, privKey asserts.PrivateKey, device *auth.DeviceState, tm timings.Measurer) (serial *asserts.Serial, ancillaryBatch *asserts.Batch, err error) {
   483  	var serialSup serialSetup
   484  	err = t.Get("serial-setup", &serialSup)
   485  	if err != nil && err != state.ErrNoState {
   486  		return nil, nil, err
   487  	}
   488  
   489  	if serialSup.Serial != "" {
   490  		// we got a serial, just haven't managed to save its info yet
   491  		a, err := asserts.Decode([]byte(serialSup.Serial))
   492  		if err != nil {
   493  			return nil, nil, fmt.Errorf("internal error: cannot decode previously saved serial: %v", err)
   494  		}
   495  		return a.(*asserts.Serial), nil, nil
   496  	}
   497  
   498  	st := t.State()
   499  	proxyConf := proxyconf.New(st)
   500  	client := httputilNewHTTPClient(&httputil.ClientOptions{
   501  		Timeout:            30 * time.Second,
   502  		MayLogBody:         true,
   503  		Proxy:              proxyConf.Conf,
   504  		ProxyConnectHeader: http.Header{"User-Agent": []string{snapdenv.UserAgent()}},
   505  	})
   506  
   507  	cfg, err := getSerialRequestConfig(t, regCtx, client)
   508  	if err != nil {
   509  		return nil, nil, err
   510  	}
   511  
   512  	// NB: until we get at least an Accepted (202) we need to
   513  	// retry from scratch creating a new request-id because the
   514  	// previous one used could have expired
   515  
   516  	if serialSup.SerialRequest == "" {
   517  		var serialRequest string
   518  		var err error
   519  		timings.Run(tm, "prepare-serial-request", "prepare device serial request", func(timings.Measurer) {
   520  			serialRequest, err = prepareSerialRequest(t, regCtx, privKey, device, client, cfg)
   521  		})
   522  		if err != nil { // errors & retries
   523  			return nil, nil, err
   524  		}
   525  
   526  		serialSup.SerialRequest = serialRequest
   527  	}
   528  
   529  	timings.Run(tm, "submit-serial-request", "submit device serial request", func(timings.Measurer) {
   530  		serial, ancillaryBatch, err = submitSerialRequest(t, serialSup.SerialRequest, client, cfg)
   531  	})
   532  	if err == errPoll {
   533  		// we can/should reuse the serial-request
   534  		t.Set("serial-setup", serialSup)
   535  		return nil, nil, errPoll
   536  	}
   537  	if err != nil { // errors & retries
   538  		return nil, nil, err
   539  	}
   540  
   541  	keyID := privKey.PublicKey().ID()
   542  	if serial.BrandID() != device.Brand || serial.Model() != device.Model || serial.DeviceKey().ID() != keyID {
   543  		return nil, nil, fmt.Errorf("obtained serial assertion does not match provided device identity information (brand, model, key id): %s / %s / %s != %s / %s / %s", serial.BrandID(), serial.Model(), serial.DeviceKey().ID(), device.Brand, device.Model, keyID)
   544  	}
   545  
   546  	// cross check authority if different from brand-id
   547  	if serial.BrandID() != serial.AuthorityID() {
   548  		model := regCtx.Model()
   549  		if !strutil.ListContains(model.SerialAuthority(), serial.AuthorityID()) {
   550  			return nil, nil, fmt.Errorf("obtained serial assertion is signed by authority %q different from brand %q without model assertion with serial-authority set to to allow for them", serial.AuthorityID(), serial.BrandID())
   551  		}
   552  	}
   553  
   554  	if ancillaryBatch == nil {
   555  		serialSup.Serial = string(asserts.Encode(serial))
   556  		t.Set("serial-setup", serialSup)
   557  	}
   558  
   559  	if repeatRequestSerial == "after-got-serial" {
   560  		// For testing purposes, ensure a crash in this state works.
   561  		return nil, nil, &state.Retry{}
   562  	}
   563  
   564  	return serial, ancillaryBatch, nil
   565  }
   566  
   567  type serialRequestConfig struct {
   568  	requestIDURL     string
   569  	serialRequestURL string
   570  	headers          map[string]string
   571  	proposedSerial   string
   572  	body             []byte
   573  }
   574  
   575  func (cfg *serialRequestConfig) applyHeaders(req *http.Request) {
   576  	for k, v := range cfg.headers {
   577  		req.Header.Set(k, v)
   578  	}
   579  }
   580  
   581  func getSerialRequestConfig(t *state.Task, regCtx registrationContext, client *http.Client) (*serialRequestConfig, error) {
   582  	var svcURL, proxyURL *url.URL
   583  
   584  	st := t.State()
   585  	tr := config.NewTransaction(st)
   586  	if proxyStore, err := proxyStore(st, tr); err != nil && err != state.ErrNoState {
   587  		return nil, err
   588  	} else if proxyStore != nil {
   589  		proxyURL = proxyStore.URL()
   590  	}
   591  
   592  	cfg := serialRequestConfig{}
   593  
   594  	gadgetName := regCtx.GadgetForSerialRequestConfig()
   595  	// gadget is optional on classic
   596  	if gadgetName != "" {
   597  		var gadgetSt snapstate.SnapState
   598  		if err := snapstate.Get(st, gadgetName, &gadgetSt); err != nil {
   599  			return nil, fmt.Errorf("cannot find gadget snap %q: %v", gadgetName, err)
   600  		}
   601  
   602  		var svcURI string
   603  		err := tr.GetMaybe(gadgetName, "device-service.url", &svcURI)
   604  		if err != nil {
   605  			return nil, err
   606  		}
   607  
   608  		if svcURI != "" {
   609  			svcURL, err = url.Parse(svcURI)
   610  			if err != nil {
   611  				return nil, fmt.Errorf("cannot parse device registration base URL %q: %v", svcURI, err)
   612  			}
   613  			if !strings.HasSuffix(svcURL.Path, "/") {
   614  				svcURL.Path += "/"
   615  			}
   616  		}
   617  
   618  		err = tr.GetMaybe(gadgetName, "device-service.headers", &cfg.headers)
   619  		if err != nil {
   620  			return nil, err
   621  		}
   622  
   623  		var bodyStr string
   624  		err = tr.GetMaybe(gadgetName, "registration.body", &bodyStr)
   625  		if err != nil {
   626  			return nil, err
   627  		}
   628  
   629  		cfg.body = []byte(bodyStr)
   630  
   631  		err = tr.GetMaybe(gadgetName, "registration.proposed-serial", &cfg.proposedSerial)
   632  		if err != nil {
   633  			return nil, err
   634  		}
   635  	}
   636  
   637  	if proxyURL != nil && svcURL != nil && !newEnoughProxy(st, proxyURL, client) {
   638  		logger.Noticef("Proxy store does not support custom serial vault; ignoring the proxy")
   639  		proxyURL = nil
   640  	}
   641  
   642  	cfg.setURLs(proxyURL, svcURL)
   643  
   644  	return &cfg, nil
   645  }
   646  
   647  func (m *DeviceManager) doRequestSerial(t *state.Task, _ *tomb.Tomb) error {
   648  	st := t.State()
   649  	st.Lock()
   650  	defer st.Unlock()
   651  
   652  	perfTimings := state.TimingsForTask(t)
   653  	defer perfTimings.Save(st)
   654  
   655  	regCtx, err := m.registrationCtx(t)
   656  	if err != nil {
   657  		return err
   658  	}
   659  
   660  	device, err := regCtx.Device()
   661  	if err != nil {
   662  		return err
   663  	}
   664  
   665  	// NB: the keyPair is fixed for now
   666  	privKey, err := m.keyPair()
   667  	if err == state.ErrNoState {
   668  		return fmt.Errorf("internal error: cannot find device key pair")
   669  	}
   670  	if err != nil {
   671  		return err
   672  	}
   673  
   674  	// make this idempotent, look if we have already a serial assertion
   675  	// for privKey
   676  	serials, err := assertstate.DB(st).FindMany(asserts.SerialType, map[string]string{
   677  		"brand-id":            device.Brand,
   678  		"model":               device.Model,
   679  		"device-key-sha3-384": privKey.PublicKey().ID(),
   680  	})
   681  	if err != nil && !asserts.IsNotFound(err) {
   682  		return err
   683  	}
   684  
   685  	finish := func(serial *asserts.Serial) error {
   686  		if regCtx.FinishRegistration(serial); err != nil {
   687  			return err
   688  		}
   689  		t.SetStatus(state.DoneStatus)
   690  		return nil
   691  	}
   692  
   693  	if len(serials) == 1 {
   694  		// means we saved the assertion but didn't get to the end of the task
   695  		return finish(serials[0].(*asserts.Serial))
   696  	}
   697  	if len(serials) > 1 {
   698  		return fmt.Errorf("internal error: multiple serial assertions for the same device key")
   699  	}
   700  
   701  	var serial *asserts.Serial
   702  	var ancillaryBatch *asserts.Batch
   703  	timings.Run(perfTimings, "get-serial", "get device serial", func(tm timings.Measurer) {
   704  		serial, ancillaryBatch, err = getSerial(t, regCtx, privKey, device, tm)
   705  	})
   706  	if err == errPoll {
   707  		t.Logf("Will poll for device serial assertion in 60 seconds")
   708  		return &state.Retry{After: retryInterval}
   709  	}
   710  	if err != nil { // errors & retries
   711  		return err
   712  
   713  	}
   714  
   715  	// TODO: the accept* helpers put the serial directly in the
   716  	// system assertion database, that will not work
   717  	// for 3rd-party signed serials in the case of a remodel
   718  	// because the model is added only later. If needed, the best way
   719  	// to fix this requires rethinking how remodel and new assertions
   720  	// interact
   721  	if ancillaryBatch == nil {
   722  		// the device service returned only the serial
   723  		if err := acceptSerialOnly(t, serial, perfTimings); err != nil {
   724  			return err
   725  		}
   726  	} else {
   727  		// the device service returned a stream of assertions
   728  		timings.Run(perfTimings, "fetch-keys", "fetch signing key chain", func(timings.Measurer) {
   729  			err = acceptSerialPlusBatch(t, serial, ancillaryBatch)
   730  		})
   731  		if err != nil {
   732  			t.Errorf("cannot accept stream of assertions from device service: %v", err)
   733  			return err
   734  		}
   735  	}
   736  
   737  	if repeatRequestSerial == "after-add-serial" {
   738  		// For testing purposes, ensure a crash in this state works.
   739  		return &state.Retry{}
   740  	}
   741  
   742  	return finish(serial)
   743  }
   744  
   745  func acceptSerialOnly(t *state.Task, serial *asserts.Serial, perfTimings *timings.Timings) error {
   746  	st := t.State()
   747  	var err error
   748  	var errAcctKey error
   749  	// try to fetch the signing key chain of the serial
   750  	timings.Run(perfTimings, "fetch-keys", "fetch signing key chain", func(timings.Measurer) {
   751  		errAcctKey, err = fetchKeys(st, serial.SignKeyID())
   752  	})
   753  	if err != nil {
   754  		return err
   755  	}
   756  
   757  	// add the serial assertion to the system assertion db
   758  	err = assertstate.Add(st, serial)
   759  	if err != nil {
   760  		// if we had failed to fetch the signing key, retry in a bit
   761  		if errAcctKey != nil {
   762  			t.Errorf("cannot fetch signing key for the serial: %v", errAcctKey)
   763  			return &state.Retry{After: retryInterval}
   764  		}
   765  		return err
   766  	}
   767  
   768  	return nil
   769  }
   770  
   771  func acceptSerialPlusBatch(t *state.Task, serial *asserts.Serial, batch *asserts.Batch) error {
   772  	st := t.State()
   773  	err := batch.Add(serial)
   774  	if err != nil {
   775  		return err
   776  	}
   777  	return assertstate.AddBatch(st, batch, &asserts.CommitOptions{Precheck: true})
   778  }
   779  
   780  var repeatRequestSerial string // for tests
   781  
   782  func fetchKeys(st *state.State, keyID string) (errAcctKey error, err error) {
   783  	// TODO: right now any store should be good enough here but
   784  	// that might change. As an alternative we do support
   785  	// receiving a stream with any relevant assertions.
   786  	sto := snapstate.Store(st, nil)
   787  	db := assertstate.DB(st)
   788  
   789  	retrieveError := false
   790  	retrieve := func(ref *asserts.Ref) (asserts.Assertion, error) {
   791  		st.Unlock()
   792  		defer st.Lock()
   793  		a, err := sto.Assertion(ref.Type, ref.PrimaryKey, nil)
   794  		retrieveError = err != nil
   795  		return a, err
   796  	}
   797  
   798  	save := func(a asserts.Assertion) error {
   799  		err = assertstate.Add(st, a)
   800  		if err != nil && !asserts.IsUnaccceptedUpdate(err) {
   801  			return err
   802  		}
   803  		return nil
   804  	}
   805  
   806  	f := asserts.NewFetcher(db, retrieve, save)
   807  
   808  	keyRef := &asserts.Ref{
   809  		Type:       asserts.AccountKeyType,
   810  		PrimaryKey: []string{keyID},
   811  	}
   812  	if err := f.Fetch(keyRef); err != nil {
   813  		if retrieveError {
   814  			return err, nil
   815  		} else {
   816  			return nil, err
   817  		}
   818  	}
   819  	return nil, nil
   820  }