github.com/rigado/snapd@v2.42.5-go-mod+incompatible/overlord/state/state.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2016 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  // Package state implements the representation of system state.
    21  package state
    22  
    23  import (
    24  	"encoding/json"
    25  	"errors"
    26  	"fmt"
    27  	"io"
    28  	"sort"
    29  	"strconv"
    30  	"sync"
    31  	"sync/atomic"
    32  	"time"
    33  
    34  	"github.com/snapcore/snapd/logger"
    35  )
    36  
    37  // A Backend is used by State to checkpoint on every unlock operation
    38  // and to mediate requests to ensure the state sooner or request restarts.
    39  type Backend interface {
    40  	Checkpoint(data []byte) error
    41  	EnsureBefore(d time.Duration)
    42  	// TODO: take flags to ask for reboot vs restart?
    43  	RequestRestart(t RestartType)
    44  }
    45  
    46  type customData map[string]*json.RawMessage
    47  
    48  func (data customData) get(key string, value interface{}) error {
    49  	entryJSON := data[key]
    50  	if entryJSON == nil {
    51  		return ErrNoState
    52  	}
    53  	err := json.Unmarshal(*entryJSON, value)
    54  	if err != nil {
    55  		return fmt.Errorf("internal error: could not unmarshal state entry %q: %v", key, err)
    56  	}
    57  	return nil
    58  }
    59  
    60  func (data customData) has(key string) bool {
    61  	return data[key] != nil
    62  }
    63  
    64  func (data customData) set(key string, value interface{}) {
    65  	if value == nil {
    66  		delete(data, key)
    67  		return
    68  	}
    69  	serialized, err := json.Marshal(value)
    70  	if err != nil {
    71  		logger.Panicf("internal error: could not marshal value for state entry %q: %v", key, err)
    72  	}
    73  	entryJSON := json.RawMessage(serialized)
    74  	data[key] = &entryJSON
    75  }
    76  
    77  type RestartType int
    78  
    79  const (
    80  	RestartUnset RestartType = iota
    81  	RestartDaemon
    82  	RestartSystem
    83  	// RestartSocket will restart the daemon so that it goes into
    84  	// socket activation mode.
    85  	RestartSocket
    86  )
    87  
    88  // State represents an evolving system state that persists across restarts.
    89  //
    90  // The State is concurrency-safe, and all reads and writes to it must be
    91  // performed with the state locked. It's a runtime error (panic) to perform
    92  // operations without it.
    93  //
    94  // The state is persisted on every unlock operation via the StateBackend
    95  // it was initialized with.
    96  type State struct {
    97  	mu  sync.Mutex
    98  	muC int32
    99  
   100  	lastTaskId   int
   101  	lastChangeId int
   102  	lastLaneId   int
   103  
   104  	backend  Backend
   105  	data     customData
   106  	changes  map[string]*Change
   107  	tasks    map[string]*Task
   108  	warnings map[string]*Warning
   109  
   110  	modified bool
   111  
   112  	cache map[interface{}]interface{}
   113  
   114  	restarting RestartType
   115  	restartLck sync.Mutex
   116  	bootID     string
   117  }
   118  
   119  // New returns a new empty state.
   120  func New(backend Backend) *State {
   121  	return &State{
   122  		backend:  backend,
   123  		data:     make(customData),
   124  		changes:  make(map[string]*Change),
   125  		tasks:    make(map[string]*Task),
   126  		warnings: make(map[string]*Warning),
   127  		modified: true,
   128  		cache:    make(map[interface{}]interface{}),
   129  	}
   130  }
   131  
   132  // Modified returns whether the state was modified since the last checkpoint.
   133  func (s *State) Modified() bool {
   134  	return s.modified
   135  }
   136  
   137  // Lock acquires the state lock.
   138  func (s *State) Lock() {
   139  	s.mu.Lock()
   140  	atomic.AddInt32(&s.muC, 1)
   141  }
   142  
   143  func (s *State) reading() {
   144  	if atomic.LoadInt32(&s.muC) != 1 {
   145  		panic("internal error: accessing state without lock")
   146  	}
   147  }
   148  
   149  func (s *State) writing() {
   150  	s.modified = true
   151  	if atomic.LoadInt32(&s.muC) != 1 {
   152  		panic("internal error: accessing state without lock")
   153  	}
   154  }
   155  
   156  func (s *State) unlock() {
   157  	atomic.AddInt32(&s.muC, -1)
   158  	s.mu.Unlock()
   159  }
   160  
   161  type marshalledState struct {
   162  	Data     map[string]*json.RawMessage `json:"data"`
   163  	Changes  map[string]*Change          `json:"changes"`
   164  	Tasks    map[string]*Task            `json:"tasks"`
   165  	Warnings []*Warning                  `json:"warnings,omitempty"`
   166  
   167  	LastChangeId int `json:"last-change-id"`
   168  	LastTaskId   int `json:"last-task-id"`
   169  	LastLaneId   int `json:"last-lane-id"`
   170  }
   171  
   172  // MarshalJSON makes State a json.Marshaller
   173  func (s *State) MarshalJSON() ([]byte, error) {
   174  	s.reading()
   175  	return json.Marshal(marshalledState{
   176  		Data:     s.data,
   177  		Changes:  s.changes,
   178  		Tasks:    s.tasks,
   179  		Warnings: s.flattenWarnings(),
   180  
   181  		LastTaskId:   s.lastTaskId,
   182  		LastChangeId: s.lastChangeId,
   183  		LastLaneId:   s.lastLaneId,
   184  	})
   185  }
   186  
   187  // UnmarshalJSON makes State a json.Unmarshaller
   188  func (s *State) UnmarshalJSON(data []byte) error {
   189  	s.writing()
   190  	var unmarshalled marshalledState
   191  	err := json.Unmarshal(data, &unmarshalled)
   192  	if err != nil {
   193  		return err
   194  	}
   195  	s.data = unmarshalled.Data
   196  	s.changes = unmarshalled.Changes
   197  	s.tasks = unmarshalled.Tasks
   198  	s.unflattenWarnings(unmarshalled.Warnings)
   199  	s.lastChangeId = unmarshalled.LastChangeId
   200  	s.lastTaskId = unmarshalled.LastTaskId
   201  	s.lastLaneId = unmarshalled.LastLaneId
   202  	// backlink state again
   203  	for _, t := range s.tasks {
   204  		t.state = s
   205  	}
   206  	for _, chg := range s.changes {
   207  		chg.state = s
   208  		chg.finishUnmarshal()
   209  	}
   210  	return nil
   211  }
   212  
   213  func (s *State) checkpointData() []byte {
   214  	data, err := json.Marshal(s)
   215  	if err != nil {
   216  		// this shouldn't happen, because the actual delicate serializing happens at various Set()s
   217  		logger.Panicf("internal error: could not marshal state for checkpointing: %v", err)
   218  	}
   219  	return data
   220  }
   221  
   222  // unlock checkpoint retry parameters (5 mins of retries by default)
   223  var (
   224  	unlockCheckpointRetryMaxTime  = 5 * time.Minute
   225  	unlockCheckpointRetryInterval = 3 * time.Second
   226  )
   227  
   228  // Unlock releases the state lock and checkpoints the state.
   229  // It does not return until the state is correctly checkpointed.
   230  // After too many unsuccessful checkpoint attempts, it panics.
   231  func (s *State) Unlock() {
   232  	defer s.unlock()
   233  
   234  	if !s.modified || s.backend == nil {
   235  		return
   236  	}
   237  
   238  	data := s.checkpointData()
   239  	var err error
   240  	start := time.Now()
   241  	for time.Since(start) <= unlockCheckpointRetryMaxTime {
   242  		if err = s.backend.Checkpoint(data); err == nil {
   243  			s.modified = false
   244  			return
   245  		}
   246  		time.Sleep(unlockCheckpointRetryInterval)
   247  	}
   248  	logger.Panicf("cannot checkpoint even after %v of retries every %v: %v", unlockCheckpointRetryMaxTime, unlockCheckpointRetryInterval, err)
   249  }
   250  
   251  // EnsureBefore asks for an ensure pass to happen sooner within duration from now.
   252  func (s *State) EnsureBefore(d time.Duration) {
   253  	if s.backend != nil {
   254  		s.backend.EnsureBefore(d)
   255  	}
   256  }
   257  
   258  // RequestRestart asks for a restart of the managing process.
   259  // The state needs to be locked to request a RestartSystem.
   260  func (s *State) RequestRestart(t RestartType) {
   261  	if s.backend != nil {
   262  		if t == RestartSystem {
   263  			if s.bootID == "" {
   264  				panic("internal error: cannot request a system restart if current boot ID was not provided via VerifyReboot")
   265  			}
   266  			s.Set("system-restart-from-boot-id", s.bootID)
   267  		}
   268  		s.restartLck.Lock()
   269  		s.restarting = t
   270  		s.restartLck.Unlock()
   271  		s.backend.RequestRestart(t)
   272  	}
   273  }
   274  
   275  // Restarting returns whether a restart was requested with RequestRestart and of which type.
   276  func (s *State) Restarting() (bool, RestartType) {
   277  	s.restartLck.Lock()
   278  	defer s.restartLck.Unlock()
   279  	return s.restarting != RestartUnset, s.restarting
   280  }
   281  
   282  var ErrExpectedReboot = errors.New("expected reboot did not happen")
   283  
   284  // VerifyReboot checks if the state rembers that a system restart was
   285  // requested and whether it succeeded based on the provided current
   286  // boot id.  It returns ErrExpectedReboot if the expected reboot did
   287  // not happen yet.  It must be called early in the usage of state and
   288  // before an RequestRestart with RestartSystem is attempted.
   289  // It must be called with the state lock held.
   290  func (s *State) VerifyReboot(curBootID string) error {
   291  	var fromBootID string
   292  	err := s.Get("system-restart-from-boot-id", &fromBootID)
   293  	if err != nil && err != ErrNoState {
   294  		return err
   295  	}
   296  	s.bootID = curBootID
   297  	if fromBootID == "" {
   298  		return nil
   299  	}
   300  	if fromBootID == curBootID {
   301  		return ErrExpectedReboot
   302  	}
   303  	// we rebooted alright
   304  	s.ClearReboot()
   305  	return nil
   306  }
   307  
   308  // ClearReboot clears state information about tracking requested reboots.
   309  func (s *State) ClearReboot() {
   310  	s.Set("system-restart-from-boot-id", nil)
   311  }
   312  
   313  func MockRestarting(s *State, restarting RestartType) RestartType {
   314  	s.restartLck.Lock()
   315  	defer s.restartLck.Unlock()
   316  	old := s.restarting
   317  	s.restarting = restarting
   318  	return old
   319  }
   320  
   321  // ErrNoState represents the case of no state entry for a given key.
   322  var ErrNoState = errors.New("no state entry for key")
   323  
   324  // Get unmarshals the stored value associated with the provided key
   325  // into the value parameter.
   326  // It returns ErrNoState if there is no entry for key.
   327  func (s *State) Get(key string, value interface{}) error {
   328  	s.reading()
   329  	return s.data.get(key, value)
   330  }
   331  
   332  // Set associates value with key for future consulting by managers.
   333  // The provided value must properly marshal and unmarshal with encoding/json.
   334  func (s *State) Set(key string, value interface{}) {
   335  	s.writing()
   336  	s.data.set(key, value)
   337  }
   338  
   339  // Cached returns the cached value associated with the provided key.
   340  // It returns nil if there is no entry for key.
   341  func (s *State) Cached(key interface{}) interface{} {
   342  	s.reading()
   343  	return s.cache[key]
   344  }
   345  
   346  // Cache associates value with key for future consulting by managers.
   347  // The cached value is not persisted.
   348  func (s *State) Cache(key, value interface{}) {
   349  	s.reading() // Doesn't touch persisted data.
   350  	if value == nil {
   351  		delete(s.cache, key)
   352  	} else {
   353  		s.cache[key] = value
   354  	}
   355  }
   356  
   357  // NewChange adds a new change to the state.
   358  func (s *State) NewChange(kind, summary string) *Change {
   359  	s.writing()
   360  	s.lastChangeId++
   361  	id := strconv.Itoa(s.lastChangeId)
   362  	chg := newChange(s, id, kind, summary)
   363  	s.changes[id] = chg
   364  	return chg
   365  }
   366  
   367  // NewLane creates a new lane in the state.
   368  func (s *State) NewLane() int {
   369  	s.writing()
   370  	s.lastLaneId++
   371  	return s.lastLaneId
   372  }
   373  
   374  // Changes returns all changes currently known to the state.
   375  func (s *State) Changes() []*Change {
   376  	s.reading()
   377  	res := make([]*Change, 0, len(s.changes))
   378  	for _, chg := range s.changes {
   379  		res = append(res, chg)
   380  	}
   381  	return res
   382  }
   383  
   384  // Change returns the change for the given ID.
   385  func (s *State) Change(id string) *Change {
   386  	s.reading()
   387  	return s.changes[id]
   388  }
   389  
   390  // NewTask creates a new task.
   391  // It usually will be registered with a Change using AddTask or
   392  // through a TaskSet.
   393  func (s *State) NewTask(kind, summary string) *Task {
   394  	s.writing()
   395  	s.lastTaskId++
   396  	id := strconv.Itoa(s.lastTaskId)
   397  	t := newTask(s, id, kind, summary)
   398  	s.tasks[id] = t
   399  	return t
   400  }
   401  
   402  // Tasks returns all tasks currently known to the state and linked to changes.
   403  func (s *State) Tasks() []*Task {
   404  	s.reading()
   405  	res := make([]*Task, 0, len(s.tasks))
   406  	for _, t := range s.tasks {
   407  		if t.Change() == nil { // skip unlinked tasks
   408  			continue
   409  		}
   410  		res = append(res, t)
   411  	}
   412  	return res
   413  }
   414  
   415  // Task returns the task for the given ID if the task has been linked to a change.
   416  func (s *State) Task(id string) *Task {
   417  	s.reading()
   418  	t := s.tasks[id]
   419  	if t == nil || t.Change() == nil {
   420  		return nil
   421  	}
   422  	return t
   423  }
   424  
   425  // TaskCount returns the number of tasks that currently exist in the state,
   426  // whether linked to a change or not.
   427  func (s *State) TaskCount() int {
   428  	s.reading()
   429  	return len(s.tasks)
   430  }
   431  
   432  func (s *State) tasksIn(tids []string) []*Task {
   433  	res := make([]*Task, len(tids))
   434  	for i, tid := range tids {
   435  		res[i] = s.tasks[tid]
   436  	}
   437  	return res
   438  }
   439  
   440  // Prune does several cleanup tasks to the in-memory state:
   441  //
   442  //  * it removes changes that became ready for more than pruneWait and aborts
   443  //    tasks spawned for more than abortWait.
   444  //
   445  //  * it removes tasks unlinked to changes after pruneWait. When there are more
   446  //    changes than the limit set via "maxReadyChanges" those changes in ready
   447  //    state will also removed even if they are below the pruneWait duration.
   448  //
   449  //  * it removes expired warnings.
   450  func (s *State) Prune(pruneWait, abortWait time.Duration, maxReadyChanges int) {
   451  	now := time.Now()
   452  	pruneLimit := now.Add(-pruneWait)
   453  	abortLimit := now.Add(-abortWait)
   454  
   455  	// sort from oldest to newest
   456  	changes := s.Changes()
   457  	sort.Sort(byReadyTime(changes))
   458  
   459  	readyChangesCount := 0
   460  	for i := range changes {
   461  		// changes are sorted (not-ready sorts first)
   462  		// so we know we can iterate in reverse and break once we
   463  		// find a ready time of "zero"
   464  		chg := changes[len(changes)-i-1]
   465  		if chg.ReadyTime().IsZero() {
   466  			break
   467  		}
   468  		readyChangesCount++
   469  	}
   470  
   471  	for k, w := range s.warnings {
   472  		if w.ExpiredBefore(now) {
   473  			delete(s.warnings, k)
   474  		}
   475  	}
   476  
   477  	for _, chg := range changes {
   478  		spawnTime := chg.SpawnTime()
   479  		readyTime := chg.ReadyTime()
   480  		if readyTime.IsZero() {
   481  			if spawnTime.Before(pruneLimit) && len(chg.Tasks()) == 0 {
   482  				chg.Abort()
   483  				delete(s.changes, chg.ID())
   484  			} else if spawnTime.Before(abortLimit) {
   485  				chg.Abort()
   486  			}
   487  			continue
   488  		}
   489  		// change old or we have too many changes
   490  		if readyTime.Before(pruneLimit) || readyChangesCount > maxReadyChanges {
   491  			s.writing()
   492  			for _, t := range chg.Tasks() {
   493  				delete(s.tasks, t.ID())
   494  			}
   495  			delete(s.changes, chg.ID())
   496  			readyChangesCount--
   497  		}
   498  	}
   499  
   500  	for tid, t := range s.tasks {
   501  		// TODO: this could be done more aggressively
   502  		if t.Change() == nil && t.SpawnTime().Before(pruneLimit) {
   503  			s.writing()
   504  			delete(s.tasks, tid)
   505  		}
   506  	}
   507  }
   508  
   509  // ReadState returns the state deserialized from r.
   510  func ReadState(backend Backend, r io.Reader) (*State, error) {
   511  	s := new(State)
   512  	s.Lock()
   513  	defer s.unlock()
   514  	d := json.NewDecoder(r)
   515  	err := d.Decode(&s)
   516  	if err != nil {
   517  		return nil, fmt.Errorf("cannot read state: %s", err)
   518  	}
   519  	s.backend = backend
   520  	s.modified = false
   521  	s.cache = make(map[interface{}]interface{})
   522  	return s, err
   523  }