github.com/Lephar/snapd@v0.0.0-20210825215435-c7fba9cef4d2/overlord/state/state.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2016-2020 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  // Package state implements the representation of system state.
    21  package state
    22  
    23  import (
    24  	"encoding/json"
    25  	"errors"
    26  	"fmt"
    27  	"io"
    28  	"sort"
    29  	"strconv"
    30  	"sync"
    31  	"sync/atomic"
    32  	"time"
    33  
    34  	"github.com/snapcore/snapd/logger"
    35  )
    36  
    37  // A Backend is used by State to checkpoint on every unlock operation
    38  // and to mediate requests to ensure the state sooner or request restarts.
    39  type Backend interface {
    40  	Checkpoint(data []byte) error
    41  	EnsureBefore(d time.Duration)
    42  	// TODO: take flags to ask for reboot vs restart?
    43  	RequestRestart(t RestartType)
    44  }
    45  
    46  type customData map[string]*json.RawMessage
    47  
    48  func (data customData) get(key string, value interface{}) error {
    49  	entryJSON := data[key]
    50  	if entryJSON == nil {
    51  		return ErrNoState
    52  	}
    53  	err := json.Unmarshal(*entryJSON, value)
    54  	if err != nil {
    55  		return fmt.Errorf("internal error: could not unmarshal state entry %q: %v", key, err)
    56  	}
    57  	return nil
    58  }
    59  
    60  func (data customData) has(key string) bool {
    61  	return data[key] != nil
    62  }
    63  
    64  func (data customData) set(key string, value interface{}) {
    65  	if value == nil {
    66  		delete(data, key)
    67  		return
    68  	}
    69  	serialized, err := json.Marshal(value)
    70  	if err != nil {
    71  		logger.Panicf("internal error: could not marshal value for state entry %q: %v", key, err)
    72  	}
    73  	entryJSON := json.RawMessage(serialized)
    74  	data[key] = &entryJSON
    75  }
    76  
    77  type RestartType int
    78  
    79  const (
    80  	RestartUnset RestartType = iota
    81  	RestartDaemon
    82  	RestartSystem
    83  	// RestartSystemNow is like RestartSystem but action is immediate
    84  	RestartSystemNow
    85  	// RestartSocket will restart the daemon so that it goes into
    86  	// socket activation mode.
    87  	RestartSocket
    88  	// Stop just stops the daemon (used with image pre-seeding)
    89  	StopDaemon
    90  	// RestartSystemHaltNow will shutdown --halt the system asap
    91  	RestartSystemHaltNow
    92  	// RestartSystemPoweroffNow will shutdown --poweroff the system asap
    93  	RestartSystemPoweroffNow
    94  )
    95  
    96  // State represents an evolving system state that persists across restarts.
    97  //
    98  // The State is concurrency-safe, and all reads and writes to it must be
    99  // performed with the state locked. It's a runtime error (panic) to perform
   100  // operations without it.
   101  //
   102  // The state is persisted on every unlock operation via the StateBackend
   103  // it was initialized with.
   104  type State struct {
   105  	mu  sync.Mutex
   106  	muC int32
   107  
   108  	lastTaskId   int
   109  	lastChangeId int
   110  	lastLaneId   int
   111  
   112  	backend  Backend
   113  	data     customData
   114  	changes  map[string]*Change
   115  	tasks    map[string]*Task
   116  	warnings map[string]*Warning
   117  
   118  	modified bool
   119  
   120  	cache map[interface{}]interface{}
   121  
   122  	restarting RestartType
   123  	restartLck sync.Mutex
   124  	bootID     string
   125  }
   126  
   127  // New returns a new empty state.
   128  func New(backend Backend) *State {
   129  	return &State{
   130  		backend:  backend,
   131  		data:     make(customData),
   132  		changes:  make(map[string]*Change),
   133  		tasks:    make(map[string]*Task),
   134  		warnings: make(map[string]*Warning),
   135  		modified: true,
   136  		cache:    make(map[interface{}]interface{}),
   137  	}
   138  }
   139  
   140  // Modified returns whether the state was modified since the last checkpoint.
   141  func (s *State) Modified() bool {
   142  	return s.modified
   143  }
   144  
   145  // Lock acquires the state lock.
   146  func (s *State) Lock() {
   147  	s.mu.Lock()
   148  	atomic.AddInt32(&s.muC, 1)
   149  }
   150  
   151  func (s *State) reading() {
   152  	if atomic.LoadInt32(&s.muC) != 1 {
   153  		panic("internal error: accessing state without lock")
   154  	}
   155  }
   156  
   157  func (s *State) writing() {
   158  	s.modified = true
   159  	if atomic.LoadInt32(&s.muC) != 1 {
   160  		panic("internal error: accessing state without lock")
   161  	}
   162  }
   163  
   164  func (s *State) unlock() {
   165  	atomic.AddInt32(&s.muC, -1)
   166  	s.mu.Unlock()
   167  }
   168  
   169  type marshalledState struct {
   170  	Data     map[string]*json.RawMessage `json:"data"`
   171  	Changes  map[string]*Change          `json:"changes"`
   172  	Tasks    map[string]*Task            `json:"tasks"`
   173  	Warnings []*Warning                  `json:"warnings,omitempty"`
   174  
   175  	LastChangeId int `json:"last-change-id"`
   176  	LastTaskId   int `json:"last-task-id"`
   177  	LastLaneId   int `json:"last-lane-id"`
   178  }
   179  
   180  // MarshalJSON makes State a json.Marshaller
   181  func (s *State) MarshalJSON() ([]byte, error) {
   182  	s.reading()
   183  	return json.Marshal(marshalledState{
   184  		Data:     s.data,
   185  		Changes:  s.changes,
   186  		Tasks:    s.tasks,
   187  		Warnings: s.flattenWarnings(),
   188  
   189  		LastTaskId:   s.lastTaskId,
   190  		LastChangeId: s.lastChangeId,
   191  		LastLaneId:   s.lastLaneId,
   192  	})
   193  }
   194  
   195  // UnmarshalJSON makes State a json.Unmarshaller
   196  func (s *State) UnmarshalJSON(data []byte) error {
   197  	s.writing()
   198  	var unmarshalled marshalledState
   199  	err := json.Unmarshal(data, &unmarshalled)
   200  	if err != nil {
   201  		return err
   202  	}
   203  	s.data = unmarshalled.Data
   204  	s.changes = unmarshalled.Changes
   205  	s.tasks = unmarshalled.Tasks
   206  	s.unflattenWarnings(unmarshalled.Warnings)
   207  	s.lastChangeId = unmarshalled.LastChangeId
   208  	s.lastTaskId = unmarshalled.LastTaskId
   209  	s.lastLaneId = unmarshalled.LastLaneId
   210  	// backlink state again
   211  	for _, t := range s.tasks {
   212  		t.state = s
   213  	}
   214  	for _, chg := range s.changes {
   215  		chg.state = s
   216  		chg.finishUnmarshal()
   217  	}
   218  	return nil
   219  }
   220  
   221  func (s *State) checkpointData() []byte {
   222  	data, err := json.Marshal(s)
   223  	if err != nil {
   224  		// this shouldn't happen, because the actual delicate serializing happens at various Set()s
   225  		logger.Panicf("internal error: could not marshal state for checkpointing: %v", err)
   226  	}
   227  	return data
   228  }
   229  
   230  // unlock checkpoint retry parameters (5 mins of retries by default)
   231  var (
   232  	unlockCheckpointRetryMaxTime  = 5 * time.Minute
   233  	unlockCheckpointRetryInterval = 3 * time.Second
   234  )
   235  
   236  // Unlock releases the state lock and checkpoints the state.
   237  // It does not return until the state is correctly checkpointed.
   238  // After too many unsuccessful checkpoint attempts, it panics.
   239  func (s *State) Unlock() {
   240  	defer s.unlock()
   241  
   242  	if !s.modified || s.backend == nil {
   243  		return
   244  	}
   245  
   246  	data := s.checkpointData()
   247  	var err error
   248  	start := time.Now()
   249  	for time.Since(start) <= unlockCheckpointRetryMaxTime {
   250  		if err = s.backend.Checkpoint(data); err == nil {
   251  			s.modified = false
   252  			return
   253  		}
   254  		time.Sleep(unlockCheckpointRetryInterval)
   255  	}
   256  	logger.Panicf("cannot checkpoint even after %v of retries every %v: %v", unlockCheckpointRetryMaxTime, unlockCheckpointRetryInterval, err)
   257  }
   258  
   259  // EnsureBefore asks for an ensure pass to happen sooner within duration from now.
   260  func (s *State) EnsureBefore(d time.Duration) {
   261  	if s.backend != nil {
   262  		s.backend.EnsureBefore(d)
   263  	}
   264  }
   265  
   266  // RequestRestart asks for a restart of the managing process.
   267  // The state needs to be locked to request a RestartSystem.
   268  func (s *State) RequestRestart(t RestartType) {
   269  	if s.backend != nil {
   270  		switch t {
   271  		case RestartSystem, RestartSystemNow, RestartSystemHaltNow, RestartSystemPoweroffNow:
   272  			if s.bootID == "" {
   273  				panic("internal error: cannot request a system restart if current boot ID was not provided via VerifyReboot")
   274  			}
   275  			s.Set("system-restart-from-boot-id", s.bootID)
   276  		}
   277  		s.restartLck.Lock()
   278  		s.restarting = t
   279  		s.restartLck.Unlock()
   280  		s.backend.RequestRestart(t)
   281  	}
   282  }
   283  
   284  // Restarting returns whether a restart was requested with RequestRestart and of which type.
   285  func (s *State) Restarting() (bool, RestartType) {
   286  	s.restartLck.Lock()
   287  	defer s.restartLck.Unlock()
   288  	return s.restarting != RestartUnset, s.restarting
   289  }
   290  
   291  var ErrExpectedReboot = errors.New("expected reboot did not happen")
   292  
   293  // VerifyReboot checks if the state remembers that a system restart was
   294  // requested and whether it succeeded based on the provided current
   295  // boot id.  It returns ErrExpectedReboot if the expected reboot did
   296  // not happen yet.  It must be called early in the usage of state and
   297  // before an RequestRestart with RestartSystem is attempted.
   298  // It must be called with the state lock held.
   299  func (s *State) VerifyReboot(curBootID string) error {
   300  	var fromBootID string
   301  	err := s.Get("system-restart-from-boot-id", &fromBootID)
   302  	if err != nil && err != ErrNoState {
   303  		return err
   304  	}
   305  	s.bootID = curBootID
   306  	if fromBootID == "" {
   307  		return nil
   308  	}
   309  	if fromBootID == curBootID {
   310  		return ErrExpectedReboot
   311  	}
   312  	// we rebooted alright
   313  	s.ClearReboot()
   314  	return nil
   315  }
   316  
   317  // ClearReboot clears state information about tracking requested reboots.
   318  func (s *State) ClearReboot() {
   319  	s.Set("system-restart-from-boot-id", nil)
   320  }
   321  
   322  func MockRestarting(s *State, restarting RestartType) RestartType {
   323  	s.restartLck.Lock()
   324  	defer s.restartLck.Unlock()
   325  	old := s.restarting
   326  	s.restarting = restarting
   327  	return old
   328  }
   329  
   330  // ErrNoState represents the case of no state entry for a given key.
   331  var ErrNoState = errors.New("no state entry for key")
   332  
   333  // Get unmarshals the stored value associated with the provided key
   334  // into the value parameter.
   335  // It returns ErrNoState if there is no entry for key.
   336  func (s *State) Get(key string, value interface{}) error {
   337  	s.reading()
   338  	return s.data.get(key, value)
   339  }
   340  
   341  // Set associates value with key for future consulting by managers.
   342  // The provided value must properly marshal and unmarshal with encoding/json.
   343  func (s *State) Set(key string, value interface{}) {
   344  	s.writing()
   345  	s.data.set(key, value)
   346  }
   347  
   348  // Cached returns the cached value associated with the provided key.
   349  // It returns nil if there is no entry for key.
   350  func (s *State) Cached(key interface{}) interface{} {
   351  	s.reading()
   352  	return s.cache[key]
   353  }
   354  
   355  // Cache associates value with key for future consulting by managers.
   356  // The cached value is not persisted.
   357  func (s *State) Cache(key, value interface{}) {
   358  	s.reading() // Doesn't touch persisted data.
   359  	if value == nil {
   360  		delete(s.cache, key)
   361  	} else {
   362  		s.cache[key] = value
   363  	}
   364  }
   365  
   366  // NewChange adds a new change to the state.
   367  func (s *State) NewChange(kind, summary string) *Change {
   368  	s.writing()
   369  	s.lastChangeId++
   370  	id := strconv.Itoa(s.lastChangeId)
   371  	chg := newChange(s, id, kind, summary)
   372  	s.changes[id] = chg
   373  	return chg
   374  }
   375  
   376  // NewLane creates a new lane in the state.
   377  func (s *State) NewLane() int {
   378  	s.writing()
   379  	s.lastLaneId++
   380  	return s.lastLaneId
   381  }
   382  
   383  // Changes returns all changes currently known to the state.
   384  func (s *State) Changes() []*Change {
   385  	s.reading()
   386  	res := make([]*Change, 0, len(s.changes))
   387  	for _, chg := range s.changes {
   388  		res = append(res, chg)
   389  	}
   390  	return res
   391  }
   392  
   393  // Change returns the change for the given ID.
   394  func (s *State) Change(id string) *Change {
   395  	s.reading()
   396  	return s.changes[id]
   397  }
   398  
   399  // NewTask creates a new task.
   400  // It usually will be registered with a Change using AddTask or
   401  // through a TaskSet.
   402  func (s *State) NewTask(kind, summary string) *Task {
   403  	s.writing()
   404  	s.lastTaskId++
   405  	id := strconv.Itoa(s.lastTaskId)
   406  	t := newTask(s, id, kind, summary)
   407  	s.tasks[id] = t
   408  	return t
   409  }
   410  
   411  // Tasks returns all tasks currently known to the state and linked to changes.
   412  func (s *State) Tasks() []*Task {
   413  	s.reading()
   414  	res := make([]*Task, 0, len(s.tasks))
   415  	for _, t := range s.tasks {
   416  		if t.Change() == nil { // skip unlinked tasks
   417  			continue
   418  		}
   419  		res = append(res, t)
   420  	}
   421  	return res
   422  }
   423  
   424  // Task returns the task for the given ID if the task has been linked to a change.
   425  func (s *State) Task(id string) *Task {
   426  	s.reading()
   427  	t := s.tasks[id]
   428  	if t == nil || t.Change() == nil {
   429  		return nil
   430  	}
   431  	return t
   432  }
   433  
   434  // TaskCount returns the number of tasks that currently exist in the state,
   435  // whether linked to a change or not.
   436  func (s *State) TaskCount() int {
   437  	s.reading()
   438  	return len(s.tasks)
   439  }
   440  
   441  func (s *State) tasksIn(tids []string) []*Task {
   442  	res := make([]*Task, len(tids))
   443  	for i, tid := range tids {
   444  		res[i] = s.tasks[tid]
   445  	}
   446  	return res
   447  }
   448  
   449  // Prune does several cleanup tasks to the in-memory state:
   450  //
   451  //  * it removes changes that became ready for more than pruneWait and aborts
   452  //    tasks spawned for more than abortWait.
   453  //
   454  //  * it removes tasks unlinked to changes after pruneWait. When there are more
   455  //    changes than the limit set via "maxReadyChanges" those changes in ready
   456  //    state will also removed even if they are below the pruneWait duration.
   457  //
   458  //  * it removes expired warnings.
   459  func (s *State) Prune(startOfOperation time.Time, pruneWait, abortWait time.Duration, maxReadyChanges int) {
   460  	now := time.Now()
   461  	pruneLimit := now.Add(-pruneWait)
   462  	abortLimit := now.Add(-abortWait)
   463  
   464  	// sort from oldest to newest
   465  	changes := s.Changes()
   466  	sort.Sort(byReadyTime(changes))
   467  
   468  	readyChangesCount := 0
   469  	for i := range changes {
   470  		// changes are sorted (not-ready sorts first)
   471  		// so we know we can iterate in reverse and break once we
   472  		// find a ready time of "zero"
   473  		chg := changes[len(changes)-i-1]
   474  		if chg.ReadyTime().IsZero() {
   475  			break
   476  		}
   477  		readyChangesCount++
   478  	}
   479  
   480  	for k, w := range s.warnings {
   481  		if w.ExpiredBefore(now) {
   482  			delete(s.warnings, k)
   483  		}
   484  	}
   485  
   486  	for _, chg := range changes {
   487  		readyTime := chg.ReadyTime()
   488  		spawnTime := chg.SpawnTime()
   489  		if spawnTime.Before(startOfOperation) {
   490  			spawnTime = startOfOperation
   491  		}
   492  		if readyTime.IsZero() {
   493  			if spawnTime.Before(pruneLimit) && len(chg.Tasks()) == 0 {
   494  				chg.Abort()
   495  				delete(s.changes, chg.ID())
   496  			} else if spawnTime.Before(abortLimit) {
   497  				chg.Abort()
   498  			}
   499  			continue
   500  		}
   501  		// change old or we have too many changes
   502  		if readyTime.Before(pruneLimit) || readyChangesCount > maxReadyChanges {
   503  			s.writing()
   504  			for _, t := range chg.Tasks() {
   505  				delete(s.tasks, t.ID())
   506  			}
   507  			delete(s.changes, chg.ID())
   508  			readyChangesCount--
   509  		}
   510  	}
   511  
   512  	for tid, t := range s.tasks {
   513  		// TODO: this could be done more aggressively
   514  		if t.Change() == nil && t.SpawnTime().Before(pruneLimit) {
   515  			s.writing()
   516  			delete(s.tasks, tid)
   517  		}
   518  	}
   519  }
   520  
   521  // GetMaybeTimings implements timings.GetSaver
   522  func (s *State) GetMaybeTimings(timings interface{}) error {
   523  	err := s.Get("timings", timings)
   524  	if err != nil && err != ErrNoState {
   525  		return err
   526  	}
   527  	return nil
   528  }
   529  
   530  // SaveTimings implements timings.GetSaver
   531  func (s *State) SaveTimings(timings interface{}) {
   532  	s.Set("timings", timings)
   533  }
   534  
   535  // ReadState returns the state deserialized from r.
   536  func ReadState(backend Backend, r io.Reader) (*State, error) {
   537  	s := new(State)
   538  	s.Lock()
   539  	defer s.unlock()
   540  	d := json.NewDecoder(r)
   541  	err := d.Decode(&s)
   542  	if err != nil {
   543  		return nil, fmt.Errorf("cannot read state: %s", err)
   544  	}
   545  	s.backend = backend
   546  	s.modified = false
   547  	s.cache = make(map[interface{}]interface{})
   548  	return s, err
   549  }