github.com/david-imola/snapd@v0.0.0-20210611180407-2de8ddeece6d/overlord/snapstate/autorefresh_gating.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2021 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package snapstate
    21  
    22  import (
    23  	"fmt"
    24  	"os"
    25  	"sort"
    26  	"strings"
    27  	"time"
    28  
    29  	"github.com/snapcore/snapd/interfaces"
    30  	"github.com/snapcore/snapd/interfaces/mount"
    31  	"github.com/snapcore/snapd/overlord/ifacestate/ifacerepo"
    32  	"github.com/snapcore/snapd/overlord/state"
    33  	"github.com/snapcore/snapd/release"
    34  	"github.com/snapcore/snapd/snap"
    35  )
    36  
    37  var gateAutoRefreshHookName = "gate-auto-refresh"
    38  
    39  // cumulative hold time for snaps other than self
    40  const maxOtherHoldDuration = time.Hour * 48
    41  
    42  var timeNow = func() time.Time {
    43  	return time.Now()
    44  }
    45  
    46  func lastRefreshed(st *state.State, snapName string) (time.Time, error) {
    47  	var snapst SnapState
    48  	if err := Get(st, snapName, &snapst); err != nil {
    49  		return time.Time{}, fmt.Errorf("internal error, cannot get snap %q: %v", snapName, err)
    50  	}
    51  	// try to get last refresh time from snapstate, but it may not be present
    52  	// for snaps installed before the introduction of last-refresh attribute.
    53  	if snapst.LastRefreshTime != nil {
    54  		return *snapst.LastRefreshTime, nil
    55  	}
    56  	snapInfo, err := snapst.CurrentInfo()
    57  	if err != nil {
    58  		return time.Time{}, err
    59  	}
    60  	// fall back to the modification time of .snap blob file as it's the best
    61  	// approximation of last refresh time.
    62  	fst, err := os.Stat(snapInfo.MountFile())
    63  	if err != nil {
    64  		return time.Time{}, err
    65  	}
    66  	return fst.ModTime(), nil
    67  }
    68  
    69  type holdState struct {
    70  	// FirstHeld keeps the time when the given snap was first held for refresh by a gating snap.
    71  	FirstHeld time.Time `json:"first-held"`
    72  	// HoldUntil stores the desired end time for holding.
    73  	HoldUntil time.Time `json:"hold-until"`
    74  }
    75  
    76  func refreshGating(st *state.State) (map[string]map[string]*holdState, error) {
    77  	// held snaps -> holding snap(s) -> first-held/hold-until time
    78  	var gating map[string]map[string]*holdState
    79  	err := st.Get("snaps-hold", &gating)
    80  	if err != nil && err != state.ErrNoState {
    81  		return nil, fmt.Errorf("internal error: cannot get snaps-hold: %v", err)
    82  	}
    83  	if err == state.ErrNoState {
    84  		return make(map[string]map[string]*holdState), nil
    85  	}
    86  	return gating, nil
    87  }
    88  
    89  // HoldDurationError contains the that error prevents requested hold, along with
    90  // hold time that's left (if any).
    91  type HoldDurationError struct {
    92  	Err          error
    93  	DurationLeft time.Duration
    94  }
    95  
    96  func (h *HoldDurationError) Error() string {
    97  	return h.Err.Error()
    98  }
    99  
   100  // HoldError contains the details of snaps that cannot to be held.
   101  type HoldError struct {
   102  	SnapsInError map[string]HoldDurationError
   103  }
   104  
   105  func (h *HoldError) Error() string {
   106  	l := []string{""}
   107  	for _, e := range h.SnapsInError {
   108  		l = append(l, e.Error())
   109  	}
   110  	return fmt.Sprintf("cannot hold some snaps:%s", strings.Join(l, "\n - "))
   111  }
   112  
   113  func maxAllowedPostponement(gatingSnap, affectedSnap string, maxPostponement time.Duration) time.Duration {
   114  	if affectedSnap == gatingSnap {
   115  		return maxPostponement
   116  	}
   117  	return maxOtherHoldDuration
   118  }
   119  
   120  // holdDurationLeft computes the maximum duration that's left for holding a refresh
   121  // given current time, last refresh time, time when snap was first held, maximum
   122  // duration allowed for the given snap and maximum overall postponement allowed by
   123  // snapd.
   124  func holdDurationLeft(now time.Time, lastRefresh, firstHeld time.Time, maxDuration, maxPostponement time.Duration) time.Duration {
   125  	d1 := firstHeld.Add(maxDuration).Sub(now)
   126  	d2 := lastRefresh.Add(maxPostponement).Sub(now)
   127  	if d1 < d2 {
   128  		return d1
   129  	}
   130  	return d2
   131  }
   132  
   133  // HoldRefresh marks affectingSnaps as held for refresh for up to holdTime.
   134  // HoldTime of zero denotes maximum allowed hold time.
   135  // Holding may fail for only some snaps in which case HoldError is returned and
   136  // it contains the details of failed ones.
   137  func HoldRefresh(st *state.State, gatingSnap string, holdDuration time.Duration, affectingSnaps ...string) error {
   138  	gating, err := refreshGating(st)
   139  	if err != nil {
   140  		return err
   141  	}
   142  	herr := &HoldError{
   143  		SnapsInError: make(map[string]HoldDurationError),
   144  	}
   145  	now := timeNow()
   146  	for _, heldSnap := range affectingSnaps {
   147  		hold, ok := gating[heldSnap][gatingSnap]
   148  		if !ok {
   149  			hold = &holdState{
   150  				FirstHeld: now,
   151  			}
   152  		}
   153  
   154  		lastRefreshTime, err := lastRefreshed(st, heldSnap)
   155  		if err != nil {
   156  			return err
   157  		}
   158  
   159  		mp := maxPostponement - maxPostponementBuffer
   160  		maxDur := maxAllowedPostponement(gatingSnap, heldSnap, mp)
   161  
   162  		// calculate max hold duration that's left considering previous hold
   163  		// requests of this snap and last refresh time.
   164  		left := holdDurationLeft(now, lastRefreshTime, hold.FirstHeld, maxDur, mp)
   165  		if left <= 0 {
   166  			herr.SnapsInError[heldSnap] = HoldDurationError{
   167  				Err: fmt.Errorf("snap %q cannot hold snap %q anymore, maximum refresh postponement exceeded", gatingSnap, heldSnap),
   168  			}
   169  			continue
   170  		}
   171  
   172  		dur := holdDuration
   173  		if dur == 0 {
   174  			// duration not specified, using a default one (maximum) or what's
   175  			// left of it.
   176  			dur = left
   177  		} else {
   178  			// explicit hold duration requested
   179  			if dur > maxDur {
   180  				herr.SnapsInError[heldSnap] = HoldDurationError{
   181  					Err:          fmt.Errorf("requested holding duration for snap %q of %s by snap %q exceeds maximum holding time", heldSnap, holdDuration, gatingSnap),
   182  					DurationLeft: left,
   183  				}
   184  				continue
   185  			}
   186  		}
   187  
   188  		newHold := now.Add(dur)
   189  		cutOff := lastRefreshTime.Add(maxPostponement - maxPostponementBuffer)
   190  
   191  		// consider last refresh time and adjust hold duration if needed so it's
   192  		// not exceeded.
   193  		if newHold.Before(cutOff) {
   194  			hold.HoldUntil = newHold
   195  		} else {
   196  			hold.HoldUntil = cutOff
   197  		}
   198  
   199  		// finally store/update gating hold data
   200  		if _, ok := gating[heldSnap]; !ok {
   201  			gating[heldSnap] = make(map[string]*holdState)
   202  		}
   203  		gating[heldSnap][gatingSnap] = hold
   204  	}
   205  
   206  	if len(herr.SnapsInError) != len(affectingSnaps) {
   207  		st.Set("snaps-hold", gating)
   208  	}
   209  	if len(herr.SnapsInError) > 0 {
   210  		return herr
   211  	}
   212  	return nil
   213  }
   214  
   215  // ProceedWithRefresh unblocks all snaps held by gatingSnap for refresh. This
   216  // should be called for --proceed on the gatingSnap.
   217  func ProceedWithRefresh(st *state.State, gatingSnap string) error {
   218  	gating, err := refreshGating(st)
   219  	if err != nil {
   220  		return err
   221  	}
   222  	if len(gating) == 0 {
   223  		return nil
   224  	}
   225  
   226  	var changed bool
   227  	for heldSnap, gatingSnaps := range gating {
   228  		if _, ok := gatingSnaps[gatingSnap]; ok {
   229  			delete(gatingSnaps, gatingSnap)
   230  			changed = true
   231  		}
   232  		if len(gatingSnaps) == 0 {
   233  			delete(gating, heldSnap)
   234  		}
   235  	}
   236  
   237  	if changed {
   238  		st.Set("snaps-hold", gating)
   239  	}
   240  	return nil
   241  }
   242  
   243  // resetGatingForRefreshed resets gating information by removing refreshedSnaps
   244  // (they are not held anymore). This should be called for all successfully
   245  // refreshed snaps.
   246  func resetGatingForRefreshed(st *state.State, refreshedSnaps ...string) error {
   247  	gating, err := refreshGating(st)
   248  	if err != nil {
   249  		return err
   250  	}
   251  	if len(gating) == 0 {
   252  		return nil
   253  	}
   254  
   255  	var changed bool
   256  	for _, snapName := range refreshedSnaps {
   257  		if _, ok := gating[snapName]; ok {
   258  			delete(gating, snapName)
   259  			changed = true
   260  		}
   261  	}
   262  
   263  	if changed {
   264  		st.Set("snaps-hold", gating)
   265  	}
   266  	return nil
   267  }
   268  
   269  // heldSnaps returns all snaps that are gated and shouldn't be refreshed.
   270  func heldSnaps(st *state.State) (map[string]bool, error) {
   271  	gating, err := refreshGating(st)
   272  	if err != nil {
   273  		return nil, err
   274  	}
   275  	if len(gating) == 0 {
   276  		return nil, nil
   277  	}
   278  
   279  	now := timeNow()
   280  
   281  	held := make(map[string]bool)
   282  Loop:
   283  	for heldSnap, holdingSnaps := range gating {
   284  		refreshed, err := lastRefreshed(st, heldSnap)
   285  		if err != nil {
   286  			return nil, err
   287  		}
   288  		// make sure we don't hold any snap for more than maxPostponement
   289  		if refreshed.Add(maxPostponement).Before(now) {
   290  			continue
   291  		}
   292  		for _, hold := range holdingSnaps {
   293  			if hold.HoldUntil.Before(now) {
   294  				continue
   295  			}
   296  			held[heldSnap] = true
   297  			continue Loop
   298  		}
   299  	}
   300  	return held, nil
   301  }
   302  
   303  type affectedSnapInfo struct {
   304  	Restart        bool
   305  	Base           bool
   306  	AffectingSnaps map[string]bool
   307  }
   308  
   309  func affectedByRefresh(st *state.State, updates []*snap.Info) (map[string]*affectedSnapInfo, error) {
   310  	all, err := All(st)
   311  	if err != nil {
   312  		return nil, err
   313  	}
   314  
   315  	var bootBase string
   316  	if !release.OnClassic {
   317  		deviceCtx, err := DeviceCtx(st, nil, nil)
   318  		if err != nil {
   319  			return nil, fmt.Errorf("cannot get device context: %v", err)
   320  		}
   321  		bootBaseInfo, err := BootBaseInfo(st, deviceCtx)
   322  		if err != nil {
   323  			return nil, fmt.Errorf("cannot get boot base info: %v", err)
   324  		}
   325  		bootBase = bootBaseInfo.InstanceName()
   326  	}
   327  
   328  	byBase := make(map[string][]string)
   329  	for name, snapSt := range all {
   330  		if !snapSt.Active {
   331  			delete(all, name)
   332  			continue
   333  		}
   334  		inf, err := snapSt.CurrentInfo()
   335  		if err != nil {
   336  			return nil, err
   337  		}
   338  		// optimization: do not consider snaps that don't have gate-auto-refresh hook.
   339  		if inf.Hooks[gateAutoRefreshHookName] == nil {
   340  			delete(all, name)
   341  			continue
   342  		}
   343  
   344  		base := inf.Base
   345  		if base == "none" {
   346  			continue
   347  		}
   348  		if inf.Base == "" {
   349  			base = "core"
   350  		}
   351  		byBase[base] = append(byBase[base], inf.InstanceName())
   352  	}
   353  
   354  	affected := make(map[string]*affectedSnapInfo)
   355  
   356  	addAffected := func(snapName, affectedBy string, restart bool, base bool) {
   357  		if affected[snapName] == nil {
   358  			affected[snapName] = &affectedSnapInfo{
   359  				AffectingSnaps: map[string]bool{},
   360  			}
   361  		}
   362  		affectedInfo := affected[snapName]
   363  		if restart {
   364  			affectedInfo.Restart = restart
   365  		}
   366  		if base {
   367  			affectedInfo.Base = base
   368  		}
   369  		affectedInfo.AffectingSnaps[affectedBy] = true
   370  	}
   371  
   372  	for _, up := range updates {
   373  		// on core system, affected by update of boot base
   374  		if bootBase != "" && up.InstanceName() == bootBase {
   375  			for _, snapSt := range all {
   376  				addAffected(snapSt.InstanceName(), up.InstanceName(), true, false)
   377  			}
   378  		}
   379  
   380  		// snaps that can trigger reboot
   381  		// XXX: gadget refresh doesn't always require reboot, refine this
   382  		if up.Type() == snap.TypeKernel || up.Type() == snap.TypeGadget {
   383  			for _, snapSt := range all {
   384  				addAffected(snapSt.InstanceName(), up.InstanceName(), true, false)
   385  			}
   386  			continue
   387  		}
   388  		if up.Type() == snap.TypeBase || up.SnapName() == "core" {
   389  			// affected by refresh of this base snap
   390  			for _, snapName := range byBase[up.InstanceName()] {
   391  				addAffected(snapName, up.InstanceName(), false, true)
   392  			}
   393  		}
   394  
   395  		repo := ifacerepo.Get(st)
   396  
   397  		// consider slots provided by refreshed snap, but exclude core and snapd
   398  		// since they provide system-level slots that are generally not disrupted
   399  		// by snap updates.
   400  		if up.SnapType != snap.TypeSnapd && up.SnapName() != "core" {
   401  			for _, slotInfo := range up.Slots {
   402  				conns, err := repo.Connected(up.InstanceName(), slotInfo.Name)
   403  				if err != nil {
   404  					return nil, err
   405  				}
   406  				for _, cref := range conns {
   407  					// affected only if it wasn't optimized out above
   408  					if all[cref.PlugRef.Snap] != nil {
   409  						addAffected(cref.PlugRef.Snap, up.InstanceName(), true, false)
   410  					}
   411  				}
   412  			}
   413  		}
   414  
   415  		// consider mount backend plugs/slots;
   416  		// for slot side only consider snapd/core because they are ignored by the
   417  		// earlier loop around slots.
   418  		if up.SnapType == snap.TypeSnapd || up.SnapType == snap.TypeOS {
   419  			for _, slotInfo := range up.Slots {
   420  				iface := repo.Interface(slotInfo.Interface)
   421  				if iface == nil {
   422  					return nil, fmt.Errorf("internal error: unknown interface %s", slotInfo.Interface)
   423  				}
   424  				if !usesMountBackend(iface) {
   425  					continue
   426  				}
   427  				conns, err := repo.Connected(up.InstanceName(), slotInfo.Name)
   428  				if err != nil {
   429  					return nil, err
   430  				}
   431  				for _, cref := range conns {
   432  					if all[cref.PlugRef.Snap] != nil {
   433  						addAffected(cref.PlugRef.Snap, up.InstanceName(), true, false)
   434  					}
   435  				}
   436  			}
   437  		}
   438  		for _, plugInfo := range up.Plugs {
   439  			iface := repo.Interface(plugInfo.Interface)
   440  			if iface == nil {
   441  				return nil, fmt.Errorf("internal error: unknown interface %s", plugInfo.Interface)
   442  			}
   443  			if !usesMountBackend(iface) {
   444  				continue
   445  			}
   446  			conns, err := repo.Connected(up.InstanceName(), plugInfo.Name)
   447  			if err != nil {
   448  				return nil, err
   449  			}
   450  			for _, cref := range conns {
   451  				if all[cref.SlotRef.Snap] != nil {
   452  					addAffected(cref.SlotRef.Snap, up.InstanceName(), true, false)
   453  				}
   454  			}
   455  		}
   456  	}
   457  
   458  	return affected, nil
   459  }
   460  
   461  // XXX: this is too wide and affects all commonInterface-based interfaces; we
   462  // need metadata on the relevant interfaces.
   463  func usesMountBackend(iface interfaces.Interface) bool {
   464  	type definer1 interface {
   465  		MountConnectedSlot(*mount.Specification, *interfaces.ConnectedPlug, *interfaces.ConnectedSlot) error
   466  	}
   467  	type definer2 interface {
   468  		MountConnectedPlug(*mount.Specification, *interfaces.ConnectedPlug, *interfaces.ConnectedSlot) error
   469  	}
   470  	type definer3 interface {
   471  		MountPermanentPlug(*mount.Specification, *snap.PlugInfo) error
   472  	}
   473  	type definer4 interface {
   474  		MountPermanentSlot(*mount.Specification, *snap.SlotInfo) error
   475  	}
   476  
   477  	if _, ok := iface.(definer1); ok {
   478  		return true
   479  	}
   480  	if _, ok := iface.(definer2); ok {
   481  		return true
   482  	}
   483  	if _, ok := iface.(definer3); ok {
   484  		return true
   485  	}
   486  	if _, ok := iface.(definer4); ok {
   487  		return true
   488  	}
   489  	return false
   490  }
   491  
   492  // createGateAutoRefreshHooks creates gate-auto-refresh hooks for all affectedSnaps.
   493  // The hooks will have their context data set from affectedSnapInfo flags (base, restart).
   494  // Hook tasks will be chained to run sequentially.
   495  func createGateAutoRefreshHooks(st *state.State, affectedSnaps map[string]*affectedSnapInfo) *state.TaskSet {
   496  	ts := state.NewTaskSet()
   497  	var prev *state.Task
   498  	// sort names for easy testing
   499  	names := make([]string, 0, len(affectedSnaps))
   500  	for snapName := range affectedSnaps {
   501  		names = append(names, snapName)
   502  	}
   503  	sort.Strings(names)
   504  	for _, snapName := range names {
   505  		affected := affectedSnaps[snapName]
   506  		hookTask := SetupGateAutoRefreshHook(st, snapName, affected.Base, affected.Restart)
   507  		// XXX: it should be fine to run the hooks in parallel
   508  		if prev != nil {
   509  			hookTask.WaitFor(prev)
   510  		}
   511  		ts.AddTask(hookTask)
   512  		prev = hookTask
   513  	}
   514  	return ts
   515  }
   516  
   517  // snapsToRefresh returns all snaps that should proceed with refresh considering
   518  // hold behavior.
   519  var snapsToRefresh = func(gatingTask *state.Task) ([]*refreshCandidate, error) {
   520  	// TODO: consider holding (responses from gating hooks) here.
   521  	// Return all refresh candidates for now.
   522  	var snaps map[string]*refreshCandidate
   523  	if err := gatingTask.Get("snaps", &snaps); err != nil {
   524  		return nil, err
   525  	}
   526  	var candidates []*refreshCandidate
   527  	for _, s := range snaps {
   528  		candidates = append(candidates, s)
   529  	}
   530  	return candidates, nil
   531  }