github.com/rigado/snapd@v2.42.5-go-mod+incompatible/cmd/snap-repair/runner.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2017 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package main
    21  
    22  import (
    23  	"bufio"
    24  	"bytes"
    25  	"crypto/tls"
    26  	"encoding/json"
    27  	"errors"
    28  	"fmt"
    29  	"io"
    30  	"io/ioutil"
    31  	"net/http"
    32  	"net/url"
    33  	"os"
    34  	"os/exec"
    35  	"path/filepath"
    36  	"strconv"
    37  	"strings"
    38  	"syscall"
    39  	"time"
    40  
    41  	"gopkg.in/retry.v1"
    42  
    43  	"github.com/snapcore/snapd/arch"
    44  	"github.com/snapcore/snapd/asserts"
    45  	"github.com/snapcore/snapd/asserts/sysdb"
    46  	"github.com/snapcore/snapd/dirs"
    47  	"github.com/snapcore/snapd/errtracker"
    48  	"github.com/snapcore/snapd/httputil"
    49  	"github.com/snapcore/snapd/logger"
    50  	"github.com/snapcore/snapd/osutil"
    51  	"github.com/snapcore/snapd/release"
    52  	"github.com/snapcore/snapd/strutil"
    53  )
    54  
    55  var (
    56  	// TODO: move inside the repairs themselves?
    57  	defaultRepairTimeout = 30 * time.Minute
    58  )
    59  
    60  var errtrackerReportRepair = errtracker.ReportRepair
    61  
    62  // Repair is a runnable repair.
    63  type Repair struct {
    64  	*asserts.Repair
    65  
    66  	run      *Runner
    67  	sequence int
    68  }
    69  
    70  func (r *Repair) RunDir() string {
    71  	return filepath.Join(dirs.SnapRepairRunDir, r.BrandID(), strconv.Itoa(r.RepairID()))
    72  }
    73  
    74  func (r *Repair) String() string {
    75  	return fmt.Sprintf("%s-%v", r.BrandID(), r.RepairID())
    76  }
    77  
    78  // SetStatus sets the status of the repair in the state and saves the latter.
    79  func (r *Repair) SetStatus(status RepairStatus) {
    80  	brandID := r.BrandID()
    81  	cur := *r.run.state.Sequences[brandID][r.sequence-1]
    82  	cur.Status = status
    83  	r.run.setRepairState(brandID, cur)
    84  	r.run.SaveState()
    85  }
    86  
    87  // makeRepairSymlink ensures $dir/repair exists and is a symlink to
    88  // /usr/lib/snapd/snap-repair
    89  func makeRepairSymlink(dir string) (err error) {
    90  	// make "repair" binary available to the repair scripts via symlink
    91  	// to the real snap-repair
    92  	if err = os.MkdirAll(dir, 0755); err != nil {
    93  		return err
    94  	}
    95  
    96  	old := filepath.Join(dirs.CoreLibExecDir, "snap-repair")
    97  	new := filepath.Join(dir, "repair")
    98  	if err := os.Symlink(old, new); err != nil && !os.IsExist(err) {
    99  		return err
   100  	}
   101  
   102  	return nil
   103  }
   104  
   105  // Run executes the repair script leaving execution trail files on disk.
   106  func (r *Repair) Run() error {
   107  	// write the script to disk
   108  	rundir := r.RunDir()
   109  	err := os.MkdirAll(rundir, 0775)
   110  	if err != nil {
   111  		return err
   112  	}
   113  
   114  	// ensure the script can use "repair done"
   115  	repairToolsDir := filepath.Join(dirs.SnapRunRepairDir, "tools")
   116  	if err := makeRepairSymlink(repairToolsDir); err != nil {
   117  		return err
   118  	}
   119  
   120  	baseName := fmt.Sprintf("r%d", r.Revision())
   121  	script := filepath.Join(rundir, baseName+".script")
   122  	err = osutil.AtomicWriteFile(script, r.Body(), 0700, 0)
   123  	if err != nil {
   124  		return err
   125  	}
   126  
   127  	logPath := filepath.Join(rundir, baseName+".running")
   128  	logf, err := os.OpenFile(logPath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0600)
   129  	if err != nil {
   130  		return err
   131  	}
   132  	defer logf.Close()
   133  
   134  	fmt.Fprintf(logf, "repair: %s\n", r)
   135  	fmt.Fprintf(logf, "revision: %d\n", r.Revision())
   136  	fmt.Fprintf(logf, "summary: %s\n", r.Summary())
   137  	fmt.Fprintf(logf, "output:\n")
   138  
   139  	statusR, statusW, err := os.Pipe()
   140  	if err != nil {
   141  		return err
   142  	}
   143  	defer statusR.Close()
   144  	defer statusW.Close()
   145  
   146  	logger.Debugf("executing %s", script)
   147  
   148  	// run the script
   149  	env := os.Environ()
   150  	// we need to hardcode FD=3 because this is the FD after
   151  	// exec.Command() forked. there is no way in go currently
   152  	// to run something right after fork() in the child to
   153  	// know the fd. However because go will close all fds
   154  	// except the ones in "cmd.ExtraFiles" we are safe to set "3"
   155  	env = append(env, "SNAP_REPAIR_STATUS_FD=3")
   156  	env = append(env, "SNAP_REPAIR_RUN_DIR="+rundir)
   157  	// inject repairToolDir into PATH so that the script can use
   158  	// `repair {done,skip,retry}`
   159  	var havePath bool
   160  	for i, envStr := range env {
   161  		if strings.HasPrefix(envStr, "PATH=") {
   162  			newEnv := fmt.Sprintf("%s:%s", strings.TrimSuffix(envStr, ":"), repairToolsDir)
   163  			env[i] = newEnv
   164  			havePath = true
   165  		}
   166  	}
   167  	if !havePath {
   168  		env = append(env, "PATH=/usr/sbin:/usr/bin:/sbin:/bin:"+repairToolsDir)
   169  	}
   170  
   171  	workdir := filepath.Join(rundir, "work")
   172  	if err := os.MkdirAll(workdir, 0700); err != nil {
   173  		return err
   174  	}
   175  
   176  	cmd := exec.Command(script)
   177  	cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
   178  	cmd.Env = env
   179  	cmd.Dir = workdir
   180  	cmd.ExtraFiles = []*os.File{statusW}
   181  	cmd.Stdout = logf
   182  	cmd.Stderr = logf
   183  	if err = cmd.Start(); err != nil {
   184  		return err
   185  	}
   186  	statusW.Close()
   187  
   188  	// wait for repair to finish or timeout
   189  	var scriptErr error
   190  	killTimerCh := time.After(defaultRepairTimeout)
   191  	doneCh := make(chan error, 1)
   192  	go func() {
   193  		doneCh <- cmd.Wait()
   194  		close(doneCh)
   195  	}()
   196  	select {
   197  	case scriptErr = <-doneCh:
   198  		// done
   199  	case <-killTimerCh:
   200  		if err := osutil.KillProcessGroup(cmd); err != nil {
   201  			logger.Noticef("cannot kill timed out repair %s: %s", r, err)
   202  		}
   203  		scriptErr = fmt.Errorf("repair did not finish within %s", defaultRepairTimeout)
   204  	}
   205  	// read repair status pipe, use the last value
   206  	status := readStatus(statusR)
   207  	statusPath := filepath.Join(rundir, baseName+"."+status.String())
   208  
   209  	// if the script had an error exit status still honor what we
   210  	// read from the status-pipe, however report the error
   211  	if scriptErr != nil {
   212  		scriptErr = fmt.Errorf("repair %s revision %d failed: %s", r, r.Revision(), scriptErr)
   213  		if err := r.errtrackerReport(scriptErr, status, logPath); err != nil {
   214  			logger.Noticef("cannot report error to errtracker: %s", err)
   215  		}
   216  		// ensure the error is present in the output log
   217  		fmt.Fprintf(logf, "\n%s", scriptErr)
   218  	}
   219  	if err := os.Rename(logPath, statusPath); err != nil {
   220  		return err
   221  	}
   222  	r.SetStatus(status)
   223  
   224  	return nil
   225  }
   226  
   227  func readStatus(r io.Reader) RepairStatus {
   228  	var status RepairStatus
   229  	scanner := bufio.NewScanner(r)
   230  	for scanner.Scan() {
   231  		switch strings.TrimSpace(scanner.Text()) {
   232  		case "done":
   233  			status = DoneStatus
   234  		// TODO: support having a script skip over many and up to a given repair-id #
   235  		case "skip":
   236  			status = SkipStatus
   237  		}
   238  	}
   239  	if scanner.Err() != nil {
   240  		return RetryStatus
   241  	}
   242  	return status
   243  }
   244  
   245  // errtrackerReport reports an repairErr with the given logPath to the
   246  // snap error tracker.
   247  func (r *Repair) errtrackerReport(repairErr error, status RepairStatus, logPath string) error {
   248  	errMsg := repairErr.Error()
   249  
   250  	scriptOutput, err := ioutil.ReadFile(logPath)
   251  	if err != nil {
   252  		logger.Noticef("cannot read %s", logPath)
   253  	}
   254  	s := fmt.Sprintf("%s/%d", r.BrandID(), r.RepairID())
   255  
   256  	dupSig := fmt.Sprintf("%s\n%s\noutput:\n%s", s, errMsg, scriptOutput)
   257  	extra := map[string]string{
   258  		"Revision": strconv.Itoa(r.Revision()),
   259  		"BrandID":  r.BrandID(),
   260  		"RepairID": strconv.Itoa(r.RepairID()),
   261  		"Status":   status.String(),
   262  	}
   263  	_, err = errtrackerReportRepair(s, errMsg, dupSig, extra)
   264  	return err
   265  }
   266  
   267  // Runner implements fetching, tracking and running repairs.
   268  type Runner struct {
   269  	BaseURL *url.URL
   270  	cli     *http.Client
   271  
   272  	state         state
   273  	stateModified bool
   274  
   275  	// sequenceNext keeps track of the next integer id in a brand sequence to considered in this run, see Next.
   276  	sequenceNext map[string]int
   277  }
   278  
   279  // NewRunner returns a Runner.
   280  func NewRunner() *Runner {
   281  	run := &Runner{
   282  		sequenceNext: make(map[string]int),
   283  	}
   284  	opts := httputil.ClientOptions{
   285  		MayLogBody: false,
   286  		TLSConfig: &tls.Config{
   287  			Time: run.now,
   288  		},
   289  	}
   290  	run.cli = httputil.NewHTTPClient(&opts)
   291  	return run
   292  }
   293  
   294  var (
   295  	fetchRetryStrategy = retry.LimitCount(7, retry.LimitTime(90*time.Second,
   296  		retry.Exponential{
   297  			Initial: 500 * time.Millisecond,
   298  			Factor:  2.5,
   299  		},
   300  	))
   301  
   302  	peekRetryStrategy = retry.LimitCount(5, retry.LimitTime(44*time.Second,
   303  		retry.Exponential{
   304  			Initial: 300 * time.Millisecond,
   305  			Factor:  2.5,
   306  		},
   307  	))
   308  )
   309  
   310  var (
   311  	ErrRepairNotFound    = errors.New("repair not found")
   312  	ErrRepairNotModified = errors.New("repair was not modified")
   313  )
   314  
   315  var (
   316  	maxRepairScriptSize = 24 * 1024 * 1024
   317  )
   318  
   319  // Fetch retrieves a stream with the repair with the given ids and any
   320  // auxiliary assertions. If revision>=0 the request will include an
   321  // If-None-Match header with an ETag for the revision, and
   322  // ErrRepairNotModified is returned if the revision is still current.
   323  func (run *Runner) Fetch(brandID string, repairID int, revision int) (*asserts.Repair, []asserts.Assertion, error) {
   324  	u, err := run.BaseURL.Parse(fmt.Sprintf("repairs/%s/%d", brandID, repairID))
   325  	if err != nil {
   326  		return nil, nil, err
   327  	}
   328  
   329  	var r []asserts.Assertion
   330  	resp, err := httputil.RetryRequest(u.String(), func() (*http.Response, error) {
   331  		req, err := http.NewRequest("GET", u.String(), nil)
   332  		if err != nil {
   333  			return nil, err
   334  		}
   335  		req.Header.Set("User-Agent", httputil.UserAgent())
   336  		req.Header.Set("Accept", "application/x.ubuntu.assertion")
   337  		if revision >= 0 {
   338  			req.Header.Set("If-None-Match", fmt.Sprintf(`"%d"`, revision))
   339  		}
   340  		return run.cli.Do(req)
   341  	}, func(resp *http.Response) error {
   342  		if resp.StatusCode == 200 {
   343  			logger.Debugf("fetching repair %s-%d", brandID, repairID)
   344  			// decode assertions
   345  			dec := asserts.NewDecoderWithTypeMaxBodySize(resp.Body, map[*asserts.AssertionType]int{
   346  				asserts.RepairType: maxRepairScriptSize,
   347  			})
   348  			for {
   349  				a, err := dec.Decode()
   350  				if err == io.EOF {
   351  					break
   352  				}
   353  				if err != nil {
   354  					return err
   355  				}
   356  				r = append(r, a)
   357  			}
   358  			if len(r) == 0 {
   359  				return io.ErrUnexpectedEOF
   360  			}
   361  		}
   362  		return nil
   363  	}, fetchRetryStrategy)
   364  
   365  	if err != nil {
   366  		return nil, nil, err
   367  	}
   368  
   369  	moveTimeLowerBound := true
   370  	defer func() {
   371  		if moveTimeLowerBound {
   372  			t, _ := http.ParseTime(resp.Header.Get("Date"))
   373  			run.moveTimeLowerBound(t)
   374  		}
   375  	}()
   376  
   377  	switch resp.StatusCode {
   378  	case 200:
   379  		// ok
   380  	case 304:
   381  		// not modified
   382  		return nil, nil, ErrRepairNotModified
   383  	case 404:
   384  		return nil, nil, ErrRepairNotFound
   385  	default:
   386  		moveTimeLowerBound = false
   387  		return nil, nil, fmt.Errorf("cannot fetch repair, unexpected status %d", resp.StatusCode)
   388  	}
   389  
   390  	repair, aux, err := checkStream(brandID, repairID, r)
   391  	if err != nil {
   392  		return nil, nil, fmt.Errorf("cannot fetch repair, %v", err)
   393  	}
   394  
   395  	if repair.Revision() <= revision {
   396  		// this shouldn't happen but if it does we behave like
   397  		// all the rest of assertion infrastructure and ignore
   398  		// the now superseded revision
   399  		return nil, nil, ErrRepairNotModified
   400  	}
   401  
   402  	return repair, aux, err
   403  }
   404  
   405  func checkStream(brandID string, repairID int, r []asserts.Assertion) (repair *asserts.Repair, aux []asserts.Assertion, err error) {
   406  	if len(r) == 0 {
   407  		return nil, nil, fmt.Errorf("empty repair assertions stream")
   408  	}
   409  	var ok bool
   410  	repair, ok = r[0].(*asserts.Repair)
   411  	if !ok {
   412  		return nil, nil, fmt.Errorf("unexpected first assertion %q", r[0].Type().Name)
   413  	}
   414  
   415  	if repair.BrandID() != brandID || repair.RepairID() != repairID {
   416  		return nil, nil, fmt.Errorf("repair id mismatch %s/%d != %s/%d", repair.BrandID(), repair.RepairID(), brandID, repairID)
   417  	}
   418  
   419  	return repair, r[1:], nil
   420  }
   421  
   422  type peekResp struct {
   423  	Headers map[string]interface{} `json:"headers"`
   424  }
   425  
   426  // Peek retrieves the headers for the repair with the given ids.
   427  func (run *Runner) Peek(brandID string, repairID int) (headers map[string]interface{}, err error) {
   428  	u, err := run.BaseURL.Parse(fmt.Sprintf("repairs/%s/%d", brandID, repairID))
   429  	if err != nil {
   430  		return nil, err
   431  	}
   432  
   433  	var rsp peekResp
   434  
   435  	resp, err := httputil.RetryRequest(u.String(), func() (*http.Response, error) {
   436  		req, err := http.NewRequest("GET", u.String(), nil)
   437  		if err != nil {
   438  			return nil, err
   439  		}
   440  		req.Header.Set("User-Agent", httputil.UserAgent())
   441  		req.Header.Set("Accept", "application/json")
   442  		return run.cli.Do(req)
   443  	}, func(resp *http.Response) error {
   444  		rsp.Headers = nil
   445  		if resp.StatusCode == 200 {
   446  			dec := json.NewDecoder(resp.Body)
   447  			return dec.Decode(&rsp)
   448  		}
   449  		return nil
   450  	}, peekRetryStrategy)
   451  
   452  	if err != nil {
   453  		return nil, err
   454  	}
   455  
   456  	moveTimeLowerBound := true
   457  	defer func() {
   458  		if moveTimeLowerBound {
   459  			t, _ := http.ParseTime(resp.Header.Get("Date"))
   460  			run.moveTimeLowerBound(t)
   461  		}
   462  	}()
   463  
   464  	switch resp.StatusCode {
   465  	case 200:
   466  		// ok
   467  	case 404:
   468  		return nil, ErrRepairNotFound
   469  	default:
   470  		moveTimeLowerBound = false
   471  		return nil, fmt.Errorf("cannot peek repair headers, unexpected status %d", resp.StatusCode)
   472  	}
   473  
   474  	headers = rsp.Headers
   475  	if headers["brand-id"] != brandID || headers["repair-id"] != strconv.Itoa(repairID) {
   476  		return nil, fmt.Errorf("cannot peek repair headers, repair id mismatch %s/%s != %s/%d", headers["brand-id"], headers["repair-id"], brandID, repairID)
   477  	}
   478  
   479  	return headers, nil
   480  }
   481  
   482  // deviceInfo captures information about the device.
   483  type deviceInfo struct {
   484  	Brand string `json:"brand"`
   485  	Model string `json:"model"`
   486  }
   487  
   488  // RepairStatus represents the possible statuses of a repair.
   489  type RepairStatus int
   490  
   491  const (
   492  	RetryStatus RepairStatus = iota
   493  	SkipStatus
   494  	DoneStatus
   495  )
   496  
   497  func (rs RepairStatus) String() string {
   498  	switch rs {
   499  	case RetryStatus:
   500  		return "retry"
   501  	case SkipStatus:
   502  		return "skip"
   503  	case DoneStatus:
   504  		return "done"
   505  	default:
   506  		return "unknown"
   507  	}
   508  }
   509  
   510  // RepairState holds the current revision and status of a repair in a sequence of repairs.
   511  type RepairState struct {
   512  	Sequence int          `json:"sequence"`
   513  	Revision int          `json:"revision"`
   514  	Status   RepairStatus `json:"status"`
   515  }
   516  
   517  // state holds the atomically updated control state of the runner with sequences of repairs and their states.
   518  type state struct {
   519  	Device         deviceInfo                `json:"device"`
   520  	Sequences      map[string][]*RepairState `json:"sequences,omitempty"`
   521  	TimeLowerBound time.Time                 `json:"time-lower-bound"`
   522  }
   523  
   524  func (run *Runner) setRepairState(brandID string, state RepairState) {
   525  	if run.state.Sequences == nil {
   526  		run.state.Sequences = make(map[string][]*RepairState)
   527  	}
   528  	sequence := run.state.Sequences[brandID]
   529  	if state.Sequence > len(sequence) {
   530  		run.stateModified = true
   531  		run.state.Sequences[brandID] = append(sequence, &state)
   532  	} else if *sequence[state.Sequence-1] != state {
   533  		run.stateModified = true
   534  		sequence[state.Sequence-1] = &state
   535  	}
   536  }
   537  
   538  func (run *Runner) readState() error {
   539  	r, err := os.Open(dirs.SnapRepairStateFile)
   540  	if err != nil {
   541  		return err
   542  	}
   543  	defer r.Close()
   544  	dec := json.NewDecoder(r)
   545  	return dec.Decode(&run.state)
   546  }
   547  
   548  func (run *Runner) moveTimeLowerBound(t time.Time) {
   549  	if t.After(run.state.TimeLowerBound) {
   550  		run.stateModified = true
   551  		run.state.TimeLowerBound = t.UTC()
   552  	}
   553  }
   554  
   555  var timeNow = time.Now
   556  
   557  func (run *Runner) now() time.Time {
   558  	now := timeNow().UTC()
   559  	if now.Before(run.state.TimeLowerBound) {
   560  		return run.state.TimeLowerBound
   561  	}
   562  	return now
   563  }
   564  
   565  func (run *Runner) initState() error {
   566  	if err := os.MkdirAll(dirs.SnapRepairDir, 0775); err != nil {
   567  		return fmt.Errorf("cannot create repair state directory: %v", err)
   568  	}
   569  	// best-effort remove old
   570  	os.Remove(dirs.SnapRepairStateFile)
   571  	run.state = state{}
   572  	// initialize time lower bound with image built time/seed.yaml time
   573  	info, err := os.Stat(filepath.Join(dirs.SnapSeedDir, "seed.yaml"))
   574  	if err != nil {
   575  		return err
   576  	}
   577  	run.moveTimeLowerBound(info.ModTime())
   578  	// initialize device info
   579  	if err := run.initDeviceInfo(); err != nil {
   580  		return err
   581  	}
   582  	run.stateModified = true
   583  	return run.SaveState()
   584  }
   585  
   586  func trustedBackstore(trusted []asserts.Assertion) asserts.Backstore {
   587  	trustedBS := asserts.NewMemoryBackstore()
   588  	for _, t := range trusted {
   589  		trustedBS.Put(t.Type(), t)
   590  	}
   591  	return trustedBS
   592  }
   593  
   594  func checkAuthorityID(a asserts.Assertion, trusted asserts.Backstore) error {
   595  	assertType := a.Type()
   596  	if assertType != asserts.AccountKeyType && assertType != asserts.AccountType {
   597  		return nil
   598  	}
   599  	// check that account and account-key assertions are signed by
   600  	// a trusted authority
   601  	acctID := a.AuthorityID()
   602  	_, err := trusted.Get(asserts.AccountType, []string{acctID}, asserts.AccountType.MaxSupportedFormat())
   603  	if err != nil && !asserts.IsNotFound(err) {
   604  		return err
   605  	}
   606  	if asserts.IsNotFound(err) {
   607  		return fmt.Errorf("%v not signed by trusted authority: %s", a.Ref(), acctID)
   608  	}
   609  	return nil
   610  }
   611  
   612  func verifySignatures(a asserts.Assertion, workBS asserts.Backstore, trusted asserts.Backstore) error {
   613  	if err := checkAuthorityID(a, trusted); err != nil {
   614  		return err
   615  	}
   616  	acctKeyMaxSuppFormat := asserts.AccountKeyType.MaxSupportedFormat()
   617  
   618  	seen := make(map[string]bool)
   619  	bottom := false
   620  	for !bottom {
   621  		u := a.Ref().Unique()
   622  		if seen[u] {
   623  			return fmt.Errorf("circular assertions")
   624  		}
   625  		seen[u] = true
   626  		signKey := []string{a.SignKeyID()}
   627  		key, err := trusted.Get(asserts.AccountKeyType, signKey, acctKeyMaxSuppFormat)
   628  		if err != nil && !asserts.IsNotFound(err) {
   629  			return err
   630  		}
   631  		if err == nil {
   632  			bottom = true
   633  		} else {
   634  			key, err = workBS.Get(asserts.AccountKeyType, signKey, acctKeyMaxSuppFormat)
   635  			if err != nil && !asserts.IsNotFound(err) {
   636  				return err
   637  			}
   638  			if asserts.IsNotFound(err) {
   639  				return fmt.Errorf("cannot find public key %q", signKey[0])
   640  			}
   641  			if err := checkAuthorityID(key, trusted); err != nil {
   642  				return err
   643  			}
   644  		}
   645  		if err := asserts.CheckSignature(a, key.(*asserts.AccountKey), nil, time.Time{}); err != nil {
   646  			return err
   647  		}
   648  		a = key
   649  	}
   650  	return nil
   651  }
   652  
   653  func (run *Runner) initDeviceInfo() error {
   654  	const errPrefix = "cannot set device information: "
   655  
   656  	workBS := asserts.NewMemoryBackstore()
   657  	assertSeedDir := filepath.Join(dirs.SnapSeedDir, "assertions")
   658  	dc, err := ioutil.ReadDir(assertSeedDir)
   659  	if err != nil {
   660  		return err
   661  	}
   662  	var model *asserts.Model
   663  	for _, fi := range dc {
   664  		fn := filepath.Join(assertSeedDir, fi.Name())
   665  		f, err := os.Open(fn)
   666  		if err != nil {
   667  			// best effort
   668  			continue
   669  		}
   670  		dec := asserts.NewDecoder(f)
   671  		for {
   672  			a, err := dec.Decode()
   673  			if err != nil {
   674  				// best effort
   675  				break
   676  			}
   677  			switch a.Type() {
   678  			case asserts.ModelType:
   679  				if model != nil {
   680  					return fmt.Errorf(errPrefix + "multiple models in seed assertions")
   681  				}
   682  				model = a.(*asserts.Model)
   683  			case asserts.AccountType, asserts.AccountKeyType:
   684  				workBS.Put(a.Type(), a)
   685  			}
   686  		}
   687  	}
   688  	if model == nil {
   689  		return fmt.Errorf(errPrefix + "no model assertion in seed data")
   690  	}
   691  	trustedBS := trustedBackstore(sysdb.Trusted())
   692  	if err := verifySignatures(model, workBS, trustedBS); err != nil {
   693  		return fmt.Errorf(errPrefix+"%v", err)
   694  	}
   695  	acctPK := []string{model.BrandID()}
   696  	acctMaxSupFormat := asserts.AccountType.MaxSupportedFormat()
   697  	acct, err := trustedBS.Get(asserts.AccountType, acctPK, acctMaxSupFormat)
   698  	if err != nil {
   699  		var err error
   700  		acct, err = workBS.Get(asserts.AccountType, acctPK, acctMaxSupFormat)
   701  		if err != nil {
   702  			return fmt.Errorf(errPrefix + "no brand account assertion in seed data")
   703  		}
   704  	}
   705  	if err := verifySignatures(acct, workBS, trustedBS); err != nil {
   706  		return fmt.Errorf(errPrefix+"%v", err)
   707  	}
   708  	run.state.Device.Brand = model.BrandID()
   709  	run.state.Device.Model = model.Model()
   710  	return nil
   711  }
   712  
   713  // LoadState loads the repairs' state from disk, and (re)initializes it if it's missing or corrupted.
   714  func (run *Runner) LoadState() error {
   715  	err := run.readState()
   716  	if err == nil {
   717  		return nil
   718  	}
   719  	// error => initialize from scratch
   720  	if !os.IsNotExist(err) {
   721  		logger.Noticef("cannor read repair state: %v", err)
   722  	}
   723  	return run.initState()
   724  }
   725  
   726  // SaveState saves the repairs' state to disk.
   727  func (run *Runner) SaveState() error {
   728  	if !run.stateModified {
   729  		return nil
   730  	}
   731  	m, err := json.Marshal(&run.state)
   732  	if err != nil {
   733  		return fmt.Errorf("cannot marshal repair state: %v", err)
   734  	}
   735  	err = osutil.AtomicWriteFile(dirs.SnapRepairStateFile, m, 0600, 0)
   736  	if err != nil {
   737  		return fmt.Errorf("cannot save repair state: %v", err)
   738  	}
   739  	run.stateModified = false
   740  	return nil
   741  }
   742  
   743  func stringList(headers map[string]interface{}, name string) ([]string, error) {
   744  	v, ok := headers[name]
   745  	if !ok {
   746  		return nil, nil
   747  	}
   748  	l, ok := v.([]interface{})
   749  	if !ok {
   750  		return nil, fmt.Errorf("header %q is not a list", name)
   751  	}
   752  	r := make([]string, len(l))
   753  	for i, v := range l {
   754  		s, ok := v.(string)
   755  		if !ok {
   756  			return nil, fmt.Errorf("header %q contains non-string elements", name)
   757  		}
   758  		r[i] = s
   759  	}
   760  	return r, nil
   761  }
   762  
   763  // Applicable returns whether a repair with the given headers is applicable to the device.
   764  func (run *Runner) Applicable(headers map[string]interface{}) bool {
   765  	if headers["disabled"] == "true" {
   766  		return false
   767  	}
   768  	series, err := stringList(headers, "series")
   769  	if err != nil {
   770  		return false
   771  	}
   772  	if len(series) != 0 && !strutil.ListContains(series, release.Series) {
   773  		return false
   774  	}
   775  	archs, err := stringList(headers, "architectures")
   776  	if err != nil {
   777  		return false
   778  	}
   779  	if len(archs) != 0 && !strutil.ListContains(archs, arch.DpkgArchitecture()) {
   780  		return false
   781  	}
   782  	brandModel := fmt.Sprintf("%s/%s", run.state.Device.Brand, run.state.Device.Model)
   783  	models, err := stringList(headers, "models")
   784  	if err != nil {
   785  		return false
   786  	}
   787  	if len(models) != 0 && !strutil.ListContains(models, brandModel) {
   788  		// model prefix matching: brand/prefix*
   789  		hit := false
   790  		for _, patt := range models {
   791  			if strings.HasSuffix(patt, "*") && strings.ContainsRune(patt, '/') {
   792  				if strings.HasPrefix(brandModel, strings.TrimSuffix(patt, "*")) {
   793  					hit = true
   794  					break
   795  				}
   796  			}
   797  		}
   798  		if !hit {
   799  			return false
   800  		}
   801  	}
   802  	return true
   803  }
   804  
   805  var errSkip = errors.New("repair unnecessary on this system")
   806  
   807  func (run *Runner) fetch(brandID string, repairID int) (repair *asserts.Repair, aux []asserts.Assertion, err error) {
   808  	headers, err := run.Peek(brandID, repairID)
   809  	if err != nil {
   810  		return nil, nil, err
   811  	}
   812  	if !run.Applicable(headers) {
   813  		return nil, nil, errSkip
   814  	}
   815  	return run.Fetch(brandID, repairID, -1)
   816  }
   817  
   818  func (run *Runner) refetch(brandID string, repairID, revision int) (repair *asserts.Repair, aux []asserts.Assertion, err error) {
   819  	return run.Fetch(brandID, repairID, revision)
   820  }
   821  
   822  func (run *Runner) saveStream(brandID string, repairID int, repair *asserts.Repair, aux []asserts.Assertion) error {
   823  	d := filepath.Join(dirs.SnapRepairAssertsDir, brandID, strconv.Itoa(repairID))
   824  	err := os.MkdirAll(d, 0775)
   825  	if err != nil {
   826  		return err
   827  	}
   828  	buf := &bytes.Buffer{}
   829  	enc := asserts.NewEncoder(buf)
   830  	r := append([]asserts.Assertion{repair}, aux...)
   831  	for _, a := range r {
   832  		if err := enc.Encode(a); err != nil {
   833  			return fmt.Errorf("cannot encode repair assertions %s-%d for saving: %v", brandID, repairID, err)
   834  		}
   835  	}
   836  	p := filepath.Join(d, fmt.Sprintf("r%d.repair", r[0].Revision()))
   837  	return osutil.AtomicWriteFile(p, buf.Bytes(), 0600, 0)
   838  }
   839  
   840  func (run *Runner) readSavedStream(brandID string, repairID, revision int) (repair *asserts.Repair, aux []asserts.Assertion, err error) {
   841  	d := filepath.Join(dirs.SnapRepairAssertsDir, brandID, strconv.Itoa(repairID))
   842  	p := filepath.Join(d, fmt.Sprintf("r%d.repair", revision))
   843  	f, err := os.Open(p)
   844  	if err != nil {
   845  		return nil, nil, err
   846  	}
   847  	defer f.Close()
   848  
   849  	dec := asserts.NewDecoder(f)
   850  	var r []asserts.Assertion
   851  	for {
   852  		a, err := dec.Decode()
   853  		if err == io.EOF {
   854  			break
   855  		}
   856  		if err != nil {
   857  			return nil, nil, fmt.Errorf("cannot decode repair assertions %s-%d from disk: %v", brandID, repairID, err)
   858  		}
   859  		r = append(r, a)
   860  	}
   861  	return checkStream(brandID, repairID, r)
   862  }
   863  
   864  func (run *Runner) makeReady(brandID string, sequenceNext int) (repair *asserts.Repair, err error) {
   865  	sequence := run.state.Sequences[brandID]
   866  	var aux []asserts.Assertion
   867  	var state RepairState
   868  	if sequenceNext <= len(sequence) {
   869  		// consider retries
   870  		state = *sequence[sequenceNext-1]
   871  		if state.Status != RetryStatus {
   872  			return nil, errSkip
   873  		}
   874  		var err error
   875  		repair, aux, err = run.refetch(brandID, state.Sequence, state.Revision)
   876  		if err != nil {
   877  			if err != ErrRepairNotModified {
   878  				logger.Noticef("cannot refetch repair %s-%d, will retry what is on disk: %v", brandID, sequenceNext, err)
   879  			}
   880  			// try to use what we have already on disk
   881  			repair, aux, err = run.readSavedStream(brandID, state.Sequence, state.Revision)
   882  			if err != nil {
   883  				return nil, err
   884  			}
   885  		}
   886  	} else {
   887  		// fetch the next repair in the sequence
   888  		// assumes no gaps, each repair id is present so far,
   889  		// possibly skipped
   890  		var err error
   891  		repair, aux, err = run.fetch(brandID, sequenceNext)
   892  		if err != nil && err != errSkip {
   893  			return nil, err
   894  		}
   895  		state = RepairState{
   896  			Sequence: sequenceNext,
   897  		}
   898  		if err == errSkip {
   899  			// TODO: store headers to justify decision
   900  			state.Status = SkipStatus
   901  			run.setRepairState(brandID, state)
   902  			return nil, errSkip
   903  		}
   904  	}
   905  	// verify with signatures
   906  	if err := run.Verify(repair, aux); err != nil {
   907  		return nil, fmt.Errorf("cannot verify repair %s-%d: %v", brandID, state.Sequence, err)
   908  	}
   909  	if err := run.saveStream(brandID, state.Sequence, repair, aux); err != nil {
   910  		return nil, err
   911  	}
   912  	state.Revision = repair.Revision()
   913  	if !run.Applicable(repair.Headers()) {
   914  		state.Status = SkipStatus
   915  		run.setRepairState(brandID, state)
   916  		return nil, errSkip
   917  	}
   918  	run.setRepairState(brandID, state)
   919  	return repair, nil
   920  }
   921  
   922  // Next returns the next repair for the brand id sequence to run/retry or ErrRepairNotFound if there is none atm. It updates the state as required.
   923  func (run *Runner) Next(brandID string) (*Repair, error) {
   924  	sequenceNext := run.sequenceNext[brandID]
   925  	if sequenceNext == 0 {
   926  		sequenceNext = 1
   927  	}
   928  	for {
   929  		repair, err := run.makeReady(brandID, sequenceNext)
   930  		// SaveState is a no-op unless makeReady modified the state
   931  		stateErr := run.SaveState()
   932  		if err != nil && err != errSkip && err != ErrRepairNotFound {
   933  			// err is a non trivial error, just log the SaveState error and report err
   934  			if stateErr != nil {
   935  				logger.Noticef("%v", stateErr)
   936  			}
   937  			return nil, err
   938  		}
   939  		if stateErr != nil {
   940  			return nil, stateErr
   941  		}
   942  		if err == ErrRepairNotFound {
   943  			return nil, ErrRepairNotFound
   944  		}
   945  
   946  		sequenceNext += 1
   947  		run.sequenceNext[brandID] = sequenceNext
   948  		if err == errSkip {
   949  			continue
   950  		}
   951  
   952  		return &Repair{
   953  			Repair:   repair,
   954  			run:      run,
   955  			sequence: sequenceNext - 1,
   956  		}, nil
   957  	}
   958  }
   959  
   960  // Limit trust to specific keys while there's no delegation or limited
   961  // keys support.  The obtained assertion stream may also include
   962  // account keys that are directly or indirectly signed by a trusted
   963  // key.
   964  var (
   965  	trustedRepairRootKeys []*asserts.AccountKey
   966  )
   967  
   968  // Verify verifies that the repair is properly signed by the specific
   969  // trusted root keys or by account keys in the stream (passed via aux)
   970  // directly or indirectly signed by a trusted key.
   971  func (run *Runner) Verify(repair *asserts.Repair, aux []asserts.Assertion) error {
   972  	workBS := asserts.NewMemoryBackstore()
   973  	for _, a := range aux {
   974  		if a.Type() != asserts.AccountKeyType {
   975  			continue
   976  		}
   977  		err := workBS.Put(asserts.AccountKeyType, a)
   978  		if err != nil {
   979  			return err
   980  		}
   981  	}
   982  	trustedBS := asserts.NewMemoryBackstore()
   983  	for _, t := range trustedRepairRootKeys {
   984  		trustedBS.Put(asserts.AccountKeyType, t)
   985  	}
   986  	for _, t := range sysdb.Trusted() {
   987  		if t.Type() == asserts.AccountType {
   988  			trustedBS.Put(asserts.AccountType, t)
   989  		}
   990  	}
   991  
   992  	return verifySignatures(repair, workBS, trustedBS)
   993  }