github.com/altoros/juju-vmware@v0.0.0-20150312064031-f19ae857ccca/utils/ssh/ssh_openssh.go (about)

     1  // Copyright 2013 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package ssh
     5  
     6  import (
     7  	"bytes"
     8  	"fmt"
     9  	"io"
    10  	"os"
    11  	"os/exec"
    12  	"strings"
    13  
    14  	"github.com/juju/utils"
    15  )
    16  
    17  var opensshCommonOptions = []string{"-o", "StrictHostKeyChecking no"}
    18  
    19  // default identities will not be attempted if
    20  // -i is specified and they are not explcitly
    21  // included.
    22  var defaultIdentities = []string{
    23  	"~/.ssh/identity",
    24  	"~/.ssh/id_rsa",
    25  	"~/.ssh/id_dsa",
    26  	"~/.ssh/id_ecdsa",
    27  }
    28  
    29  type opensshCommandKind int
    30  
    31  const (
    32  	sshKind opensshCommandKind = iota
    33  	scpKind
    34  )
    35  
    36  // sshpassWrap wraps the command/args with sshpass if it is found in $PATH
    37  // and the SSHPASS environment variable is set. Otherwise, the original
    38  // command/args are returned.
    39  func sshpassWrap(cmd string, args []string) (string, []string) {
    40  	if os.Getenv("SSHPASS") != "" {
    41  		if path, err := exec.LookPath("sshpass"); err == nil {
    42  			return path, append([]string{"-e", cmd}, args...)
    43  		}
    44  	}
    45  	return cmd, args
    46  }
    47  
    48  // OpenSSHClient is an implementation of Client that
    49  // uses the ssh and scp executables found in $PATH.
    50  type OpenSSHClient struct{}
    51  
    52  // NewOpenSSHClient creates a new OpenSSHClient.
    53  // If the ssh and scp programs cannot be found
    54  // in $PATH, then an error is returned.
    55  func NewOpenSSHClient() (*OpenSSHClient, error) {
    56  	var c OpenSSHClient
    57  	if _, err := exec.LookPath("ssh"); err != nil {
    58  		return nil, err
    59  	}
    60  	if _, err := exec.LookPath("scp"); err != nil {
    61  		return nil, err
    62  	}
    63  	return &c, nil
    64  }
    65  
    66  func opensshOptions(options *Options, commandKind opensshCommandKind) []string {
    67  	args := append([]string{}, opensshCommonOptions...)
    68  	if options == nil {
    69  		options = &Options{}
    70  	}
    71  	if len(options.proxyCommand) > 0 {
    72  		args = append(args, "-o", "ProxyCommand "+utils.CommandString(options.proxyCommand...))
    73  	}
    74  	if !options.passwordAuthAllowed {
    75  		args = append(args, "-o", "PasswordAuthentication no")
    76  	}
    77  
    78  	// We must set ServerAliveInterval or the server may
    79  	// think we've become unresponsive on long running
    80  	// command executions such as "apt-get upgrade".
    81  	args = append(args, "-o", "ServerAliveInterval 30")
    82  
    83  	if options.allocatePTY {
    84  		args = append(args, "-t", "-t") // twice to force
    85  	}
    86  	if options.knownHostsFile != "" {
    87  		args = append(args, "-o", "UserKnownHostsFile "+utils.CommandString(options.knownHostsFile))
    88  	}
    89  	identities := append([]string{}, options.identities...)
    90  	if pk := PrivateKeyFiles(); len(pk) > 0 {
    91  		// Add client keys as implicit identities
    92  		identities = append(identities, pk...)
    93  	}
    94  	// If any identities are specified, the
    95  	// default ones must be explicitly specified.
    96  	if len(identities) > 0 {
    97  		for _, identity := range defaultIdentities {
    98  			path, err := utils.NormalizePath(identity)
    99  			if err != nil {
   100  				logger.Warningf("failed to normalize path %q: %v", identity, err)
   101  				continue
   102  			}
   103  			if _, err := os.Stat(path); err == nil {
   104  				identities = append(identities, path)
   105  			}
   106  		}
   107  	}
   108  	for _, identity := range identities {
   109  		args = append(args, "-i", identity)
   110  	}
   111  	if options.port != 0 {
   112  		port := fmt.Sprint(options.port)
   113  		if commandKind == scpKind {
   114  			// scp uses -P instead of -p (-p means preserve).
   115  			args = append(args, "-P", port)
   116  		} else {
   117  			args = append(args, "-p", port)
   118  		}
   119  	}
   120  	return args
   121  }
   122  
   123  // Command implements Client.Command.
   124  func (c *OpenSSHClient) Command(host string, command []string, options *Options) *Cmd {
   125  	args := opensshOptions(options, sshKind)
   126  	args = append(args, host)
   127  	if len(command) > 0 {
   128  		args = append(args, command...)
   129  	}
   130  	bin, args := sshpassWrap("ssh", args)
   131  	logger.Tracef("running: %s %s", bin, utils.CommandString(args...))
   132  	return &Cmd{impl: &opensshCmd{exec.Command(bin, args...)}}
   133  }
   134  
   135  // Copy implements Client.Copy.
   136  func (c *OpenSSHClient) Copy(args []string, userOptions *Options) error {
   137  	var options Options
   138  	if userOptions != nil {
   139  		options = *userOptions
   140  		options.allocatePTY = false // doesn't make sense for scp
   141  	}
   142  	allArgs := opensshOptions(&options, scpKind)
   143  	allArgs = append(allArgs, args...)
   144  	bin, allArgs := sshpassWrap("scp", allArgs)
   145  	cmd := exec.Command(bin, allArgs...)
   146  	var stderr bytes.Buffer
   147  	cmd.Stderr = &stderr
   148  	logger.Tracef("running: %s %s", bin, utils.CommandString(args...))
   149  	if err := cmd.Run(); err != nil {
   150  		stderr := strings.TrimSpace(stderr.String())
   151  		if len(stderr) > 0 {
   152  			err = fmt.Errorf("%v (%v)", err, stderr)
   153  		}
   154  		return err
   155  	}
   156  	return nil
   157  }
   158  
   159  type opensshCmd struct {
   160  	*exec.Cmd
   161  }
   162  
   163  func (c *opensshCmd) SetStdio(stdin io.Reader, stdout, stderr io.Writer) {
   164  	c.Stdin, c.Stdout, c.Stderr = stdin, stdout, stderr
   165  }
   166  
   167  func (c *opensshCmd) StdinPipe() (io.WriteCloser, io.Reader, error) {
   168  	wc, err := c.Cmd.StdinPipe()
   169  	if err != nil {
   170  		return nil, nil, err
   171  	}
   172  	return wc, c.Stdin, nil
   173  }
   174  
   175  func (c *opensshCmd) StdoutPipe() (io.ReadCloser, io.Writer, error) {
   176  	rc, err := c.Cmd.StdoutPipe()
   177  	if err != nil {
   178  		return nil, nil, err
   179  	}
   180  	return rc, c.Stdout, nil
   181  }
   182  
   183  func (c *opensshCmd) StderrPipe() (io.ReadCloser, io.Writer, error) {
   184  	rc, err := c.Cmd.StderrPipe()
   185  	if err != nil {
   186  		return nil, nil, err
   187  	}
   188  	return rc, c.Stderr, nil
   189  }
   190  
   191  func (c *opensshCmd) Kill() error {
   192  	if c.Process == nil {
   193  		return fmt.Errorf("process has not been started")
   194  	}
   195  	return c.Process.Kill()
   196  }