
     1  // Copyright 2015 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     4  package context
     6  import (
     7  	"sort"
     9  	""
    10  	""
    12  	""
    13  )
    15  var logger = loggo.GetLogger("juju.payload.context")
    17  // APIClient represents the API needs of a Context.
    18  type APIClient interface {
    19  	// List requests the payload info for the given IDs.
    20  	List(fullIDs ...string) ([]payload.Result, error)
    21  	// Track sends a request to update state with the provided payloads.
    22  	Track(payloads ...payload.Payload) ([]payload.Result, error)
    23  	// Untrack removes the payloads from our list track.
    24  	Untrack(fullIDs ...string) ([]payload.Result, error)
    25  	// SetStatus sets the status for the given IDs.
    26  	SetStatus(status string, fullIDs ...string) ([]payload.Result, error)
    27  }
    29  // TODO(ericsnow) Rename Get and Set to more specifically describe what
    30  // they are for.
    32  // Component provides the hook context data specific to payloads.
    33  type Component interface {
    34  	// Get returns the payload info corresponding to the given ID.
    35  	Get(class, id string) (*payload.Payload, error)
    36  	// Track records the payload info in the hook context.
    37  	Track(payload payload.Payload) error
    38  	// Untrack removes the payload from our list of payloads to track.
    39  	Untrack(class, id string) error
    40  	// SetStatus sets the status of the payload.
    41  	SetStatus(class, id, status string) error
    42  	// List returns the list of registered payload IDs.
    43  	List() ([]string, error)
    44  	// Flush pushes the hook context data out to state.
    45  	Flush() error
    46  }
    48  var _ Component = (*Context)(nil)
    50  // Context is the payload portion of the hook context.
    51  type Context struct {
    52  	api     APIClient
    53  	dataDir string
    54  	// TODO(ericsnow) Use the Juju ID for the key rather than Info.ID().
    55  	payloads map[string]payload.Payload
    56  	updates  map[string]payload.Payload
    57  }
    59  // NewContext returns a new jujuc.ContextComponent for payloads.
    60  func NewContext(api APIClient, dataDir string) *Context {
    61  	return &Context{
    62  		api:      api,
    63  		dataDir:  dataDir,
    64  		payloads: make(map[string]payload.Payload),
    65  		updates:  make(map[string]payload.Payload),
    66  	}
    67  }
    69  // NewContextAPI returns a new jujuc.ContextComponent for payloads.
    70  func NewContextAPI(api APIClient, dataDir string) (*Context, error) {
    71  	results, err := api.List()
    72  	if err != nil {
    73  		return nil, errors.Trace(err)
    74  	}
    76  	ctx := NewContext(api, dataDir)
    77  	for _, result := range results {
    78  		pl := result.Payload
    79  		// TODO(ericsnow) Use id instead of pl.FullID().
    80  		ctx.payloads[pl.FullID()] = pl.Payload
    81  	}
    82  	return ctx, nil
    83  }
    85  // HookContext is the portion of jujuc.Context used in this package.
    86  type HookContext interface {
    87  	// Component implements jujuc.Context.
    88  	Component(string) (Component, error)
    89  }
    91  // ContextComponent returns the hook context for the payload
    92  // payload component.
    93  func ContextComponent(ctx HookContext) (Component, error) {
    94  	compCtx, err := ctx.Component(payload.ComponentName)
    95  	if errors.IsNotFound(err) {
    96  		return nil, errors.Errorf("component %q not registered", payload.ComponentName)
    97  	}
    98  	if err != nil {
    99  		return nil, errors.Trace(err)
   100  	}
   101  	if compCtx == nil {
   102  		return nil, errors.Errorf("component %q disabled", payload.ComponentName)
   103  	}
   104  	return compCtx, nil
   105  }
   107  // TODO(ericsnow) Should we build in refreshes in all the methods?
   109  // Payloads returns the payloads known to the context.
   110  func (c *Context) Payloads() ([]payload.Payload, error) {
   111  	payloads := mergePayloadMaps(c.payloads, c.updates)
   112  	var newPayloads []payload.Payload
   113  	for _, pl := range payloads {
   114  		newPayloads = append(newPayloads, pl)
   115  	}
   117  	return newPayloads, nil
   118  }
   120  func mergePayloadMaps(payloads, updates map[string]payload.Payload) map[string]payload.Payload {
   121  	// At this point payloads and updates have already been checked for
   122  	// nil values so we won't see any here.
   123  	result := make(map[string]payload.Payload)
   124  	for k, v := range payloads {
   125  		result[k] = v
   126  	}
   127  	for k, v := range updates {
   128  		result[k] = v
   129  	}
   130  	return result
   131  }
   133  // Get returns the payload info corresponding to the given ID.
   134  func (c *Context) Get(class, id string) (*payload.Payload, error) {
   135  	fullID := payload.BuildID(class, id)
   136  	logger.Tracef("getting %q from hook context", fullID)
   138  	actual, ok := c.updates[fullID]
   139  	if !ok {
   140  		actual, ok = c.payloads[fullID]
   141  		if !ok {
   142  			return nil, errors.NotFoundf("%s", fullID)
   143  		}
   144  	}
   145  	return &actual, nil
   146  }
   148  // List returns the sorted names of all registered payloads.
   149  func (c *Context) List() ([]string, error) {
   150  	logger.Tracef("listing all payloads in hook context")
   152  	payloads, err := c.Payloads()
   153  	if err != nil {
   154  		return nil, errors.Trace(err)
   155  	}
   156  	if len(payloads) == 0 {
   157  		return nil, nil
   158  	}
   159  	var ids []string
   160  	for _, wl := range payloads {
   161  		ids = append(ids, wl.FullID())
   162  	}
   163  	sort.Strings(ids)
   164  	return ids, nil
   165  }
   167  // Track records the payload info in the hook context.
   168  func (c *Context) Track(pl payload.Payload) error {
   169  	logger.Tracef("adding %q to hook context: %#v", pl.FullID(), pl)
   171  	if err := pl.Validate(); err != nil {
   172  		return errors.Trace(err)
   173  	}
   175  	// TODO(ericsnow) We are likely missing mechanisim for local persistence.
   176  	id := pl.FullID()
   177  	c.updates[id] = pl
   178  	return nil
   179  }
   181  // Untrack tells juju to stop tracking this payload.
   182  func (c *Context) Untrack(class, id string) error {
   183  	fullID := payload.BuildID(class, id)
   184  	logger.Tracef("Calling untrack on payload context %q", fullID)
   186  	res, err := c.api.Untrack(fullID)
   187  	if err != nil {
   188  		return errors.Trace(err)
   189  	}
   190  	// TODO(ericsnow) We should not ignore a 0-len result.
   191  	if len(res) > 0 && res[0].Error != nil {
   192  		return errors.Trace(res[0].Error)
   193  	}
   194  	delete(c.payloads, id)
   196  	return nil
   197  }
   199  // SetStatus sets the identified payload's status.
   200  func (c *Context) SetStatus(class, id, status string) error {
   201  	fullID := payload.BuildID(class, id)
   202  	logger.Tracef("Calling status-set on payload context %q", fullID)
   204  	res, err := c.api.SetStatus(status, fullID)
   205  	if err != nil {
   206  		return errors.Trace(err)
   207  	}
   208  	// TODO(ericsnow) We should not ignore a 0-len result.
   209  	if len(res) > 0 && res[0].Error != nil {
   210  		// In a hook context, the case where the specified payload does
   211  		// not exist is a special one. A hook tool is how a charm author
   212  		// communicates the state of the charm. So returning an error
   213  		// here in the "missing" case makes less sense than in other
   214  		// places. We could simply ignore any error that surfaces for
   215  		// that case. However, returning the error communicates to the
   216  		// charm author that what they're trying to communicate doesn't
   217  		// make sense.
   218  		return errors.Trace(res[0].Error)
   219  	}
   221  	return nil
   222  }
   224  // TODO(ericsnow) The context machinery is not actually using this yet.
   226  // Flush implements jujuc.ContextComponent. In this case that means all
   227  // added and updated payload.Payload in the hook context are pushed to
   228  // Juju state via the API.
   229  func (c *Context) Flush() error {
   230  	logger.Tracef("flushing from hook context to state")
   231  	// TODO(natefinch): make this a noop and move this code into set.
   233  	if len(c.updates) > 0 {
   234  		var updates []payload.Payload
   235  		for _, pl := range c.updates {
   236  			updates = append(updates, pl)
   237  		}
   239  		res, err := c.api.Track(updates...)
   240  		if err != nil {
   241  			return errors.Trace(err)
   242  		}
   243  		if len(res) > 0 && res[0].Error != nil {
   244  			return errors.Trace(res[0].Error)
   245  		}
   247  		for k, v := range c.updates {
   248  			c.payloads[k] = v
   249  		}
   250  		c.updates = map[string]payload.Payload{}
   251  	}
   252  	return nil
   253  }