github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/apiserver/common/ensuredead.go (about)

     1  // Copyright 2013 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package common
     5  
     6  import (
     7  	"github.com/juju/errors"
     8  	"github.com/juju/names/v5"
     9  
    10  	apiservererrors "github.com/juju/juju/apiserver/errors"
    11  	"github.com/juju/juju/rpc/params"
    12  	"github.com/juju/juju/state"
    13  )
    14  
    15  // DeadEnsurer implements a common EnsureDead method for use by
    16  // various facades.
    17  type DeadEnsurer struct {
    18  	st           state.EntityFinder
    19  	afterDead    func(names.Tag)
    20  	getCanModify GetAuthFunc
    21  }
    22  
    23  // NewDeadEnsurer returns a new DeadEnsurer. The GetAuthFunc will be
    24  // used on each invocation of EnsureDead to determine current
    25  // permissions.
    26  func NewDeadEnsurer(st state.EntityFinder, afterDead func(names.Tag), getCanModify GetAuthFunc) *DeadEnsurer {
    27  	return &DeadEnsurer{
    28  		st:           st,
    29  		afterDead:    afterDead,
    30  		getCanModify: getCanModify,
    31  	}
    32  }
    33  
    34  func (d *DeadEnsurer) ensureEntityDead(tag names.Tag) error {
    35  	entity0, err := d.st.FindEntity(tag)
    36  	if err != nil {
    37  		return err
    38  	}
    39  	entity, ok := entity0.(state.EnsureDeader)
    40  	if !ok {
    41  		return apiservererrors.NotSupportedError(tag, "ensuring death")
    42  	}
    43  	if err := entity.EnsureDead(); err != nil {
    44  		return errors.Trace(err)
    45  	}
    46  	if d.afterDead != nil {
    47  		d.afterDead(tag)
    48  	}
    49  	return nil
    50  }
    51  
    52  // EnsureDead calls EnsureDead on each given entity from state. It
    53  // will fail if the entity is not present. If it's Alive, nothing will
    54  // happen (see state/EnsureDead() for units or machines).
    55  func (d *DeadEnsurer) EnsureDead(args params.Entities) (params.ErrorResults, error) {
    56  	result := params.ErrorResults{
    57  		Results: make([]params.ErrorResult, len(args.Entities)),
    58  	}
    59  	if len(args.Entities) == 0 {
    60  		return result, nil
    61  	}
    62  	canModify, err := d.getCanModify()
    63  	if err != nil {
    64  		return params.ErrorResults{}, errors.Trace(err)
    65  	}
    66  	for i, entity := range args.Entities {
    67  		tag, err := names.ParseTag(entity.Tag)
    68  		if err != nil {
    69  			return params.ErrorResults{}, errors.Trace(err)
    70  		}
    71  
    72  		err = apiservererrors.ErrPerm
    73  		if canModify(tag) {
    74  			err = d.ensureEntityDead(tag)
    75  		}
    76  		result.Results[i].Error = apiservererrors.ServerError(err)
    77  	}
    78  	return result, nil
    79  }