github.com/juju/juju@v0.0.0-20240327075706-a90865de2538/worker/controlsocket/manifold.go (about)

     1  // Copyright 2023 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package controlsocket
     5  
     6  import (
     7  	"github.com/juju/errors"
     8  	"github.com/juju/worker/v3"
     9  	"github.com/juju/worker/v3/dependency"
    10  
    11  	"github.com/juju/juju/state"
    12  	"github.com/juju/juju/worker/common"
    13  	workerstate "github.com/juju/juju/worker/state"
    14  )
    15  
    16  // ManifoldConfig describes the dependencies required by the controlsocket worker.
    17  type ManifoldConfig struct {
    18  	StateName  string
    19  	Logger     Logger
    20  	NewWorker  func(Config) (worker.Worker, error)
    21  	SocketName string
    22  }
    23  
    24  // Manifold returns a Manifold that encapsulates the controlsocket worker.
    25  func Manifold(config ManifoldConfig) dependency.Manifold {
    26  	return dependency.Manifold{
    27  		Inputs: []string{
    28  			config.StateName,
    29  		},
    30  		Start: config.start,
    31  	}
    32  }
    33  
    34  // Validate is called by start to check for bad configuration.
    35  func (cfg ManifoldConfig) Validate() error {
    36  	if cfg.StateName == "" {
    37  		return errors.NotValidf("empty StateName")
    38  	}
    39  	if cfg.Logger == nil {
    40  		return errors.NotValidf("nil Logger")
    41  	}
    42  	if cfg.NewWorker == nil {
    43  		return errors.NotValidf("nil NewWorker func")
    44  	}
    45  	if cfg.SocketName == "" {
    46  		return errors.NotValidf("empty SocketName")
    47  	}
    48  	return nil
    49  }
    50  
    51  // start is a StartFunc for a Worker manifold.
    52  func (cfg ManifoldConfig) start(context dependency.Context) (_ worker.Worker, err error) {
    53  	if err = cfg.Validate(); err != nil {
    54  		return nil, errors.Trace(err)
    55  	}
    56  
    57  	var stTracker workerstate.StateTracker
    58  	if err = context.Get(cfg.StateName, &stTracker); err != nil {
    59  		return nil, errors.Trace(err)
    60  	}
    61  
    62  	var statePool *state.StatePool
    63  	statePool, err = stTracker.Use()
    64  	if err != nil {
    65  		return nil, errors.Trace(err)
    66  	}
    67  	// Make sure we clean up state objects if an error occurs.
    68  	defer func() {
    69  		if err != nil {
    70  			_ = stTracker.Done()
    71  		}
    72  	}()
    73  
    74  	var st *state.State
    75  	st, err = statePool.SystemState()
    76  	if err != nil {
    77  		return nil, errors.Trace(err)
    78  	}
    79  
    80  	var w worker.Worker
    81  	w, err = cfg.NewWorker(Config{
    82  		State:      stateShim{st},
    83  		Logger:     cfg.Logger,
    84  		SocketName: cfg.SocketName,
    85  	})
    86  	if err != nil {
    87  		return nil, errors.Trace(err)
    88  	}
    89  	return common.NewCleanupWorker(w, func() { _ = stTracker.Done() }), nil
    90  }