github.com/mattyw/juju@v0.0.0-20140610034352-732aecd63861/state/apiserver/client/run.go (about) 1 // Copyright 2013 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package client 5 6 import ( 7 "fmt" 8 "path/filepath" 9 "sort" 10 "sync" 11 "time" 12 13 "github.com/juju/utils" 14 "github.com/juju/utils/set" 15 16 "github.com/juju/juju/agent" 17 "github.com/juju/juju/instance" 18 "github.com/juju/juju/state" 19 "github.com/juju/juju/state/api/params" 20 "github.com/juju/juju/state/apiserver/common" 21 "github.com/juju/juju/utils/ssh" 22 ) 23 24 // remoteParamsForMachine returns a filled in RemoteExec instance 25 // based on the machine, command and timeout params. If the machine 26 // does not have an internal address, the Host is empty. This is caught 27 // by the function that actually tries to execute the command. 28 func remoteParamsForMachine(machine *state.Machine, command string, timeout time.Duration) *RemoteExec { 29 // magic boolean parameters are bad :-( 30 address := instance.SelectInternalAddress(machine.Addresses(), false) 31 execParams := &RemoteExec{ 32 ExecParams: ssh.ExecParams{ 33 Command: command, 34 Timeout: timeout, 35 }, 36 MachineId: machine.Id(), 37 } 38 if address != "" { 39 execParams.Host = fmt.Sprintf("ubuntu@%s", address) 40 } 41 return execParams 42 } 43 44 // getAllUnitNames returns a sequence of valid Unit objects from state. If any 45 // of the service names or unit names are not found, an error is returned. 46 func getAllUnitNames(st *state.State, units, services []string) (result []*state.Unit, err error) { 47 unitsSet := set.NewStrings(units...) 48 for _, name := range services { 49 service, err := st.Service(name) 50 if err != nil { 51 return nil, err 52 } 53 units, err := service.AllUnits() 54 if err != nil { 55 return nil, err 56 } 57 for _, unit := range units { 58 unitsSet.Add(unit.Name()) 59 } 60 } 61 for _, unitName := range unitsSet.Values() { 62 unit, err := st.Unit(unitName) 63 if err != nil { 64 return nil, err 65 } 66 // We only operate on principal units, and only thise that have an 67 // assigned machines. 68 if unit.IsPrincipal() { 69 if _, err := unit.AssignedMachineId(); err != nil { 70 return nil, err 71 } 72 } else { 73 return nil, fmt.Errorf("%s is not a principal unit", unit) 74 } 75 result = append(result, unit) 76 } 77 return result, nil 78 } 79 80 func (c *Client) getDataDir() string { 81 dataResource, ok := c.api.resources.Get("dataDir").(common.StringResource) 82 if !ok { 83 return "" 84 } 85 return dataResource.String() 86 } 87 88 // Run the commands specified on the machines identified through the 89 // list of machines, units and services. 90 func (c *Client) Run(run params.RunParams) (results params.RunResults, err error) { 91 units, err := getAllUnitNames(c.api.state, run.Units, run.Services) 92 if err != nil { 93 return results, err 94 } 95 // We want to create a RemoteExec for each unit and each machine. 96 // If we have both a unit and a machine request, we run it twice, 97 // once for the unit inside the exec context using juju-run, and 98 // the other outside the context just using bash. 99 var params []*RemoteExec 100 var quotedCommands = utils.ShQuote(run.Commands) 101 for _, unit := range units { 102 // We know that the unit is both a principal unit, and that it has an 103 // assigned machine. 104 machineId, _ := unit.AssignedMachineId() 105 machine, err := c.api.state.Machine(machineId) 106 if err != nil { 107 return results, err 108 } 109 command := fmt.Sprintf("juju-run %s %s", unit.Name(), quotedCommands) 110 execParam := remoteParamsForMachine(machine, command, run.Timeout) 111 execParam.UnitId = unit.Name() 112 params = append(params, execParam) 113 } 114 for _, machineId := range run.Machines { 115 machine, err := c.api.state.Machine(machineId) 116 if err != nil { 117 return results, err 118 } 119 command := fmt.Sprintf("juju-run --no-context %s", quotedCommands) 120 execParam := remoteParamsForMachine(machine, command, run.Timeout) 121 params = append(params, execParam) 122 } 123 return ParallelExecute(c.getDataDir(), params), nil 124 } 125 126 // RunOnAllMachines attempts to run the specified command on all the machines. 127 func (c *Client) RunOnAllMachines(run params.RunParams) (params.RunResults, error) { 128 machines, err := c.api.state.AllMachines() 129 if err != nil { 130 return params.RunResults{}, err 131 } 132 var params []*RemoteExec 133 quotedCommands := utils.ShQuote(run.Commands) 134 command := fmt.Sprintf("juju-run --no-context %s", quotedCommands) 135 for _, machine := range machines { 136 params = append(params, remoteParamsForMachine(machine, command, run.Timeout)) 137 } 138 return ParallelExecute(c.getDataDir(), params), nil 139 } 140 141 // RemoteExec extends the standard ssh.ExecParams by providing the machine and 142 // perhaps the unit ids. These are then returned in the params.RunResult return 143 // values. 144 type RemoteExec struct { 145 ssh.ExecParams 146 MachineId string 147 UnitId string 148 } 149 150 // ParallelExecute executes all of the requests defined in the params, 151 // using the system identity stored in the dataDir. 152 func ParallelExecute(dataDir string, runParams []*RemoteExec) params.RunResults { 153 logger.Debugf("exec %#v", runParams) 154 var outstanding sync.WaitGroup 155 var lock sync.Mutex 156 var result []params.RunResult 157 identity := filepath.Join(dataDir, agent.SystemIdentity) 158 for _, param := range runParams { 159 outstanding.Add(1) 160 logger.Debugf("exec on %s: %#v", param.MachineId, *param) 161 param.IdentityFile = identity 162 go func(param *RemoteExec) { 163 response, err := ssh.ExecuteCommandOnMachine(param.ExecParams) 164 logger.Debugf("reponse from %s: %v (err:%v)", param.MachineId, response, err) 165 execResponse := params.RunResult{ 166 ExecResponse: response, 167 MachineId: param.MachineId, 168 UnitId: param.UnitId, 169 } 170 if err != nil { 171 execResponse.Error = fmt.Sprint(err) 172 } 173 174 lock.Lock() 175 defer lock.Unlock() 176 result = append(result, execResponse) 177 outstanding.Done() 178 }(param) 179 } 180 181 outstanding.Wait() 182 sort.Sort(MachineOrder(result)) 183 return params.RunResults{result} 184 } 185 186 // MachineOrder is used to provide the api to sort the results by the machine 187 // id. 188 type MachineOrder []params.RunResult 189 190 func (a MachineOrder) Len() int { return len(a) } 191 func (a MachineOrder) Swap(i, j int) { a[i], a[j] = a[j], a[i] } 192 func (a MachineOrder) Less(i, j int) bool { return a[i].MachineId < a[j].MachineId }