launchpad.net/~rogpeppe/juju-core/500-errgo-fix@v0.0.0-20140213181702-000000002356/utils/ssh/ssh_openssh.go (about)

     1  // Copyright 2013 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package ssh
     5  
     6  import (
     7  	"bytes"
     8  	"fmt"
     9  	"io"
    10  	"os"
    11  	"os/exec"
    12  	"strings"
    13  
    14  	"launchpad.net/errgo/errors"
    15  	"launchpad.net/juju-core/utils"
    16  )
    17  
    18  var opensshCommonOptions = []string{"-o", "StrictHostKeyChecking no"}
    19  
    20  // default identities will not be attempted if
    21  // -i is specified and they are not explcitly
    22  // included.
    23  var defaultIdentities = []string{
    24  	"~/.ssh/identity",
    25  	"~/.ssh/id_rsa",
    26  	"~/.ssh/id_dsa",
    27  	"~/.ssh/id_ecdsa",
    28  }
    29  
    30  type opensshCommandKind int
    31  
    32  const (
    33  	sshKind opensshCommandKind = iota
    34  	scpKind
    35  )
    36  
    37  // sshpassWrap wraps the command/args with sshpass if it is found in $PATH
    38  // and the SSHPASS environment variable is set. Otherwise, the original
    39  // command/args are returned.
    40  func sshpassWrap(cmd string, args []string) (string, []string) {
    41  	if os.Getenv("SSHPASS") != "" {
    42  		if path, err := exec.LookPath("sshpass"); err == nil {
    43  			return path, append([]string{"-e", cmd}, args...)
    44  		}
    45  	}
    46  	return cmd, args
    47  }
    48  
    49  // OpenSSHClient is an implementation of Client that
    50  // uses the ssh and scp executables found in $PATH.
    51  type OpenSSHClient struct{}
    52  
    53  // NewOpenSSHClient creates a new OpenSSHClient.
    54  // If the ssh and scp programs cannot be found
    55  // in $PATH, then an error is returned.
    56  func NewOpenSSHClient() (*OpenSSHClient, error) {
    57  	var c OpenSSHClient
    58  	if _, err := exec.LookPath("ssh"); err != nil {
    59  		return nil, mask(err)
    60  	}
    61  	if _, err := exec.LookPath("scp"); err != nil {
    62  		return nil, mask(err)
    63  	}
    64  	return &c, nil
    65  }
    66  
    67  func opensshOptions(options *Options, commandKind opensshCommandKind) []string {
    68  	args := append([]string{}, opensshCommonOptions...)
    69  	if options == nil {
    70  		options = &Options{}
    71  	}
    72  	if !options.passwordAuthAllowed {
    73  		args = append(args, "-o", "PasswordAuthentication no")
    74  	}
    75  	if options.allocatePTY {
    76  		args = append(args, "-t")
    77  	}
    78  	identities := append([]string{}, options.identities...)
    79  	if pk := PrivateKeyFiles(); len(pk) > 0 {
    80  		// Add client keys as implicit identities
    81  		identities = append(identities, pk...)
    82  	}
    83  	// If any identities are specified, the
    84  	// default ones must be explicitly specified.
    85  	if len(identities) > 0 {
    86  		for _, identity := range defaultIdentities {
    87  			path, err := utils.NormalizePath(identity)
    88  			if err != nil {
    89  				logger.Warningf("failed to normalize path %q: %v", identity, err)
    90  				continue
    91  			}
    92  			if _, err := os.Stat(path); err == nil {
    93  				identities = append(identities, path)
    94  			}
    95  		}
    96  	}
    97  	for _, identity := range identities {
    98  		args = append(args, "-i", identity)
    99  	}
   100  	if options.port != 0 {
   101  		if commandKind == scpKind {
   102  			// scp uses -P instead of -p (-p means preserve).
   103  			args = append(args, "-P")
   104  		} else {
   105  			args = append(args, "-p")
   106  		}
   107  		args = append(args, fmt.Sprint(options.port))
   108  	}
   109  	return args
   110  }
   111  
   112  // Command implements Client.Command.
   113  func (c *OpenSSHClient) Command(host string, command []string, options *Options) *Cmd {
   114  	args := opensshOptions(options, sshKind)
   115  	args = append(args, host)
   116  	if len(command) > 0 {
   117  		args = append(args, "--")
   118  		args = append(args, command...)
   119  	}
   120  	bin, args := sshpassWrap("ssh", args)
   121  	return &Cmd{impl: &opensshCmd{exec.Command(bin, args...)}}
   122  }
   123  
   124  // Copy implements Client.Copy.
   125  func (c *OpenSSHClient) Copy(source, dest string, userOptions *Options) error {
   126  	var options Options
   127  	if userOptions != nil {
   128  		options = *userOptions
   129  		options.allocatePTY = false // doesn't make sense for scp
   130  	}
   131  	args := opensshOptions(&options, scpKind)
   132  	args = append(args, source, dest)
   133  	bin, args := sshpassWrap("scp", args)
   134  	cmd := exec.Command(bin, args...)
   135  	var stderr bytes.Buffer
   136  	cmd.Stderr = &stderr
   137  	if err := cmd.Run(); err != nil {
   138  		stderr := strings.TrimSpace(stderr.String())
   139  		if len(stderr) > 0 {
   140  			err = errors.Newf("%v (%v)", err, stderr)
   141  		}
   142  		return err
   143  	}
   144  	return nil
   145  }
   146  
   147  type opensshCmd struct {
   148  	*exec.Cmd
   149  }
   150  
   151  func (c *opensshCmd) SetStdio(stdin io.Reader, stdout, stderr io.Writer) {
   152  	c.Stdin, c.Stdout, c.Stderr = stdin, stdout, stderr
   153  }
   154  
   155  func (c *opensshCmd) StdinPipe() (io.WriteCloser, io.Reader, error) {
   156  	wc, err := c.Cmd.StdinPipe()
   157  	if err != nil {
   158  		return nil, nil, mask(err)
   159  	}
   160  	return wc, c.Stdin, nil
   161  }
   162  
   163  func (c *opensshCmd) StdoutPipe() (io.ReadCloser, io.Writer, error) {
   164  	rc, err := c.Cmd.StdoutPipe()
   165  	if err != nil {
   166  		return nil, nil, mask(err)
   167  	}
   168  	return rc, c.Stdout, nil
   169  }
   170  
   171  func (c *opensshCmd) StderrPipe() (io.ReadCloser, io.Writer, error) {
   172  	rc, err := c.Cmd.StderrPipe()
   173  	if err != nil {
   174  		return nil, nil, mask(err)
   175  	}
   176  	return rc, c.Stderr, nil
   177  }
   178  
   179  func (c *opensshCmd) Kill() error {
   180  	if c.Process == nil {
   181  		return errors.Newf("process has not been started")
   182  	}
   183  	return c.Process.Kill()
   184  }