github.com/altoros/juju-vmware@v0.0.0-20150312064031-f19ae857ccca/utils/ssh/ssh.go (about)

     1  // Copyright 2013 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  // Package ssh contains utilities for dealing with SSH connections,
     5  // key management, and so on. All SSH-based command executions in
     6  // Juju should use the Command/ScpCommand functions in this package.
     7  //
     8  package ssh
     9  
    10  import (
    11  	"bytes"
    12  	"errors"
    13  	"io"
    14  	"os/exec"
    15  	"syscall"
    16  
    17  	"github.com/juju/cmd"
    18  	je "github.com/juju/errors"
    19  )
    20  
    21  // Options is a client-implementation independent SSH options set.
    22  type Options struct {
    23  	// proxyCommand specifies the command to
    24  	// execute to proxy SSH traffic through.
    25  	proxyCommand []string
    26  	// ssh server port; zero means use the default (22)
    27  	port int
    28  	// no PTY forced by default
    29  	allocatePTY bool
    30  	// password authentication is disallowed by default
    31  	passwordAuthAllowed bool
    32  	// identities is a sequence of paths to private key/identity files
    33  	// to use when attempting to login. A client implementaton may attempt
    34  	// with additional identities, but must give preference to these
    35  	identities []string
    36  	// knownHostsFile is a path to a file in which to save the host's
    37  	// fingerprint.
    38  	knownHostsFile string
    39  }
    40  
    41  // SetProxyCommand sets a command to execute to proxy traffic through.
    42  func (o *Options) SetProxyCommand(command ...string) {
    43  	o.proxyCommand = append([]string{}, command...)
    44  }
    45  
    46  // SetPort sets the SSH server port to connect to.
    47  func (o *Options) SetPort(port int) {
    48  	o.port = port
    49  }
    50  
    51  // EnablePTY forces the allocation of a pseudo-TTY.
    52  //
    53  // Forcing a pseudo-TTY is required, for example, for sudo
    54  // prompts on the target host.
    55  func (o *Options) EnablePTY() {
    56  	o.allocatePTY = true
    57  }
    58  
    59  // SetKnownHostsFile sets the host's fingerprint to be saved in the given file.
    60  //
    61  // Host fingerprints are saved in ~/.ssh/known_hosts by default.
    62  func (o *Options) SetKnownHostsFile(file string) {
    63  	o.knownHostsFile = file
    64  }
    65  
    66  // AllowPasswordAuthentication allows the SSH
    67  // client to prompt the user for a password.
    68  //
    69  // Password authentication is disallowed by default.
    70  func (o *Options) AllowPasswordAuthentication() {
    71  	o.passwordAuthAllowed = true
    72  }
    73  
    74  // SetIdentities sets a sequence of paths to private key/identity files
    75  // to use when attempting login. Client implementations may attempt to
    76  // use additional identities, but must give preference to the ones
    77  // specified here.
    78  func (o *Options) SetIdentities(identityFiles ...string) {
    79  	o.identities = append([]string{}, identityFiles...)
    80  }
    81  
    82  // Client is an interface for SSH clients to implement
    83  type Client interface {
    84  	// Command returns a Command for executing a command
    85  	// on the specified host. Each Command is executed
    86  	// within its own SSH session.
    87  	//
    88  	// Host is specified in the format [user@]host.
    89  	Command(host string, command []string, options *Options) *Cmd
    90  
    91  	// Copy copies file(s) between local and remote host(s).
    92  	// Paths are specified in the scp format, [[user@]host:]path. If
    93  	// any extra arguments are specified in extraArgs, they are passed
    94  	// verbatim.
    95  	Copy(args []string, options *Options) error
    96  }
    97  
    98  // Cmd represents a command to be (or being) executed
    99  // on a remote host.
   100  type Cmd struct {
   101  	Stdin  io.Reader
   102  	Stdout io.Writer
   103  	Stderr io.Writer
   104  	impl   command
   105  }
   106  
   107  func newCmd(impl command) *Cmd {
   108  	return &Cmd{impl: impl}
   109  }
   110  
   111  // CombinedOutput runs the command, and returns the
   112  // combined stdout/stderr output and result of
   113  // executing the command.
   114  func (c *Cmd) CombinedOutput() ([]byte, error) {
   115  	if c.Stdout != nil {
   116  		return nil, errors.New("ssh: Stdout already set")
   117  	}
   118  	if c.Stderr != nil {
   119  		return nil, errors.New("ssh: Stderr already set")
   120  	}
   121  	var b bytes.Buffer
   122  	c.Stdout = &b
   123  	c.Stderr = &b
   124  	err := c.Run()
   125  	return b.Bytes(), err
   126  }
   127  
   128  // Output runs the command, and returns the stdout
   129  // output and result of executing the command.
   130  func (c *Cmd) Output() ([]byte, error) {
   131  	if c.Stdout != nil {
   132  		return nil, errors.New("ssh: Stdout already set")
   133  	}
   134  	var b bytes.Buffer
   135  	c.Stdout = &b
   136  	err := c.Run()
   137  	return b.Bytes(), err
   138  }
   139  
   140  // Run runs the command, and returns the result as an error.
   141  func (c *Cmd) Run() error {
   142  	if err := c.Start(); err != nil {
   143  		return err
   144  	}
   145  	err := c.Wait()
   146  	if exitError, ok := err.(*exec.ExitError); ok && exitError != nil {
   147  		status := exitError.ProcessState.Sys().(syscall.WaitStatus)
   148  		if status.Exited() {
   149  			return cmd.NewRcPassthroughError(status.ExitStatus())
   150  		}
   151  	}
   152  	return err
   153  }
   154  
   155  // Start starts the command running, but does not wait for
   156  // it to complete. If the command could not be started, an
   157  // error is returned.
   158  func (c *Cmd) Start() error {
   159  	c.impl.SetStdio(c.Stdin, c.Stdout, c.Stderr)
   160  	return c.impl.Start()
   161  }
   162  
   163  // Wait waits for the started command to complete,
   164  // and returns the result as an error.
   165  func (c *Cmd) Wait() error {
   166  	return c.impl.Wait()
   167  }
   168  
   169  // Kill kills the started command.
   170  func (c *Cmd) Kill() error {
   171  	return c.impl.Kill()
   172  }
   173  
   174  // StdinPipe creates a pipe and connects it to
   175  // the command's stdin. The read end of the pipe
   176  // is assigned to c.Stdin.
   177  func (c *Cmd) StdinPipe() (io.WriteCloser, error) {
   178  	wc, r, err := c.impl.StdinPipe()
   179  	if err != nil {
   180  		return nil, err
   181  	}
   182  	c.Stdin = r
   183  	return wc, nil
   184  }
   185  
   186  // StdoutPipe creates a pipe and connects it to
   187  // the command's stdout. The write end of the pipe
   188  // is assigned to c.Stdout.
   189  func (c *Cmd) StdoutPipe() (io.ReadCloser, error) {
   190  	rc, w, err := c.impl.StdoutPipe()
   191  	if err != nil {
   192  		return nil, err
   193  	}
   194  	c.Stdout = w
   195  	return rc, nil
   196  }
   197  
   198  // StderrPipe creates a pipe and connects it to
   199  // the command's stderr. The write end of the pipe
   200  // is assigned to c.Stderr.
   201  func (c *Cmd) StderrPipe() (io.ReadCloser, error) {
   202  	rc, w, err := c.impl.StderrPipe()
   203  	if err != nil {
   204  		return nil, err
   205  	}
   206  	c.Stderr = w
   207  	return rc, nil
   208  }
   209  
   210  // command is an implementation-specific representation of a
   211  // command prepared to execute against a specific host.
   212  type command interface {
   213  	Start() error
   214  	Wait() error
   215  	Kill() error
   216  	SetStdio(stdin io.Reader, stdout, stderr io.Writer)
   217  	StdinPipe() (io.WriteCloser, io.Reader, error)
   218  	StdoutPipe() (io.ReadCloser, io.Writer, error)
   219  	StderrPipe() (io.ReadCloser, io.Writer, error)
   220  }
   221  
   222  // DefaultClient is the default SSH client for the process.
   223  //
   224  // If the OpenSSH client is found in $PATH, then it will be
   225  // used for DefaultClient; otherwise, DefaultClient will use
   226  // an embedded client based on go.crypto/ssh.
   227  var DefaultClient Client
   228  
   229  // chosenClient holds the type of SSH client created for
   230  // DefaultClient, so that we can log it in Command or Copy.
   231  var chosenClient string
   232  
   233  func init() {
   234  	initDefaultClient()
   235  }
   236  
   237  func initDefaultClient() {
   238  	if client, err := NewOpenSSHClient(); err == nil {
   239  		DefaultClient = client
   240  		chosenClient = "OpenSSH"
   241  	} else if client, err := NewGoCryptoClient(); err == nil {
   242  		DefaultClient = client
   243  		chosenClient = "go.crypto (embedded)"
   244  	}
   245  }
   246  
   247  // Command is a short-cut for DefaultClient.Command.
   248  func Command(host string, command []string, options *Options) *Cmd {
   249  	logger.Debugf("using %s ssh client", chosenClient)
   250  	return DefaultClient.Command(host, command, options)
   251  }
   252  
   253  // Copy is a short-cut for DefaultClient.Copy.
   254  func Copy(args []string, options *Options) error {
   255  	logger.Debugf("using %s ssh client", chosenClient)
   256  	return DefaultClient.Copy(args, options)
   257  }
   258  
   259  // CopyReader sends the reader's data to a file on the remote host over SSH.
   260  func CopyReader(host, filename string, r io.Reader, options *Options) error {
   261  	logger.Debugf("using %s ssh client", chosenClient)
   262  	return copyReader(DefaultClient, host, filename, r, options)
   263  }
   264  
   265  func copyReader(client Client, host, filename string, r io.Reader, options *Options) error {
   266  	cmd := client.Command(host, []string{"cat - > " + filename}, options)
   267  	cmd.Stdin = r
   268  	return je.Trace(cmd.Run())
   269  }