github.com/rigado/snapd@v2.42.5-go-mod+incompatible/overlord/state/state.go (about) 1 // -*- Mode: Go; indent-tabs-mode: t -*- 2 3 /* 4 * Copyright (C) 2016 Canonical Ltd 5 * 6 * This program is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU General Public License version 3 as 8 * published by the Free Software Foundation. 9 * 10 * This program is distributed in the hope that it will be useful, 11 * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 * GNU General Public License for more details. 14 * 15 * You should have received a copy of the GNU General Public License 16 * along with this program. If not, see <http://www.gnu.org/licenses/>. 17 * 18 */ 19 20 // Package state implements the representation of system state. 21 package state 22 23 import ( 24 "encoding/json" 25 "errors" 26 "fmt" 27 "io" 28 "sort" 29 "strconv" 30 "sync" 31 "sync/atomic" 32 "time" 33 34 "github.com/snapcore/snapd/logger" 35 ) 36 37 // A Backend is used by State to checkpoint on every unlock operation 38 // and to mediate requests to ensure the state sooner or request restarts. 39 type Backend interface { 40 Checkpoint(data []byte) error 41 EnsureBefore(d time.Duration) 42 // TODO: take flags to ask for reboot vs restart? 43 RequestRestart(t RestartType) 44 } 45 46 type customData map[string]*json.RawMessage 47 48 func (data customData) get(key string, value interface{}) error { 49 entryJSON := data[key] 50 if entryJSON == nil { 51 return ErrNoState 52 } 53 err := json.Unmarshal(*entryJSON, value) 54 if err != nil { 55 return fmt.Errorf("internal error: could not unmarshal state entry %q: %v", key, err) 56 } 57 return nil 58 } 59 60 func (data customData) has(key string) bool { 61 return data[key] != nil 62 } 63 64 func (data customData) set(key string, value interface{}) { 65 if value == nil { 66 delete(data, key) 67 return 68 } 69 serialized, err := json.Marshal(value) 70 if err != nil { 71 logger.Panicf("internal error: could not marshal value for state entry %q: %v", key, err) 72 } 73 entryJSON := json.RawMessage(serialized) 74 data[key] = &entryJSON 75 } 76 77 type RestartType int 78 79 const ( 80 RestartUnset RestartType = iota 81 RestartDaemon 82 RestartSystem 83 // RestartSocket will restart the daemon so that it goes into 84 // socket activation mode. 85 RestartSocket 86 ) 87 88 // State represents an evolving system state that persists across restarts. 89 // 90 // The State is concurrency-safe, and all reads and writes to it must be 91 // performed with the state locked. It's a runtime error (panic) to perform 92 // operations without it. 93 // 94 // The state is persisted on every unlock operation via the StateBackend 95 // it was initialized with. 96 type State struct { 97 mu sync.Mutex 98 muC int32 99 100 lastTaskId int 101 lastChangeId int 102 lastLaneId int 103 104 backend Backend 105 data customData 106 changes map[string]*Change 107 tasks map[string]*Task 108 warnings map[string]*Warning 109 110 modified bool 111 112 cache map[interface{}]interface{} 113 114 restarting RestartType 115 restartLck sync.Mutex 116 bootID string 117 } 118 119 // New returns a new empty state. 120 func New(backend Backend) *State { 121 return &State{ 122 backend: backend, 123 data: make(customData), 124 changes: make(map[string]*Change), 125 tasks: make(map[string]*Task), 126 warnings: make(map[string]*Warning), 127 modified: true, 128 cache: make(map[interface{}]interface{}), 129 } 130 } 131 132 // Modified returns whether the state was modified since the last checkpoint. 133 func (s *State) Modified() bool { 134 return s.modified 135 } 136 137 // Lock acquires the state lock. 138 func (s *State) Lock() { 139 s.mu.Lock() 140 atomic.AddInt32(&s.muC, 1) 141 } 142 143 func (s *State) reading() { 144 if atomic.LoadInt32(&s.muC) != 1 { 145 panic("internal error: accessing state without lock") 146 } 147 } 148 149 func (s *State) writing() { 150 s.modified = true 151 if atomic.LoadInt32(&s.muC) != 1 { 152 panic("internal error: accessing state without lock") 153 } 154 } 155 156 func (s *State) unlock() { 157 atomic.AddInt32(&s.muC, -1) 158 s.mu.Unlock() 159 } 160 161 type marshalledState struct { 162 Data map[string]*json.RawMessage `json:"data"` 163 Changes map[string]*Change `json:"changes"` 164 Tasks map[string]*Task `json:"tasks"` 165 Warnings []*Warning `json:"warnings,omitempty"` 166 167 LastChangeId int `json:"last-change-id"` 168 LastTaskId int `json:"last-task-id"` 169 LastLaneId int `json:"last-lane-id"` 170 } 171 172 // MarshalJSON makes State a json.Marshaller 173 func (s *State) MarshalJSON() ([]byte, error) { 174 s.reading() 175 return json.Marshal(marshalledState{ 176 Data: s.data, 177 Changes: s.changes, 178 Tasks: s.tasks, 179 Warnings: s.flattenWarnings(), 180 181 LastTaskId: s.lastTaskId, 182 LastChangeId: s.lastChangeId, 183 LastLaneId: s.lastLaneId, 184 }) 185 } 186 187 // UnmarshalJSON makes State a json.Unmarshaller 188 func (s *State) UnmarshalJSON(data []byte) error { 189 s.writing() 190 var unmarshalled marshalledState 191 err := json.Unmarshal(data, &unmarshalled) 192 if err != nil { 193 return err 194 } 195 s.data = unmarshalled.Data 196 s.changes = unmarshalled.Changes 197 s.tasks = unmarshalled.Tasks 198 s.unflattenWarnings(unmarshalled.Warnings) 199 s.lastChangeId = unmarshalled.LastChangeId 200 s.lastTaskId = unmarshalled.LastTaskId 201 s.lastLaneId = unmarshalled.LastLaneId 202 // backlink state again 203 for _, t := range s.tasks { 204 t.state = s 205 } 206 for _, chg := range s.changes { 207 chg.state = s 208 chg.finishUnmarshal() 209 } 210 return nil 211 } 212 213 func (s *State) checkpointData() []byte { 214 data, err := json.Marshal(s) 215 if err != nil { 216 // this shouldn't happen, because the actual delicate serializing happens at various Set()s 217 logger.Panicf("internal error: could not marshal state for checkpointing: %v", err) 218 } 219 return data 220 } 221 222 // unlock checkpoint retry parameters (5 mins of retries by default) 223 var ( 224 unlockCheckpointRetryMaxTime = 5 * time.Minute 225 unlockCheckpointRetryInterval = 3 * time.Second 226 ) 227 228 // Unlock releases the state lock and checkpoints the state. 229 // It does not return until the state is correctly checkpointed. 230 // After too many unsuccessful checkpoint attempts, it panics. 231 func (s *State) Unlock() { 232 defer s.unlock() 233 234 if !s.modified || s.backend == nil { 235 return 236 } 237 238 data := s.checkpointData() 239 var err error 240 start := time.Now() 241 for time.Since(start) <= unlockCheckpointRetryMaxTime { 242 if err = s.backend.Checkpoint(data); err == nil { 243 s.modified = false 244 return 245 } 246 time.Sleep(unlockCheckpointRetryInterval) 247 } 248 logger.Panicf("cannot checkpoint even after %v of retries every %v: %v", unlockCheckpointRetryMaxTime, unlockCheckpointRetryInterval, err) 249 } 250 251 // EnsureBefore asks for an ensure pass to happen sooner within duration from now. 252 func (s *State) EnsureBefore(d time.Duration) { 253 if s.backend != nil { 254 s.backend.EnsureBefore(d) 255 } 256 } 257 258 // RequestRestart asks for a restart of the managing process. 259 // The state needs to be locked to request a RestartSystem. 260 func (s *State) RequestRestart(t RestartType) { 261 if s.backend != nil { 262 if t == RestartSystem { 263 if s.bootID == "" { 264 panic("internal error: cannot request a system restart if current boot ID was not provided via VerifyReboot") 265 } 266 s.Set("system-restart-from-boot-id", s.bootID) 267 } 268 s.restartLck.Lock() 269 s.restarting = t 270 s.restartLck.Unlock() 271 s.backend.RequestRestart(t) 272 } 273 } 274 275 // Restarting returns whether a restart was requested with RequestRestart and of which type. 276 func (s *State) Restarting() (bool, RestartType) { 277 s.restartLck.Lock() 278 defer s.restartLck.Unlock() 279 return s.restarting != RestartUnset, s.restarting 280 } 281 282 var ErrExpectedReboot = errors.New("expected reboot did not happen") 283 284 // VerifyReboot checks if the state rembers that a system restart was 285 // requested and whether it succeeded based on the provided current 286 // boot id. It returns ErrExpectedReboot if the expected reboot did 287 // not happen yet. It must be called early in the usage of state and 288 // before an RequestRestart with RestartSystem is attempted. 289 // It must be called with the state lock held. 290 func (s *State) VerifyReboot(curBootID string) error { 291 var fromBootID string 292 err := s.Get("system-restart-from-boot-id", &fromBootID) 293 if err != nil && err != ErrNoState { 294 return err 295 } 296 s.bootID = curBootID 297 if fromBootID == "" { 298 return nil 299 } 300 if fromBootID == curBootID { 301 return ErrExpectedReboot 302 } 303 // we rebooted alright 304 s.ClearReboot() 305 return nil 306 } 307 308 // ClearReboot clears state information about tracking requested reboots. 309 func (s *State) ClearReboot() { 310 s.Set("system-restart-from-boot-id", nil) 311 } 312 313 func MockRestarting(s *State, restarting RestartType) RestartType { 314 s.restartLck.Lock() 315 defer s.restartLck.Unlock() 316 old := s.restarting 317 s.restarting = restarting 318 return old 319 } 320 321 // ErrNoState represents the case of no state entry for a given key. 322 var ErrNoState = errors.New("no state entry for key") 323 324 // Get unmarshals the stored value associated with the provided key 325 // into the value parameter. 326 // It returns ErrNoState if there is no entry for key. 327 func (s *State) Get(key string, value interface{}) error { 328 s.reading() 329 return s.data.get(key, value) 330 } 331 332 // Set associates value with key for future consulting by managers. 333 // The provided value must properly marshal and unmarshal with encoding/json. 334 func (s *State) Set(key string, value interface{}) { 335 s.writing() 336 s.data.set(key, value) 337 } 338 339 // Cached returns the cached value associated with the provided key. 340 // It returns nil if there is no entry for key. 341 func (s *State) Cached(key interface{}) interface{} { 342 s.reading() 343 return s.cache[key] 344 } 345 346 // Cache associates value with key for future consulting by managers. 347 // The cached value is not persisted. 348 func (s *State) Cache(key, value interface{}) { 349 s.reading() // Doesn't touch persisted data. 350 if value == nil { 351 delete(s.cache, key) 352 } else { 353 s.cache[key] = value 354 } 355 } 356 357 // NewChange adds a new change to the state. 358 func (s *State) NewChange(kind, summary string) *Change { 359 s.writing() 360 s.lastChangeId++ 361 id := strconv.Itoa(s.lastChangeId) 362 chg := newChange(s, id, kind, summary) 363 s.changes[id] = chg 364 return chg 365 } 366 367 // NewLane creates a new lane in the state. 368 func (s *State) NewLane() int { 369 s.writing() 370 s.lastLaneId++ 371 return s.lastLaneId 372 } 373 374 // Changes returns all changes currently known to the state. 375 func (s *State) Changes() []*Change { 376 s.reading() 377 res := make([]*Change, 0, len(s.changes)) 378 for _, chg := range s.changes { 379 res = append(res, chg) 380 } 381 return res 382 } 383 384 // Change returns the change for the given ID. 385 func (s *State) Change(id string) *Change { 386 s.reading() 387 return s.changes[id] 388 } 389 390 // NewTask creates a new task. 391 // It usually will be registered with a Change using AddTask or 392 // through a TaskSet. 393 func (s *State) NewTask(kind, summary string) *Task { 394 s.writing() 395 s.lastTaskId++ 396 id := strconv.Itoa(s.lastTaskId) 397 t := newTask(s, id, kind, summary) 398 s.tasks[id] = t 399 return t 400 } 401 402 // Tasks returns all tasks currently known to the state and linked to changes. 403 func (s *State) Tasks() []*Task { 404 s.reading() 405 res := make([]*Task, 0, len(s.tasks)) 406 for _, t := range s.tasks { 407 if t.Change() == nil { // skip unlinked tasks 408 continue 409 } 410 res = append(res, t) 411 } 412 return res 413 } 414 415 // Task returns the task for the given ID if the task has been linked to a change. 416 func (s *State) Task(id string) *Task { 417 s.reading() 418 t := s.tasks[id] 419 if t == nil || t.Change() == nil { 420 return nil 421 } 422 return t 423 } 424 425 // TaskCount returns the number of tasks that currently exist in the state, 426 // whether linked to a change or not. 427 func (s *State) TaskCount() int { 428 s.reading() 429 return len(s.tasks) 430 } 431 432 func (s *State) tasksIn(tids []string) []*Task { 433 res := make([]*Task, len(tids)) 434 for i, tid := range tids { 435 res[i] = s.tasks[tid] 436 } 437 return res 438 } 439 440 // Prune does several cleanup tasks to the in-memory state: 441 // 442 // * it removes changes that became ready for more than pruneWait and aborts 443 // tasks spawned for more than abortWait. 444 // 445 // * it removes tasks unlinked to changes after pruneWait. When there are more 446 // changes than the limit set via "maxReadyChanges" those changes in ready 447 // state will also removed even if they are below the pruneWait duration. 448 // 449 // * it removes expired warnings. 450 func (s *State) Prune(pruneWait, abortWait time.Duration, maxReadyChanges int) { 451 now := time.Now() 452 pruneLimit := now.Add(-pruneWait) 453 abortLimit := now.Add(-abortWait) 454 455 // sort from oldest to newest 456 changes := s.Changes() 457 sort.Sort(byReadyTime(changes)) 458 459 readyChangesCount := 0 460 for i := range changes { 461 // changes are sorted (not-ready sorts first) 462 // so we know we can iterate in reverse and break once we 463 // find a ready time of "zero" 464 chg := changes[len(changes)-i-1] 465 if chg.ReadyTime().IsZero() { 466 break 467 } 468 readyChangesCount++ 469 } 470 471 for k, w := range s.warnings { 472 if w.ExpiredBefore(now) { 473 delete(s.warnings, k) 474 } 475 } 476 477 for _, chg := range changes { 478 spawnTime := chg.SpawnTime() 479 readyTime := chg.ReadyTime() 480 if readyTime.IsZero() { 481 if spawnTime.Before(pruneLimit) && len(chg.Tasks()) == 0 { 482 chg.Abort() 483 delete(s.changes, chg.ID()) 484 } else if spawnTime.Before(abortLimit) { 485 chg.Abort() 486 } 487 continue 488 } 489 // change old or we have too many changes 490 if readyTime.Before(pruneLimit) || readyChangesCount > maxReadyChanges { 491 s.writing() 492 for _, t := range chg.Tasks() { 493 delete(s.tasks, t.ID()) 494 } 495 delete(s.changes, chg.ID()) 496 readyChangesCount-- 497 } 498 } 499 500 for tid, t := range s.tasks { 501 // TODO: this could be done more aggressively 502 if t.Change() == nil && t.SpawnTime().Before(pruneLimit) { 503 s.writing() 504 delete(s.tasks, tid) 505 } 506 } 507 } 508 509 // ReadState returns the state deserialized from r. 510 func ReadState(backend Backend, r io.Reader) (*State, error) { 511 s := new(State) 512 s.Lock() 513 defer s.unlock() 514 d := json.NewDecoder(r) 515 err := d.Decode(&s) 516 if err != nil { 517 return nil, fmt.Errorf("cannot read state: %s", err) 518 } 519 s.backend = backend 520 s.modified = false 521 s.cache = make(map[interface{}]interface{}) 522 return s, err 523 }