github.com/jenspinney/cli@v6.42.1-0.20190207184520-7450c600020e+incompatible/util/clissh/ssh.go (about)

     1  package clissh
     2  
     3  import (
     4  	"crypto/md5"
     5  	"crypto/sha1"
     6  	"crypto/sha256"
     7  	"encoding/base64"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"net"
    12  	"os"
    13  	"os/signal"
    14  	"runtime"
    15  	"strings"
    16  	"sync"
    17  	"syscall"
    18  	"time"
    19  
    20  	"code.cloudfoundry.org/cli/util/clissh/sigwinch"
    21  	"code.cloudfoundry.org/cli/util/clissh/ssherror"
    22  	"github.com/moby/moby/pkg/term"
    23  	"golang.org/x/crypto/ssh"
    24  )
    25  
    26  type TTYRequest int
    27  
    28  const (
    29  	RequestTTYAuto TTYRequest = iota
    30  	RequestTTYNo
    31  	RequestTTYYes
    32  	RequestTTYForce
    33  )
    34  
    35  const (
    36  	md5FingerprintLength          = 47 // inclusive of space between bytes
    37  	hexSha1FingerprintLength      = 59 // inclusive of space between bytes
    38  	base64Sha256FingerprintLength = 43
    39  
    40  	DefaultKeepAliveInterval = 30 * time.Second
    41  )
    42  
    43  type LocalPortForward struct {
    44  	LocalAddress  string
    45  	RemoteAddress string
    46  }
    47  
    48  //go:generate counterfeiter . SecureDialer
    49  
    50  type SecureDialer interface {
    51  	Dial(network, address string, config *ssh.ClientConfig) (SecureClient, error)
    52  }
    53  
    54  //go:generate counterfeiter . SecureClient
    55  
    56  type SecureClient interface {
    57  	NewSession() (SecureSession, error)
    58  	Conn() ssh.Conn
    59  	Dial(network, address string) (net.Conn, error)
    60  	Wait() error
    61  	Close() error
    62  }
    63  
    64  //go:generate counterfeiter . TerminalHelper
    65  
    66  type TerminalHelper interface {
    67  	GetFdInfo(in interface{}) (fd uintptr, isTerminal bool)
    68  	SetRawTerminal(fd uintptr) (*term.State, error)
    69  	RestoreTerminal(fd uintptr, state *term.State) error
    70  	GetWinsize(fd uintptr) (*term.Winsize, error)
    71  	StdStreams() (stdin io.ReadCloser, stdout io.Writer, stderr io.Writer)
    72  }
    73  
    74  //go:generate counterfeiter . ListenerFactory
    75  
    76  type ListenerFactory interface {
    77  	Listen(network, address string) (net.Listener, error)
    78  }
    79  
    80  //go:generate counterfeiter . SecureSession
    81  
    82  type SecureSession interface {
    83  	RequestPty(term string, height, width int, termModes ssh.TerminalModes) error
    84  	SendRequest(name string, wantReply bool, payload []byte) (bool, error)
    85  	StdinPipe() (io.WriteCloser, error)
    86  	StdoutPipe() (io.Reader, error)
    87  	StderrPipe() (io.Reader, error)
    88  	Start(command string) error
    89  	Shell() error
    90  	Wait() error
    91  	Close() error
    92  }
    93  
    94  type SecureShell struct {
    95  	secureDialer    SecureDialer
    96  	secureClient    SecureClient
    97  	terminalHelper  TerminalHelper
    98  	listenerFactory ListenerFactory
    99  
   100  	localListeners    []net.Listener
   101  	keepAliveInterval time.Duration
   102  }
   103  
   104  func NewDefaultSecureShell() *SecureShell {
   105  	defaultSecureDialer := DefaultSecureDialer()
   106  	defaultTerminalHelper := DefaultTerminalHelper()
   107  	defaultListenerFactory := DefaultListenerFactory()
   108  	return &SecureShell{
   109  		secureDialer:      defaultSecureDialer,
   110  		terminalHelper:    defaultTerminalHelper,
   111  		listenerFactory:   defaultListenerFactory,
   112  		keepAliveInterval: DefaultKeepAliveInterval,
   113  		localListeners:    []net.Listener{},
   114  	}
   115  }
   116  
   117  func NewSecureShell(
   118  	secureDialer SecureDialer,
   119  	terminalHelper TerminalHelper,
   120  	listenerFactory ListenerFactory,
   121  	keepAliveInterval time.Duration,
   122  ) *SecureShell {
   123  	return &SecureShell{
   124  		secureDialer:      secureDialer,
   125  		terminalHelper:    terminalHelper,
   126  		listenerFactory:   listenerFactory,
   127  		keepAliveInterval: keepAliveInterval,
   128  		localListeners:    []net.Listener{},
   129  	}
   130  }
   131  
   132  func (c *SecureShell) Connect(username string, passcode string, appSSHEndpoint string, appSSHHostKeyFingerprint string, skipHostValidation bool) error {
   133  	hostKeyCallbackFunction := fingerprintCallback(skipHostValidation, appSSHHostKeyFingerprint)
   134  
   135  	clientConfig := &ssh.ClientConfig{
   136  		User:            username,
   137  		Auth:            []ssh.AuthMethod{ssh.Password(passcode)},
   138  		HostKeyCallback: hostKeyCallbackFunction,
   139  	}
   140  
   141  	secureClient, err := c.secureDialer.Dial("tcp", appSSHEndpoint, clientConfig)
   142  	if err != nil {
   143  		if strings.Contains(err.Error(), "ssh: unable to authenticate") {
   144  			return ssherror.UnableToAuthenticateError{Err: err}
   145  		}
   146  		return err
   147  	}
   148  
   149  	c.secureClient = secureClient
   150  	return nil
   151  }
   152  
   153  func (c *SecureShell) Close() error {
   154  	for _, listener := range c.localListeners {
   155  		listener.Close()
   156  	}
   157  	return c.secureClient.Close()
   158  }
   159  
   160  func (c *SecureShell) LocalPortForward(localPortForwardSpecs []LocalPortForward) error {
   161  	for _, spec := range localPortForwardSpecs {
   162  		listener, err := c.listenerFactory.Listen("tcp", spec.LocalAddress)
   163  		if err != nil {
   164  			return err
   165  		}
   166  		c.localListeners = append(c.localListeners, listener)
   167  
   168  		go c.localForwardAcceptLoop(listener, spec.RemoteAddress)
   169  	}
   170  
   171  	return nil
   172  }
   173  
   174  func (c *SecureShell) localForwardAcceptLoop(listener net.Listener, addr string) {
   175  	defer listener.Close()
   176  
   177  	for {
   178  		conn, err := listener.Accept()
   179  		if err != nil {
   180  			if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
   181  				time.Sleep(100 * time.Millisecond)
   182  				continue
   183  			}
   184  			return
   185  		}
   186  
   187  		go c.handleForwardConnection(conn, addr)
   188  	}
   189  }
   190  
   191  func (c *SecureShell) handleForwardConnection(conn net.Conn, targetAddr string) {
   192  	defer conn.Close()
   193  
   194  	target, err := c.secureClient.Dial("tcp", targetAddr)
   195  	if err != nil {
   196  		fmt.Printf("connect to %s failed: %s\n", targetAddr, err.Error())
   197  		return
   198  	}
   199  	defer target.Close()
   200  
   201  	wg := &sync.WaitGroup{}
   202  	wg.Add(2)
   203  
   204  	go copyAndClose(wg, conn, target)
   205  	go copyAndClose(wg, target, conn)
   206  	wg.Wait()
   207  }
   208  
   209  func copyAndClose(wg *sync.WaitGroup, dest io.WriteCloser, src io.Reader) {
   210  	_, _ = io.Copy(dest, src)
   211  	_ = dest.Close()
   212  	if wg != nil {
   213  		wg.Done()
   214  	}
   215  }
   216  
   217  func copyAndDone(wg *sync.WaitGroup, dest io.Writer, src io.Reader) {
   218  	_, _ = io.Copy(dest, src)
   219  	wg.Done()
   220  }
   221  
   222  func (c *SecureShell) InteractiveSession(commands []string, terminalRequest TTYRequest) error {
   223  	session, err := c.secureClient.NewSession()
   224  	if err != nil {
   225  		return fmt.Errorf("SSH session allocation failed: %s", err.Error())
   226  	}
   227  	defer session.Close()
   228  
   229  	stdin, stdout, stderr := c.terminalHelper.StdStreams()
   230  
   231  	inPipe, err := session.StdinPipe()
   232  	if err != nil {
   233  		return err
   234  	}
   235  
   236  	outPipe, err := session.StdoutPipe()
   237  	if err != nil {
   238  		return err
   239  	}
   240  
   241  	errPipe, err := session.StderrPipe()
   242  	if err != nil {
   243  		return err
   244  	}
   245  
   246  	stdinFd, stdinIsTerminal := c.terminalHelper.GetFdInfo(stdin)
   247  	stdoutFd, stdoutIsTerminal := c.terminalHelper.GetFdInfo(stdout)
   248  
   249  	if c.shouldAllocateTerminal(commands, terminalRequest, stdinIsTerminal) {
   250  		modes := ssh.TerminalModes{
   251  			ssh.ECHO:          1,
   252  			ssh.TTY_OP_ISPEED: 115200,
   253  			ssh.TTY_OP_OSPEED: 115200,
   254  		}
   255  
   256  		width, height := c.getWindowDimensions(stdoutFd)
   257  
   258  		err = session.RequestPty(c.terminalType(), height, width, modes)
   259  		if err != nil {
   260  			return err
   261  		}
   262  
   263  		var state *term.State
   264  		state, err = c.terminalHelper.SetRawTerminal(stdinFd)
   265  		if err == nil {
   266  			defer c.terminalHelper.RestoreTerminal(stdinFd, state)
   267  		}
   268  	}
   269  
   270  	if len(commands) > 0 {
   271  		cmd := strings.Join(commands, " ")
   272  		err = session.Start(cmd)
   273  		if err != nil {
   274  			return err
   275  		}
   276  	} else {
   277  		err = session.Shell()
   278  		if err != nil {
   279  			return err
   280  		}
   281  	}
   282  
   283  	wg := &sync.WaitGroup{}
   284  	wg.Add(2)
   285  
   286  	go copyAndClose(nil, inPipe, stdin)
   287  	go copyAndDone(wg, stdout, outPipe)
   288  	go copyAndDone(wg, stderr, errPipe)
   289  
   290  	if stdoutIsTerminal {
   291  		resized := make(chan os.Signal, 16)
   292  
   293  		if runtime.GOOS == "windows" {
   294  			ticker := time.NewTicker(250 * time.Millisecond)
   295  			defer ticker.Stop()
   296  
   297  			go func() {
   298  				for range ticker.C {
   299  					resized <- syscall.Signal(-1)
   300  				}
   301  				close(resized)
   302  			}()
   303  		} else {
   304  			signal.Notify(resized, sigwinch.SIGWINCH())
   305  			defer func() { signal.Stop(resized); close(resized) }()
   306  		}
   307  
   308  		go c.resize(resized, session, stdoutFd)
   309  	}
   310  
   311  	keepaliveStopCh := make(chan struct{})
   312  	defer close(keepaliveStopCh)
   313  
   314  	go keepalive(c.secureClient.Conn(), time.NewTicker(c.keepAliveInterval), keepaliveStopCh)
   315  
   316  	result := session.Wait()
   317  	wg.Wait()
   318  	return result
   319  }
   320  
   321  func (c *SecureShell) Wait() error {
   322  	keepaliveStopCh := make(chan struct{})
   323  	defer close(keepaliveStopCh)
   324  
   325  	go keepalive(c.secureClient.Conn(), time.NewTicker(c.keepAliveInterval), keepaliveStopCh)
   326  
   327  	return c.secureClient.Wait()
   328  }
   329  
   330  func md5Fingerprint(key ssh.PublicKey) string {
   331  	sum := md5.Sum(key.Marshal())
   332  	return strings.Replace(fmt.Sprintf("% x", sum), " ", ":", -1)
   333  }
   334  
   335  func hexSha1Fingerprint(key ssh.PublicKey) string {
   336  	sum := sha1.Sum(key.Marshal())
   337  	return strings.Replace(fmt.Sprintf("% x", sum), " ", ":", -1)
   338  }
   339  
   340  func base64Sha256Fingerprint(key ssh.PublicKey) string {
   341  	sum := sha256.Sum256(key.Marshal())
   342  	return base64.RawStdEncoding.EncodeToString(sum[:])
   343  }
   344  
   345  func fingerprintCallback(skipHostValidation bool, expectedFingerprint string) ssh.HostKeyCallback {
   346  	return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
   347  		if skipHostValidation {
   348  			return nil
   349  		}
   350  
   351  		var fingerprint string
   352  
   353  		switch len(expectedFingerprint) {
   354  		case base64Sha256FingerprintLength:
   355  			fingerprint = base64Sha256Fingerprint(key)
   356  		case hexSha1FingerprintLength:
   357  			fingerprint = hexSha1Fingerprint(key)
   358  		case md5FingerprintLength:
   359  			fingerprint = md5Fingerprint(key)
   360  		case 0:
   361  			fingerprint = md5Fingerprint(key)
   362  			return fmt.Errorf("Unable to verify identity of host.\n\nThe fingerprint of the received key was %q.", fingerprint)
   363  		default:
   364  			return errors.New("Unsupported host key fingerprint format")
   365  		}
   366  
   367  		if fingerprint != expectedFingerprint {
   368  			return fmt.Errorf("Host key verification failed.\n\nThe fingerprint of the received key was %q.", fingerprint)
   369  		}
   370  		return nil
   371  	}
   372  }
   373  
   374  func (c *SecureShell) shouldAllocateTerminal(commands []string, terminalRequest TTYRequest, stdinIsTerminal bool) bool {
   375  	switch terminalRequest {
   376  	case RequestTTYForce:
   377  		return true
   378  	case RequestTTYNo:
   379  		return false
   380  	case RequestTTYYes:
   381  		return stdinIsTerminal
   382  	case RequestTTYAuto:
   383  		return len(commands) == 0 && stdinIsTerminal
   384  	default:
   385  		return false
   386  	}
   387  }
   388  
   389  func (c *SecureShell) resize(resized <-chan os.Signal, session SecureSession, terminalFd uintptr) {
   390  	type resizeMessage struct {
   391  		Width       uint32
   392  		Height      uint32
   393  		PixelWidth  uint32
   394  		PixelHeight uint32
   395  	}
   396  
   397  	var previousWidth, previousHeight int
   398  
   399  	for range resized {
   400  		width, height := c.getWindowDimensions(terminalFd)
   401  
   402  		if width == previousWidth && height == previousHeight {
   403  			continue
   404  		}
   405  
   406  		message := resizeMessage{
   407  			Width:  uint32(width),
   408  			Height: uint32(height),
   409  		}
   410  
   411  		_, _ = session.SendRequest("window-change", false, ssh.Marshal(message))
   412  
   413  		previousWidth = width
   414  		previousHeight = height
   415  	}
   416  }
   417  
   418  func keepalive(conn ssh.Conn, ticker *time.Ticker, stopCh chan struct{}) {
   419  	for {
   420  		select {
   421  		case <-ticker.C:
   422  			_, _, _ = conn.SendRequest("keepalive@cloudfoundry.org", true, nil)
   423  		case <-stopCh:
   424  			ticker.Stop()
   425  			return
   426  		}
   427  	}
   428  }
   429  
   430  func (c *SecureShell) terminalType() string {
   431  	term := os.Getenv("TERM")
   432  	if term == "" {
   433  		term = "xterm"
   434  	}
   435  	return term
   436  }
   437  
   438  func (c *SecureShell) getWindowDimensions(terminalFd uintptr) (width int, height int) {
   439  	winSize, err := c.terminalHelper.GetWinsize(terminalFd)
   440  	if err != nil {
   441  		winSize = &term.Winsize{
   442  			Width:  80,
   443  			Height: 43,
   444  		}
   445  	}
   446  
   447  	return int(winSize.Width), int(winSize.Height)
   448  }