github.com/hugh712/snapd@v0.0.0-20200910133618-1a99902bd583/overlord/state/state.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2016-2020 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  // Package state implements the representation of system state.
    21  package state
    22  
    23  import (
    24  	"encoding/json"
    25  	"errors"
    26  	"fmt"
    27  	"io"
    28  	"sort"
    29  	"strconv"
    30  	"sync"
    31  	"sync/atomic"
    32  	"time"
    33  
    34  	"github.com/snapcore/snapd/logger"
    35  )
    36  
    37  // A Backend is used by State to checkpoint on every unlock operation
    38  // and to mediate requests to ensure the state sooner or request restarts.
    39  type Backend interface {
    40  	Checkpoint(data []byte) error
    41  	EnsureBefore(d time.Duration)
    42  	// TODO: take flags to ask for reboot vs restart?
    43  	RequestRestart(t RestartType)
    44  }
    45  
    46  type customData map[string]*json.RawMessage
    47  
    48  func (data customData) get(key string, value interface{}) error {
    49  	entryJSON := data[key]
    50  	if entryJSON == nil {
    51  		return ErrNoState
    52  	}
    53  	err := json.Unmarshal(*entryJSON, value)
    54  	if err != nil {
    55  		return fmt.Errorf("internal error: could not unmarshal state entry %q: %v", key, err)
    56  	}
    57  	return nil
    58  }
    59  
    60  func (data customData) has(key string) bool {
    61  	return data[key] != nil
    62  }
    63  
    64  func (data customData) set(key string, value interface{}) {
    65  	if value == nil {
    66  		delete(data, key)
    67  		return
    68  	}
    69  	serialized, err := json.Marshal(value)
    70  	if err != nil {
    71  		logger.Panicf("internal error: could not marshal value for state entry %q: %v", key, err)
    72  	}
    73  	entryJSON := json.RawMessage(serialized)
    74  	data[key] = &entryJSON
    75  }
    76  
    77  type RestartType int
    78  
    79  const (
    80  	RestartUnset RestartType = iota
    81  	RestartDaemon
    82  	RestartSystem
    83  	// RestartSystemNow is like RestartSystem but action is immediate
    84  	RestartSystemNow
    85  	// RestartSocket will restart the daemon so that it goes into
    86  	// socket activation mode.
    87  	RestartSocket
    88  	// Stop just stops the daemon (used with image pre-seeding)
    89  	StopDaemon
    90  )
    91  
    92  // State represents an evolving system state that persists across restarts.
    93  //
    94  // The State is concurrency-safe, and all reads and writes to it must be
    95  // performed with the state locked. It's a runtime error (panic) to perform
    96  // operations without it.
    97  //
    98  // The state is persisted on every unlock operation via the StateBackend
    99  // it was initialized with.
   100  type State struct {
   101  	mu  sync.Mutex
   102  	muC int32
   103  
   104  	lastTaskId   int
   105  	lastChangeId int
   106  	lastLaneId   int
   107  
   108  	backend  Backend
   109  	data     customData
   110  	changes  map[string]*Change
   111  	tasks    map[string]*Task
   112  	warnings map[string]*Warning
   113  
   114  	modified bool
   115  
   116  	cache map[interface{}]interface{}
   117  
   118  	restarting RestartType
   119  	restartLck sync.Mutex
   120  	bootID     string
   121  }
   122  
   123  // New returns a new empty state.
   124  func New(backend Backend) *State {
   125  	return &State{
   126  		backend:  backend,
   127  		data:     make(customData),
   128  		changes:  make(map[string]*Change),
   129  		tasks:    make(map[string]*Task),
   130  		warnings: make(map[string]*Warning),
   131  		modified: true,
   132  		cache:    make(map[interface{}]interface{}),
   133  	}
   134  }
   135  
   136  // Modified returns whether the state was modified since the last checkpoint.
   137  func (s *State) Modified() bool {
   138  	return s.modified
   139  }
   140  
   141  // Lock acquires the state lock.
   142  func (s *State) Lock() {
   143  	s.mu.Lock()
   144  	atomic.AddInt32(&s.muC, 1)
   145  }
   146  
   147  func (s *State) reading() {
   148  	if atomic.LoadInt32(&s.muC) != 1 {
   149  		panic("internal error: accessing state without lock")
   150  	}
   151  }
   152  
   153  func (s *State) writing() {
   154  	s.modified = true
   155  	if atomic.LoadInt32(&s.muC) != 1 {
   156  		panic("internal error: accessing state without lock")
   157  	}
   158  }
   159  
   160  func (s *State) unlock() {
   161  	atomic.AddInt32(&s.muC, -1)
   162  	s.mu.Unlock()
   163  }
   164  
   165  type marshalledState struct {
   166  	Data     map[string]*json.RawMessage `json:"data"`
   167  	Changes  map[string]*Change          `json:"changes"`
   168  	Tasks    map[string]*Task            `json:"tasks"`
   169  	Warnings []*Warning                  `json:"warnings,omitempty"`
   170  
   171  	LastChangeId int `json:"last-change-id"`
   172  	LastTaskId   int `json:"last-task-id"`
   173  	LastLaneId   int `json:"last-lane-id"`
   174  }
   175  
   176  // MarshalJSON makes State a json.Marshaller
   177  func (s *State) MarshalJSON() ([]byte, error) {
   178  	s.reading()
   179  	return json.Marshal(marshalledState{
   180  		Data:     s.data,
   181  		Changes:  s.changes,
   182  		Tasks:    s.tasks,
   183  		Warnings: s.flattenWarnings(),
   184  
   185  		LastTaskId:   s.lastTaskId,
   186  		LastChangeId: s.lastChangeId,
   187  		LastLaneId:   s.lastLaneId,
   188  	})
   189  }
   190  
   191  // UnmarshalJSON makes State a json.Unmarshaller
   192  func (s *State) UnmarshalJSON(data []byte) error {
   193  	s.writing()
   194  	var unmarshalled marshalledState
   195  	err := json.Unmarshal(data, &unmarshalled)
   196  	if err != nil {
   197  		return err
   198  	}
   199  	s.data = unmarshalled.Data
   200  	s.changes = unmarshalled.Changes
   201  	s.tasks = unmarshalled.Tasks
   202  	s.unflattenWarnings(unmarshalled.Warnings)
   203  	s.lastChangeId = unmarshalled.LastChangeId
   204  	s.lastTaskId = unmarshalled.LastTaskId
   205  	s.lastLaneId = unmarshalled.LastLaneId
   206  	// backlink state again
   207  	for _, t := range s.tasks {
   208  		t.state = s
   209  	}
   210  	for _, chg := range s.changes {
   211  		chg.state = s
   212  		chg.finishUnmarshal()
   213  	}
   214  	return nil
   215  }
   216  
   217  func (s *State) checkpointData() []byte {
   218  	data, err := json.Marshal(s)
   219  	if err != nil {
   220  		// this shouldn't happen, because the actual delicate serializing happens at various Set()s
   221  		logger.Panicf("internal error: could not marshal state for checkpointing: %v", err)
   222  	}
   223  	return data
   224  }
   225  
   226  // unlock checkpoint retry parameters (5 mins of retries by default)
   227  var (
   228  	unlockCheckpointRetryMaxTime  = 5 * time.Minute
   229  	unlockCheckpointRetryInterval = 3 * time.Second
   230  )
   231  
   232  // Unlock releases the state lock and checkpoints the state.
   233  // It does not return until the state is correctly checkpointed.
   234  // After too many unsuccessful checkpoint attempts, it panics.
   235  func (s *State) Unlock() {
   236  	defer s.unlock()
   237  
   238  	if !s.modified || s.backend == nil {
   239  		return
   240  	}
   241  
   242  	data := s.checkpointData()
   243  	var err error
   244  	start := time.Now()
   245  	for time.Since(start) <= unlockCheckpointRetryMaxTime {
   246  		if err = s.backend.Checkpoint(data); err == nil {
   247  			s.modified = false
   248  			return
   249  		}
   250  		time.Sleep(unlockCheckpointRetryInterval)
   251  	}
   252  	logger.Panicf("cannot checkpoint even after %v of retries every %v: %v", unlockCheckpointRetryMaxTime, unlockCheckpointRetryInterval, err)
   253  }
   254  
   255  // EnsureBefore asks for an ensure pass to happen sooner within duration from now.
   256  func (s *State) EnsureBefore(d time.Duration) {
   257  	if s.backend != nil {
   258  		s.backend.EnsureBefore(d)
   259  	}
   260  }
   261  
   262  // RequestRestart asks for a restart of the managing process.
   263  // The state needs to be locked to request a RestartSystem.
   264  func (s *State) RequestRestart(t RestartType) {
   265  	if s.backend != nil {
   266  		if t == RestartSystem || t == RestartSystemNow {
   267  			if s.bootID == "" {
   268  				panic("internal error: cannot request a system restart if current boot ID was not provided via VerifyReboot")
   269  			}
   270  			s.Set("system-restart-from-boot-id", s.bootID)
   271  		}
   272  		s.restartLck.Lock()
   273  		s.restarting = t
   274  		s.restartLck.Unlock()
   275  		s.backend.RequestRestart(t)
   276  	}
   277  }
   278  
   279  // Restarting returns whether a restart was requested with RequestRestart and of which type.
   280  func (s *State) Restarting() (bool, RestartType) {
   281  	s.restartLck.Lock()
   282  	defer s.restartLck.Unlock()
   283  	return s.restarting != RestartUnset, s.restarting
   284  }
   285  
   286  var ErrExpectedReboot = errors.New("expected reboot did not happen")
   287  
   288  // VerifyReboot checks if the state remembers that a system restart was
   289  // requested and whether it succeeded based on the provided current
   290  // boot id.  It returns ErrExpectedReboot if the expected reboot did
   291  // not happen yet.  It must be called early in the usage of state and
   292  // before an RequestRestart with RestartSystem is attempted.
   293  // It must be called with the state lock held.
   294  func (s *State) VerifyReboot(curBootID string) error {
   295  	var fromBootID string
   296  	err := s.Get("system-restart-from-boot-id", &fromBootID)
   297  	if err != nil && err != ErrNoState {
   298  		return err
   299  	}
   300  	s.bootID = curBootID
   301  	if fromBootID == "" {
   302  		return nil
   303  	}
   304  	if fromBootID == curBootID {
   305  		return ErrExpectedReboot
   306  	}
   307  	// we rebooted alright
   308  	s.ClearReboot()
   309  	return nil
   310  }
   311  
   312  // ClearReboot clears state information about tracking requested reboots.
   313  func (s *State) ClearReboot() {
   314  	s.Set("system-restart-from-boot-id", nil)
   315  }
   316  
   317  func MockRestarting(s *State, restarting RestartType) RestartType {
   318  	s.restartLck.Lock()
   319  	defer s.restartLck.Unlock()
   320  	old := s.restarting
   321  	s.restarting = restarting
   322  	return old
   323  }
   324  
   325  // ErrNoState represents the case of no state entry for a given key.
   326  var ErrNoState = errors.New("no state entry for key")
   327  
   328  // Get unmarshals the stored value associated with the provided key
   329  // into the value parameter.
   330  // It returns ErrNoState if there is no entry for key.
   331  func (s *State) Get(key string, value interface{}) error {
   332  	s.reading()
   333  	return s.data.get(key, value)
   334  }
   335  
   336  // Set associates value with key for future consulting by managers.
   337  // The provided value must properly marshal and unmarshal with encoding/json.
   338  func (s *State) Set(key string, value interface{}) {
   339  	s.writing()
   340  	s.data.set(key, value)
   341  }
   342  
   343  // Cached returns the cached value associated with the provided key.
   344  // It returns nil if there is no entry for key.
   345  func (s *State) Cached(key interface{}) interface{} {
   346  	s.reading()
   347  	return s.cache[key]
   348  }
   349  
   350  // Cache associates value with key for future consulting by managers.
   351  // The cached value is not persisted.
   352  func (s *State) Cache(key, value interface{}) {
   353  	s.reading() // Doesn't touch persisted data.
   354  	if value == nil {
   355  		delete(s.cache, key)
   356  	} else {
   357  		s.cache[key] = value
   358  	}
   359  }
   360  
   361  // NewChange adds a new change to the state.
   362  func (s *State) NewChange(kind, summary string) *Change {
   363  	s.writing()
   364  	s.lastChangeId++
   365  	id := strconv.Itoa(s.lastChangeId)
   366  	chg := newChange(s, id, kind, summary)
   367  	s.changes[id] = chg
   368  	return chg
   369  }
   370  
   371  // NewLane creates a new lane in the state.
   372  func (s *State) NewLane() int {
   373  	s.writing()
   374  	s.lastLaneId++
   375  	return s.lastLaneId
   376  }
   377  
   378  // Changes returns all changes currently known to the state.
   379  func (s *State) Changes() []*Change {
   380  	s.reading()
   381  	res := make([]*Change, 0, len(s.changes))
   382  	for _, chg := range s.changes {
   383  		res = append(res, chg)
   384  	}
   385  	return res
   386  }
   387  
   388  // Change returns the change for the given ID.
   389  func (s *State) Change(id string) *Change {
   390  	s.reading()
   391  	return s.changes[id]
   392  }
   393  
   394  // NewTask creates a new task.
   395  // It usually will be registered with a Change using AddTask or
   396  // through a TaskSet.
   397  func (s *State) NewTask(kind, summary string) *Task {
   398  	s.writing()
   399  	s.lastTaskId++
   400  	id := strconv.Itoa(s.lastTaskId)
   401  	t := newTask(s, id, kind, summary)
   402  	s.tasks[id] = t
   403  	return t
   404  }
   405  
   406  // Tasks returns all tasks currently known to the state and linked to changes.
   407  func (s *State) Tasks() []*Task {
   408  	s.reading()
   409  	res := make([]*Task, 0, len(s.tasks))
   410  	for _, t := range s.tasks {
   411  		if t.Change() == nil { // skip unlinked tasks
   412  			continue
   413  		}
   414  		res = append(res, t)
   415  	}
   416  	return res
   417  }
   418  
   419  // Task returns the task for the given ID if the task has been linked to a change.
   420  func (s *State) Task(id string) *Task {
   421  	s.reading()
   422  	t := s.tasks[id]
   423  	if t == nil || t.Change() == nil {
   424  		return nil
   425  	}
   426  	return t
   427  }
   428  
   429  // TaskCount returns the number of tasks that currently exist in the state,
   430  // whether linked to a change or not.
   431  func (s *State) TaskCount() int {
   432  	s.reading()
   433  	return len(s.tasks)
   434  }
   435  
   436  func (s *State) tasksIn(tids []string) []*Task {
   437  	res := make([]*Task, len(tids))
   438  	for i, tid := range tids {
   439  		res[i] = s.tasks[tid]
   440  	}
   441  	return res
   442  }
   443  
   444  // Prune does several cleanup tasks to the in-memory state:
   445  //
   446  //  * it removes changes that became ready for more than pruneWait and aborts
   447  //    tasks spawned for more than abortWait.
   448  //
   449  //  * it removes tasks unlinked to changes after pruneWait. When there are more
   450  //    changes than the limit set via "maxReadyChanges" those changes in ready
   451  //    state will also removed even if they are below the pruneWait duration.
   452  //
   453  //  * it removes expired warnings.
   454  func (s *State) Prune(startOfOperation time.Time, pruneWait, abortWait time.Duration, maxReadyChanges int) {
   455  	now := time.Now()
   456  	pruneLimit := now.Add(-pruneWait)
   457  	abortLimit := now.Add(-abortWait)
   458  
   459  	// sort from oldest to newest
   460  	changes := s.Changes()
   461  	sort.Sort(byReadyTime(changes))
   462  
   463  	readyChangesCount := 0
   464  	for i := range changes {
   465  		// changes are sorted (not-ready sorts first)
   466  		// so we know we can iterate in reverse and break once we
   467  		// find a ready time of "zero"
   468  		chg := changes[len(changes)-i-1]
   469  		if chg.ReadyTime().IsZero() {
   470  			break
   471  		}
   472  		readyChangesCount++
   473  	}
   474  
   475  	for k, w := range s.warnings {
   476  		if w.ExpiredBefore(now) {
   477  			delete(s.warnings, k)
   478  		}
   479  	}
   480  
   481  	for _, chg := range changes {
   482  		readyTime := chg.ReadyTime()
   483  		spawnTime := chg.SpawnTime()
   484  		if spawnTime.Before(startOfOperation) {
   485  			spawnTime = startOfOperation
   486  		}
   487  		if readyTime.IsZero() {
   488  			if spawnTime.Before(pruneLimit) && len(chg.Tasks()) == 0 {
   489  				chg.Abort()
   490  				delete(s.changes, chg.ID())
   491  			} else if spawnTime.Before(abortLimit) {
   492  				chg.Abort()
   493  			}
   494  			continue
   495  		}
   496  		// change old or we have too many changes
   497  		if readyTime.Before(pruneLimit) || readyChangesCount > maxReadyChanges {
   498  			s.writing()
   499  			for _, t := range chg.Tasks() {
   500  				delete(s.tasks, t.ID())
   501  			}
   502  			delete(s.changes, chg.ID())
   503  			readyChangesCount--
   504  		}
   505  	}
   506  
   507  	for tid, t := range s.tasks {
   508  		// TODO: this could be done more aggressively
   509  		if t.Change() == nil && t.SpawnTime().Before(pruneLimit) {
   510  			s.writing()
   511  			delete(s.tasks, tid)
   512  		}
   513  	}
   514  }
   515  
   516  // GetMaybeTimings implements timings.GetSaver
   517  func (s *State) GetMaybeTimings(timings interface{}) error {
   518  	err := s.Get("timings", timings)
   519  	if err != nil && err != ErrNoState {
   520  		return err
   521  	}
   522  	return nil
   523  }
   524  
   525  // SaveTimings implements timings.GetSaver
   526  func (s *State) SaveTimings(timings interface{}) {
   527  	s.Set("timings", timings)
   528  }
   529  
   530  // ReadState returns the state deserialized from r.
   531  func ReadState(backend Backend, r io.Reader) (*State, error) {
   532  	s := new(State)
   533  	s.Lock()
   534  	defer s.unlock()
   535  	d := json.NewDecoder(r)
   536  	err := d.Decode(&s)
   537  	if err != nil {
   538  		return nil, fmt.Errorf("cannot read state: %s", err)
   539  	}
   540  	s.backend = backend
   541  	s.modified = false
   542  	s.cache = make(map[interface{}]interface{})
   543  	return s, err
   544  }