github.com/meulengracht/snapd@v0.0.0-20210719210640-8bde69bcc84e/overlord/snapstate/autorefresh_gating.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2021 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package snapstate
    21  
    22  import (
    23  	"encoding/json"
    24  	"fmt"
    25  	"os"
    26  	"sort"
    27  	"strings"
    28  	"time"
    29  
    30  	"github.com/snapcore/snapd/interfaces"
    31  	"github.com/snapcore/snapd/interfaces/mount"
    32  	"github.com/snapcore/snapd/logger"
    33  	"github.com/snapcore/snapd/overlord/ifacestate/ifacerepo"
    34  	"github.com/snapcore/snapd/overlord/state"
    35  	"github.com/snapcore/snapd/release"
    36  	"github.com/snapcore/snapd/snap"
    37  )
    38  
    39  var gateAutoRefreshHookName = "gate-auto-refresh"
    40  
    41  // gateAutoRefreshAction represents the action executed by
    42  // snapctl refresh --hold or --proceed and stored in the context of
    43  // gate-auto-refresh hook.
    44  type GateAutoRefreshAction int
    45  
    46  const (
    47  	GateAutoRefreshProceed GateAutoRefreshAction = iota
    48  	GateAutoRefreshHold
    49  )
    50  
    51  // cumulative hold time for snaps other than self
    52  const maxOtherHoldDuration = time.Hour * 48
    53  
    54  var timeNow = func() time.Time {
    55  	return time.Now()
    56  }
    57  
    58  func lastRefreshed(st *state.State, snapName string) (time.Time, error) {
    59  	var snapst SnapState
    60  	if err := Get(st, snapName, &snapst); err != nil {
    61  		return time.Time{}, fmt.Errorf("internal error, cannot get snap %q: %v", snapName, err)
    62  	}
    63  	// try to get last refresh time from snapstate, but it may not be present
    64  	// for snaps installed before the introduction of last-refresh attribute.
    65  	if snapst.LastRefreshTime != nil {
    66  		return *snapst.LastRefreshTime, nil
    67  	}
    68  	snapInfo, err := snapst.CurrentInfo()
    69  	if err != nil {
    70  		return time.Time{}, err
    71  	}
    72  	// fall back to the modification time of .snap blob file as it's the best
    73  	// approximation of last refresh time.
    74  	fst, err := os.Stat(snapInfo.MountFile())
    75  	if err != nil {
    76  		return time.Time{}, err
    77  	}
    78  	return fst.ModTime(), nil
    79  }
    80  
    81  type holdState struct {
    82  	// FirstHeld keeps the time when the given snap was first held for refresh by a gating snap.
    83  	FirstHeld time.Time `json:"first-held"`
    84  	// HoldUntil stores the desired end time for holding.
    85  	HoldUntil time.Time `json:"hold-until"`
    86  }
    87  
    88  func refreshGating(st *state.State) (map[string]map[string]*holdState, error) {
    89  	// held snaps -> holding snap(s) -> first-held/hold-until time
    90  	var gating map[string]map[string]*holdState
    91  	err := st.Get("snaps-hold", &gating)
    92  	if err != nil && err != state.ErrNoState {
    93  		return nil, fmt.Errorf("internal error: cannot get snaps-hold: %v", err)
    94  	}
    95  	if err == state.ErrNoState {
    96  		return make(map[string]map[string]*holdState), nil
    97  	}
    98  	return gating, nil
    99  }
   100  
   101  // HoldDurationError contains the that error prevents requested hold, along with
   102  // hold time that's left (if any).
   103  type HoldDurationError struct {
   104  	Err          error
   105  	DurationLeft time.Duration
   106  }
   107  
   108  func (h *HoldDurationError) Error() string {
   109  	return h.Err.Error()
   110  }
   111  
   112  // HoldError contains the details of snaps that cannot to be held.
   113  type HoldError struct {
   114  	SnapsInError map[string]HoldDurationError
   115  }
   116  
   117  func (h *HoldError) Error() string {
   118  	l := []string{""}
   119  	for _, e := range h.SnapsInError {
   120  		l = append(l, e.Error())
   121  	}
   122  	return fmt.Sprintf("cannot hold some snaps:%s", strings.Join(l, "\n - "))
   123  }
   124  
   125  func maxAllowedPostponement(gatingSnap, affectedSnap string, maxPostponement time.Duration) time.Duration {
   126  	if affectedSnap == gatingSnap {
   127  		return maxPostponement
   128  	}
   129  	return maxOtherHoldDuration
   130  }
   131  
   132  // holdDurationLeft computes the maximum duration that's left for holding a refresh
   133  // given current time, last refresh time, time when snap was first held, maximum
   134  // duration allowed for the given snap and maximum overall postponement allowed by
   135  // snapd.
   136  func holdDurationLeft(now time.Time, lastRefresh, firstHeld time.Time, maxDuration, maxPostponement time.Duration) time.Duration {
   137  	d1 := firstHeld.Add(maxDuration).Sub(now)
   138  	d2 := lastRefresh.Add(maxPostponement).Sub(now)
   139  	if d1 < d2 {
   140  		return d1
   141  	}
   142  	return d2
   143  }
   144  
   145  // HoldRefresh marks affectingSnaps as held for refresh for up to holdTime.
   146  // HoldTime of zero denotes maximum allowed hold time.
   147  // Holding may fail for only some snaps in which case HoldError is returned and
   148  // it contains the details of failed ones.
   149  func HoldRefresh(st *state.State, gatingSnap string, holdDuration time.Duration, affectingSnaps ...string) error {
   150  	gating, err := refreshGating(st)
   151  	if err != nil {
   152  		return err
   153  	}
   154  	herr := &HoldError{
   155  		SnapsInError: make(map[string]HoldDurationError),
   156  	}
   157  	now := timeNow()
   158  	for _, heldSnap := range affectingSnaps {
   159  		hold, ok := gating[heldSnap][gatingSnap]
   160  		if !ok {
   161  			hold = &holdState{
   162  				FirstHeld: now,
   163  			}
   164  		}
   165  
   166  		lastRefreshTime, err := lastRefreshed(st, heldSnap)
   167  		if err != nil {
   168  			return err
   169  		}
   170  
   171  		mp := maxPostponement - maxPostponementBuffer
   172  		maxDur := maxAllowedPostponement(gatingSnap, heldSnap, mp)
   173  
   174  		// calculate max hold duration that's left considering previous hold
   175  		// requests of this snap and last refresh time.
   176  		left := holdDurationLeft(now, lastRefreshTime, hold.FirstHeld, maxDur, mp)
   177  		if left <= 0 {
   178  			herr.SnapsInError[heldSnap] = HoldDurationError{
   179  				Err: fmt.Errorf("snap %q cannot hold snap %q anymore, maximum refresh postponement exceeded", gatingSnap, heldSnap),
   180  			}
   181  			continue
   182  		}
   183  
   184  		dur := holdDuration
   185  		if dur == 0 {
   186  			// duration not specified, using a default one (maximum) or what's
   187  			// left of it.
   188  			dur = left
   189  		} else {
   190  			// explicit hold duration requested
   191  			if dur > maxDur {
   192  				herr.SnapsInError[heldSnap] = HoldDurationError{
   193  					Err:          fmt.Errorf("requested holding duration for snap %q of %s by snap %q exceeds maximum holding time", heldSnap, holdDuration, gatingSnap),
   194  					DurationLeft: left,
   195  				}
   196  				continue
   197  			}
   198  		}
   199  
   200  		newHold := now.Add(dur)
   201  		cutOff := lastRefreshTime.Add(maxPostponement - maxPostponementBuffer)
   202  
   203  		// consider last refresh time and adjust hold duration if needed so it's
   204  		// not exceeded.
   205  		if newHold.Before(cutOff) {
   206  			hold.HoldUntil = newHold
   207  		} else {
   208  			hold.HoldUntil = cutOff
   209  		}
   210  
   211  		// finally store/update gating hold data
   212  		if _, ok := gating[heldSnap]; !ok {
   213  			gating[heldSnap] = make(map[string]*holdState)
   214  		}
   215  		gating[heldSnap][gatingSnap] = hold
   216  	}
   217  
   218  	if len(herr.SnapsInError) != len(affectingSnaps) {
   219  		st.Set("snaps-hold", gating)
   220  	}
   221  	if len(herr.SnapsInError) > 0 {
   222  		return herr
   223  	}
   224  	return nil
   225  }
   226  
   227  // ProceedWithRefresh unblocks all snaps held by gatingSnap for refresh. This
   228  // should be called for --proceed on the gatingSnap.
   229  func ProceedWithRefresh(st *state.State, gatingSnap string) error {
   230  	gating, err := refreshGating(st)
   231  	if err != nil {
   232  		return err
   233  	}
   234  	if len(gating) == 0 {
   235  		return nil
   236  	}
   237  
   238  	var changed bool
   239  	for heldSnap, gatingSnaps := range gating {
   240  		if _, ok := gatingSnaps[gatingSnap]; ok {
   241  			delete(gatingSnaps, gatingSnap)
   242  			changed = true
   243  		}
   244  		if len(gatingSnaps) == 0 {
   245  			delete(gating, heldSnap)
   246  		}
   247  	}
   248  
   249  	if changed {
   250  		st.Set("snaps-hold", gating)
   251  	}
   252  	return nil
   253  }
   254  
   255  // pruneGating removes affecting snaps that are not in candidates (meaning
   256  // there is no update for them anymore).
   257  func pruneGating(st *state.State, candidates map[string]*refreshCandidate) error {
   258  	gating, err := refreshGating(st)
   259  	if err != nil {
   260  		return err
   261  	}
   262  
   263  	if len(gating) == 0 {
   264  		return nil
   265  	}
   266  
   267  	var changed bool
   268  	for affectingSnap := range gating {
   269  		if candidates[affectingSnap] == nil {
   270  			// the snap doesn't have an update anymore, forget it
   271  			delete(gating, affectingSnap)
   272  			changed = true
   273  		}
   274  	}
   275  	if changed {
   276  		st.Set("snaps-hold", gating)
   277  	}
   278  	return nil
   279  }
   280  
   281  // resetGatingForRefreshed resets gating information by removing refreshedSnaps
   282  // (they are not held anymore). This should be called for snaps about to be
   283  // refreshed.
   284  func resetGatingForRefreshed(st *state.State, refreshedSnaps ...string) error {
   285  	gating, err := refreshGating(st)
   286  	if err != nil {
   287  		return err
   288  	}
   289  	if len(gating) == 0 {
   290  		return nil
   291  	}
   292  
   293  	var changed bool
   294  	for _, snapName := range refreshedSnaps {
   295  		if _, ok := gating[snapName]; ok {
   296  			delete(gating, snapName)
   297  			changed = true
   298  		}
   299  	}
   300  
   301  	if changed {
   302  		st.Set("snaps-hold", gating)
   303  	}
   304  	return nil
   305  }
   306  
   307  // pruneSnapsHold removes the given snap from snaps-hold, whether it was an
   308  // affecting snap or gating snap. This should be called when a snap gets
   309  // removed.
   310  func pruneSnapsHold(st *state.State, snapName string) error {
   311  	gating, err := refreshGating(st)
   312  	if err != nil {
   313  		return err
   314  	}
   315  	if len(gating) == 0 {
   316  		return nil
   317  	}
   318  
   319  	var changed bool
   320  
   321  	if _, ok := gating[snapName]; ok {
   322  		delete(gating, snapName)
   323  		changed = true
   324  	}
   325  
   326  	for heldSnap, holdingSnaps := range gating {
   327  		if _, ok := holdingSnaps[snapName]; ok {
   328  			delete(holdingSnaps, snapName)
   329  			if len(holdingSnaps) == 0 {
   330  				delete(gating, heldSnap)
   331  			}
   332  			changed = true
   333  		}
   334  	}
   335  
   336  	if changed {
   337  		st.Set("snaps-hold", gating)
   338  	}
   339  
   340  	return nil
   341  }
   342  
   343  // heldSnaps returns all snaps that are gated and shouldn't be refreshed.
   344  func heldSnaps(st *state.State) (map[string]bool, error) {
   345  	gating, err := refreshGating(st)
   346  	if err != nil {
   347  		return nil, err
   348  	}
   349  	if len(gating) == 0 {
   350  		return nil, nil
   351  	}
   352  
   353  	now := timeNow()
   354  
   355  	held := make(map[string]bool)
   356  Loop:
   357  	for heldSnap, holdingSnaps := range gating {
   358  		refreshed, err := lastRefreshed(st, heldSnap)
   359  		if err != nil {
   360  			return nil, err
   361  		}
   362  		// make sure we don't hold any snap for more than maxPostponement
   363  		if refreshed.Add(maxPostponement).Before(now) {
   364  			continue
   365  		}
   366  		for _, hold := range holdingSnaps {
   367  			if hold.HoldUntil.Before(now) {
   368  				continue
   369  			}
   370  			held[heldSnap] = true
   371  			continue Loop
   372  		}
   373  	}
   374  	return held, nil
   375  }
   376  
   377  type affectedSnapInfo struct {
   378  	Restart        bool
   379  	Base           bool
   380  	AffectingSnaps map[string]bool
   381  }
   382  
   383  func affectedByRefresh(st *state.State, updates []*snap.Info) (map[string]*affectedSnapInfo, error) {
   384  	all, err := All(st)
   385  	if err != nil {
   386  		return nil, err
   387  	}
   388  
   389  	var bootBase string
   390  	if !release.OnClassic {
   391  		deviceCtx, err := DeviceCtx(st, nil, nil)
   392  		if err != nil {
   393  			return nil, fmt.Errorf("cannot get device context: %v", err)
   394  		}
   395  		bootBaseInfo, err := BootBaseInfo(st, deviceCtx)
   396  		if err != nil {
   397  			return nil, fmt.Errorf("cannot get boot base info: %v", err)
   398  		}
   399  		bootBase = bootBaseInfo.InstanceName()
   400  	}
   401  
   402  	byBase := make(map[string][]string)
   403  	for name, snapSt := range all {
   404  		if !snapSt.Active {
   405  			delete(all, name)
   406  			continue
   407  		}
   408  		inf, err := snapSt.CurrentInfo()
   409  		if err != nil {
   410  			return nil, err
   411  		}
   412  		// optimization: do not consider snaps that don't have gate-auto-refresh hook.
   413  		if inf.Hooks[gateAutoRefreshHookName] == nil {
   414  			delete(all, name)
   415  			continue
   416  		}
   417  
   418  		base := inf.Base
   419  		if base == "none" {
   420  			continue
   421  		}
   422  		if inf.Base == "" {
   423  			base = "core"
   424  		}
   425  		byBase[base] = append(byBase[base], inf.InstanceName())
   426  	}
   427  
   428  	affected := make(map[string]*affectedSnapInfo)
   429  
   430  	addAffected := func(snapName, affectedBy string, restart bool, base bool) {
   431  		if affected[snapName] == nil {
   432  			affected[snapName] = &affectedSnapInfo{
   433  				AffectingSnaps: map[string]bool{},
   434  			}
   435  		}
   436  		affectedInfo := affected[snapName]
   437  		if restart {
   438  			affectedInfo.Restart = restart
   439  		}
   440  		if base {
   441  			affectedInfo.Base = base
   442  		}
   443  		affectedInfo.AffectingSnaps[affectedBy] = true
   444  	}
   445  
   446  	for _, up := range updates {
   447  		// on core system, affected by update of boot base
   448  		if bootBase != "" && up.InstanceName() == bootBase {
   449  			for _, snapSt := range all {
   450  				addAffected(snapSt.InstanceName(), up.InstanceName(), true, false)
   451  			}
   452  		}
   453  
   454  		// snaps that can trigger reboot
   455  		// XXX: gadget refresh doesn't always require reboot, refine this
   456  		if up.Type() == snap.TypeKernel || up.Type() == snap.TypeGadget {
   457  			for _, snapSt := range all {
   458  				addAffected(snapSt.InstanceName(), up.InstanceName(), true, false)
   459  			}
   460  			continue
   461  		}
   462  		if up.Type() == snap.TypeBase || up.SnapName() == "core" {
   463  			// affected by refresh of this base snap
   464  			for _, snapName := range byBase[up.InstanceName()] {
   465  				addAffected(snapName, up.InstanceName(), false, true)
   466  			}
   467  		}
   468  
   469  		repo := ifacerepo.Get(st)
   470  
   471  		// consider slots provided by refreshed snap, but exclude core and snapd
   472  		// since they provide system-level slots that are generally not disrupted
   473  		// by snap updates.
   474  		if up.SnapType != snap.TypeSnapd && up.SnapName() != "core" {
   475  			for _, slotInfo := range up.Slots {
   476  				conns, err := repo.Connected(up.InstanceName(), slotInfo.Name)
   477  				if err != nil {
   478  					return nil, err
   479  				}
   480  				for _, cref := range conns {
   481  					// affected only if it wasn't optimized out above
   482  					if all[cref.PlugRef.Snap] != nil {
   483  						addAffected(cref.PlugRef.Snap, up.InstanceName(), true, false)
   484  					}
   485  				}
   486  			}
   487  		}
   488  
   489  		// consider mount backend plugs/slots;
   490  		// for slot side only consider snapd/core because they are ignored by the
   491  		// earlier loop around slots.
   492  		if up.SnapType == snap.TypeSnapd || up.SnapType == snap.TypeOS {
   493  			for _, slotInfo := range up.Slots {
   494  				iface := repo.Interface(slotInfo.Interface)
   495  				if iface == nil {
   496  					return nil, fmt.Errorf("internal error: unknown interface %s", slotInfo.Interface)
   497  				}
   498  				if !usesMountBackend(iface) {
   499  					continue
   500  				}
   501  				conns, err := repo.Connected(up.InstanceName(), slotInfo.Name)
   502  				if err != nil {
   503  					return nil, err
   504  				}
   505  				for _, cref := range conns {
   506  					if all[cref.PlugRef.Snap] != nil {
   507  						addAffected(cref.PlugRef.Snap, up.InstanceName(), true, false)
   508  					}
   509  				}
   510  			}
   511  		}
   512  		for _, plugInfo := range up.Plugs {
   513  			iface := repo.Interface(plugInfo.Interface)
   514  			if iface == nil {
   515  				return nil, fmt.Errorf("internal error: unknown interface %s", plugInfo.Interface)
   516  			}
   517  			if !usesMountBackend(iface) {
   518  				continue
   519  			}
   520  			conns, err := repo.Connected(up.InstanceName(), plugInfo.Name)
   521  			if err != nil {
   522  				return nil, err
   523  			}
   524  			for _, cref := range conns {
   525  				if all[cref.SlotRef.Snap] != nil {
   526  					addAffected(cref.SlotRef.Snap, up.InstanceName(), true, false)
   527  				}
   528  			}
   529  		}
   530  	}
   531  
   532  	return affected, nil
   533  }
   534  
   535  // XXX: this is too wide and affects all commonInterface-based interfaces; we
   536  // need metadata on the relevant interfaces.
   537  func usesMountBackend(iface interfaces.Interface) bool {
   538  	type definer1 interface {
   539  		MountConnectedSlot(*mount.Specification, *interfaces.ConnectedPlug, *interfaces.ConnectedSlot) error
   540  	}
   541  	type definer2 interface {
   542  		MountConnectedPlug(*mount.Specification, *interfaces.ConnectedPlug, *interfaces.ConnectedSlot) error
   543  	}
   544  	type definer3 interface {
   545  		MountPermanentPlug(*mount.Specification, *snap.PlugInfo) error
   546  	}
   547  	type definer4 interface {
   548  		MountPermanentSlot(*mount.Specification, *snap.SlotInfo) error
   549  	}
   550  
   551  	if _, ok := iface.(definer1); ok {
   552  		return true
   553  	}
   554  	if _, ok := iface.(definer2); ok {
   555  		return true
   556  	}
   557  	if _, ok := iface.(definer3); ok {
   558  		return true
   559  	}
   560  	if _, ok := iface.(definer4); ok {
   561  		return true
   562  	}
   563  	return false
   564  }
   565  
   566  // createGateAutoRefreshHooks creates gate-auto-refresh hooks for all affectedSnaps.
   567  // The hooks will have their context data set from affectedSnapInfo flags (base, restart).
   568  // Hook tasks will be chained to run sequentially.
   569  func createGateAutoRefreshHooks(st *state.State, affectedSnaps map[string]*affectedSnapInfo) *state.TaskSet {
   570  	ts := state.NewTaskSet()
   571  	var prev *state.Task
   572  	// sort names for easy testing
   573  	names := make([]string, 0, len(affectedSnaps))
   574  	for snapName := range affectedSnaps {
   575  		names = append(names, snapName)
   576  	}
   577  	sort.Strings(names)
   578  	for _, snapName := range names {
   579  		affected := affectedSnaps[snapName]
   580  		hookTask := SetupGateAutoRefreshHook(st, snapName, affected.Base, affected.Restart, affected.AffectingSnaps)
   581  		// XXX: it should be fine to run the hooks in parallel
   582  		if prev != nil {
   583  			hookTask.WaitFor(prev)
   584  		}
   585  		ts.AddTask(hookTask)
   586  		prev = hookTask
   587  	}
   588  	return ts
   589  }
   590  
   591  func conditionalAutoRefreshAffectedSnaps(t *state.Task) ([]string, error) {
   592  	var snaps map[string]*json.RawMessage
   593  	if err := t.Get("snaps", &snaps); err != nil {
   594  		return nil, fmt.Errorf("internal error: cannot get snaps to update for %s task %s", t.Kind(), t.ID())
   595  	}
   596  	names := make([]string, 0, len(snaps))
   597  	for sn := range snaps {
   598  		// TODO: drop snaps once we know the outcome of gate-auto-refresh hooks.
   599  		names = append(names, sn)
   600  	}
   601  	return names, nil
   602  }
   603  
   604  // snapsToRefresh returns all snaps that should proceed with refresh considering
   605  // hold behavior.
   606  var snapsToRefresh = func(gatingTask *state.Task) ([]*refreshCandidate, error) {
   607  	var snaps map[string]*refreshCandidate
   608  	if err := gatingTask.Get("snaps", &snaps); err != nil {
   609  		return nil, err
   610  	}
   611  
   612  	held, err := heldSnaps(gatingTask.State())
   613  	if err != nil {
   614  		return nil, err
   615  	}
   616  
   617  	var skipped []string
   618  	var candidates []*refreshCandidate
   619  	for _, s := range snaps {
   620  		if !held[s.InstanceName()] {
   621  			candidates = append(candidates, s)
   622  		} else {
   623  			skipped = append(skipped, s.InstanceName())
   624  		}
   625  	}
   626  
   627  	if len(skipped) > 0 {
   628  		sort.Strings(skipped)
   629  		logger.Noticef("skipping refresh of held snaps: %s", strings.Join(skipped, ","))
   630  	}
   631  
   632  	return candidates, nil
   633  }