github.com/makyo/juju@v0.0.0-20160425123129-2608902037e9/apiserver/common/reboot.go (about)

     1  package common
     2  
     3  import (
     4  	"github.com/juju/errors"
     5  	"github.com/juju/names"
     6  
     7  	"github.com/juju/juju/apiserver/params"
     8  	"github.com/juju/juju/state"
     9  )
    10  
    11  // RebootRequester implements the RequestReboot API method
    12  type RebootRequester struct {
    13  	st   state.EntityFinder
    14  	auth GetAuthFunc
    15  }
    16  
    17  func NewRebootRequester(st state.EntityFinder, auth GetAuthFunc) *RebootRequester {
    18  	return &RebootRequester{
    19  		st:   st,
    20  		auth: auth,
    21  	}
    22  }
    23  
    24  func (r *RebootRequester) oneRequest(tag names.Tag) error {
    25  	entity0, err := r.st.FindEntity(tag)
    26  	if err != nil {
    27  		return err
    28  	}
    29  	entity, ok := entity0.(state.RebootFlagSetter)
    30  	if !ok {
    31  		return NotSupportedError(tag, "request reboot")
    32  	}
    33  	return entity.SetRebootFlag(true)
    34  }
    35  
    36  // RequestReboot sets the reboot flag on the provided machines
    37  func (r *RebootRequester) RequestReboot(args params.Entities) (params.ErrorResults, error) {
    38  	result := params.ErrorResults{
    39  		Results: make([]params.ErrorResult, len(args.Entities)),
    40  	}
    41  	if len(args.Entities) == 0 {
    42  		return result, nil
    43  	}
    44  	auth, err := r.auth()
    45  	if err != nil {
    46  		return params.ErrorResults{}, errors.Trace(err)
    47  	}
    48  	for i, entity := range args.Entities {
    49  		tag, err := names.ParseTag(entity.Tag)
    50  		if err != nil {
    51  			result.Results[i].Error = ServerError(ErrPerm)
    52  			continue
    53  		}
    54  		err = ErrPerm
    55  		if auth(tag) {
    56  			err = r.oneRequest(tag)
    57  		}
    58  		result.Results[i].Error = ServerError(err)
    59  	}
    60  	return result, nil
    61  }
    62  
    63  // RebootActionGetter implements the GetRebootAction API method
    64  type RebootActionGetter struct {
    65  	st   state.EntityFinder
    66  	auth GetAuthFunc
    67  }
    68  
    69  func NewRebootActionGetter(st state.EntityFinder, auth GetAuthFunc) *RebootActionGetter {
    70  	return &RebootActionGetter{
    71  		st:   st,
    72  		auth: auth,
    73  	}
    74  }
    75  
    76  func (r *RebootActionGetter) getOneAction(tag names.Tag) (params.RebootAction, error) {
    77  	entity0, err := r.st.FindEntity(tag)
    78  	if err != nil {
    79  		return "", err
    80  	}
    81  	entity, ok := entity0.(state.RebootActionGetter)
    82  	if !ok {
    83  		return "", NotSupportedError(tag, "request reboot")
    84  	}
    85  	rAction, err := entity.ShouldRebootOrShutdown()
    86  	if err != nil {
    87  		return params.ShouldDoNothing, err
    88  	}
    89  	return params.RebootAction(rAction), nil
    90  }
    91  
    92  // GetRebootAction returns the action a machine agent should take.
    93  // If a reboot flag is set on the machine, then that machine is
    94  // expected to reboot (params.ShouldReboot).
    95  // a reboot flag set on the machine parent or grandparent, will
    96  // cause the machine to shutdown (params.ShouldShutdown).
    97  // If no reboot flag is set, the machine should do nothing (params.ShouldDoNothing).
    98  func (r *RebootActionGetter) GetRebootAction(args params.Entities) (params.RebootActionResults, error) {
    99  	result := params.RebootActionResults{
   100  		Results: make([]params.RebootActionResult, len(args.Entities)),
   101  	}
   102  	if len(args.Entities) == 0 {
   103  		return result, nil
   104  	}
   105  	auth, err := r.auth()
   106  	if err != nil {
   107  		return params.RebootActionResults{}, errors.Trace(err)
   108  	}
   109  	for i, entity := range args.Entities {
   110  		tag, err := names.ParseTag(entity.Tag)
   111  		if err != nil {
   112  			result.Results[i].Error = ServerError(ErrPerm)
   113  			continue
   114  		}
   115  		err = ErrPerm
   116  		if auth(tag) {
   117  			result.Results[i].Result, err = r.getOneAction(tag)
   118  		}
   119  		result.Results[i].Error = ServerError(err)
   120  	}
   121  	return result, nil
   122  }
   123  
   124  // RebootFlagClearer implements the ClearReboot API call
   125  type RebootFlagClearer struct {
   126  	st   state.EntityFinder
   127  	auth GetAuthFunc
   128  }
   129  
   130  func NewRebootFlagClearer(st state.EntityFinder, auth GetAuthFunc) *RebootFlagClearer {
   131  	return &RebootFlagClearer{
   132  		st:   st,
   133  		auth: auth,
   134  	}
   135  }
   136  
   137  func (r *RebootFlagClearer) clearOneFlag(tag names.Tag) error {
   138  	entity0, err := r.st.FindEntity(tag)
   139  	if err != nil {
   140  		return err
   141  	}
   142  	entity, ok := entity0.(state.RebootFlagSetter)
   143  	if !ok {
   144  		return NotSupportedError(tag, "clear reboot flag")
   145  	}
   146  	return entity.SetRebootFlag(false)
   147  }
   148  
   149  // ClearReboot will clear the reboot flag on provided machines, if it exists.
   150  func (r *RebootFlagClearer) ClearReboot(args params.Entities) (params.ErrorResults, error) {
   151  	result := params.ErrorResults{
   152  		Results: make([]params.ErrorResult, len(args.Entities)),
   153  	}
   154  	if len(args.Entities) == 0 {
   155  		return result, nil
   156  	}
   157  	auth, err := r.auth()
   158  	if err != nil {
   159  		return params.ErrorResults{}, errors.Trace(err)
   160  	}
   161  	for i, entity := range args.Entities {
   162  		tag, err := names.ParseTag(entity.Tag)
   163  		if err != nil {
   164  			result.Results[i].Error = ServerError(ErrPerm)
   165  			continue
   166  		}
   167  		err = ErrPerm
   168  		if auth(tag) {
   169  			err = r.clearOneFlag(tag)
   170  		}
   171  		result.Results[i].Error = ServerError(err)
   172  	}
   173  	return result, nil
   174  }