github.com/niedbalski/juju@v0.0.0-20190215020005-8ff100488e47/cmd/juju/commands/ssh_common.go (about)

     1  // Copyright 2016 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package commands
     5  
     6  import (
     7  	"bufio"
     8  	"io"
     9  	"io/ioutil"
    10  	"net"
    11  	"os"
    12  	"os/exec"
    13  	"strings"
    14  	"time"
    15  
    16  	"github.com/juju/collections/set"
    17  	"github.com/juju/errors"
    18  	"github.com/juju/gnuflag"
    19  	"github.com/juju/utils"
    20  	"github.com/juju/utils/ssh"
    21  	"gopkg.in/juju/names.v2"
    22  
    23  	"github.com/juju/juju/api/sshclient"
    24  	"github.com/juju/juju/apiserver/params"
    25  	"github.com/juju/juju/cmd/modelcmd"
    26  	"github.com/juju/juju/network"
    27  	jujussh "github.com/juju/juju/network/ssh"
    28  )
    29  
    30  // SSHCommon implements functionality shared by sshCommand, SCPCommand
    31  // and DebugHooksCommand.
    32  type SSHCommon struct {
    33  	modelcmd.ModelCommandBase
    34  	modelcmd.IAASOnlyCommand
    35  	proxy           bool
    36  	noHostKeyChecks bool
    37  	Target          string
    38  	Args            []string
    39  	apiClient       sshAPIClient
    40  	apiAddr         string
    41  	knownHostsPath  string
    42  	hostChecker     jujussh.ReachableChecker
    43  	forceAPIv1      bool
    44  }
    45  
    46  const jujuSSHClientForceAPIv1 = "JUJU_SSHCLIENT_API_V1"
    47  
    48  type sshAPIClient interface {
    49  	BestAPIVersion() int
    50  	PublicAddress(target string) (string, error)
    51  	PrivateAddress(target string) (string, error)
    52  	AllAddresses(target string) ([]string, error)
    53  	PublicKeys(target string) ([]string, error)
    54  	Proxy() (bool, error)
    55  	Close() error
    56  }
    57  
    58  type resolvedTarget struct {
    59  	user   string
    60  	entity string
    61  	host   string
    62  }
    63  
    64  func (t *resolvedTarget) userHost() string {
    65  	if t.user == "" {
    66  		return t.host
    67  	}
    68  	return t.user + "@" + t.host
    69  }
    70  
    71  func (t *resolvedTarget) isAgent() bool {
    72  	return targetIsAgent(t.entity)
    73  }
    74  
    75  // attemptStarter is an interface corresponding to utils.AttemptStrategy
    76  //
    77  // TODO(katco): 2016-08-09: lp:1611427
    78  type attemptStarter interface {
    79  	Start() attempt
    80  }
    81  
    82  type attempt interface {
    83  	Next() bool
    84  }
    85  
    86  // TODO(katco): 2016-08-09: lp:1611427
    87  type attemptStrategy utils.AttemptStrategy
    88  
    89  func (s attemptStrategy) Start() attempt {
    90  	// TODO(katco): 2016-08-09: lp:1611427
    91  	return utils.AttemptStrategy(s).Start()
    92  }
    93  
    94  const (
    95  	// SSHRetryDelay is the time to wait for an SSH connection to be established
    96  	// to a single endpoint of a target.
    97  	SSHRetryDelay = 500 * time.Millisecond
    98  
    99  	// SSHTimeout is the time to wait for before giving up trying to establish
   100  	// an SSH connection to a target, after retrying.
   101  	SSHTimeout = 5 * time.Second
   102  
   103  	// SSHPort is the TCP port used for SSH connections.
   104  	SSHPort = 22
   105  )
   106  
   107  var sshHostFromTargetAttemptStrategy attemptStarter = attemptStrategy{
   108  	Total: SSHTimeout,
   109  	Delay: SSHRetryDelay,
   110  }
   111  
   112  func (c *SSHCommon) SetFlags(f *gnuflag.FlagSet) {
   113  	c.ModelCommandBase.SetFlags(f)
   114  	f.BoolVar(&c.proxy, "proxy", false, "Proxy through the API server")
   115  	f.BoolVar(&c.noHostKeyChecks, "no-host-key-checks", false, "Skip host key checking (INSECURE)")
   116  }
   117  
   118  // defaultReachableChecker returns a jujussh.ReachableChecker with a connection
   119  // timeout of SSHRetryDelay and an overall timout of SSHTimeout
   120  func defaultReachableChecker() jujussh.ReachableChecker {
   121  	return jujussh.NewReachableChecker(&net.Dialer{Timeout: SSHRetryDelay}, SSHTimeout)
   122  }
   123  
   124  func (c *SSHCommon) setHostChecker(checker jujussh.ReachableChecker) {
   125  	if checker == nil {
   126  		checker = defaultReachableChecker()
   127  	}
   128  	c.hostChecker = checker
   129  }
   130  
   131  // initRun initializes the API connection if required, and determines
   132  // if SSH proxying is required. It must be called at the top of the
   133  // command's Run method.
   134  //
   135  // The apiClient, apiAddr and proxy fields are initialized after this call.
   136  func (c *SSHCommon) initRun() error {
   137  	if err := c.ensureAPIClient(); err != nil {
   138  		return errors.Trace(err)
   139  	}
   140  
   141  	if proxy, err := c.proxySSH(); err != nil {
   142  		return errors.Trace(err)
   143  	} else {
   144  		c.proxy = proxy
   145  	}
   146  
   147  	// Used mostly for testing, but useful for debugging and/or
   148  	// backwards-compatibility with some scripts.
   149  	c.forceAPIv1 = os.Getenv(jujuSSHClientForceAPIv1) != ""
   150  	return nil
   151  }
   152  
   153  // cleanupRun removes the temporary SSH known_hosts file (if one was
   154  // created) and closes the API connection. It must be called at the
   155  // end of the command's Run (i.e. as a defer).
   156  func (c *SSHCommon) cleanupRun() {
   157  	if c.knownHostsPath != "" {
   158  		os.Remove(c.knownHostsPath)
   159  		c.knownHostsPath = ""
   160  	}
   161  	if c.apiClient != nil {
   162  		c.apiClient.Close()
   163  		c.apiClient = nil
   164  	}
   165  }
   166  
   167  // getSSHOptions configures SSH options based on command line
   168  // arguments and the SSH targets specified.
   169  func (c *SSHCommon) getSSHOptions(enablePty bool, targets ...*resolvedTarget) (*ssh.Options, error) {
   170  	var options ssh.Options
   171  
   172  	if c.noHostKeyChecks {
   173  		options.SetStrictHostKeyChecking(ssh.StrictHostChecksNo)
   174  		options.SetKnownHostsFile(os.DevNull)
   175  	} else {
   176  		knownHostsPath, err := c.generateKnownHosts(targets)
   177  		if err != nil {
   178  			return nil, errors.Trace(err)
   179  		}
   180  
   181  		// There might not be a custom known_hosts file if the SSH
   182  		// targets are specified using arbitrary hostnames or
   183  		// addresses. In this case, the user's personal known_hosts
   184  		// file is used.
   185  
   186  		if knownHostsPath != "" {
   187  			// When a known_hosts file has been generated, enforce
   188  			// strict host key checking.
   189  			options.SetStrictHostKeyChecking(ssh.StrictHostChecksYes)
   190  			options.SetKnownHostsFile(knownHostsPath)
   191  		}
   192  	}
   193  
   194  	if enablePty {
   195  		options.EnablePTY()
   196  	}
   197  
   198  	if c.proxy {
   199  		if err := c.setProxyCommand(&options); err != nil {
   200  			return nil, err
   201  		}
   202  	}
   203  
   204  	return &options, nil
   205  }
   206  
   207  // generateKnownHosts takes the provided targets, retrieves the SSH
   208  // public host keys for them and generates a temporary known_hosts
   209  // file for them.
   210  func (c *SSHCommon) generateKnownHosts(targets []*resolvedTarget) (string, error) {
   211  	knownHosts := newKnownHostsBuilder()
   212  	agentCount := 0
   213  	nonAgentCount := 0
   214  	for _, target := range targets {
   215  		if target.isAgent() {
   216  			agentCount++
   217  			keys, err := c.apiClient.PublicKeys(target.entity)
   218  			if err != nil {
   219  				return "", errors.Annotatef(err, "retrieving SSH host keys for %q", target.entity)
   220  			}
   221  			knownHosts.add(target.host, keys)
   222  		} else {
   223  			nonAgentCount++
   224  		}
   225  	}
   226  
   227  	if agentCount > 0 && nonAgentCount > 0 {
   228  		return "", errors.New("can't determine host keys for all targets: consider --no-host-key-checks")
   229  	}
   230  
   231  	if knownHosts.size() == 0 {
   232  		// No public keys to write so exit early.
   233  		return "", nil
   234  	}
   235  
   236  	f, err := ioutil.TempFile("", "ssh_known_hosts")
   237  	if err != nil {
   238  		return "", errors.Annotate(err, "creating known hosts file")
   239  	}
   240  	defer f.Close()
   241  	c.knownHostsPath = f.Name() // Record for later deletion
   242  	if knownHosts.write(f); err != nil {
   243  		return "", errors.Trace(err)
   244  	}
   245  	return c.knownHostsPath, nil
   246  }
   247  
   248  // proxySSH returns false if both c.proxy and the proxy-ssh model
   249  // configuration are false -- otherwise it returns true.
   250  func (c *SSHCommon) proxySSH() (bool, error) {
   251  	if c.proxy {
   252  		// No need to check the API if user explicitly requested
   253  		// proxying.
   254  		return true, nil
   255  	}
   256  	proxy, err := c.apiClient.Proxy()
   257  	if err != nil {
   258  		return false, errors.Trace(err)
   259  	}
   260  	logger.Debugf("proxy-ssh is %v", proxy)
   261  	return proxy, nil
   262  }
   263  
   264  // setProxyCommand sets the proxy command option.
   265  func (c *SSHCommon) setProxyCommand(options *ssh.Options) error {
   266  	apiServerHost, _, err := net.SplitHostPort(c.apiAddr)
   267  	if err != nil {
   268  		return errors.Errorf("failed to get proxy address: %v", err)
   269  	}
   270  	juju, err := getJujuExecutable()
   271  	if err != nil {
   272  		return errors.Errorf("failed to get juju executable path: %v", err)
   273  	}
   274  
   275  	modelName, err := c.ModelName()
   276  	if err != nil {
   277  		return errors.Trace(err)
   278  	}
   279  	// TODO(mjs) 2016-05-09 LP #1579592 - It would be good to check the
   280  	// host key of the controller machine being used for proxying
   281  	// here. This isn't too serious as all traffic passing through the
   282  	// controller host is encrypted and the host key of the ultimate
   283  	// target host is verified but it would still be better to perform
   284  	// this extra level of checking.
   285  	options.SetProxyCommand(
   286  		juju, "ssh",
   287  		"--model="+modelName,
   288  		"--proxy=false",
   289  		"--no-host-key-checks",
   290  		"--pty=false",
   291  		"ubuntu@"+apiServerHost,
   292  		"-q",
   293  		"nc %h %p",
   294  	)
   295  	return nil
   296  }
   297  
   298  func (c *SSHCommon) ensureAPIClient() error {
   299  	if c.apiClient != nil {
   300  		return nil
   301  	}
   302  	return errors.Trace(c.initAPIClient())
   303  }
   304  
   305  // initAPIClient initialises the API connection.
   306  func (c *SSHCommon) initAPIClient() error {
   307  	conn, err := c.NewAPIRoot()
   308  	if err != nil {
   309  		return errors.Trace(err)
   310  	}
   311  	c.apiClient = sshclient.NewFacade(conn)
   312  	c.apiAddr = conn.Addr()
   313  	return nil
   314  }
   315  
   316  func (c *SSHCommon) resolveTarget(target string) (*resolvedTarget, error) {
   317  	out, ok := c.resolveAsAgent(target)
   318  	if !ok {
   319  		// Not a machine or unit agent target - use directly.
   320  		return out, nil
   321  	}
   322  
   323  	getAddress := c.reachableAddressGetter
   324  	if c.apiClient.BestAPIVersion() < 2 || c.forceAPIv1 {
   325  		logger.Debugf("using legacy SSHClient API v1: no support for AllAddresses()")
   326  		getAddress = c.legacyAddressGetter
   327  	} else if c.proxy {
   328  		// Ideally a reachability scan would be done from the
   329  		// controller's perspective but that isn't possible yet, so
   330  		// fall back to the legacy mode (i.e. use the instance's
   331  		// "private" address).
   332  		//
   333  		// This is in some ways better anyway as a both the external
   334  		// and internal addresses of an instance (if it has both) are
   335  		// likely to be accessible from the controller. With a
   336  		// reachability scan juju ssh could inadvertently end up using
   337  		// the public address when it really should be using the
   338  		// internal/private address.
   339  		logger.Debugf("proxy-ssh enabled so not doing reachability scan")
   340  		getAddress = c.legacyAddressGetter
   341  	}
   342  
   343  	return c.resolveWithRetry(*out, getAddress)
   344  }
   345  
   346  func (c *SSHCommon) resolveAsAgent(target string) (*resolvedTarget, bool) {
   347  	out := new(resolvedTarget)
   348  	out.user, out.entity = splitUserTarget(target)
   349  	isAgent := out.isAgent()
   350  
   351  	if !isAgent {
   352  		// Not a machine/unit agent target: resolve - use as-is.
   353  		out.host = out.entity
   354  	} else if out.user == "" {
   355  		out.user = "ubuntu"
   356  	}
   357  
   358  	return out, isAgent
   359  }
   360  
   361  type addressGetterFunc func(target string) (string, error)
   362  
   363  func (c *SSHCommon) resolveWithRetry(target resolvedTarget, getAddress addressGetterFunc) (*resolvedTarget, error) {
   364  	// A target may not initially have an address (e.g. the
   365  	// address updater hasn't yet run), so we must do this in
   366  	// a loop.
   367  	var err error
   368  	out := &target
   369  	for a := sshHostFromTargetAttemptStrategy.Start(); a.Next(); {
   370  		out.host, err = getAddress(out.entity)
   371  		if errors.IsNotFound(err) || params.IsCodeNotFound(err) {
   372  			// Catch issues like passing invalid machine/unit IDs early.
   373  			return nil, errors.Trace(err)
   374  		}
   375  
   376  		if err != nil {
   377  			logger.Debugf("getting target %q address(es) failed: %v (retrying)", out.entity, err)
   378  			continue
   379  		}
   380  
   381  		logger.Debugf("using target %q address %q", out.entity, out.host)
   382  		return out, nil
   383  	}
   384  
   385  	return nil, errors.Trace(err)
   386  }
   387  
   388  // legacyAddressGetter returns the preferred public or private address of the
   389  // given entity (private when c.proxy is true), using the apiClient. Only used
   390  // when the SSHClient API facade v2 is not available or when proxy-ssh is set.
   391  func (c *SSHCommon) legacyAddressGetter(entity string) (string, error) {
   392  	if c.proxy {
   393  		return c.apiClient.PrivateAddress(entity)
   394  	}
   395  
   396  	return c.apiClient.PublicAddress(entity)
   397  }
   398  
   399  // reachableAddressGetter dials all addresses of the given entity, returning the
   400  // first one that succeeds. Only used with SSHClient API facade v2 or later is
   401  // available. It does not try to dial if only one address is available.
   402  func (c *SSHCommon) reachableAddressGetter(entity string) (string, error) {
   403  	addresses, err := c.apiClient.AllAddresses(entity)
   404  	if err != nil {
   405  		return "", errors.Trace(err)
   406  	} else if len(addresses) == 0 {
   407  		return "", network.NoAddressError("available")
   408  	} else if len(addresses) == 1 {
   409  		logger.Debugf("Only one SSH address provided (%s), using it without probing", addresses[0])
   410  		return addresses[0], nil
   411  	}
   412  	publicKeys := []string{}
   413  	if !c.noHostKeyChecks {
   414  		publicKeys, err = c.apiClient.PublicKeys(entity)
   415  		if err != nil {
   416  			return "", errors.Annotatef(err, "retrieving SSH host keys for %q", entity)
   417  		}
   418  	}
   419  
   420  	hostPorts := network.NewHostPorts(SSHPort, addresses...)
   421  	usableHPs := network.FilterUnusableHostPorts(hostPorts)
   422  	bestHP, err := c.hostChecker.FindHost(usableHPs, publicKeys)
   423  	if err != nil {
   424  		return "", errors.Trace(err)
   425  	}
   426  
   427  	return bestHP.Address.Value, nil
   428  }
   429  
   430  // AllowInterspersedFlags for ssh/scp is set to false so that
   431  // flags after the unit name are passed through to ssh, for eg.
   432  // `juju ssh -v application-name/0 uname -a`.
   433  func (c *SSHCommon) AllowInterspersedFlags() bool {
   434  	return false
   435  }
   436  
   437  // getJujuExecutable returns the path to the juju
   438  // executable, or an error if it could not be found.
   439  var getJujuExecutable = func() (string, error) {
   440  	return exec.LookPath(os.Args[0])
   441  }
   442  
   443  func targetIsAgent(target string) bool {
   444  	return names.IsValidMachine(target) || names.IsValidUnit(target)
   445  }
   446  
   447  func splitUserTarget(target string) (string, string) {
   448  	if i := strings.IndexRune(target, '@'); i != -1 {
   449  		return target[:i], target[i+1:]
   450  	}
   451  	return "", target
   452  }
   453  
   454  func newKnownHostsBuilder() *knownHostsBuilder {
   455  	return &knownHostsBuilder{
   456  		seen: set.NewStrings(),
   457  	}
   458  }
   459  
   460  // knownHostsBuilder supports the construction of a SSH known_hosts file.
   461  type knownHostsBuilder struct {
   462  	lines []string
   463  	seen  set.Strings
   464  }
   465  
   466  func (b *knownHostsBuilder) add(host string, keys []string) {
   467  	if b.seen.Contains(host) {
   468  		return
   469  	}
   470  	b.seen.Add(host)
   471  	for _, key := range keys {
   472  		b.lines = append(b.lines, host+" "+key+"\n")
   473  	}
   474  }
   475  
   476  func (b *knownHostsBuilder) write(w io.Writer) error {
   477  	bufw := bufio.NewWriter(w)
   478  	for _, line := range b.lines {
   479  		_, err := bufw.WriteString(line)
   480  		if err != nil {
   481  			return errors.Annotate(err, "writing known hosts file")
   482  		}
   483  	}
   484  	bufw.Flush()
   485  	return nil
   486  }
   487  
   488  func (b *knownHostsBuilder) size() int {
   489  	return len(b.lines)
   490  }