github.com/freetocompute/snapd@v0.0.0-20210618182524-2fb355d72fd9/overlord/snapstate/autorefresh_gating.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2021 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package snapstate
    21  
    22  import (
    23  	"fmt"
    24  	"os"
    25  	"sort"
    26  	"strings"
    27  	"time"
    28  
    29  	"github.com/snapcore/snapd/interfaces"
    30  	"github.com/snapcore/snapd/interfaces/mount"
    31  	"github.com/snapcore/snapd/logger"
    32  	"github.com/snapcore/snapd/overlord/ifacestate/ifacerepo"
    33  	"github.com/snapcore/snapd/overlord/state"
    34  	"github.com/snapcore/snapd/release"
    35  	"github.com/snapcore/snapd/snap"
    36  )
    37  
    38  var gateAutoRefreshHookName = "gate-auto-refresh"
    39  
    40  // gateAutoRefreshAction represents the action executed by
    41  // snapctl refresh --hold or --proceed and stored in the context of
    42  // gate-auto-refresh hook.
    43  type gateAutoRefreshAction int
    44  
    45  const (
    46  	GateAutoRefreshProceed gateAutoRefreshAction = iota
    47  	GateAutoRefreshHold
    48  )
    49  
    50  // cumulative hold time for snaps other than self
    51  const maxOtherHoldDuration = time.Hour * 48
    52  
    53  var timeNow = func() time.Time {
    54  	return time.Now()
    55  }
    56  
    57  func lastRefreshed(st *state.State, snapName string) (time.Time, error) {
    58  	var snapst SnapState
    59  	if err := Get(st, snapName, &snapst); err != nil {
    60  		return time.Time{}, fmt.Errorf("internal error, cannot get snap %q: %v", snapName, err)
    61  	}
    62  	// try to get last refresh time from snapstate, but it may not be present
    63  	// for snaps installed before the introduction of last-refresh attribute.
    64  	if snapst.LastRefreshTime != nil {
    65  		return *snapst.LastRefreshTime, nil
    66  	}
    67  	snapInfo, err := snapst.CurrentInfo()
    68  	if err != nil {
    69  		return time.Time{}, err
    70  	}
    71  	// fall back to the modification time of .snap blob file as it's the best
    72  	// approximation of last refresh time.
    73  	fst, err := os.Stat(snapInfo.MountFile())
    74  	if err != nil {
    75  		return time.Time{}, err
    76  	}
    77  	return fst.ModTime(), nil
    78  }
    79  
    80  type holdState struct {
    81  	// FirstHeld keeps the time when the given snap was first held for refresh by a gating snap.
    82  	FirstHeld time.Time `json:"first-held"`
    83  	// HoldUntil stores the desired end time for holding.
    84  	HoldUntil time.Time `json:"hold-until"`
    85  }
    86  
    87  func refreshGating(st *state.State) (map[string]map[string]*holdState, error) {
    88  	// held snaps -> holding snap(s) -> first-held/hold-until time
    89  	var gating map[string]map[string]*holdState
    90  	err := st.Get("snaps-hold", &gating)
    91  	if err != nil && err != state.ErrNoState {
    92  		return nil, fmt.Errorf("internal error: cannot get snaps-hold: %v", err)
    93  	}
    94  	if err == state.ErrNoState {
    95  		return make(map[string]map[string]*holdState), nil
    96  	}
    97  	return gating, nil
    98  }
    99  
   100  // HoldDurationError contains the that error prevents requested hold, along with
   101  // hold time that's left (if any).
   102  type HoldDurationError struct {
   103  	Err          error
   104  	DurationLeft time.Duration
   105  }
   106  
   107  func (h *HoldDurationError) Error() string {
   108  	return h.Err.Error()
   109  }
   110  
   111  // HoldError contains the details of snaps that cannot to be held.
   112  type HoldError struct {
   113  	SnapsInError map[string]HoldDurationError
   114  }
   115  
   116  func (h *HoldError) Error() string {
   117  	l := []string{""}
   118  	for _, e := range h.SnapsInError {
   119  		l = append(l, e.Error())
   120  	}
   121  	return fmt.Sprintf("cannot hold some snaps:%s", strings.Join(l, "\n - "))
   122  }
   123  
   124  func maxAllowedPostponement(gatingSnap, affectedSnap string, maxPostponement time.Duration) time.Duration {
   125  	if affectedSnap == gatingSnap {
   126  		return maxPostponement
   127  	}
   128  	return maxOtherHoldDuration
   129  }
   130  
   131  // holdDurationLeft computes the maximum duration that's left for holding a refresh
   132  // given current time, last refresh time, time when snap was first held, maximum
   133  // duration allowed for the given snap and maximum overall postponement allowed by
   134  // snapd.
   135  func holdDurationLeft(now time.Time, lastRefresh, firstHeld time.Time, maxDuration, maxPostponement time.Duration) time.Duration {
   136  	d1 := firstHeld.Add(maxDuration).Sub(now)
   137  	d2 := lastRefresh.Add(maxPostponement).Sub(now)
   138  	if d1 < d2 {
   139  		return d1
   140  	}
   141  	return d2
   142  }
   143  
   144  // HoldRefresh marks affectingSnaps as held for refresh for up to holdTime.
   145  // HoldTime of zero denotes maximum allowed hold time.
   146  // Holding may fail for only some snaps in which case HoldError is returned and
   147  // it contains the details of failed ones.
   148  func HoldRefresh(st *state.State, gatingSnap string, holdDuration time.Duration, affectingSnaps ...string) error {
   149  	gating, err := refreshGating(st)
   150  	if err != nil {
   151  		return err
   152  	}
   153  	herr := &HoldError{
   154  		SnapsInError: make(map[string]HoldDurationError),
   155  	}
   156  	now := timeNow()
   157  	for _, heldSnap := range affectingSnaps {
   158  		hold, ok := gating[heldSnap][gatingSnap]
   159  		if !ok {
   160  			hold = &holdState{
   161  				FirstHeld: now,
   162  			}
   163  		}
   164  
   165  		lastRefreshTime, err := lastRefreshed(st, heldSnap)
   166  		if err != nil {
   167  			return err
   168  		}
   169  
   170  		mp := maxPostponement - maxPostponementBuffer
   171  		maxDur := maxAllowedPostponement(gatingSnap, heldSnap, mp)
   172  
   173  		// calculate max hold duration that's left considering previous hold
   174  		// requests of this snap and last refresh time.
   175  		left := holdDurationLeft(now, lastRefreshTime, hold.FirstHeld, maxDur, mp)
   176  		if left <= 0 {
   177  			herr.SnapsInError[heldSnap] = HoldDurationError{
   178  				Err: fmt.Errorf("snap %q cannot hold snap %q anymore, maximum refresh postponement exceeded", gatingSnap, heldSnap),
   179  			}
   180  			continue
   181  		}
   182  
   183  		dur := holdDuration
   184  		if dur == 0 {
   185  			// duration not specified, using a default one (maximum) or what's
   186  			// left of it.
   187  			dur = left
   188  		} else {
   189  			// explicit hold duration requested
   190  			if dur > maxDur {
   191  				herr.SnapsInError[heldSnap] = HoldDurationError{
   192  					Err:          fmt.Errorf("requested holding duration for snap %q of %s by snap %q exceeds maximum holding time", heldSnap, holdDuration, gatingSnap),
   193  					DurationLeft: left,
   194  				}
   195  				continue
   196  			}
   197  		}
   198  
   199  		newHold := now.Add(dur)
   200  		cutOff := lastRefreshTime.Add(maxPostponement - maxPostponementBuffer)
   201  
   202  		// consider last refresh time and adjust hold duration if needed so it's
   203  		// not exceeded.
   204  		if newHold.Before(cutOff) {
   205  			hold.HoldUntil = newHold
   206  		} else {
   207  			hold.HoldUntil = cutOff
   208  		}
   209  
   210  		// finally store/update gating hold data
   211  		if _, ok := gating[heldSnap]; !ok {
   212  			gating[heldSnap] = make(map[string]*holdState)
   213  		}
   214  		gating[heldSnap][gatingSnap] = hold
   215  	}
   216  
   217  	if len(herr.SnapsInError) != len(affectingSnaps) {
   218  		st.Set("snaps-hold", gating)
   219  	}
   220  	if len(herr.SnapsInError) > 0 {
   221  		return herr
   222  	}
   223  	return nil
   224  }
   225  
   226  // ProceedWithRefresh unblocks all snaps held by gatingSnap for refresh. This
   227  // should be called for --proceed on the gatingSnap.
   228  func ProceedWithRefresh(st *state.State, gatingSnap string) error {
   229  	gating, err := refreshGating(st)
   230  	if err != nil {
   231  		return err
   232  	}
   233  	if len(gating) == 0 {
   234  		return nil
   235  	}
   236  
   237  	var changed bool
   238  	for heldSnap, gatingSnaps := range gating {
   239  		if _, ok := gatingSnaps[gatingSnap]; ok {
   240  			delete(gatingSnaps, gatingSnap)
   241  			changed = true
   242  		}
   243  		if len(gatingSnaps) == 0 {
   244  			delete(gating, heldSnap)
   245  		}
   246  	}
   247  
   248  	if changed {
   249  		st.Set("snaps-hold", gating)
   250  	}
   251  	return nil
   252  }
   253  
   254  // pruneGating removes affecting snaps that are not in candidates (meaning
   255  // there is no update for them anymore).
   256  func pruneGating(st *state.State, candidates map[string]*refreshCandidate) error {
   257  	gating, err := refreshGating(st)
   258  	if err != nil {
   259  		return err
   260  	}
   261  
   262  	if len(gating) == 0 {
   263  		return nil
   264  	}
   265  
   266  	var changed bool
   267  	for affectingSnap := range gating {
   268  		if candidates[affectingSnap] == nil {
   269  			// the snap doesn't have an update anymore, forget it
   270  			delete(gating, affectingSnap)
   271  			changed = true
   272  		}
   273  	}
   274  	if changed {
   275  		st.Set("snaps-hold", gating)
   276  	}
   277  	return nil
   278  }
   279  
   280  // resetGatingForRefreshed resets gating information by removing refreshedSnaps
   281  // (they are not held anymore). This should be called for all successfully
   282  // refreshed snaps.
   283  func resetGatingForRefreshed(st *state.State, refreshedSnaps ...string) error {
   284  	gating, err := refreshGating(st)
   285  	if err != nil {
   286  		return err
   287  	}
   288  	if len(gating) == 0 {
   289  		return nil
   290  	}
   291  
   292  	var changed bool
   293  	for _, snapName := range refreshedSnaps {
   294  		if _, ok := gating[snapName]; ok {
   295  			delete(gating, snapName)
   296  			changed = true
   297  		}
   298  	}
   299  
   300  	if changed {
   301  		st.Set("snaps-hold", gating)
   302  	}
   303  	return nil
   304  }
   305  
   306  // heldSnaps returns all snaps that are gated and shouldn't be refreshed.
   307  func heldSnaps(st *state.State) (map[string]bool, error) {
   308  	gating, err := refreshGating(st)
   309  	if err != nil {
   310  		return nil, err
   311  	}
   312  	if len(gating) == 0 {
   313  		return nil, nil
   314  	}
   315  
   316  	now := timeNow()
   317  
   318  	held := make(map[string]bool)
   319  Loop:
   320  	for heldSnap, holdingSnaps := range gating {
   321  		refreshed, err := lastRefreshed(st, heldSnap)
   322  		if err != nil {
   323  			return nil, err
   324  		}
   325  		// make sure we don't hold any snap for more than maxPostponement
   326  		if refreshed.Add(maxPostponement).Before(now) {
   327  			continue
   328  		}
   329  		for _, hold := range holdingSnaps {
   330  			if hold.HoldUntil.Before(now) {
   331  				continue
   332  			}
   333  			held[heldSnap] = true
   334  			continue Loop
   335  		}
   336  	}
   337  	return held, nil
   338  }
   339  
   340  type affectedSnapInfo struct {
   341  	Restart        bool
   342  	Base           bool
   343  	AffectingSnaps map[string]bool
   344  }
   345  
   346  func affectedByRefresh(st *state.State, updates []*snap.Info) (map[string]*affectedSnapInfo, error) {
   347  	all, err := All(st)
   348  	if err != nil {
   349  		return nil, err
   350  	}
   351  
   352  	var bootBase string
   353  	if !release.OnClassic {
   354  		deviceCtx, err := DeviceCtx(st, nil, nil)
   355  		if err != nil {
   356  			return nil, fmt.Errorf("cannot get device context: %v", err)
   357  		}
   358  		bootBaseInfo, err := BootBaseInfo(st, deviceCtx)
   359  		if err != nil {
   360  			return nil, fmt.Errorf("cannot get boot base info: %v", err)
   361  		}
   362  		bootBase = bootBaseInfo.InstanceName()
   363  	}
   364  
   365  	byBase := make(map[string][]string)
   366  	for name, snapSt := range all {
   367  		if !snapSt.Active {
   368  			delete(all, name)
   369  			continue
   370  		}
   371  		inf, err := snapSt.CurrentInfo()
   372  		if err != nil {
   373  			return nil, err
   374  		}
   375  		// optimization: do not consider snaps that don't have gate-auto-refresh hook.
   376  		if inf.Hooks[gateAutoRefreshHookName] == nil {
   377  			delete(all, name)
   378  			continue
   379  		}
   380  
   381  		base := inf.Base
   382  		if base == "none" {
   383  			continue
   384  		}
   385  		if inf.Base == "" {
   386  			base = "core"
   387  		}
   388  		byBase[base] = append(byBase[base], inf.InstanceName())
   389  	}
   390  
   391  	affected := make(map[string]*affectedSnapInfo)
   392  
   393  	addAffected := func(snapName, affectedBy string, restart bool, base bool) {
   394  		if affected[snapName] == nil {
   395  			affected[snapName] = &affectedSnapInfo{
   396  				AffectingSnaps: map[string]bool{},
   397  			}
   398  		}
   399  		affectedInfo := affected[snapName]
   400  		if restart {
   401  			affectedInfo.Restart = restart
   402  		}
   403  		if base {
   404  			affectedInfo.Base = base
   405  		}
   406  		affectedInfo.AffectingSnaps[affectedBy] = true
   407  	}
   408  
   409  	for _, up := range updates {
   410  		// on core system, affected by update of boot base
   411  		if bootBase != "" && up.InstanceName() == bootBase {
   412  			for _, snapSt := range all {
   413  				addAffected(snapSt.InstanceName(), up.InstanceName(), true, false)
   414  			}
   415  		}
   416  
   417  		// snaps that can trigger reboot
   418  		// XXX: gadget refresh doesn't always require reboot, refine this
   419  		if up.Type() == snap.TypeKernel || up.Type() == snap.TypeGadget {
   420  			for _, snapSt := range all {
   421  				addAffected(snapSt.InstanceName(), up.InstanceName(), true, false)
   422  			}
   423  			continue
   424  		}
   425  		if up.Type() == snap.TypeBase || up.SnapName() == "core" {
   426  			// affected by refresh of this base snap
   427  			for _, snapName := range byBase[up.InstanceName()] {
   428  				addAffected(snapName, up.InstanceName(), false, true)
   429  			}
   430  		}
   431  
   432  		repo := ifacerepo.Get(st)
   433  
   434  		// consider slots provided by refreshed snap, but exclude core and snapd
   435  		// since they provide system-level slots that are generally not disrupted
   436  		// by snap updates.
   437  		if up.SnapType != snap.TypeSnapd && up.SnapName() != "core" {
   438  			for _, slotInfo := range up.Slots {
   439  				conns, err := repo.Connected(up.InstanceName(), slotInfo.Name)
   440  				if err != nil {
   441  					return nil, err
   442  				}
   443  				for _, cref := range conns {
   444  					// affected only if it wasn't optimized out above
   445  					if all[cref.PlugRef.Snap] != nil {
   446  						addAffected(cref.PlugRef.Snap, up.InstanceName(), true, false)
   447  					}
   448  				}
   449  			}
   450  		}
   451  
   452  		// consider mount backend plugs/slots;
   453  		// for slot side only consider snapd/core because they are ignored by the
   454  		// earlier loop around slots.
   455  		if up.SnapType == snap.TypeSnapd || up.SnapType == snap.TypeOS {
   456  			for _, slotInfo := range up.Slots {
   457  				iface := repo.Interface(slotInfo.Interface)
   458  				if iface == nil {
   459  					return nil, fmt.Errorf("internal error: unknown interface %s", slotInfo.Interface)
   460  				}
   461  				if !usesMountBackend(iface) {
   462  					continue
   463  				}
   464  				conns, err := repo.Connected(up.InstanceName(), slotInfo.Name)
   465  				if err != nil {
   466  					return nil, err
   467  				}
   468  				for _, cref := range conns {
   469  					if all[cref.PlugRef.Snap] != nil {
   470  						addAffected(cref.PlugRef.Snap, up.InstanceName(), true, false)
   471  					}
   472  				}
   473  			}
   474  		}
   475  		for _, plugInfo := range up.Plugs {
   476  			iface := repo.Interface(plugInfo.Interface)
   477  			if iface == nil {
   478  				return nil, fmt.Errorf("internal error: unknown interface %s", plugInfo.Interface)
   479  			}
   480  			if !usesMountBackend(iface) {
   481  				continue
   482  			}
   483  			conns, err := repo.Connected(up.InstanceName(), plugInfo.Name)
   484  			if err != nil {
   485  				return nil, err
   486  			}
   487  			for _, cref := range conns {
   488  				if all[cref.SlotRef.Snap] != nil {
   489  					addAffected(cref.SlotRef.Snap, up.InstanceName(), true, false)
   490  				}
   491  			}
   492  		}
   493  	}
   494  
   495  	return affected, nil
   496  }
   497  
   498  // XXX: this is too wide and affects all commonInterface-based interfaces; we
   499  // need metadata on the relevant interfaces.
   500  func usesMountBackend(iface interfaces.Interface) bool {
   501  	type definer1 interface {
   502  		MountConnectedSlot(*mount.Specification, *interfaces.ConnectedPlug, *interfaces.ConnectedSlot) error
   503  	}
   504  	type definer2 interface {
   505  		MountConnectedPlug(*mount.Specification, *interfaces.ConnectedPlug, *interfaces.ConnectedSlot) error
   506  	}
   507  	type definer3 interface {
   508  		MountPermanentPlug(*mount.Specification, *snap.PlugInfo) error
   509  	}
   510  	type definer4 interface {
   511  		MountPermanentSlot(*mount.Specification, *snap.SlotInfo) error
   512  	}
   513  
   514  	if _, ok := iface.(definer1); ok {
   515  		return true
   516  	}
   517  	if _, ok := iface.(definer2); ok {
   518  		return true
   519  	}
   520  	if _, ok := iface.(definer3); ok {
   521  		return true
   522  	}
   523  	if _, ok := iface.(definer4); ok {
   524  		return true
   525  	}
   526  	return false
   527  }
   528  
   529  // createGateAutoRefreshHooks creates gate-auto-refresh hooks for all affectedSnaps.
   530  // The hooks will have their context data set from affectedSnapInfo flags (base, restart).
   531  // Hook tasks will be chained to run sequentially.
   532  func createGateAutoRefreshHooks(st *state.State, affectedSnaps map[string]*affectedSnapInfo) *state.TaskSet {
   533  	ts := state.NewTaskSet()
   534  	var prev *state.Task
   535  	// sort names for easy testing
   536  	names := make([]string, 0, len(affectedSnaps))
   537  	for snapName := range affectedSnaps {
   538  		names = append(names, snapName)
   539  	}
   540  	sort.Strings(names)
   541  	for _, snapName := range names {
   542  		affected := affectedSnaps[snapName]
   543  		hookTask := SetupGateAutoRefreshHook(st, snapName, affected.Base, affected.Restart, affected.AffectingSnaps)
   544  		// XXX: it should be fine to run the hooks in parallel
   545  		if prev != nil {
   546  			hookTask.WaitFor(prev)
   547  		}
   548  		ts.AddTask(hookTask)
   549  		prev = hookTask
   550  	}
   551  	return ts
   552  }
   553  
   554  // snapsToRefresh returns all snaps that should proceed with refresh considering
   555  // hold behavior.
   556  var snapsToRefresh = func(gatingTask *state.Task) ([]*refreshCandidate, error) {
   557  	var snaps map[string]*refreshCandidate
   558  	if err := gatingTask.Get("snaps", &snaps); err != nil {
   559  		return nil, err
   560  	}
   561  
   562  	held, err := heldSnaps(gatingTask.State())
   563  	if err != nil {
   564  		return nil, err
   565  	}
   566  
   567  	var skipped []string
   568  	var candidates []*refreshCandidate
   569  	for _, s := range snaps {
   570  		if !held[s.InstanceName()] {
   571  			candidates = append(candidates, s)
   572  		} else {
   573  			skipped = append(skipped, s.InstanceName())
   574  		}
   575  	}
   576  
   577  	if len(skipped) > 0 {
   578  		sort.Strings(skipped)
   579  		logger.Noticef("skipping refresh of held snaps: %s", strings.Join(skipped, ","))
   580  	}
   581  
   582  	return candidates, nil
   583  }