github.com/rogpeppe/juju@v0.0.0-20140613142852-6337964b789e/state/apiserver/keyupdater/authorisedkeys.go (about) 1 // Copyright 2013 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package keyupdater 5 6 import ( 7 "github.com/juju/errors" 8 9 "github.com/juju/juju/state" 10 "github.com/juju/juju/state/api/params" 11 "github.com/juju/juju/state/apiserver/common" 12 "github.com/juju/juju/state/watcher" 13 "github.com/juju/juju/utils/ssh" 14 ) 15 16 // KeyUpdater defines the methods on the keyupdater API end point. 17 type KeyUpdater interface { 18 AuthorisedKeys(args params.Entities) (params.StringsResults, error) 19 WatchAuthorisedKeys(args params.Entities) (params.NotifyWatchResults, error) 20 } 21 22 // KeyUpdaterAPI implements the KeyUpdater interface and is the concrete 23 // implementation of the api end point. 24 type KeyUpdaterAPI struct { 25 state *state.State 26 resources *common.Resources 27 authorizer common.Authorizer 28 getCanRead common.GetAuthFunc 29 } 30 31 var _ KeyUpdater = (*KeyUpdaterAPI)(nil) 32 33 // NewKeyUpdaterAPI creates a new server-side keyupdater API end point. 34 func NewKeyUpdaterAPI( 35 st *state.State, 36 resources *common.Resources, 37 authorizer common.Authorizer, 38 ) (*KeyUpdaterAPI, error) { 39 // Only machine agents have access to the keyupdater service. 40 if !authorizer.AuthMachineAgent() { 41 return nil, common.ErrPerm 42 } 43 // No-one else except the machine itself can only read a machine's own credentials. 44 getCanRead := func() (common.AuthFunc, error) { 45 return authorizer.AuthOwner, nil 46 } 47 return &KeyUpdaterAPI{state: st, resources: resources, authorizer: authorizer, getCanRead: getCanRead}, nil 48 } 49 50 // WatchAuthorisedKeys starts a watcher to track changes to the authorised ssh keys 51 // for the specified machines. 52 // The current implementation relies on global authorised keys being stored in the environment config. 53 // This will change as new user management and authorisation functionality is added. 54 func (api *KeyUpdaterAPI) WatchAuthorisedKeys(arg params.Entities) (params.NotifyWatchResults, error) { 55 results := make([]params.NotifyWatchResult, len(arg.Entities)) 56 57 canRead, err := api.getCanRead() 58 if err != nil { 59 return params.NotifyWatchResults{}, err 60 } 61 for i, entity := range arg.Entities { 62 // 1. Check permissions 63 if !canRead(entity.Tag) { 64 results[i].Error = common.ServerError(common.ErrPerm) 65 continue 66 } 67 // 2. Check entity exists 68 if _, err := api.state.FindEntity(entity.Tag); err != nil { 69 if errors.IsNotFound(err) { 70 results[i].Error = common.ServerError(common.ErrPerm) 71 } else { 72 results[i].Error = common.ServerError(err) 73 } 74 continue 75 } 76 // 3. Watch fr changes 77 var err error 78 watch := api.state.WatchForEnvironConfigChanges() 79 // Consume the initial event. 80 if _, ok := <-watch.Changes(); ok { 81 results[i].NotifyWatcherId = api.resources.Register(watch) 82 } else { 83 err = watcher.MustErr(watch) 84 } 85 results[i].Error = common.ServerError(err) 86 } 87 return params.NotifyWatchResults{Results: results}, nil 88 } 89 90 // AuthorisedKeys reports the authorised ssh keys for the specified machines. 91 // The current implementation relies on global authorised keys being stored in the environment config. 92 // This will change as new user management and authorisation functionality is added. 93 func (api *KeyUpdaterAPI) AuthorisedKeys(arg params.Entities) (params.StringsResults, error) { 94 if len(arg.Entities) == 0 { 95 return params.StringsResults{}, nil 96 } 97 results := make([]params.StringsResult, len(arg.Entities)) 98 99 // For now, authorised keys are global, common to all machines. 100 var keys []string 101 config, configErr := api.state.EnvironConfig() 102 if configErr == nil { 103 keys = ssh.SplitAuthorisedKeys(config.AuthorizedKeys()) 104 } 105 106 canRead, err := api.getCanRead() 107 if err != nil { 108 return params.StringsResults{}, err 109 } 110 for i, entity := range arg.Entities { 111 // 1. Check permissions 112 if !canRead(entity.Tag) { 113 results[i].Error = common.ServerError(common.ErrPerm) 114 continue 115 } 116 // 2. Check entity exists 117 if _, err := api.state.FindEntity(entity.Tag); err != nil { 118 if errors.IsNotFound(err) { 119 results[i].Error = common.ServerError(common.ErrPerm) 120 } else { 121 results[i].Error = common.ServerError(err) 122 } 123 continue 124 } 125 // 3. Get keys 126 var err error 127 if configErr == nil { 128 results[i].Result = keys 129 } else { 130 err = configErr 131 } 132 results[i].Error = common.ServerError(err) 133 } 134 return params.StringsResults{Results: results}, nil 135 }