github.com/david-imola/snapd@v0.0.0-20210611180407-2de8ddeece6d/overlord/hookstate/context.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2016 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package hookstate
    21  
    22  import (
    23  	"bytes"
    24  	"encoding/json"
    25  	"fmt"
    26  	"sync"
    27  	"sync/atomic"
    28  	"time"
    29  
    30  	"github.com/snapcore/snapd/jsonutil"
    31  	"github.com/snapcore/snapd/logger"
    32  	"github.com/snapcore/snapd/overlord/state"
    33  	"github.com/snapcore/snapd/randutil"
    34  	"github.com/snapcore/snapd/snap"
    35  )
    36  
    37  // Context represents the context under which the snap is calling back into snapd.
    38  // It is associated with a task when the callback is happening from within a hook,
    39  // or otherwise considered an ephemeral context in that its associated data will
    40  // be discarded once that individual call is finished.
    41  type Context struct {
    42  	task    *state.Task
    43  	state   *state.State
    44  	setup   *HookSetup
    45  	id      string
    46  	handler Handler
    47  
    48  	cache  map[interface{}]interface{}
    49  	onDone []func() error
    50  
    51  	mutex        sync.Mutex
    52  	mutexChecker int32
    53  }
    54  
    55  // NewContext returns a new context associated with the provided task or
    56  // an ephemeral context if task is nil.
    57  //
    58  // A random ID is generated if contextID is empty.
    59  func NewContext(task *state.Task, state *state.State, setup *HookSetup, handler Handler, contextID string) (*Context, error) {
    60  	if contextID == "" {
    61  		var err error
    62  		contextID, err = randutil.CryptoToken(32)
    63  		if err != nil {
    64  			return nil, err
    65  		}
    66  	}
    67  
    68  	return &Context{
    69  		task:    task,
    70  		state:   state,
    71  		setup:   setup,
    72  		id:      contextID,
    73  		handler: handler,
    74  		cache:   make(map[interface{}]interface{}),
    75  	}, nil
    76  }
    77  
    78  func newEphemeralHookContextWithData(st *state.State, setup *HookSetup, contextData map[string]interface{}) (*Context, error) {
    79  	context, err := NewContext(nil, st, setup, nil, "")
    80  	if err != nil {
    81  		return nil, err
    82  	}
    83  	if contextData != nil {
    84  		serialized, err := json.Marshal(contextData)
    85  		if err != nil {
    86  			return nil, err
    87  		}
    88  		var data map[string]*json.RawMessage
    89  		if err := json.Unmarshal(serialized, &data); err != nil {
    90  			return nil, err
    91  		}
    92  		context.cache["ephemeral-context"] = data
    93  	}
    94  	return context, nil
    95  }
    96  
    97  // InstanceName returns the name of the snap instance containing the hook.
    98  func (c *Context) InstanceName() string {
    99  	return c.setup.Snap
   100  }
   101  
   102  // SnapRevision returns the revision of the snap containing the hook.
   103  func (c *Context) SnapRevision() snap.Revision {
   104  	return c.setup.Revision
   105  }
   106  
   107  // Task returns the task associated with the hook or (nil, false) if the context is ephemeral
   108  // and task is not available.
   109  func (c *Context) Task() (*state.Task, bool) {
   110  	return c.task, c.task != nil
   111  }
   112  
   113  // HookName returns the name of the hook in this context.
   114  func (c *Context) HookName() string {
   115  	return c.setup.Hook
   116  }
   117  
   118  // Timeout returns the maximum time this hook can run
   119  func (c *Context) Timeout() time.Duration {
   120  	return c.setup.Timeout
   121  }
   122  
   123  // ID returns the ID of the context.
   124  func (c *Context) ID() string {
   125  	return c.id
   126  }
   127  
   128  // Handler returns the handler for this context
   129  func (c *Context) Handler() Handler {
   130  	return c.handler
   131  }
   132  
   133  // Lock acquires the lock for this context (required for Set/Get, Cache/Cached, Logf/Errorf),
   134  // and OnDone/Done).
   135  func (c *Context) Lock() {
   136  	c.mutex.Lock()
   137  	c.state.Lock()
   138  	atomic.AddInt32(&c.mutexChecker, 1)
   139  }
   140  
   141  // Unlock releases the lock for this context.
   142  func (c *Context) Unlock() {
   143  	atomic.AddInt32(&c.mutexChecker, -1)
   144  	c.state.Unlock()
   145  	c.mutex.Unlock()
   146  }
   147  
   148  func (c *Context) reading() {
   149  	if atomic.LoadInt32(&c.mutexChecker) != 1 {
   150  		panic("internal error: accessing context without lock")
   151  	}
   152  }
   153  
   154  func (c *Context) writing() {
   155  	if atomic.LoadInt32(&c.mutexChecker) != 1 {
   156  		panic("internal error: accessing context without lock")
   157  	}
   158  }
   159  
   160  // Set associates value with key. The provided value must properly marshal and
   161  // unmarshal with encoding/json. Note that the context needs to be locked and
   162  // unlocked by the caller.
   163  func (c *Context) Set(key string, value interface{}) {
   164  	c.writing()
   165  
   166  	var data map[string]*json.RawMessage
   167  	if c.IsEphemeral() {
   168  		data, _ = c.cache["ephemeral-context"].(map[string]*json.RawMessage)
   169  	} else {
   170  		if err := c.task.Get("hook-context", &data); err != nil && err != state.ErrNoState {
   171  			panic(fmt.Sprintf("internal error: cannot unmarshal context: %v", err))
   172  		}
   173  	}
   174  	if data == nil {
   175  		data = make(map[string]*json.RawMessage)
   176  	}
   177  
   178  	marshalledValue, err := json.Marshal(value)
   179  	if err != nil {
   180  		panic(fmt.Sprintf("internal error: cannot marshal context value for %q: %s", key, err))
   181  	}
   182  	raw := json.RawMessage(marshalledValue)
   183  	data[key] = &raw
   184  
   185  	if c.IsEphemeral() {
   186  		c.cache["ephemeral-context"] = data
   187  	} else {
   188  		c.task.Set("hook-context", data)
   189  	}
   190  }
   191  
   192  // Get unmarshals the stored value associated with the provided key into the
   193  // value parameter. Note that the context needs to be locked/unlocked by the
   194  // caller.
   195  func (c *Context) Get(key string, value interface{}) error {
   196  	c.reading()
   197  
   198  	var data map[string]*json.RawMessage
   199  	if c.IsEphemeral() {
   200  		data, _ = c.cache["ephemeral-context"].(map[string]*json.RawMessage)
   201  		if data == nil {
   202  			return state.ErrNoState
   203  		}
   204  	} else {
   205  		if err := c.task.Get("hook-context", &data); err != nil {
   206  			return err
   207  		}
   208  	}
   209  
   210  	raw, ok := data[key]
   211  	if !ok {
   212  		return state.ErrNoState
   213  	}
   214  
   215  	err := jsonutil.DecodeWithNumber(bytes.NewReader(*raw), &value)
   216  	if err != nil {
   217  		return fmt.Errorf("cannot unmarshal context value for %q: %s", key, err)
   218  	}
   219  
   220  	return nil
   221  }
   222  
   223  // State returns the state contained within the context
   224  func (c *Context) State() *state.State {
   225  	return c.state
   226  }
   227  
   228  // Cached returns the cached value associated with the provided key. It returns
   229  // nil if there is no entry for key. Note that the context needs to be locked
   230  // and unlocked by the caller.
   231  func (c *Context) Cached(key interface{}) interface{} {
   232  	c.reading()
   233  
   234  	return c.cache[key]
   235  }
   236  
   237  // Cache associates value with key. The cached value is not persisted. Note that
   238  // the context needs to be locked/unlocked by the caller.
   239  func (c *Context) Cache(key, value interface{}) {
   240  	c.writing()
   241  
   242  	c.cache[key] = value
   243  }
   244  
   245  // OnDone requests the provided function to be run once the context knows it's
   246  // complete. This can be called multiple times; each function will be called in
   247  // the order in which they were added. Note that the context needs to be locked
   248  // and unlocked by the caller.
   249  func (c *Context) OnDone(f func() error) {
   250  	c.writing()
   251  
   252  	c.onDone = append(c.onDone, f)
   253  }
   254  
   255  // Done is called to notify the context that its hook has exited successfully.
   256  // It will call all of the functions added in OnDone (even if one of them
   257  // returns an error) and will return the first error encountered. Note that the
   258  // context needs to be locked/unlocked by the caller.
   259  func (c *Context) Done() error {
   260  	c.reading()
   261  
   262  	var firstErr error
   263  	for _, f := range c.onDone {
   264  		if err := f(); err != nil && firstErr == nil {
   265  			firstErr = err
   266  		}
   267  	}
   268  
   269  	return firstErr
   270  }
   271  
   272  func (c *Context) IsEphemeral() bool {
   273  	return c.task == nil
   274  }
   275  
   276  // ChangeID returns change ID for non-ephemeral context
   277  // or empty string otherwise.
   278  func (c *Context) ChangeID() string {
   279  	if task, ok := c.Task(); ok {
   280  		if chg := task.Change(); chg != nil {
   281  			return chg.ID()
   282  		}
   283  	}
   284  	return ""
   285  }
   286  
   287  // Logf logs to the context, either to the logger for ephemeral contexts
   288  // or the task log.
   289  //
   290  // Context must be locked.
   291  func (c *Context) Logf(fmt string, args ...interface{}) {
   292  	c.writing()
   293  	if c.IsEphemeral() {
   294  		logger.Noticef(fmt, args...)
   295  	} else {
   296  		c.task.Logf(fmt, args...)
   297  	}
   298  }
   299  
   300  // Errorf logs errors to the context, either to the logger for
   301  // ephemeral contexts or the task log.
   302  //
   303  // Context must be locked.
   304  func (c *Context) Errorf(fmt string, args ...interface{}) {
   305  	c.writing()
   306  	if c.IsEphemeral() {
   307  		// XXX: loger has no Errorf() :/
   308  		logger.Noticef(fmt, args...)
   309  	} else {
   310  		c.task.Errorf(fmt, args...)
   311  	}
   312  }