github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/worker/multiwatcher/manifold.go (about)

     1  // Copyright 2019 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package multiwatcher
     5  
     6  import (
     7  	"time"
     8  
     9  	"github.com/juju/errors"
    10  	"github.com/juju/worker/v3"
    11  	"github.com/juju/worker/v3/dependency"
    12  	"github.com/prometheus/client_golang/prometheus"
    13  
    14  	"github.com/juju/juju/core/multiwatcher"
    15  	"github.com/juju/juju/state"
    16  	workerstate "github.com/juju/juju/worker/state"
    17  )
    18  
    19  // Logger describes the logging methods used in this package by the worker.
    20  type Logger interface {
    21  	IsTraceEnabled() bool
    22  	Tracef(string, ...interface{})
    23  	Errorf(string, ...interface{})
    24  	Criticalf(string, ...interface{})
    25  }
    26  
    27  // Clock describes the time methods used in this package by the worker.
    28  type Clock interface {
    29  	Now() time.Time
    30  }
    31  
    32  // ManifoldConfig holds the information necessary to run a model cache worker in
    33  // a dependency.Engine.
    34  type ManifoldConfig struct {
    35  	StateName string
    36  	Clock     Clock
    37  	Logger    Logger
    38  
    39  	// NOTE: what metrics do we want to expose here?
    40  	// loop restart count for one.
    41  	PrometheusRegisterer prometheus.Registerer
    42  
    43  	NewWorker     func(Config) (worker.Worker, error)
    44  	NewAllWatcher func(*state.StatePool) (state.AllWatcherBacking, error)
    45  }
    46  
    47  // Validate validates the manifold configuration.
    48  func (config ManifoldConfig) Validate() error {
    49  	if config.StateName == "" {
    50  		return errors.NotValidf("empty StateName")
    51  	}
    52  	if config.Clock == nil {
    53  		return errors.NotValidf("missing Clock")
    54  	}
    55  	if config.Logger == nil {
    56  		return errors.NotValidf("missing Logger")
    57  	}
    58  	if config.PrometheusRegisterer == nil {
    59  		return errors.NotValidf("missing PrometheusRegisterer")
    60  	}
    61  	if config.NewWorker == nil {
    62  		return errors.NotValidf("missing NewWorker func")
    63  	}
    64  	if config.NewAllWatcher == nil {
    65  		return errors.NotValidf("missing NewAllWatcher func")
    66  	}
    67  	return nil
    68  }
    69  
    70  // Manifold returns a dependency.Manifold that will run a model cache
    71  // worker. The manifold outputs a *cache.Controller, primarily for
    72  // the apiserver to depend on and use.
    73  func Manifold(config ManifoldConfig) dependency.Manifold {
    74  	return dependency.Manifold{
    75  		Inputs: []string{
    76  			config.StateName,
    77  		},
    78  		Start:  config.start,
    79  		Output: WorkerFactory,
    80  	}
    81  }
    82  
    83  // start is a method on ManifoldConfig because it's more readable than a closure.
    84  func (config ManifoldConfig) start(context dependency.Context) (worker.Worker, error) {
    85  	if err := config.Validate(); err != nil {
    86  		return nil, errors.Trace(err)
    87  	}
    88  	var stTracker workerstate.StateTracker
    89  	if err := context.Get(config.StateName, &stTracker); err != nil {
    90  		return nil, errors.Trace(err)
    91  	}
    92  
    93  	pool, err := stTracker.Use()
    94  	if err != nil {
    95  		return nil, errors.Trace(err)
    96  	}
    97  
    98  	allWatcher, err := config.NewAllWatcher(pool)
    99  	if err != nil {
   100  		return nil, errors.Trace(err)
   101  	}
   102  
   103  	w, err := config.NewWorker(Config{
   104  		Clock:                config.Clock,
   105  		Logger:               config.Logger,
   106  		Backing:              allWatcher,
   107  		PrometheusRegisterer: config.PrometheusRegisterer,
   108  		Cleanup:              func() { _ = stTracker.Done() },
   109  	})
   110  	if err != nil {
   111  		_ = stTracker.Done()
   112  		return nil, errors.Trace(err)
   113  	}
   114  	return w, nil
   115  }
   116  
   117  // WorkerFactory extracts a Factory from a *Worker.
   118  func WorkerFactory(in worker.Worker, out interface{}) error {
   119  	inWorker, _ := in.(*Worker)
   120  	if inWorker == nil {
   121  		return errors.Errorf("in should be a %T; got %T", inWorker, in)
   122  	}
   123  
   124  	switch outPointer := out.(type) {
   125  	case *multiwatcher.Factory:
   126  		// The worker itself is the factory.
   127  		*outPointer = inWorker
   128  	default:
   129  		return errors.Errorf("out should be *multiwatcher.Factory; got %T", out)
   130  	}
   131  	return nil
   132  }