github.com/altoros/juju-vmware@v0.0.0-20150312064031-f19ae857ccca/apiserver/common/reboot.go (about)

     1  package common
     2  
     3  import (
     4  	"github.com/juju/errors"
     5  	"github.com/juju/juju/apiserver/params"
     6  	"github.com/juju/juju/state"
     7  	"github.com/juju/names"
     8  )
     9  
    10  // RebootRequester implements the RequestReboot API method
    11  type RebootRequester struct {
    12  	st   state.EntityFinder
    13  	auth GetAuthFunc
    14  }
    15  
    16  func NewRebootRequester(st state.EntityFinder, auth GetAuthFunc) *RebootRequester {
    17  	return &RebootRequester{
    18  		st:   st,
    19  		auth: auth,
    20  	}
    21  }
    22  
    23  func (r *RebootRequester) oneRequest(tag names.Tag) error {
    24  	entity0, err := r.st.FindEntity(tag)
    25  	if err != nil {
    26  		return err
    27  	}
    28  	entity, ok := entity0.(state.RebootFlagSetter)
    29  	if !ok {
    30  		return NotSupportedError(tag, "request reboot")
    31  	}
    32  	return entity.SetRebootFlag(true)
    33  }
    34  
    35  // RequestReboot sets the reboot flag on the provided machines
    36  func (r *RebootRequester) RequestReboot(args params.Entities) (params.ErrorResults, error) {
    37  	result := params.ErrorResults{
    38  		Results: make([]params.ErrorResult, len(args.Entities)),
    39  	}
    40  	if len(args.Entities) == 0 {
    41  		return result, nil
    42  	}
    43  	auth, err := r.auth()
    44  	if err != nil {
    45  		return params.ErrorResults{}, errors.Trace(err)
    46  	}
    47  	for i, entity := range args.Entities {
    48  		tag, err := names.ParseTag(entity.Tag)
    49  		if err != nil {
    50  			result.Results[i].Error = ServerError(ErrPerm)
    51  			continue
    52  		}
    53  		err = ErrPerm
    54  		if auth(tag) {
    55  			err = r.oneRequest(tag)
    56  		}
    57  		result.Results[i].Error = ServerError(err)
    58  	}
    59  	return result, nil
    60  }
    61  
    62  // RebootActionGetter implements the GetRebootAction API method
    63  type RebootActionGetter struct {
    64  	st   state.EntityFinder
    65  	auth GetAuthFunc
    66  }
    67  
    68  func NewRebootActionGetter(st state.EntityFinder, auth GetAuthFunc) *RebootActionGetter {
    69  	return &RebootActionGetter{
    70  		st:   st,
    71  		auth: auth,
    72  	}
    73  }
    74  
    75  func (r *RebootActionGetter) getOneAction(tag names.Tag) (params.RebootAction, error) {
    76  	entity0, err := r.st.FindEntity(tag)
    77  	if err != nil {
    78  		return "", err
    79  	}
    80  	entity, ok := entity0.(state.RebootActionGetter)
    81  	if !ok {
    82  		return "", NotSupportedError(tag, "request reboot")
    83  	}
    84  	rAction, err := entity.ShouldRebootOrShutdown()
    85  	if err != nil {
    86  		return params.ShouldDoNothing, err
    87  	}
    88  	return params.RebootAction(rAction), nil
    89  }
    90  
    91  // GetRebootAction returns the action a machine agent should take.
    92  // If a reboot flag is set on the machine, then that machine is
    93  // expected to reboot (params.ShouldReboot).
    94  // a reboot flag set on the machine parent or grandparent, will
    95  // cause the machine to shutdown (params.ShouldShutdown).
    96  // If no reboot flag is set, the machine should do nothing (params.ShouldDoNothing).
    97  func (r *RebootActionGetter) GetRebootAction(args params.Entities) (params.RebootActionResults, error) {
    98  	result := params.RebootActionResults{
    99  		Results: make([]params.RebootActionResult, len(args.Entities)),
   100  	}
   101  	if len(args.Entities) == 0 {
   102  		return result, nil
   103  	}
   104  	auth, err := r.auth()
   105  	if err != nil {
   106  		return params.RebootActionResults{}, errors.Trace(err)
   107  	}
   108  	for i, entity := range args.Entities {
   109  		tag, err := names.ParseTag(entity.Tag)
   110  		if err != nil {
   111  			result.Results[i].Error = ServerError(ErrPerm)
   112  			continue
   113  		}
   114  		err = ErrPerm
   115  		if auth(tag) {
   116  			result.Results[i].Result, err = r.getOneAction(tag)
   117  		}
   118  		result.Results[i].Error = ServerError(err)
   119  	}
   120  	return result, nil
   121  }
   122  
   123  // RebootFlagClearer implements the ClearReboot API call
   124  type RebootFlagClearer struct {
   125  	st   state.EntityFinder
   126  	auth GetAuthFunc
   127  }
   128  
   129  func NewRebootFlagClearer(st state.EntityFinder, auth GetAuthFunc) *RebootFlagClearer {
   130  	return &RebootFlagClearer{
   131  		st:   st,
   132  		auth: auth,
   133  	}
   134  }
   135  
   136  func (r *RebootFlagClearer) clearOneFlag(tag names.Tag) error {
   137  	entity0, err := r.st.FindEntity(tag)
   138  	if err != nil {
   139  		return err
   140  	}
   141  	entity, ok := entity0.(state.RebootFlagSetter)
   142  	if !ok {
   143  		return NotSupportedError(tag, "clear reboot flag")
   144  	}
   145  	return entity.SetRebootFlag(false)
   146  }
   147  
   148  // ClearReboot will clear the reboot flag on provided machines, if it exists.
   149  func (r *RebootFlagClearer) ClearReboot(args params.Entities) (params.ErrorResults, error) {
   150  	result := params.ErrorResults{
   151  		Results: make([]params.ErrorResult, len(args.Entities)),
   152  	}
   153  	if len(args.Entities) == 0 {
   154  		return result, nil
   155  	}
   156  	auth, err := r.auth()
   157  	if err != nil {
   158  		return params.ErrorResults{}, errors.Trace(err)
   159  	}
   160  	for i, entity := range args.Entities {
   161  		tag, err := names.ParseTag(entity.Tag)
   162  		if err != nil {
   163  			result.Results[i].Error = ServerError(ErrPerm)
   164  			continue
   165  		}
   166  		err = ErrPerm
   167  		if auth(tag) {
   168  			err = r.clearOneFlag(tag)
   169  		}
   170  		result.Results[i].Error = ServerError(err)
   171  	}
   172  	return result, nil
   173  }