github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/apiserver/common/reboot.go (about)

     1  // Copyright 2014 Canonical Ltd.
     2  // Copyright 2014 Cloudbase Solutions
     3  // Licensed under the AGPLv3, see LICENCE file for details.
     4  
     5  package common
     6  
     7  import (
     8  	"github.com/juju/errors"
     9  	"github.com/juju/names/v5"
    10  
    11  	apiservererrors "github.com/juju/juju/apiserver/errors"
    12  	"github.com/juju/juju/rpc/params"
    13  	"github.com/juju/juju/state"
    14  )
    15  
    16  // RebootRequester implements the RequestReboot API method
    17  type RebootRequester struct {
    18  	st   state.EntityFinder
    19  	auth GetAuthFunc
    20  }
    21  
    22  func NewRebootRequester(st state.EntityFinder, auth GetAuthFunc) *RebootRequester {
    23  	return &RebootRequester{
    24  		st:   st,
    25  		auth: auth,
    26  	}
    27  }
    28  
    29  func (r *RebootRequester) oneRequest(tag names.Tag) error {
    30  	entity0, err := r.st.FindEntity(tag)
    31  	if err != nil {
    32  		return err
    33  	}
    34  	entity, ok := entity0.(state.RebootFlagSetter)
    35  	if !ok {
    36  		return apiservererrors.NotSupportedError(tag, "request reboot")
    37  	}
    38  	return entity.SetRebootFlag(true)
    39  }
    40  
    41  // RequestReboot sets the reboot flag on the provided machines
    42  func (r *RebootRequester) RequestReboot(args params.Entities) (params.ErrorResults, error) {
    43  	result := params.ErrorResults{
    44  		Results: make([]params.ErrorResult, len(args.Entities)),
    45  	}
    46  	if len(args.Entities) == 0 {
    47  		return result, nil
    48  	}
    49  	auth, err := r.auth()
    50  	if err != nil {
    51  		return params.ErrorResults{}, errors.Trace(err)
    52  	}
    53  	for i, entity := range args.Entities {
    54  		tag, err := names.ParseTag(entity.Tag)
    55  		if err != nil {
    56  			result.Results[i].Error = apiservererrors.ServerError(apiservererrors.ErrPerm)
    57  			continue
    58  		}
    59  		err = apiservererrors.ErrPerm
    60  		if auth(tag) {
    61  			err = r.oneRequest(tag)
    62  		}
    63  		result.Results[i].Error = apiservererrors.ServerError(err)
    64  	}
    65  	return result, nil
    66  }
    67  
    68  // RebootActionGetter implements the GetRebootAction API method
    69  type RebootActionGetter struct {
    70  	st   state.EntityFinder
    71  	auth GetAuthFunc
    72  }
    73  
    74  func NewRebootActionGetter(st state.EntityFinder, auth GetAuthFunc) *RebootActionGetter {
    75  	return &RebootActionGetter{
    76  		st:   st,
    77  		auth: auth,
    78  	}
    79  }
    80  
    81  func (r *RebootActionGetter) getOneAction(tag names.Tag) (params.RebootAction, error) {
    82  	entity0, err := r.st.FindEntity(tag)
    83  	if err != nil {
    84  		return "", err
    85  	}
    86  	entity, ok := entity0.(state.RebootActionGetter)
    87  	if !ok {
    88  		return "", apiservererrors.NotSupportedError(tag, "request reboot")
    89  	}
    90  	rAction, err := entity.ShouldRebootOrShutdown()
    91  	if err != nil {
    92  		return params.ShouldDoNothing, err
    93  	}
    94  	return params.RebootAction(rAction), nil
    95  }
    96  
    97  // GetRebootAction returns the action a machine agent should take.
    98  // If a reboot flag is set on the machine, then that machine is
    99  // expected to reboot (params.ShouldReboot).
   100  // a reboot flag set on the machine parent or grandparent, will
   101  // cause the machine to shutdown (params.ShouldShutdown).
   102  // If no reboot flag is set, the machine should do nothing (params.ShouldDoNothing).
   103  func (r *RebootActionGetter) GetRebootAction(args params.Entities) (params.RebootActionResults, error) {
   104  	result := params.RebootActionResults{
   105  		Results: make([]params.RebootActionResult, len(args.Entities)),
   106  	}
   107  	if len(args.Entities) == 0 {
   108  		return result, nil
   109  	}
   110  	auth, err := r.auth()
   111  	if err != nil {
   112  		return params.RebootActionResults{}, errors.Trace(err)
   113  	}
   114  	for i, entity := range args.Entities {
   115  		tag, err := names.ParseTag(entity.Tag)
   116  		if err != nil {
   117  			result.Results[i].Error = apiservererrors.ServerError(apiservererrors.ErrPerm)
   118  			continue
   119  		}
   120  		err = apiservererrors.ErrPerm
   121  		if auth(tag) {
   122  			result.Results[i].Result, err = r.getOneAction(tag)
   123  		}
   124  		result.Results[i].Error = apiservererrors.ServerError(err)
   125  	}
   126  	return result, nil
   127  }
   128  
   129  // RebootFlagClearer implements the ClearReboot API call
   130  type RebootFlagClearer struct {
   131  	st   state.EntityFinder
   132  	auth GetAuthFunc
   133  }
   134  
   135  func NewRebootFlagClearer(st state.EntityFinder, auth GetAuthFunc) *RebootFlagClearer {
   136  	return &RebootFlagClearer{
   137  		st:   st,
   138  		auth: auth,
   139  	}
   140  }
   141  
   142  func (r *RebootFlagClearer) clearOneFlag(tag names.Tag) error {
   143  	entity0, err := r.st.FindEntity(tag)
   144  	if err != nil {
   145  		return err
   146  	}
   147  	entity, ok := entity0.(state.RebootFlagSetter)
   148  	if !ok {
   149  		return apiservererrors.NotSupportedError(tag, "clear reboot flag")
   150  	}
   151  	return entity.SetRebootFlag(false)
   152  }
   153  
   154  // ClearReboot will clear the reboot flag on provided machines, if it exists.
   155  func (r *RebootFlagClearer) ClearReboot(args params.Entities) (params.ErrorResults, error) {
   156  	result := params.ErrorResults{
   157  		Results: make([]params.ErrorResult, len(args.Entities)),
   158  	}
   159  	if len(args.Entities) == 0 {
   160  		return result, nil
   161  	}
   162  	auth, err := r.auth()
   163  	if err != nil {
   164  		return params.ErrorResults{}, errors.Trace(err)
   165  	}
   166  	for i, entity := range args.Entities {
   167  		tag, err := names.ParseTag(entity.Tag)
   168  		if err != nil {
   169  			result.Results[i].Error = apiservererrors.ServerError(apiservererrors.ErrPerm)
   170  			continue
   171  		}
   172  		err = apiservererrors.ErrPerm
   173  		if auth(tag) {
   174  			err = r.clearOneFlag(tag)
   175  		}
   176  		result.Results[i].Error = apiservererrors.ServerError(err)
   177  	}
   178  	return result, nil
   179  }