github.com/axw/juju@v0.0.0-20161005053422-4bd6544d08d4/cmd/juju/commands/ssh_common.go (about)

     1  // Copyright 2016 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package commands
     5  
     6  import (
     7  	"bufio"
     8  	"io"
     9  	"io/ioutil"
    10  	"net"
    11  	"os"
    12  	"os/exec"
    13  	"strings"
    14  	"time"
    15  
    16  	"github.com/juju/errors"
    17  	"github.com/juju/gnuflag"
    18  	"github.com/juju/utils"
    19  	"github.com/juju/utils/set"
    20  	"github.com/juju/utils/ssh"
    21  	"gopkg.in/juju/names.v2"
    22  
    23  	"github.com/juju/juju/api/sshclient"
    24  	"github.com/juju/juju/cmd/modelcmd"
    25  )
    26  
    27  // SSHCommon implements functionality shared by sshCommand, SCPCommand
    28  // and DebugHooksCommand.
    29  type SSHCommon struct {
    30  	modelcmd.ModelCommandBase
    31  	proxy           bool
    32  	pty             bool
    33  	noHostKeyChecks bool
    34  	Target          string
    35  	Args            []string
    36  	apiClient       sshAPIClient
    37  	apiAddr         string
    38  	knownHostsPath  string
    39  }
    40  
    41  type sshAPIClient interface {
    42  	PublicAddress(target string) (string, error)
    43  	PrivateAddress(target string) (string, error)
    44  	PublicKeys(target string) ([]string, error)
    45  	Proxy() (bool, error)
    46  	Close() error
    47  }
    48  
    49  type resolvedTarget struct {
    50  	user   string
    51  	entity string
    52  	host   string
    53  }
    54  
    55  func (t *resolvedTarget) userHost() string {
    56  	if t.user == "" {
    57  		return t.host
    58  	}
    59  	return t.user + "@" + t.host
    60  }
    61  
    62  func (t *resolvedTarget) isAgent() bool {
    63  	return targetIsAgent(t.entity)
    64  }
    65  
    66  // attemptStarter is an interface corresponding to utils.AttemptStrategy
    67  //
    68  // TODO(katco): 2016-08-09: lp:1611427
    69  type attemptStarter interface {
    70  	Start() attempt
    71  }
    72  
    73  type attempt interface {
    74  	Next() bool
    75  }
    76  
    77  // TODO(katco): 2016-08-09: lp:1611427
    78  type attemptStrategy utils.AttemptStrategy
    79  
    80  func (s attemptStrategy) Start() attempt {
    81  	// TODO(katco): 2016-08-09: lp:1611427
    82  	return utils.AttemptStrategy(s).Start()
    83  }
    84  
    85  var sshHostFromTargetAttemptStrategy attemptStarter = attemptStrategy{
    86  	Total: 5 * time.Second,
    87  	Delay: 500 * time.Millisecond,
    88  }
    89  
    90  func (c *SSHCommon) SetFlags(f *gnuflag.FlagSet) {
    91  	c.ModelCommandBase.SetFlags(f)
    92  	f.BoolVar(&c.proxy, "proxy", false, "Proxy through the API server")
    93  	f.BoolVar(&c.pty, "pty", true, "Enable pseudo-tty allocation")
    94  	f.BoolVar(&c.noHostKeyChecks, "no-host-key-checks", false, "Skip host key checking (INSECURE)")
    95  }
    96  
    97  // initRun initializes the API connection if required, and determines
    98  // if SSH proxying is required. It must be called at the top of the
    99  // command's Run method.
   100  //
   101  // The apiClient, apiAddr and proxy fields are initialized after this
   102  // call.
   103  func (c *SSHCommon) initRun() error {
   104  	if err := c.ensureAPIClient(); err != nil {
   105  		return errors.Trace(err)
   106  	}
   107  	if proxy, err := c.proxySSH(); err != nil {
   108  		return errors.Trace(err)
   109  	} else {
   110  		c.proxy = proxy
   111  	}
   112  	return nil
   113  }
   114  
   115  // cleanupRun removes the temporary SSH known_hosts file (if one was
   116  // created) and closes the API connection. It must be called at the
   117  // end of the command's Run (i.e. as a defer).
   118  func (c *SSHCommon) cleanupRun() {
   119  	if c.knownHostsPath != "" {
   120  		os.Remove(c.knownHostsPath)
   121  		c.knownHostsPath = ""
   122  	}
   123  	if c.apiClient != nil {
   124  		c.apiClient.Close()
   125  		c.apiClient = nil
   126  	}
   127  }
   128  
   129  // getSSHOptions configures SSH options based on command line
   130  // arguments and the SSH targets specified.
   131  func (c *SSHCommon) getSSHOptions(enablePty bool, targets ...*resolvedTarget) (*ssh.Options, error) {
   132  	var options ssh.Options
   133  
   134  	if c.noHostKeyChecks {
   135  		options.SetStrictHostKeyChecking(ssh.StrictHostChecksNo)
   136  		options.SetKnownHostsFile("/dev/null")
   137  	} else {
   138  		knownHostsPath, err := c.generateKnownHosts(targets)
   139  		if err != nil {
   140  			return nil, errors.Trace(err)
   141  		}
   142  
   143  		// There might not be a custom known_hosts file if the SSH
   144  		// targets are specified using arbitrary hostnames or
   145  		// addresses. In this case, the user's personal known_hosts
   146  		// file is used.
   147  
   148  		if knownHostsPath != "" {
   149  			// When a known_hosts file has been generated, enforce
   150  			// strict host key checking.
   151  			options.SetStrictHostKeyChecking(ssh.StrictHostChecksYes)
   152  			options.SetKnownHostsFile(knownHostsPath)
   153  		} else {
   154  			// If the user's personal known_hosts is used, also use
   155  			// the user's personal StrictHostKeyChecking preferences.
   156  			options.SetStrictHostKeyChecking(ssh.StrictHostChecksUnset)
   157  		}
   158  	}
   159  
   160  	if enablePty {
   161  		options.EnablePTY()
   162  	}
   163  
   164  	if c.proxy {
   165  		if err := c.setProxyCommand(&options); err != nil {
   166  			return nil, err
   167  		}
   168  	}
   169  
   170  	return &options, nil
   171  }
   172  
   173  // generateKnownHosts takes the provided targets, retrieves the SSH
   174  // public host keys for them and generates a temporary known_hosts
   175  // file for them.
   176  func (c *SSHCommon) generateKnownHosts(targets []*resolvedTarget) (string, error) {
   177  	knownHosts := newKnownHostsBuilder()
   178  	agentCount := 0
   179  	nonAgentCount := 0
   180  	for _, target := range targets {
   181  		if target.isAgent() {
   182  			agentCount++
   183  			keys, err := c.apiClient.PublicKeys(target.entity)
   184  			if err != nil {
   185  				return "", errors.Annotatef(err, "retrieving SSH host keys for %q", target.entity)
   186  			}
   187  			knownHosts.add(target.host, keys)
   188  		} else {
   189  			nonAgentCount++
   190  		}
   191  	}
   192  
   193  	if agentCount > 0 && nonAgentCount > 0 {
   194  		return "", errors.New("can't determine host keys for all targets: consider --no-host-key-checks")
   195  	}
   196  
   197  	if knownHosts.size() == 0 {
   198  		// No public keys to write so exit early.
   199  		return "", nil
   200  	}
   201  
   202  	f, err := ioutil.TempFile("", "ssh_known_hosts")
   203  	if err != nil {
   204  		return "", errors.Annotate(err, "creating known hosts file")
   205  	}
   206  	defer f.Close()
   207  	c.knownHostsPath = f.Name() // Record for later deletion
   208  	if knownHosts.write(f); err != nil {
   209  		return "", errors.Trace(err)
   210  	}
   211  	return c.knownHostsPath, nil
   212  }
   213  
   214  // proxySSH returns false if both c.proxy and the proxy-ssh model
   215  // configuration are false -- otherwise it returns true.
   216  func (c *SSHCommon) proxySSH() (bool, error) {
   217  	if c.proxy {
   218  		// No need to check the API if user explictly requested
   219  		// proxying.
   220  		return true, nil
   221  	}
   222  	proxy, err := c.apiClient.Proxy()
   223  	if err != nil {
   224  		return false, errors.Trace(err)
   225  	}
   226  	logger.Debugf("proxy-ssh is %v", proxy)
   227  	return proxy, nil
   228  }
   229  
   230  // setProxyCommand sets the proxy command option.
   231  func (c *SSHCommon) setProxyCommand(options *ssh.Options) error {
   232  	apiServerHost, _, err := net.SplitHostPort(c.apiAddr)
   233  	if err != nil {
   234  		return errors.Errorf("failed to get proxy address: %v", err)
   235  	}
   236  	juju, err := getJujuExecutable()
   237  	if err != nil {
   238  		return errors.Errorf("failed to get juju executable path: %v", err)
   239  	}
   240  
   241  	// TODO(mjs) 2016-05-09 LP #1579592 - It would be good to check the
   242  	// host key of the controller machine being used for proxying
   243  	// here. This isn't too serious as all traffic passing through the
   244  	// controller host is encrypted and the host key of the ultimate
   245  	// target host is verified but it would still be better to perform
   246  	// this extra level of checking.
   247  	options.SetProxyCommand(
   248  		juju, "ssh",
   249  		"--proxy=false",
   250  		"--no-host-key-checks",
   251  		"--pty=false",
   252  		"ubuntu@"+apiServerHost,
   253  		"-q",
   254  		"nc %h %p",
   255  	)
   256  	return nil
   257  }
   258  
   259  func (c *SSHCommon) ensureAPIClient() error {
   260  	if c.apiClient != nil {
   261  		return nil
   262  	}
   263  	return errors.Trace(c.initAPIClient())
   264  }
   265  
   266  // initAPIClient initialises the API connection.
   267  func (c *SSHCommon) initAPIClient() error {
   268  	conn, err := c.NewAPIRoot()
   269  	if err != nil {
   270  		return errors.Trace(err)
   271  	}
   272  	c.apiClient = sshclient.NewFacade(conn)
   273  	c.apiAddr = conn.Addr()
   274  	return nil
   275  }
   276  
   277  func (c *SSHCommon) resolveTarget(target string) (*resolvedTarget, error) {
   278  	out := new(resolvedTarget)
   279  	out.user, out.entity = splitUserTarget(target)
   280  
   281  	// If the target is neither a machine nor a unit assume it's a
   282  	// hostname and try it directly.
   283  	if !targetIsAgent(out.entity) {
   284  		out.host = out.entity
   285  		return out, nil
   286  	}
   287  
   288  	if out.user == "" {
   289  		out.user = "ubuntu"
   290  	}
   291  
   292  	// A target may not initially have an address (e.g. the
   293  	// address updater hasn't yet run), so we must do this in
   294  	// a loop.
   295  	var err error
   296  	for a := sshHostFromTargetAttemptStrategy.Start(); a.Next(); {
   297  		if c.proxy {
   298  			out.host, err = c.apiClient.PrivateAddress(out.entity)
   299  		} else {
   300  			out.host, err = c.apiClient.PublicAddress(out.entity)
   301  		}
   302  		if err == nil {
   303  			return out, nil
   304  		}
   305  	}
   306  	return nil, err
   307  }
   308  
   309  // AllowInterspersedFlags for ssh/scp is set to false so that
   310  // flags after the unit name are passed through to ssh, for eg.
   311  // `juju ssh -v application-name/0 uname -a`.
   312  func (c *SSHCommon) AllowInterspersedFlags() bool {
   313  	return false
   314  }
   315  
   316  // getJujuExecutable returns the path to the juju
   317  // executable, or an error if it could not be found.
   318  var getJujuExecutable = func() (string, error) {
   319  	return exec.LookPath(os.Args[0])
   320  }
   321  
   322  func targetIsAgent(target string) bool {
   323  	return names.IsValidMachine(target) || names.IsValidUnit(target)
   324  }
   325  
   326  func splitUserTarget(target string) (string, string) {
   327  	if i := strings.IndexRune(target, '@'); i != -1 {
   328  		return target[:i], target[i+1:]
   329  	}
   330  	return "", target
   331  }
   332  
   333  func newKnownHostsBuilder() *knownHostsBuilder {
   334  	return &knownHostsBuilder{
   335  		seen: set.NewStrings(),
   336  	}
   337  }
   338  
   339  // knownHostsBuilder supports the construction of a SSH known_hosts file.
   340  type knownHostsBuilder struct {
   341  	lines []string
   342  	seen  set.Strings
   343  }
   344  
   345  func (b *knownHostsBuilder) add(host string, keys []string) {
   346  	if b.seen.Contains(host) {
   347  		return
   348  	}
   349  	b.seen.Add(host)
   350  	for _, key := range keys {
   351  		b.lines = append(b.lines, host+" "+key+"\n")
   352  	}
   353  }
   354  
   355  func (b *knownHostsBuilder) write(w io.Writer) error {
   356  	bufw := bufio.NewWriter(w)
   357  	for _, line := range b.lines {
   358  		_, err := bufw.WriteString(line)
   359  		if err != nil {
   360  			return errors.Annotate(err, "writing known hosts file")
   361  		}
   362  	}
   363  	bufw.Flush()
   364  	return nil
   365  }
   366  
   367  func (b *knownHostsBuilder) size() int {
   368  	return len(b.lines)
   369  }