github.com/cloudbase/juju-core@v0.0.0-20140504232958-a7271ac7912f/state/apiserver/client/run.go (about)

     1  // Copyright 2013 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package client
     5  
     6  import (
     7  	"fmt"
     8  	"path/filepath"
     9  	"sort"
    10  	"sync"
    11  	"time"
    12  
    13  	"launchpad.net/juju-core/environs/cloudinit"
    14  	"launchpad.net/juju-core/instance"
    15  	"launchpad.net/juju-core/state"
    16  	"launchpad.net/juju-core/state/api/params"
    17  	"launchpad.net/juju-core/utils"
    18  	"launchpad.net/juju-core/utils/set"
    19  	"launchpad.net/juju-core/utils/ssh"
    20  )
    21  
    22  // remoteParamsForMachine returns a filled in RemoteExec instance
    23  // based on the machine, command and timeout params.  If the machine
    24  // does not have an internal address, the Host is empty. This is caught
    25  // by the function that actually tries to execute the command.
    26  func remoteParamsForMachine(machine *state.Machine, command string, timeout time.Duration) *RemoteExec {
    27  	// magic boolean parameters are bad :-(
    28  	address := instance.SelectInternalAddress(machine.Addresses(), false)
    29  	execParams := &RemoteExec{
    30  		ExecParams: ssh.ExecParams{
    31  			Command: command,
    32  			Timeout: timeout,
    33  		},
    34  		MachineId: machine.Id(),
    35  	}
    36  	if address != "" {
    37  		execParams.Host = fmt.Sprintf("ubuntu@%s", address)
    38  	}
    39  	return execParams
    40  }
    41  
    42  // getAllUnitNames returns a sequence of valid Unit objects from state. If any
    43  // of the service names or unit names are not found, an error is returned.
    44  func getAllUnitNames(st *state.State, units, services []string) (result []*state.Unit, err error) {
    45  	unitsSet := set.NewStrings(units...)
    46  	for _, name := range services {
    47  		service, err := st.Service(name)
    48  		if err != nil {
    49  			return nil, err
    50  		}
    51  		units, err := service.AllUnits()
    52  		if err != nil {
    53  			return nil, err
    54  		}
    55  		for _, unit := range units {
    56  			unitsSet.Add(unit.Name())
    57  		}
    58  	}
    59  	for _, unitName := range unitsSet.Values() {
    60  		unit, err := st.Unit(unitName)
    61  		if err != nil {
    62  			return nil, err
    63  		}
    64  		// We only operate on principal units, and only thise that have an
    65  		// assigned machines.
    66  		if unit.IsPrincipal() {
    67  			if _, err := unit.AssignedMachineId(); err != nil {
    68  				return nil, err
    69  			}
    70  		} else {
    71  			return nil, fmt.Errorf("%s is not a principal unit", unit)
    72  		}
    73  		result = append(result, unit)
    74  	}
    75  	return result, nil
    76  }
    77  
    78  // Run the commands specified on the machines identified through the
    79  // list of machines, units and services.
    80  func (c *Client) Run(run params.RunParams) (results params.RunResults, err error) {
    81  	units, err := getAllUnitNames(c.api.state, run.Units, run.Services)
    82  	if err != nil {
    83  		return results, err
    84  	}
    85  	// We want to create a RemoteExec for each unit and each machine.
    86  	// If we have both a unit and a machine request, we run it twice,
    87  	// once for the unit inside the exec context using juju-run, and
    88  	// the other outside the context just using bash.
    89  	var params []*RemoteExec
    90  	var quotedCommands = utils.ShQuote(run.Commands)
    91  	for _, unit := range units {
    92  		// We know that the unit is both a principal unit, and that it has an
    93  		// assigned machine.
    94  		machineId, _ := unit.AssignedMachineId()
    95  		machine, err := c.api.state.Machine(machineId)
    96  		if err != nil {
    97  			return results, err
    98  		}
    99  		command := fmt.Sprintf("juju-run %s %s", unit.Name(), quotedCommands)
   100  		execParam := remoteParamsForMachine(machine, command, run.Timeout)
   101  		execParam.UnitId = unit.Name()
   102  		params = append(params, execParam)
   103  	}
   104  	for _, machineId := range run.Machines {
   105  		machine, err := c.api.state.Machine(machineId)
   106  		if err != nil {
   107  			return results, err
   108  		}
   109  		command := fmt.Sprintf("juju-run --no-context %s", quotedCommands)
   110  		execParam := remoteParamsForMachine(machine, command, run.Timeout)
   111  		params = append(params, execParam)
   112  	}
   113  	return ParallelExecute(c.api.dataDir, params), nil
   114  }
   115  
   116  // RunOnAllMachines attempts to run the specified command on all the machines.
   117  func (c *Client) RunOnAllMachines(run params.RunParams) (params.RunResults, error) {
   118  	machines, err := c.api.state.AllMachines()
   119  	if err != nil {
   120  		return params.RunResults{}, err
   121  	}
   122  	var params []*RemoteExec
   123  	quotedCommands := utils.ShQuote(run.Commands)
   124  	command := fmt.Sprintf("juju-run --no-context %s", quotedCommands)
   125  	for _, machine := range machines {
   126  		params = append(params, remoteParamsForMachine(machine, command, run.Timeout))
   127  	}
   128  	return ParallelExecute(c.api.dataDir, params), nil
   129  }
   130  
   131  // RemoteExec extends the standard ssh.ExecParams by providing the machine and
   132  // perhaps the unit ids.  These are then returned in the params.RunResult return
   133  // values.
   134  type RemoteExec struct {
   135  	ssh.ExecParams
   136  	MachineId string
   137  	UnitId    string
   138  }
   139  
   140  // ParallelExecute executes all of the requests defined in the params,
   141  // using the system identity stored in the dataDir.
   142  func ParallelExecute(dataDir string, runParams []*RemoteExec) params.RunResults {
   143  	logger.Debugf("exec %#v", runParams)
   144  	var outstanding sync.WaitGroup
   145  	var lock sync.Mutex
   146  	var result []params.RunResult
   147  	identity := filepath.Join(dataDir, cloudinit.SystemIdentity)
   148  	for _, param := range runParams {
   149  		outstanding.Add(1)
   150  		logger.Debugf("exec on %s: %#v", param.MachineId, *param)
   151  		param.IdentityFile = identity
   152  		go func(param *RemoteExec) {
   153  			response, err := ssh.ExecuteCommandOnMachine(param.ExecParams)
   154  			logger.Debugf("reponse from %s: %v (err:%v)", param.MachineId, response, err)
   155  			execResponse := params.RunResult{
   156  				ExecResponse: response,
   157  				MachineId:    param.MachineId,
   158  				UnitId:       param.UnitId,
   159  			}
   160  			if err != nil {
   161  				execResponse.Error = fmt.Sprint(err)
   162  			}
   163  
   164  			lock.Lock()
   165  			defer lock.Unlock()
   166  			result = append(result, execResponse)
   167  			outstanding.Done()
   168  		}(param)
   169  	}
   170  
   171  	outstanding.Wait()
   172  	sort.Sort(MachineOrder(result))
   173  	return params.RunResults{result}
   174  }
   175  
   176  // MachineOrder is used to provide the api to sort the results by the machine
   177  // id.
   178  type MachineOrder []params.RunResult
   179  
   180  func (a MachineOrder) Len() int           { return len(a) }
   181  func (a MachineOrder) Swap(i, j int)      { a[i], a[j] = a[j], a[i] }
   182  func (a MachineOrder) Less(i, j int) bool { return a[i].MachineId < a[j].MachineId }