github.com/dcarley/cf-cli@v6.24.1-0.20170220111324-4225ff346898+incompatible/cf/ssh/ssh.go (about)

     1  package sshCmd
     2  
     3  import (
     4  	"crypto/md5"
     5  	"crypto/sha1"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net"
    10  	"os"
    11  	"os/signal"
    12  	"runtime"
    13  	"strings"
    14  	"sync"
    15  	"syscall"
    16  	"time"
    17  
    18  	"golang.org/x/crypto/ssh"
    19  
    20  	"code.cloudfoundry.org/cli/cf/models"
    21  	"code.cloudfoundry.org/cli/cf/ssh/options"
    22  	"code.cloudfoundry.org/cli/cf/ssh/sigwinch"
    23  	"code.cloudfoundry.org/cli/cf/ssh/terminal"
    24  	"github.com/docker/docker/pkg/term"
    25  )
    26  
    27  const (
    28  	md5FingerprintLength  = 47 // inclusive of space between bytes
    29  	sha1FingerprintLength = 59 // inclusive of space between bytes
    30  )
    31  
    32  //go:generate counterfeiter . SecureShell
    33  
    34  type SecureShell interface {
    35  	Connect(opts *options.SSHOptions) error
    36  	InteractiveSession() error
    37  	LocalPortForward() error
    38  	Wait() error
    39  	Close() error
    40  }
    41  
    42  //go:generate counterfeiter . SecureDialer
    43  
    44  type SecureDialer interface {
    45  	Dial(network, address string, config *ssh.ClientConfig) (SecureClient, error)
    46  }
    47  
    48  //go:generate counterfeiter . SecureClient
    49  
    50  type SecureClient interface {
    51  	NewSession() (SecureSession, error)
    52  	Conn() ssh.Conn
    53  	Dial(network, address string) (net.Conn, error)
    54  	Wait() error
    55  	Close() error
    56  }
    57  
    58  //go:generate counterfeiter . ListenerFactory
    59  
    60  type ListenerFactory interface {
    61  	Listen(network, address string) (net.Listener, error)
    62  }
    63  
    64  //go:generate counterfeiter . SecureSession
    65  
    66  type SecureSession interface {
    67  	RequestPty(term string, height, width int, termModes ssh.TerminalModes) error
    68  	SendRequest(name string, wantReply bool, payload []byte) (bool, error)
    69  	StdinPipe() (io.WriteCloser, error)
    70  	StdoutPipe() (io.Reader, error)
    71  	StderrPipe() (io.Reader, error)
    72  	Start(command string) error
    73  	Shell() error
    74  	Wait() error
    75  	Close() error
    76  }
    77  
    78  type secureShell struct {
    79  	secureDialer           SecureDialer
    80  	terminalHelper         terminal.TerminalHelper
    81  	listenerFactory        ListenerFactory
    82  	keepAliveInterval      time.Duration
    83  	app                    models.Application
    84  	sshEndpointFingerprint string
    85  	sshEndpoint            string
    86  	token                  string
    87  	secureClient           SecureClient
    88  	opts                   *options.SSHOptions
    89  
    90  	localListeners []net.Listener
    91  }
    92  
    93  func NewSecureShell(
    94  	secureDialer SecureDialer,
    95  	terminalHelper terminal.TerminalHelper,
    96  	listenerFactory ListenerFactory,
    97  	keepAliveInterval time.Duration,
    98  	app models.Application,
    99  	sshEndpointFingerprint string,
   100  	sshEndpoint string,
   101  	token string,
   102  ) SecureShell {
   103  	return &secureShell{
   104  		secureDialer:      secureDialer,
   105  		terminalHelper:    terminalHelper,
   106  		listenerFactory:   listenerFactory,
   107  		keepAliveInterval: keepAliveInterval,
   108  		app:               app,
   109  		sshEndpointFingerprint: sshEndpointFingerprint,
   110  		sshEndpoint:            sshEndpoint,
   111  		token:                  token,
   112  		localListeners:         []net.Listener{},
   113  	}
   114  }
   115  
   116  func (c *secureShell) Connect(opts *options.SSHOptions) error {
   117  	err := c.validateTarget(opts)
   118  	if err != nil {
   119  		return err
   120  	}
   121  
   122  	clientConfig := &ssh.ClientConfig{
   123  		User: fmt.Sprintf("cf:%s/%d", c.app.GUID, opts.Index),
   124  		Auth: []ssh.AuthMethod{
   125  			ssh.Password(c.token),
   126  		},
   127  		HostKeyCallback: fingerprintCallback(opts, c.sshEndpointFingerprint),
   128  	}
   129  
   130  	secureClient, err := c.secureDialer.Dial("tcp", c.sshEndpoint, clientConfig)
   131  	if err != nil {
   132  		return err
   133  	}
   134  
   135  	c.secureClient = secureClient
   136  	c.opts = opts
   137  	return nil
   138  }
   139  
   140  func (c *secureShell) Close() error {
   141  	for _, listener := range c.localListeners {
   142  		_ = listener.Close()
   143  	}
   144  	return c.secureClient.Close()
   145  }
   146  
   147  func (c *secureShell) LocalPortForward() error {
   148  	for _, forwardSpec := range c.opts.ForwardSpecs {
   149  		listener, err := c.listenerFactory.Listen("tcp", forwardSpec.ListenAddress)
   150  		if err != nil {
   151  			return err
   152  		}
   153  		c.localListeners = append(c.localListeners, listener)
   154  
   155  		go c.localForwardAcceptLoop(listener, forwardSpec.ConnectAddress)
   156  	}
   157  
   158  	return nil
   159  }
   160  
   161  func (c *secureShell) localForwardAcceptLoop(listener net.Listener, addr string) {
   162  	defer listener.Close()
   163  
   164  	for {
   165  		conn, err := listener.Accept()
   166  		if err != nil {
   167  			if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
   168  				time.Sleep(100 * time.Millisecond)
   169  				continue
   170  			}
   171  			return
   172  		}
   173  
   174  		go c.handleForwardConnection(conn, addr)
   175  	}
   176  }
   177  
   178  func (c *secureShell) handleForwardConnection(conn net.Conn, targetAddr string) {
   179  	defer conn.Close()
   180  
   181  	target, err := c.secureClient.Dial("tcp", targetAddr)
   182  	if err != nil {
   183  		fmt.Printf("connect to %s failed: %s\n", targetAddr, err.Error())
   184  		return
   185  	}
   186  	defer target.Close()
   187  
   188  	wg := &sync.WaitGroup{}
   189  	wg.Add(2)
   190  
   191  	go copyAndClose(wg, conn, target)
   192  	go copyAndClose(wg, target, conn)
   193  	wg.Wait()
   194  }
   195  
   196  func copyAndClose(wg *sync.WaitGroup, dest io.WriteCloser, src io.Reader) {
   197  	_, _ = io.Copy(dest, src)
   198  	_ = dest.Close()
   199  	if wg != nil {
   200  		wg.Done()
   201  	}
   202  }
   203  
   204  func copyAndDone(wg *sync.WaitGroup, dest io.Writer, src io.Reader) {
   205  	_, _ = io.Copy(dest, src)
   206  	wg.Done()
   207  }
   208  
   209  func (c *secureShell) InteractiveSession() error {
   210  	var err error
   211  
   212  	secureClient := c.secureClient
   213  	opts := c.opts
   214  
   215  	session, err := secureClient.NewSession()
   216  	if err != nil {
   217  		return fmt.Errorf("SSH session allocation failed: %s", err.Error())
   218  	}
   219  	defer session.Close()
   220  
   221  	stdin, stdout, stderr := c.terminalHelper.StdStreams()
   222  
   223  	inPipe, err := session.StdinPipe()
   224  	if err != nil {
   225  		return err
   226  	}
   227  
   228  	outPipe, err := session.StdoutPipe()
   229  	if err != nil {
   230  		return err
   231  	}
   232  
   233  	errPipe, err := session.StderrPipe()
   234  	if err != nil {
   235  		return err
   236  	}
   237  
   238  	stdinFd, stdinIsTerminal := c.terminalHelper.GetFdInfo(stdin)
   239  	stdoutFd, stdoutIsTerminal := c.terminalHelper.GetFdInfo(stdout)
   240  
   241  	if c.shouldAllocateTerminal(opts, stdinIsTerminal) {
   242  		modes := ssh.TerminalModes{
   243  			ssh.ECHO:          1,
   244  			ssh.TTY_OP_ISPEED: 115200,
   245  			ssh.TTY_OP_OSPEED: 115200,
   246  		}
   247  
   248  		width, height := c.getWindowDimensions(stdoutFd)
   249  
   250  		err = session.RequestPty(c.terminalType(), height, width, modes)
   251  		if err != nil {
   252  			return err
   253  		}
   254  
   255  		var state *term.State
   256  		state, err = c.terminalHelper.SetRawTerminal(stdinFd)
   257  		if err == nil {
   258  			defer c.terminalHelper.RestoreTerminal(stdinFd, state)
   259  		}
   260  	}
   261  
   262  	if len(opts.Command) != 0 {
   263  		cmd := strings.Join(opts.Command, " ")
   264  		err = session.Start(cmd)
   265  		if err != nil {
   266  			return err
   267  		}
   268  	} else {
   269  		err = session.Shell()
   270  		if err != nil {
   271  			return err
   272  		}
   273  	}
   274  
   275  	wg := &sync.WaitGroup{}
   276  	wg.Add(2)
   277  
   278  	go copyAndClose(nil, inPipe, stdin)
   279  	go copyAndDone(wg, stdout, outPipe)
   280  	go copyAndDone(wg, stderr, errPipe)
   281  
   282  	if stdoutIsTerminal {
   283  		resized := make(chan os.Signal, 16)
   284  
   285  		if runtime.GOOS == "windows" {
   286  			ticker := time.NewTicker(250 * time.Millisecond)
   287  			defer ticker.Stop()
   288  
   289  			go func() {
   290  				for range ticker.C {
   291  					resized <- syscall.Signal(-1)
   292  				}
   293  				close(resized)
   294  			}()
   295  		} else {
   296  			signal.Notify(resized, sigwinch.SIGWINCH())
   297  			defer func() { signal.Stop(resized); close(resized) }()
   298  		}
   299  
   300  		go c.resize(resized, session, stdoutFd)
   301  	}
   302  
   303  	keepaliveStopCh := make(chan struct{})
   304  	defer close(keepaliveStopCh)
   305  
   306  	go keepalive(secureClient.Conn(), time.NewTicker(c.keepAliveInterval), keepaliveStopCh)
   307  
   308  	result := session.Wait()
   309  	wg.Wait()
   310  	return result
   311  }
   312  
   313  func (c *secureShell) Wait() error {
   314  	keepaliveStopCh := make(chan struct{})
   315  	defer close(keepaliveStopCh)
   316  
   317  	go keepalive(c.secureClient.Conn(), time.NewTicker(c.keepAliveInterval), keepaliveStopCh)
   318  
   319  	return c.secureClient.Wait()
   320  }
   321  
   322  func (c *secureShell) validateTarget(opts *options.SSHOptions) error {
   323  	if strings.ToUpper(c.app.State) != "STARTED" {
   324  		return fmt.Errorf("Application %q is not in the STARTED state", opts.AppName)
   325  	}
   326  
   327  	if !c.app.Diego {
   328  		return fmt.Errorf("Application %q is not running on Diego", opts.AppName)
   329  	}
   330  
   331  	return nil
   332  }
   333  
   334  func md5Fingerprint(key ssh.PublicKey) string {
   335  	sum := md5.Sum(key.Marshal())
   336  	return strings.Replace(fmt.Sprintf("% x", sum), " ", ":", -1)
   337  }
   338  
   339  func sha1Fingerprint(key ssh.PublicKey) string {
   340  	sum := sha1.Sum(key.Marshal())
   341  	return strings.Replace(fmt.Sprintf("% x", sum), " ", ":", -1)
   342  }
   343  
   344  type hostKeyCallback func(hostname string, remote net.Addr, key ssh.PublicKey) error
   345  
   346  func fingerprintCallback(opts *options.SSHOptions, expectedFingerprint string) hostKeyCallback {
   347  	if opts.SkipHostValidation {
   348  		return nil
   349  	}
   350  
   351  	return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
   352  		switch len(expectedFingerprint) {
   353  		case sha1FingerprintLength:
   354  			fingerprint := sha1Fingerprint(key)
   355  			if fingerprint != expectedFingerprint {
   356  				return fmt.Errorf("Host key verification failed.\n\nThe fingerprint of the received key was %q.", fingerprint)
   357  			}
   358  		case md5FingerprintLength:
   359  			fingerprint := md5Fingerprint(key)
   360  			if fingerprint != expectedFingerprint {
   361  				return fmt.Errorf("Host key verification failed.\n\nThe fingerprint of the received key was %q.", fingerprint)
   362  			}
   363  		case 0:
   364  			fingerprint := md5Fingerprint(key)
   365  			return fmt.Errorf("Unable to verify identity of host.\n\nThe fingerprint of the received key was %q.", fingerprint)
   366  		default:
   367  			return errors.New("Unsupported host key fingerprint format")
   368  		}
   369  		return nil
   370  	}
   371  }
   372  
   373  func (c *secureShell) shouldAllocateTerminal(opts *options.SSHOptions, stdinIsTerminal bool) bool {
   374  	switch opts.TerminalRequest {
   375  	case options.RequestTTYForce:
   376  		return true
   377  	case options.RequestTTYNo:
   378  		return false
   379  	case options.RequestTTYYes:
   380  		return stdinIsTerminal
   381  	case options.RequestTTYAuto:
   382  		return len(opts.Command) == 0 && stdinIsTerminal
   383  	default:
   384  		return false
   385  	}
   386  }
   387  
   388  func (c *secureShell) resize(resized <-chan os.Signal, session SecureSession, terminalFd uintptr) {
   389  	type resizeMessage struct {
   390  		Width       uint32
   391  		Height      uint32
   392  		PixelWidth  uint32
   393  		PixelHeight uint32
   394  	}
   395  
   396  	var previousWidth, previousHeight int
   397  
   398  	for range resized {
   399  		width, height := c.getWindowDimensions(terminalFd)
   400  
   401  		if width == previousWidth && height == previousHeight {
   402  			continue
   403  		}
   404  
   405  		message := resizeMessage{
   406  			Width:  uint32(width),
   407  			Height: uint32(height),
   408  		}
   409  
   410  		_, _ = session.SendRequest("window-change", false, ssh.Marshal(message))
   411  
   412  		previousWidth = width
   413  		previousHeight = height
   414  	}
   415  }
   416  
   417  func keepalive(conn ssh.Conn, ticker *time.Ticker, stopCh chan struct{}) {
   418  	for {
   419  		select {
   420  		case <-ticker.C:
   421  			_, _, _ = conn.SendRequest("keepalive@cloudfoundry.org", true, nil)
   422  		case <-stopCh:
   423  			ticker.Stop()
   424  			return
   425  		}
   426  	}
   427  }
   428  
   429  func (c *secureShell) terminalType() string {
   430  	term := os.Getenv("TERM")
   431  	if term == "" {
   432  		term = "xterm"
   433  	}
   434  	return term
   435  }
   436  
   437  func (c *secureShell) getWindowDimensions(terminalFd uintptr) (width int, height int) {
   438  	winSize, err := c.terminalHelper.GetWinsize(terminalFd)
   439  	if err != nil {
   440  		winSize = &term.Winsize{
   441  			Width:  80,
   442  			Height: 43,
   443  		}
   444  	}
   445  
   446  	return int(winSize.Width), int(winSize.Height)
   447  }
   448  
   449  type secureDialer struct{}
   450  
   451  func (d *secureDialer) Dial(network string, address string, config *ssh.ClientConfig) (SecureClient, error) {
   452  	client, err := ssh.Dial(network, address, config)
   453  	if err != nil {
   454  		return nil, err
   455  	}
   456  
   457  	return &secureClient{client: client}, nil
   458  }
   459  
   460  func DefaultSecureDialer() SecureDialer {
   461  	return &secureDialer{}
   462  }
   463  
   464  type secureClient struct{ client *ssh.Client }
   465  
   466  func (sc *secureClient) Close() error   { return sc.client.Close() }
   467  func (sc *secureClient) Conn() ssh.Conn { return sc.client.Conn }
   468  func (sc *secureClient) Wait() error    { return sc.client.Wait() }
   469  func (sc *secureClient) Dial(n, addr string) (net.Conn, error) {
   470  	return sc.client.Dial(n, addr)
   471  }
   472  func (sc *secureClient) NewSession() (SecureSession, error) {
   473  	return sc.client.NewSession()
   474  }
   475  
   476  type listenerFactory struct{}
   477  
   478  func (lf *listenerFactory) Listen(network, address string) (net.Listener, error) {
   479  	return net.Listen(network, address)
   480  }
   481  
   482  func DefaultListenerFactory() ListenerFactory {
   483  	return &listenerFactory{}
   484  }