github.com/meulengracht/snapd@v0.0.0-20210719210640-8bde69bcc84e/overlord/snapstate/autorefresh_gating.go (about) 1 // -*- Mode: Go; indent-tabs-mode: t -*- 2 3 /* 4 * Copyright (C) 2021 Canonical Ltd 5 * 6 * This program is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU General Public License version 3 as 8 * published by the Free Software Foundation. 9 * 10 * This program is distributed in the hope that it will be useful, 11 * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 * GNU General Public License for more details. 14 * 15 * You should have received a copy of the GNU General Public License 16 * along with this program. If not, see <http://www.gnu.org/licenses/>. 17 * 18 */ 19 20 package snapstate 21 22 import ( 23 "encoding/json" 24 "fmt" 25 "os" 26 "sort" 27 "strings" 28 "time" 29 30 "github.com/snapcore/snapd/interfaces" 31 "github.com/snapcore/snapd/interfaces/mount" 32 "github.com/snapcore/snapd/logger" 33 "github.com/snapcore/snapd/overlord/ifacestate/ifacerepo" 34 "github.com/snapcore/snapd/overlord/state" 35 "github.com/snapcore/snapd/release" 36 "github.com/snapcore/snapd/snap" 37 ) 38 39 var gateAutoRefreshHookName = "gate-auto-refresh" 40 41 // gateAutoRefreshAction represents the action executed by 42 // snapctl refresh --hold or --proceed and stored in the context of 43 // gate-auto-refresh hook. 44 type GateAutoRefreshAction int 45 46 const ( 47 GateAutoRefreshProceed GateAutoRefreshAction = iota 48 GateAutoRefreshHold 49 ) 50 51 // cumulative hold time for snaps other than self 52 const maxOtherHoldDuration = time.Hour * 48 53 54 var timeNow = func() time.Time { 55 return time.Now() 56 } 57 58 func lastRefreshed(st *state.State, snapName string) (time.Time, error) { 59 var snapst SnapState 60 if err := Get(st, snapName, &snapst); err != nil { 61 return time.Time{}, fmt.Errorf("internal error, cannot get snap %q: %v", snapName, err) 62 } 63 // try to get last refresh time from snapstate, but it may not be present 64 // for snaps installed before the introduction of last-refresh attribute. 65 if snapst.LastRefreshTime != nil { 66 return *snapst.LastRefreshTime, nil 67 } 68 snapInfo, err := snapst.CurrentInfo() 69 if err != nil { 70 return time.Time{}, err 71 } 72 // fall back to the modification time of .snap blob file as it's the best 73 // approximation of last refresh time. 74 fst, err := os.Stat(snapInfo.MountFile()) 75 if err != nil { 76 return time.Time{}, err 77 } 78 return fst.ModTime(), nil 79 } 80 81 type holdState struct { 82 // FirstHeld keeps the time when the given snap was first held for refresh by a gating snap. 83 FirstHeld time.Time `json:"first-held"` 84 // HoldUntil stores the desired end time for holding. 85 HoldUntil time.Time `json:"hold-until"` 86 } 87 88 func refreshGating(st *state.State) (map[string]map[string]*holdState, error) { 89 // held snaps -> holding snap(s) -> first-held/hold-until time 90 var gating map[string]map[string]*holdState 91 err := st.Get("snaps-hold", &gating) 92 if err != nil && err != state.ErrNoState { 93 return nil, fmt.Errorf("internal error: cannot get snaps-hold: %v", err) 94 } 95 if err == state.ErrNoState { 96 return make(map[string]map[string]*holdState), nil 97 } 98 return gating, nil 99 } 100 101 // HoldDurationError contains the that error prevents requested hold, along with 102 // hold time that's left (if any). 103 type HoldDurationError struct { 104 Err error 105 DurationLeft time.Duration 106 } 107 108 func (h *HoldDurationError) Error() string { 109 return h.Err.Error() 110 } 111 112 // HoldError contains the details of snaps that cannot to be held. 113 type HoldError struct { 114 SnapsInError map[string]HoldDurationError 115 } 116 117 func (h *HoldError) Error() string { 118 l := []string{""} 119 for _, e := range h.SnapsInError { 120 l = append(l, e.Error()) 121 } 122 return fmt.Sprintf("cannot hold some snaps:%s", strings.Join(l, "\n - ")) 123 } 124 125 func maxAllowedPostponement(gatingSnap, affectedSnap string, maxPostponement time.Duration) time.Duration { 126 if affectedSnap == gatingSnap { 127 return maxPostponement 128 } 129 return maxOtherHoldDuration 130 } 131 132 // holdDurationLeft computes the maximum duration that's left for holding a refresh 133 // given current time, last refresh time, time when snap was first held, maximum 134 // duration allowed for the given snap and maximum overall postponement allowed by 135 // snapd. 136 func holdDurationLeft(now time.Time, lastRefresh, firstHeld time.Time, maxDuration, maxPostponement time.Duration) time.Duration { 137 d1 := firstHeld.Add(maxDuration).Sub(now) 138 d2 := lastRefresh.Add(maxPostponement).Sub(now) 139 if d1 < d2 { 140 return d1 141 } 142 return d2 143 } 144 145 // HoldRefresh marks affectingSnaps as held for refresh for up to holdTime. 146 // HoldTime of zero denotes maximum allowed hold time. 147 // Holding may fail for only some snaps in which case HoldError is returned and 148 // it contains the details of failed ones. 149 func HoldRefresh(st *state.State, gatingSnap string, holdDuration time.Duration, affectingSnaps ...string) error { 150 gating, err := refreshGating(st) 151 if err != nil { 152 return err 153 } 154 herr := &HoldError{ 155 SnapsInError: make(map[string]HoldDurationError), 156 } 157 now := timeNow() 158 for _, heldSnap := range affectingSnaps { 159 hold, ok := gating[heldSnap][gatingSnap] 160 if !ok { 161 hold = &holdState{ 162 FirstHeld: now, 163 } 164 } 165 166 lastRefreshTime, err := lastRefreshed(st, heldSnap) 167 if err != nil { 168 return err 169 } 170 171 mp := maxPostponement - maxPostponementBuffer 172 maxDur := maxAllowedPostponement(gatingSnap, heldSnap, mp) 173 174 // calculate max hold duration that's left considering previous hold 175 // requests of this snap and last refresh time. 176 left := holdDurationLeft(now, lastRefreshTime, hold.FirstHeld, maxDur, mp) 177 if left <= 0 { 178 herr.SnapsInError[heldSnap] = HoldDurationError{ 179 Err: fmt.Errorf("snap %q cannot hold snap %q anymore, maximum refresh postponement exceeded", gatingSnap, heldSnap), 180 } 181 continue 182 } 183 184 dur := holdDuration 185 if dur == 0 { 186 // duration not specified, using a default one (maximum) or what's 187 // left of it. 188 dur = left 189 } else { 190 // explicit hold duration requested 191 if dur > maxDur { 192 herr.SnapsInError[heldSnap] = HoldDurationError{ 193 Err: fmt.Errorf("requested holding duration for snap %q of %s by snap %q exceeds maximum holding time", heldSnap, holdDuration, gatingSnap), 194 DurationLeft: left, 195 } 196 continue 197 } 198 } 199 200 newHold := now.Add(dur) 201 cutOff := lastRefreshTime.Add(maxPostponement - maxPostponementBuffer) 202 203 // consider last refresh time and adjust hold duration if needed so it's 204 // not exceeded. 205 if newHold.Before(cutOff) { 206 hold.HoldUntil = newHold 207 } else { 208 hold.HoldUntil = cutOff 209 } 210 211 // finally store/update gating hold data 212 if _, ok := gating[heldSnap]; !ok { 213 gating[heldSnap] = make(map[string]*holdState) 214 } 215 gating[heldSnap][gatingSnap] = hold 216 } 217 218 if len(herr.SnapsInError) != len(affectingSnaps) { 219 st.Set("snaps-hold", gating) 220 } 221 if len(herr.SnapsInError) > 0 { 222 return herr 223 } 224 return nil 225 } 226 227 // ProceedWithRefresh unblocks all snaps held by gatingSnap for refresh. This 228 // should be called for --proceed on the gatingSnap. 229 func ProceedWithRefresh(st *state.State, gatingSnap string) error { 230 gating, err := refreshGating(st) 231 if err != nil { 232 return err 233 } 234 if len(gating) == 0 { 235 return nil 236 } 237 238 var changed bool 239 for heldSnap, gatingSnaps := range gating { 240 if _, ok := gatingSnaps[gatingSnap]; ok { 241 delete(gatingSnaps, gatingSnap) 242 changed = true 243 } 244 if len(gatingSnaps) == 0 { 245 delete(gating, heldSnap) 246 } 247 } 248 249 if changed { 250 st.Set("snaps-hold", gating) 251 } 252 return nil 253 } 254 255 // pruneGating removes affecting snaps that are not in candidates (meaning 256 // there is no update for them anymore). 257 func pruneGating(st *state.State, candidates map[string]*refreshCandidate) error { 258 gating, err := refreshGating(st) 259 if err != nil { 260 return err 261 } 262 263 if len(gating) == 0 { 264 return nil 265 } 266 267 var changed bool 268 for affectingSnap := range gating { 269 if candidates[affectingSnap] == nil { 270 // the snap doesn't have an update anymore, forget it 271 delete(gating, affectingSnap) 272 changed = true 273 } 274 } 275 if changed { 276 st.Set("snaps-hold", gating) 277 } 278 return nil 279 } 280 281 // resetGatingForRefreshed resets gating information by removing refreshedSnaps 282 // (they are not held anymore). This should be called for snaps about to be 283 // refreshed. 284 func resetGatingForRefreshed(st *state.State, refreshedSnaps ...string) error { 285 gating, err := refreshGating(st) 286 if err != nil { 287 return err 288 } 289 if len(gating) == 0 { 290 return nil 291 } 292 293 var changed bool 294 for _, snapName := range refreshedSnaps { 295 if _, ok := gating[snapName]; ok { 296 delete(gating, snapName) 297 changed = true 298 } 299 } 300 301 if changed { 302 st.Set("snaps-hold", gating) 303 } 304 return nil 305 } 306 307 // pruneSnapsHold removes the given snap from snaps-hold, whether it was an 308 // affecting snap or gating snap. This should be called when a snap gets 309 // removed. 310 func pruneSnapsHold(st *state.State, snapName string) error { 311 gating, err := refreshGating(st) 312 if err != nil { 313 return err 314 } 315 if len(gating) == 0 { 316 return nil 317 } 318 319 var changed bool 320 321 if _, ok := gating[snapName]; ok { 322 delete(gating, snapName) 323 changed = true 324 } 325 326 for heldSnap, holdingSnaps := range gating { 327 if _, ok := holdingSnaps[snapName]; ok { 328 delete(holdingSnaps, snapName) 329 if len(holdingSnaps) == 0 { 330 delete(gating, heldSnap) 331 } 332 changed = true 333 } 334 } 335 336 if changed { 337 st.Set("snaps-hold", gating) 338 } 339 340 return nil 341 } 342 343 // heldSnaps returns all snaps that are gated and shouldn't be refreshed. 344 func heldSnaps(st *state.State) (map[string]bool, error) { 345 gating, err := refreshGating(st) 346 if err != nil { 347 return nil, err 348 } 349 if len(gating) == 0 { 350 return nil, nil 351 } 352 353 now := timeNow() 354 355 held := make(map[string]bool) 356 Loop: 357 for heldSnap, holdingSnaps := range gating { 358 refreshed, err := lastRefreshed(st, heldSnap) 359 if err != nil { 360 return nil, err 361 } 362 // make sure we don't hold any snap for more than maxPostponement 363 if refreshed.Add(maxPostponement).Before(now) { 364 continue 365 } 366 for _, hold := range holdingSnaps { 367 if hold.HoldUntil.Before(now) { 368 continue 369 } 370 held[heldSnap] = true 371 continue Loop 372 } 373 } 374 return held, nil 375 } 376 377 type affectedSnapInfo struct { 378 Restart bool 379 Base bool 380 AffectingSnaps map[string]bool 381 } 382 383 func affectedByRefresh(st *state.State, updates []*snap.Info) (map[string]*affectedSnapInfo, error) { 384 all, err := All(st) 385 if err != nil { 386 return nil, err 387 } 388 389 var bootBase string 390 if !release.OnClassic { 391 deviceCtx, err := DeviceCtx(st, nil, nil) 392 if err != nil { 393 return nil, fmt.Errorf("cannot get device context: %v", err) 394 } 395 bootBaseInfo, err := BootBaseInfo(st, deviceCtx) 396 if err != nil { 397 return nil, fmt.Errorf("cannot get boot base info: %v", err) 398 } 399 bootBase = bootBaseInfo.InstanceName() 400 } 401 402 byBase := make(map[string][]string) 403 for name, snapSt := range all { 404 if !snapSt.Active { 405 delete(all, name) 406 continue 407 } 408 inf, err := snapSt.CurrentInfo() 409 if err != nil { 410 return nil, err 411 } 412 // optimization: do not consider snaps that don't have gate-auto-refresh hook. 413 if inf.Hooks[gateAutoRefreshHookName] == nil { 414 delete(all, name) 415 continue 416 } 417 418 base := inf.Base 419 if base == "none" { 420 continue 421 } 422 if inf.Base == "" { 423 base = "core" 424 } 425 byBase[base] = append(byBase[base], inf.InstanceName()) 426 } 427 428 affected := make(map[string]*affectedSnapInfo) 429 430 addAffected := func(snapName, affectedBy string, restart bool, base bool) { 431 if affected[snapName] == nil { 432 affected[snapName] = &affectedSnapInfo{ 433 AffectingSnaps: map[string]bool{}, 434 } 435 } 436 affectedInfo := affected[snapName] 437 if restart { 438 affectedInfo.Restart = restart 439 } 440 if base { 441 affectedInfo.Base = base 442 } 443 affectedInfo.AffectingSnaps[affectedBy] = true 444 } 445 446 for _, up := range updates { 447 // on core system, affected by update of boot base 448 if bootBase != "" && up.InstanceName() == bootBase { 449 for _, snapSt := range all { 450 addAffected(snapSt.InstanceName(), up.InstanceName(), true, false) 451 } 452 } 453 454 // snaps that can trigger reboot 455 // XXX: gadget refresh doesn't always require reboot, refine this 456 if up.Type() == snap.TypeKernel || up.Type() == snap.TypeGadget { 457 for _, snapSt := range all { 458 addAffected(snapSt.InstanceName(), up.InstanceName(), true, false) 459 } 460 continue 461 } 462 if up.Type() == snap.TypeBase || up.SnapName() == "core" { 463 // affected by refresh of this base snap 464 for _, snapName := range byBase[up.InstanceName()] { 465 addAffected(snapName, up.InstanceName(), false, true) 466 } 467 } 468 469 repo := ifacerepo.Get(st) 470 471 // consider slots provided by refreshed snap, but exclude core and snapd 472 // since they provide system-level slots that are generally not disrupted 473 // by snap updates. 474 if up.SnapType != snap.TypeSnapd && up.SnapName() != "core" { 475 for _, slotInfo := range up.Slots { 476 conns, err := repo.Connected(up.InstanceName(), slotInfo.Name) 477 if err != nil { 478 return nil, err 479 } 480 for _, cref := range conns { 481 // affected only if it wasn't optimized out above 482 if all[cref.PlugRef.Snap] != nil { 483 addAffected(cref.PlugRef.Snap, up.InstanceName(), true, false) 484 } 485 } 486 } 487 } 488 489 // consider mount backend plugs/slots; 490 // for slot side only consider snapd/core because they are ignored by the 491 // earlier loop around slots. 492 if up.SnapType == snap.TypeSnapd || up.SnapType == snap.TypeOS { 493 for _, slotInfo := range up.Slots { 494 iface := repo.Interface(slotInfo.Interface) 495 if iface == nil { 496 return nil, fmt.Errorf("internal error: unknown interface %s", slotInfo.Interface) 497 } 498 if !usesMountBackend(iface) { 499 continue 500 } 501 conns, err := repo.Connected(up.InstanceName(), slotInfo.Name) 502 if err != nil { 503 return nil, err 504 } 505 for _, cref := range conns { 506 if all[cref.PlugRef.Snap] != nil { 507 addAffected(cref.PlugRef.Snap, up.InstanceName(), true, false) 508 } 509 } 510 } 511 } 512 for _, plugInfo := range up.Plugs { 513 iface := repo.Interface(plugInfo.Interface) 514 if iface == nil { 515 return nil, fmt.Errorf("internal error: unknown interface %s", plugInfo.Interface) 516 } 517 if !usesMountBackend(iface) { 518 continue 519 } 520 conns, err := repo.Connected(up.InstanceName(), plugInfo.Name) 521 if err != nil { 522 return nil, err 523 } 524 for _, cref := range conns { 525 if all[cref.SlotRef.Snap] != nil { 526 addAffected(cref.SlotRef.Snap, up.InstanceName(), true, false) 527 } 528 } 529 } 530 } 531 532 return affected, nil 533 } 534 535 // XXX: this is too wide and affects all commonInterface-based interfaces; we 536 // need metadata on the relevant interfaces. 537 func usesMountBackend(iface interfaces.Interface) bool { 538 type definer1 interface { 539 MountConnectedSlot(*mount.Specification, *interfaces.ConnectedPlug, *interfaces.ConnectedSlot) error 540 } 541 type definer2 interface { 542 MountConnectedPlug(*mount.Specification, *interfaces.ConnectedPlug, *interfaces.ConnectedSlot) error 543 } 544 type definer3 interface { 545 MountPermanentPlug(*mount.Specification, *snap.PlugInfo) error 546 } 547 type definer4 interface { 548 MountPermanentSlot(*mount.Specification, *snap.SlotInfo) error 549 } 550 551 if _, ok := iface.(definer1); ok { 552 return true 553 } 554 if _, ok := iface.(definer2); ok { 555 return true 556 } 557 if _, ok := iface.(definer3); ok { 558 return true 559 } 560 if _, ok := iface.(definer4); ok { 561 return true 562 } 563 return false 564 } 565 566 // createGateAutoRefreshHooks creates gate-auto-refresh hooks for all affectedSnaps. 567 // The hooks will have their context data set from affectedSnapInfo flags (base, restart). 568 // Hook tasks will be chained to run sequentially. 569 func createGateAutoRefreshHooks(st *state.State, affectedSnaps map[string]*affectedSnapInfo) *state.TaskSet { 570 ts := state.NewTaskSet() 571 var prev *state.Task 572 // sort names for easy testing 573 names := make([]string, 0, len(affectedSnaps)) 574 for snapName := range affectedSnaps { 575 names = append(names, snapName) 576 } 577 sort.Strings(names) 578 for _, snapName := range names { 579 affected := affectedSnaps[snapName] 580 hookTask := SetupGateAutoRefreshHook(st, snapName, affected.Base, affected.Restart, affected.AffectingSnaps) 581 // XXX: it should be fine to run the hooks in parallel 582 if prev != nil { 583 hookTask.WaitFor(prev) 584 } 585 ts.AddTask(hookTask) 586 prev = hookTask 587 } 588 return ts 589 } 590 591 func conditionalAutoRefreshAffectedSnaps(t *state.Task) ([]string, error) { 592 var snaps map[string]*json.RawMessage 593 if err := t.Get("snaps", &snaps); err != nil { 594 return nil, fmt.Errorf("internal error: cannot get snaps to update for %s task %s", t.Kind(), t.ID()) 595 } 596 names := make([]string, 0, len(snaps)) 597 for sn := range snaps { 598 // TODO: drop snaps once we know the outcome of gate-auto-refresh hooks. 599 names = append(names, sn) 600 } 601 return names, nil 602 } 603 604 // snapsToRefresh returns all snaps that should proceed with refresh considering 605 // hold behavior. 606 var snapsToRefresh = func(gatingTask *state.Task) ([]*refreshCandidate, error) { 607 var snaps map[string]*refreshCandidate 608 if err := gatingTask.Get("snaps", &snaps); err != nil { 609 return nil, err 610 } 611 612 held, err := heldSnaps(gatingTask.State()) 613 if err != nil { 614 return nil, err 615 } 616 617 var skipped []string 618 var candidates []*refreshCandidate 619 for _, s := range snaps { 620 if !held[s.InstanceName()] { 621 candidates = append(candidates, s) 622 } else { 623 skipped = append(skipped, s.InstanceName()) 624 } 625 } 626 627 if len(skipped) > 0 { 628 sort.Strings(skipped) 629 logger.Noticef("skipping refresh of held snaps: %s", strings.Join(skipped, ",")) 630 } 631 632 return candidates, nil 633 }