github.com/cloudfoundry-attic/ltc@v0.0.0-20151123212628-098adc7919fc/ssh/sshapi/client.go (about)

     1  package sshapi
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"net"
     7  	"os"
     8  	"sync"
     9  	"time"
    10  
    11  	"golang.org/x/crypto/ssh"
    12  )
    13  
    14  var DialFunc = ssh.Dial
    15  
    16  //go:generate counterfeiter -o mocks/fake_dialer.go . Dialer
    17  type Dialer interface {
    18  	Dial(n, addr string) (net.Conn, error)
    19  }
    20  
    21  //go:generate counterfeiter -o mocks/fake_ssh_session_factory.go . SSHSessionFactory
    22  type SSHSessionFactory interface {
    23  	New() (SSHSession, error)
    24  }
    25  
    26  type Client struct {
    27  	Dialer            Dialer
    28  	SSHSessionFactory SSHSessionFactory
    29  	Stdin             io.Reader
    30  	Stdout            io.Writer
    31  	Stderr            io.Writer
    32  }
    33  
    34  func New(user, authUser, authPassword, address string) (*Client, error) {
    35  	config := &ssh.ClientConfig{
    36  		User: user,
    37  		Auth: []ssh.AuthMethod{ssh.Password(fmt.Sprintf("%s:%s", authUser, authPassword))},
    38  	}
    39  
    40  	client, err := DialFunc("tcp", address, config)
    41  	if err != nil {
    42  		return nil, err
    43  	}
    44  
    45  	return &Client{
    46  		Dialer:            client,
    47  		SSHSessionFactory: &CryptoSSHSessionFactory{client},
    48  		Stdin:             os.Stdin,
    49  		Stdout:            os.Stdout,
    50  		Stderr:            os.Stderr,
    51  	}, nil
    52  }
    53  
    54  func (c *Client) Forward(localConn io.ReadWriteCloser, remoteAddress string) error {
    55  	remoteConn, err := c.Dialer.Dial("tcp", remoteAddress)
    56  	if err != nil {
    57  		return err
    58  	}
    59  
    60  	wg := &sync.WaitGroup{}
    61  	wg.Add(2)
    62  
    63  	go copyAndClose(wg, localConn, remoteConn)
    64  	go copyAndClose(wg, remoteConn, localConn)
    65  	wg.Wait()
    66  
    67  	return nil
    68  }
    69  
    70  func (c *Client) Open(width, height int, desirePTY bool) (*Session, error) {
    71  	session, err := c.SSHSessionFactory.New()
    72  	if err != nil {
    73  		return nil, err
    74  	}
    75  
    76  	sessionIn, err := session.StdinPipe()
    77  	if err != nil {
    78  		return nil, err
    79  	}
    80  
    81  	sessionOut, err := session.StdoutPipe()
    82  	if err != nil {
    83  		return nil, err
    84  	}
    85  
    86  	sessionErr, err := session.StderrPipe()
    87  	if err != nil {
    88  		return nil, err
    89  	}
    90  
    91  	if desirePTY {
    92  		modes := ssh.TerminalModes{
    93  			ssh.ECHO:          1,
    94  			ssh.TTY_OP_ISPEED: 115200,
    95  			ssh.TTY_OP_OSPEED: 115200,
    96  		}
    97  
    98  		terminalType := os.Getenv("TERM")
    99  		if terminalType == "" {
   100  			terminalType = "xterm"
   101  		}
   102  
   103  		if err := session.RequestPty(terminalType, height, width, modes); err != nil {
   104  			return nil, err
   105  		}
   106  	}
   107  
   108  	go copyAndClose(nil, sessionIn, c.Stdin)
   109  	go io.Copy(c.Stdout, sessionOut)
   110  	go io.Copy(c.Stderr, sessionErr)
   111  
   112  	return &Session{time.NewTicker(30 * time.Second), session, session}, nil
   113  }
   114  
   115  func copyAndClose(wg *sync.WaitGroup, dest io.WriteCloser, src io.Reader) {
   116  	io.Copy(dest, src)
   117  	dest.Close()
   118  	if wg != nil {
   119  		wg.Done()
   120  	}
   121  }