github.com/mook-as/cf-cli@v7.0.0-beta.28.0.20200120190804-b91c115fae48+incompatible/util/clissh/ssh.go (about)

     1  package clissh
     2  
     3  import (
     4  	"crypto/md5"
     5  	"crypto/sha1"
     6  	"crypto/sha256"
     7  	"encoding/base64"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"net"
    12  	"os"
    13  	"os/signal"
    14  	"runtime"
    15  	"strings"
    16  	"sync"
    17  	"syscall"
    18  	"time"
    19  
    20  	"code.cloudfoundry.org/cli/cf/ssh/sigwinch"
    21  	"code.cloudfoundry.org/cli/util/clissh/ssherror"
    22  	"github.com/moby/moby/pkg/term"
    23  	log "github.com/sirupsen/logrus"
    24  	"golang.org/x/crypto/ssh"
    25  )
    26  
    27  const (
    28  	md5FingerprintLength          = 47 // inclusive of space between bytes
    29  	hexSha1FingerprintLength      = 59 // inclusive of space between bytes
    30  	base64Sha256FingerprintLength = 43
    31  
    32  	DefaultKeepAliveInterval = 30 * time.Second
    33  )
    34  
    35  type LocalPortForward struct {
    36  	LocalAddress  string
    37  	RemoteAddress string
    38  }
    39  
    40  type SecureShell struct {
    41  	secureDialer    SecureDialer
    42  	secureClient    SecureClient
    43  	terminalHelper  TerminalHelper
    44  	listenerFactory ListenerFactory
    45  
    46  	localListeners    []net.Listener
    47  	keepAliveInterval time.Duration
    48  }
    49  
    50  func NewDefaultSecureShell() *SecureShell {
    51  	defaultSecureDialer := DefaultSecureDialer()
    52  	defaultTerminalHelper := DefaultTerminalHelper()
    53  	defaultListenerFactory := DefaultListenerFactory()
    54  	return &SecureShell{
    55  		secureDialer:      defaultSecureDialer,
    56  		terminalHelper:    defaultTerminalHelper,
    57  		listenerFactory:   defaultListenerFactory,
    58  		keepAliveInterval: DefaultKeepAliveInterval,
    59  		localListeners:    []net.Listener{},
    60  	}
    61  }
    62  
    63  func NewSecureShell(
    64  	secureDialer SecureDialer,
    65  	terminalHelper TerminalHelper,
    66  	listenerFactory ListenerFactory,
    67  	keepAliveInterval time.Duration,
    68  ) *SecureShell {
    69  	return &SecureShell{
    70  		secureDialer:      secureDialer,
    71  		terminalHelper:    terminalHelper,
    72  		listenerFactory:   listenerFactory,
    73  		keepAliveInterval: keepAliveInterval,
    74  		localListeners:    []net.Listener{},
    75  	}
    76  }
    77  
    78  func (c *SecureShell) Close() error {
    79  	for _, listener := range c.localListeners {
    80  		listener.Close()
    81  	}
    82  	return c.secureClient.Close()
    83  }
    84  
    85  func (c *SecureShell) Connect(username string, passcode string, appSSHEndpoint string, appSSHHostKeyFingerprint string, skipHostValidation bool) error {
    86  	hostKeyCallbackFunction := fingerprintCallback(skipHostValidation, appSSHHostKeyFingerprint)
    87  
    88  	clientConfig := &ssh.ClientConfig{
    89  		User:            username,
    90  		Auth:            []ssh.AuthMethod{ssh.Password(passcode)},
    91  		HostKeyCallback: hostKeyCallbackFunction,
    92  	}
    93  
    94  	secureClient, err := c.secureDialer.Dial("tcp", appSSHEndpoint, clientConfig)
    95  	if err != nil {
    96  		if strings.Contains(err.Error(), "ssh: unable to authenticate") {
    97  			return ssherror.UnableToAuthenticateError{Err: err}
    98  		}
    99  		return err
   100  	}
   101  
   102  	c.secureClient = secureClient
   103  	return nil
   104  }
   105  
   106  func (c *SecureShell) InteractiveSession(commands []string, terminalRequest TTYRequest) error {
   107  	session, err := c.secureClient.NewSession()
   108  	if err != nil {
   109  		return fmt.Errorf("SSH session allocation failed: %s", err.Error())
   110  	}
   111  	defer session.Close()
   112  
   113  	stdin, stdout, stderr := c.terminalHelper.StdStreams()
   114  
   115  	inPipe, err := session.StdinPipe()
   116  	if err != nil {
   117  		return err
   118  	}
   119  
   120  	outPipe, err := session.StdoutPipe()
   121  	if err != nil {
   122  		return err
   123  	}
   124  
   125  	errPipe, err := session.StderrPipe()
   126  	if err != nil {
   127  		return err
   128  	}
   129  
   130  	stdinFd, stdinIsTerminal := c.terminalHelper.GetFdInfo(stdin)
   131  	stdoutFd, stdoutIsTerminal := c.terminalHelper.GetFdInfo(stdout)
   132  
   133  	if c.shouldAllocateTerminal(commands, terminalRequest, stdinIsTerminal) {
   134  		modes := ssh.TerminalModes{
   135  			ssh.ECHO:          1,
   136  			ssh.TTY_OP_ISPEED: 115200,
   137  			ssh.TTY_OP_OSPEED: 115200,
   138  		}
   139  
   140  		width, height := c.getWindowDimensions(stdoutFd)
   141  
   142  		err = session.RequestPty(c.terminalType(), height, width, modes)
   143  		if err != nil {
   144  			return err
   145  		}
   146  
   147  		var state *term.State
   148  		state, err = c.terminalHelper.SetRawTerminal(stdinFd)
   149  		if err == nil {
   150  			defer func() {
   151  				err := c.terminalHelper.RestoreTerminal(stdinFd, state)
   152  				log.Errorln("restore terminal", err)
   153  			}()
   154  		}
   155  	}
   156  
   157  	if len(commands) > 0 {
   158  		cmd := strings.Join(commands, " ")
   159  		err = session.Start(cmd)
   160  		if err != nil {
   161  			return err
   162  		}
   163  	} else {
   164  		err = session.Shell()
   165  		if err != nil {
   166  			return err
   167  		}
   168  	}
   169  
   170  	wg := &sync.WaitGroup{}
   171  	wg.Add(2)
   172  
   173  	go copyAndClose(nil, inPipe, stdin)
   174  	go copyAndDone(wg, stdout, outPipe)
   175  	go copyAndDone(wg, stderr, errPipe)
   176  
   177  	if stdoutIsTerminal {
   178  		resized := make(chan os.Signal, 16)
   179  
   180  		if runtime.GOOS == "windows" {
   181  			ticker := time.NewTicker(250 * time.Millisecond)
   182  			defer ticker.Stop()
   183  
   184  			go func() {
   185  				for range ticker.C {
   186  					resized <- syscall.Signal(-1)
   187  				}
   188  				close(resized)
   189  			}()
   190  		} else {
   191  			signal.Notify(resized, sigwinch.SIGWINCH())
   192  			defer func() { signal.Stop(resized); close(resized) }()
   193  		}
   194  
   195  		go c.resize(resized, session, stdoutFd)
   196  	}
   197  
   198  	keepaliveStopCh := make(chan struct{})
   199  	defer close(keepaliveStopCh)
   200  
   201  	go keepalive(c.secureClient.Conn(), time.NewTicker(c.keepAliveInterval), keepaliveStopCh)
   202  
   203  	result := session.Wait()
   204  	wg.Wait()
   205  	return result
   206  }
   207  
   208  func (c *SecureShell) LocalPortForward(localPortForwardSpecs []LocalPortForward) error {
   209  	for _, spec := range localPortForwardSpecs {
   210  		listener, err := c.listenerFactory.Listen("tcp", spec.LocalAddress)
   211  		if err != nil {
   212  			return err
   213  		}
   214  		c.localListeners = append(c.localListeners, listener)
   215  
   216  		go c.localForwardAcceptLoop(listener, spec.RemoteAddress)
   217  	}
   218  
   219  	return nil
   220  }
   221  
   222  func (c *SecureShell) Wait() error {
   223  	keepaliveStopCh := make(chan struct{})
   224  	defer close(keepaliveStopCh)
   225  
   226  	go keepalive(c.secureClient.Conn(), time.NewTicker(c.keepAliveInterval), keepaliveStopCh)
   227  
   228  	return c.secureClient.Wait()
   229  }
   230  
   231  func (c *SecureShell) getWindowDimensions(terminalFd uintptr) (width int, height int) {
   232  	winSize, err := c.terminalHelper.GetWinsize(terminalFd)
   233  	if err != nil {
   234  		winSize = &term.Winsize{
   235  			Width:  80,
   236  			Height: 43,
   237  		}
   238  	}
   239  
   240  	return int(winSize.Width), int(winSize.Height)
   241  }
   242  
   243  func (c *SecureShell) handleForwardConnection(conn net.Conn, targetAddr string) {
   244  	defer conn.Close()
   245  
   246  	target, err := c.secureClient.Dial("tcp", targetAddr)
   247  	if err != nil {
   248  		fmt.Printf("connect to %s failed: %s\n", targetAddr, err.Error())
   249  		return
   250  	}
   251  	defer target.Close()
   252  
   253  	wg := &sync.WaitGroup{}
   254  	wg.Add(2)
   255  
   256  	go copyAndClose(wg, conn, target)
   257  	go copyAndClose(wg, target, conn)
   258  	wg.Wait()
   259  }
   260  
   261  func (c *SecureShell) localForwardAcceptLoop(listener net.Listener, addr string) {
   262  	defer listener.Close()
   263  
   264  	for {
   265  		conn, err := listener.Accept()
   266  		if err != nil {
   267  			if netErr, ok := err.(net.Error); ok && netErr.Temporary() {
   268  				time.Sleep(100 * time.Millisecond)
   269  				continue
   270  			}
   271  			return
   272  		}
   273  
   274  		go c.handleForwardConnection(conn, addr)
   275  	}
   276  }
   277  
   278  func (c *SecureShell) resize(resized <-chan os.Signal, session SecureSession, terminalFd uintptr) {
   279  	type resizeMessage struct {
   280  		Width       uint32
   281  		Height      uint32
   282  		PixelWidth  uint32
   283  		PixelHeight uint32
   284  	}
   285  
   286  	var previousWidth, previousHeight int
   287  
   288  	for range resized {
   289  		width, height := c.getWindowDimensions(terminalFd)
   290  
   291  		if width == previousWidth && height == previousHeight {
   292  			continue
   293  		}
   294  
   295  		message := resizeMessage{
   296  			Width:  uint32(width),
   297  			Height: uint32(height),
   298  		}
   299  
   300  		_, err := session.SendRequest("window-change", false, ssh.Marshal(message))
   301  		if err != nil {
   302  			log.Errorln("window-change:", err)
   303  		}
   304  
   305  		previousWidth = width
   306  		previousHeight = height
   307  	}
   308  }
   309  
   310  func (c *SecureShell) shouldAllocateTerminal(commands []string, terminalRequest TTYRequest, stdinIsTerminal bool) bool {
   311  	switch terminalRequest {
   312  	case RequestTTYForce:
   313  		return true
   314  	case RequestTTYNo:
   315  		return false
   316  	case RequestTTYYes:
   317  		return stdinIsTerminal
   318  	case RequestTTYAuto:
   319  		return len(commands) == 0 && stdinIsTerminal
   320  	default:
   321  		return false
   322  	}
   323  }
   324  
   325  func (c *SecureShell) terminalType() string {
   326  	term := os.Getenv("TERM")
   327  	if term == "" {
   328  		term = "xterm"
   329  	}
   330  	return term
   331  }
   332  
   333  func base64Sha256Fingerprint(key ssh.PublicKey) string {
   334  	sum := sha256.Sum256(key.Marshal())
   335  	return base64.RawStdEncoding.EncodeToString(sum[:])
   336  }
   337  
   338  func copyAndClose(wg *sync.WaitGroup, dest io.WriteCloser, src io.Reader) {
   339  	_, err := io.Copy(dest, src)
   340  	if err != nil {
   341  		log.Errorln("copy and close:", err)
   342  	}
   343  	_ = dest.Close()
   344  	if wg != nil {
   345  		wg.Done()
   346  	}
   347  }
   348  
   349  func copyAndDone(wg *sync.WaitGroup, dest io.Writer, src io.Reader) {
   350  	_, err := io.Copy(dest, src)
   351  	if err != nil {
   352  		log.Errorln("copy and done:", err)
   353  	}
   354  	wg.Done()
   355  }
   356  
   357  func fingerprintCallback(skipHostValidation bool, expectedFingerprint string) ssh.HostKeyCallback {
   358  	return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
   359  		if skipHostValidation {
   360  			return nil
   361  		}
   362  
   363  		var fingerprint string
   364  
   365  		switch len(expectedFingerprint) {
   366  		case base64Sha256FingerprintLength:
   367  			fingerprint = base64Sha256Fingerprint(key)
   368  		case hexSha1FingerprintLength:
   369  			fingerprint = hexSha1Fingerprint(key)
   370  		case md5FingerprintLength:
   371  			fingerprint = md5Fingerprint(key)
   372  		case 0:
   373  			fingerprint = md5Fingerprint(key)
   374  			return fmt.Errorf("Unable to verify identity of host.\n\nThe fingerprint of the received key was %q.", fingerprint)
   375  		default:
   376  			return errors.New("Unsupported host key fingerprint format")
   377  		}
   378  
   379  		if fingerprint != expectedFingerprint {
   380  			return fmt.Errorf("Host key verification failed.\n\nThe fingerprint of the received key was %q.", fingerprint)
   381  		}
   382  		return nil
   383  	}
   384  }
   385  
   386  func hexSha1Fingerprint(key ssh.PublicKey) string {
   387  	sum := sha1.Sum(key.Marshal())
   388  	return strings.Replace(fmt.Sprintf("% x", sum), " ", ":", -1)
   389  }
   390  
   391  func keepalive(conn ssh.Conn, ticker *time.Ticker, stopCh chan struct{}) {
   392  	for {
   393  		select {
   394  		case <-ticker.C:
   395  			_, _, err := conn.SendRequest("keepalive@cloudfoundry.org", true, nil)
   396  			if err != nil {
   397  				log.Errorln("err sending keep alive:", err)
   398  			}
   399  		case <-stopCh:
   400  			ticker.Stop()
   401  			return
   402  		}
   403  	}
   404  }
   405  
   406  func md5Fingerprint(key ssh.PublicKey) string {
   407  	sum := md5.Sum(key.Marshal())
   408  	return strings.Replace(fmt.Sprintf("% x", sum), " ", ":", -1)
   409  }