github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/apiserver/facades/client/action/run.go (about)

     1  // Copyright 2013 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package action
     5  
     6  import (
     7  	"fmt"
     8  	"strings"
     9  	"time"
    10  
    11  	"github.com/juju/collections/set"
    12  	"github.com/juju/errors"
    13  	"github.com/juju/names/v5"
    14  
    15  	apiservererrors "github.com/juju/juju/apiserver/errors"
    16  	"github.com/juju/juju/core/actions"
    17  	"github.com/juju/juju/rpc/params"
    18  	"github.com/juju/juju/state"
    19  )
    20  
    21  // Run the commands specified on the machines identified through the
    22  // list of machines, units and services.
    23  func (a *ActionAPI) Run(run params.RunParams) (results params.EnqueuedActions, err error) {
    24  	if err := a.checkCanAdmin(); err != nil {
    25  		return results, err
    26  	}
    27  
    28  	if err := a.check.ChangeAllowed(); err != nil {
    29  		return results, errors.Trace(err)
    30  	}
    31  
    32  	units, err := a.getAllUnitNames(run.Units, run.Applications)
    33  	if err != nil {
    34  		return results, errors.Trace(err)
    35  	}
    36  
    37  	machines := make([]names.Tag, len(run.Machines))
    38  	for i, machineId := range run.Machines {
    39  		if !names.IsValidMachine(machineId) {
    40  			return results, errors.Errorf("invalid machine id %q", machineId)
    41  		}
    42  		machines[i] = names.NewMachineTag(machineId)
    43  	}
    44  
    45  	actionParams, err := a.createRunActionsParams(append(units, machines...), run.Commands, run.Timeout, run.WorkloadContext, run.Parallel, run.ExecutionGroup)
    46  	if err != nil {
    47  		return results, errors.Trace(err)
    48  	}
    49  	return a.EnqueueOperation(actionParams)
    50  }
    51  
    52  // RunOnAllMachines attempts to run the specified command on all the machines.
    53  func (a *ActionAPI) RunOnAllMachines(run params.RunParams) (results params.EnqueuedActions, err error) {
    54  	if err := a.checkCanAdmin(); err != nil {
    55  		return results, err
    56  	}
    57  
    58  	if err := a.check.ChangeAllowed(); err != nil {
    59  		return results, errors.Trace(err)
    60  	}
    61  
    62  	m, err := a.state.Model()
    63  	if err != nil {
    64  		return results, errors.Trace(err)
    65  	}
    66  	if m.Type() != state.ModelTypeIAAS {
    67  		return results, errors.Errorf("cannot run on all machines with a %s model", m.Type())
    68  	}
    69  
    70  	machines, err := a.state.AllMachines()
    71  	if err != nil {
    72  		return results, err
    73  	}
    74  	machineTags := make([]names.Tag, len(machines))
    75  	for i, machine := range machines {
    76  		machineTags[i] = machine.Tag()
    77  	}
    78  
    79  	actionParams, err := a.createRunActionsParams(machineTags, run.Commands, run.Timeout, false, run.Parallel, run.ExecutionGroup)
    80  	if err != nil {
    81  		return results, errors.Trace(err)
    82  	}
    83  	return a.EnqueueOperation(actionParams)
    84  }
    85  
    86  func (a *ActionAPI) createRunActionsParams(
    87  	actionReceiverTags []names.Tag,
    88  	quotedCommands string,
    89  	timeout time.Duration,
    90  	workloadContext bool,
    91  	parallel *bool,
    92  	executionGroup *string,
    93  ) (params.Actions, error) {
    94  	apiActionParams := params.Actions{Actions: []params.Action{}}
    95  
    96  	if actions.HasJujuExecAction(quotedCommands) {
    97  		return apiActionParams, errors.NewNotSupported(nil, fmt.Sprintf("cannot use %q as an action command", quotedCommands))
    98  	}
    99  
   100  	actionParams := map[string]interface{}{}
   101  	actionParams["command"] = quotedCommands
   102  	actionParams["timeout"] = timeout.Nanoseconds()
   103  	actionParams["workload-context"] = workloadContext
   104  
   105  	for _, tag := range actionReceiverTags {
   106  		apiActionParams.Actions = append(apiActionParams.Actions, params.Action{
   107  			Receiver:       tag.String(),
   108  			Name:           actions.JujuExecActionName,
   109  			Parameters:     actionParams,
   110  			Parallel:       parallel,
   111  			ExecutionGroup: executionGroup,
   112  		})
   113  	}
   114  
   115  	return apiActionParams, nil
   116  }
   117  
   118  // getAllUnitNames returns a sequence of valid Unit objects from state. If any
   119  // of the application names or unit names are not found, an error is returned.
   120  func (a *ActionAPI) getAllUnitNames(units, applications []string) (result []names.Tag, err error) {
   121  	var leaders map[string]string
   122  	getLeader := func(appName string) (string, error) {
   123  		if leaders == nil {
   124  			var err error
   125  			leaders, err = a.leadership.Leaders()
   126  			if err != nil {
   127  				return "", err
   128  			}
   129  		}
   130  		if leader, ok := leaders[appName]; ok {
   131  			return leader, nil
   132  		}
   133  		return "", errors.Errorf("could not determine leader for %q", appName)
   134  	}
   135  
   136  	// Replace units matching $app/leader with the appropriate unit for
   137  	// the leader.
   138  	unitsSet := set.NewStrings()
   139  	for _, unit := range units {
   140  		if !strings.HasSuffix(unit, "leader") {
   141  			unitsSet.Add(unit)
   142  			continue
   143  		}
   144  
   145  		app := strings.Split(unit, "/")[0]
   146  		leaderUnit, err := getLeader(app)
   147  		if err != nil {
   148  			return nil, apiservererrors.ServerError(err)
   149  		}
   150  
   151  		unitsSet.Add(leaderUnit)
   152  	}
   153  
   154  	for _, name := range applications {
   155  		service, err := a.state.Application(name)
   156  		if err != nil {
   157  			return nil, err
   158  		}
   159  		units, err := service.AllUnits()
   160  		if err != nil {
   161  			return nil, err
   162  		}
   163  		for _, unit := range units {
   164  			unitsSet.Add(unit.Name())
   165  		}
   166  	}
   167  	for _, unitName := range unitsSet.SortedValues() {
   168  		if !names.IsValidUnit(unitName) {
   169  			return nil, errors.Errorf("invalid unit name %q", unitName)
   170  		}
   171  		result = append(result, names.NewUnitTag(unitName))
   172  	}
   173  	return result, nil
   174  }