github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/apiserver/common/unitswatcher.go (about) 1 // Copyright 2013 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package common 5 6 import ( 7 "github.com/juju/errors" 8 "github.com/juju/names/v5" 9 10 apiservererrors "github.com/juju/juju/apiserver/errors" 11 "github.com/juju/juju/apiserver/facade" 12 "github.com/juju/juju/rpc/params" 13 "github.com/juju/juju/state" 14 "github.com/juju/juju/state/watcher" 15 ) 16 17 // UnitsWatcher implements a common WatchUnits method for use by 18 // various facades. 19 type UnitsWatcher struct { 20 st state.EntityFinder 21 resources facade.Resources 22 getCanWatch GetAuthFunc 23 } 24 25 // NewUnitsWatcher returns a new UnitsWatcher. The GetAuthFunc will be 26 // used on each invocation of WatchUnits to determine current 27 // permissions. 28 func NewUnitsWatcher(st state.EntityFinder, resources facade.Resources, getCanWatch GetAuthFunc) *UnitsWatcher { 29 return &UnitsWatcher{ 30 st: st, 31 resources: resources, 32 getCanWatch: getCanWatch, 33 } 34 } 35 36 func (u *UnitsWatcher) watchOneEntityUnits(canWatch AuthFunc, tag names.Tag) (params.StringsWatchResult, error) { 37 nothing := params.StringsWatchResult{} 38 if !canWatch(tag) { 39 return nothing, apiservererrors.ErrPerm 40 } 41 entity0, err := u.st.FindEntity(tag) 42 if err != nil { 43 return nothing, err 44 } 45 entity, ok := entity0.(state.UnitsWatcher) 46 if !ok { 47 return nothing, apiservererrors.NotSupportedError(tag, "watching units") 48 } 49 watch := entity.WatchUnits() 50 // Consume the initial event and forward it to the result. 51 if changes, ok := <-watch.Changes(); ok { 52 return params.StringsWatchResult{ 53 StringsWatcherId: u.resources.Register(watch), 54 Changes: changes, 55 }, nil 56 } 57 return nothing, watcher.EnsureErr(watch) 58 } 59 60 // WatchUnits starts a StringsWatcher to watch all units belonging to 61 // to any entity (machine or service) passed in args. 62 func (u *UnitsWatcher) WatchUnits(args params.Entities) (params.StringsWatchResults, error) { 63 result := params.StringsWatchResults{ 64 Results: make([]params.StringsWatchResult, len(args.Entities)), 65 } 66 if len(args.Entities) == 0 { 67 return result, nil 68 } 69 canWatch, err := u.getCanWatch() 70 if err != nil { 71 return params.StringsWatchResults{}, errors.Trace(err) 72 } 73 for i, entity := range args.Entities { 74 tag, err := names.ParseTag(entity.Tag) 75 if err != nil { 76 result.Results[i].Error = apiservererrors.ServerError(apiservererrors.ErrPerm) 77 continue 78 } 79 entityResult, err := u.watchOneEntityUnits(canWatch, tag) 80 result.Results[i] = entityResult 81 result.Results[i].Error = apiservererrors.ServerError(err) 82 } 83 return result, nil 84 }