github.com/axw/juju@v0.0.0-20161005053422-4bd6544d08d4/cmd/juju/commands/ssh_common.go (about) 1 // Copyright 2016 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package commands 5 6 import ( 7 "bufio" 8 "io" 9 "io/ioutil" 10 "net" 11 "os" 12 "os/exec" 13 "strings" 14 "time" 15 16 "github.com/juju/errors" 17 "github.com/juju/gnuflag" 18 "github.com/juju/utils" 19 "github.com/juju/utils/set" 20 "github.com/juju/utils/ssh" 21 "gopkg.in/juju/names.v2" 22 23 "github.com/juju/juju/api/sshclient" 24 "github.com/juju/juju/cmd/modelcmd" 25 ) 26 27 // SSHCommon implements functionality shared by sshCommand, SCPCommand 28 // and DebugHooksCommand. 29 type SSHCommon struct { 30 modelcmd.ModelCommandBase 31 proxy bool 32 pty bool 33 noHostKeyChecks bool 34 Target string 35 Args []string 36 apiClient sshAPIClient 37 apiAddr string 38 knownHostsPath string 39 } 40 41 type sshAPIClient interface { 42 PublicAddress(target string) (string, error) 43 PrivateAddress(target string) (string, error) 44 PublicKeys(target string) ([]string, error) 45 Proxy() (bool, error) 46 Close() error 47 } 48 49 type resolvedTarget struct { 50 user string 51 entity string 52 host string 53 } 54 55 func (t *resolvedTarget) userHost() string { 56 if t.user == "" { 57 return t.host 58 } 59 return t.user + "@" + t.host 60 } 61 62 func (t *resolvedTarget) isAgent() bool { 63 return targetIsAgent(t.entity) 64 } 65 66 // attemptStarter is an interface corresponding to utils.AttemptStrategy 67 // 68 // TODO(katco): 2016-08-09: lp:1611427 69 type attemptStarter interface { 70 Start() attempt 71 } 72 73 type attempt interface { 74 Next() bool 75 } 76 77 // TODO(katco): 2016-08-09: lp:1611427 78 type attemptStrategy utils.AttemptStrategy 79 80 func (s attemptStrategy) Start() attempt { 81 // TODO(katco): 2016-08-09: lp:1611427 82 return utils.AttemptStrategy(s).Start() 83 } 84 85 var sshHostFromTargetAttemptStrategy attemptStarter = attemptStrategy{ 86 Total: 5 * time.Second, 87 Delay: 500 * time.Millisecond, 88 } 89 90 func (c *SSHCommon) SetFlags(f *gnuflag.FlagSet) { 91 c.ModelCommandBase.SetFlags(f) 92 f.BoolVar(&c.proxy, "proxy", false, "Proxy through the API server") 93 f.BoolVar(&c.pty, "pty", true, "Enable pseudo-tty allocation") 94 f.BoolVar(&c.noHostKeyChecks, "no-host-key-checks", false, "Skip host key checking (INSECURE)") 95 } 96 97 // initRun initializes the API connection if required, and determines 98 // if SSH proxying is required. It must be called at the top of the 99 // command's Run method. 100 // 101 // The apiClient, apiAddr and proxy fields are initialized after this 102 // call. 103 func (c *SSHCommon) initRun() error { 104 if err := c.ensureAPIClient(); err != nil { 105 return errors.Trace(err) 106 } 107 if proxy, err := c.proxySSH(); err != nil { 108 return errors.Trace(err) 109 } else { 110 c.proxy = proxy 111 } 112 return nil 113 } 114 115 // cleanupRun removes the temporary SSH known_hosts file (if one was 116 // created) and closes the API connection. It must be called at the 117 // end of the command's Run (i.e. as a defer). 118 func (c *SSHCommon) cleanupRun() { 119 if c.knownHostsPath != "" { 120 os.Remove(c.knownHostsPath) 121 c.knownHostsPath = "" 122 } 123 if c.apiClient != nil { 124 c.apiClient.Close() 125 c.apiClient = nil 126 } 127 } 128 129 // getSSHOptions configures SSH options based on command line 130 // arguments and the SSH targets specified. 131 func (c *SSHCommon) getSSHOptions(enablePty bool, targets ...*resolvedTarget) (*ssh.Options, error) { 132 var options ssh.Options 133 134 if c.noHostKeyChecks { 135 options.SetStrictHostKeyChecking(ssh.StrictHostChecksNo) 136 options.SetKnownHostsFile("/dev/null") 137 } else { 138 knownHostsPath, err := c.generateKnownHosts(targets) 139 if err != nil { 140 return nil, errors.Trace(err) 141 } 142 143 // There might not be a custom known_hosts file if the SSH 144 // targets are specified using arbitrary hostnames or 145 // addresses. In this case, the user's personal known_hosts 146 // file is used. 147 148 if knownHostsPath != "" { 149 // When a known_hosts file has been generated, enforce 150 // strict host key checking. 151 options.SetStrictHostKeyChecking(ssh.StrictHostChecksYes) 152 options.SetKnownHostsFile(knownHostsPath) 153 } else { 154 // If the user's personal known_hosts is used, also use 155 // the user's personal StrictHostKeyChecking preferences. 156 options.SetStrictHostKeyChecking(ssh.StrictHostChecksUnset) 157 } 158 } 159 160 if enablePty { 161 options.EnablePTY() 162 } 163 164 if c.proxy { 165 if err := c.setProxyCommand(&options); err != nil { 166 return nil, err 167 } 168 } 169 170 return &options, nil 171 } 172 173 // generateKnownHosts takes the provided targets, retrieves the SSH 174 // public host keys for them and generates a temporary known_hosts 175 // file for them. 176 func (c *SSHCommon) generateKnownHosts(targets []*resolvedTarget) (string, error) { 177 knownHosts := newKnownHostsBuilder() 178 agentCount := 0 179 nonAgentCount := 0 180 for _, target := range targets { 181 if target.isAgent() { 182 agentCount++ 183 keys, err := c.apiClient.PublicKeys(target.entity) 184 if err != nil { 185 return "", errors.Annotatef(err, "retrieving SSH host keys for %q", target.entity) 186 } 187 knownHosts.add(target.host, keys) 188 } else { 189 nonAgentCount++ 190 } 191 } 192 193 if agentCount > 0 && nonAgentCount > 0 { 194 return "", errors.New("can't determine host keys for all targets: consider --no-host-key-checks") 195 } 196 197 if knownHosts.size() == 0 { 198 // No public keys to write so exit early. 199 return "", nil 200 } 201 202 f, err := ioutil.TempFile("", "ssh_known_hosts") 203 if err != nil { 204 return "", errors.Annotate(err, "creating known hosts file") 205 } 206 defer f.Close() 207 c.knownHostsPath = f.Name() // Record for later deletion 208 if knownHosts.write(f); err != nil { 209 return "", errors.Trace(err) 210 } 211 return c.knownHostsPath, nil 212 } 213 214 // proxySSH returns false if both c.proxy and the proxy-ssh model 215 // configuration are false -- otherwise it returns true. 216 func (c *SSHCommon) proxySSH() (bool, error) { 217 if c.proxy { 218 // No need to check the API if user explictly requested 219 // proxying. 220 return true, nil 221 } 222 proxy, err := c.apiClient.Proxy() 223 if err != nil { 224 return false, errors.Trace(err) 225 } 226 logger.Debugf("proxy-ssh is %v", proxy) 227 return proxy, nil 228 } 229 230 // setProxyCommand sets the proxy command option. 231 func (c *SSHCommon) setProxyCommand(options *ssh.Options) error { 232 apiServerHost, _, err := net.SplitHostPort(c.apiAddr) 233 if err != nil { 234 return errors.Errorf("failed to get proxy address: %v", err) 235 } 236 juju, err := getJujuExecutable() 237 if err != nil { 238 return errors.Errorf("failed to get juju executable path: %v", err) 239 } 240 241 // TODO(mjs) 2016-05-09 LP #1579592 - It would be good to check the 242 // host key of the controller machine being used for proxying 243 // here. This isn't too serious as all traffic passing through the 244 // controller host is encrypted and the host key of the ultimate 245 // target host is verified but it would still be better to perform 246 // this extra level of checking. 247 options.SetProxyCommand( 248 juju, "ssh", 249 "--proxy=false", 250 "--no-host-key-checks", 251 "--pty=false", 252 "ubuntu@"+apiServerHost, 253 "-q", 254 "nc %h %p", 255 ) 256 return nil 257 } 258 259 func (c *SSHCommon) ensureAPIClient() error { 260 if c.apiClient != nil { 261 return nil 262 } 263 return errors.Trace(c.initAPIClient()) 264 } 265 266 // initAPIClient initialises the API connection. 267 func (c *SSHCommon) initAPIClient() error { 268 conn, err := c.NewAPIRoot() 269 if err != nil { 270 return errors.Trace(err) 271 } 272 c.apiClient = sshclient.NewFacade(conn) 273 c.apiAddr = conn.Addr() 274 return nil 275 } 276 277 func (c *SSHCommon) resolveTarget(target string) (*resolvedTarget, error) { 278 out := new(resolvedTarget) 279 out.user, out.entity = splitUserTarget(target) 280 281 // If the target is neither a machine nor a unit assume it's a 282 // hostname and try it directly. 283 if !targetIsAgent(out.entity) { 284 out.host = out.entity 285 return out, nil 286 } 287 288 if out.user == "" { 289 out.user = "ubuntu" 290 } 291 292 // A target may not initially have an address (e.g. the 293 // address updater hasn't yet run), so we must do this in 294 // a loop. 295 var err error 296 for a := sshHostFromTargetAttemptStrategy.Start(); a.Next(); { 297 if c.proxy { 298 out.host, err = c.apiClient.PrivateAddress(out.entity) 299 } else { 300 out.host, err = c.apiClient.PublicAddress(out.entity) 301 } 302 if err == nil { 303 return out, nil 304 } 305 } 306 return nil, err 307 } 308 309 // AllowInterspersedFlags for ssh/scp is set to false so that 310 // flags after the unit name are passed through to ssh, for eg. 311 // `juju ssh -v application-name/0 uname -a`. 312 func (c *SSHCommon) AllowInterspersedFlags() bool { 313 return false 314 } 315 316 // getJujuExecutable returns the path to the juju 317 // executable, or an error if it could not be found. 318 var getJujuExecutable = func() (string, error) { 319 return exec.LookPath(os.Args[0]) 320 } 321 322 func targetIsAgent(target string) bool { 323 return names.IsValidMachine(target) || names.IsValidUnit(target) 324 } 325 326 func splitUserTarget(target string) (string, string) { 327 if i := strings.IndexRune(target, '@'); i != -1 { 328 return target[:i], target[i+1:] 329 } 330 return "", target 331 } 332 333 func newKnownHostsBuilder() *knownHostsBuilder { 334 return &knownHostsBuilder{ 335 seen: set.NewStrings(), 336 } 337 } 338 339 // knownHostsBuilder supports the construction of a SSH known_hosts file. 340 type knownHostsBuilder struct { 341 lines []string 342 seen set.Strings 343 } 344 345 func (b *knownHostsBuilder) add(host string, keys []string) { 346 if b.seen.Contains(host) { 347 return 348 } 349 b.seen.Add(host) 350 for _, key := range keys { 351 b.lines = append(b.lines, host+" "+key+"\n") 352 } 353 } 354 355 func (b *knownHostsBuilder) write(w io.Writer) error { 356 bufw := bufio.NewWriter(w) 357 for _, line := range b.lines { 358 _, err := bufw.WriteString(line) 359 if err != nil { 360 return errors.Annotate(err, "writing known hosts file") 361 } 362 } 363 bufw.Flush() 364 return nil 365 } 366 367 func (b *knownHostsBuilder) size() int { 368 return len(b.lines) 369 }