github.com/Pankov404/juju@v0.0.0-20150703034450-be266991dceb/utils/ssh/ssh_openssh.go (about)

     1  // Copyright 2013 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package ssh
     5  
     6  import (
     7  	"bytes"
     8  	"fmt"
     9  	"io"
    10  	"os"
    11  	"os/exec"
    12  	"strings"
    13  
    14  	"github.com/juju/utils"
    15  )
    16  
    17  var opensshCommonOptions = []string{"-o", "StrictHostKeyChecking no"}
    18  
    19  // default identities will not be attempted if
    20  // -i is specified and they are not explcitly
    21  // included.
    22  var defaultIdentities = []string{
    23  	"~/.ssh/identity",
    24  	"~/.ssh/id_rsa",
    25  	"~/.ssh/id_dsa",
    26  	"~/.ssh/id_ecdsa",
    27  }
    28  
    29  type opensshCommandKind int
    30  
    31  const (
    32  	sshKind opensshCommandKind = iota
    33  	scpKind
    34  )
    35  
    36  // sshpassWrap wraps the command/args with sshpass if it is found in $PATH
    37  // and the SSHPASS environment variable is set. Otherwise, the original
    38  // command/args are returned.
    39  func sshpassWrap(cmd string, args []string) (string, []string) {
    40  	if os.Getenv("SSHPASS") != "" {
    41  		if path, err := exec.LookPath("sshpass"); err == nil {
    42  			return path, append([]string{"-e", cmd}, args...)
    43  		}
    44  	}
    45  	return cmd, args
    46  }
    47  
    48  // OpenSSHClient is an implementation of Client that
    49  // uses the ssh and scp executables found in $PATH.
    50  type OpenSSHClient struct{}
    51  
    52  // NewOpenSSHClient creates a new OpenSSHClient.
    53  // If the ssh and scp programs cannot be found
    54  // in $PATH, then an error is returned.
    55  func NewOpenSSHClient() (*OpenSSHClient, error) {
    56  	var c OpenSSHClient
    57  	if _, err := exec.LookPath("ssh"); err != nil {
    58  		return nil, err
    59  	}
    60  	if _, err := exec.LookPath("scp"); err != nil {
    61  		return nil, err
    62  	}
    63  	return &c, nil
    64  }
    65  
    66  func opensshOptions(options *Options, commandKind opensshCommandKind) []string {
    67  	args := append([]string{}, opensshCommonOptions...)
    68  	if options == nil {
    69  		options = &Options{}
    70  	}
    71  	if len(options.proxyCommand) > 0 {
    72  		args = append(args, "-o", "ProxyCommand "+utils.CommandString(options.proxyCommand...))
    73  	}
    74  	if !options.passwordAuthAllowed {
    75  		args = append(args, "-o", "PasswordAuthentication no")
    76  	}
    77  
    78  	// We must set ServerAliveInterval or the server may
    79  	// think we've become unresponsive on long running
    80  	// command executions such as "apt-get upgrade".
    81  	args = append(args, "-o", "ServerAliveInterval 30")
    82  
    83  	if options.allocatePTY {
    84  		args = append(args, "-t", "-t") // twice to force
    85  	}
    86  	if options.knownHostsFile != "" {
    87  		args = append(args, "-o", "UserKnownHostsFile "+utils.CommandString(options.knownHostsFile))
    88  	}
    89  	identities := append([]string{}, options.identities...)
    90  	if pk := PrivateKeyFiles(); len(pk) > 0 {
    91  		// Add client keys as implicit identities
    92  		identities = append(identities, pk...)
    93  	}
    94  	// If any identities are specified, the
    95  	// default ones must be explicitly specified.
    96  	if len(identities) > 0 {
    97  		// Restrict SSH to only the explicitly provided identity files.
    98  		// Otherwise we may run out of authentication attempts if the
    99  		// user has many identity files.
   100  		args = append(args, "-o", "IdentitiesOnly yes")
   101  		for _, identity := range defaultIdentities {
   102  			path, err := utils.NormalizePath(identity)
   103  			if err != nil {
   104  				logger.Warningf("failed to normalize path %q: %v", identity, err)
   105  				continue
   106  			}
   107  			if _, err := os.Stat(path); err == nil {
   108  				identities = append(identities, path)
   109  			}
   110  		}
   111  	}
   112  	for _, identity := range identities {
   113  		args = append(args, "-i", identity)
   114  	}
   115  	if options.port != 0 {
   116  		port := fmt.Sprint(options.port)
   117  		if commandKind == scpKind {
   118  			// scp uses -P instead of -p (-p means preserve).
   119  			args = append(args, "-P", port)
   120  		} else {
   121  			args = append(args, "-p", port)
   122  		}
   123  	}
   124  	return args
   125  }
   126  
   127  // Command implements Client.Command.
   128  func (c *OpenSSHClient) Command(host string, command []string, options *Options) *Cmd {
   129  	args := opensshOptions(options, sshKind)
   130  	args = append(args, host)
   131  	if len(command) > 0 {
   132  		args = append(args, command...)
   133  	}
   134  	bin, args := sshpassWrap("ssh", args)
   135  	logger.Tracef("running: %s %s", bin, utils.CommandString(args...))
   136  	return &Cmd{impl: &opensshCmd{exec.Command(bin, args...)}}
   137  }
   138  
   139  // Copy implements Client.Copy.
   140  func (c *OpenSSHClient) Copy(args []string, userOptions *Options) error {
   141  	var options Options
   142  	if userOptions != nil {
   143  		options = *userOptions
   144  		options.allocatePTY = false // doesn't make sense for scp
   145  	}
   146  	allArgs := opensshOptions(&options, scpKind)
   147  	allArgs = append(allArgs, args...)
   148  	bin, allArgs := sshpassWrap("scp", allArgs)
   149  	cmd := exec.Command(bin, allArgs...)
   150  	var stderr bytes.Buffer
   151  	cmd.Stderr = &stderr
   152  	logger.Tracef("running: %s %s", bin, utils.CommandString(args...))
   153  	if err := cmd.Run(); err != nil {
   154  		stderr := strings.TrimSpace(stderr.String())
   155  		if len(stderr) > 0 {
   156  			err = fmt.Errorf("%v (%v)", err, stderr)
   157  		}
   158  		return err
   159  	}
   160  	return nil
   161  }
   162  
   163  type opensshCmd struct {
   164  	*exec.Cmd
   165  }
   166  
   167  func (c *opensshCmd) SetStdio(stdin io.Reader, stdout, stderr io.Writer) {
   168  	c.Stdin, c.Stdout, c.Stderr = stdin, stdout, stderr
   169  }
   170  
   171  func (c *opensshCmd) StdinPipe() (io.WriteCloser, io.Reader, error) {
   172  	wc, err := c.Cmd.StdinPipe()
   173  	if err != nil {
   174  		return nil, nil, err
   175  	}
   176  	return wc, c.Stdin, nil
   177  }
   178  
   179  func (c *opensshCmd) StdoutPipe() (io.ReadCloser, io.Writer, error) {
   180  	rc, err := c.Cmd.StdoutPipe()
   181  	if err != nil {
   182  		return nil, nil, err
   183  	}
   184  	return rc, c.Stdout, nil
   185  }
   186  
   187  func (c *opensshCmd) StderrPipe() (io.ReadCloser, io.Writer, error) {
   188  	rc, err := c.Cmd.StderrPipe()
   189  	if err != nil {
   190  		return nil, nil, err
   191  	}
   192  	return rc, c.Stderr, nil
   193  }
   194  
   195  func (c *opensshCmd) Kill() error {
   196  	if c.Process == nil {
   197  		return fmt.Errorf("process has not been started")
   198  	}
   199  	return c.Process.Kill()
   200  }