github.com/Lephar/snapd@v0.0.0-20210825215435-c7fba9cef4d2/cmd/snap-repair/runner.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2017-2020 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package main
    21  
    22  import (
    23  	"bufio"
    24  	"bytes"
    25  	"crypto/tls"
    26  	"encoding/json"
    27  	"errors"
    28  	"fmt"
    29  	"io"
    30  	"io/ioutil"
    31  	"net/http"
    32  	"net/url"
    33  	"os"
    34  	"os/exec"
    35  	"path/filepath"
    36  	"strconv"
    37  	"strings"
    38  	"syscall"
    39  	"time"
    40  
    41  	"github.com/mvo5/goconfigparser"
    42  	"gopkg.in/retry.v1"
    43  
    44  	"github.com/snapcore/snapd/arch"
    45  	"github.com/snapcore/snapd/asserts"
    46  	"github.com/snapcore/snapd/asserts/sysdb"
    47  	"github.com/snapcore/snapd/dirs"
    48  	"github.com/snapcore/snapd/errtracker"
    49  	"github.com/snapcore/snapd/httputil"
    50  	"github.com/snapcore/snapd/logger"
    51  	"github.com/snapcore/snapd/osutil"
    52  	"github.com/snapcore/snapd/release"
    53  	"github.com/snapcore/snapd/snap"
    54  	"github.com/snapcore/snapd/snapdenv"
    55  	"github.com/snapcore/snapd/strutil"
    56  )
    57  
    58  var (
    59  	// TODO: move inside the repairs themselves?
    60  	defaultRepairTimeout = 30 * time.Minute
    61  )
    62  
    63  var errtrackerReportRepair = errtracker.ReportRepair
    64  
    65  // Repair is a runnable repair.
    66  type Repair struct {
    67  	*asserts.Repair
    68  
    69  	run      *Runner
    70  	sequence int
    71  }
    72  
    73  func (r *Repair) RunDir() string {
    74  	return filepath.Join(dirs.SnapRepairRunDir, r.BrandID(), strconv.Itoa(r.RepairID()))
    75  }
    76  
    77  func (r *Repair) String() string {
    78  	return fmt.Sprintf("%s-%v", r.BrandID(), r.RepairID())
    79  }
    80  
    81  // SetStatus sets the status of the repair in the state and saves the latter.
    82  func (r *Repair) SetStatus(status RepairStatus) {
    83  	brandID := r.BrandID()
    84  	cur := *r.run.state.Sequences[brandID][r.sequence-1]
    85  	cur.Status = status
    86  	r.run.setRepairState(brandID, cur)
    87  	r.run.SaveState()
    88  }
    89  
    90  // makeRepairSymlink ensures $dir/repair exists and is a symlink to
    91  // /usr/lib/snapd/snap-repair
    92  func makeRepairSymlink(dir string) (err error) {
    93  	// make "repair" binary available to the repair scripts via symlink
    94  	// to the real snap-repair
    95  	if err = os.MkdirAll(dir, 0755); err != nil {
    96  		return err
    97  	}
    98  
    99  	old := filepath.Join(dirs.CoreLibExecDir, "snap-repair")
   100  	new := filepath.Join(dir, "repair")
   101  	if err := os.Symlink(old, new); err != nil && !os.IsExist(err) {
   102  		return err
   103  	}
   104  
   105  	return nil
   106  }
   107  
   108  // Run executes the repair script leaving execution trail files on disk.
   109  func (r *Repair) Run() error {
   110  	// write the script to disk
   111  	rundir := r.RunDir()
   112  	err := os.MkdirAll(rundir, 0775)
   113  	if err != nil {
   114  		return err
   115  	}
   116  
   117  	// ensure the script can use "repair done"
   118  	repairToolsDir := filepath.Join(dirs.SnapRunRepairDir, "tools")
   119  	if err := makeRepairSymlink(repairToolsDir); err != nil {
   120  		return err
   121  	}
   122  
   123  	baseName := fmt.Sprintf("r%d", r.Revision())
   124  	script := filepath.Join(rundir, baseName+".script")
   125  	err = osutil.AtomicWriteFile(script, r.Body(), 0700, 0)
   126  	if err != nil {
   127  		return err
   128  	}
   129  
   130  	logPath := filepath.Join(rundir, baseName+".running")
   131  	logf, err := os.OpenFile(logPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
   132  	if err != nil {
   133  		return err
   134  	}
   135  	defer logf.Close()
   136  
   137  	fmt.Fprintf(logf, "repair: %s\n", r)
   138  	fmt.Fprintf(logf, "revision: %d\n", r.Revision())
   139  	fmt.Fprintf(logf, "summary: %s\n", r.Summary())
   140  	fmt.Fprintf(logf, "output:\n")
   141  
   142  	statusR, statusW, err := os.Pipe()
   143  	if err != nil {
   144  		return err
   145  	}
   146  	defer statusR.Close()
   147  	defer statusW.Close()
   148  
   149  	logger.Debugf("executing %s", script)
   150  
   151  	// run the script
   152  	env := os.Environ()
   153  	// we need to hardcode FD=3 because this is the FD after
   154  	// exec.Command() forked. there is no way in go currently
   155  	// to run something right after fork() in the child to
   156  	// know the fd. However because go will close all fds
   157  	// except the ones in "cmd.ExtraFiles" we are safe to set "3"
   158  	env = append(env, "SNAP_REPAIR_STATUS_FD=3")
   159  	env = append(env, "SNAP_REPAIR_RUN_DIR="+rundir)
   160  	// inject repairToolDir into PATH so that the script can use
   161  	// `repair {done,skip,retry}`
   162  	var havePath bool
   163  	for i, envStr := range env {
   164  		if strings.HasPrefix(envStr, "PATH=") {
   165  			newEnv := fmt.Sprintf("%s:%s", strings.TrimSuffix(envStr, ":"), repairToolsDir)
   166  			env[i] = newEnv
   167  			havePath = true
   168  		}
   169  	}
   170  	if !havePath {
   171  		env = append(env, "PATH=/usr/sbin:/usr/bin:/sbin:/bin:"+repairToolsDir)
   172  	}
   173  
   174  	// TODO:UC20 what other details about recover mode should be included in the
   175  	// env for the repair assertion to read about? probably somethings related
   176  	// to degraded.json
   177  	if r.run.state.Device.Mode != "" {
   178  		env = append(env, fmt.Sprintf("SNAP_SYSTEM_MODE=%s", r.run.state.Device.Mode))
   179  	}
   180  
   181  	workdir := filepath.Join(rundir, "work")
   182  	if err := os.MkdirAll(workdir, 0700); err != nil {
   183  		return err
   184  	}
   185  
   186  	cmd := exec.Command(script)
   187  	cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
   188  	cmd.Env = env
   189  	cmd.Dir = workdir
   190  	cmd.ExtraFiles = []*os.File{statusW}
   191  	cmd.Stdout = logf
   192  	cmd.Stderr = logf
   193  	if err = cmd.Start(); err != nil {
   194  		return err
   195  	}
   196  	statusW.Close()
   197  
   198  	// wait for repair to finish or timeout
   199  	var scriptErr error
   200  	killTimerCh := time.After(defaultRepairTimeout)
   201  	doneCh := make(chan error, 1)
   202  	go func() {
   203  		doneCh <- cmd.Wait()
   204  		close(doneCh)
   205  	}()
   206  	select {
   207  	case scriptErr = <-doneCh:
   208  		// done
   209  	case <-killTimerCh:
   210  		if err := osutil.KillProcessGroup(cmd); err != nil {
   211  			logger.Noticef("cannot kill timed out repair %s: %s", r, err)
   212  		}
   213  		scriptErr = fmt.Errorf("repair did not finish within %s", defaultRepairTimeout)
   214  	}
   215  	// read repair status pipe, use the last value
   216  	status := readStatus(statusR)
   217  	statusPath := filepath.Join(rundir, baseName+"."+status.String())
   218  
   219  	// if the script had an error exit status still honor what we
   220  	// read from the status-pipe, however report the error
   221  	if scriptErr != nil {
   222  		scriptErr = fmt.Errorf("repair %s revision %d failed: %s", r, r.Revision(), scriptErr)
   223  		if err := r.errtrackerReport(scriptErr, status, logPath); err != nil {
   224  			logger.Noticef("cannot report error to errtracker: %s", err)
   225  		}
   226  		// ensure the error is present in the output log
   227  		fmt.Fprintf(logf, "\n%s", scriptErr)
   228  	}
   229  	if err := os.Rename(logPath, statusPath); err != nil {
   230  		return err
   231  	}
   232  	r.SetStatus(status)
   233  
   234  	return nil
   235  }
   236  
   237  func readStatus(r io.Reader) RepairStatus {
   238  	var status RepairStatus
   239  	scanner := bufio.NewScanner(r)
   240  	for scanner.Scan() {
   241  		switch strings.TrimSpace(scanner.Text()) {
   242  		case "done":
   243  			status = DoneStatus
   244  		// TODO: support having a script skip over many and up to a given repair-id #
   245  		case "skip":
   246  			status = SkipStatus
   247  		}
   248  	}
   249  	if scanner.Err() != nil {
   250  		return RetryStatus
   251  	}
   252  	return status
   253  }
   254  
   255  // errtrackerReport reports an repairErr with the given logPath to the
   256  // snap error tracker.
   257  func (r *Repair) errtrackerReport(repairErr error, status RepairStatus, logPath string) error {
   258  	errMsg := repairErr.Error()
   259  
   260  	scriptOutput, err := ioutil.ReadFile(logPath)
   261  	if err != nil {
   262  		logger.Noticef("cannot read %s", logPath)
   263  	}
   264  	s := fmt.Sprintf("%s/%d", r.BrandID(), r.RepairID())
   265  
   266  	dupSig := fmt.Sprintf("%s\n%s\noutput:\n%s", s, errMsg, scriptOutput)
   267  	extra := map[string]string{
   268  		"Revision": strconv.Itoa(r.Revision()),
   269  		"BrandID":  r.BrandID(),
   270  		"RepairID": strconv.Itoa(r.RepairID()),
   271  		"Status":   status.String(),
   272  	}
   273  	_, err = errtrackerReportRepair(s, errMsg, dupSig, extra)
   274  	return err
   275  }
   276  
   277  // Runner implements fetching, tracking and running repairs.
   278  type Runner struct {
   279  	BaseURL *url.URL
   280  	cli     *http.Client
   281  
   282  	state         state
   283  	stateModified bool
   284  
   285  	// sequenceNext keeps track of the next integer id in a brand sequence to considered in this run, see Next.
   286  	sequenceNext map[string]int
   287  }
   288  
   289  // NewRunner returns a Runner.
   290  func NewRunner() *Runner {
   291  	run := &Runner{
   292  		sequenceNext: make(map[string]int),
   293  	}
   294  	opts := httputil.ClientOptions{
   295  		MayLogBody:         false,
   296  		ProxyConnectHeader: http.Header{"User-Agent": []string{snapdenv.UserAgent()}},
   297  		TLSConfig: &tls.Config{
   298  			Time: run.now,
   299  		},
   300  		ExtraSSLCerts: &httputil.ExtraSSLCertsFromDir{
   301  			Dir: dirs.SnapdStoreSSLCertsDir,
   302  		},
   303  	}
   304  	run.cli = httputil.NewHTTPClient(&opts)
   305  	return run
   306  }
   307  
   308  var (
   309  	fetchRetryStrategy = retry.LimitCount(7, retry.LimitTime(90*time.Second,
   310  		retry.Exponential{
   311  			Initial: 500 * time.Millisecond,
   312  			Factor:  2.5,
   313  		},
   314  	))
   315  
   316  	peekRetryStrategy = retry.LimitCount(6, retry.LimitTime(44*time.Second,
   317  		retry.Exponential{
   318  			Initial: 500 * time.Millisecond,
   319  			Factor:  2.5,
   320  		},
   321  	))
   322  )
   323  
   324  var (
   325  	ErrRepairNotFound    = errors.New("repair not found")
   326  	ErrRepairNotModified = errors.New("repair was not modified")
   327  )
   328  
   329  var (
   330  	maxRepairScriptSize = 24 * 1024 * 1024
   331  )
   332  
   333  // Fetch retrieves a stream with the repair with the given ids and any
   334  // auxiliary assertions. If revision>=0 the request will include an
   335  // If-None-Match header with an ETag for the revision, and
   336  // ErrRepairNotModified is returned if the revision is still current.
   337  func (run *Runner) Fetch(brandID string, repairID int, revision int) (*asserts.Repair, []asserts.Assertion, error) {
   338  	u, err := run.BaseURL.Parse(fmt.Sprintf("repairs/%s/%d", brandID, repairID))
   339  	if err != nil {
   340  		return nil, nil, err
   341  	}
   342  
   343  	var r []asserts.Assertion
   344  	resp, err := httputil.RetryRequest(u.String(), func() (*http.Response, error) {
   345  		req, err := http.NewRequest("GET", u.String(), nil)
   346  		if err != nil {
   347  			return nil, err
   348  		}
   349  		req.Header.Set("User-Agent", snapdenv.UserAgent())
   350  		req.Header.Set("Accept", "application/x.ubuntu.assertion")
   351  		if revision >= 0 {
   352  			req.Header.Set("If-None-Match", fmt.Sprintf(`"%d"`, revision))
   353  		}
   354  		return run.cli.Do(req)
   355  	}, func(resp *http.Response) error {
   356  		if resp.StatusCode == 200 {
   357  			logger.Debugf("fetching repair %s-%d", brandID, repairID)
   358  
   359  			// TODO: use something like TransferSpeedMonitoringWriter to avoid stalling here
   360  			// decode assertions
   361  			dec := asserts.NewDecoderWithTypeMaxBodySize(resp.Body, map[*asserts.AssertionType]int{
   362  				asserts.RepairType: maxRepairScriptSize,
   363  			})
   364  			for {
   365  				a, err := dec.Decode()
   366  				if err == io.EOF {
   367  					break
   368  				}
   369  				if err != nil {
   370  					return err
   371  				}
   372  				r = append(r, a)
   373  			}
   374  			if len(r) == 0 {
   375  				return io.ErrUnexpectedEOF
   376  			}
   377  		}
   378  		return nil
   379  	}, fetchRetryStrategy)
   380  
   381  	if err != nil {
   382  		return nil, nil, err
   383  	}
   384  
   385  	moveTimeLowerBound := true
   386  	defer func() {
   387  		if moveTimeLowerBound {
   388  			t, _ := http.ParseTime(resp.Header.Get("Date"))
   389  			run.moveTimeLowerBound(t)
   390  		}
   391  	}()
   392  
   393  	switch resp.StatusCode {
   394  	case 200:
   395  		// ok
   396  	case 304:
   397  		// not modified
   398  		return nil, nil, ErrRepairNotModified
   399  	case 404:
   400  		return nil, nil, ErrRepairNotFound
   401  	default:
   402  		moveTimeLowerBound = false
   403  		return nil, nil, fmt.Errorf("cannot fetch repair, unexpected status %d", resp.StatusCode)
   404  	}
   405  
   406  	repair, aux, err := checkStream(brandID, repairID, r)
   407  	if err != nil {
   408  		return nil, nil, fmt.Errorf("cannot fetch repair, %v", err)
   409  	}
   410  
   411  	if repair.Revision() <= revision {
   412  		// this shouldn't happen but if it does we behave like
   413  		// all the rest of assertion infrastructure and ignore
   414  		// the now superseded revision
   415  		return nil, nil, ErrRepairNotModified
   416  	}
   417  
   418  	return repair, aux, err
   419  }
   420  
   421  func checkStream(brandID string, repairID int, r []asserts.Assertion) (repair *asserts.Repair, aux []asserts.Assertion, err error) {
   422  	if len(r) == 0 {
   423  		return nil, nil, fmt.Errorf("empty repair assertions stream")
   424  	}
   425  	var ok bool
   426  	repair, ok = r[0].(*asserts.Repair)
   427  	if !ok {
   428  		return nil, nil, fmt.Errorf("unexpected first assertion %q", r[0].Type().Name)
   429  	}
   430  
   431  	if repair.BrandID() != brandID || repair.RepairID() != repairID {
   432  		return nil, nil, fmt.Errorf("repair id mismatch %s/%d != %s/%d", repair.BrandID(), repair.RepairID(), brandID, repairID)
   433  	}
   434  
   435  	return repair, r[1:], nil
   436  }
   437  
   438  type peekResp struct {
   439  	Headers map[string]interface{} `json:"headers"`
   440  }
   441  
   442  // Peek retrieves the headers for the repair with the given ids.
   443  func (run *Runner) Peek(brandID string, repairID int) (headers map[string]interface{}, err error) {
   444  	u, err := run.BaseURL.Parse(fmt.Sprintf("repairs/%s/%d", brandID, repairID))
   445  	if err != nil {
   446  		return nil, err
   447  	}
   448  
   449  	var rsp peekResp
   450  
   451  	resp, err := httputil.RetryRequest(u.String(), func() (*http.Response, error) {
   452  		// TODO: setup a overall request timeout using contexts
   453  		// can be many minutes but not unlimited like now
   454  		req, err := http.NewRequest("GET", u.String(), nil)
   455  		if err != nil {
   456  			return nil, err
   457  		}
   458  		req.Header.Set("User-Agent", snapdenv.UserAgent())
   459  		req.Header.Set("Accept", "application/json")
   460  		return run.cli.Do(req)
   461  	}, func(resp *http.Response) error {
   462  		rsp.Headers = nil
   463  		if resp.StatusCode == 200 {
   464  			dec := json.NewDecoder(resp.Body)
   465  			return dec.Decode(&rsp)
   466  		}
   467  		return nil
   468  	}, peekRetryStrategy)
   469  
   470  	if err != nil {
   471  		return nil, err
   472  	}
   473  
   474  	moveTimeLowerBound := true
   475  	defer func() {
   476  		if moveTimeLowerBound {
   477  			t, _ := http.ParseTime(resp.Header.Get("Date"))
   478  			run.moveTimeLowerBound(t)
   479  		}
   480  	}()
   481  
   482  	switch resp.StatusCode {
   483  	case 200:
   484  		// ok
   485  	case 404:
   486  		return nil, ErrRepairNotFound
   487  	default:
   488  		moveTimeLowerBound = false
   489  		return nil, fmt.Errorf("cannot peek repair headers, unexpected status %d", resp.StatusCode)
   490  	}
   491  
   492  	headers = rsp.Headers
   493  	if headers["brand-id"] != brandID || headers["repair-id"] != strconv.Itoa(repairID) {
   494  		return nil, fmt.Errorf("cannot peek repair headers, repair id mismatch %s/%s != %s/%d", headers["brand-id"], headers["repair-id"], brandID, repairID)
   495  	}
   496  
   497  	return headers, nil
   498  }
   499  
   500  // deviceInfo captures information about the device.
   501  type deviceInfo struct {
   502  	Brand string `json:"brand"`
   503  	Model string `json:"model"`
   504  	Base  string `json:"base"`
   505  	Mode  string `json:"mode"`
   506  }
   507  
   508  // RepairStatus represents the possible statuses of a repair.
   509  type RepairStatus int
   510  
   511  const (
   512  	RetryStatus RepairStatus = iota
   513  	SkipStatus
   514  	DoneStatus
   515  )
   516  
   517  func (rs RepairStatus) String() string {
   518  	switch rs {
   519  	case RetryStatus:
   520  		return "retry"
   521  	case SkipStatus:
   522  		return "skip"
   523  	case DoneStatus:
   524  		return "done"
   525  	default:
   526  		return "unknown"
   527  	}
   528  }
   529  
   530  // RepairState holds the current revision and status of a repair in a sequence of repairs.
   531  type RepairState struct {
   532  	Sequence int          `json:"sequence"`
   533  	Revision int          `json:"revision"`
   534  	Status   RepairStatus `json:"status"`
   535  }
   536  
   537  // state holds the atomically updated control state of the runner with sequences of repairs and their states.
   538  type state struct {
   539  	Device         deviceInfo                `json:"device"`
   540  	Sequences      map[string][]*RepairState `json:"sequences,omitempty"`
   541  	TimeLowerBound time.Time                 `json:"time-lower-bound"`
   542  }
   543  
   544  func (run *Runner) setRepairState(brandID string, state RepairState) {
   545  	if run.state.Sequences == nil {
   546  		run.state.Sequences = make(map[string][]*RepairState)
   547  	}
   548  	sequence := run.state.Sequences[brandID]
   549  	if state.Sequence > len(sequence) {
   550  		run.stateModified = true
   551  		run.state.Sequences[brandID] = append(sequence, &state)
   552  	} else if *sequence[state.Sequence-1] != state {
   553  		run.stateModified = true
   554  		sequence[state.Sequence-1] = &state
   555  	}
   556  }
   557  
   558  func (run *Runner) readState() error {
   559  	r, err := os.Open(dirs.SnapRepairStateFile)
   560  	if err != nil {
   561  		return err
   562  	}
   563  	defer r.Close()
   564  	dec := json.NewDecoder(r)
   565  	return dec.Decode(&run.state)
   566  }
   567  
   568  func (run *Runner) moveTimeLowerBound(t time.Time) {
   569  	if t.After(run.state.TimeLowerBound) {
   570  		run.stateModified = true
   571  		run.state.TimeLowerBound = t.UTC()
   572  	}
   573  }
   574  
   575  var timeNow = time.Now
   576  
   577  func (run *Runner) now() time.Time {
   578  	now := timeNow().UTC()
   579  	if now.Before(run.state.TimeLowerBound) {
   580  		return run.state.TimeLowerBound
   581  	}
   582  	return now
   583  }
   584  
   585  func (run *Runner) initState() error {
   586  	if err := os.MkdirAll(dirs.SnapRepairDir, 0775); err != nil {
   587  		return fmt.Errorf("cannot create repair state directory: %v", err)
   588  	}
   589  	// best-effort remove old
   590  	os.Remove(dirs.SnapRepairStateFile)
   591  	run.state = state{}
   592  	// initialize time lower bound with image built time/seed.yaml time
   593  	if err := run.findTimeLowerBound(); err != nil {
   594  		return err
   595  	}
   596  	// initialize device info
   597  	if err := run.initDeviceInfo(); err != nil {
   598  		return err
   599  	}
   600  	run.stateModified = true
   601  	return run.SaveState()
   602  }
   603  
   604  func trustedBackstore(trusted []asserts.Assertion) asserts.Backstore {
   605  	trustedBS := asserts.NewMemoryBackstore()
   606  	for _, t := range trusted {
   607  		trustedBS.Put(t.Type(), t)
   608  	}
   609  	return trustedBS
   610  }
   611  
   612  func checkAuthorityID(a asserts.Assertion, trusted asserts.Backstore) error {
   613  	assertType := a.Type()
   614  	if assertType != asserts.AccountKeyType && assertType != asserts.AccountType {
   615  		return nil
   616  	}
   617  	// check that account and account-key assertions are signed by
   618  	// a trusted authority
   619  	acctID := a.AuthorityID()
   620  	_, err := trusted.Get(asserts.AccountType, []string{acctID}, asserts.AccountType.MaxSupportedFormat())
   621  	if err != nil && !asserts.IsNotFound(err) {
   622  		return err
   623  	}
   624  	if asserts.IsNotFound(err) {
   625  		return fmt.Errorf("%v not signed by trusted authority: %s", a.Ref(), acctID)
   626  	}
   627  	return nil
   628  }
   629  
   630  func verifySignatures(a asserts.Assertion, workBS asserts.Backstore, trusted asserts.Backstore) error {
   631  	if err := checkAuthorityID(a, trusted); err != nil {
   632  		return err
   633  	}
   634  	acctKeyMaxSuppFormat := asserts.AccountKeyType.MaxSupportedFormat()
   635  
   636  	seen := make(map[string]bool)
   637  	bottom := false
   638  	for !bottom {
   639  		u := a.Ref().Unique()
   640  		if seen[u] {
   641  			return fmt.Errorf("circular assertions")
   642  		}
   643  		seen[u] = true
   644  		signKey := []string{a.SignKeyID()}
   645  		key, err := trusted.Get(asserts.AccountKeyType, signKey, acctKeyMaxSuppFormat)
   646  		if err != nil && !asserts.IsNotFound(err) {
   647  			return err
   648  		}
   649  		if err == nil {
   650  			bottom = true
   651  		} else {
   652  			key, err = workBS.Get(asserts.AccountKeyType, signKey, acctKeyMaxSuppFormat)
   653  			if err != nil && !asserts.IsNotFound(err) {
   654  				return err
   655  			}
   656  			if asserts.IsNotFound(err) {
   657  				return fmt.Errorf("cannot find public key %q", signKey[0])
   658  			}
   659  			if err := checkAuthorityID(key, trusted); err != nil {
   660  				return err
   661  			}
   662  		}
   663  		if err := asserts.CheckSignature(a, key.(*asserts.AccountKey), nil, time.Time{}, time.Time{}); err != nil {
   664  			return err
   665  		}
   666  		a = key
   667  	}
   668  	return nil
   669  }
   670  
   671  func (run *Runner) findTimeLowerBound() error {
   672  	timeLowerBoundSources := []string{
   673  		// uc16
   674  		filepath.Join(dirs.SnapSeedDir, "seed.yaml"),
   675  		// uc20+
   676  		dirs.SnapModeenvFile,
   677  	}
   678  	// add all model files from uc20 seeds
   679  	allModels, err := filepath.Glob(filepath.Join(dirs.SnapSeedDir, "systems/*/model"))
   680  	if err != nil {
   681  		return err
   682  	}
   683  	timeLowerBoundSources = append(timeLowerBoundSources, allModels...)
   684  
   685  	// use all files as potential time inputs
   686  	for _, p := range timeLowerBoundSources {
   687  		info, err := os.Stat(p)
   688  		if os.IsNotExist(err) {
   689  			continue
   690  		}
   691  		if err != nil {
   692  			return err
   693  		}
   694  		run.moveTimeLowerBound(info.ModTime())
   695  	}
   696  	return nil
   697  }
   698  
   699  func findBrandAndModel() (*deviceInfo, error) {
   700  	if osutil.FileExists(dirs.SnapModeenvFile) {
   701  		return findDevInfo20()
   702  	}
   703  	return findDevInfo16()
   704  }
   705  
   706  func findDevInfo20() (*deviceInfo, error) {
   707  	cfg := goconfigparser.New()
   708  	cfg.AllowNoSectionHeader = true
   709  	if err := cfg.ReadFile(dirs.SnapModeenvFile); err != nil {
   710  		return nil, err
   711  	}
   712  	brandAndModel, err := cfg.Get("", "model")
   713  	if err != nil {
   714  		return nil, err
   715  	}
   716  	l := strings.SplitN(brandAndModel, "/", 2)
   717  	if len(l) != 2 {
   718  		return nil, fmt.Errorf("cannot find brand/model in modeenv model string %q", brandAndModel)
   719  	}
   720  
   721  	mode, err := cfg.Get("", "mode")
   722  	if err != nil {
   723  		return nil, err
   724  	}
   725  
   726  	baseName, err := cfg.Get("", "base")
   727  	if err != nil {
   728  		return nil, err
   729  	}
   730  
   731  	baseSn, err := snap.ParsePlaceInfoFromSnapFileName(baseName)
   732  	if err != nil {
   733  		return nil, err
   734  	}
   735  
   736  	return &deviceInfo{
   737  		Brand: l[0],
   738  		Model: l[1],
   739  		Base:  baseSn.SnapName(),
   740  		Mode:  mode,
   741  	}, nil
   742  }
   743  
   744  func findDevInfo16() (*deviceInfo, error) {
   745  	workBS := asserts.NewMemoryBackstore()
   746  	assertSeedDir := filepath.Join(dirs.SnapSeedDir, "assertions")
   747  	dc, err := ioutil.ReadDir(assertSeedDir)
   748  	if err != nil {
   749  		return nil, err
   750  	}
   751  	var modelAs *asserts.Model
   752  	for _, fi := range dc {
   753  		fn := filepath.Join(assertSeedDir, fi.Name())
   754  		f, err := os.Open(fn)
   755  		if err != nil {
   756  			// best effort
   757  			continue
   758  		}
   759  		dec := asserts.NewDecoder(f)
   760  		for {
   761  			a, err := dec.Decode()
   762  			if err != nil {
   763  				// best effort
   764  				break
   765  			}
   766  			switch a.Type() {
   767  			case asserts.ModelType:
   768  				if modelAs != nil {
   769  					return nil, fmt.Errorf("multiple models in seed assertions")
   770  				}
   771  				modelAs = a.(*asserts.Model)
   772  			case asserts.AccountType, asserts.AccountKeyType:
   773  				workBS.Put(a.Type(), a)
   774  			}
   775  		}
   776  	}
   777  	if modelAs == nil {
   778  		return nil, fmt.Errorf("no model assertion in seed data")
   779  	}
   780  	trustedBS := trustedBackstore(sysdb.Trusted())
   781  	if err := verifySignatures(modelAs, workBS, trustedBS); err != nil {
   782  		return nil, err
   783  	}
   784  	acctPK := []string{modelAs.BrandID()}
   785  	acctMaxSupFormat := asserts.AccountType.MaxSupportedFormat()
   786  	acct, err := trustedBS.Get(asserts.AccountType, acctPK, acctMaxSupFormat)
   787  	if err != nil {
   788  		var err error
   789  		acct, err = workBS.Get(asserts.AccountType, acctPK, acctMaxSupFormat)
   790  		if err != nil {
   791  			return nil, fmt.Errorf("no brand account assertion in seed data")
   792  		}
   793  	}
   794  	if err := verifySignatures(acct, workBS, trustedBS); err != nil {
   795  		return nil, err
   796  	}
   797  
   798  	// get the base snap as well, on uc16 it won't be specified in the model
   799  	// assertion and instead will be empty, so in this case we replace it with
   800  	// "core"
   801  	base := modelAs.Base()
   802  	if modelAs.Base() == "" {
   803  		base = "core"
   804  	}
   805  
   806  	return &deviceInfo{
   807  		Brand: modelAs.BrandID(),
   808  		Model: modelAs.Model(),
   809  		Base:  base,
   810  		// Mode is unset on uc16/uc18
   811  	}, nil
   812  }
   813  
   814  func (run *Runner) initDeviceInfo() error {
   815  	dev, err := findBrandAndModel()
   816  	if err != nil {
   817  		return fmt.Errorf("cannot set device information: %v", err)
   818  	}
   819  	run.state.Device = *dev
   820  
   821  	return nil
   822  }
   823  
   824  // LoadState loads the repairs' state from disk, and (re)initializes it if it's missing or corrupted.
   825  func (run *Runner) LoadState() error {
   826  	err := run.readState()
   827  	if err == nil {
   828  		return nil
   829  	}
   830  	// error => initialize from scratch
   831  	if !os.IsNotExist(err) {
   832  		logger.Noticef("cannor read repair state: %v", err)
   833  	}
   834  	return run.initState()
   835  }
   836  
   837  // SaveState saves the repairs' state to disk.
   838  func (run *Runner) SaveState() error {
   839  	if !run.stateModified {
   840  		return nil
   841  	}
   842  	m, err := json.Marshal(&run.state)
   843  	if err != nil {
   844  		return fmt.Errorf("cannot marshal repair state: %v", err)
   845  	}
   846  	err = osutil.AtomicWriteFile(dirs.SnapRepairStateFile, m, 0600, 0)
   847  	if err != nil {
   848  		return fmt.Errorf("cannot save repair state: %v", err)
   849  	}
   850  	run.stateModified = false
   851  	return nil
   852  }
   853  
   854  func stringList(headers map[string]interface{}, name string) ([]string, error) {
   855  	v, ok := headers[name]
   856  	if !ok {
   857  		return nil, nil
   858  	}
   859  	l, ok := v.([]interface{})
   860  	if !ok {
   861  		return nil, fmt.Errorf("header %q is not a list", name)
   862  	}
   863  	r := make([]string, len(l))
   864  	for i, v := range l {
   865  		s, ok := v.(string)
   866  		if !ok {
   867  			return nil, fmt.Errorf("header %q contains non-string elements", name)
   868  		}
   869  		r[i] = s
   870  	}
   871  	return r, nil
   872  }
   873  
   874  // Applicable returns whether a repair with the given headers is applicable to the device.
   875  func (run *Runner) Applicable(headers map[string]interface{}) bool {
   876  	if headers["disabled"] == "true" {
   877  		return false
   878  	}
   879  	series, err := stringList(headers, "series")
   880  	if err != nil {
   881  		return false
   882  	}
   883  	if len(series) != 0 && !strutil.ListContains(series, release.Series) {
   884  		return false
   885  	}
   886  	archs, err := stringList(headers, "architectures")
   887  	if err != nil {
   888  		return false
   889  	}
   890  	if len(archs) != 0 && !strutil.ListContains(archs, arch.DpkgArchitecture()) {
   891  		return false
   892  	}
   893  	brandModel := fmt.Sprintf("%s/%s", run.state.Device.Brand, run.state.Device.Model)
   894  	models, err := stringList(headers, "models")
   895  	if err != nil {
   896  		return false
   897  	}
   898  	if len(models) != 0 && !strutil.ListContains(models, brandModel) {
   899  		// model prefix matching: brand/prefix*
   900  		hit := false
   901  		for _, patt := range models {
   902  			if strings.HasSuffix(patt, "*") && strings.ContainsRune(patt, '/') {
   903  				if strings.HasPrefix(brandModel, strings.TrimSuffix(patt, "*")) {
   904  					hit = true
   905  					break
   906  				}
   907  			}
   908  		}
   909  		if !hit {
   910  			return false
   911  		}
   912  	}
   913  
   914  	// also filter by base snaps and modes
   915  	bases, err := stringList(headers, "bases")
   916  	if err != nil {
   917  		return false
   918  	}
   919  
   920  	if len(bases) != 0 && !strutil.ListContains(bases, run.state.Device.Base) {
   921  		return false
   922  	}
   923  
   924  	modes, err := stringList(headers, "modes")
   925  	if err != nil {
   926  		return false
   927  	}
   928  
   929  	// modes is slightly more nuanced, if the modes setting in the assertion
   930  	// header is unset, then it means it runs on all uc16/uc18 devices, but only
   931  	// during run mode on uc20 devices
   932  	if run.state.Device.Mode == "" {
   933  		// uc16 / uc18 device, the assertion is only applicable to us if modes
   934  		// is unset
   935  		if len(modes) != 0 {
   936  			return false
   937  		}
   938  		// else modes is unset and still applies to us
   939  	} else {
   940  		// uc20 device
   941  		switch {
   942  		case len(modes) == 0 && run.state.Device.Mode != "run":
   943  			// if modes is unset, then it is only applicable if we are
   944  			// in run mode
   945  			return false
   946  		case len(modes) != 0 && !strutil.ListContains(modes, run.state.Device.Mode):
   947  			// modes was specified and our current mode is not in the header, so
   948  			// not applicable to us
   949  			return false
   950  		}
   951  		// other cases are either that we are in run mode and modes is unset (in
   952  		// which case it is applicable) or modes is set to something with our
   953  		// current mode in the list (also in which case it is applicable)
   954  	}
   955  
   956  	return true
   957  }
   958  
   959  var errSkip = errors.New("repair unnecessary on this system")
   960  
   961  func (run *Runner) fetch(brandID string, repairID int) (repair *asserts.Repair, aux []asserts.Assertion, err error) {
   962  	headers, err := run.Peek(brandID, repairID)
   963  	if err != nil {
   964  		return nil, nil, err
   965  	}
   966  	if !run.Applicable(headers) {
   967  		return nil, nil, errSkip
   968  	}
   969  	return run.Fetch(brandID, repairID, -1)
   970  }
   971  
   972  func (run *Runner) refetch(brandID string, repairID, revision int) (repair *asserts.Repair, aux []asserts.Assertion, err error) {
   973  	return run.Fetch(brandID, repairID, revision)
   974  }
   975  
   976  func (run *Runner) saveStream(brandID string, repairID int, repair *asserts.Repair, aux []asserts.Assertion) error {
   977  	d := filepath.Join(dirs.SnapRepairAssertsDir, brandID, strconv.Itoa(repairID))
   978  	err := os.MkdirAll(d, 0775)
   979  	if err != nil {
   980  		return err
   981  	}
   982  	buf := &bytes.Buffer{}
   983  	enc := asserts.NewEncoder(buf)
   984  	r := append([]asserts.Assertion{repair}, aux...)
   985  	for _, a := range r {
   986  		if err := enc.Encode(a); err != nil {
   987  			return fmt.Errorf("cannot encode repair assertions %s-%d for saving: %v", brandID, repairID, err)
   988  		}
   989  	}
   990  	p := filepath.Join(d, fmt.Sprintf("r%d.repair", r[0].Revision()))
   991  	return osutil.AtomicWriteFile(p, buf.Bytes(), 0600, 0)
   992  }
   993  
   994  func (run *Runner) readSavedStream(brandID string, repairID, revision int) (repair *asserts.Repair, aux []asserts.Assertion, err error) {
   995  	d := filepath.Join(dirs.SnapRepairAssertsDir, brandID, strconv.Itoa(repairID))
   996  	p := filepath.Join(d, fmt.Sprintf("r%d.repair", revision))
   997  	f, err := os.Open(p)
   998  	if err != nil {
   999  		return nil, nil, err
  1000  	}
  1001  	defer f.Close()
  1002  
  1003  	dec := asserts.NewDecoder(f)
  1004  	var r []asserts.Assertion
  1005  	for {
  1006  		a, err := dec.Decode()
  1007  		if err == io.EOF {
  1008  			break
  1009  		}
  1010  		if err != nil {
  1011  			return nil, nil, fmt.Errorf("cannot decode repair assertions %s-%d from disk: %v", brandID, repairID, err)
  1012  		}
  1013  		r = append(r, a)
  1014  	}
  1015  	return checkStream(brandID, repairID, r)
  1016  }
  1017  
  1018  func (run *Runner) makeReady(brandID string, sequenceNext int) (repair *asserts.Repair, err error) {
  1019  	sequence := run.state.Sequences[brandID]
  1020  	var aux []asserts.Assertion
  1021  	var state RepairState
  1022  	if sequenceNext <= len(sequence) {
  1023  		// consider retries
  1024  		state = *sequence[sequenceNext-1]
  1025  		if state.Status != RetryStatus {
  1026  			return nil, errSkip
  1027  		}
  1028  		var err error
  1029  		repair, aux, err = run.refetch(brandID, state.Sequence, state.Revision)
  1030  		if err != nil {
  1031  			if err != ErrRepairNotModified {
  1032  				logger.Noticef("cannot refetch repair %s-%d, will retry what is on disk: %v", brandID, sequenceNext, err)
  1033  			}
  1034  			// try to use what we have already on disk
  1035  			repair, aux, err = run.readSavedStream(brandID, state.Sequence, state.Revision)
  1036  			if err != nil {
  1037  				return nil, err
  1038  			}
  1039  		}
  1040  	} else {
  1041  		// fetch the next repair in the sequence
  1042  		// assumes no gaps, each repair id is present so far,
  1043  		// possibly skipped
  1044  		var err error
  1045  		repair, aux, err = run.fetch(brandID, sequenceNext)
  1046  		if err != nil && err != errSkip {
  1047  			return nil, err
  1048  		}
  1049  		state = RepairState{
  1050  			Sequence: sequenceNext,
  1051  		}
  1052  		if err == errSkip {
  1053  			// TODO: store headers to justify decision
  1054  			state.Status = SkipStatus
  1055  			run.setRepairState(brandID, state)
  1056  			return nil, errSkip
  1057  		}
  1058  	}
  1059  	// verify with signatures
  1060  	if err := run.Verify(repair, aux); err != nil {
  1061  		return nil, fmt.Errorf("cannot verify repair %s-%d: %v", brandID, state.Sequence, err)
  1062  	}
  1063  	if err := run.saveStream(brandID, state.Sequence, repair, aux); err != nil {
  1064  		return nil, err
  1065  	}
  1066  	state.Revision = repair.Revision()
  1067  	if !run.Applicable(repair.Headers()) {
  1068  		state.Status = SkipStatus
  1069  		run.setRepairState(brandID, state)
  1070  		return nil, errSkip
  1071  	}
  1072  	run.setRepairState(brandID, state)
  1073  	return repair, nil
  1074  }
  1075  
  1076  // Next returns the next repair for the brand id sequence to run/retry or
  1077  // ErrRepairNotFound if there is none atm. It updates the state as required.
  1078  func (run *Runner) Next(brandID string) (*Repair, error) {
  1079  	sequenceNext := run.sequenceNext[brandID]
  1080  	if sequenceNext == 0 {
  1081  		sequenceNext = 1
  1082  	}
  1083  	for {
  1084  		repair, err := run.makeReady(brandID, sequenceNext)
  1085  		// SaveState is a no-op unless makeReady modified the state
  1086  		stateErr := run.SaveState()
  1087  		if err != nil && err != errSkip && err != ErrRepairNotFound {
  1088  			// err is a non trivial error, just log the SaveState error and report err
  1089  			if stateErr != nil {
  1090  				logger.Noticef("%v", stateErr)
  1091  			}
  1092  			return nil, err
  1093  		}
  1094  		if stateErr != nil {
  1095  			return nil, stateErr
  1096  		}
  1097  		if err == ErrRepairNotFound {
  1098  			return nil, ErrRepairNotFound
  1099  		}
  1100  
  1101  		sequenceNext += 1
  1102  		run.sequenceNext[brandID] = sequenceNext
  1103  		if err == errSkip {
  1104  			continue
  1105  		}
  1106  
  1107  		return &Repair{
  1108  			Repair:   repair,
  1109  			run:      run,
  1110  			sequence: sequenceNext - 1,
  1111  		}, nil
  1112  	}
  1113  }
  1114  
  1115  // Limit trust to specific keys while there's no delegation or limited
  1116  // keys support.  The obtained assertion stream may also include
  1117  // account keys that are directly or indirectly signed by a trusted
  1118  // key.
  1119  var (
  1120  	trustedRepairRootKeys []*asserts.AccountKey
  1121  )
  1122  
  1123  // Verify verifies that the repair is properly signed by the specific
  1124  // trusted root keys or by account keys in the stream (passed via aux)
  1125  // directly or indirectly signed by a trusted key.
  1126  func (run *Runner) Verify(repair *asserts.Repair, aux []asserts.Assertion) error {
  1127  	workBS := asserts.NewMemoryBackstore()
  1128  	for _, a := range aux {
  1129  		if a.Type() != asserts.AccountKeyType {
  1130  			continue
  1131  		}
  1132  		err := workBS.Put(asserts.AccountKeyType, a)
  1133  		if err != nil {
  1134  			return err
  1135  		}
  1136  	}
  1137  	trustedBS := asserts.NewMemoryBackstore()
  1138  	for _, t := range trustedRepairRootKeys {
  1139  		trustedBS.Put(asserts.AccountKeyType, t)
  1140  	}
  1141  	for _, t := range sysdb.Trusted() {
  1142  		// we do *not* add the defalt sysdb trusted account
  1143  		// keys here because the repair assertions have their
  1144  		// own *dedicated* root of trust
  1145  		if t.Type() == asserts.AccountType {
  1146  			trustedBS.Put(asserts.AccountType, t)
  1147  		}
  1148  	}
  1149  
  1150  	return verifySignatures(repair, workBS, trustedBS)
  1151  }