github.com/ethanhsieh/snapd@v0.0.0-20210615102523-3db9b8e4edc5/overlord/snapstate/autorefresh_gating.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2021 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package snapstate
    21  
    22  import (
    23  	"fmt"
    24  	"os"
    25  	"sort"
    26  	"strings"
    27  	"time"
    28  
    29  	"github.com/snapcore/snapd/interfaces"
    30  	"github.com/snapcore/snapd/interfaces/mount"
    31  	"github.com/snapcore/snapd/overlord/ifacestate/ifacerepo"
    32  	"github.com/snapcore/snapd/overlord/state"
    33  	"github.com/snapcore/snapd/release"
    34  	"github.com/snapcore/snapd/snap"
    35  )
    36  
    37  var gateAutoRefreshHookName = "gate-auto-refresh"
    38  
    39  // gateAutoRefreshAction represents the action executed by
    40  // snapctl refresh --hold or --proceed and stored in the context of
    41  // gate-auto-refresh hook.
    42  type gateAutoRefreshAction int
    43  
    44  const (
    45  	GateAutoRefreshProceed gateAutoRefreshAction = iota
    46  	GateAutoRefreshHold
    47  )
    48  
    49  // cumulative hold time for snaps other than self
    50  const maxOtherHoldDuration = time.Hour * 48
    51  
    52  var timeNow = func() time.Time {
    53  	return time.Now()
    54  }
    55  
    56  func lastRefreshed(st *state.State, snapName string) (time.Time, error) {
    57  	var snapst SnapState
    58  	if err := Get(st, snapName, &snapst); err != nil {
    59  		return time.Time{}, fmt.Errorf("internal error, cannot get snap %q: %v", snapName, err)
    60  	}
    61  	// try to get last refresh time from snapstate, but it may not be present
    62  	// for snaps installed before the introduction of last-refresh attribute.
    63  	if snapst.LastRefreshTime != nil {
    64  		return *snapst.LastRefreshTime, nil
    65  	}
    66  	snapInfo, err := snapst.CurrentInfo()
    67  	if err != nil {
    68  		return time.Time{}, err
    69  	}
    70  	// fall back to the modification time of .snap blob file as it's the best
    71  	// approximation of last refresh time.
    72  	fst, err := os.Stat(snapInfo.MountFile())
    73  	if err != nil {
    74  		return time.Time{}, err
    75  	}
    76  	return fst.ModTime(), nil
    77  }
    78  
    79  type holdState struct {
    80  	// FirstHeld keeps the time when the given snap was first held for refresh by a gating snap.
    81  	FirstHeld time.Time `json:"first-held"`
    82  	// HoldUntil stores the desired end time for holding.
    83  	HoldUntil time.Time `json:"hold-until"`
    84  }
    85  
    86  func refreshGating(st *state.State) (map[string]map[string]*holdState, error) {
    87  	// held snaps -> holding snap(s) -> first-held/hold-until time
    88  	var gating map[string]map[string]*holdState
    89  	err := st.Get("snaps-hold", &gating)
    90  	if err != nil && err != state.ErrNoState {
    91  		return nil, fmt.Errorf("internal error: cannot get snaps-hold: %v", err)
    92  	}
    93  	if err == state.ErrNoState {
    94  		return make(map[string]map[string]*holdState), nil
    95  	}
    96  	return gating, nil
    97  }
    98  
    99  // HoldDurationError contains the that error prevents requested hold, along with
   100  // hold time that's left (if any).
   101  type HoldDurationError struct {
   102  	Err          error
   103  	DurationLeft time.Duration
   104  }
   105  
   106  func (h *HoldDurationError) Error() string {
   107  	return h.Err.Error()
   108  }
   109  
   110  // HoldError contains the details of snaps that cannot to be held.
   111  type HoldError struct {
   112  	SnapsInError map[string]HoldDurationError
   113  }
   114  
   115  func (h *HoldError) Error() string {
   116  	l := []string{""}
   117  	for _, e := range h.SnapsInError {
   118  		l = append(l, e.Error())
   119  	}
   120  	return fmt.Sprintf("cannot hold some snaps:%s", strings.Join(l, "\n - "))
   121  }
   122  
   123  func maxAllowedPostponement(gatingSnap, affectedSnap string, maxPostponement time.Duration) time.Duration {
   124  	if affectedSnap == gatingSnap {
   125  		return maxPostponement
   126  	}
   127  	return maxOtherHoldDuration
   128  }
   129  
   130  // holdDurationLeft computes the maximum duration that's left for holding a refresh
   131  // given current time, last refresh time, time when snap was first held, maximum
   132  // duration allowed for the given snap and maximum overall postponement allowed by
   133  // snapd.
   134  func holdDurationLeft(now time.Time, lastRefresh, firstHeld time.Time, maxDuration, maxPostponement time.Duration) time.Duration {
   135  	d1 := firstHeld.Add(maxDuration).Sub(now)
   136  	d2 := lastRefresh.Add(maxPostponement).Sub(now)
   137  	if d1 < d2 {
   138  		return d1
   139  	}
   140  	return d2
   141  }
   142  
   143  // HoldRefresh marks affectingSnaps as held for refresh for up to holdTime.
   144  // HoldTime of zero denotes maximum allowed hold time.
   145  // Holding may fail for only some snaps in which case HoldError is returned and
   146  // it contains the details of failed ones.
   147  func HoldRefresh(st *state.State, gatingSnap string, holdDuration time.Duration, affectingSnaps ...string) error {
   148  	gating, err := refreshGating(st)
   149  	if err != nil {
   150  		return err
   151  	}
   152  	herr := &HoldError{
   153  		SnapsInError: make(map[string]HoldDurationError),
   154  	}
   155  	now := timeNow()
   156  	for _, heldSnap := range affectingSnaps {
   157  		hold, ok := gating[heldSnap][gatingSnap]
   158  		if !ok {
   159  			hold = &holdState{
   160  				FirstHeld: now,
   161  			}
   162  		}
   163  
   164  		lastRefreshTime, err := lastRefreshed(st, heldSnap)
   165  		if err != nil {
   166  			return err
   167  		}
   168  
   169  		mp := maxPostponement - maxPostponementBuffer
   170  		maxDur := maxAllowedPostponement(gatingSnap, heldSnap, mp)
   171  
   172  		// calculate max hold duration that's left considering previous hold
   173  		// requests of this snap and last refresh time.
   174  		left := holdDurationLeft(now, lastRefreshTime, hold.FirstHeld, maxDur, mp)
   175  		if left <= 0 {
   176  			herr.SnapsInError[heldSnap] = HoldDurationError{
   177  				Err: fmt.Errorf("snap %q cannot hold snap %q anymore, maximum refresh postponement exceeded", gatingSnap, heldSnap),
   178  			}
   179  			continue
   180  		}
   181  
   182  		dur := holdDuration
   183  		if dur == 0 {
   184  			// duration not specified, using a default one (maximum) or what's
   185  			// left of it.
   186  			dur = left
   187  		} else {
   188  			// explicit hold duration requested
   189  			if dur > maxDur {
   190  				herr.SnapsInError[heldSnap] = HoldDurationError{
   191  					Err:          fmt.Errorf("requested holding duration for snap %q of %s by snap %q exceeds maximum holding time", heldSnap, holdDuration, gatingSnap),
   192  					DurationLeft: left,
   193  				}
   194  				continue
   195  			}
   196  		}
   197  
   198  		newHold := now.Add(dur)
   199  		cutOff := lastRefreshTime.Add(maxPostponement - maxPostponementBuffer)
   200  
   201  		// consider last refresh time and adjust hold duration if needed so it's
   202  		// not exceeded.
   203  		if newHold.Before(cutOff) {
   204  			hold.HoldUntil = newHold
   205  		} else {
   206  			hold.HoldUntil = cutOff
   207  		}
   208  
   209  		// finally store/update gating hold data
   210  		if _, ok := gating[heldSnap]; !ok {
   211  			gating[heldSnap] = make(map[string]*holdState)
   212  		}
   213  		gating[heldSnap][gatingSnap] = hold
   214  	}
   215  
   216  	if len(herr.SnapsInError) != len(affectingSnaps) {
   217  		st.Set("snaps-hold", gating)
   218  	}
   219  	if len(herr.SnapsInError) > 0 {
   220  		return herr
   221  	}
   222  	return nil
   223  }
   224  
   225  // ProceedWithRefresh unblocks all snaps held by gatingSnap for refresh. This
   226  // should be called for --proceed on the gatingSnap.
   227  func ProceedWithRefresh(st *state.State, gatingSnap string) error {
   228  	gating, err := refreshGating(st)
   229  	if err != nil {
   230  		return err
   231  	}
   232  	if len(gating) == 0 {
   233  		return nil
   234  	}
   235  
   236  	var changed bool
   237  	for heldSnap, gatingSnaps := range gating {
   238  		if _, ok := gatingSnaps[gatingSnap]; ok {
   239  			delete(gatingSnaps, gatingSnap)
   240  			changed = true
   241  		}
   242  		if len(gatingSnaps) == 0 {
   243  			delete(gating, heldSnap)
   244  		}
   245  	}
   246  
   247  	if changed {
   248  		st.Set("snaps-hold", gating)
   249  	}
   250  	return nil
   251  }
   252  
   253  // resetGatingForRefreshed resets gating information by removing refreshedSnaps
   254  // (they are not held anymore). This should be called for all successfully
   255  // refreshed snaps.
   256  func resetGatingForRefreshed(st *state.State, refreshedSnaps ...string) error {
   257  	gating, err := refreshGating(st)
   258  	if err != nil {
   259  		return err
   260  	}
   261  	if len(gating) == 0 {
   262  		return nil
   263  	}
   264  
   265  	var changed bool
   266  	for _, snapName := range refreshedSnaps {
   267  		if _, ok := gating[snapName]; ok {
   268  			delete(gating, snapName)
   269  			changed = true
   270  		}
   271  	}
   272  
   273  	if changed {
   274  		st.Set("snaps-hold", gating)
   275  	}
   276  	return nil
   277  }
   278  
   279  // heldSnaps returns all snaps that are gated and shouldn't be refreshed.
   280  func heldSnaps(st *state.State) (map[string]bool, error) {
   281  	gating, err := refreshGating(st)
   282  	if err != nil {
   283  		return nil, err
   284  	}
   285  	if len(gating) == 0 {
   286  		return nil, nil
   287  	}
   288  
   289  	now := timeNow()
   290  
   291  	held := make(map[string]bool)
   292  Loop:
   293  	for heldSnap, holdingSnaps := range gating {
   294  		refreshed, err := lastRefreshed(st, heldSnap)
   295  		if err != nil {
   296  			return nil, err
   297  		}
   298  		// make sure we don't hold any snap for more than maxPostponement
   299  		if refreshed.Add(maxPostponement).Before(now) {
   300  			continue
   301  		}
   302  		for _, hold := range holdingSnaps {
   303  			if hold.HoldUntil.Before(now) {
   304  				continue
   305  			}
   306  			held[heldSnap] = true
   307  			continue Loop
   308  		}
   309  	}
   310  	return held, nil
   311  }
   312  
   313  type affectedSnapInfo struct {
   314  	Restart        bool
   315  	Base           bool
   316  	AffectingSnaps map[string]bool
   317  }
   318  
   319  func affectedByRefresh(st *state.State, updates []*snap.Info) (map[string]*affectedSnapInfo, error) {
   320  	all, err := All(st)
   321  	if err != nil {
   322  		return nil, err
   323  	}
   324  
   325  	var bootBase string
   326  	if !release.OnClassic {
   327  		deviceCtx, err := DeviceCtx(st, nil, nil)
   328  		if err != nil {
   329  			return nil, fmt.Errorf("cannot get device context: %v", err)
   330  		}
   331  		bootBaseInfo, err := BootBaseInfo(st, deviceCtx)
   332  		if err != nil {
   333  			return nil, fmt.Errorf("cannot get boot base info: %v", err)
   334  		}
   335  		bootBase = bootBaseInfo.InstanceName()
   336  	}
   337  
   338  	byBase := make(map[string][]string)
   339  	for name, snapSt := range all {
   340  		if !snapSt.Active {
   341  			delete(all, name)
   342  			continue
   343  		}
   344  		inf, err := snapSt.CurrentInfo()
   345  		if err != nil {
   346  			return nil, err
   347  		}
   348  		// optimization: do not consider snaps that don't have gate-auto-refresh hook.
   349  		if inf.Hooks[gateAutoRefreshHookName] == nil {
   350  			delete(all, name)
   351  			continue
   352  		}
   353  
   354  		base := inf.Base
   355  		if base == "none" {
   356  			continue
   357  		}
   358  		if inf.Base == "" {
   359  			base = "core"
   360  		}
   361  		byBase[base] = append(byBase[base], inf.InstanceName())
   362  	}
   363  
   364  	affected := make(map[string]*affectedSnapInfo)
   365  
   366  	addAffected := func(snapName, affectedBy string, restart bool, base bool) {
   367  		if affected[snapName] == nil {
   368  			affected[snapName] = &affectedSnapInfo{
   369  				AffectingSnaps: map[string]bool{},
   370  			}
   371  		}
   372  		affectedInfo := affected[snapName]
   373  		if restart {
   374  			affectedInfo.Restart = restart
   375  		}
   376  		if base {
   377  			affectedInfo.Base = base
   378  		}
   379  		affectedInfo.AffectingSnaps[affectedBy] = true
   380  	}
   381  
   382  	for _, up := range updates {
   383  		// on core system, affected by update of boot base
   384  		if bootBase != "" && up.InstanceName() == bootBase {
   385  			for _, snapSt := range all {
   386  				addAffected(snapSt.InstanceName(), up.InstanceName(), true, false)
   387  			}
   388  		}
   389  
   390  		// snaps that can trigger reboot
   391  		// XXX: gadget refresh doesn't always require reboot, refine this
   392  		if up.Type() == snap.TypeKernel || up.Type() == snap.TypeGadget {
   393  			for _, snapSt := range all {
   394  				addAffected(snapSt.InstanceName(), up.InstanceName(), true, false)
   395  			}
   396  			continue
   397  		}
   398  		if up.Type() == snap.TypeBase || up.SnapName() == "core" {
   399  			// affected by refresh of this base snap
   400  			for _, snapName := range byBase[up.InstanceName()] {
   401  				addAffected(snapName, up.InstanceName(), false, true)
   402  			}
   403  		}
   404  
   405  		repo := ifacerepo.Get(st)
   406  
   407  		// consider slots provided by refreshed snap, but exclude core and snapd
   408  		// since they provide system-level slots that are generally not disrupted
   409  		// by snap updates.
   410  		if up.SnapType != snap.TypeSnapd && up.SnapName() != "core" {
   411  			for _, slotInfo := range up.Slots {
   412  				conns, err := repo.Connected(up.InstanceName(), slotInfo.Name)
   413  				if err != nil {
   414  					return nil, err
   415  				}
   416  				for _, cref := range conns {
   417  					// affected only if it wasn't optimized out above
   418  					if all[cref.PlugRef.Snap] != nil {
   419  						addAffected(cref.PlugRef.Snap, up.InstanceName(), true, false)
   420  					}
   421  				}
   422  			}
   423  		}
   424  
   425  		// consider mount backend plugs/slots;
   426  		// for slot side only consider snapd/core because they are ignored by the
   427  		// earlier loop around slots.
   428  		if up.SnapType == snap.TypeSnapd || up.SnapType == snap.TypeOS {
   429  			for _, slotInfo := range up.Slots {
   430  				iface := repo.Interface(slotInfo.Interface)
   431  				if iface == nil {
   432  					return nil, fmt.Errorf("internal error: unknown interface %s", slotInfo.Interface)
   433  				}
   434  				if !usesMountBackend(iface) {
   435  					continue
   436  				}
   437  				conns, err := repo.Connected(up.InstanceName(), slotInfo.Name)
   438  				if err != nil {
   439  					return nil, err
   440  				}
   441  				for _, cref := range conns {
   442  					if all[cref.PlugRef.Snap] != nil {
   443  						addAffected(cref.PlugRef.Snap, up.InstanceName(), true, false)
   444  					}
   445  				}
   446  			}
   447  		}
   448  		for _, plugInfo := range up.Plugs {
   449  			iface := repo.Interface(plugInfo.Interface)
   450  			if iface == nil {
   451  				return nil, fmt.Errorf("internal error: unknown interface %s", plugInfo.Interface)
   452  			}
   453  			if !usesMountBackend(iface) {
   454  				continue
   455  			}
   456  			conns, err := repo.Connected(up.InstanceName(), plugInfo.Name)
   457  			if err != nil {
   458  				return nil, err
   459  			}
   460  			for _, cref := range conns {
   461  				if all[cref.SlotRef.Snap] != nil {
   462  					addAffected(cref.SlotRef.Snap, up.InstanceName(), true, false)
   463  				}
   464  			}
   465  		}
   466  	}
   467  
   468  	return affected, nil
   469  }
   470  
   471  // XXX: this is too wide and affects all commonInterface-based interfaces; we
   472  // need metadata on the relevant interfaces.
   473  func usesMountBackend(iface interfaces.Interface) bool {
   474  	type definer1 interface {
   475  		MountConnectedSlot(*mount.Specification, *interfaces.ConnectedPlug, *interfaces.ConnectedSlot) error
   476  	}
   477  	type definer2 interface {
   478  		MountConnectedPlug(*mount.Specification, *interfaces.ConnectedPlug, *interfaces.ConnectedSlot) error
   479  	}
   480  	type definer3 interface {
   481  		MountPermanentPlug(*mount.Specification, *snap.PlugInfo) error
   482  	}
   483  	type definer4 interface {
   484  		MountPermanentSlot(*mount.Specification, *snap.SlotInfo) error
   485  	}
   486  
   487  	if _, ok := iface.(definer1); ok {
   488  		return true
   489  	}
   490  	if _, ok := iface.(definer2); ok {
   491  		return true
   492  	}
   493  	if _, ok := iface.(definer3); ok {
   494  		return true
   495  	}
   496  	if _, ok := iface.(definer4); ok {
   497  		return true
   498  	}
   499  	return false
   500  }
   501  
   502  // createGateAutoRefreshHooks creates gate-auto-refresh hooks for all affectedSnaps.
   503  // The hooks will have their context data set from affectedSnapInfo flags (base, restart).
   504  // Hook tasks will be chained to run sequentially.
   505  func createGateAutoRefreshHooks(st *state.State, affectedSnaps map[string]*affectedSnapInfo) *state.TaskSet {
   506  	ts := state.NewTaskSet()
   507  	var prev *state.Task
   508  	// sort names for easy testing
   509  	names := make([]string, 0, len(affectedSnaps))
   510  	for snapName := range affectedSnaps {
   511  		names = append(names, snapName)
   512  	}
   513  	sort.Strings(names)
   514  	for _, snapName := range names {
   515  		affected := affectedSnaps[snapName]
   516  		hookTask := SetupGateAutoRefreshHook(st, snapName, affected.Base, affected.Restart, affected.AffectingSnaps)
   517  		// XXX: it should be fine to run the hooks in parallel
   518  		if prev != nil {
   519  			hookTask.WaitFor(prev)
   520  		}
   521  		ts.AddTask(hookTask)
   522  		prev = hookTask
   523  	}
   524  	return ts
   525  }
   526  
   527  // snapsToRefresh returns all snaps that should proceed with refresh considering
   528  // hold behavior.
   529  var snapsToRefresh = func(gatingTask *state.Task) ([]*refreshCandidate, error) {
   530  	// TODO: consider holding (responses from gating hooks) here.
   531  	// Return all refresh candidates for now.
   532  	var snaps map[string]*refreshCandidate
   533  	if err := gatingTask.Get("snaps", &snaps); err != nil {
   534  		return nil, err
   535  	}
   536  	var candidates []*refreshCandidate
   537  	for _, s := range snaps {
   538  		candidates = append(candidates, s)
   539  	}
   540  	return candidates, nil
   541  }