github.com/hugh712/snapd@v0.0.0-20200910133618-1a99902bd583/cmd/snap-repair/runner.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2017-2020 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package main
    21  
    22  import (
    23  	"bufio"
    24  	"bytes"
    25  	"crypto/tls"
    26  	"encoding/json"
    27  	"errors"
    28  	"fmt"
    29  	"io"
    30  	"io/ioutil"
    31  	"net/http"
    32  	"net/url"
    33  	"os"
    34  	"os/exec"
    35  	"path/filepath"
    36  	"strconv"
    37  	"strings"
    38  	"syscall"
    39  	"time"
    40  
    41  	"gopkg.in/retry.v1"
    42  
    43  	"github.com/snapcore/snapd/arch"
    44  	"github.com/snapcore/snapd/asserts"
    45  	"github.com/snapcore/snapd/asserts/sysdb"
    46  	"github.com/snapcore/snapd/dirs"
    47  	"github.com/snapcore/snapd/errtracker"
    48  	"github.com/snapcore/snapd/httputil"
    49  	"github.com/snapcore/snapd/logger"
    50  	"github.com/snapcore/snapd/osutil"
    51  	"github.com/snapcore/snapd/release"
    52  	"github.com/snapcore/snapd/snapdenv"
    53  	"github.com/snapcore/snapd/strutil"
    54  )
    55  
    56  var (
    57  	// TODO: move inside the repairs themselves?
    58  	defaultRepairTimeout = 30 * time.Minute
    59  )
    60  
    61  var errtrackerReportRepair = errtracker.ReportRepair
    62  
    63  // Repair is a runnable repair.
    64  type Repair struct {
    65  	*asserts.Repair
    66  
    67  	run      *Runner
    68  	sequence int
    69  }
    70  
    71  func (r *Repair) RunDir() string {
    72  	return filepath.Join(dirs.SnapRepairRunDir, r.BrandID(), strconv.Itoa(r.RepairID()))
    73  }
    74  
    75  func (r *Repair) String() string {
    76  	return fmt.Sprintf("%s-%v", r.BrandID(), r.RepairID())
    77  }
    78  
    79  // SetStatus sets the status of the repair in the state and saves the latter.
    80  func (r *Repair) SetStatus(status RepairStatus) {
    81  	brandID := r.BrandID()
    82  	cur := *r.run.state.Sequences[brandID][r.sequence-1]
    83  	cur.Status = status
    84  	r.run.setRepairState(brandID, cur)
    85  	r.run.SaveState()
    86  }
    87  
    88  // makeRepairSymlink ensures $dir/repair exists and is a symlink to
    89  // /usr/lib/snapd/snap-repair
    90  func makeRepairSymlink(dir string) (err error) {
    91  	// make "repair" binary available to the repair scripts via symlink
    92  	// to the real snap-repair
    93  	if err = os.MkdirAll(dir, 0755); err != nil {
    94  		return err
    95  	}
    96  
    97  	old := filepath.Join(dirs.CoreLibExecDir, "snap-repair")
    98  	new := filepath.Join(dir, "repair")
    99  	if err := os.Symlink(old, new); err != nil && !os.IsExist(err) {
   100  		return err
   101  	}
   102  
   103  	return nil
   104  }
   105  
   106  // Run executes the repair script leaving execution trail files on disk.
   107  func (r *Repair) Run() error {
   108  	// write the script to disk
   109  	rundir := r.RunDir()
   110  	err := os.MkdirAll(rundir, 0775)
   111  	if err != nil {
   112  		return err
   113  	}
   114  
   115  	// ensure the script can use "repair done"
   116  	repairToolsDir := filepath.Join(dirs.SnapRunRepairDir, "tools")
   117  	if err := makeRepairSymlink(repairToolsDir); err != nil {
   118  		return err
   119  	}
   120  
   121  	baseName := fmt.Sprintf("r%d", r.Revision())
   122  	script := filepath.Join(rundir, baseName+".script")
   123  	err = osutil.AtomicWriteFile(script, r.Body(), 0700, 0)
   124  	if err != nil {
   125  		return err
   126  	}
   127  
   128  	logPath := filepath.Join(rundir, baseName+".running")
   129  	logf, err := os.OpenFile(logPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
   130  	if err != nil {
   131  		return err
   132  	}
   133  	defer logf.Close()
   134  
   135  	fmt.Fprintf(logf, "repair: %s\n", r)
   136  	fmt.Fprintf(logf, "revision: %d\n", r.Revision())
   137  	fmt.Fprintf(logf, "summary: %s\n", r.Summary())
   138  	fmt.Fprintf(logf, "output:\n")
   139  
   140  	statusR, statusW, err := os.Pipe()
   141  	if err != nil {
   142  		return err
   143  	}
   144  	defer statusR.Close()
   145  	defer statusW.Close()
   146  
   147  	logger.Debugf("executing %s", script)
   148  
   149  	// run the script
   150  	env := os.Environ()
   151  	// we need to hardcode FD=3 because this is the FD after
   152  	// exec.Command() forked. there is no way in go currently
   153  	// to run something right after fork() in the child to
   154  	// know the fd. However because go will close all fds
   155  	// except the ones in "cmd.ExtraFiles" we are safe to set "3"
   156  	env = append(env, "SNAP_REPAIR_STATUS_FD=3")
   157  	env = append(env, "SNAP_REPAIR_RUN_DIR="+rundir)
   158  	// inject repairToolDir into PATH so that the script can use
   159  	// `repair {done,skip,retry}`
   160  	var havePath bool
   161  	for i, envStr := range env {
   162  		if strings.HasPrefix(envStr, "PATH=") {
   163  			newEnv := fmt.Sprintf("%s:%s", strings.TrimSuffix(envStr, ":"), repairToolsDir)
   164  			env[i] = newEnv
   165  			havePath = true
   166  		}
   167  	}
   168  	if !havePath {
   169  		env = append(env, "PATH=/usr/sbin:/usr/bin:/sbin:/bin:"+repairToolsDir)
   170  	}
   171  
   172  	workdir := filepath.Join(rundir, "work")
   173  	if err := os.MkdirAll(workdir, 0700); err != nil {
   174  		return err
   175  	}
   176  
   177  	cmd := exec.Command(script)
   178  	cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
   179  	cmd.Env = env
   180  	cmd.Dir = workdir
   181  	cmd.ExtraFiles = []*os.File{statusW}
   182  	cmd.Stdout = logf
   183  	cmd.Stderr = logf
   184  	if err = cmd.Start(); err != nil {
   185  		return err
   186  	}
   187  	statusW.Close()
   188  
   189  	// wait for repair to finish or timeout
   190  	var scriptErr error
   191  	killTimerCh := time.After(defaultRepairTimeout)
   192  	doneCh := make(chan error, 1)
   193  	go func() {
   194  		doneCh <- cmd.Wait()
   195  		close(doneCh)
   196  	}()
   197  	select {
   198  	case scriptErr = <-doneCh:
   199  		// done
   200  	case <-killTimerCh:
   201  		if err := osutil.KillProcessGroup(cmd); err != nil {
   202  			logger.Noticef("cannot kill timed out repair %s: %s", r, err)
   203  		}
   204  		scriptErr = fmt.Errorf("repair did not finish within %s", defaultRepairTimeout)
   205  	}
   206  	// read repair status pipe, use the last value
   207  	status := readStatus(statusR)
   208  	statusPath := filepath.Join(rundir, baseName+"."+status.String())
   209  
   210  	// if the script had an error exit status still honor what we
   211  	// read from the status-pipe, however report the error
   212  	if scriptErr != nil {
   213  		scriptErr = fmt.Errorf("repair %s revision %d failed: %s", r, r.Revision(), scriptErr)
   214  		if err := r.errtrackerReport(scriptErr, status, logPath); err != nil {
   215  			logger.Noticef("cannot report error to errtracker: %s", err)
   216  		}
   217  		// ensure the error is present in the output log
   218  		fmt.Fprintf(logf, "\n%s", scriptErr)
   219  	}
   220  	if err := os.Rename(logPath, statusPath); err != nil {
   221  		return err
   222  	}
   223  	r.SetStatus(status)
   224  
   225  	return nil
   226  }
   227  
   228  func readStatus(r io.Reader) RepairStatus {
   229  	var status RepairStatus
   230  	scanner := bufio.NewScanner(r)
   231  	for scanner.Scan() {
   232  		switch strings.TrimSpace(scanner.Text()) {
   233  		case "done":
   234  			status = DoneStatus
   235  		// TODO: support having a script skip over many and up to a given repair-id #
   236  		case "skip":
   237  			status = SkipStatus
   238  		}
   239  	}
   240  	if scanner.Err() != nil {
   241  		return RetryStatus
   242  	}
   243  	return status
   244  }
   245  
   246  // errtrackerReport reports an repairErr with the given logPath to the
   247  // snap error tracker.
   248  func (r *Repair) errtrackerReport(repairErr error, status RepairStatus, logPath string) error {
   249  	errMsg := repairErr.Error()
   250  
   251  	scriptOutput, err := ioutil.ReadFile(logPath)
   252  	if err != nil {
   253  		logger.Noticef("cannot read %s", logPath)
   254  	}
   255  	s := fmt.Sprintf("%s/%d", r.BrandID(), r.RepairID())
   256  
   257  	dupSig := fmt.Sprintf("%s\n%s\noutput:\n%s", s, errMsg, scriptOutput)
   258  	extra := map[string]string{
   259  		"Revision": strconv.Itoa(r.Revision()),
   260  		"BrandID":  r.BrandID(),
   261  		"RepairID": strconv.Itoa(r.RepairID()),
   262  		"Status":   status.String(),
   263  	}
   264  	_, err = errtrackerReportRepair(s, errMsg, dupSig, extra)
   265  	return err
   266  }
   267  
   268  // Runner implements fetching, tracking and running repairs.
   269  type Runner struct {
   270  	BaseURL *url.URL
   271  	cli     *http.Client
   272  
   273  	state         state
   274  	stateModified bool
   275  
   276  	// sequenceNext keeps track of the next integer id in a brand sequence to considered in this run, see Next.
   277  	sequenceNext map[string]int
   278  }
   279  
   280  // NewRunner returns a Runner.
   281  func NewRunner() *Runner {
   282  	run := &Runner{
   283  		sequenceNext: make(map[string]int),
   284  	}
   285  	opts := httputil.ClientOptions{
   286  		MayLogBody:         false,
   287  		ProxyConnectHeader: http.Header{"User-Agent": []string{snapdenv.UserAgent()}},
   288  		TLSConfig: &tls.Config{
   289  			Time: run.now,
   290  		},
   291  		ExtraSSLCerts: &httputil.ExtraSSLCertsFromDir{
   292  			Dir: dirs.SnapdStoreSSLCertsDir,
   293  		},
   294  	}
   295  	run.cli = httputil.NewHTTPClient(&opts)
   296  	return run
   297  }
   298  
   299  var (
   300  	fetchRetryStrategy = retry.LimitCount(7, retry.LimitTime(90*time.Second,
   301  		retry.Exponential{
   302  			Initial: 500 * time.Millisecond,
   303  			Factor:  2.5,
   304  		},
   305  	))
   306  
   307  	peekRetryStrategy = retry.LimitCount(5, retry.LimitTime(44*time.Second,
   308  		retry.Exponential{
   309  			Initial: 300 * time.Millisecond,
   310  			Factor:  2.5,
   311  		},
   312  	))
   313  )
   314  
   315  var (
   316  	ErrRepairNotFound    = errors.New("repair not found")
   317  	ErrRepairNotModified = errors.New("repair was not modified")
   318  )
   319  
   320  var (
   321  	maxRepairScriptSize = 24 * 1024 * 1024
   322  )
   323  
   324  // Fetch retrieves a stream with the repair with the given ids and any
   325  // auxiliary assertions. If revision>=0 the request will include an
   326  // If-None-Match header with an ETag for the revision, and
   327  // ErrRepairNotModified is returned if the revision is still current.
   328  func (run *Runner) Fetch(brandID string, repairID int, revision int) (*asserts.Repair, []asserts.Assertion, error) {
   329  	u, err := run.BaseURL.Parse(fmt.Sprintf("repairs/%s/%d", brandID, repairID))
   330  	if err != nil {
   331  		return nil, nil, err
   332  	}
   333  
   334  	var r []asserts.Assertion
   335  	resp, err := httputil.RetryRequest(u.String(), func() (*http.Response, error) {
   336  		req, err := http.NewRequest("GET", u.String(), nil)
   337  		if err != nil {
   338  			return nil, err
   339  		}
   340  		req.Header.Set("User-Agent", snapdenv.UserAgent())
   341  		req.Header.Set("Accept", "application/x.ubuntu.assertion")
   342  		if revision >= 0 {
   343  			req.Header.Set("If-None-Match", fmt.Sprintf(`"%d"`, revision))
   344  		}
   345  		return run.cli.Do(req)
   346  	}, func(resp *http.Response) error {
   347  		if resp.StatusCode == 200 {
   348  			logger.Debugf("fetching repair %s-%d", brandID, repairID)
   349  			// decode assertions
   350  			dec := asserts.NewDecoderWithTypeMaxBodySize(resp.Body, map[*asserts.AssertionType]int{
   351  				asserts.RepairType: maxRepairScriptSize,
   352  			})
   353  			for {
   354  				a, err := dec.Decode()
   355  				if err == io.EOF {
   356  					break
   357  				}
   358  				if err != nil {
   359  					return err
   360  				}
   361  				r = append(r, a)
   362  			}
   363  			if len(r) == 0 {
   364  				return io.ErrUnexpectedEOF
   365  			}
   366  		}
   367  		return nil
   368  	}, fetchRetryStrategy)
   369  
   370  	if err != nil {
   371  		return nil, nil, err
   372  	}
   373  
   374  	moveTimeLowerBound := true
   375  	defer func() {
   376  		if moveTimeLowerBound {
   377  			t, _ := http.ParseTime(resp.Header.Get("Date"))
   378  			run.moveTimeLowerBound(t)
   379  		}
   380  	}()
   381  
   382  	switch resp.StatusCode {
   383  	case 200:
   384  		// ok
   385  	case 304:
   386  		// not modified
   387  		return nil, nil, ErrRepairNotModified
   388  	case 404:
   389  		return nil, nil, ErrRepairNotFound
   390  	default:
   391  		moveTimeLowerBound = false
   392  		return nil, nil, fmt.Errorf("cannot fetch repair, unexpected status %d", resp.StatusCode)
   393  	}
   394  
   395  	repair, aux, err := checkStream(brandID, repairID, r)
   396  	if err != nil {
   397  		return nil, nil, fmt.Errorf("cannot fetch repair, %v", err)
   398  	}
   399  
   400  	if repair.Revision() <= revision {
   401  		// this shouldn't happen but if it does we behave like
   402  		// all the rest of assertion infrastructure and ignore
   403  		// the now superseded revision
   404  		return nil, nil, ErrRepairNotModified
   405  	}
   406  
   407  	return repair, aux, err
   408  }
   409  
   410  func checkStream(brandID string, repairID int, r []asserts.Assertion) (repair *asserts.Repair, aux []asserts.Assertion, err error) {
   411  	if len(r) == 0 {
   412  		return nil, nil, fmt.Errorf("empty repair assertions stream")
   413  	}
   414  	var ok bool
   415  	repair, ok = r[0].(*asserts.Repair)
   416  	if !ok {
   417  		return nil, nil, fmt.Errorf("unexpected first assertion %q", r[0].Type().Name)
   418  	}
   419  
   420  	if repair.BrandID() != brandID || repair.RepairID() != repairID {
   421  		return nil, nil, fmt.Errorf("repair id mismatch %s/%d != %s/%d", repair.BrandID(), repair.RepairID(), brandID, repairID)
   422  	}
   423  
   424  	return repair, r[1:], nil
   425  }
   426  
   427  type peekResp struct {
   428  	Headers map[string]interface{} `json:"headers"`
   429  }
   430  
   431  // Peek retrieves the headers for the repair with the given ids.
   432  func (run *Runner) Peek(brandID string, repairID int) (headers map[string]interface{}, err error) {
   433  	u, err := run.BaseURL.Parse(fmt.Sprintf("repairs/%s/%d", brandID, repairID))
   434  	if err != nil {
   435  		return nil, err
   436  	}
   437  
   438  	var rsp peekResp
   439  
   440  	resp, err := httputil.RetryRequest(u.String(), func() (*http.Response, error) {
   441  		req, err := http.NewRequest("GET", u.String(), nil)
   442  		if err != nil {
   443  			return nil, err
   444  		}
   445  		req.Header.Set("User-Agent", snapdenv.UserAgent())
   446  		req.Header.Set("Accept", "application/json")
   447  		return run.cli.Do(req)
   448  	}, func(resp *http.Response) error {
   449  		rsp.Headers = nil
   450  		if resp.StatusCode == 200 {
   451  			dec := json.NewDecoder(resp.Body)
   452  			return dec.Decode(&rsp)
   453  		}
   454  		return nil
   455  	}, peekRetryStrategy)
   456  
   457  	if err != nil {
   458  		return nil, err
   459  	}
   460  
   461  	moveTimeLowerBound := true
   462  	defer func() {
   463  		if moveTimeLowerBound {
   464  			t, _ := http.ParseTime(resp.Header.Get("Date"))
   465  			run.moveTimeLowerBound(t)
   466  		}
   467  	}()
   468  
   469  	switch resp.StatusCode {
   470  	case 200:
   471  		// ok
   472  	case 404:
   473  		return nil, ErrRepairNotFound
   474  	default:
   475  		moveTimeLowerBound = false
   476  		return nil, fmt.Errorf("cannot peek repair headers, unexpected status %d", resp.StatusCode)
   477  	}
   478  
   479  	headers = rsp.Headers
   480  	if headers["brand-id"] != brandID || headers["repair-id"] != strconv.Itoa(repairID) {
   481  		return nil, fmt.Errorf("cannot peek repair headers, repair id mismatch %s/%s != %s/%d", headers["brand-id"], headers["repair-id"], brandID, repairID)
   482  	}
   483  
   484  	return headers, nil
   485  }
   486  
   487  // deviceInfo captures information about the device.
   488  type deviceInfo struct {
   489  	Brand string `json:"brand"`
   490  	Model string `json:"model"`
   491  }
   492  
   493  // RepairStatus represents the possible statuses of a repair.
   494  type RepairStatus int
   495  
   496  const (
   497  	RetryStatus RepairStatus = iota
   498  	SkipStatus
   499  	DoneStatus
   500  )
   501  
   502  func (rs RepairStatus) String() string {
   503  	switch rs {
   504  	case RetryStatus:
   505  		return "retry"
   506  	case SkipStatus:
   507  		return "skip"
   508  	case DoneStatus:
   509  		return "done"
   510  	default:
   511  		return "unknown"
   512  	}
   513  }
   514  
   515  // RepairState holds the current revision and status of a repair in a sequence of repairs.
   516  type RepairState struct {
   517  	Sequence int          `json:"sequence"`
   518  	Revision int          `json:"revision"`
   519  	Status   RepairStatus `json:"status"`
   520  }
   521  
   522  // state holds the atomically updated control state of the runner with sequences of repairs and their states.
   523  type state struct {
   524  	Device         deviceInfo                `json:"device"`
   525  	Sequences      map[string][]*RepairState `json:"sequences,omitempty"`
   526  	TimeLowerBound time.Time                 `json:"time-lower-bound"`
   527  }
   528  
   529  func (run *Runner) setRepairState(brandID string, state RepairState) {
   530  	if run.state.Sequences == nil {
   531  		run.state.Sequences = make(map[string][]*RepairState)
   532  	}
   533  	sequence := run.state.Sequences[brandID]
   534  	if state.Sequence > len(sequence) {
   535  		run.stateModified = true
   536  		run.state.Sequences[brandID] = append(sequence, &state)
   537  	} else if *sequence[state.Sequence-1] != state {
   538  		run.stateModified = true
   539  		sequence[state.Sequence-1] = &state
   540  	}
   541  }
   542  
   543  func (run *Runner) readState() error {
   544  	r, err := os.Open(dirs.SnapRepairStateFile)
   545  	if err != nil {
   546  		return err
   547  	}
   548  	defer r.Close()
   549  	dec := json.NewDecoder(r)
   550  	return dec.Decode(&run.state)
   551  }
   552  
   553  func (run *Runner) moveTimeLowerBound(t time.Time) {
   554  	if t.After(run.state.TimeLowerBound) {
   555  		run.stateModified = true
   556  		run.state.TimeLowerBound = t.UTC()
   557  	}
   558  }
   559  
   560  var timeNow = time.Now
   561  
   562  func (run *Runner) now() time.Time {
   563  	now := timeNow().UTC()
   564  	if now.Before(run.state.TimeLowerBound) {
   565  		return run.state.TimeLowerBound
   566  	}
   567  	return now
   568  }
   569  
   570  func (run *Runner) initState() error {
   571  	if err := os.MkdirAll(dirs.SnapRepairDir, 0775); err != nil {
   572  		return fmt.Errorf("cannot create repair state directory: %v", err)
   573  	}
   574  	// best-effort remove old
   575  	os.Remove(dirs.SnapRepairStateFile)
   576  	run.state = state{}
   577  	// initialize time lower bound with image built time/seed.yaml time
   578  	info, err := os.Stat(filepath.Join(dirs.SnapSeedDir, "seed.yaml"))
   579  	if err != nil {
   580  		return err
   581  	}
   582  	run.moveTimeLowerBound(info.ModTime())
   583  	// initialize device info
   584  	if err := run.initDeviceInfo(); err != nil {
   585  		return err
   586  	}
   587  	run.stateModified = true
   588  	return run.SaveState()
   589  }
   590  
   591  func trustedBackstore(trusted []asserts.Assertion) asserts.Backstore {
   592  	trustedBS := asserts.NewMemoryBackstore()
   593  	for _, t := range trusted {
   594  		trustedBS.Put(t.Type(), t)
   595  	}
   596  	return trustedBS
   597  }
   598  
   599  func checkAuthorityID(a asserts.Assertion, trusted asserts.Backstore) error {
   600  	assertType := a.Type()
   601  	if assertType != asserts.AccountKeyType && assertType != asserts.AccountType {
   602  		return nil
   603  	}
   604  	// check that account and account-key assertions are signed by
   605  	// a trusted authority
   606  	acctID := a.AuthorityID()
   607  	_, err := trusted.Get(asserts.AccountType, []string{acctID}, asserts.AccountType.MaxSupportedFormat())
   608  	if err != nil && !asserts.IsNotFound(err) {
   609  		return err
   610  	}
   611  	if asserts.IsNotFound(err) {
   612  		return fmt.Errorf("%v not signed by trusted authority: %s", a.Ref(), acctID)
   613  	}
   614  	return nil
   615  }
   616  
   617  func verifySignatures(a asserts.Assertion, workBS asserts.Backstore, trusted asserts.Backstore) error {
   618  	if err := checkAuthorityID(a, trusted); err != nil {
   619  		return err
   620  	}
   621  	acctKeyMaxSuppFormat := asserts.AccountKeyType.MaxSupportedFormat()
   622  
   623  	seen := make(map[string]bool)
   624  	bottom := false
   625  	for !bottom {
   626  		u := a.Ref().Unique()
   627  		if seen[u] {
   628  			return fmt.Errorf("circular assertions")
   629  		}
   630  		seen[u] = true
   631  		signKey := []string{a.SignKeyID()}
   632  		key, err := trusted.Get(asserts.AccountKeyType, signKey, acctKeyMaxSuppFormat)
   633  		if err != nil && !asserts.IsNotFound(err) {
   634  			return err
   635  		}
   636  		if err == nil {
   637  			bottom = true
   638  		} else {
   639  			key, err = workBS.Get(asserts.AccountKeyType, signKey, acctKeyMaxSuppFormat)
   640  			if err != nil && !asserts.IsNotFound(err) {
   641  				return err
   642  			}
   643  			if asserts.IsNotFound(err) {
   644  				return fmt.Errorf("cannot find public key %q", signKey[0])
   645  			}
   646  			if err := checkAuthorityID(key, trusted); err != nil {
   647  				return err
   648  			}
   649  		}
   650  		if err := asserts.CheckSignature(a, key.(*asserts.AccountKey), nil, time.Time{}); err != nil {
   651  			return err
   652  		}
   653  		a = key
   654  	}
   655  	return nil
   656  }
   657  
   658  func (run *Runner) initDeviceInfo() error {
   659  	const errPrefix = "cannot set device information: "
   660  
   661  	workBS := asserts.NewMemoryBackstore()
   662  	assertSeedDir := filepath.Join(dirs.SnapSeedDir, "assertions")
   663  	dc, err := ioutil.ReadDir(assertSeedDir)
   664  	if err != nil {
   665  		return err
   666  	}
   667  	var model *asserts.Model
   668  	for _, fi := range dc {
   669  		fn := filepath.Join(assertSeedDir, fi.Name())
   670  		f, err := os.Open(fn)
   671  		if err != nil {
   672  			// best effort
   673  			continue
   674  		}
   675  		dec := asserts.NewDecoder(f)
   676  		for {
   677  			a, err := dec.Decode()
   678  			if err != nil {
   679  				// best effort
   680  				break
   681  			}
   682  			switch a.Type() {
   683  			case asserts.ModelType:
   684  				if model != nil {
   685  					return fmt.Errorf(errPrefix + "multiple models in seed assertions")
   686  				}
   687  				model = a.(*asserts.Model)
   688  			case asserts.AccountType, asserts.AccountKeyType:
   689  				workBS.Put(a.Type(), a)
   690  			}
   691  		}
   692  	}
   693  	if model == nil {
   694  		return fmt.Errorf(errPrefix + "no model assertion in seed data")
   695  	}
   696  	trustedBS := trustedBackstore(sysdb.Trusted())
   697  	if err := verifySignatures(model, workBS, trustedBS); err != nil {
   698  		return fmt.Errorf(errPrefix+"%v", err)
   699  	}
   700  	acctPK := []string{model.BrandID()}
   701  	acctMaxSupFormat := asserts.AccountType.MaxSupportedFormat()
   702  	acct, err := trustedBS.Get(asserts.AccountType, acctPK, acctMaxSupFormat)
   703  	if err != nil {
   704  		var err error
   705  		acct, err = workBS.Get(asserts.AccountType, acctPK, acctMaxSupFormat)
   706  		if err != nil {
   707  			return fmt.Errorf(errPrefix + "no brand account assertion in seed data")
   708  		}
   709  	}
   710  	if err := verifySignatures(acct, workBS, trustedBS); err != nil {
   711  		return fmt.Errorf(errPrefix+"%v", err)
   712  	}
   713  	run.state.Device.Brand = model.BrandID()
   714  	run.state.Device.Model = model.Model()
   715  	return nil
   716  }
   717  
   718  // LoadState loads the repairs' state from disk, and (re)initializes it if it's missing or corrupted.
   719  func (run *Runner) LoadState() error {
   720  	err := run.readState()
   721  	if err == nil {
   722  		return nil
   723  	}
   724  	// error => initialize from scratch
   725  	if !os.IsNotExist(err) {
   726  		logger.Noticef("cannor read repair state: %v", err)
   727  	}
   728  	return run.initState()
   729  }
   730  
   731  // SaveState saves the repairs' state to disk.
   732  func (run *Runner) SaveState() error {
   733  	if !run.stateModified {
   734  		return nil
   735  	}
   736  	m, err := json.Marshal(&run.state)
   737  	if err != nil {
   738  		return fmt.Errorf("cannot marshal repair state: %v", err)
   739  	}
   740  	err = osutil.AtomicWriteFile(dirs.SnapRepairStateFile, m, 0600, 0)
   741  	if err != nil {
   742  		return fmt.Errorf("cannot save repair state: %v", err)
   743  	}
   744  	run.stateModified = false
   745  	return nil
   746  }
   747  
   748  func stringList(headers map[string]interface{}, name string) ([]string, error) {
   749  	v, ok := headers[name]
   750  	if !ok {
   751  		return nil, nil
   752  	}
   753  	l, ok := v.([]interface{})
   754  	if !ok {
   755  		return nil, fmt.Errorf("header %q is not a list", name)
   756  	}
   757  	r := make([]string, len(l))
   758  	for i, v := range l {
   759  		s, ok := v.(string)
   760  		if !ok {
   761  			return nil, fmt.Errorf("header %q contains non-string elements", name)
   762  		}
   763  		r[i] = s
   764  	}
   765  	return r, nil
   766  }
   767  
   768  // Applicable returns whether a repair with the given headers is applicable to the device.
   769  func (run *Runner) Applicable(headers map[string]interface{}) bool {
   770  	if headers["disabled"] == "true" {
   771  		return false
   772  	}
   773  	series, err := stringList(headers, "series")
   774  	if err != nil {
   775  		return false
   776  	}
   777  	if len(series) != 0 && !strutil.ListContains(series, release.Series) {
   778  		return false
   779  	}
   780  	archs, err := stringList(headers, "architectures")
   781  	if err != nil {
   782  		return false
   783  	}
   784  	if len(archs) != 0 && !strutil.ListContains(archs, arch.DpkgArchitecture()) {
   785  		return false
   786  	}
   787  	brandModel := fmt.Sprintf("%s/%s", run.state.Device.Brand, run.state.Device.Model)
   788  	models, err := stringList(headers, "models")
   789  	if err != nil {
   790  		return false
   791  	}
   792  	if len(models) != 0 && !strutil.ListContains(models, brandModel) {
   793  		// model prefix matching: brand/prefix*
   794  		hit := false
   795  		for _, patt := range models {
   796  			if strings.HasSuffix(patt, "*") && strings.ContainsRune(patt, '/') {
   797  				if strings.HasPrefix(brandModel, strings.TrimSuffix(patt, "*")) {
   798  					hit = true
   799  					break
   800  				}
   801  			}
   802  		}
   803  		if !hit {
   804  			return false
   805  		}
   806  	}
   807  	return true
   808  }
   809  
   810  var errSkip = errors.New("repair unnecessary on this system")
   811  
   812  func (run *Runner) fetch(brandID string, repairID int) (repair *asserts.Repair, aux []asserts.Assertion, err error) {
   813  	headers, err := run.Peek(brandID, repairID)
   814  	if err != nil {
   815  		return nil, nil, err
   816  	}
   817  	if !run.Applicable(headers) {
   818  		return nil, nil, errSkip
   819  	}
   820  	return run.Fetch(brandID, repairID, -1)
   821  }
   822  
   823  func (run *Runner) refetch(brandID string, repairID, revision int) (repair *asserts.Repair, aux []asserts.Assertion, err error) {
   824  	return run.Fetch(brandID, repairID, revision)
   825  }
   826  
   827  func (run *Runner) saveStream(brandID string, repairID int, repair *asserts.Repair, aux []asserts.Assertion) error {
   828  	d := filepath.Join(dirs.SnapRepairAssertsDir, brandID, strconv.Itoa(repairID))
   829  	err := os.MkdirAll(d, 0775)
   830  	if err != nil {
   831  		return err
   832  	}
   833  	buf := &bytes.Buffer{}
   834  	enc := asserts.NewEncoder(buf)
   835  	r := append([]asserts.Assertion{repair}, aux...)
   836  	for _, a := range r {
   837  		if err := enc.Encode(a); err != nil {
   838  			return fmt.Errorf("cannot encode repair assertions %s-%d for saving: %v", brandID, repairID, err)
   839  		}
   840  	}
   841  	p := filepath.Join(d, fmt.Sprintf("r%d.repair", r[0].Revision()))
   842  	return osutil.AtomicWriteFile(p, buf.Bytes(), 0600, 0)
   843  }
   844  
   845  func (run *Runner) readSavedStream(brandID string, repairID, revision int) (repair *asserts.Repair, aux []asserts.Assertion, err error) {
   846  	d := filepath.Join(dirs.SnapRepairAssertsDir, brandID, strconv.Itoa(repairID))
   847  	p := filepath.Join(d, fmt.Sprintf("r%d.repair", revision))
   848  	f, err := os.Open(p)
   849  	if err != nil {
   850  		return nil, nil, err
   851  	}
   852  	defer f.Close()
   853  
   854  	dec := asserts.NewDecoder(f)
   855  	var r []asserts.Assertion
   856  	for {
   857  		a, err := dec.Decode()
   858  		if err == io.EOF {
   859  			break
   860  		}
   861  		if err != nil {
   862  			return nil, nil, fmt.Errorf("cannot decode repair assertions %s-%d from disk: %v", brandID, repairID, err)
   863  		}
   864  		r = append(r, a)
   865  	}
   866  	return checkStream(brandID, repairID, r)
   867  }
   868  
   869  func (run *Runner) makeReady(brandID string, sequenceNext int) (repair *asserts.Repair, err error) {
   870  	sequence := run.state.Sequences[brandID]
   871  	var aux []asserts.Assertion
   872  	var state RepairState
   873  	if sequenceNext <= len(sequence) {
   874  		// consider retries
   875  		state = *sequence[sequenceNext-1]
   876  		if state.Status != RetryStatus {
   877  			return nil, errSkip
   878  		}
   879  		var err error
   880  		repair, aux, err = run.refetch(brandID, state.Sequence, state.Revision)
   881  		if err != nil {
   882  			if err != ErrRepairNotModified {
   883  				logger.Noticef("cannot refetch repair %s-%d, will retry what is on disk: %v", brandID, sequenceNext, err)
   884  			}
   885  			// try to use what we have already on disk
   886  			repair, aux, err = run.readSavedStream(brandID, state.Sequence, state.Revision)
   887  			if err != nil {
   888  				return nil, err
   889  			}
   890  		}
   891  	} else {
   892  		// fetch the next repair in the sequence
   893  		// assumes no gaps, each repair id is present so far,
   894  		// possibly skipped
   895  		var err error
   896  		repair, aux, err = run.fetch(brandID, sequenceNext)
   897  		if err != nil && err != errSkip {
   898  			return nil, err
   899  		}
   900  		state = RepairState{
   901  			Sequence: sequenceNext,
   902  		}
   903  		if err == errSkip {
   904  			// TODO: store headers to justify decision
   905  			state.Status = SkipStatus
   906  			run.setRepairState(brandID, state)
   907  			return nil, errSkip
   908  		}
   909  	}
   910  	// verify with signatures
   911  	if err := run.Verify(repair, aux); err != nil {
   912  		return nil, fmt.Errorf("cannot verify repair %s-%d: %v", brandID, state.Sequence, err)
   913  	}
   914  	if err := run.saveStream(brandID, state.Sequence, repair, aux); err != nil {
   915  		return nil, err
   916  	}
   917  	state.Revision = repair.Revision()
   918  	if !run.Applicable(repair.Headers()) {
   919  		state.Status = SkipStatus
   920  		run.setRepairState(brandID, state)
   921  		return nil, errSkip
   922  	}
   923  	run.setRepairState(brandID, state)
   924  	return repair, nil
   925  }
   926  
   927  // Next returns the next repair for the brand id sequence to run/retry or ErrRepairNotFound if there is none atm. It updates the state as required.
   928  func (run *Runner) Next(brandID string) (*Repair, error) {
   929  	sequenceNext := run.sequenceNext[brandID]
   930  	if sequenceNext == 0 {
   931  		sequenceNext = 1
   932  	}
   933  	for {
   934  		repair, err := run.makeReady(brandID, sequenceNext)
   935  		// SaveState is a no-op unless makeReady modified the state
   936  		stateErr := run.SaveState()
   937  		if err != nil && err != errSkip && err != ErrRepairNotFound {
   938  			// err is a non trivial error, just log the SaveState error and report err
   939  			if stateErr != nil {
   940  				logger.Noticef("%v", stateErr)
   941  			}
   942  			return nil, err
   943  		}
   944  		if stateErr != nil {
   945  			return nil, stateErr
   946  		}
   947  		if err == ErrRepairNotFound {
   948  			return nil, ErrRepairNotFound
   949  		}
   950  
   951  		sequenceNext += 1
   952  		run.sequenceNext[brandID] = sequenceNext
   953  		if err == errSkip {
   954  			continue
   955  		}
   956  
   957  		return &Repair{
   958  			Repair:   repair,
   959  			run:      run,
   960  			sequence: sequenceNext - 1,
   961  		}, nil
   962  	}
   963  }
   964  
   965  // Limit trust to specific keys while there's no delegation or limited
   966  // keys support.  The obtained assertion stream may also include
   967  // account keys that are directly or indirectly signed by a trusted
   968  // key.
   969  var (
   970  	trustedRepairRootKeys []*asserts.AccountKey
   971  )
   972  
   973  // Verify verifies that the repair is properly signed by the specific
   974  // trusted root keys or by account keys in the stream (passed via aux)
   975  // directly or indirectly signed by a trusted key.
   976  func (run *Runner) Verify(repair *asserts.Repair, aux []asserts.Assertion) error {
   977  	workBS := asserts.NewMemoryBackstore()
   978  	for _, a := range aux {
   979  		if a.Type() != asserts.AccountKeyType {
   980  			continue
   981  		}
   982  		err := workBS.Put(asserts.AccountKeyType, a)
   983  		if err != nil {
   984  			return err
   985  		}
   986  	}
   987  	trustedBS := asserts.NewMemoryBackstore()
   988  	for _, t := range trustedRepairRootKeys {
   989  		trustedBS.Put(asserts.AccountKeyType, t)
   990  	}
   991  	for _, t := range sysdb.Trusted() {
   992  		// we do *not* add the defalt sysdb trusted account
   993  		// keys here because the repair assertions have their
   994  		// own *dedicated* root of trust
   995  		if t.Type() == asserts.AccountType {
   996  			trustedBS.Put(asserts.AccountType, t)
   997  		}
   998  	}
   999  
  1000  	return verifySignatures(repair, workBS, trustedBS)
  1001  }