github.com/rigado/snapd@v2.42.5-go-mod+incompatible/overlord/devicestate/handlers.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  /*
     3   * Copyright (C) 2016-2017 Canonical Ltd
     4   *
     5   * This program is free software: you can redistribute it and/or modify
     6   * it under the terms of the GNU General Public License version 3 as
     7   * published by the Free Software Foundation.
     8   *
     9   * This program is distributed in the hope that it will be useful,
    10   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    11   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    12   * GNU General Public License for more details.
    13   *
    14   * You should have received a copy of the GNU General Public License
    15   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    16   *
    17   */
    18  
    19  package devicestate
    20  
    21  import (
    22  	"bytes"
    23  	"crypto/rsa"
    24  	"encoding/json"
    25  	"errors"
    26  	"fmt"
    27  	"io"
    28  	"net"
    29  	"net/http"
    30  	"net/url"
    31  	"os"
    32  	"path/filepath"
    33  	"strconv"
    34  	"strings"
    35  	"time"
    36  
    37  	"gopkg.in/tomb.v2"
    38  
    39  	"github.com/snapcore/snapd/asserts"
    40  	"github.com/snapcore/snapd/dirs"
    41  	"github.com/snapcore/snapd/gadget"
    42  	"github.com/snapcore/snapd/httputil"
    43  	"github.com/snapcore/snapd/logger"
    44  	"github.com/snapcore/snapd/osutil"
    45  	"github.com/snapcore/snapd/overlord/assertstate"
    46  	"github.com/snapcore/snapd/overlord/auth"
    47  	"github.com/snapcore/snapd/overlord/configstate/config"
    48  	"github.com/snapcore/snapd/overlord/configstate/proxyconf"
    49  	"github.com/snapcore/snapd/overlord/snapstate"
    50  	"github.com/snapcore/snapd/overlord/state"
    51  	"github.com/snapcore/snapd/release"
    52  	"github.com/snapcore/snapd/snap"
    53  	"github.com/snapcore/snapd/snap/naming"
    54  	"github.com/snapcore/snapd/timings"
    55  )
    56  
    57  func (m *DeviceManager) doMarkSeeded(t *state.Task, _ *tomb.Tomb) error {
    58  	st := t.State()
    59  	st.Lock()
    60  	defer st.Unlock()
    61  
    62  	st.Set("seed-time", time.Now())
    63  	st.Set("seeded", true)
    64  	// make sure we setup a fallback model/consider the next phase
    65  	// (registration) timely
    66  	st.EnsureBefore(0)
    67  	return nil
    68  }
    69  
    70  func isSameAssertsRevision(err error) bool {
    71  	if e, ok := err.(*asserts.RevisionError); ok {
    72  		if e.Used == e.Current {
    73  			return true
    74  		}
    75  	}
    76  	return false
    77  }
    78  
    79  func (m *DeviceManager) doSetModel(t *state.Task, _ *tomb.Tomb) error {
    80  	st := t.State()
    81  	st.Lock()
    82  	defer st.Unlock()
    83  
    84  	remodCtx, err := remodelCtxFromTask(t)
    85  	if err != nil {
    86  		return err
    87  	}
    88  	new := remodCtx.Model()
    89  
    90  	err = assertstate.Add(st, new)
    91  	if err != nil && !isSameAssertsRevision(err) {
    92  		return err
    93  	}
    94  
    95  	// unmark no-longer required snaps
    96  	requiredSnaps := getAllRequiredSnapsForModel(new)
    97  	// TODO:XXX: have AllByRef
    98  	snapStates, err := snapstate.All(st)
    99  	if err != nil {
   100  		return err
   101  	}
   102  	for snapName, snapst := range snapStates {
   103  		// TODO: remove this type restriction once we remodel
   104  		//       gadgets and add tests that ensure
   105  		//       that the required flag is properly set/unset
   106  		typ, err := snapst.Type()
   107  		if err != nil {
   108  			return err
   109  		}
   110  		if typ != snap.TypeApp && typ != snap.TypeBase && typ != snap.TypeKernel {
   111  			continue
   112  		}
   113  		// clean required flag if no-longer needed
   114  		if snapst.Flags.Required && !requiredSnaps.Contains(naming.Snap(snapName)) {
   115  			snapst.Flags.Required = false
   116  			snapstate.Set(st, snapName, snapst)
   117  		}
   118  		// TODO: clean "required" flag of "core" if a remodel
   119  		//       moves from the "core" snap to a different
   120  		//       bootable base snap.
   121  	}
   122  
   123  	return remodCtx.Finish()
   124  }
   125  
   126  func (m *DeviceManager) cleanupRemodel(t *state.Task, _ *tomb.Tomb) error {
   127  	st := t.State()
   128  	st.Lock()
   129  	defer st.Unlock()
   130  	// cleanup the cached remodel context
   131  	cleanupRemodelCtx(t.Change())
   132  	return nil
   133  }
   134  
   135  func useStaging() bool {
   136  	return osutil.GetenvBool("SNAPPY_USE_STAGING_STORE")
   137  }
   138  
   139  func baseURL() *url.URL {
   140  	if useStaging() {
   141  		return mustParse("https://api.staging.snapcraft.io/")
   142  	}
   143  	return mustParse("https://api.snapcraft.io/")
   144  }
   145  
   146  func mustParse(s string) *url.URL {
   147  	u, err := url.Parse(s)
   148  	if err != nil {
   149  		panic(err)
   150  	}
   151  	return u
   152  }
   153  
   154  var (
   155  	keyLength     = 4096
   156  	retryInterval = 60 * time.Second
   157  	maxTentatives = 15
   158  	baseStoreURL  = baseURL().ResolveReference(authRef)
   159  
   160  	authRef    = mustParse("api/v1/snaps/auth/") // authRef must end in / for the following refs to work
   161  	reqIdRef   = mustParse("request-id")
   162  	serialRef  = mustParse("serial")
   163  	devicesRef = mustParse("devices")
   164  
   165  	// we accept a stream with the serial assertion as well
   166  	registrationCapabilities = []string{"serial-stream"}
   167  )
   168  
   169  func newEnoughProxy(st *state.State, proxyURL *url.URL, client *http.Client) bool {
   170  	st.Unlock()
   171  	defer st.Lock()
   172  
   173  	const prefix = "Cannot check whether proxy store supports a custom serial vault"
   174  
   175  	req, err := http.NewRequest("HEAD", proxyURL.String(), nil)
   176  	if err != nil {
   177  		// can't really happen unless proxyURL is somehow broken
   178  		logger.Debugf(prefix+": %v", err)
   179  		return false
   180  	}
   181  	req.Header.Set("User-Agent", httputil.UserAgent())
   182  	resp, err := client.Do(req)
   183  	if err != nil {
   184  		// some sort of network or protocol error
   185  		logger.Debugf(prefix+": %v", err)
   186  		return false
   187  	}
   188  	resp.Body.Close()
   189  	if resp.StatusCode != 200 {
   190  		logger.Debugf(prefix+": Head request returned %s.", resp.Status)
   191  		return false
   192  	}
   193  	verstr := resp.Header.Get("Snap-Store-Version")
   194  	ver, err := strconv.Atoi(verstr)
   195  	if err != nil {
   196  		logger.Debugf(prefix+": Bogus Snap-Store-Version header %q.", verstr)
   197  		return false
   198  	}
   199  	return ver >= 6
   200  }
   201  
   202  func (cfg *serialRequestConfig) setURLs(proxyURL, svcURL *url.URL) {
   203  	base := baseStoreURL
   204  	if proxyURL != nil {
   205  		if svcURL != nil {
   206  			if cfg.headers == nil {
   207  				cfg.headers = make(map[string]string, 1)
   208  			}
   209  			cfg.headers["X-Snap-Device-Service-URL"] = svcURL.String()
   210  		}
   211  		base = proxyURL.ResolveReference(authRef)
   212  	} else if svcURL != nil {
   213  		base = svcURL
   214  	}
   215  
   216  	cfg.requestIDURL = base.ResolveReference(reqIdRef).String()
   217  	if svcURL != nil && proxyURL == nil {
   218  		// talking directly to the custom device service
   219  		cfg.serialRequestURL = base.ResolveReference(serialRef).String()
   220  	} else {
   221  		cfg.serialRequestURL = base.ResolveReference(devicesRef).String()
   222  	}
   223  }
   224  
   225  func (m *DeviceManager) doGenerateDeviceKey(t *state.Task, _ *tomb.Tomb) error {
   226  	st := t.State()
   227  	st.Lock()
   228  	defer st.Unlock()
   229  
   230  	perfTimings := timings.NewForTask(t)
   231  	defer perfTimings.Save(st)
   232  
   233  	device, err := m.device()
   234  	if err != nil {
   235  		return err
   236  	}
   237  
   238  	if device.KeyID != "" {
   239  		// nothing to do
   240  		return nil
   241  	}
   242  
   243  	st.Unlock()
   244  	var keyPair *rsa.PrivateKey
   245  	timings.Run(perfTimings, "generate-rsa-key", "generating device key pair", func(tm timings.Measurer) {
   246  		keyPair, err = generateRSAKey(keyLength)
   247  	})
   248  	st.Lock()
   249  	if err != nil {
   250  		return fmt.Errorf("cannot generate device key pair: %v", err)
   251  	}
   252  
   253  	privKey := asserts.RSAPrivateKey(keyPair)
   254  	err = m.keypairMgr.Put(privKey)
   255  	if err != nil {
   256  		return fmt.Errorf("cannot store device key pair: %v", err)
   257  	}
   258  
   259  	device.KeyID = privKey.PublicKey().ID()
   260  	err = m.setDevice(device)
   261  	if err != nil {
   262  		return err
   263  	}
   264  	t.SetStatus(state.DoneStatus)
   265  	return nil
   266  }
   267  
   268  // A registrationContext handles the contextual information needed
   269  // for the initial registration or a re-registration.
   270  type registrationContext interface {
   271  	Device() (*auth.DeviceState, error)
   272  
   273  	GadgetForSerialRequestConfig() string
   274  	SerialRequestExtraHeaders() map[string]interface{}
   275  	SerialRequestAncillaryAssertions() []asserts.Assertion
   276  
   277  	FinishRegistration(serial *asserts.Serial) error
   278  
   279  	ForRemodeling() bool
   280  }
   281  
   282  // initialRegistrationContext is a thin wrapper around DeviceManager
   283  // implementing registrationContext for initial regitration
   284  type initialRegistrationContext struct {
   285  	deviceMgr *DeviceManager
   286  
   287  	gadget string
   288  }
   289  
   290  func (rc *initialRegistrationContext) ForRemodeling() bool {
   291  	return false
   292  }
   293  
   294  func (rc *initialRegistrationContext) Device() (*auth.DeviceState, error) {
   295  	return rc.deviceMgr.device()
   296  }
   297  
   298  func (rc *initialRegistrationContext) GadgetForSerialRequestConfig() string {
   299  	return rc.gadget
   300  }
   301  
   302  func (rc *initialRegistrationContext) SerialRequestExtraHeaders() map[string]interface{} {
   303  	return nil
   304  }
   305  
   306  func (rc *initialRegistrationContext) SerialRequestAncillaryAssertions() []asserts.Assertion {
   307  	return nil
   308  }
   309  
   310  func (rc *initialRegistrationContext) FinishRegistration(serial *asserts.Serial) error {
   311  	device, err := rc.deviceMgr.device()
   312  	if err != nil {
   313  		return err
   314  	}
   315  
   316  	device.Serial = serial.Serial()
   317  	if err := rc.deviceMgr.setDevice(device); err != nil {
   318  		return err
   319  	}
   320  	rc.deviceMgr.markRegistered()
   321  
   322  	// make sure we timely consider anything that was blocked on
   323  	// registration
   324  	rc.deviceMgr.state.EnsureBefore(0)
   325  
   326  	return nil
   327  }
   328  
   329  // registrationCtx returns a registrationContext appropriate for the task and its change.
   330  func (m *DeviceManager) registrationCtx(t *state.Task) (registrationContext, error) {
   331  	remodCtx, err := remodelCtxFromTask(t)
   332  	if err != nil && err != state.ErrNoState {
   333  		return nil, err
   334  	}
   335  	if regCtx, ok := remodCtx.(registrationContext); ok {
   336  		return regCtx, nil
   337  	}
   338  	model, err := m.Model()
   339  	if err != nil {
   340  		return nil, err
   341  	}
   342  
   343  	return &initialRegistrationContext{
   344  		deviceMgr: m,
   345  		gadget:    model.Gadget(),
   346  	}, nil
   347  }
   348  
   349  type serialSetup struct {
   350  	SerialRequest string `json:"serial-request"`
   351  	Serial        string `json:"serial"`
   352  }
   353  
   354  type requestIDResp struct {
   355  	RequestID string `json:"request-id"`
   356  }
   357  
   358  func retryErr(t *state.Task, nTentatives int, reason string, a ...interface{}) error {
   359  	t.State().Lock()
   360  	defer t.State().Unlock()
   361  	if nTentatives >= maxTentatives {
   362  		return fmt.Errorf(reason, a...)
   363  	}
   364  	t.Errorf(reason, a...)
   365  	return &state.Retry{After: retryInterval}
   366  }
   367  
   368  type serverError struct {
   369  	Message string         `json:"message"`
   370  	Errors  []*serverError `json:"error_list"`
   371  }
   372  
   373  func retryBadStatus(t *state.Task, nTentatives int, reason string, resp *http.Response) error {
   374  	if resp.StatusCode > 500 {
   375  		// likely temporary
   376  		return retryErr(t, nTentatives, "%s: unexpected status %d", reason, resp.StatusCode)
   377  	}
   378  	if resp.Header.Get("Content-Type") == "application/json" {
   379  		var srvErr serverError
   380  		dec := json.NewDecoder(resp.Body)
   381  		err := dec.Decode(&srvErr)
   382  		if err == nil {
   383  			msg := srvErr.Message
   384  			if msg == "" && len(srvErr.Errors) > 0 {
   385  				msg = srvErr.Errors[0].Message
   386  			}
   387  			if msg != "" {
   388  				return fmt.Errorf("%s: %s", reason, msg)
   389  			}
   390  		}
   391  	}
   392  	return fmt.Errorf("%s: unexpected status %d", reason, resp.StatusCode)
   393  }
   394  
   395  func prepareSerialRequest(t *state.Task, regCtx registrationContext, privKey asserts.PrivateKey, device *auth.DeviceState, client *http.Client, cfg *serialRequestConfig) (string, error) {
   396  	// limit tentatives starting from scratch before going to
   397  	// slower full retries
   398  	var nTentatives int
   399  	err := t.Get("pre-poll-tentatives", &nTentatives)
   400  	if err != nil && err != state.ErrNoState {
   401  		return "", err
   402  	}
   403  	nTentatives++
   404  	t.Set("pre-poll-tentatives", nTentatives)
   405  
   406  	st := t.State()
   407  	st.Unlock()
   408  	defer st.Lock()
   409  
   410  	req, err := http.NewRequest("POST", cfg.requestIDURL, nil)
   411  	if err != nil {
   412  		return "", fmt.Errorf("internal error: cannot create request-id request %q", cfg.requestIDURL)
   413  	}
   414  	req.Header.Set("User-Agent", httputil.UserAgent())
   415  	cfg.applyHeaders(req)
   416  
   417  	resp, err := client.Do(req)
   418  	if err != nil {
   419  		if netErr, ok := err.(net.Error); ok && !netErr.Temporary() {
   420  			// a non temporary net error, like a DNS no
   421  			// host, error out and do full retries
   422  			return "", fmt.Errorf("cannot retrieve request-id for making a request for a serial: %v", err)
   423  		}
   424  		return "", retryErr(t, nTentatives, "cannot retrieve request-id for making a request for a serial: %v", err)
   425  	}
   426  	defer resp.Body.Close()
   427  	if resp.StatusCode != 200 {
   428  		return "", retryBadStatus(t, nTentatives, "cannot retrieve request-id for making a request for a serial", resp)
   429  	}
   430  
   431  	dec := json.NewDecoder(resp.Body)
   432  	var requestID requestIDResp
   433  	err = dec.Decode(&requestID)
   434  	if err != nil { // assume broken i/o
   435  		return "", retryErr(t, nTentatives, "cannot read response with request-id for making a request for a serial: %v", err)
   436  	}
   437  
   438  	encodedPubKey, err := asserts.EncodePublicKey(privKey.PublicKey())
   439  	if err != nil {
   440  		return "", fmt.Errorf("internal error: cannot encode device public key: %v", err)
   441  
   442  	}
   443  
   444  	headers := map[string]interface{}{
   445  		"brand-id":   device.Brand,
   446  		"model":      device.Model,
   447  		"request-id": requestID.RequestID,
   448  		"device-key": string(encodedPubKey),
   449  	}
   450  	if cfg.proposedSerial != "" {
   451  		headers["serial"] = cfg.proposedSerial
   452  	}
   453  
   454  	for k, v := range regCtx.SerialRequestExtraHeaders() {
   455  		headers[k] = v
   456  	}
   457  
   458  	serialReq, err := asserts.SignWithoutAuthority(asserts.SerialRequestType, headers, cfg.body, privKey)
   459  	if err != nil {
   460  		return "", err
   461  	}
   462  
   463  	buf := new(bytes.Buffer)
   464  	encoder := asserts.NewEncoder(buf)
   465  	if err := encoder.Encode(serialReq); err != nil {
   466  		return "", fmt.Errorf("cannot encode serial-request: %v", err)
   467  	}
   468  
   469  	for _, ancillaryAs := range regCtx.SerialRequestAncillaryAssertions() {
   470  		if err := encoder.Encode(ancillaryAs); err != nil {
   471  			return "", fmt.Errorf("cannot encode ancillary assertion: %v", err)
   472  		}
   473  
   474  	}
   475  
   476  	return buf.String(), nil
   477  }
   478  
   479  var errPoll = errors.New("serial-request accepted, poll later")
   480  
   481  func submitSerialRequest(t *state.Task, serialRequest string, client *http.Client, cfg *serialRequestConfig) (*asserts.Serial, *asserts.Batch, error) {
   482  	st := t.State()
   483  	st.Unlock()
   484  	defer st.Lock()
   485  
   486  	req, err := http.NewRequest("POST", cfg.serialRequestURL, bytes.NewBufferString(serialRequest))
   487  	if err != nil {
   488  		return nil, nil, fmt.Errorf("internal error: cannot create serial-request request %q", cfg.serialRequestURL)
   489  	}
   490  	req.Header.Set("User-Agent", httputil.UserAgent())
   491  	req.Header.Set("Snap-Device-Capabilities", strings.Join(registrationCapabilities, " "))
   492  	cfg.applyHeaders(req)
   493  	req.Header.Set("Content-Type", asserts.MediaType)
   494  
   495  	resp, err := client.Do(req)
   496  	if err != nil {
   497  		return nil, nil, retryErr(t, 0, "cannot deliver device serial request: %v", err)
   498  	}
   499  	defer resp.Body.Close()
   500  
   501  	switch resp.StatusCode {
   502  	case 200, 201:
   503  	case 202:
   504  		return nil, nil, errPoll
   505  	default:
   506  		return nil, nil, retryBadStatus(t, 0, "cannot deliver device serial request", resp)
   507  	}
   508  
   509  	var serial *asserts.Serial
   510  	var batch *asserts.Batch
   511  	// decode body with stream of assertions, of which one is the serial
   512  	dec := asserts.NewDecoder(resp.Body)
   513  	for {
   514  		got, err := dec.Decode()
   515  		if err == io.EOF {
   516  			break
   517  		}
   518  		if err != nil { // assume broken i/o
   519  			return nil, nil, retryErr(t, 0, "cannot read response to request for a serial: %v", err)
   520  		}
   521  		if got.Type() == asserts.SerialType {
   522  			if serial != nil {
   523  				return nil, nil, fmt.Errorf("cannot accept more than a single device serial assertion from the device service")
   524  			}
   525  			serial = got.(*asserts.Serial)
   526  		} else {
   527  			if batch == nil {
   528  				batch = asserts.NewBatch(nil)
   529  			}
   530  			if err := batch.Add(got); err != nil {
   531  				return nil, nil, err
   532  			}
   533  		}
   534  		// TODO: consider a size limit?
   535  	}
   536  
   537  	if serial == nil {
   538  		return nil, nil, fmt.Errorf("cannot proceed, received assertion stream from the device service missing device serial assertion")
   539  	}
   540  
   541  	return serial, batch, nil
   542  }
   543  
   544  func getSerial(t *state.Task, regCtx registrationContext, privKey asserts.PrivateKey, device *auth.DeviceState, tm timings.Measurer) (serial *asserts.Serial, ancillaryBatch *asserts.Batch, err error) {
   545  	var serialSup serialSetup
   546  	err = t.Get("serial-setup", &serialSup)
   547  	if err != nil && err != state.ErrNoState {
   548  		return nil, nil, err
   549  	}
   550  
   551  	if serialSup.Serial != "" {
   552  		// we got a serial, just haven't managed to save its info yet
   553  		a, err := asserts.Decode([]byte(serialSup.Serial))
   554  		if err != nil {
   555  			return nil, nil, fmt.Errorf("internal error: cannot decode previously saved serial: %v", err)
   556  		}
   557  		return a.(*asserts.Serial), nil, nil
   558  	}
   559  
   560  	st := t.State()
   561  	proxyConf := proxyconf.New(st)
   562  	client := httputil.NewHTTPClient(&httputil.ClientOptions{
   563  		Timeout:    30 * time.Second,
   564  		MayLogBody: true,
   565  		Proxy:      proxyConf.Conf,
   566  	})
   567  
   568  	cfg, err := getSerialRequestConfig(t, regCtx, client)
   569  	if err != nil {
   570  		return nil, nil, err
   571  	}
   572  
   573  	// NB: until we get at least an Accepted (202) we need to
   574  	// retry from scratch creating a new request-id because the
   575  	// previous one used could have expired
   576  
   577  	if serialSup.SerialRequest == "" {
   578  		var serialRequest string
   579  		var err error
   580  		timings.Run(tm, "prepare-serial-request", "prepare device serial request", func(timings.Measurer) {
   581  			serialRequest, err = prepareSerialRequest(t, regCtx, privKey, device, client, cfg)
   582  		})
   583  		if err != nil { // errors & retries
   584  			return nil, nil, err
   585  		}
   586  
   587  		serialSup.SerialRequest = serialRequest
   588  	}
   589  
   590  	timings.Run(tm, "submit-serial-request", "submit device serial request", func(timings.Measurer) {
   591  		serial, ancillaryBatch, err = submitSerialRequest(t, serialSup.SerialRequest, client, cfg)
   592  	})
   593  	if err == errPoll {
   594  		// we can/should reuse the serial-request
   595  		t.Set("serial-setup", serialSup)
   596  		return nil, nil, errPoll
   597  	}
   598  	if err != nil { // errors & retries
   599  		return nil, nil, err
   600  	}
   601  
   602  	keyID := privKey.PublicKey().ID()
   603  	if serial.BrandID() != device.Brand || serial.Model() != device.Model || serial.DeviceKey().ID() != keyID {
   604  		return nil, nil, fmt.Errorf("obtained serial assertion does not match provided device identity information (brand, model, key id): %s / %s / %s != %s / %s / %s", serial.BrandID(), serial.Model(), serial.DeviceKey().ID(), device.Brand, device.Model, keyID)
   605  	}
   606  
   607  	if ancillaryBatch == nil {
   608  		serialSup.Serial = string(asserts.Encode(serial))
   609  		t.Set("serial-setup", serialSup)
   610  	}
   611  
   612  	if repeatRequestSerial == "after-got-serial" {
   613  		// For testing purposes, ensure a crash in this state works.
   614  		return nil, nil, &state.Retry{}
   615  	}
   616  
   617  	return serial, ancillaryBatch, nil
   618  }
   619  
   620  type serialRequestConfig struct {
   621  	requestIDURL     string
   622  	serialRequestURL string
   623  	headers          map[string]string
   624  	proposedSerial   string
   625  	body             []byte
   626  }
   627  
   628  func (cfg *serialRequestConfig) applyHeaders(req *http.Request) {
   629  	for k, v := range cfg.headers {
   630  		req.Header.Set(k, v)
   631  	}
   632  }
   633  
   634  func getSerialRequestConfig(t *state.Task, regCtx registrationContext, client *http.Client) (*serialRequestConfig, error) {
   635  	var svcURL, proxyURL *url.URL
   636  
   637  	st := t.State()
   638  	tr := config.NewTransaction(st)
   639  	if proxyStore, err := proxyStore(st, tr); err != nil && err != state.ErrNoState {
   640  		return nil, err
   641  	} else if proxyStore != nil {
   642  		proxyURL = proxyStore.URL()
   643  	}
   644  
   645  	cfg := serialRequestConfig{}
   646  
   647  	gadgetName := regCtx.GadgetForSerialRequestConfig()
   648  	// gadget is optional on classic
   649  	if gadgetName != "" {
   650  		var gadgetSt snapstate.SnapState
   651  		if err := snapstate.Get(st, gadgetName, &gadgetSt); err != nil {
   652  			return nil, fmt.Errorf("cannot find gadget snap %q: %v", gadgetName, err)
   653  		}
   654  
   655  		var svcURI string
   656  		err := tr.GetMaybe(gadgetName, "device-service.url", &svcURI)
   657  		if err != nil {
   658  			return nil, err
   659  		}
   660  
   661  		if svcURI != "" {
   662  			svcURL, err = url.Parse(svcURI)
   663  			if err != nil {
   664  				return nil, fmt.Errorf("cannot parse device registration base URL %q: %v", svcURI, err)
   665  			}
   666  			if !strings.HasSuffix(svcURL.Path, "/") {
   667  				svcURL.Path += "/"
   668  			}
   669  		}
   670  
   671  		err = tr.GetMaybe(gadgetName, "device-service.headers", &cfg.headers)
   672  		if err != nil {
   673  			return nil, err
   674  		}
   675  
   676  		var bodyStr string
   677  		err = tr.GetMaybe(gadgetName, "registration.body", &bodyStr)
   678  		if err != nil {
   679  			return nil, err
   680  		}
   681  
   682  		cfg.body = []byte(bodyStr)
   683  
   684  		err = tr.GetMaybe(gadgetName, "registration.proposed-serial", &cfg.proposedSerial)
   685  		if err != nil {
   686  			return nil, err
   687  		}
   688  	}
   689  
   690  	if proxyURL != nil && svcURL != nil && !newEnoughProxy(st, proxyURL, client) {
   691  		logger.Noticef("Proxy store does not support custom serial vault; ignoring the proxy")
   692  		proxyURL = nil
   693  	}
   694  
   695  	cfg.setURLs(proxyURL, svcURL)
   696  
   697  	return &cfg, nil
   698  }
   699  
   700  func (m *DeviceManager) doRequestSerial(t *state.Task, _ *tomb.Tomb) error {
   701  	st := t.State()
   702  	st.Lock()
   703  	defer st.Unlock()
   704  
   705  	perfTimings := timings.NewForTask(t)
   706  	defer perfTimings.Save(st)
   707  
   708  	regCtx, err := m.registrationCtx(t)
   709  	if err != nil {
   710  		return err
   711  	}
   712  
   713  	device, err := regCtx.Device()
   714  	if err != nil {
   715  		return err
   716  	}
   717  
   718  	// NB: the keyPair is fixed for now
   719  	privKey, err := m.keyPair()
   720  	if err == state.ErrNoState {
   721  		return fmt.Errorf("internal error: cannot find device key pair")
   722  	}
   723  	if err != nil {
   724  		return err
   725  	}
   726  
   727  	// make this idempotent, look if we have already a serial assertion
   728  	// for privKey
   729  	serials, err := assertstate.DB(st).FindMany(asserts.SerialType, map[string]string{
   730  		"brand-id":            device.Brand,
   731  		"model":               device.Model,
   732  		"device-key-sha3-384": privKey.PublicKey().ID(),
   733  	})
   734  	if err != nil && !asserts.IsNotFound(err) {
   735  		return err
   736  	}
   737  
   738  	finish := func(serial *asserts.Serial) error {
   739  		if regCtx.FinishRegistration(serial); err != nil {
   740  			return err
   741  		}
   742  		t.SetStatus(state.DoneStatus)
   743  		return nil
   744  	}
   745  
   746  	if len(serials) == 1 {
   747  		// means we saved the assertion but didn't get to the end of the task
   748  		return finish(serials[0].(*asserts.Serial))
   749  	}
   750  	if len(serials) > 1 {
   751  		return fmt.Errorf("internal error: multiple serial assertions for the same device key")
   752  	}
   753  
   754  	var serial *asserts.Serial
   755  	var ancillaryBatch *asserts.Batch
   756  	timings.Run(perfTimings, "get-serial", "get device serial", func(tm timings.Measurer) {
   757  		serial, ancillaryBatch, err = getSerial(t, regCtx, privKey, device, tm)
   758  	})
   759  	if err == errPoll {
   760  		t.Logf("Will poll for device serial assertion in 60 seconds")
   761  		return &state.Retry{After: retryInterval}
   762  	}
   763  	if err != nil { // errors & retries
   764  		return err
   765  
   766  	}
   767  
   768  	if ancillaryBatch == nil {
   769  		// the device service returned only the serial
   770  		if err := acceptSerialOnly(t, serial, perfTimings); err != nil {
   771  			return err
   772  		}
   773  	} else {
   774  		// the device service returned a stream of assertions
   775  		timings.Run(perfTimings, "fetch-keys", "fetch signing key chain", func(timings.Measurer) {
   776  			err = acceptSerialPlusBatch(t, serial, ancillaryBatch)
   777  		})
   778  		if err != nil {
   779  			t.Errorf("cannot accept stream of assertions from device service: %v", err)
   780  			return err
   781  		}
   782  	}
   783  
   784  	if repeatRequestSerial == "after-add-serial" {
   785  		// For testing purposes, ensure a crash in this state works.
   786  		return &state.Retry{}
   787  	}
   788  
   789  	return finish(serial)
   790  }
   791  
   792  func acceptSerialOnly(t *state.Task, serial *asserts.Serial, perfTimings *timings.Timings) error {
   793  	st := t.State()
   794  	var err error
   795  	var errAcctKey error
   796  	// try to fetch the signing key chain of the serial
   797  	timings.Run(perfTimings, "fetch-keys", "fetch signing key chain", func(timings.Measurer) {
   798  		errAcctKey, err = fetchKeys(st, serial.SignKeyID())
   799  	})
   800  	if err != nil {
   801  		return err
   802  	}
   803  
   804  	// add the serial assertion to the system assertion db
   805  	err = assertstate.Add(st, serial)
   806  	if err != nil {
   807  		// if we had failed to fetch the signing key, retry in a bit
   808  		if errAcctKey != nil {
   809  			t.Errorf("cannot fetch signing key for the serial: %v", errAcctKey)
   810  			return &state.Retry{After: retryInterval}
   811  		}
   812  		return err
   813  	}
   814  
   815  	return nil
   816  }
   817  
   818  func acceptSerialPlusBatch(t *state.Task, serial *asserts.Serial, batch *asserts.Batch) error {
   819  	st := t.State()
   820  	err := batch.Add(serial)
   821  	if err != nil {
   822  		return err
   823  	}
   824  	return assertstate.AddBatch(st, batch, &asserts.CommitOptions{Precheck: true})
   825  }
   826  
   827  var repeatRequestSerial string // for tests
   828  
   829  func fetchKeys(st *state.State, keyID string) (errAcctKey error, err error) {
   830  	// TODO: right now any store should be good enough here but
   831  	// that might change. As an alternative we do support
   832  	// receiving a stream with any relevant assertions.
   833  	sto := snapstate.Store(st, nil)
   834  	db := assertstate.DB(st)
   835  
   836  	retrieveError := false
   837  	retrieve := func(ref *asserts.Ref) (asserts.Assertion, error) {
   838  		st.Unlock()
   839  		defer st.Lock()
   840  		a, err := sto.Assertion(ref.Type, ref.PrimaryKey, nil)
   841  		retrieveError = err != nil
   842  		return a, err
   843  	}
   844  
   845  	save := func(a asserts.Assertion) error {
   846  		err = assertstate.Add(st, a)
   847  		if err != nil && !asserts.IsUnaccceptedUpdate(err) {
   848  			return err
   849  		}
   850  		return nil
   851  	}
   852  
   853  	f := asserts.NewFetcher(db, retrieve, save)
   854  
   855  	keyRef := &asserts.Ref{
   856  		Type:       asserts.AccountKeyType,
   857  		PrimaryKey: []string{keyID},
   858  	}
   859  	if err := f.Fetch(keyRef); err != nil {
   860  		if retrieveError {
   861  			return err, nil
   862  		} else {
   863  			return nil, err
   864  		}
   865  	}
   866  	return nil, nil
   867  }
   868  
   869  func (m *DeviceManager) doPrepareRemodeling(t *state.Task, tmb *tomb.Tomb) error {
   870  	st := t.State()
   871  	st.Lock()
   872  	defer st.Unlock()
   873  
   874  	remodCtx, err := remodelCtxFromTask(t)
   875  	if err != nil {
   876  		return err
   877  	}
   878  	current, err := findModel(st)
   879  	if err != nil {
   880  		return err
   881  	}
   882  
   883  	sto := remodCtx.Store()
   884  	if sto == nil {
   885  		return fmt.Errorf("internal error: re-registration remodeling should have built a store")
   886  	}
   887  	// ensure a new session accounting for the new brand/model
   888  	st.Unlock()
   889  	_, err = sto.EnsureDeviceSession()
   890  	st.Lock()
   891  	if err != nil {
   892  		return fmt.Errorf("cannot get a store session based on the new model assertion: %v", err)
   893  	}
   894  
   895  	chgID := t.Change().ID()
   896  
   897  	tss, err := remodelTasks(tmb.Context(nil), st, current, remodCtx.Model(), remodCtx, chgID)
   898  	if err != nil {
   899  		return err
   900  	}
   901  
   902  	allTs := state.NewTaskSet()
   903  	for _, ts := range tss {
   904  		allTs.AddAll(ts)
   905  	}
   906  	snapstate.InjectTasks(t, allTs)
   907  
   908  	st.EnsureBefore(0)
   909  	t.SetStatus(state.DoneStatus)
   910  
   911  	return nil
   912  }
   913  
   914  func snapState(st *state.State, name string) (*snapstate.SnapState, error) {
   915  	var snapst snapstate.SnapState
   916  	err := snapstate.Get(st, name, &snapst)
   917  	if err != nil && err != state.ErrNoState {
   918  		return nil, err
   919  	}
   920  	return &snapst, nil
   921  }
   922  
   923  func makeRollbackDir(name string) (string, error) {
   924  	rollbackDir := filepath.Join(dirs.SnapRollbackDir, name)
   925  
   926  	if err := os.MkdirAll(rollbackDir, 0750); err != nil {
   927  		return "", err
   928  	}
   929  
   930  	return rollbackDir, nil
   931  }
   932  
   933  func currentGadgetInfo(snapst *snapstate.SnapState) (*gadget.GadgetData, error) {
   934  	currentInfo, err := snapst.CurrentInfo()
   935  	if err != nil && err != snapstate.ErrNoCurrent {
   936  		return nil, err
   937  	}
   938  	if currentInfo == nil {
   939  		// no current yet
   940  		return nil, nil
   941  	}
   942  	const onClassic = false
   943  	gi, err := gadget.ReadInfo(currentInfo.MountDir(), onClassic)
   944  	if err != nil {
   945  		return nil, err
   946  	}
   947  	return &gadget.GadgetData{Info: gi, RootDir: currentInfo.MountDir()}, nil
   948  }
   949  
   950  func pendingGadgetInfo(snapsup *snapstate.SnapSetup) (*gadget.GadgetData, error) {
   951  	info, err := snap.ReadInfo(snapsup.InstanceName(), snapsup.SideInfo)
   952  	if err != nil {
   953  		return nil, err
   954  	}
   955  	const onClassic = false
   956  	update, err := gadget.ReadInfo(info.MountDir(), onClassic)
   957  	if err != nil {
   958  		return nil, err
   959  	}
   960  	return &gadget.GadgetData{Info: update, RootDir: info.MountDir()}, nil
   961  }
   962  
   963  func gadgetCurrentAndUpdate(st *state.State, snapsup *snapstate.SnapSetup) (current *gadget.GadgetData, update *gadget.GadgetData, err error) {
   964  	snapst, err := snapState(st, snapsup.InstanceName())
   965  	if err != nil {
   966  		return nil, nil, err
   967  	}
   968  
   969  	currentData, err := currentGadgetInfo(snapst)
   970  	if err != nil {
   971  		return nil, nil, fmt.Errorf("cannot read current gadget snap details: %v", err)
   972  	}
   973  	if currentData == nil {
   974  		// don't bother reading update if there is no current
   975  		return nil, nil, nil
   976  	}
   977  
   978  	newData, err := pendingGadgetInfo(snapsup)
   979  	if err != nil {
   980  		return nil, nil, fmt.Errorf("cannot read candidate gadget snap details: %v", err)
   981  	}
   982  
   983  	return currentData, newData, nil
   984  }
   985  
   986  var (
   987  	gadgetUpdate = gadget.Update
   988  )
   989  
   990  func (m *DeviceManager) doUpdateGadgetAssets(t *state.Task, _ *tomb.Tomb) error {
   991  	if release.OnClassic {
   992  		return fmt.Errorf("cannot run update gadget assets task on a classic system")
   993  	}
   994  
   995  	st := t.State()
   996  	st.Lock()
   997  	defer st.Unlock()
   998  
   999  	snapsup, err := snapstate.TaskSnapSetup(t)
  1000  	if err != nil {
  1001  		return err
  1002  	}
  1003  
  1004  	currentData, updateData, err := gadgetCurrentAndUpdate(t.State(), snapsup)
  1005  	if err != nil {
  1006  		return err
  1007  	}
  1008  	if currentData == nil {
  1009  		// no updates during first boot & seeding
  1010  		return nil
  1011  	}
  1012  
  1013  	snapRollbackDir, err := makeRollbackDir(fmt.Sprintf("%v_%v", snapsup.InstanceName(), snapsup.SideInfo.Revision))
  1014  	if err != nil {
  1015  		return fmt.Errorf("cannot prepare update rollback directory: %v", err)
  1016  	}
  1017  
  1018  	st.Unlock()
  1019  	err = gadgetUpdate(*currentData, *updateData, snapRollbackDir)
  1020  	st.Lock()
  1021  	if err != nil {
  1022  		if err == gadget.ErrNoUpdate {
  1023  			// no update needed
  1024  			t.Logf("No gadget assets update needed")
  1025  			return nil
  1026  		}
  1027  		return err
  1028  	}
  1029  
  1030  	t.SetStatus(state.DoneStatus)
  1031  
  1032  	if err := os.RemoveAll(snapRollbackDir); err != nil && !os.IsNotExist(err) {
  1033  		logger.Noticef("failed to remove gadget update rollback directory %q: %v", snapRollbackDir, err)
  1034  	}
  1035  
  1036  	// TODO: consider having the option to do this early via recovery in
  1037  	// core20, have fallback code as well there
  1038  	st.RequestRestart(state.RestartSystem)
  1039  
  1040  	return nil
  1041  }