github.com/sleungcy-sap/cli@v7.1.0+incompatible/cf/ssh/ssh.go (about)

     1  package sshCmd
     2  
     3  import (
     4  	"crypto/md5"
     5  	"crypto/sha1"
     6  	"crypto/sha256"
     7  	"encoding/base64"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"net"
    12  	"os"
    13  	"os/signal"
    14  	"runtime"
    15  	"strings"
    16  	"sync"
    17  	"syscall"
    18  	"time"
    19  
    20  	"golang.org/x/crypto/ssh"
    21  
    22  	"code.cloudfoundry.org/cli/cf/models"
    23  	"code.cloudfoundry.org/cli/cf/ssh/options"
    24  	"code.cloudfoundry.org/cli/cf/ssh/sigwinch"
    25  	"code.cloudfoundry.org/cli/cf/ssh/terminal"
    26  	"github.com/moby/moby/pkg/term"
    27  )
    28  
    29  const (
    30  	md5FingerprintLength          = 47 // inclusive of space between bytes
    31  	hexSha1FingerprintLength      = 59 // inclusive of space between bytes
    32  	base64Sha256FingerprintLength = 43
    33  )
    34  
    35  //go:generate counterfeiter . SecureShell
    36  
    37  type SecureShell interface {
    38  	Connect(opts *options.SSHOptions) error
    39  	InteractiveSession() error
    40  	LocalPortForward() error
    41  	Wait() error
    42  	Close() error
    43  }
    44  
    45  //go:generate counterfeiter . SecureDialer
    46  
    47  type SecureDialer interface {
    48  	Dial(network, address string, config *ssh.ClientConfig) (SecureClient, error)
    49  }
    50  
    51  //go:generate counterfeiter . SecureClient
    52  
    53  type SecureClient interface {
    54  	NewSession() (SecureSession, error)
    55  	Conn() ssh.Conn
    56  	Dial(network, address string) (net.Conn, error)
    57  	Wait() error
    58  	Close() error
    59  }
    60  
    61  //go:generate counterfeiter . ListenerFactory
    62  
    63  type ListenerFactory interface {
    64  	Listen(network, address string) (net.Listener, error)
    65  }
    66  
    67  //go:generate counterfeiter . SecureSession
    68  
    69  type SecureSession interface {
    70  	RequestPty(term string, height, width int, termModes ssh.TerminalModes) error
    71  	SendRequest(name string, wantReply bool, payload []byte) (bool, error)
    72  	StdinPipe() (io.WriteCloser, error)
    73  	StdoutPipe() (io.Reader, error)
    74  	StderrPipe() (io.Reader, error)
    75  	Start(command string) error
    76  	Shell() error
    77  	Wait() error
    78  	Close() error
    79  }
    80  
    81  type secureShell struct {
    82  	secureDialer           SecureDialer
    83  	terminalHelper         terminal.TerminalHelper
    84  	listenerFactory        ListenerFactory
    85  	keepAliveInterval      time.Duration
    86  	app                    models.Application
    87  	sshEndpointFingerprint string
    88  	sshEndpoint            string
    89  	token                  string
    90  	secureClient           SecureClient
    91  	opts                   *options.SSHOptions
    92  
    93  	localListeners []net.Listener
    94  }
    95  
    96  func NewSecureShell(
    97  	secureDialer SecureDialer,
    98  	terminalHelper terminal.TerminalHelper,
    99  	listenerFactory ListenerFactory,
   100  	keepAliveInterval time.Duration,
   101  	app models.Application,
   102  	sshEndpointFingerprint string,
   103  	sshEndpoint string,
   104  	token string,
   105  ) SecureShell {
   106  	return &secureShell{
   107  		secureDialer:           secureDialer,
   108  		terminalHelper:         terminalHelper,
   109  		listenerFactory:        listenerFactory,
   110  		keepAliveInterval:      keepAliveInterval,
   111  		app:                    app,
   112  		sshEndpointFingerprint: sshEndpointFingerprint,
   113  		sshEndpoint:            sshEndpoint,
   114  		token:                  token,
   115  		localListeners:         []net.Listener{},
   116  	}
   117  }
   118  
   119  func (c *secureShell) Connect(opts *options.SSHOptions) error {
   120  	err := c.validateTarget(opts)
   121  	if err != nil {
   122  		return err
   123  	}
   124  
   125  	clientConfig := &ssh.ClientConfig{
   126  		User: fmt.Sprintf("cf:%s/%d", c.app.GUID, opts.Index),
   127  		Auth: []ssh.AuthMethod{
   128  			ssh.Password(c.token),
   129  		},
   130  		HostKeyCallback: fingerprintCallback(opts, c.sshEndpointFingerprint),
   131  	}
   132  
   133  	secureClient, err := c.secureDialer.Dial("tcp", c.sshEndpoint, clientConfig)
   134  	if err != nil {
   135  		return err
   136  	}
   137  
   138  	c.secureClient = secureClient
   139  	c.opts = opts
   140  	return nil
   141  }
   142  
   143  func (c *secureShell) Close() error {
   144  	for _, listener := range c.localListeners {
   145  		_ = listener.Close()
   146  	}
   147  	return c.secureClient.Close()
   148  }
   149  
   150  func (c *secureShell) LocalPortForward() error {
   151  	for _, forwardSpec := range c.opts.ForwardSpecs {
   152  		listener, err := c.listenerFactory.Listen("tcp", forwardSpec.ListenAddress)
   153  		if err != nil {
   154  			return err
   155  		}
   156  		c.localListeners = append(c.localListeners, listener)
   157  
   158  		go c.localForwardAcceptLoop(listener, forwardSpec.ConnectAddress)
   159  	}
   160  
   161  	return nil
   162  }
   163  
   164  func (c *secureShell) localForwardAcceptLoop(listener net.Listener, addr string) {
   165  	defer listener.Close()
   166  
   167  	for {
   168  		conn, err := listener.Accept()
   169  		if err != nil {
   170  			if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
   171  				time.Sleep(100 * time.Millisecond)
   172  				continue
   173  			}
   174  			return
   175  		}
   176  
   177  		go c.handleForwardConnection(conn, addr)
   178  	}
   179  }
   180  
   181  func (c *secureShell) handleForwardConnection(conn net.Conn, targetAddr string) {
   182  	defer conn.Close()
   183  
   184  	target, err := c.secureClient.Dial("tcp", targetAddr)
   185  	if err != nil {
   186  		fmt.Printf("connect to %s failed: %s\n", targetAddr, err.Error())
   187  		return
   188  	}
   189  	defer target.Close()
   190  
   191  	wg := &sync.WaitGroup{}
   192  	wg.Add(2)
   193  
   194  	go copyAndClose(wg, conn, target)
   195  	go copyAndClose(wg, target, conn)
   196  	wg.Wait()
   197  }
   198  
   199  func copyAndClose(wg *sync.WaitGroup, dest io.WriteCloser, src io.Reader) {
   200  	_, _ = io.Copy(dest, src)
   201  	_ = dest.Close()
   202  	if wg != nil {
   203  		wg.Done()
   204  	}
   205  }
   206  
   207  func copyAndDone(wg *sync.WaitGroup, dest io.Writer, src io.Reader) {
   208  	_, _ = io.Copy(dest, src)
   209  	wg.Done()
   210  }
   211  
   212  func (c *secureShell) InteractiveSession() error {
   213  	var err error
   214  
   215  	secureClient := c.secureClient
   216  	opts := c.opts
   217  
   218  	session, err := secureClient.NewSession()
   219  	if err != nil {
   220  		return fmt.Errorf("SSH session allocation failed: %s", err.Error())
   221  	}
   222  	defer session.Close()
   223  
   224  	stdin, stdout, stderr := c.terminalHelper.StdStreams()
   225  
   226  	inPipe, err := session.StdinPipe()
   227  	if err != nil {
   228  		return err
   229  	}
   230  
   231  	outPipe, err := session.StdoutPipe()
   232  	if err != nil {
   233  		return err
   234  	}
   235  
   236  	errPipe, err := session.StderrPipe()
   237  	if err != nil {
   238  		return err
   239  	}
   240  
   241  	stdinFd, stdinIsTerminal := c.terminalHelper.GetFdInfo(stdin)
   242  	stdoutFd, stdoutIsTerminal := c.terminalHelper.GetFdInfo(stdout)
   243  
   244  	if c.shouldAllocateTerminal(opts, stdinIsTerminal) {
   245  		modes := ssh.TerminalModes{
   246  			ssh.ECHO:          1,
   247  			ssh.TTY_OP_ISPEED: 115200,
   248  			ssh.TTY_OP_OSPEED: 115200,
   249  		}
   250  
   251  		width, height := c.getWindowDimensions(stdoutFd)
   252  
   253  		err = session.RequestPty(c.terminalType(), height, width, modes)
   254  		if err != nil {
   255  			return err
   256  		}
   257  
   258  		var state *term.State
   259  		state, err = c.terminalHelper.SetRawTerminal(stdinFd)
   260  		if err == nil {
   261  			defer c.terminalHelper.RestoreTerminal(stdinFd, state)
   262  		}
   263  	}
   264  
   265  	if len(opts.Command) != 0 {
   266  		cmd := strings.Join(opts.Command, " ")
   267  		err = session.Start(cmd)
   268  		if err != nil {
   269  			return err
   270  		}
   271  	} else {
   272  		err = session.Shell()
   273  		if err != nil {
   274  			return err
   275  		}
   276  	}
   277  
   278  	wg := &sync.WaitGroup{}
   279  	wg.Add(2)
   280  
   281  	go copyAndClose(nil, inPipe, stdin)
   282  	go copyAndDone(wg, stdout, outPipe)
   283  	go copyAndDone(wg, stderr, errPipe)
   284  
   285  	if stdoutIsTerminal {
   286  		resized := make(chan os.Signal, 16)
   287  
   288  		if runtime.GOOS == "windows" {
   289  			ticker := time.NewTicker(250 * time.Millisecond)
   290  			defer ticker.Stop()
   291  
   292  			go func() {
   293  				for range ticker.C {
   294  					resized <- syscall.Signal(-1)
   295  				}
   296  				close(resized)
   297  			}()
   298  		} else {
   299  			signal.Notify(resized, sigwinch.SIGWINCH())
   300  			defer func() { signal.Stop(resized); close(resized) }()
   301  		}
   302  
   303  		go c.resize(resized, session, stdoutFd)
   304  	}
   305  
   306  	keepaliveStopCh := make(chan struct{})
   307  	defer close(keepaliveStopCh)
   308  
   309  	go keepalive(secureClient.Conn(), time.NewTicker(c.keepAliveInterval), keepaliveStopCh)
   310  
   311  	result := session.Wait()
   312  	wg.Wait()
   313  	return result
   314  }
   315  
   316  func (c *secureShell) Wait() error {
   317  	keepaliveStopCh := make(chan struct{})
   318  	defer close(keepaliveStopCh)
   319  
   320  	go keepalive(c.secureClient.Conn(), time.NewTicker(c.keepAliveInterval), keepaliveStopCh)
   321  
   322  	return c.secureClient.Wait()
   323  }
   324  
   325  func (c *secureShell) validateTarget(opts *options.SSHOptions) error {
   326  	if strings.ToUpper(c.app.State) != "STARTED" {
   327  		return fmt.Errorf("Application %q is not in the STARTED state", opts.AppName)
   328  	}
   329  
   330  	return nil
   331  }
   332  
   333  func md5Fingerprint(key ssh.PublicKey) string {
   334  	sum := md5.Sum(key.Marshal())
   335  	return strings.Replace(fmt.Sprintf("% x", sum), " ", ":", -1)
   336  }
   337  
   338  func hexSha1Fingerprint(key ssh.PublicKey) string {
   339  	sum := sha1.Sum(key.Marshal())
   340  	return strings.Replace(fmt.Sprintf("% x", sum), " ", ":", -1)
   341  }
   342  
   343  func base64Sha256Fingerprint(key ssh.PublicKey) string {
   344  	sum := sha256.Sum256(key.Marshal())
   345  	return base64.RawStdEncoding.EncodeToString(sum[:])
   346  }
   347  
   348  func fingerprintCallback(opts *options.SSHOptions, expectedFingerprint string) ssh.HostKeyCallback {
   349  	return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
   350  
   351  		if opts.SkipHostValidation {
   352  			return nil
   353  		}
   354  		var fingerprint string
   355  
   356  		switch len(expectedFingerprint) {
   357  		case base64Sha256FingerprintLength:
   358  			fingerprint = base64Sha256Fingerprint(key)
   359  		case hexSha1FingerprintLength:
   360  			fingerprint = hexSha1Fingerprint(key)
   361  		case md5FingerprintLength:
   362  			fingerprint = md5Fingerprint(key)
   363  		case 0:
   364  			fingerprint = md5Fingerprint(key)
   365  			return fmt.Errorf("Unable to verify identity of host.\n\nThe fingerprint of the received key was %q.", fingerprint)
   366  		default:
   367  			return errors.New("Unsupported host key fingerprint format")
   368  		}
   369  
   370  		if fingerprint != expectedFingerprint {
   371  			return fmt.Errorf("Host key verification failed.\n\nThe fingerprint of the received key was %q.", fingerprint)
   372  		}
   373  		return nil
   374  	}
   375  }
   376  
   377  func (c *secureShell) shouldAllocateTerminal(opts *options.SSHOptions, stdinIsTerminal bool) bool {
   378  	switch opts.TerminalRequest {
   379  	case options.RequestTTYForce:
   380  		return true
   381  	case options.RequestTTYNo:
   382  		return false
   383  	case options.RequestTTYYes:
   384  		return stdinIsTerminal
   385  	case options.RequestTTYAuto:
   386  		return len(opts.Command) == 0 && stdinIsTerminal
   387  	default:
   388  		return false
   389  	}
   390  }
   391  
   392  func (c *secureShell) resize(resized <-chan os.Signal, session SecureSession, terminalFd uintptr) {
   393  	type resizeMessage struct {
   394  		Width       uint32
   395  		Height      uint32
   396  		PixelWidth  uint32
   397  		PixelHeight uint32
   398  	}
   399  
   400  	var previousWidth, previousHeight int
   401  
   402  	for range resized {
   403  		width, height := c.getWindowDimensions(terminalFd)
   404  
   405  		if width == previousWidth && height == previousHeight {
   406  			continue
   407  		}
   408  
   409  		message := resizeMessage{
   410  			Width:  uint32(width),
   411  			Height: uint32(height),
   412  		}
   413  
   414  		_, _ = session.SendRequest("window-change", false, ssh.Marshal(message))
   415  
   416  		previousWidth = width
   417  		previousHeight = height
   418  	}
   419  }
   420  
   421  func keepalive(conn ssh.Conn, ticker *time.Ticker, stopCh chan struct{}) {
   422  	for {
   423  		select {
   424  		case <-ticker.C:
   425  			_, _, _ = conn.SendRequest("keepalive@cloudfoundry.org", true, nil)
   426  		case <-stopCh:
   427  			ticker.Stop()
   428  			return
   429  		}
   430  	}
   431  }
   432  
   433  func (c *secureShell) terminalType() string {
   434  	term := os.Getenv("TERM")
   435  	if term == "" {
   436  		term = "xterm"
   437  	}
   438  	return term
   439  }
   440  
   441  func (c *secureShell) getWindowDimensions(terminalFd uintptr) (width int, height int) {
   442  	winSize, err := c.terminalHelper.GetWinsize(terminalFd)
   443  	if err != nil {
   444  		winSize = &term.Winsize{
   445  			Width:  80,
   446  			Height: 43,
   447  		}
   448  	}
   449  
   450  	return int(winSize.Width), int(winSize.Height)
   451  }
   452  
   453  type secureDialer struct{}
   454  
   455  func (d *secureDialer) Dial(network string, address string, config *ssh.ClientConfig) (SecureClient, error) {
   456  	client, err := ssh.Dial(network, address, config)
   457  	if err != nil {
   458  		return nil, err
   459  	}
   460  
   461  	return &secureClient{client: client}, nil
   462  }
   463  
   464  func DefaultSecureDialer() SecureDialer {
   465  	return &secureDialer{}
   466  }
   467  
   468  type secureClient struct{ client *ssh.Client }
   469  
   470  func (sc *secureClient) Close() error   { return sc.client.Close() }
   471  func (sc *secureClient) Conn() ssh.Conn { return sc.client.Conn }
   472  func (sc *secureClient) Wait() error    { return sc.client.Wait() }
   473  func (sc *secureClient) Dial(n, addr string) (net.Conn, error) {
   474  	return sc.client.Dial(n, addr)
   475  }
   476  func (sc *secureClient) NewSession() (SecureSession, error) {
   477  	return sc.client.NewSession()
   478  }
   479  
   480  type listenerFactory struct{}
   481  
   482  func (lf *listenerFactory) Listen(network, address string) (net.Listener, error) {
   483  	return net.Listen(network, address)
   484  }
   485  
   486  func DefaultListenerFactory() ListenerFactory {
   487  	return &listenerFactory{}
   488  }