github.com/mattyw/juju@v0.0.0-20140610034352-732aecd63861/state/apiserver/client/run.go (about)

     1  // Copyright 2013 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package client
     5  
     6  import (
     7  	"fmt"
     8  	"path/filepath"
     9  	"sort"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/juju/utils"
    14  	"github.com/juju/utils/set"
    15  
    16  	"github.com/juju/juju/agent"
    17  	"github.com/juju/juju/instance"
    18  	"github.com/juju/juju/state"
    19  	"github.com/juju/juju/state/api/params"
    20  	"github.com/juju/juju/state/apiserver/common"
    21  	"github.com/juju/juju/utils/ssh"
    22  )
    23  
    24  // remoteParamsForMachine returns a filled in RemoteExec instance
    25  // based on the machine, command and timeout params.  If the machine
    26  // does not have an internal address, the Host is empty. This is caught
    27  // by the function that actually tries to execute the command.
    28  func remoteParamsForMachine(machine *state.Machine, command string, timeout time.Duration) *RemoteExec {
    29  	// magic boolean parameters are bad :-(
    30  	address := instance.SelectInternalAddress(machine.Addresses(), false)
    31  	execParams := &RemoteExec{
    32  		ExecParams: ssh.ExecParams{
    33  			Command: command,
    34  			Timeout: timeout,
    35  		},
    36  		MachineId: machine.Id(),
    37  	}
    38  	if address != "" {
    39  		execParams.Host = fmt.Sprintf("ubuntu@%s", address)
    40  	}
    41  	return execParams
    42  }
    43  
    44  // getAllUnitNames returns a sequence of valid Unit objects from state. If any
    45  // of the service names or unit names are not found, an error is returned.
    46  func getAllUnitNames(st *state.State, units, services []string) (result []*state.Unit, err error) {
    47  	unitsSet := set.NewStrings(units...)
    48  	for _, name := range services {
    49  		service, err := st.Service(name)
    50  		if err != nil {
    51  			return nil, err
    52  		}
    53  		units, err := service.AllUnits()
    54  		if err != nil {
    55  			return nil, err
    56  		}
    57  		for _, unit := range units {
    58  			unitsSet.Add(unit.Name())
    59  		}
    60  	}
    61  	for _, unitName := range unitsSet.Values() {
    62  		unit, err := st.Unit(unitName)
    63  		if err != nil {
    64  			return nil, err
    65  		}
    66  		// We only operate on principal units, and only thise that have an
    67  		// assigned machines.
    68  		if unit.IsPrincipal() {
    69  			if _, err := unit.AssignedMachineId(); err != nil {
    70  				return nil, err
    71  			}
    72  		} else {
    73  			return nil, fmt.Errorf("%s is not a principal unit", unit)
    74  		}
    75  		result = append(result, unit)
    76  	}
    77  	return result, nil
    78  }
    79  
    80  func (c *Client) getDataDir() string {
    81  	dataResource, ok := c.api.resources.Get("dataDir").(common.StringResource)
    82  	if !ok {
    83  		return ""
    84  	}
    85  	return dataResource.String()
    86  }
    87  
    88  // Run the commands specified on the machines identified through the
    89  // list of machines, units and services.
    90  func (c *Client) Run(run params.RunParams) (results params.RunResults, err error) {
    91  	units, err := getAllUnitNames(c.api.state, run.Units, run.Services)
    92  	if err != nil {
    93  		return results, err
    94  	}
    95  	// We want to create a RemoteExec for each unit and each machine.
    96  	// If we have both a unit and a machine request, we run it twice,
    97  	// once for the unit inside the exec context using juju-run, and
    98  	// the other outside the context just using bash.
    99  	var params []*RemoteExec
   100  	var quotedCommands = utils.ShQuote(run.Commands)
   101  	for _, unit := range units {
   102  		// We know that the unit is both a principal unit, and that it has an
   103  		// assigned machine.
   104  		machineId, _ := unit.AssignedMachineId()
   105  		machine, err := c.api.state.Machine(machineId)
   106  		if err != nil {
   107  			return results, err
   108  		}
   109  		command := fmt.Sprintf("juju-run %s %s", unit.Name(), quotedCommands)
   110  		execParam := remoteParamsForMachine(machine, command, run.Timeout)
   111  		execParam.UnitId = unit.Name()
   112  		params = append(params, execParam)
   113  	}
   114  	for _, machineId := range run.Machines {
   115  		machine, err := c.api.state.Machine(machineId)
   116  		if err != nil {
   117  			return results, err
   118  		}
   119  		command := fmt.Sprintf("juju-run --no-context %s", quotedCommands)
   120  		execParam := remoteParamsForMachine(machine, command, run.Timeout)
   121  		params = append(params, execParam)
   122  	}
   123  	return ParallelExecute(c.getDataDir(), params), nil
   124  }
   125  
   126  // RunOnAllMachines attempts to run the specified command on all the machines.
   127  func (c *Client) RunOnAllMachines(run params.RunParams) (params.RunResults, error) {
   128  	machines, err := c.api.state.AllMachines()
   129  	if err != nil {
   130  		return params.RunResults{}, err
   131  	}
   132  	var params []*RemoteExec
   133  	quotedCommands := utils.ShQuote(run.Commands)
   134  	command := fmt.Sprintf("juju-run --no-context %s", quotedCommands)
   135  	for _, machine := range machines {
   136  		params = append(params, remoteParamsForMachine(machine, command, run.Timeout))
   137  	}
   138  	return ParallelExecute(c.getDataDir(), params), nil
   139  }
   140  
   141  // RemoteExec extends the standard ssh.ExecParams by providing the machine and
   142  // perhaps the unit ids.  These are then returned in the params.RunResult return
   143  // values.
   144  type RemoteExec struct {
   145  	ssh.ExecParams
   146  	MachineId string
   147  	UnitId    string
   148  }
   149  
   150  // ParallelExecute executes all of the requests defined in the params,
   151  // using the system identity stored in the dataDir.
   152  func ParallelExecute(dataDir string, runParams []*RemoteExec) params.RunResults {
   153  	logger.Debugf("exec %#v", runParams)
   154  	var outstanding sync.WaitGroup
   155  	var lock sync.Mutex
   156  	var result []params.RunResult
   157  	identity := filepath.Join(dataDir, agent.SystemIdentity)
   158  	for _, param := range runParams {
   159  		outstanding.Add(1)
   160  		logger.Debugf("exec on %s: %#v", param.MachineId, *param)
   161  		param.IdentityFile = identity
   162  		go func(param *RemoteExec) {
   163  			response, err := ssh.ExecuteCommandOnMachine(param.ExecParams)
   164  			logger.Debugf("reponse from %s: %v (err:%v)", param.MachineId, response, err)
   165  			execResponse := params.RunResult{
   166  				ExecResponse: response,
   167  				MachineId:    param.MachineId,
   168  				UnitId:       param.UnitId,
   169  			}
   170  			if err != nil {
   171  				execResponse.Error = fmt.Sprint(err)
   172  			}
   173  
   174  			lock.Lock()
   175  			defer lock.Unlock()
   176  			result = append(result, execResponse)
   177  			outstanding.Done()
   178  		}(param)
   179  	}
   180  
   181  	outstanding.Wait()
   182  	sort.Sort(MachineOrder(result))
   183  	return params.RunResults{result}
   184  }
   185  
   186  // MachineOrder is used to provide the api to sort the results by the machine
   187  // id.
   188  type MachineOrder []params.RunResult
   189  
   190  func (a MachineOrder) Len() int           { return len(a) }
   191  func (a MachineOrder) Swap(i, j int)      { a[i], a[j] = a[j], a[i] }
   192  func (a MachineOrder) Less(i, j int) bool { return a[i].MachineId < a[j].MachineId }