github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/apiserver/common/upgradeseries.go (about)

     1  // Copyright 2018 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package common
     5  
     6  import (
     7  	"github.com/juju/errors"
     8  	"github.com/juju/loggo"
     9  	"github.com/juju/names/v5"
    10  
    11  	apiservererrors "github.com/juju/juju/apiserver/errors"
    12  	"github.com/juju/juju/apiserver/facade"
    13  	"github.com/juju/juju/core/model"
    14  	"github.com/juju/juju/core/status"
    15  	"github.com/juju/juju/rpc/params"
    16  	"github.com/juju/juju/state"
    17  )
    18  
    19  //go:generate go run go.uber.org/mock/mockgen -package mocks -destination mocks/upgradeseries.go github.com/juju/juju/apiserver/common UpgradeSeriesBackend,UpgradeSeriesMachine,UpgradeSeriesUnit
    20  
    21  type UpgradeSeriesBackend interface {
    22  	Machine(string) (UpgradeSeriesMachine, error)
    23  	Unit(string) (UpgradeSeriesUnit, error)
    24  }
    25  
    26  // UpgradeSeriesMachine describes machine-receiver state methods
    27  // for executing a series upgrade.
    28  type UpgradeSeriesMachine interface {
    29  	WatchUpgradeSeriesNotifications() (state.NotifyWatcher, error)
    30  	Units() ([]UpgradeSeriesUnit, error)
    31  	UpgradeSeriesStatus() (model.UpgradeSeriesStatus, error)
    32  	SetUpgradeSeriesStatus(model.UpgradeSeriesStatus, string) error
    33  	StartUpgradeSeriesUnitCompletion(string) error
    34  	UpgradeSeriesUnitStatuses() (map[string]state.UpgradeSeriesUnitStatus, error)
    35  	RemoveUpgradeSeriesLock() error
    36  	UpgradeSeriesTarget() (string, error)
    37  	Base() state.Base
    38  	UpdateMachineSeries(base state.Base) error
    39  	SetInstanceStatus(status.StatusInfo) error
    40  }
    41  
    42  // UpgradeSeriesUnit describes unit-receiver state methods
    43  // for executing a series upgrade.
    44  type UpgradeSeriesUnit interface {
    45  	Tag() names.Tag
    46  	AssignedMachineId() (string, error)
    47  	UpgradeSeriesStatus() (model.UpgradeSeriesStatus, string, error)
    48  	SetUpgradeSeriesStatus(model.UpgradeSeriesStatus, string) error
    49  }
    50  
    51  // UpgradeSeriesState implements the UpgradeSeriesBackend indirection
    52  // over state.State.
    53  type UpgradeSeriesState struct {
    54  	St *state.State
    55  }
    56  
    57  func (s UpgradeSeriesState) Machine(id string) (UpgradeSeriesMachine, error) {
    58  	m, err := s.St.Machine(id)
    59  	return &upgradeSeriesMachine{m}, err
    60  }
    61  
    62  func (s UpgradeSeriesState) Unit(id string) (UpgradeSeriesUnit, error) {
    63  	return s.St.Unit(id)
    64  }
    65  
    66  type upgradeSeriesMachine struct {
    67  	*state.Machine
    68  }
    69  
    70  // Units maintains the UpgradeSeriesMachine indirection by wrapping the call to
    71  // state.Machine.Units().
    72  func (m *upgradeSeriesMachine) Units() ([]UpgradeSeriesUnit, error) {
    73  	units, err := m.Machine.Units()
    74  	if err != nil {
    75  		return nil, errors.Trace(err)
    76  	}
    77  
    78  	wrapped := make([]UpgradeSeriesUnit, len(units))
    79  	for i, u := range units {
    80  		wrapped[i] = u
    81  	}
    82  	return wrapped, nil
    83  }
    84  
    85  type UpgradeSeriesAPI struct {
    86  	backend   UpgradeSeriesBackend
    87  	resources facade.Resources
    88  
    89  	logger loggo.Logger
    90  
    91  	accessUnitOrMachine GetAuthFunc
    92  	AccessMachine       GetAuthFunc
    93  	accessUnit          GetAuthFunc
    94  }
    95  
    96  // NewUpgradeSeriesAPI returns a new UpgradeSeriesAPI. Currently both
    97  // GetAuthFuncs can used to determine current permissions.
    98  func NewUpgradeSeriesAPI(
    99  	backend UpgradeSeriesBackend,
   100  	resources facade.Resources,
   101  	authorizer facade.Authorizer,
   102  	accessMachine GetAuthFunc,
   103  	accessUnit GetAuthFunc,
   104  	logger loggo.Logger,
   105  ) *UpgradeSeriesAPI {
   106  	logger.Tracef("NewUpgradeSeriesAPI called with %s", authorizer.GetAuthTag())
   107  	return &UpgradeSeriesAPI{
   108  		backend:             backend,
   109  		resources:           resources,
   110  		accessUnitOrMachine: AuthAny(accessUnit, accessMachine),
   111  		AccessMachine:       accessMachine,
   112  		accessUnit:          accessUnit,
   113  		logger:              logger,
   114  	}
   115  }
   116  
   117  // WatchUpgradeSeriesNotifications returns a NotifyWatcher for observing changes to upgrade series locks.
   118  func (u *UpgradeSeriesAPI) WatchUpgradeSeriesNotifications(args params.Entities) (params.NotifyWatchResults, error) {
   119  	u.logger.Tracef("Starting WatchUpgradeSeriesNotifications with %+v", args)
   120  	result := params.NotifyWatchResults{
   121  		Results: make([]params.NotifyWatchResult, len(args.Entities)),
   122  	}
   123  	canAccess, err := u.accessUnitOrMachine()
   124  	if err != nil {
   125  		return params.NotifyWatchResults{}, err
   126  	}
   127  	for i, entity := range args.Entities {
   128  		tag, err := names.ParseTag(entity.Tag)
   129  		if err != nil {
   130  			result.Results[i].Error = apiservererrors.ServerError(apiservererrors.ErrPerm)
   131  			continue
   132  		}
   133  
   134  		if !canAccess(tag) {
   135  			result.Results[i].Error = apiservererrors.ServerError(apiservererrors.ErrPerm)
   136  			continue
   137  		}
   138  		machine, err := u.GetMachine(tag)
   139  		if err != nil {
   140  			result.Results[i].Error = apiservererrors.ServerError(err)
   141  			continue
   142  		}
   143  		w, err := machine.WatchUpgradeSeriesNotifications()
   144  		if err != nil {
   145  			result.Results[i].Error = apiservererrors.ServerError(err)
   146  			continue
   147  		}
   148  		watcherId := u.resources.Register(w)
   149  		result.Results[i].NotifyWatcherId = watcherId
   150  	}
   151  	return result, nil
   152  }
   153  
   154  // UpgradeSeriesUnitStatus returns the current preparation status of an
   155  // upgrading unit.
   156  // If no series upgrade is in progress an error is returned instead.
   157  func (u *UpgradeSeriesAPI) UpgradeSeriesUnitStatus(args params.Entities) (params.UpgradeSeriesStatusResults, error) {
   158  	u.logger.Tracef("Starting UpgradeSeriesUnitStatus with %+v", args)
   159  	return u.unitStatus(args)
   160  }
   161  
   162  // SetUpgradeSeriesUnitStatus sets the upgrade series status of the unit.
   163  // If no upgrade is in progress an error is returned instead.
   164  func (u *UpgradeSeriesAPI) SetUpgradeSeriesUnitStatus(
   165  	args params.UpgradeSeriesStatusParams,
   166  ) (params.ErrorResults, error) {
   167  	u.logger.Tracef("Starting SetUpgradeSeriesUnitStatus with %+v", args)
   168  	return u.setUnitStatus(args)
   169  }
   170  
   171  func (u *UpgradeSeriesAPI) GetMachine(tag names.Tag) (UpgradeSeriesMachine, error) {
   172  	var id string
   173  	switch tag.Kind() {
   174  	case names.MachineTagKind:
   175  		id = tag.Id()
   176  	case names.UnitTagKind:
   177  		unit, err := u.backend.Unit(tag.Id())
   178  		if err != nil {
   179  			return nil, errors.Trace(err)
   180  		}
   181  		id, err = unit.AssignedMachineId()
   182  		if err != nil {
   183  			return nil, errors.Trace(err)
   184  		}
   185  	default:
   186  	}
   187  	return u.backend.Machine(id)
   188  }
   189  
   190  func (u *UpgradeSeriesAPI) getUnit(tag names.Tag) (UpgradeSeriesUnit, error) {
   191  	return u.backend.Unit(tag.Id())
   192  }
   193  
   194  // NewExternalUpgradeSeriesAPI can be used for API registration.
   195  func NewExternalUpgradeSeriesAPI(
   196  	st *state.State,
   197  	resources facade.Resources,
   198  	authorizer facade.Authorizer,
   199  	accessMachine GetAuthFunc,
   200  	accessUnit GetAuthFunc,
   201  	logger loggo.Logger,
   202  ) *UpgradeSeriesAPI {
   203  	return NewUpgradeSeriesAPI(UpgradeSeriesState{st}, resources, authorizer, accessMachine, accessUnit, logger)
   204  }
   205  
   206  func (u *UpgradeSeriesAPI) setUnitStatus(args params.UpgradeSeriesStatusParams) (params.ErrorResults, error) {
   207  	result := params.ErrorResults{
   208  		Results: make([]params.ErrorResult, len(args.Params)),
   209  	}
   210  	canAccess, err := u.accessUnit()
   211  	if err != nil {
   212  		return params.ErrorResults{}, err
   213  	}
   214  	for i, p := range args.Params {
   215  		tag, err := names.ParseUnitTag(p.Entity.Tag)
   216  		if err != nil {
   217  			result.Results[i].Error = apiservererrors.ServerError(apiservererrors.ErrPerm)
   218  			continue
   219  		}
   220  		if !canAccess(tag) {
   221  			result.Results[i].Error = apiservererrors.ServerError(apiservererrors.ErrPerm)
   222  			continue
   223  		}
   224  		unit, err := u.getUnit(tag)
   225  		if err != nil {
   226  			result.Results[i].Error = apiservererrors.ServerError(err)
   227  			continue
   228  		}
   229  
   230  		graph := model.UpgradeSeriesGraph()
   231  		if !graph.ValidState(p.Status) {
   232  			result.Results[i].Error = apiservererrors.ServerError(errors.NotValidf("upgrade series status %q", p.Status))
   233  			continue
   234  		}
   235  
   236  		sts, _, err := unit.UpgradeSeriesStatus()
   237  		if err != nil {
   238  			logger.Tracef("unit upgrade series status not found, fallback to not-started: %v", err)
   239  			sts = model.UpgradeSeriesNotStarted
   240  		}
   241  		if !graph.ValidState(sts) {
   242  			result.Results[i].Error = apiservererrors.ServerError(errors.NotValidf("current upgrade series status %q", sts))
   243  			continue
   244  		}
   245  
   246  		// If attempting to set the same status, we're done.
   247  		// This can happen in situations where the upgrade completion hook
   248  		// fails and requires resolution before re-running.
   249  		if sts == p.Status {
   250  			logger.Debugf("unit %s already has upgrade series status %s", tag.Id(), sts)
   251  			continue
   252  		}
   253  
   254  		fsm, err := model.NewUpgradeSeriesFSM(graph, sts)
   255  		if err != nil {
   256  			result.Results[i].Error = apiservererrors.ServerError(err)
   257  			continue
   258  		}
   259  		if !fsm.TransitionTo(p.Status) {
   260  			result.Results[i].Error = apiservererrors.ServerError(errors.BadRequestf("upgrade series status %q", p.Status))
   261  			continue
   262  		}
   263  
   264  		if err = unit.SetUpgradeSeriesStatus(p.Status, p.Message); err != nil {
   265  			result.Results[i].Error = apiservererrors.ServerError(err)
   266  		}
   267  	}
   268  	return result, nil
   269  }
   270  
   271  func (u *UpgradeSeriesAPI) unitStatus(args params.Entities) (params.UpgradeSeriesStatusResults, error) {
   272  	canAccess, err := u.accessUnit()
   273  	if err != nil {
   274  		return params.UpgradeSeriesStatusResults{}, err
   275  	}
   276  
   277  	results := make([]params.UpgradeSeriesStatusResult, len(args.Entities))
   278  	for i, entity := range args.Entities {
   279  		tag, err := names.ParseUnitTag(entity.Tag)
   280  		if err != nil {
   281  			results[i].Error = apiservererrors.ServerError(apiservererrors.ErrPerm)
   282  			continue
   283  		}
   284  		if !canAccess(tag) {
   285  			results[i].Error = apiservererrors.ServerError(apiservererrors.ErrPerm)
   286  			continue
   287  		}
   288  		unit, err := u.getUnit(tag)
   289  		if err != nil {
   290  			results[i].Error = apiservererrors.ServerError(err)
   291  			continue
   292  		}
   293  		status, target, err := unit.UpgradeSeriesStatus()
   294  		if err != nil {
   295  			results[i].Error = apiservererrors.ServerError(err)
   296  			continue
   297  		}
   298  		results[i].Status = status
   299  		results[i].Target = target
   300  	}
   301  	return params.UpgradeSeriesStatusResults{Results: results}, nil
   302  }