github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/apiserver/common/unitswatcher.go (about)

     1  // Copyright 2013 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package common
     5  
     6  import (
     7  	"github.com/juju/errors"
     8  	"github.com/juju/names/v5"
     9  
    10  	apiservererrors "github.com/juju/juju/apiserver/errors"
    11  	"github.com/juju/juju/apiserver/facade"
    12  	"github.com/juju/juju/rpc/params"
    13  	"github.com/juju/juju/state"
    14  	"github.com/juju/juju/state/watcher"
    15  )
    16  
    17  // UnitsWatcher implements a common WatchUnits method for use by
    18  // various facades.
    19  type UnitsWatcher struct {
    20  	st          state.EntityFinder
    21  	resources   facade.Resources
    22  	getCanWatch GetAuthFunc
    23  }
    24  
    25  // NewUnitsWatcher returns a new UnitsWatcher. The GetAuthFunc will be
    26  // used on each invocation of WatchUnits to determine current
    27  // permissions.
    28  func NewUnitsWatcher(st state.EntityFinder, resources facade.Resources, getCanWatch GetAuthFunc) *UnitsWatcher {
    29  	return &UnitsWatcher{
    30  		st:          st,
    31  		resources:   resources,
    32  		getCanWatch: getCanWatch,
    33  	}
    34  }
    35  
    36  func (u *UnitsWatcher) watchOneEntityUnits(canWatch AuthFunc, tag names.Tag) (params.StringsWatchResult, error) {
    37  	nothing := params.StringsWatchResult{}
    38  	if !canWatch(tag) {
    39  		return nothing, apiservererrors.ErrPerm
    40  	}
    41  	entity0, err := u.st.FindEntity(tag)
    42  	if err != nil {
    43  		return nothing, err
    44  	}
    45  	entity, ok := entity0.(state.UnitsWatcher)
    46  	if !ok {
    47  		return nothing, apiservererrors.NotSupportedError(tag, "watching units")
    48  	}
    49  	watch := entity.WatchUnits()
    50  	// Consume the initial event and forward it to the result.
    51  	if changes, ok := <-watch.Changes(); ok {
    52  		return params.StringsWatchResult{
    53  			StringsWatcherId: u.resources.Register(watch),
    54  			Changes:          changes,
    55  		}, nil
    56  	}
    57  	return nothing, watcher.EnsureErr(watch)
    58  }
    59  
    60  // WatchUnits starts a StringsWatcher to watch all units belonging to
    61  // to any entity (machine or service) passed in args.
    62  func (u *UnitsWatcher) WatchUnits(args params.Entities) (params.StringsWatchResults, error) {
    63  	result := params.StringsWatchResults{
    64  		Results: make([]params.StringsWatchResult, len(args.Entities)),
    65  	}
    66  	if len(args.Entities) == 0 {
    67  		return result, nil
    68  	}
    69  	canWatch, err := u.getCanWatch()
    70  	if err != nil {
    71  		return params.StringsWatchResults{}, errors.Trace(err)
    72  	}
    73  	for i, entity := range args.Entities {
    74  		tag, err := names.ParseTag(entity.Tag)
    75  		if err != nil {
    76  			result.Results[i].Error = apiservererrors.ServerError(apiservererrors.ErrPerm)
    77  			continue
    78  		}
    79  		entityResult, err := u.watchOneEntityUnits(canWatch, tag)
    80  		result.Results[i] = entityResult
    81  		result.Results[i].Error = apiservererrors.ServerError(err)
    82  	}
    83  	return result, nil
    84  }