github.com/wallyworld/juju@v0.0.0-20161013125918-6cf1bc9d917a/apiserver/common/reboot.go (about)

     1  // Copyright 2014 Canonical Ltd.
     2  // Copyright 2014 Cloudbase Solutions
     3  // Licensed under the AGPLv3, see LICENCE file for details.
     4  
     5  package common
     6  
     7  import (
     8  	"github.com/juju/errors"
     9  	"gopkg.in/juju/names.v2"
    10  
    11  	"github.com/juju/juju/apiserver/params"
    12  	"github.com/juju/juju/state"
    13  )
    14  
    15  // RebootRequester implements the RequestReboot API method
    16  type RebootRequester struct {
    17  	st   state.EntityFinder
    18  	auth GetAuthFunc
    19  }
    20  
    21  func NewRebootRequester(st state.EntityFinder, auth GetAuthFunc) *RebootRequester {
    22  	return &RebootRequester{
    23  		st:   st,
    24  		auth: auth,
    25  	}
    26  }
    27  
    28  func (r *RebootRequester) oneRequest(tag names.Tag) error {
    29  	entity0, err := r.st.FindEntity(tag)
    30  	if err != nil {
    31  		return err
    32  	}
    33  	entity, ok := entity0.(state.RebootFlagSetter)
    34  	if !ok {
    35  		return NotSupportedError(tag, "request reboot")
    36  	}
    37  	return entity.SetRebootFlag(true)
    38  }
    39  
    40  // RequestReboot sets the reboot flag on the provided machines
    41  func (r *RebootRequester) RequestReboot(args params.Entities) (params.ErrorResults, error) {
    42  	result := params.ErrorResults{
    43  		Results: make([]params.ErrorResult, len(args.Entities)),
    44  	}
    45  	if len(args.Entities) == 0 {
    46  		return result, nil
    47  	}
    48  	auth, err := r.auth()
    49  	if err != nil {
    50  		return params.ErrorResults{}, errors.Trace(err)
    51  	}
    52  	for i, entity := range args.Entities {
    53  		tag, err := names.ParseTag(entity.Tag)
    54  		if err != nil {
    55  			result.Results[i].Error = ServerError(ErrPerm)
    56  			continue
    57  		}
    58  		err = ErrPerm
    59  		if auth(tag) {
    60  			err = r.oneRequest(tag)
    61  		}
    62  		result.Results[i].Error = ServerError(err)
    63  	}
    64  	return result, nil
    65  }
    66  
    67  // RebootActionGetter implements the GetRebootAction API method
    68  type RebootActionGetter struct {
    69  	st   state.EntityFinder
    70  	auth GetAuthFunc
    71  }
    72  
    73  func NewRebootActionGetter(st state.EntityFinder, auth GetAuthFunc) *RebootActionGetter {
    74  	return &RebootActionGetter{
    75  		st:   st,
    76  		auth: auth,
    77  	}
    78  }
    79  
    80  func (r *RebootActionGetter) getOneAction(tag names.Tag) (params.RebootAction, error) {
    81  	entity0, err := r.st.FindEntity(tag)
    82  	if err != nil {
    83  		return "", err
    84  	}
    85  	entity, ok := entity0.(state.RebootActionGetter)
    86  	if !ok {
    87  		return "", NotSupportedError(tag, "request reboot")
    88  	}
    89  	rAction, err := entity.ShouldRebootOrShutdown()
    90  	if err != nil {
    91  		return params.ShouldDoNothing, err
    92  	}
    93  	return params.RebootAction(rAction), nil
    94  }
    95  
    96  // GetRebootAction returns the action a machine agent should take.
    97  // If a reboot flag is set on the machine, then that machine is
    98  // expected to reboot (params.ShouldReboot).
    99  // a reboot flag set on the machine parent or grandparent, will
   100  // cause the machine to shutdown (params.ShouldShutdown).
   101  // If no reboot flag is set, the machine should do nothing (params.ShouldDoNothing).
   102  func (r *RebootActionGetter) GetRebootAction(args params.Entities) (params.RebootActionResults, error) {
   103  	result := params.RebootActionResults{
   104  		Results: make([]params.RebootActionResult, len(args.Entities)),
   105  	}
   106  	if len(args.Entities) == 0 {
   107  		return result, nil
   108  	}
   109  	auth, err := r.auth()
   110  	if err != nil {
   111  		return params.RebootActionResults{}, errors.Trace(err)
   112  	}
   113  	for i, entity := range args.Entities {
   114  		tag, err := names.ParseTag(entity.Tag)
   115  		if err != nil {
   116  			result.Results[i].Error = ServerError(ErrPerm)
   117  			continue
   118  		}
   119  		err = ErrPerm
   120  		if auth(tag) {
   121  			result.Results[i].Result, err = r.getOneAction(tag)
   122  		}
   123  		result.Results[i].Error = ServerError(err)
   124  	}
   125  	return result, nil
   126  }
   127  
   128  // RebootFlagClearer implements the ClearReboot API call
   129  type RebootFlagClearer struct {
   130  	st   state.EntityFinder
   131  	auth GetAuthFunc
   132  }
   133  
   134  func NewRebootFlagClearer(st state.EntityFinder, auth GetAuthFunc) *RebootFlagClearer {
   135  	return &RebootFlagClearer{
   136  		st:   st,
   137  		auth: auth,
   138  	}
   139  }
   140  
   141  func (r *RebootFlagClearer) clearOneFlag(tag names.Tag) error {
   142  	entity0, err := r.st.FindEntity(tag)
   143  	if err != nil {
   144  		return err
   145  	}
   146  	entity, ok := entity0.(state.RebootFlagSetter)
   147  	if !ok {
   148  		return NotSupportedError(tag, "clear reboot flag")
   149  	}
   150  	return entity.SetRebootFlag(false)
   151  }
   152  
   153  // ClearReboot will clear the reboot flag on provided machines, if it exists.
   154  func (r *RebootFlagClearer) ClearReboot(args params.Entities) (params.ErrorResults, error) {
   155  	result := params.ErrorResults{
   156  		Results: make([]params.ErrorResult, len(args.Entities)),
   157  	}
   158  	if len(args.Entities) == 0 {
   159  		return result, nil
   160  	}
   161  	auth, err := r.auth()
   162  	if err != nil {
   163  		return params.ErrorResults{}, errors.Trace(err)
   164  	}
   165  	for i, entity := range args.Entities {
   166  		tag, err := names.ParseTag(entity.Tag)
   167  		if err != nil {
   168  			result.Results[i].Error = ServerError(ErrPerm)
   169  			continue
   170  		}
   171  		err = ErrPerm
   172  		if auth(tag) {
   173  			err = r.clearOneFlag(tag)
   174  		}
   175  		result.Results[i].Error = ServerError(err)
   176  	}
   177  	return result, nil
   178  }