github.com/juju/juju@v0.0.0-20240327075706-a90865de2538/worker/state/manifold.go (about)

     1  // Copyright 2016 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package state
     5  
     6  import (
     7  	stdcontext "context"
     8  	"time"
     9  
    10  	"github.com/juju/errors"
    11  	"github.com/juju/loggo"
    12  	"github.com/juju/worker/v3"
    13  	"github.com/juju/worker/v3/catacomb"
    14  	"github.com/juju/worker/v3/dependency"
    15  
    16  	coreagent "github.com/juju/juju/agent"
    17  	"github.com/juju/juju/state"
    18  )
    19  
    20  var logger = loggo.GetLogger("juju.worker.state")
    21  
    22  // ManifoldConfig provides the dependencies for Manifold.
    23  type ManifoldConfig struct {
    24  	AgentName              string
    25  	StateConfigWatcherName string
    26  	OpenStatePool          func(stdcontext.Context, coreagent.Config) (*state.StatePool, error)
    27  	PingInterval           time.Duration
    28  
    29  	// SetStatePool is called with the state pool when it is created,
    30  	// and called again with nil just before the state pool is closed.
    31  	// This is used for publishing the state pool to the agent's
    32  	// introspection worker, which runs outside of the dependency
    33  	// engine; hence the manifold's Output cannot be relied upon.
    34  	SetStatePool func(*state.StatePool)
    35  }
    36  
    37  // Validate validates the manifold configuration.
    38  func (config ManifoldConfig) Validate() error {
    39  	if config.AgentName == "" {
    40  		return errors.NotValidf("empty AgentName")
    41  	}
    42  	if config.StateConfigWatcherName == "" {
    43  		return errors.NotValidf("empty StateConfigWatcherName")
    44  	}
    45  	if config.OpenStatePool == nil {
    46  		return errors.NotValidf("nil OpenStatePool")
    47  	}
    48  	if config.SetStatePool == nil {
    49  		return errors.NotValidf("nil SetStatePool")
    50  	}
    51  	return nil
    52  }
    53  
    54  const defaultPingInterval = 15 * time.Second
    55  
    56  // Manifold returns a manifold whose worker which wraps a
    57  // *state.State, which is in turn wrapper by a StateTracker.  It will
    58  // exit if the State's associated mongodb session dies.
    59  func Manifold(config ManifoldConfig) dependency.Manifold {
    60  	return dependency.Manifold{
    61  		Inputs: []string{
    62  			config.AgentName,
    63  			config.StateConfigWatcherName,
    64  		},
    65  		Start: func(context dependency.Context) (worker.Worker, error) {
    66  			if err := config.Validate(); err != nil {
    67  				return nil, errors.Trace(err)
    68  			}
    69  
    70  			// Get the agent.
    71  			var agent coreagent.Agent
    72  			if err := context.Get(config.AgentName, &agent); err != nil {
    73  				return nil, err
    74  			}
    75  
    76  			// Confirm we're running in a state server by asking the
    77  			// stateconfigwatcher manifold.
    78  			var haveStateConfig bool
    79  			if err := context.Get(config.StateConfigWatcherName, &haveStateConfig); err != nil {
    80  				return nil, err
    81  			}
    82  			if !haveStateConfig {
    83  				return nil, errors.Annotate(dependency.ErrMissing, "no StateServingInfo in config")
    84  			}
    85  
    86  			pool, err := config.OpenStatePool(stdcontext.Background(), agent.CurrentConfig())
    87  			if err != nil {
    88  				return nil, errors.Trace(err)
    89  			}
    90  			stTracker := newStateTracker(pool)
    91  
    92  			pingInterval := config.PingInterval
    93  			if pingInterval == 0 {
    94  				pingInterval = defaultPingInterval
    95  			}
    96  
    97  			w := &stateWorker{
    98  				stTracker:    stTracker,
    99  				pingInterval: pingInterval,
   100  				setStatePool: config.SetStatePool,
   101  			}
   102  			if err := catacomb.Invoke(catacomb.Plan{
   103  				Site: &w.catacomb,
   104  				Work: w.loop,
   105  			}); err != nil {
   106  				if err := stTracker.Done(); err != nil {
   107  					logger.Warningf("error releasing state: %v", err)
   108  				}
   109  				return nil, errors.Trace(err)
   110  			}
   111  			return w, nil
   112  		},
   113  		Output: outputFunc,
   114  	}
   115  }
   116  
   117  // outputFunc extracts a *StateTracker from a *stateWorker.
   118  func outputFunc(in worker.Worker, out interface{}) error {
   119  	inWorker, _ := in.(*stateWorker)
   120  	if inWorker == nil {
   121  		return errors.Errorf("in should be a %T; got %T", inWorker, in)
   122  	}
   123  
   124  	switch outPointer := out.(type) {
   125  	case *StateTracker:
   126  		*outPointer = inWorker.stTracker
   127  	default:
   128  		return errors.Errorf("out should be *StateTracker; got %T", out)
   129  	}
   130  	return nil
   131  }