github.com/cloudbase/juju-core@v0.0.0-20140504232958-a7271ac7912f/utils/ssh/ssh_gocrypto.go (about)

     1  // Copyright 2013 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package ssh
     5  
     6  import (
     7  	"fmt"
     8  	"io"
     9  	"io/ioutil"
    10  	"os/user"
    11  	"strings"
    12  
    13  	"code.google.com/p/go.crypto/ssh"
    14  
    15  	"launchpad.net/juju-core/utils"
    16  )
    17  
    18  const sshDefaultPort = 22
    19  
    20  // GoCryptoClient is an implementation of Client that
    21  // uses the embedded go.crypto/ssh SSH client.
    22  //
    23  // GoCryptoClient is intentionally limited in the
    24  // functionality that it enables, as it is currently
    25  // intended to be used only for non-interactive command
    26  // execution.
    27  type GoCryptoClient struct {
    28  	signers []ssh.Signer
    29  }
    30  
    31  // NewGoCryptoClient creates a new GoCryptoClient.
    32  //
    33  // If no signers are specified, NewGoCryptoClient will
    34  // use the private key generated by LoadClientKeys.
    35  func NewGoCryptoClient(signers ...ssh.Signer) (*GoCryptoClient, error) {
    36  	return &GoCryptoClient{signers: signers}, nil
    37  }
    38  
    39  // Command implements Client.Command.
    40  func (c *GoCryptoClient) Command(host string, command []string, options *Options) *Cmd {
    41  	shellCommand := utils.CommandString(command...)
    42  	signers := c.signers
    43  	if len(signers) == 0 {
    44  		signers = privateKeys()
    45  	}
    46  	user, host := splitUserHost(host)
    47  	port := sshDefaultPort
    48  	if options != nil {
    49  		if options.port != 0 {
    50  			port = options.port
    51  		}
    52  	}
    53  	logger.Debugf(`running (equivalent of): ssh "%s@%s" -p %d '%s'`, user, host, port, shellCommand)
    54  	return &Cmd{impl: &goCryptoCommand{
    55  		signers: signers,
    56  		user:    user,
    57  		addr:    fmt.Sprintf("%s:%d", host, port),
    58  		command: shellCommand,
    59  	}}
    60  }
    61  
    62  // Copy implements Client.Copy.
    63  //
    64  // Copy is currently unimplemented, and will always return an error.
    65  func (c *GoCryptoClient) Copy(targets, extraArgs []string, options *Options) error {
    66  	return fmt.Errorf("scp command is not implemented (OpenSSH scp not available in PATH)")
    67  }
    68  
    69  type goCryptoCommand struct {
    70  	signers []ssh.Signer
    71  	user    string
    72  	addr    string
    73  	command string
    74  	stdin   io.Reader
    75  	stdout  io.Writer
    76  	stderr  io.Writer
    77  	conn    *ssh.ClientConn
    78  	sess    *ssh.Session
    79  }
    80  
    81  var sshDial = ssh.Dial
    82  
    83  func (c *goCryptoCommand) ensureSession() (*ssh.Session, error) {
    84  	if c.sess != nil {
    85  		return c.sess, nil
    86  	}
    87  	if len(c.signers) == 0 {
    88  		return nil, fmt.Errorf("no private keys available")
    89  	}
    90  	if c.user == "" {
    91  		currentUser, err := user.Current()
    92  		if err != nil {
    93  			return nil, fmt.Errorf("getting current user: %v", err)
    94  		}
    95  		c.user = currentUser.Username
    96  	}
    97  	config := &ssh.ClientConfig{
    98  		User: c.user,
    99  		Auth: []ssh.ClientAuth{
   100  			ssh.ClientAuthKeyring(keyring{c.signers}),
   101  		},
   102  	}
   103  	conn, err := sshDial("tcp", c.addr, config)
   104  	if err != nil {
   105  		return nil, err
   106  	}
   107  	sess, err := conn.NewSession()
   108  	if err != nil {
   109  		conn.Close()
   110  		return nil, err
   111  	}
   112  	c.conn = conn
   113  	c.sess = sess
   114  	c.sess.Stdin = c.stdin
   115  	c.sess.Stdout = c.stdout
   116  	c.sess.Stderr = c.stderr
   117  	return sess, nil
   118  }
   119  
   120  func (c *goCryptoCommand) Start() error {
   121  	sess, err := c.ensureSession()
   122  	if err != nil {
   123  		return err
   124  	}
   125  	if c.command == "" {
   126  		return sess.Shell()
   127  	}
   128  	return sess.Start(c.command)
   129  }
   130  
   131  func (c *goCryptoCommand) Close() error {
   132  	if c.sess == nil {
   133  		return nil
   134  	}
   135  	err0 := c.sess.Close()
   136  	err1 := c.conn.Close()
   137  	if err0 == nil {
   138  		err0 = err1
   139  	}
   140  	c.sess = nil
   141  	c.conn = nil
   142  	return err0
   143  }
   144  
   145  func (c *goCryptoCommand) Wait() error {
   146  	if c.sess == nil {
   147  		return fmt.Errorf("Command has not been started")
   148  	}
   149  	err := c.sess.Wait()
   150  	c.Close()
   151  	return err
   152  }
   153  
   154  func (c *goCryptoCommand) Kill() error {
   155  	if c.sess == nil {
   156  		return fmt.Errorf("Command has not been started")
   157  	}
   158  	return c.sess.Signal(ssh.SIGKILL)
   159  }
   160  
   161  func (c *goCryptoCommand) SetStdio(stdin io.Reader, stdout, stderr io.Writer) {
   162  	c.stdin = stdin
   163  	c.stdout = stdout
   164  	c.stderr = stderr
   165  }
   166  
   167  func (c *goCryptoCommand) StdinPipe() (io.WriteCloser, io.Reader, error) {
   168  	sess, err := c.ensureSession()
   169  	if err != nil {
   170  		return nil, nil, err
   171  	}
   172  	wc, err := sess.StdinPipe()
   173  	return wc, sess.Stdin, err
   174  }
   175  
   176  func (c *goCryptoCommand) StdoutPipe() (io.ReadCloser, io.Writer, error) {
   177  	sess, err := c.ensureSession()
   178  	if err != nil {
   179  		return nil, nil, err
   180  	}
   181  	wc, err := sess.StdoutPipe()
   182  	return ioutil.NopCloser(wc), sess.Stdout, err
   183  }
   184  
   185  func (c *goCryptoCommand) StderrPipe() (io.ReadCloser, io.Writer, error) {
   186  	sess, err := c.ensureSession()
   187  	if err != nil {
   188  		return nil, nil, err
   189  	}
   190  	wc, err := sess.StderrPipe()
   191  	return ioutil.NopCloser(wc), sess.Stderr, err
   192  }
   193  
   194  // keyring implements ssh.ClientKeyring
   195  type keyring struct {
   196  	signers []ssh.Signer
   197  }
   198  
   199  func (k keyring) Key(i int) (ssh.PublicKey, error) {
   200  	if i < 0 || i >= len(k.signers) {
   201  		// nil key marks the end of the keyring; must not return an error.
   202  		return nil, nil
   203  	}
   204  	return k.signers[i].PublicKey(), nil
   205  }
   206  
   207  func (k keyring) Sign(i int, rand io.Reader, data []byte) ([]byte, error) {
   208  	if i < 0 || i >= len(k.signers) {
   209  		return nil, fmt.Errorf("no key at position %d", i)
   210  	}
   211  	return k.signers[i].Sign(rand, data)
   212  }
   213  
   214  func splitUserHost(s string) (user, host string) {
   215  	userHost := strings.SplitN(s, "@", 2)
   216  	if len(userHost) == 2 {
   217  		return userHost[0], userHost[1]
   218  	}
   219  	return "", userHost[0]
   220  }