github.com/ablease/cli@v6.37.1-0.20180613014814-3adbb7d7fb19+incompatible/util/clissh/ssh_test.go (about)

     1  // +build !windows,!386
     2  
     3  // skipping 386 because lager uses UInt64 in Session()
     4  // skipping windows because Unix/Linux only syscall in test.
     5  // should refactor out the conflicts so we could test this package in multi platforms.
     6  
     7  package clissh_test
     8  
     9  import (
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"net"
    14  	"os"
    15  	"sync"
    16  	"syscall"
    17  	"time"
    18  
    19  	"code.cloudfoundry.org/cli/util/clissh/clisshfakes"
    20  	"code.cloudfoundry.org/cli/util/clissh/ssherror"
    21  	"code.cloudfoundry.org/diego-ssh/server"
    22  	fake_server "code.cloudfoundry.org/diego-ssh/server/fakes"
    23  	"code.cloudfoundry.org/diego-ssh/test_helpers"
    24  	"code.cloudfoundry.org/diego-ssh/test_helpers/fake_io"
    25  	"code.cloudfoundry.org/diego-ssh/test_helpers/fake_net"
    26  	"code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh"
    27  	"code.cloudfoundry.org/lager/lagertest"
    28  	"github.com/kr/pty"
    29  	"github.com/moby/moby/pkg/term"
    30  	"golang.org/x/crypto/ssh"
    31  
    32  	. "code.cloudfoundry.org/cli/util/clissh"
    33  	. "github.com/onsi/ginkgo"
    34  	. "github.com/onsi/gomega"
    35  )
    36  
    37  func BlockAcceptOnClose(fake *fake_net.FakeListener) {
    38  	waitUntilClosed := make(chan bool)
    39  	fake.AcceptStub = func() (net.Conn, error) {
    40  		<-waitUntilClosed
    41  		return nil, errors.New("banana")
    42  	}
    43  
    44  	var once sync.Once
    45  	fake.CloseStub = func() error {
    46  		once.Do(func() {
    47  			close(waitUntilClosed)
    48  		})
    49  		return nil
    50  	}
    51  }
    52  
    53  var _ = Describe("CLI SSH", func() {
    54  	var (
    55  		fakeSecureDialer    *clisshfakes.FakeSecureDialer
    56  		fakeSecureClient    *clisshfakes.FakeSecureClient
    57  		fakeTerminalHelper  *clisshfakes.FakeTerminalHelper
    58  		fakeListenerFactory *clisshfakes.FakeListenerFactory
    59  		fakeSecureSession   *clisshfakes.FakeSecureSession
    60  
    61  		fakeConnection *fake_ssh.FakeConn
    62  		stdinPipe      *fake_io.FakeWriteCloser
    63  		stdoutPipe     *fake_io.FakeReader
    64  		stderrPipe     *fake_io.FakeReader
    65  		secureShell    *SecureShell
    66  
    67  		username               string
    68  		passcode               string
    69  		sshEndpoint            string
    70  		sshEndpointFingerprint string
    71  		skipHostValidation     bool
    72  		commands               []string
    73  		terminalRequest        TTYRequest
    74  		keepAliveDuration      time.Duration
    75  	)
    76  
    77  	BeforeEach(func() {
    78  		fakeSecureDialer = new(clisshfakes.FakeSecureDialer)
    79  		fakeSecureClient = new(clisshfakes.FakeSecureClient)
    80  		fakeTerminalHelper = new(clisshfakes.FakeTerminalHelper)
    81  		fakeListenerFactory = new(clisshfakes.FakeListenerFactory)
    82  		fakeSecureSession = new(clisshfakes.FakeSecureSession)
    83  
    84  		fakeConnection = new(fake_ssh.FakeConn)
    85  		stdinPipe = new(fake_io.FakeWriteCloser)
    86  		stdoutPipe = new(fake_io.FakeReader)
    87  		stderrPipe = new(fake_io.FakeReader)
    88  
    89  		fakeListenerFactory.ListenStub = net.Listen
    90  		fakeSecureClient.NewSessionReturns(fakeSecureSession, nil)
    91  		fakeSecureClient.ConnReturns(fakeConnection)
    92  		fakeSecureDialer.DialReturns(fakeSecureClient, nil)
    93  
    94  		stdinPipe.WriteStub = func(p []byte) (int, error) {
    95  			return len(p), nil
    96  		}
    97  		fakeSecureSession.StdinPipeReturns(stdinPipe, nil)
    98  
    99  		stdoutPipe.ReadStub = func(p []byte) (int, error) {
   100  			return 0, io.EOF
   101  		}
   102  		fakeSecureSession.StdoutPipeReturns(stdoutPipe, nil)
   103  
   104  		stderrPipe.ReadStub = func(p []byte) (int, error) {
   105  			return 0, io.EOF
   106  		}
   107  		fakeSecureSession.StderrPipeReturns(stderrPipe, nil)
   108  
   109  		username = "some-user"
   110  		passcode = "some-passcode"
   111  		sshEndpoint = "some-endpoint"
   112  		sshEndpointFingerprint = "some-fingerprint"
   113  		skipHostValidation = false
   114  		commands = []string{}
   115  		terminalRequest = RequestTTYAuto
   116  		keepAliveDuration = DefaultKeepAliveInterval
   117  	})
   118  
   119  	JustBeforeEach(func() {
   120  		secureShell = NewSecureShell(
   121  			fakeSecureDialer,
   122  			fakeTerminalHelper,
   123  			fakeListenerFactory,
   124  			keepAliveDuration,
   125  		)
   126  	})
   127  
   128  	Describe("Connect", func() {
   129  		var connectErr error
   130  
   131  		JustBeforeEach(func() {
   132  			connectErr = secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation)
   133  		})
   134  
   135  		Context("when dialing succeeds", func() {
   136  			It("creates the ssh client", func() {
   137  				Expect(connectErr).ToNot(HaveOccurred())
   138  
   139  				Expect(fakeSecureDialer.DialCallCount()).To(Equal(1))
   140  				protocolArg, sshEndpointArg, sshConfigArg := fakeSecureDialer.DialArgsForCall(0)
   141  				Expect(protocolArg).To(Equal("tcp"))
   142  				Expect(sshEndpointArg).To(Equal(sshEndpoint))
   143  				Expect(sshConfigArg.User).To(Equal(username))
   144  				Expect(sshConfigArg.Auth).To(HaveLen(1))
   145  				Expect(sshConfigArg.HostKeyCallback).ToNot(BeNil())
   146  			})
   147  		})
   148  
   149  		Context("when dialing fails", func() {
   150  			var dialError error
   151  
   152  			Context("when the error is a generic Dial error", func() {
   153  				BeforeEach(func() {
   154  					dialError = errors.New("woops")
   155  					fakeSecureDialer.DialReturns(nil, dialError)
   156  				})
   157  
   158  				It("returns the dial error", func() {
   159  					Expect(connectErr).To(Equal(dialError))
   160  					Expect(fakeSecureDialer.DialCallCount()).To(Equal(1))
   161  				})
   162  			})
   163  
   164  			Context("when the dialing error is a golang 'unable to authenticate' error", func() {
   165  				BeforeEach(func() {
   166  					dialError = fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", []string{"none", "password"})
   167  					fakeSecureDialer.DialReturns(nil, dialError)
   168  				})
   169  
   170  				It("returns an UnableToAuthenticateError", func() {
   171  					Expect(connectErr).To(MatchError(ssherror.UnableToAuthenticateError{Err: dialError}))
   172  					Expect(fakeSecureDialer.DialCallCount()).To(Equal(1))
   173  				})
   174  			})
   175  		})
   176  	})
   177  
   178  	Describe("InteractiveSession", func() {
   179  		var (
   180  			stdin          *fake_io.FakeReadCloser
   181  			stdout, stderr *fake_io.FakeWriter
   182  
   183  			sessionErr                error
   184  			interactiveSessionInvoker func(secureShell *SecureShell)
   185  		)
   186  
   187  		BeforeEach(func() {
   188  			stdin = new(fake_io.FakeReadCloser)
   189  			stdout = new(fake_io.FakeWriter)
   190  			stderr = new(fake_io.FakeWriter)
   191  
   192  			fakeTerminalHelper.StdStreamsReturns(stdin, stdout, stderr)
   193  			interactiveSessionInvoker = func(secureShell *SecureShell) {
   194  				sessionErr = secureShell.InteractiveSession(commands, terminalRequest)
   195  			}
   196  		})
   197  
   198  		JustBeforeEach(func() {
   199  			connectErr := secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation)
   200  			Expect(connectErr).NotTo(HaveOccurred())
   201  			interactiveSessionInvoker(secureShell)
   202  		})
   203  
   204  		Context("when host key validation is enabled", func() {
   205  			var (
   206  				callback func(hostname string, remote net.Addr, key ssh.PublicKey) error
   207  				addr     net.Addr
   208  			)
   209  
   210  			BeforeEach(func() {
   211  				skipHostValidation = false
   212  			})
   213  
   214  			JustBeforeEach(func() {
   215  				Expect(fakeSecureDialer.DialCallCount()).To(Equal(1))
   216  				_, _, config := fakeSecureDialer.DialArgsForCall(0)
   217  				callback = config.HostKeyCallback
   218  
   219  				listener, err := net.Listen("tcp", "localhost:0")
   220  				Expect(err).NotTo(HaveOccurred())
   221  
   222  				addr = listener.Addr()
   223  				listener.Close()
   224  			})
   225  
   226  			Context("when the md5 fingerprint matches", func() {
   227  				BeforeEach(func() {
   228  					sshEndpointFingerprint = "41:ce:56:e6:9c:42:a9:c6:9e:68:ac:e3:4d:f6:38:79"
   229  				})
   230  
   231  				It("does not return an error", func() {
   232  					Expect(callback("", addr, TestHostKey.PublicKey())).ToNot(HaveOccurred())
   233  				})
   234  			})
   235  
   236  			Context("when the hex sha1 fingerprint matches", func() {
   237  				BeforeEach(func() {
   238  					sshEndpointFingerprint = "a8:e2:67:cb:ea:2a:6e:23:a1:72:ce:8f:07:92:15:ee:1f:82:f8:ca"
   239  				})
   240  
   241  				It("does not return an error", func() {
   242  					Expect(callback("", addr, TestHostKey.PublicKey())).ToNot(HaveOccurred())
   243  				})
   244  			})
   245  
   246  			Context("when the base64 sha256 fingerprint matches", func() {
   247  				BeforeEach(func() {
   248  					sshEndpointFingerprint = "sp/jrLuj66r+yrLDUKZdJU5tdzt4mq/UaSiNBjpgr+8"
   249  				})
   250  
   251  				It("does not return an error", func() {
   252  					Expect(callback("", addr, TestHostKey.PublicKey())).ToNot(HaveOccurred())
   253  				})
   254  			})
   255  
   256  			Context("when the base64 SHA256 fingerprint does not match", func() {
   257  				BeforeEach(func() {
   258  					sshEndpointFingerprint = "0000000000000000000000000000000000000000000"
   259  				})
   260  
   261  				It("returns an error'", func() {
   262  					err := callback("", addr, TestHostKey.PublicKey())
   263  					Expect(err).To(MatchError(MatchRegexp("Host key verification failed\\.")))
   264  					Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\"")))
   265  				})
   266  			})
   267  
   268  			Context("when the hex SHA1 fingerprint does not match", func() {
   269  				BeforeEach(func() {
   270  					sshEndpointFingerprint = "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00"
   271  				})
   272  
   273  				It("returns an error'", func() {
   274  					err := callback("", addr, TestHostKey.PublicKey())
   275  					Expect(err).To(MatchError(MatchRegexp("Host key verification failed\\.")))
   276  					Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\"")))
   277  				})
   278  			})
   279  
   280  			Context("when the MD5 fingerprint does not match", func() {
   281  				BeforeEach(func() {
   282  					sshEndpointFingerprint = "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00"
   283  				})
   284  
   285  				It("returns an error'", func() {
   286  					err := callback("", addr, TestHostKey.PublicKey())
   287  					Expect(err).To(MatchError(MatchRegexp("Host key verification failed\\.")))
   288  					Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\"")))
   289  				})
   290  			})
   291  
   292  			Context("when no fingerprint is present in endpoint info", func() {
   293  				BeforeEach(func() {
   294  					sshEndpointFingerprint = ""
   295  					sshEndpoint = ""
   296  				})
   297  
   298  				It("returns an error'", func() {
   299  					err := callback("", addr, TestHostKey.PublicKey())
   300  					Expect(err).To(MatchError(MatchRegexp("Unable to verify identity of host\\.")))
   301  					Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\"")))
   302  				})
   303  			})
   304  
   305  			Context("when the fingerprint length doesn't make sense", func() {
   306  				BeforeEach(func() {
   307  					sshEndpointFingerprint = "garbage"
   308  				})
   309  
   310  				It("returns an error", func() {
   311  					err := callback("", addr, TestHostKey.PublicKey())
   312  					Eventually(err).Should(MatchError(MatchRegexp("Unsupported host key fingerprint format")))
   313  				})
   314  			})
   315  		})
   316  
   317  		Context("when the skip host validation flag is set", func() {
   318  			BeforeEach(func() {
   319  				skipHostValidation = true
   320  			})
   321  
   322  			It("the HostKeyCallback on the Config to always return nil", func() {
   323  				Expect(fakeSecureDialer.DialCallCount()).To(Equal(1))
   324  
   325  				_, _, config := fakeSecureDialer.DialArgsForCall(0)
   326  				Expect(config.HostKeyCallback("some-hostname", nil, nil)).To(BeNil())
   327  			})
   328  		})
   329  
   330  		// TODO: see if it's possible to test the piping between the ss client input and outputs and the UI object we pass in
   331  		Context("when dialing is successful", func() {
   332  			It("creates a new secure shell session", func() {
   333  				Expect(fakeSecureClient.NewSessionCallCount()).To(Equal(1))
   334  			})
   335  
   336  			It("closes the session", func() {
   337  				Expect(fakeSecureSession.CloseCallCount()).To(Equal(1))
   338  			})
   339  
   340  			It("gets a stdin pipe for the session", func() {
   341  				Expect(fakeSecureSession.StdinPipeCallCount()).To(Equal(1))
   342  			})
   343  
   344  			Context("when getting the stdin pipe fails", func() {
   345  				BeforeEach(func() {
   346  					fakeSecureSession.StdinPipeReturns(nil, errors.New("woops"))
   347  				})
   348  
   349  				It("returns the error", func() {
   350  					Expect(sessionErr).Should(MatchError("woops"))
   351  				})
   352  			})
   353  
   354  			It("gets a stdout pipe for the session", func() {
   355  				Expect(fakeSecureSession.StdoutPipeCallCount()).To(Equal(1))
   356  			})
   357  
   358  			Context("when getting the stdout pipe fails", func() {
   359  				BeforeEach(func() {
   360  					fakeSecureSession.StdoutPipeReturns(nil, errors.New("woops"))
   361  				})
   362  
   363  				It("returns the error", func() {
   364  					Expect(sessionErr).Should(MatchError("woops"))
   365  				})
   366  			})
   367  
   368  			It("gets a stderr pipe for the session", func() {
   369  				Expect(fakeSecureSession.StderrPipeCallCount()).To(Equal(1))
   370  			})
   371  
   372  			Context("when getting the stderr pipe fails", func() {
   373  				BeforeEach(func() {
   374  					fakeSecureSession.StderrPipeReturns(nil, errors.New("woops"))
   375  				})
   376  
   377  				It("returns the error", func() {
   378  					Expect(sessionErr).Should(MatchError("woops"))
   379  				})
   380  			})
   381  		})
   382  
   383  		Context("when stdin is a terminal", func() {
   384  			var master, slave *os.File
   385  
   386  			BeforeEach(func() {
   387  				var err error
   388  				master, slave, err = pty.Open()
   389  				Expect(err).NotTo(HaveOccurred())
   390  
   391  				terminalRequest = RequestTTYForce
   392  
   393  				terminalHelper := DefaultTerminalHelper()
   394  				fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo
   395  				fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize
   396  			})
   397  
   398  			AfterEach(func() {
   399  				master.Close()
   400  				// slave.Close() // race
   401  			})
   402  
   403  			Context("when a command is not specified", func() {
   404  				var terminalType string
   405  
   406  				BeforeEach(func() {
   407  					terminalType = os.Getenv("TERM")
   408  					os.Setenv("TERM", "test-terminal-type")
   409  
   410  					winsize := &term.Winsize{Width: 1024, Height: 256}
   411  					fakeTerminalHelper.GetWinsizeReturns(winsize, nil)
   412  
   413  					fakeSecureSession.ShellStub = func() error {
   414  						Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1))
   415  						Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(0))
   416  						return nil
   417  					}
   418  				})
   419  
   420  				AfterEach(func() {
   421  					os.Setenv("TERM", terminalType)
   422  				})
   423  
   424  				It("requests a pty with the correct terminal type, window size, and modes", func() {
   425  					Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1))
   426  					Expect(fakeTerminalHelper.GetWinsizeCallCount()).To(Equal(1))
   427  
   428  					termType, height, width, modes := fakeSecureSession.RequestPtyArgsForCall(0)
   429  					Expect(termType).To(Equal("test-terminal-type"))
   430  					Expect(height).To(Equal(256))
   431  					Expect(width).To(Equal(1024))
   432  
   433  					expectedModes := ssh.TerminalModes{
   434  						ssh.ECHO:          1,
   435  						ssh.TTY_OP_ISPEED: 115200,
   436  						ssh.TTY_OP_OSPEED: 115200,
   437  					}
   438  					Expect(modes).To(Equal(expectedModes))
   439  				})
   440  
   441  				Context("when the TERM environment variable is not set", func() {
   442  					BeforeEach(func() {
   443  						os.Unsetenv("TERM")
   444  					})
   445  
   446  					It("requests a pty with the default terminal type", func() {
   447  						Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1))
   448  
   449  						termType, _, _, _ := fakeSecureSession.RequestPtyArgsForCall(0)
   450  						Expect(termType).To(Equal("xterm"))
   451  					})
   452  				})
   453  
   454  				It("puts the terminal into raw mode and restores it after running the shell", func() {
   455  					Expect(fakeSecureSession.ShellCallCount()).To(Equal(1))
   456  					Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1))
   457  					Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(1))
   458  				})
   459  
   460  				Context("when the pty allocation fails", func() {
   461  					var ptyError error
   462  
   463  					BeforeEach(func() {
   464  						ptyError = errors.New("pty allocation error")
   465  						fakeSecureSession.RequestPtyReturns(ptyError)
   466  					})
   467  
   468  					It("returns the error", func() {
   469  						Expect(sessionErr).To(Equal(ptyError))
   470  					})
   471  				})
   472  
   473  				Context("when placing the terminal into raw mode fails", func() {
   474  					BeforeEach(func() {
   475  						fakeTerminalHelper.SetRawTerminalReturns(nil, errors.New("woops"))
   476  					})
   477  
   478  					It("keeps calm and carries on", func() {
   479  						Expect(fakeSecureSession.ShellCallCount()).To(Equal(1))
   480  					})
   481  
   482  					It("does not not restore the terminal", func() {
   483  						Expect(fakeSecureSession.ShellCallCount()).To(Equal(1))
   484  						Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1))
   485  						Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(0))
   486  					})
   487  				})
   488  			})
   489  
   490  			Context("when a command is specified", func() {
   491  				BeforeEach(func() {
   492  					commands = []string{"echo", "-n", "hello"}
   493  				})
   494  
   495  				Context("when a terminal is requested", func() {
   496  					BeforeEach(func() {
   497  						terminalRequest = RequestTTYYes
   498  						fakeTerminalHelper.GetFdInfoReturns(0, true)
   499  					})
   500  
   501  					It("requests a pty", func() {
   502  						Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1))
   503  					})
   504  				})
   505  
   506  				Context("when a terminal is not explicitly requested", func() {
   507  					BeforeEach(func() {
   508  						terminalRequest = RequestTTYAuto
   509  					})
   510  
   511  					It("does not request a pty", func() {
   512  						Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0))
   513  					})
   514  				})
   515  			})
   516  		})
   517  
   518  		Context("when stdin is not a terminal", func() {
   519  			BeforeEach(func() {
   520  				stdin.ReadStub = func(p []byte) (int, error) {
   521  					return 0, io.EOF
   522  				}
   523  
   524  				terminalHelper := DefaultTerminalHelper()
   525  				fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo
   526  				fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize
   527  			})
   528  
   529  			Context("when a terminal is not requested", func() {
   530  				It("does not request a pty", func() {
   531  					Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0))
   532  				})
   533  			})
   534  
   535  			Context("when a terminal is requested", func() {
   536  				BeforeEach(func() {
   537  					terminalRequest = RequestTTYYes
   538  				})
   539  
   540  				It("does not request a pty", func() {
   541  					Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0))
   542  				})
   543  			})
   544  		})
   545  
   546  		PContext("when a terminal is forced", func() {
   547  			BeforeEach(func() {
   548  				terminalRequest = RequestTTYForce
   549  			})
   550  
   551  			It("requests a pty", func() {
   552  				Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1))
   553  			})
   554  		})
   555  
   556  		Context("when a terminal is disabled", func() {
   557  			BeforeEach(func() {
   558  				terminalRequest = RequestTTYNo
   559  			})
   560  
   561  			It("does not request a pty", func() {
   562  				Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0))
   563  			})
   564  		})
   565  
   566  		Context("when a command is not specified", func() {
   567  			It("requests an interactive shell", func() {
   568  				Expect(fakeSecureSession.ShellCallCount()).To(Equal(1))
   569  			})
   570  
   571  			Context("when the shell request returns an error", func() {
   572  				BeforeEach(func() {
   573  					fakeSecureSession.ShellReturns(errors.New("oh bother"))
   574  				})
   575  
   576  				It("returns the error", func() {
   577  					Expect(sessionErr).To(MatchError("oh bother"))
   578  				})
   579  			})
   580  		})
   581  
   582  		Context("when a command is specifed", func() {
   583  			BeforeEach(func() {
   584  				commands = []string{"echo", "-n", "hello"}
   585  			})
   586  
   587  			It("starts the command", func() {
   588  				Expect(fakeSecureSession.StartCallCount()).To(Equal(1))
   589  				Expect(fakeSecureSession.StartArgsForCall(0)).To(Equal("echo -n hello"))
   590  			})
   591  
   592  			Context("when the command fails to start", func() {
   593  				BeforeEach(func() {
   594  					fakeSecureSession.StartReturns(errors.New("oh well"))
   595  				})
   596  
   597  				It("returns the error", func() {
   598  					Expect(sessionErr).To(MatchError("oh well"))
   599  				})
   600  			})
   601  		})
   602  
   603  		Context("when the shell or command has started", func() {
   604  			BeforeEach(func() {
   605  				stdin.ReadStub = func(p []byte) (int, error) {
   606  					p[0] = 0
   607  					return 1, io.EOF
   608  				}
   609  				stdinPipe.WriteStub = func(p []byte) (int, error) {
   610  					defer GinkgoRecover()
   611  					Expect(p[0]).To(Equal(byte(0)))
   612  					return 1, nil
   613  				}
   614  
   615  				stdoutPipe.ReadStub = func(p []byte) (int, error) {
   616  					p[0] = 1
   617  					return 1, io.EOF
   618  				}
   619  				stdout.WriteStub = func(p []byte) (int, error) {
   620  					defer GinkgoRecover()
   621  					Expect(p[0]).To(Equal(byte(1)))
   622  					return 1, nil
   623  				}
   624  
   625  				stderrPipe.ReadStub = func(p []byte) (int, error) {
   626  					p[0] = 2
   627  					return 1, io.EOF
   628  				}
   629  				stderr.WriteStub = func(p []byte) (int, error) {
   630  					defer GinkgoRecover()
   631  					Expect(p[0]).To(Equal(byte(2)))
   632  					return 1, nil
   633  				}
   634  
   635  				fakeSecureSession.StdinPipeReturns(stdinPipe, nil)
   636  				fakeSecureSession.StdoutPipeReturns(stdoutPipe, nil)
   637  				fakeSecureSession.StderrPipeReturns(stderrPipe, nil)
   638  
   639  				fakeSecureSession.WaitReturns(errors.New("error result"))
   640  			})
   641  
   642  			It("copies data from the stdin stream to the session stdin pipe", func() {
   643  				Eventually(stdin.ReadCallCount).Should(Equal(1))
   644  				Eventually(stdinPipe.WriteCallCount).Should(Equal(1))
   645  			})
   646  
   647  			It("copies data from the session stdout pipe to the stdout stream", func() {
   648  				Eventually(stdoutPipe.ReadCallCount).Should(Equal(1))
   649  				Eventually(stdout.WriteCallCount).Should(Equal(1))
   650  			})
   651  
   652  			It("copies data from the session stderr pipe to the stderr stream", func() {
   653  				Eventually(stderrPipe.ReadCallCount).Should(Equal(1))
   654  				Eventually(stderr.WriteCallCount).Should(Equal(1))
   655  			})
   656  
   657  			It("waits for the session to end", func() {
   658  				Expect(fakeSecureSession.WaitCallCount()).To(Equal(1))
   659  			})
   660  
   661  			It("returns the result from wait", func() {
   662  				Expect(sessionErr).To(MatchError("error result"))
   663  			})
   664  
   665  			Context("when the session terminates before stream copies complete", func() {
   666  				var sessionErrorCh chan error
   667  
   668  				BeforeEach(func() {
   669  					sessionErrorCh = make(chan error, 1)
   670  
   671  					interactiveSessionInvoker = func(secureShell *SecureShell) {
   672  						go func() {
   673  							sessionErrorCh <- secureShell.InteractiveSession(commands, terminalRequest)
   674  						}()
   675  					}
   676  
   677  					stdoutPipe.ReadStub = func(p []byte) (int, error) {
   678  						defer GinkgoRecover()
   679  						Eventually(fakeSecureSession.WaitCallCount).Should(Equal(1))
   680  						Consistently(sessionErrorCh).ShouldNot(Receive())
   681  
   682  						p[0] = 1
   683  						return 1, io.EOF
   684  					}
   685  
   686  					stderrPipe.ReadStub = func(p []byte) (int, error) {
   687  						defer GinkgoRecover()
   688  						Eventually(fakeSecureSession.WaitCallCount).Should(Equal(1))
   689  						Consistently(sessionErrorCh).ShouldNot(Receive())
   690  
   691  						p[0] = 2
   692  						return 1, io.EOF
   693  					}
   694  				})
   695  
   696  				It("waits for the copies to complete", func() {
   697  					Eventually(sessionErrorCh).Should(Receive())
   698  					Expect(stdoutPipe.ReadCallCount()).To(Equal(1))
   699  					Expect(stderrPipe.ReadCallCount()).To(Equal(1))
   700  				})
   701  			})
   702  
   703  			Context("when stdin is closed", func() {
   704  				BeforeEach(func() {
   705  					stdin.ReadStub = func(p []byte) (int, error) {
   706  						defer GinkgoRecover()
   707  						Consistently(stdinPipe.CloseCallCount).Should(Equal(0))
   708  						p[0] = 0
   709  						return 1, io.EOF
   710  					}
   711  				})
   712  
   713  				It("closes the stdinPipe", func() {
   714  					Eventually(stdinPipe.CloseCallCount).Should(Equal(1))
   715  				})
   716  			})
   717  		})
   718  
   719  		Context("when stdout is a terminal and a window size change occurs", func() {
   720  			var master, slave *os.File
   721  
   722  			BeforeEach(func() {
   723  				var err error
   724  				master, slave, err = pty.Open()
   725  				Expect(err).NotTo(HaveOccurred())
   726  
   727  				terminalHelper := DefaultTerminalHelper()
   728  				fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo
   729  				fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize
   730  				fakeTerminalHelper.StdStreamsReturns(stdin, slave, stderr)
   731  
   732  				winsize := &term.Winsize{Height: 100, Width: 100}
   733  				err = term.SetWinsize(slave.Fd(), winsize)
   734  				Expect(err).NotTo(HaveOccurred())
   735  
   736  				fakeSecureSession.WaitStub = func() error {
   737  					fakeSecureSession.SendRequestCallCount()
   738  					Expect(fakeSecureSession.SendRequestCallCount()).To(Equal(0))
   739  
   740  					// No dimension change
   741  					for i := 0; i < 3; i++ {
   742  						winsize := &term.Winsize{Height: 100, Width: 100}
   743  						err = term.SetWinsize(slave.Fd(), winsize)
   744  						Expect(err).NotTo(HaveOccurred())
   745  					}
   746  
   747  					winsize := &term.Winsize{Height: 100, Width: 200}
   748  					err = term.SetWinsize(slave.Fd(), winsize)
   749  					Expect(err).NotTo(HaveOccurred())
   750  
   751  					err = syscall.Kill(syscall.Getpid(), syscall.SIGWINCH)
   752  					Expect(err).NotTo(HaveOccurred())
   753  
   754  					Eventually(fakeSecureSession.SendRequestCallCount).Should(Equal(1))
   755  					return nil
   756  				}
   757  			})
   758  
   759  			AfterEach(func() {
   760  				master.Close()
   761  				slave.Close()
   762  			})
   763  
   764  			It("sends window change events when the window dimensions change", func() {
   765  				Expect(fakeSecureSession.SendRequestCallCount()).To(Equal(1))
   766  
   767  				requestType, wantReply, message := fakeSecureSession.SendRequestArgsForCall(0)
   768  				Expect(requestType).To(Equal("window-change"))
   769  				Expect(wantReply).To(BeFalse())
   770  
   771  				type resizeMessage struct {
   772  					Width       uint32
   773  					Height      uint32
   774  					PixelWidth  uint32
   775  					PixelHeight uint32
   776  				}
   777  				var resizeMsg resizeMessage
   778  
   779  				err := ssh.Unmarshal(message, &resizeMsg)
   780  				Expect(err).NotTo(HaveOccurred())
   781  
   782  				Expect(resizeMsg).To(Equal(resizeMessage{Height: 100, Width: 200}))
   783  			})
   784  		})
   785  
   786  		Describe("keep alive messages", func() {
   787  			var times []time.Time
   788  			var timesCh chan []time.Time
   789  			var done chan struct{}
   790  
   791  			BeforeEach(func() {
   792  				keepAliveDuration = 100 * time.Millisecond
   793  
   794  				times = []time.Time{}
   795  				timesCh = make(chan []time.Time, 1)
   796  				done = make(chan struct{}, 1)
   797  
   798  				fakeConnection.SendRequestStub = func(reqName string, wantReply bool, message []byte) (bool, []byte, error) {
   799  					Expect(reqName).To(Equal("keepalive@cloudfoundry.org"))
   800  					Expect(wantReply).To(BeTrue())
   801  					Expect(message).To(BeNil())
   802  
   803  					times = append(times, time.Now())
   804  					if len(times) == 3 {
   805  						timesCh <- times
   806  						close(done)
   807  					}
   808  					return true, nil, nil
   809  				}
   810  
   811  				fakeSecureSession.WaitStub = func() error {
   812  					Eventually(done).Should(BeClosed())
   813  					return nil
   814  				}
   815  			})
   816  
   817  			PIt("sends keep alive messages at the expected interval", func() {
   818  				times := <-timesCh
   819  				Expect(times[2]).To(BeTemporally("~", times[0].Add(200*time.Millisecond), 160*time.Millisecond))
   820  			})
   821  		})
   822  	})
   823  
   824  	Describe("LocalPortForward", func() {
   825  		var (
   826  			forwardErr error
   827  
   828  			echoAddress  string
   829  			echoListener *fake_net.FakeListener
   830  			echoHandler  *fake_server.FakeConnectionHandler
   831  			echoServer   *server.Server
   832  
   833  			localAddress string
   834  
   835  			realLocalListener net.Listener
   836  			fakeLocalListener *fake_net.FakeListener
   837  
   838  			forwardSpecs []LocalPortForward
   839  		)
   840  
   841  		BeforeEach(func() {
   842  			logger := lagertest.NewTestLogger("test")
   843  
   844  			var err error
   845  			realLocalListener, err = net.Listen("tcp", "127.0.0.1:0")
   846  			Expect(err).NotTo(HaveOccurred())
   847  
   848  			localAddress = realLocalListener.Addr().String()
   849  			fakeListenerFactory.ListenReturns(realLocalListener, nil)
   850  
   851  			echoHandler = new(fake_server.FakeConnectionHandler)
   852  			echoHandler.HandleConnectionStub = func(conn net.Conn) {
   853  				io.Copy(conn, conn)
   854  				conn.Close()
   855  			}
   856  
   857  			realListener, err := net.Listen("tcp", "127.0.0.1:0")
   858  			Expect(err).NotTo(HaveOccurred())
   859  			echoAddress = realListener.Addr().String()
   860  
   861  			echoListener = new(fake_net.FakeListener)
   862  			echoListener.AcceptStub = realListener.Accept
   863  			echoListener.CloseStub = realListener.Close
   864  			echoListener.AddrStub = realListener.Addr
   865  
   866  			fakeLocalListener = new(fake_net.FakeListener)
   867  			fakeLocalListener.AcceptReturns(nil, errors.New("Not Accepting Connections"))
   868  
   869  			echoServer = server.NewServer(logger.Session("echo"), "", echoHandler)
   870  			echoServer.SetListener(echoListener)
   871  			go echoServer.Serve()
   872  
   873  			forwardSpecs = []LocalPortForward{{
   874  				RemoteAddress: echoAddress,
   875  				LocalAddress:  localAddress,
   876  			}}
   877  
   878  			fakeSecureClient.DialStub = net.Dial
   879  		})
   880  
   881  		JustBeforeEach(func() {
   882  			connectErr := secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation)
   883  			Expect(connectErr).NotTo(HaveOccurred())
   884  
   885  			forwardErr = secureShell.LocalPortForward(forwardSpecs)
   886  		})
   887  
   888  		AfterEach(func() {
   889  			err := secureShell.Close()
   890  			Expect(err).NotTo(HaveOccurred())
   891  			echoServer.Shutdown()
   892  
   893  			realLocalListener.Close()
   894  		})
   895  
   896  		validateConnectivity := func(addr string) {
   897  			conn, err := net.Dial("tcp", addr)
   898  			Expect(err).NotTo(HaveOccurred())
   899  
   900  			msg := fmt.Sprintf("Hello from %s\n", addr)
   901  			n, err := conn.Write([]byte(msg))
   902  			Expect(err).NotTo(HaveOccurred())
   903  			Expect(n).To(Equal(len(msg)))
   904  
   905  			response := make([]byte, len(msg))
   906  			n, err = conn.Read(response)
   907  			Expect(err).NotTo(HaveOccurred())
   908  			Expect(n).To(Equal(len(msg)))
   909  
   910  			err = conn.Close()
   911  			Expect(err).NotTo(HaveOccurred())
   912  
   913  			Expect(response).To(Equal([]byte(msg)))
   914  		}
   915  
   916  		It("dials the connect address when a local connection is made", func() {
   917  			Expect(forwardErr).NotTo(HaveOccurred())
   918  
   919  			conn, err := net.Dial("tcp", localAddress)
   920  			Expect(err).NotTo(HaveOccurred())
   921  
   922  			Eventually(echoListener.AcceptCallCount).Should(BeNumerically(">=", 1))
   923  			Eventually(fakeSecureClient.DialCallCount).Should(Equal(1))
   924  
   925  			network, addr := fakeSecureClient.DialArgsForCall(0)
   926  			Expect(network).To(Equal("tcp"))
   927  			Expect(addr).To(Equal(echoAddress))
   928  
   929  			Expect(conn.Close()).NotTo(HaveOccurred())
   930  		})
   931  
   932  		It("copies data between the local and remote connections", func() {
   933  			validateConnectivity(localAddress)
   934  		})
   935  
   936  		Context("when a local connection is already open", func() {
   937  			var conn net.Conn
   938  
   939  			JustBeforeEach(func() {
   940  				var err error
   941  				conn, err = net.Dial("tcp", localAddress)
   942  				Expect(err).NotTo(HaveOccurred())
   943  			})
   944  
   945  			AfterEach(func() {
   946  				err := conn.Close()
   947  				Expect(err).NotTo(HaveOccurred())
   948  			})
   949  
   950  			It("allows for new incoming connections as well", func() {
   951  				validateConnectivity(localAddress)
   952  			})
   953  		})
   954  
   955  		Context("when there are multiple port forward specs", func() {
   956  			Context("when provided a real listener", func() {
   957  				var (
   958  					realLocalListener2 net.Listener
   959  					localAddress2      string
   960  				)
   961  
   962  				BeforeEach(func() {
   963  					var err error
   964  					realLocalListener2, err = net.Listen("tcp", "127.0.0.1:0")
   965  					Expect(err).NotTo(HaveOccurred())
   966  
   967  					localAddress2 = realLocalListener2.Addr().String()
   968  
   969  					fakeListenerFactory.ListenStub = func(network, addr string) (net.Listener, error) {
   970  						if addr == localAddress {
   971  							return realLocalListener, nil
   972  						}
   973  
   974  						if addr == localAddress2 {
   975  							return realLocalListener2, nil
   976  						}
   977  
   978  						return nil, errors.New("unexpected address")
   979  					}
   980  
   981  					forwardSpecs = []LocalPortForward{
   982  						{
   983  							RemoteAddress: echoAddress,
   984  							LocalAddress:  localAddress,
   985  						},
   986  						{
   987  							RemoteAddress: echoAddress,
   988  							LocalAddress:  localAddress2,
   989  						},
   990  					}
   991  				})
   992  
   993  				AfterEach(func() {
   994  					realLocalListener2.Close()
   995  				})
   996  
   997  				It("listens to all the things", func() {
   998  					Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(2))
   999  
  1000  					network, addr := fakeListenerFactory.ListenArgsForCall(0)
  1001  					Expect(network).To(Equal("tcp"))
  1002  					Expect(addr).To(Equal(localAddress))
  1003  
  1004  					network, addr = fakeListenerFactory.ListenArgsForCall(1)
  1005  					Expect(network).To(Equal("tcp"))
  1006  					Expect(addr).To(Equal(localAddress2))
  1007  				})
  1008  
  1009  				It("forwards to the correct target", func() {
  1010  					validateConnectivity(localAddress)
  1011  					validateConnectivity(localAddress2)
  1012  				})
  1013  			})
  1014  
  1015  			Context("when the secure client is closed", func() {
  1016  				var (
  1017  					fakeLocalListener1 *fake_net.FakeListener
  1018  					fakeLocalListener2 *fake_net.FakeListener
  1019  				)
  1020  
  1021  				BeforeEach(func() {
  1022  					forwardSpecs = []LocalPortForward{
  1023  						{
  1024  							RemoteAddress: echoAddress,
  1025  							LocalAddress:  localAddress,
  1026  						},
  1027  						{
  1028  							RemoteAddress: echoAddress,
  1029  							LocalAddress:  localAddress,
  1030  						},
  1031  					}
  1032  
  1033  					fakeLocalListener1 = new(fake_net.FakeListener)
  1034  					BlockAcceptOnClose(fakeLocalListener1)
  1035  					fakeListenerFactory.ListenReturnsOnCall(0, fakeLocalListener1, nil)
  1036  
  1037  					fakeLocalListener2 = new(fake_net.FakeListener)
  1038  					BlockAcceptOnClose(fakeLocalListener2)
  1039  					fakeListenerFactory.ListenReturnsOnCall(1, fakeLocalListener2, nil)
  1040  				})
  1041  
  1042  				It("closes the listeners", func() {
  1043  					Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(2))
  1044  					Eventually(fakeLocalListener1.AcceptCallCount).Should(Equal(1))
  1045  					Eventually(fakeLocalListener2.AcceptCallCount).Should(Equal(1))
  1046  
  1047  					err := secureShell.Close()
  1048  					Expect(err).NotTo(HaveOccurred())
  1049  					Eventually(fakeLocalListener1.CloseCallCount).Should(Equal(2))
  1050  					Eventually(fakeLocalListener2.CloseCallCount).Should(Equal(2))
  1051  				})
  1052  			})
  1053  		})
  1054  
  1055  		Context("when listen fails", func() {
  1056  			BeforeEach(func() {
  1057  				fakeListenerFactory.ListenReturns(nil, errors.New("failure is an option"))
  1058  			})
  1059  
  1060  			It("returns the error", func() {
  1061  				Expect(forwardErr).To(MatchError("failure is an option"))
  1062  			})
  1063  		})
  1064  
  1065  		Context("when the client it closed", func() {
  1066  			BeforeEach(func() {
  1067  				fakeLocalListener = new(fake_net.FakeListener)
  1068  				BlockAcceptOnClose(fakeLocalListener)
  1069  
  1070  				fakeListenerFactory.ListenReturns(fakeLocalListener, nil)
  1071  			})
  1072  
  1073  			It("closes the listener when the client is closed", func() {
  1074  				Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(1))
  1075  				Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(1))
  1076  
  1077  				err := secureShell.Close()
  1078  				Expect(err).NotTo(HaveOccurred())
  1079  				Eventually(fakeLocalListener.CloseCallCount).Should(Equal(2))
  1080  			})
  1081  		})
  1082  
  1083  		Context("when accept fails", func() {
  1084  			var fakeConn *fake_net.FakeConn
  1085  
  1086  			BeforeEach(func() {
  1087  				fakeConn = &fake_net.FakeConn{}
  1088  				fakeConn.ReadReturns(0, io.EOF)
  1089  
  1090  				fakeListenerFactory.ListenReturns(fakeLocalListener, nil)
  1091  			})
  1092  
  1093  			Context("with a permanent error", func() {
  1094  				BeforeEach(func() {
  1095  					fakeLocalListener.AcceptReturns(nil, errors.New("boom"))
  1096  				})
  1097  
  1098  				It("stops trying to accept connections", func() {
  1099  					Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(1))
  1100  					Consistently(fakeLocalListener.AcceptCallCount).Should(Equal(1))
  1101  					Expect(fakeLocalListener.CloseCallCount()).To(Equal(1))
  1102  				})
  1103  			})
  1104  
  1105  			Context("with a temporary error", func() {
  1106  				var timeCh chan time.Time
  1107  
  1108  				BeforeEach(func() {
  1109  					timeCh = make(chan time.Time, 3)
  1110  
  1111  					fakeLocalListener.AcceptStub = func() (net.Conn, error) {
  1112  						timeCh := timeCh
  1113  						if fakeLocalListener.AcceptCallCount() > 3 {
  1114  							close(timeCh)
  1115  							return nil, test_helpers.NewTestNetError(false, false)
  1116  						} else {
  1117  							timeCh <- time.Now()
  1118  							return nil, test_helpers.NewTestNetError(false, true)
  1119  						}
  1120  					}
  1121  				})
  1122  
  1123  				PIt("retries connecting after a short delay", func() {
  1124  					Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(3))
  1125  					Expect(timeCh).To(HaveLen(3))
  1126  
  1127  					times := make([]time.Time, 0)
  1128  					for t := range timeCh {
  1129  						times = append(times, t)
  1130  					}
  1131  
  1132  					Expect(times[1]).To(BeTemporally("~", times[0].Add(115*time.Millisecond), 80*time.Millisecond))
  1133  					Expect(times[2]).To(BeTemporally("~", times[1].Add(115*time.Millisecond), 100*time.Millisecond))
  1134  				})
  1135  			})
  1136  		})
  1137  
  1138  		Context("when dialing the connect address fails", func() {
  1139  			var fakeTarget *fake_net.FakeConn
  1140  
  1141  			BeforeEach(func() {
  1142  				fakeTarget = &fake_net.FakeConn{}
  1143  				fakeSecureClient.DialReturns(fakeTarget, errors.New("boom"))
  1144  			})
  1145  
  1146  			It("does not call close on the target connection", func() {
  1147  				Consistently(fakeTarget.CloseCallCount).Should(Equal(0))
  1148  			})
  1149  		})
  1150  	})
  1151  
  1152  	Describe("Wait", func() {
  1153  		var waitErr error
  1154  
  1155  		JustBeforeEach(func() {
  1156  			connectErr := secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation)
  1157  			Expect(connectErr).NotTo(HaveOccurred())
  1158  
  1159  			waitErr = secureShell.Wait()
  1160  		})
  1161  
  1162  		It("calls wait on the secureClient", func() {
  1163  			Expect(waitErr).NotTo(HaveOccurred())
  1164  			Expect(fakeSecureClient.WaitCallCount()).To(Equal(1))
  1165  		})
  1166  
  1167  		Describe("keep alive messages", func() {
  1168  			var times []time.Time
  1169  			var timesCh chan []time.Time
  1170  			var done chan struct{}
  1171  
  1172  			BeforeEach(func() {
  1173  				keepAliveDuration = 100 * time.Millisecond
  1174  
  1175  				times = []time.Time{}
  1176  				timesCh = make(chan []time.Time, 1)
  1177  				done = make(chan struct{}, 1)
  1178  
  1179  				fakeConnection.SendRequestStub = func(reqName string, wantReply bool, message []byte) (bool, []byte, error) {
  1180  					Expect(reqName).To(Equal("keepalive@cloudfoundry.org"))
  1181  					Expect(wantReply).To(BeTrue())
  1182  					Expect(message).To(BeNil())
  1183  
  1184  					times = append(times, time.Now())
  1185  					if len(times) == 3 {
  1186  						timesCh <- times
  1187  						close(done)
  1188  					}
  1189  					return true, nil, nil
  1190  				}
  1191  
  1192  				fakeSecureClient.WaitStub = func() error {
  1193  					Eventually(done).Should(BeClosed())
  1194  					return nil
  1195  				}
  1196  			})
  1197  
  1198  			PIt("sends keep alive messages at the expected interval", func() {
  1199  				Expect(waitErr).NotTo(HaveOccurred())
  1200  				times := <-timesCh
  1201  				Expect(times[2]).To(BeTemporally("~", times[0].Add(200*time.Millisecond), 100*time.Millisecond))
  1202  			})
  1203  		})
  1204  	})
  1205  
  1206  	Describe("Close", func() {
  1207  		JustBeforeEach(func() {
  1208  			connectErr := secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation)
  1209  			Expect(connectErr).NotTo(HaveOccurred())
  1210  		})
  1211  
  1212  		It("calls close on the secureClient", func() {
  1213  			err := secureShell.Close()
  1214  			Expect(err).NotTo(HaveOccurred())
  1215  
  1216  			Expect(fakeSecureClient.CloseCallCount()).To(Equal(1))
  1217  		})
  1218  	})
  1219  })