github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/api/agent/reboot/reboot.go (about) 1 // Copyright 2014 Cloudbase Solutions 2 // Copyright 2014 Canonical Ltd. 3 // Licensed under the AGPLv3, see LICENCE file for details. 4 5 package reboot 6 7 import ( 8 "github.com/juju/errors" 9 "github.com/juju/names/v5" 10 11 "github.com/juju/juju/api" 12 "github.com/juju/juju/api/base" 13 apiwatcher "github.com/juju/juju/api/watcher" 14 "github.com/juju/juju/core/watcher" 15 "github.com/juju/juju/rpc/params" 16 ) 17 18 // State provides access to an reboot worker's view of the state. 19 // NOTE: This is defined as an interface due to PPC64 bug #1533469 - 20 // if it were a type build errors happen (due to a linker bug). 21 type State interface { 22 // WatchForRebootEvent returns a watcher.NotifyWatcher that 23 // reacts to reboot flag changes. 24 WatchForRebootEvent() (watcher.NotifyWatcher, error) 25 26 // RequestReboot sets the reboot flag for the calling machine. 27 RequestReboot() error 28 29 // ClearReboot clears the reboot flag for the calling machine. 30 ClearReboot() error 31 32 // GetRebootAction returns the reboot action for the calling machine. 33 GetRebootAction() (params.RebootAction, error) 34 } 35 36 var _ State = (*state)(nil) 37 38 // state implements State. 39 type state struct { 40 machineTag names.Tag 41 facade base.FacadeCaller 42 } 43 44 // NewState returns a version of the state that provides functionality 45 // required by the reboot worker. 46 func NewState(caller base.APICaller, machineTag names.MachineTag) State { 47 return &state{ 48 facade: base.NewFacadeCaller(caller, "Reboot"), 49 machineTag: machineTag, 50 } 51 } 52 53 // ConnectionReboot returns access to the Reboot API 54 func NewFromConnection(c api.Connection) (State, error) { 55 switch tag := c.AuthTag().(type) { 56 case names.MachineTag: 57 return NewState(c, tag), nil 58 default: 59 return nil, errors.Errorf("expected names.MachineTag, got %T", tag) 60 } 61 } 62 63 // WatchForRebootEvent implements State.WatchForRebootEvent 64 func (st *state) WatchForRebootEvent() (watcher.NotifyWatcher, error) { 65 var result params.NotifyWatchResult 66 67 if err := st.facade.FacadeCall("WatchForRebootEvent", nil, &result); err != nil { 68 return nil, err 69 } 70 if result.Error != nil { 71 return nil, result.Error 72 } 73 74 w := apiwatcher.NewNotifyWatcher(st.facade.RawAPICaller(), result) 75 return w, nil 76 } 77 78 // RequestReboot implements State.RequestReboot 79 func (st *state) RequestReboot() error { 80 var results params.ErrorResults 81 args := params.Entities{ 82 Entities: []params.Entity{{Tag: st.machineTag.String()}}, 83 } 84 85 err := st.facade.FacadeCall("RequestReboot", args, &results) 86 if err != nil { 87 return errors.Trace(err) 88 } 89 if len(results.Results) != 1 { 90 return errors.Errorf("expected 1 result, got %d", len(results.Results)) 91 } 92 93 if results.Results[0].Error != nil { 94 return errors.Trace(results.Results[0].Error) 95 } 96 return nil 97 } 98 99 // ClearReboot implements State.ClearReboot 100 func (st *state) ClearReboot() error { 101 var results params.ErrorResults 102 args := params.Entities{ 103 Entities: []params.Entity{{Tag: st.machineTag.String()}}, 104 } 105 106 err := st.facade.FacadeCall("ClearReboot", args, &results) 107 if err != nil { 108 return errors.Trace(err) 109 } 110 111 if len(results.Results) != 1 { 112 return errors.Errorf("expected 1 result, got %d", len(results.Results)) 113 } 114 115 if results.Results[0].Error != nil { 116 return errors.Trace(results.Results[0].Error) 117 } 118 119 return nil 120 } 121 122 // GetRebootAction implements State.GetRebootAction 123 func (st *state) GetRebootAction() (params.RebootAction, error) { 124 var results params.RebootActionResults 125 args := params.Entities{ 126 Entities: []params.Entity{{Tag: st.machineTag.String()}}, 127 } 128 129 err := st.facade.FacadeCall("GetRebootAction", args, &results) 130 if err != nil { 131 return params.ShouldDoNothing, err 132 } 133 if len(results.Results) != 1 { 134 return params.ShouldDoNothing, errors.Errorf("expected 1 result, got %d", len(results.Results)) 135 } 136 137 if results.Results[0].Error != nil { 138 return params.ShouldDoNothing, errors.Trace(results.Results[0].Error) 139 } 140 141 return results.Results[0].Result, nil 142 }