github.com/cloudbase/juju-core@v0.0.0-20140504232958-a7271ac7912f/utils/ssh/ssh_openssh.go (about)

     1  // Copyright 2013 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package ssh
     5  
     6  import (
     7  	"bytes"
     8  	"fmt"
     9  	"io"
    10  	"os"
    11  	"os/exec"
    12  	"strings"
    13  
    14  	"launchpad.net/juju-core/utils"
    15  )
    16  
    17  var opensshCommonOptions = map[string][]string{"-o": []string{"StrictHostKeyChecking no"}}
    18  
    19  // default identities will not be attempted if
    20  // -i is specified and they are not explcitly
    21  // included.
    22  var defaultIdentities = []string{
    23  	"~/.ssh/identity",
    24  	"~/.ssh/id_rsa",
    25  	"~/.ssh/id_dsa",
    26  	"~/.ssh/id_ecdsa",
    27  }
    28  
    29  type opensshCommandKind int
    30  
    31  const (
    32  	sshKind opensshCommandKind = iota
    33  	scpKind
    34  )
    35  
    36  // sshpassWrap wraps the command/args with sshpass if it is found in $PATH
    37  // and the SSHPASS environment variable is set. Otherwise, the original
    38  // command/args are returned.
    39  func sshpassWrap(cmd string, args []string) (string, []string) {
    40  	if os.Getenv("SSHPASS") != "" {
    41  		if path, err := exec.LookPath("sshpass"); err == nil {
    42  			return path, append([]string{"-e", cmd}, args...)
    43  		}
    44  	}
    45  	return cmd, args
    46  }
    47  
    48  // OpenSSHClient is an implementation of Client that
    49  // uses the ssh and scp executables found in $PATH.
    50  type OpenSSHClient struct{}
    51  
    52  // NewOpenSSHClient creates a new OpenSSHClient.
    53  // If the ssh and scp programs cannot be found
    54  // in $PATH, then an error is returned.
    55  func NewOpenSSHClient() (*OpenSSHClient, error) {
    56  	var c OpenSSHClient
    57  	if _, err := exec.LookPath("ssh"); err != nil {
    58  		return nil, err
    59  	}
    60  	if _, err := exec.LookPath("scp"); err != nil {
    61  		return nil, err
    62  	}
    63  	return &c, nil
    64  }
    65  
    66  func opensshOptions(options *Options, commandKind opensshCommandKind) map[string][]string {
    67  	args := make(map[string][]string)
    68  	for k, v := range opensshCommonOptions {
    69  		args[k] = v
    70  	}
    71  	if options == nil {
    72  		options = &Options{}
    73  	}
    74  	if !options.passwordAuthAllowed {
    75  		args["-o"] = append(args["-o"], "PasswordAuthentication no")
    76  	}
    77  	if options.allocatePTY {
    78  		args["-t"] = []string{}
    79  	}
    80  	identities := append([]string{}, options.identities...)
    81  	if pk := PrivateKeyFiles(); len(pk) > 0 {
    82  		// Add client keys as implicit identities
    83  		identities = append(identities, pk...)
    84  	}
    85  	// If any identities are specified, the
    86  	// default ones must be explicitly specified.
    87  	if len(identities) > 0 {
    88  		for _, identity := range defaultIdentities {
    89  			path, err := utils.NormalizePath(identity)
    90  			if err != nil {
    91  				logger.Warningf("failed to normalize path %q: %v", identity, err)
    92  				continue
    93  			}
    94  			if _, err := os.Stat(path); err == nil {
    95  				identities = append(identities, path)
    96  			}
    97  		}
    98  	}
    99  	for _, identity := range identities {
   100  		args["-i"] = append(args["-i"], identity)
   101  	}
   102  	if options.port != 0 {
   103  		port := fmt.Sprint(options.port)
   104  		if commandKind == scpKind {
   105  			// scp uses -P instead of -p (-p means preserve).
   106  			args["-P"] = []string{port}
   107  		} else {
   108  			args["-p"] = []string{port}
   109  		}
   110  	}
   111  	return args
   112  }
   113  
   114  func expandArgs(args map[string][]string, quote bool) []string {
   115  	var list []string
   116  	for opt, vals := range args {
   117  		if len(vals) == 0 {
   118  			list = append(list, opt)
   119  			if opt == "-t" {
   120  				// In order to force a PTY to be allocated, we need to
   121  				// pass -t twice.
   122  				list = append(list, opt)
   123  			}
   124  		}
   125  		for _, val := range vals {
   126  			list = append(list, opt)
   127  			if quote {
   128  				val = fmt.Sprintf("%q", val)
   129  			}
   130  			list = append(list, val)
   131  		}
   132  	}
   133  	return list
   134  }
   135  
   136  // Command implements Client.Command.
   137  func (c *OpenSSHClient) Command(host string, command []string, options *Options) *Cmd {
   138  	opts := opensshOptions(options, sshKind)
   139  	args := expandArgs(opts, false)
   140  	args = append(args, host)
   141  	if len(command) > 0 {
   142  		args = append(args, command...)
   143  	}
   144  	bin, args := sshpassWrap("ssh", args)
   145  	optsList := strings.Join(expandArgs(opts, true), " ")
   146  	fullCommand := strings.Join(command, " ")
   147  	logger.Debugf("running: %s %s %q '%s'", bin, optsList, host, fullCommand)
   148  	return &Cmd{impl: &opensshCmd{exec.Command(bin, args...)}}
   149  }
   150  
   151  // Copy implements Client.Copy.
   152  func (c *OpenSSHClient) Copy(targets, extraArgs []string, userOptions *Options) error {
   153  	var options Options
   154  	if userOptions != nil {
   155  		options = *userOptions
   156  		options.allocatePTY = false // doesn't make sense for scp
   157  	}
   158  	opts := opensshOptions(&options, scpKind)
   159  	args := expandArgs(opts, false)
   160  	args = append(args, extraArgs...)
   161  	args = append(args, targets...)
   162  	bin, args := sshpassWrap("scp", args)
   163  	cmd := exec.Command(bin, args...)
   164  	var stderr bytes.Buffer
   165  	cmd.Stderr = &stderr
   166  	allOpts := append(expandArgs(opts, true), extraArgs...)
   167  	optsList := strings.Join(allOpts, " ")
   168  	targetList := `"` + strings.Join(targets, `" "`) + `"`
   169  	logger.Debugf("running: %s %s %s", bin, optsList, targetList)
   170  	if err := cmd.Run(); err != nil {
   171  		stderr := strings.TrimSpace(stderr.String())
   172  		if len(stderr) > 0 {
   173  			err = fmt.Errorf("%v (%v)", err, stderr)
   174  		}
   175  		return err
   176  	}
   177  	return nil
   178  }
   179  
   180  type opensshCmd struct {
   181  	*exec.Cmd
   182  }
   183  
   184  func (c *opensshCmd) SetStdio(stdin io.Reader, stdout, stderr io.Writer) {
   185  	c.Stdin, c.Stdout, c.Stderr = stdin, stdout, stderr
   186  }
   187  
   188  func (c *opensshCmd) StdinPipe() (io.WriteCloser, io.Reader, error) {
   189  	wc, err := c.Cmd.StdinPipe()
   190  	if err != nil {
   191  		return nil, nil, err
   192  	}
   193  	return wc, c.Stdin, nil
   194  }
   195  
   196  func (c *opensshCmd) StdoutPipe() (io.ReadCloser, io.Writer, error) {
   197  	rc, err := c.Cmd.StdoutPipe()
   198  	if err != nil {
   199  		return nil, nil, err
   200  	}
   201  	return rc, c.Stdout, nil
   202  }
   203  
   204  func (c *opensshCmd) StderrPipe() (io.ReadCloser, io.Writer, error) {
   205  	rc, err := c.Cmd.StderrPipe()
   206  	if err != nil {
   207  		return nil, nil, err
   208  	}
   209  	return rc, c.Stderr, nil
   210  }
   211  
   212  func (c *opensshCmd) Kill() error {
   213  	if c.Process == nil {
   214  		return fmt.Errorf("process has not been started")
   215  	}
   216  	return c.Process.Kill()
   217  }