gopkg.in/ubuntu-core/snappy.v0@v0.0.0-20210902073436-25a8614f10a6/overlord/state/state.go (about) 1 // -*- Mode: Go; indent-tabs-mode: t -*- 2 3 /* 4 * Copyright (C) 2016-2020 Canonical Ltd 5 * 6 * This program is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU General Public License version 3 as 8 * published by the Free Software Foundation. 9 * 10 * This program is distributed in the hope that it will be useful, 11 * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 * GNU General Public License for more details. 14 * 15 * You should have received a copy of the GNU General Public License 16 * along with this program. If not, see <http://www.gnu.org/licenses/>. 17 * 18 */ 19 20 // Package state implements the representation of system state. 21 package state 22 23 import ( 24 "encoding/json" 25 "errors" 26 "fmt" 27 "io" 28 "sort" 29 "strconv" 30 "sync" 31 "sync/atomic" 32 "time" 33 34 "github.com/snapcore/snapd/logger" 35 ) 36 37 // A Backend is used by State to checkpoint on every unlock operation 38 // and to mediate requests to ensure the state sooner or request restarts. 39 type Backend interface { 40 Checkpoint(data []byte) error 41 EnsureBefore(d time.Duration) 42 // TODO: take flags to ask for reboot vs restart? 43 RequestRestart(t RestartType) 44 } 45 46 type customData map[string]*json.RawMessage 47 48 func (data customData) get(key string, value interface{}) error { 49 entryJSON := data[key] 50 if entryJSON == nil { 51 return ErrNoState 52 } 53 err := json.Unmarshal(*entryJSON, value) 54 if err != nil { 55 return fmt.Errorf("internal error: could not unmarshal state entry %q: %v", key, err) 56 } 57 return nil 58 } 59 60 func (data customData) has(key string) bool { 61 return data[key] != nil 62 } 63 64 func (data customData) set(key string, value interface{}) { 65 if value == nil { 66 delete(data, key) 67 return 68 } 69 serialized, err := json.Marshal(value) 70 if err != nil { 71 logger.Panicf("internal error: could not marshal value for state entry %q: %v", key, err) 72 } 73 entryJSON := json.RawMessage(serialized) 74 data[key] = &entryJSON 75 } 76 77 type RestartType int 78 79 const ( 80 RestartUnset RestartType = iota 81 RestartDaemon 82 RestartSystem 83 // RestartSystemNow is like RestartSystem but action is immediate 84 RestartSystemNow 85 // RestartSocket will restart the daemon so that it goes into 86 // socket activation mode. 87 RestartSocket 88 // Stop just stops the daemon (used with image pre-seeding) 89 StopDaemon 90 // RestartSystemHaltNow will shutdown --halt the system asap 91 RestartSystemHaltNow 92 // RestartSystemPoweroffNow will shutdown --poweroff the system asap 93 RestartSystemPoweroffNow 94 ) 95 96 // State represents an evolving system state that persists across restarts. 97 // 98 // The State is concurrency-safe, and all reads and writes to it must be 99 // performed with the state locked. It's a runtime error (panic) to perform 100 // operations without it. 101 // 102 // The state is persisted on every unlock operation via the StateBackend 103 // it was initialized with. 104 type State struct { 105 mu sync.Mutex 106 muC int32 107 108 lastTaskId int 109 lastChangeId int 110 lastLaneId int 111 112 backend Backend 113 data customData 114 changes map[string]*Change 115 tasks map[string]*Task 116 warnings map[string]*Warning 117 118 modified bool 119 120 cache map[interface{}]interface{} 121 122 restarting RestartType 123 restartLck sync.Mutex 124 bootID string 125 } 126 127 // New returns a new empty state. 128 func New(backend Backend) *State { 129 return &State{ 130 backend: backend, 131 data: make(customData), 132 changes: make(map[string]*Change), 133 tasks: make(map[string]*Task), 134 warnings: make(map[string]*Warning), 135 modified: true, 136 cache: make(map[interface{}]interface{}), 137 } 138 } 139 140 // Modified returns whether the state was modified since the last checkpoint. 141 func (s *State) Modified() bool { 142 return s.modified 143 } 144 145 // Lock acquires the state lock. 146 func (s *State) Lock() { 147 s.mu.Lock() 148 atomic.AddInt32(&s.muC, 1) 149 } 150 151 func (s *State) reading() { 152 if atomic.LoadInt32(&s.muC) != 1 { 153 panic("internal error: accessing state without lock") 154 } 155 } 156 157 func (s *State) writing() { 158 s.modified = true 159 if atomic.LoadInt32(&s.muC) != 1 { 160 panic("internal error: accessing state without lock") 161 } 162 } 163 164 func (s *State) unlock() { 165 atomic.AddInt32(&s.muC, -1) 166 s.mu.Unlock() 167 } 168 169 type marshalledState struct { 170 Data map[string]*json.RawMessage `json:"data"` 171 Changes map[string]*Change `json:"changes"` 172 Tasks map[string]*Task `json:"tasks"` 173 Warnings []*Warning `json:"warnings,omitempty"` 174 175 LastChangeId int `json:"last-change-id"` 176 LastTaskId int `json:"last-task-id"` 177 LastLaneId int `json:"last-lane-id"` 178 } 179 180 // MarshalJSON makes State a json.Marshaller 181 func (s *State) MarshalJSON() ([]byte, error) { 182 s.reading() 183 return json.Marshal(marshalledState{ 184 Data: s.data, 185 Changes: s.changes, 186 Tasks: s.tasks, 187 Warnings: s.flattenWarnings(), 188 189 LastTaskId: s.lastTaskId, 190 LastChangeId: s.lastChangeId, 191 LastLaneId: s.lastLaneId, 192 }) 193 } 194 195 // UnmarshalJSON makes State a json.Unmarshaller 196 func (s *State) UnmarshalJSON(data []byte) error { 197 s.writing() 198 var unmarshalled marshalledState 199 err := json.Unmarshal(data, &unmarshalled) 200 if err != nil { 201 return err 202 } 203 s.data = unmarshalled.Data 204 s.changes = unmarshalled.Changes 205 s.tasks = unmarshalled.Tasks 206 s.unflattenWarnings(unmarshalled.Warnings) 207 s.lastChangeId = unmarshalled.LastChangeId 208 s.lastTaskId = unmarshalled.LastTaskId 209 s.lastLaneId = unmarshalled.LastLaneId 210 // backlink state again 211 for _, t := range s.tasks { 212 t.state = s 213 } 214 for _, chg := range s.changes { 215 chg.state = s 216 chg.finishUnmarshal() 217 } 218 return nil 219 } 220 221 func (s *State) checkpointData() []byte { 222 data, err := json.Marshal(s) 223 if err != nil { 224 // this shouldn't happen, because the actual delicate serializing happens at various Set()s 225 logger.Panicf("internal error: could not marshal state for checkpointing: %v", err) 226 } 227 return data 228 } 229 230 // unlock checkpoint retry parameters (5 mins of retries by default) 231 var ( 232 unlockCheckpointRetryMaxTime = 5 * time.Minute 233 unlockCheckpointRetryInterval = 3 * time.Second 234 ) 235 236 // Unlock releases the state lock and checkpoints the state. 237 // It does not return until the state is correctly checkpointed. 238 // After too many unsuccessful checkpoint attempts, it panics. 239 func (s *State) Unlock() { 240 defer s.unlock() 241 242 if !s.modified || s.backend == nil { 243 return 244 } 245 246 data := s.checkpointData() 247 var err error 248 start := time.Now() 249 for time.Since(start) <= unlockCheckpointRetryMaxTime { 250 if err = s.backend.Checkpoint(data); err == nil { 251 s.modified = false 252 return 253 } 254 time.Sleep(unlockCheckpointRetryInterval) 255 } 256 logger.Panicf("cannot checkpoint even after %v of retries every %v: %v", unlockCheckpointRetryMaxTime, unlockCheckpointRetryInterval, err) 257 } 258 259 // EnsureBefore asks for an ensure pass to happen sooner within duration from now. 260 func (s *State) EnsureBefore(d time.Duration) { 261 if s.backend != nil { 262 s.backend.EnsureBefore(d) 263 } 264 } 265 266 // RequestRestart asks for a restart of the managing process. 267 // The state needs to be locked to request a RestartSystem. 268 func (s *State) RequestRestart(t RestartType) { 269 if s.backend != nil { 270 switch t { 271 case RestartSystem, RestartSystemNow, RestartSystemHaltNow, RestartSystemPoweroffNow: 272 if s.bootID == "" { 273 panic("internal error: cannot request a system restart if current boot ID was not provided via VerifyReboot") 274 } 275 s.Set("system-restart-from-boot-id", s.bootID) 276 } 277 s.restartLck.Lock() 278 s.restarting = t 279 s.restartLck.Unlock() 280 s.backend.RequestRestart(t) 281 } 282 } 283 284 // Restarting returns whether a restart was requested with RequestRestart and of which type. 285 func (s *State) Restarting() (bool, RestartType) { 286 s.restartLck.Lock() 287 defer s.restartLck.Unlock() 288 return s.restarting != RestartUnset, s.restarting 289 } 290 291 var ErrExpectedReboot = errors.New("expected reboot did not happen") 292 293 // VerifyReboot checks if the state remembers that a system restart was 294 // requested and whether it succeeded based on the provided current 295 // boot id. It returns ErrExpectedReboot if the expected reboot did 296 // not happen yet. It must be called early in the usage of state and 297 // before an RequestRestart with RestartSystem is attempted. 298 // It must be called with the state lock held. 299 func (s *State) VerifyReboot(curBootID string) error { 300 var fromBootID string 301 err := s.Get("system-restart-from-boot-id", &fromBootID) 302 if err != nil && err != ErrNoState { 303 return err 304 } 305 s.bootID = curBootID 306 if fromBootID == "" { 307 return nil 308 } 309 if fromBootID == curBootID { 310 return ErrExpectedReboot 311 } 312 // we rebooted alright 313 s.ClearReboot() 314 return nil 315 } 316 317 // ClearReboot clears state information about tracking requested reboots. 318 func (s *State) ClearReboot() { 319 s.Set("system-restart-from-boot-id", nil) 320 } 321 322 func MockRestarting(s *State, restarting RestartType) RestartType { 323 s.restartLck.Lock() 324 defer s.restartLck.Unlock() 325 old := s.restarting 326 s.restarting = restarting 327 return old 328 } 329 330 // ErrNoState represents the case of no state entry for a given key. 331 var ErrNoState = errors.New("no state entry for key") 332 333 // Get unmarshals the stored value associated with the provided key 334 // into the value parameter. 335 // It returns ErrNoState if there is no entry for key. 336 func (s *State) Get(key string, value interface{}) error { 337 s.reading() 338 return s.data.get(key, value) 339 } 340 341 // Set associates value with key for future consulting by managers. 342 // The provided value must properly marshal and unmarshal with encoding/json. 343 func (s *State) Set(key string, value interface{}) { 344 s.writing() 345 s.data.set(key, value) 346 } 347 348 // Cached returns the cached value associated with the provided key. 349 // It returns nil if there is no entry for key. 350 func (s *State) Cached(key interface{}) interface{} { 351 s.reading() 352 return s.cache[key] 353 } 354 355 // Cache associates value with key for future consulting by managers. 356 // The cached value is not persisted. 357 func (s *State) Cache(key, value interface{}) { 358 s.reading() // Doesn't touch persisted data. 359 if value == nil { 360 delete(s.cache, key) 361 } else { 362 s.cache[key] = value 363 } 364 } 365 366 // NewChange adds a new change to the state. 367 func (s *State) NewChange(kind, summary string) *Change { 368 s.writing() 369 s.lastChangeId++ 370 id := strconv.Itoa(s.lastChangeId) 371 chg := newChange(s, id, kind, summary) 372 s.changes[id] = chg 373 return chg 374 } 375 376 // NewLane creates a new lane in the state. 377 func (s *State) NewLane() int { 378 s.writing() 379 s.lastLaneId++ 380 return s.lastLaneId 381 } 382 383 // Changes returns all changes currently known to the state. 384 func (s *State) Changes() []*Change { 385 s.reading() 386 res := make([]*Change, 0, len(s.changes)) 387 for _, chg := range s.changes { 388 res = append(res, chg) 389 } 390 return res 391 } 392 393 // Change returns the change for the given ID. 394 func (s *State) Change(id string) *Change { 395 s.reading() 396 return s.changes[id] 397 } 398 399 // NewTask creates a new task. 400 // It usually will be registered with a Change using AddTask or 401 // through a TaskSet. 402 func (s *State) NewTask(kind, summary string) *Task { 403 s.writing() 404 s.lastTaskId++ 405 id := strconv.Itoa(s.lastTaskId) 406 t := newTask(s, id, kind, summary) 407 s.tasks[id] = t 408 return t 409 } 410 411 // Tasks returns all tasks currently known to the state and linked to changes. 412 func (s *State) Tasks() []*Task { 413 s.reading() 414 res := make([]*Task, 0, len(s.tasks)) 415 for _, t := range s.tasks { 416 if t.Change() == nil { // skip unlinked tasks 417 continue 418 } 419 res = append(res, t) 420 } 421 return res 422 } 423 424 // Task returns the task for the given ID if the task has been linked to a change. 425 func (s *State) Task(id string) *Task { 426 s.reading() 427 t := s.tasks[id] 428 if t == nil || t.Change() == nil { 429 return nil 430 } 431 return t 432 } 433 434 // TaskCount returns the number of tasks that currently exist in the state, 435 // whether linked to a change or not. 436 func (s *State) TaskCount() int { 437 s.reading() 438 return len(s.tasks) 439 } 440 441 func (s *State) tasksIn(tids []string) []*Task { 442 res := make([]*Task, len(tids)) 443 for i, tid := range tids { 444 res[i] = s.tasks[tid] 445 } 446 return res 447 } 448 449 // Prune does several cleanup tasks to the in-memory state: 450 // 451 // * it removes changes that became ready for more than pruneWait and aborts 452 // tasks spawned for more than abortWait. 453 // 454 // * it removes tasks unlinked to changes after pruneWait. When there are more 455 // changes than the limit set via "maxReadyChanges" those changes in ready 456 // state will also removed even if they are below the pruneWait duration. 457 // 458 // * it removes expired warnings. 459 func (s *State) Prune(startOfOperation time.Time, pruneWait, abortWait time.Duration, maxReadyChanges int) { 460 now := time.Now() 461 pruneLimit := now.Add(-pruneWait) 462 abortLimit := now.Add(-abortWait) 463 464 // sort from oldest to newest 465 changes := s.Changes() 466 sort.Sort(byReadyTime(changes)) 467 468 readyChangesCount := 0 469 for i := range changes { 470 // changes are sorted (not-ready sorts first) 471 // so we know we can iterate in reverse and break once we 472 // find a ready time of "zero" 473 chg := changes[len(changes)-i-1] 474 if chg.ReadyTime().IsZero() { 475 break 476 } 477 readyChangesCount++ 478 } 479 480 for k, w := range s.warnings { 481 if w.ExpiredBefore(now) { 482 delete(s.warnings, k) 483 } 484 } 485 486 for _, chg := range changes { 487 readyTime := chg.ReadyTime() 488 spawnTime := chg.SpawnTime() 489 if spawnTime.Before(startOfOperation) { 490 spawnTime = startOfOperation 491 } 492 if readyTime.IsZero() { 493 if spawnTime.Before(pruneLimit) && len(chg.Tasks()) == 0 { 494 chg.Abort() 495 delete(s.changes, chg.ID()) 496 } else if spawnTime.Before(abortLimit) { 497 chg.Abort() 498 } 499 continue 500 } 501 // change old or we have too many changes 502 if readyTime.Before(pruneLimit) || readyChangesCount > maxReadyChanges { 503 s.writing() 504 for _, t := range chg.Tasks() { 505 delete(s.tasks, t.ID()) 506 } 507 delete(s.changes, chg.ID()) 508 readyChangesCount-- 509 } 510 } 511 512 for tid, t := range s.tasks { 513 // TODO: this could be done more aggressively 514 if t.Change() == nil && t.SpawnTime().Before(pruneLimit) { 515 s.writing() 516 delete(s.tasks, tid) 517 } 518 } 519 } 520 521 // GetMaybeTimings implements timings.GetSaver 522 func (s *State) GetMaybeTimings(timings interface{}) error { 523 err := s.Get("timings", timings) 524 if err != nil && err != ErrNoState { 525 return err 526 } 527 return nil 528 } 529 530 // SaveTimings implements timings.GetSaver 531 func (s *State) SaveTimings(timings interface{}) { 532 s.Set("timings", timings) 533 } 534 535 // ReadState returns the state deserialized from r. 536 func ReadState(backend Backend, r io.Reader) (*State, error) { 537 s := new(State) 538 s.Lock() 539 defer s.unlock() 540 d := json.NewDecoder(r) 541 err := d.Decode(&s) 542 if err != nil { 543 return nil, fmt.Errorf("cannot read state: %s", err) 544 } 545 s.backend = backend 546 s.modified = false 547 s.cache = make(map[interface{}]interface{}) 548 return s, err 549 }