launchpad.net/~rogpeppe/juju-core/500-errgo-fix@v0.0.0-20140213181702-000000002356/state/apiserver/client/run.go (about)

     1  // Copyright 2013 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package client
     5  
     6  import (
     7  	"fmt"
     8  	"launchpad.net/errgo/errors"
     9  	"launchpad.net/juju-core/state"
    10  	"launchpad.net/juju-core/utils"
    11  	"path/filepath"
    12  	"sort"
    13  	"sync"
    14  	"time"
    15  
    16  	"launchpad.net/juju-core/environs/cloudinit"
    17  	"launchpad.net/juju-core/instance"
    18  	"launchpad.net/juju-core/state/api/params"
    19  	"launchpad.net/juju-core/utils/set"
    20  	"launchpad.net/juju-core/utils/ssh"
    21  )
    22  
    23  // remoteParamsForMachine returns a filled in RemoteExec instance
    24  // based on the machine, command and timeout params.  If the machine
    25  // does not have an internal address, the Host is empty. This is caught
    26  // by the function that actually tries to execute the command.
    27  func remoteParamsForMachine(machine *state.Machine, command string, timeout time.Duration) *RemoteExec {
    28  	// magic boolean parameters are bad :-(
    29  	address := instance.SelectInternalAddress(machine.Addresses(), false)
    30  	execParams := &RemoteExec{
    31  		ExecParams: ssh.ExecParams{
    32  			Command: command,
    33  			Timeout: timeout,
    34  		},
    35  		MachineId: machine.Id(),
    36  	}
    37  	if address != "" {
    38  		execParams.Host = fmt.Sprintf("ubuntu@%s", address)
    39  	}
    40  	return execParams
    41  }
    42  
    43  // getAllUnitNames returns a sequence of valid Unit objects from state. If any
    44  // of the service names or unit names are not found, an error is returned.
    45  func getAllUnitNames(st *state.State, units, services []string) (result []*state.Unit, err error) {
    46  	unitsSet := set.NewStrings(units...)
    47  	for _, name := range services {
    48  		service, err := st.Service(name)
    49  		if err != nil {
    50  			return nil, mask(err)
    51  		}
    52  		units, err := service.AllUnits()
    53  		if err != nil {
    54  			return nil, mask(err)
    55  		}
    56  		for _, unit := range units {
    57  			unitsSet.Add(unit.Name())
    58  		}
    59  	}
    60  	for _, unitName := range unitsSet.Values() {
    61  		unit, err := st.Unit(unitName)
    62  		if err != nil {
    63  			return nil, mask(err)
    64  		}
    65  
    66  		// We only operate on principal units, and only thise that have an
    67  		// assigned machines.
    68  		if unit.IsPrincipal() {
    69  			if _, err := unit.AssignedMachineId(); err != nil {
    70  				return nil, mask(err)
    71  			}
    72  		} else {
    73  			return nil, errors.Newf("%s is not a principal unit", unit)
    74  		}
    75  		result = append(result, unit)
    76  	}
    77  	return result, nil
    78  }
    79  
    80  // Run the commands specified on the machines identified through the
    81  // list of machines, units and services.
    82  func (c *Client) Run(run params.RunParams) (results params.RunResults, err error) {
    83  	units, err := getAllUnitNames(c.api.state, run.Units, run.Services)
    84  	if err != nil {
    85  		return results, mask(err)
    86  	}
    87  
    88  	// We want to create a RemoteExec for each unit and each machine.
    89  	// If we have both a unit and a machine request, we run it twice,
    90  	// once for the unit inside the exec context using juju-run, and
    91  	// the other outside the context just using bash.
    92  	var params []*RemoteExec
    93  	var quotedCommands = utils.ShQuote(run.Commands)
    94  	for _, unit := range units {
    95  		// We know that the unit is both a principal unit, and that it has an
    96  		// assigned machine.
    97  		machineId, _ := unit.AssignedMachineId()
    98  		machine, err := c.api.state.Machine(machineId)
    99  		if err != nil {
   100  			return results, mask(err)
   101  		}
   102  		command := fmt.Sprintf("juju-run %s %s", unit.Name(), quotedCommands)
   103  		execParam := remoteParamsForMachine(machine, command, run.Timeout)
   104  		execParam.UnitId = unit.Name()
   105  		params = append(params, execParam)
   106  	}
   107  	for _, machineId := range run.Machines {
   108  		machine, err := c.api.state.Machine(machineId)
   109  		if err != nil {
   110  			return results, mask(err)
   111  		}
   112  		command := fmt.Sprintf("juju-run --no-context %s", quotedCommands)
   113  		execParam := remoteParamsForMachine(machine, command, run.Timeout)
   114  		params = append(params, execParam)
   115  	}
   116  	return ParallelExecute(c.api.dataDir, params), nil
   117  }
   118  
   119  // RunOnAllMachines attempts to run the specified command on all the machines.
   120  func (c *Client) RunOnAllMachines(run params.RunParams) (params.RunResults, error) {
   121  	machines, err := c.api.state.AllMachines()
   122  	if err != nil {
   123  		return params.RunResults{}, mask(err)
   124  	}
   125  	var params []*RemoteExec
   126  	quotedCommands := utils.ShQuote(run.Commands)
   127  	command := fmt.Sprintf("juju-run --no-context %s", quotedCommands)
   128  	for _, machine := range machines {
   129  		params = append(params, remoteParamsForMachine(machine, command, run.Timeout))
   130  	}
   131  	return ParallelExecute(c.api.dataDir, params), nil
   132  }
   133  
   134  // RemoteExec extends the standard ssh.ExecParams by providing the machine and
   135  // perhaps the unit ids.  These are then returned in the params.RunResult return
   136  // values.
   137  type RemoteExec struct {
   138  	ssh.ExecParams
   139  	MachineId string
   140  	UnitId    string
   141  }
   142  
   143  // ParallelExecute executes all of the requests defined in the params,
   144  // using the system identity stored in the dataDir.
   145  func ParallelExecute(dataDir string, runParams []*RemoteExec) params.RunResults {
   146  	logger.Debugf("exec %#v", runParams)
   147  	var outstanding sync.WaitGroup
   148  	var lock sync.Mutex
   149  	var result []params.RunResult
   150  	identity := filepath.Join(dataDir, cloudinit.SystemIdentity)
   151  	for _, param := range runParams {
   152  		outstanding.Add(1)
   153  		logger.Debugf("exec on %s: %#v", param.MachineId, *param)
   154  		param.IdentityFile = identity
   155  		go func(param *RemoteExec) {
   156  			response, err := ssh.ExecuteCommandOnMachine(param.ExecParams)
   157  			logger.Debugf("reponse from %s: %v (err:%v)", param.MachineId, response, err)
   158  			execResponse := params.RunResult{
   159  				ExecResponse: response,
   160  				MachineId:    param.MachineId,
   161  				UnitId:       param.UnitId,
   162  			}
   163  			if err != nil {
   164  				execResponse.Error = fmt.Sprint(err)
   165  			}
   166  
   167  			lock.Lock()
   168  			defer lock.Unlock()
   169  			result = append(result, execResponse)
   170  			outstanding.Done()
   171  		}(param)
   172  	}
   173  
   174  	outstanding.Wait()
   175  	sort.Sort(MachineOrder(result))
   176  	return params.RunResults{result}
   177  }
   178  
   179  // MachineOrder is used to provide the api to sort the results by the machine
   180  // id.
   181  type MachineOrder []params.RunResult
   182  
   183  func (a MachineOrder) Len() int           { return len(a) }
   184  func (a MachineOrder) Swap(i, j int)      { a[i], a[j] = a[j], a[i] }
   185  func (a MachineOrder) Less(i, j int) bool { return a[i].MachineId < a[j].MachineId }