github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/apiserver/common/reboot.go (about) 1 // Copyright 2014 Canonical Ltd. 2 // Copyright 2014 Cloudbase Solutions 3 // Licensed under the AGPLv3, see LICENCE file for details. 4 5 package common 6 7 import ( 8 "github.com/juju/errors" 9 "github.com/juju/names/v5" 10 11 apiservererrors "github.com/juju/juju/apiserver/errors" 12 "github.com/juju/juju/rpc/params" 13 "github.com/juju/juju/state" 14 ) 15 16 // RebootRequester implements the RequestReboot API method 17 type RebootRequester struct { 18 st state.EntityFinder 19 auth GetAuthFunc 20 } 21 22 func NewRebootRequester(st state.EntityFinder, auth GetAuthFunc) *RebootRequester { 23 return &RebootRequester{ 24 st: st, 25 auth: auth, 26 } 27 } 28 29 func (r *RebootRequester) oneRequest(tag names.Tag) error { 30 entity0, err := r.st.FindEntity(tag) 31 if err != nil { 32 return err 33 } 34 entity, ok := entity0.(state.RebootFlagSetter) 35 if !ok { 36 return apiservererrors.NotSupportedError(tag, "request reboot") 37 } 38 return entity.SetRebootFlag(true) 39 } 40 41 // RequestReboot sets the reboot flag on the provided machines 42 func (r *RebootRequester) RequestReboot(args params.Entities) (params.ErrorResults, error) { 43 result := params.ErrorResults{ 44 Results: make([]params.ErrorResult, len(args.Entities)), 45 } 46 if len(args.Entities) == 0 { 47 return result, nil 48 } 49 auth, err := r.auth() 50 if err != nil { 51 return params.ErrorResults{}, errors.Trace(err) 52 } 53 for i, entity := range args.Entities { 54 tag, err := names.ParseTag(entity.Tag) 55 if err != nil { 56 result.Results[i].Error = apiservererrors.ServerError(apiservererrors.ErrPerm) 57 continue 58 } 59 err = apiservererrors.ErrPerm 60 if auth(tag) { 61 err = r.oneRequest(tag) 62 } 63 result.Results[i].Error = apiservererrors.ServerError(err) 64 } 65 return result, nil 66 } 67 68 // RebootActionGetter implements the GetRebootAction API method 69 type RebootActionGetter struct { 70 st state.EntityFinder 71 auth GetAuthFunc 72 } 73 74 func NewRebootActionGetter(st state.EntityFinder, auth GetAuthFunc) *RebootActionGetter { 75 return &RebootActionGetter{ 76 st: st, 77 auth: auth, 78 } 79 } 80 81 func (r *RebootActionGetter) getOneAction(tag names.Tag) (params.RebootAction, error) { 82 entity0, err := r.st.FindEntity(tag) 83 if err != nil { 84 return "", err 85 } 86 entity, ok := entity0.(state.RebootActionGetter) 87 if !ok { 88 return "", apiservererrors.NotSupportedError(tag, "request reboot") 89 } 90 rAction, err := entity.ShouldRebootOrShutdown() 91 if err != nil { 92 return params.ShouldDoNothing, err 93 } 94 return params.RebootAction(rAction), nil 95 } 96 97 // GetRebootAction returns the action a machine agent should take. 98 // If a reboot flag is set on the machine, then that machine is 99 // expected to reboot (params.ShouldReboot). 100 // a reboot flag set on the machine parent or grandparent, will 101 // cause the machine to shutdown (params.ShouldShutdown). 102 // If no reboot flag is set, the machine should do nothing (params.ShouldDoNothing). 103 func (r *RebootActionGetter) GetRebootAction(args params.Entities) (params.RebootActionResults, error) { 104 result := params.RebootActionResults{ 105 Results: make([]params.RebootActionResult, len(args.Entities)), 106 } 107 if len(args.Entities) == 0 { 108 return result, nil 109 } 110 auth, err := r.auth() 111 if err != nil { 112 return params.RebootActionResults{}, errors.Trace(err) 113 } 114 for i, entity := range args.Entities { 115 tag, err := names.ParseTag(entity.Tag) 116 if err != nil { 117 result.Results[i].Error = apiservererrors.ServerError(apiservererrors.ErrPerm) 118 continue 119 } 120 err = apiservererrors.ErrPerm 121 if auth(tag) { 122 result.Results[i].Result, err = r.getOneAction(tag) 123 } 124 result.Results[i].Error = apiservererrors.ServerError(err) 125 } 126 return result, nil 127 } 128 129 // RebootFlagClearer implements the ClearReboot API call 130 type RebootFlagClearer struct { 131 st state.EntityFinder 132 auth GetAuthFunc 133 } 134 135 func NewRebootFlagClearer(st state.EntityFinder, auth GetAuthFunc) *RebootFlagClearer { 136 return &RebootFlagClearer{ 137 st: st, 138 auth: auth, 139 } 140 } 141 142 func (r *RebootFlagClearer) clearOneFlag(tag names.Tag) error { 143 entity0, err := r.st.FindEntity(tag) 144 if err != nil { 145 return err 146 } 147 entity, ok := entity0.(state.RebootFlagSetter) 148 if !ok { 149 return apiservererrors.NotSupportedError(tag, "clear reboot flag") 150 } 151 return entity.SetRebootFlag(false) 152 } 153 154 // ClearReboot will clear the reboot flag on provided machines, if it exists. 155 func (r *RebootFlagClearer) ClearReboot(args params.Entities) (params.ErrorResults, error) { 156 result := params.ErrorResults{ 157 Results: make([]params.ErrorResult, len(args.Entities)), 158 } 159 if len(args.Entities) == 0 { 160 return result, nil 161 } 162 auth, err := r.auth() 163 if err != nil { 164 return params.ErrorResults{}, errors.Trace(err) 165 } 166 for i, entity := range args.Entities { 167 tag, err := names.ParseTag(entity.Tag) 168 if err != nil { 169 result.Results[i].Error = apiservererrors.ServerError(apiservererrors.ErrPerm) 170 continue 171 } 172 err = apiservererrors.ErrPerm 173 if auth(tag) { 174 err = r.clearOneFlag(tag) 175 } 176 result.Results[i].Error = apiservererrors.ServerError(err) 177 } 178 return result, nil 179 }