launchpad.net/~rogpeppe/juju-core/500-errgo-fix@v0.0.0-20140213181702-000000002356/utils/ssh/ssh_gocrypto.go (about)

     1  // Copyright 2013 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package ssh
     5  
     6  import (
     7  	"fmt"
     8  	"io"
     9  	"io/ioutil"
    10  	"os/user"
    11  	"strings"
    12  
    13  	"code.google.com/p/go.crypto/ssh"
    14  
    15  	"launchpad.net/errgo/errors"
    16  	"launchpad.net/juju-core/utils"
    17  )
    18  
    19  const sshDefaultPort = 22
    20  
    21  // GoCryptoClient is an implementation of Client that
    22  // uses the embedded go.crypto/ssh SSH client.
    23  //
    24  // GoCryptoClient is intentionally limited in the
    25  // functionality that it enables, as it is currently
    26  // intended to be used only for non-interactive command
    27  // execution.
    28  type GoCryptoClient struct {
    29  	signers []ssh.Signer
    30  }
    31  
    32  // NewGoCryptoClient creates a new GoCryptoClient.
    33  //
    34  // If no signers are specified, NewGoCryptoClient will
    35  // use the private key generated by LoadClientKeys.
    36  func NewGoCryptoClient(signers ...ssh.Signer) (*GoCryptoClient, error) {
    37  	return &GoCryptoClient{signers: signers}, nil
    38  }
    39  
    40  // Command implements Client.Command.
    41  func (c *GoCryptoClient) Command(host string, command []string, options *Options) *Cmd {
    42  	shellCommand := utils.CommandString(command...)
    43  	signers := c.signers
    44  	if len(signers) == 0 {
    45  		signers = privateKeys()
    46  	}
    47  	user, host := splitUserHost(host)
    48  	port := sshDefaultPort
    49  	if options != nil {
    50  		if options.port != 0 {
    51  			port = options.port
    52  		}
    53  	}
    54  	return &Cmd{impl: &goCryptoCommand{
    55  		signers: signers,
    56  		user:    user,
    57  		addr:    fmt.Sprintf("%s:%d", host, port),
    58  		command: shellCommand,
    59  	}}
    60  }
    61  
    62  // Copy implements Client.Copy.
    63  //
    64  // Copy is currently unimplemented, and will always return an error.
    65  func (c *GoCryptoClient) Copy(source, dest string, options *Options) error {
    66  	return errors.Newf("Copy is not implemented")
    67  }
    68  
    69  type goCryptoCommand struct {
    70  	signers []ssh.Signer
    71  	user    string
    72  	addr    string
    73  	command string
    74  	stdin   io.Reader
    75  	stdout  io.Writer
    76  	stderr  io.Writer
    77  	conn    *ssh.ClientConn
    78  	sess    *ssh.Session
    79  }
    80  
    81  func (c *goCryptoCommand) ensureSession() (*ssh.Session, error) {
    82  	if c.sess != nil {
    83  		return c.sess, nil
    84  	}
    85  	if len(c.signers) == 0 {
    86  		return nil, errors.Newf("no private keys available")
    87  	}
    88  	if c.user == "" {
    89  		currentUser, err := user.Current()
    90  		if err != nil {
    91  			return nil, errors.Notef(err, "getting current user")
    92  		}
    93  		c.user = currentUser.Username
    94  	}
    95  	config := &ssh.ClientConfig{
    96  		User: c.user,
    97  		Auth: []ssh.ClientAuth{
    98  			ssh.ClientAuthKeyring(keyring{c.signers}),
    99  		},
   100  	}
   101  	conn, err := ssh.Dial("tcp", c.addr, config)
   102  	if err != nil {
   103  		return nil, mask(err)
   104  	}
   105  	sess, err := conn.NewSession()
   106  	if err != nil {
   107  		conn.Close()
   108  		return nil, err
   109  	}
   110  	c.conn = conn
   111  	c.sess = sess
   112  	c.sess.Stdin = c.stdin
   113  	c.sess.Stdout = c.stdout
   114  	c.sess.Stderr = c.stderr
   115  	return sess, nil
   116  }
   117  
   118  func (c *goCryptoCommand) Start() error {
   119  	sess, err := c.ensureSession()
   120  	if err != nil {
   121  		return mask(err)
   122  	}
   123  	if c.command == "" {
   124  		return sess.Shell()
   125  	}
   126  	return sess.Start(c.command)
   127  }
   128  
   129  func (c *goCryptoCommand) Close() error {
   130  	if c.sess == nil {
   131  		return nil
   132  	}
   133  	err0 := c.sess.Close()
   134  	err1 := c.conn.Close()
   135  	if err0 == nil {
   136  		err0 = err1
   137  	}
   138  	c.sess = nil
   139  	c.conn = nil
   140  	return err0
   141  }
   142  
   143  func (c *goCryptoCommand) Wait() error {
   144  	if c.sess == nil {
   145  		return errors.Newf("Command has not been started")
   146  	}
   147  	err := c.sess.Wait()
   148  	c.Close()
   149  	return err
   150  }
   151  
   152  func (c *goCryptoCommand) Kill() error {
   153  	if c.sess == nil {
   154  		return errors.Newf("Command has not been started")
   155  	}
   156  	return c.sess.Signal(ssh.SIGKILL)
   157  }
   158  
   159  func (c *goCryptoCommand) SetStdio(stdin io.Reader, stdout, stderr io.Writer) {
   160  	c.stdin = stdin
   161  	c.stdout = stdout
   162  	c.stderr = stderr
   163  }
   164  
   165  func (c *goCryptoCommand) StdinPipe() (io.WriteCloser, io.Reader, error) {
   166  	sess, err := c.ensureSession()
   167  	if err != nil {
   168  		return nil, nil, mask(err)
   169  	}
   170  	wc, err := sess.StdinPipe()
   171  	return wc, sess.Stdin, err
   172  }
   173  
   174  func (c *goCryptoCommand) StdoutPipe() (io.ReadCloser, io.Writer, error) {
   175  	sess, err := c.ensureSession()
   176  	if err != nil {
   177  		return nil, nil, mask(err)
   178  	}
   179  	wc, err := sess.StdoutPipe()
   180  	return ioutil.NopCloser(wc), sess.Stdout, err
   181  }
   182  
   183  func (c *goCryptoCommand) StderrPipe() (io.ReadCloser, io.Writer, error) {
   184  	sess, err := c.ensureSession()
   185  	if err != nil {
   186  		return nil, nil, mask(err)
   187  	}
   188  	wc, err := sess.StderrPipe()
   189  	return ioutil.NopCloser(wc), sess.Stderr, err
   190  }
   191  
   192  // keyring implements ssh.ClientKeyring
   193  type keyring struct {
   194  	signers []ssh.Signer
   195  }
   196  
   197  func (k keyring) Key(i int) (ssh.PublicKey, error) {
   198  	if i < 0 || i >= len(k.signers) {
   199  		// nil key marks the end of the keyring; must not return an error.
   200  		return nil, nil
   201  	}
   202  	return k.signers[i].PublicKey(), nil
   203  }
   204  
   205  func (k keyring) Sign(i int, rand io.Reader, data []byte) ([]byte, error) {
   206  	if i < 0 || i >= len(k.signers) {
   207  		return nil, errors.Newf("no key at position %d", i)
   208  	}
   209  	return k.signers[i].Sign(rand, data)
   210  }
   211  
   212  func splitUserHost(s string) (user, host string) {
   213  	userHost := strings.SplitN(s, "@", 2)
   214  	if len(userHost) == 2 {
   215  		return userHost[0], userHost[1]
   216  	}
   217  	return "", userHost[0]
   218  }