github.com/axw/juju@v0.0.0-20161005053422-4bd6544d08d4/apiserver/common/unitswatcher.go (about) 1 // Copyright 2013 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package common 5 6 import ( 7 "github.com/juju/errors" 8 "gopkg.in/juju/names.v2" 9 10 "github.com/juju/juju/apiserver/facade" 11 "github.com/juju/juju/apiserver/params" 12 "github.com/juju/juju/state" 13 "github.com/juju/juju/state/watcher" 14 ) 15 16 // UnitsWatcher implements a common WatchUnits method for use by 17 // various facades. 18 type UnitsWatcher struct { 19 st state.EntityFinder 20 resources facade.Resources 21 getCanWatch GetAuthFunc 22 } 23 24 // NewUnitsWatcher returns a new UnitsWatcher. The GetAuthFunc will be 25 // used on each invocation of WatchUnits to determine current 26 // permissions. 27 func NewUnitsWatcher(st state.EntityFinder, resources facade.Resources, getCanWatch GetAuthFunc) *UnitsWatcher { 28 return &UnitsWatcher{ 29 st: st, 30 resources: resources, 31 getCanWatch: getCanWatch, 32 } 33 } 34 35 func (u *UnitsWatcher) watchOneEntityUnits(canWatch AuthFunc, tag names.Tag) (params.StringsWatchResult, error) { 36 nothing := params.StringsWatchResult{} 37 if !canWatch(tag) { 38 return nothing, ErrPerm 39 } 40 entity0, err := u.st.FindEntity(tag) 41 if err != nil { 42 return nothing, err 43 } 44 entity, ok := entity0.(state.UnitsWatcher) 45 if !ok { 46 return nothing, NotSupportedError(tag, "watching units") 47 } 48 watch := entity.WatchUnits() 49 // Consume the initial event and forward it to the result. 50 if changes, ok := <-watch.Changes(); ok { 51 return params.StringsWatchResult{ 52 StringsWatcherId: u.resources.Register(watch), 53 Changes: changes, 54 }, nil 55 } 56 return nothing, watcher.EnsureErr(watch) 57 } 58 59 // WatchUnits starts a StringsWatcher to watch all units belonging to 60 // to any entity (machine or service) passed in args. 61 func (u *UnitsWatcher) WatchUnits(args params.Entities) (params.StringsWatchResults, error) { 62 result := params.StringsWatchResults{ 63 Results: make([]params.StringsWatchResult, len(args.Entities)), 64 } 65 if len(args.Entities) == 0 { 66 return result, nil 67 } 68 canWatch, err := u.getCanWatch() 69 if err != nil { 70 return params.StringsWatchResults{}, errors.Trace(err) 71 } 72 for i, entity := range args.Entities { 73 tag, err := names.ParseTag(entity.Tag) 74 if err != nil { 75 result.Results[i].Error = ServerError(ErrPerm) 76 continue 77 } 78 entityResult, err := u.watchOneEntityUnits(canWatch, tag) 79 result.Results[i] = entityResult 80 result.Results[i].Error = ServerError(err) 81 } 82 return result, nil 83 }