github.com/loggregator/cli@v6.33.1-0.20180224010324-82334f081791+incompatible/util/clissh/ssh_test.go (about)

     1  // +build !windows,!386
     2  
     3  // skipping 386 because lager uses UInt64 in Session()
     4  // skipping windows because Unix/Linux only syscall in test.
     5  // should refactor out the conflicts so we could test this package in multi platforms.
     6  
     7  package clissh_test
     8  
     9  import (
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"net"
    14  	"os"
    15  	"syscall"
    16  	"time"
    17  
    18  	"code.cloudfoundry.org/cli/util/clissh/clisshfakes"
    19  	"code.cloudfoundry.org/cli/util/clissh/ssherror"
    20  	"code.cloudfoundry.org/diego-ssh/server"
    21  	fake_server "code.cloudfoundry.org/diego-ssh/server/fakes"
    22  	"code.cloudfoundry.org/diego-ssh/test_helpers"
    23  	"code.cloudfoundry.org/diego-ssh/test_helpers/fake_io"
    24  	"code.cloudfoundry.org/diego-ssh/test_helpers/fake_net"
    25  	"code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh"
    26  	"code.cloudfoundry.org/lager/lagertest"
    27  	"github.com/kr/pty"
    28  	"github.com/moby/moby/pkg/term"
    29  	"golang.org/x/crypto/ssh"
    30  
    31  	. "code.cloudfoundry.org/cli/util/clissh"
    32  	. "github.com/onsi/ginkgo"
    33  	. "github.com/onsi/gomega"
    34  )
    35  
    36  var _ = Describe("CLI SSH", func() {
    37  	var (
    38  		fakeSecureDialer    *clisshfakes.FakeSecureDialer
    39  		fakeSecureClient    *clisshfakes.FakeSecureClient
    40  		fakeTerminalHelper  *clisshfakes.FakeTerminalHelper
    41  		fakeListenerFactory *clisshfakes.FakeListenerFactory
    42  		fakeSecureSession   *clisshfakes.FakeSecureSession
    43  
    44  		fakeConnection *fake_ssh.FakeConn
    45  		stdinPipe      *fake_io.FakeWriteCloser
    46  		stdoutPipe     *fake_io.FakeReader
    47  		stderrPipe     *fake_io.FakeReader
    48  		secureShell    *SecureShell
    49  
    50  		username               string
    51  		passcode               string
    52  		sshEndpoint            string
    53  		sshEndpointFingerprint string
    54  		skipHostValidation     bool
    55  		commands               []string
    56  		terminalRequest        TTYRequest
    57  		keepAliveDuration      time.Duration
    58  	)
    59  
    60  	BeforeEach(func() {
    61  		fakeSecureDialer = new(clisshfakes.FakeSecureDialer)
    62  		fakeSecureClient = new(clisshfakes.FakeSecureClient)
    63  		fakeTerminalHelper = new(clisshfakes.FakeTerminalHelper)
    64  		fakeListenerFactory = new(clisshfakes.FakeListenerFactory)
    65  		fakeSecureSession = new(clisshfakes.FakeSecureSession)
    66  
    67  		fakeConnection = new(fake_ssh.FakeConn)
    68  		stdinPipe = new(fake_io.FakeWriteCloser)
    69  		stdoutPipe = new(fake_io.FakeReader)
    70  		stderrPipe = new(fake_io.FakeReader)
    71  
    72  		fakeListenerFactory.ListenStub = net.Listen
    73  		fakeSecureClient.NewSessionReturns(fakeSecureSession, nil)
    74  		fakeSecureClient.ConnReturns(fakeConnection)
    75  		fakeSecureDialer.DialReturns(fakeSecureClient, nil)
    76  
    77  		stdinPipe.WriteStub = func(p []byte) (int, error) {
    78  			return len(p), nil
    79  		}
    80  		fakeSecureSession.StdinPipeReturns(stdinPipe, nil)
    81  
    82  		stdoutPipe.ReadStub = func(p []byte) (int, error) {
    83  			return 0, io.EOF
    84  		}
    85  		fakeSecureSession.StdoutPipeReturns(stdoutPipe, nil)
    86  
    87  		stderrPipe.ReadStub = func(p []byte) (int, error) {
    88  			return 0, io.EOF
    89  		}
    90  		fakeSecureSession.StderrPipeReturns(stderrPipe, nil)
    91  
    92  		username = "some-user"
    93  		passcode = "some-passcode"
    94  		sshEndpoint = "some-endpoint"
    95  		sshEndpointFingerprint = "some-fingerprint"
    96  		skipHostValidation = false
    97  		commands = []string{}
    98  		terminalRequest = RequestTTYAuto
    99  		keepAliveDuration = DefaultKeepAliveInterval
   100  	})
   101  
   102  	JustBeforeEach(func() {
   103  		secureShell = NewSecureShell(
   104  			fakeSecureDialer,
   105  			fakeTerminalHelper,
   106  			fakeListenerFactory,
   107  			keepAliveDuration,
   108  		)
   109  	})
   110  
   111  	Describe("Connect", func() {
   112  		var connectErr error
   113  
   114  		JustBeforeEach(func() {
   115  			connectErr = secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation)
   116  		})
   117  
   118  		Context("when dialing succeeds", func() {
   119  			It("creates the ssh client", func() {
   120  				Expect(connectErr).ToNot(HaveOccurred())
   121  
   122  				Expect(fakeSecureDialer.DialCallCount()).To(Equal(1))
   123  				protocolArg, sshEndpointArg, sshConfigArg := fakeSecureDialer.DialArgsForCall(0)
   124  				Expect(protocolArg).To(Equal("tcp"))
   125  				Expect(sshEndpointArg).To(Equal(sshEndpoint))
   126  				Expect(sshConfigArg.User).To(Equal(username))
   127  				Expect(sshConfigArg.Auth).To(HaveLen(1))
   128  				Expect(sshConfigArg.HostKeyCallback).ToNot(BeNil())
   129  			})
   130  		})
   131  
   132  		Context("when dialing fails", func() {
   133  			var dialError error
   134  
   135  			Context("when the error is a generic Dial error", func() {
   136  				BeforeEach(func() {
   137  					dialError = errors.New("woops")
   138  					fakeSecureDialer.DialReturns(nil, dialError)
   139  				})
   140  
   141  				It("returns the dial error", func() {
   142  					Expect(connectErr).To(Equal(dialError))
   143  					Expect(fakeSecureDialer.DialCallCount()).To(Equal(1))
   144  				})
   145  			})
   146  
   147  			Context("when the dialing error is a golang 'unable to authenticate' error", func() {
   148  				BeforeEach(func() {
   149  					dialError = fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", []string{"none", "password"})
   150  					fakeSecureDialer.DialReturns(nil, dialError)
   151  				})
   152  
   153  				It("returns an UnableToAuthenticateError", func() {
   154  					Expect(connectErr).To(MatchError(ssherror.UnableToAuthenticateError{Err: dialError}))
   155  					Expect(fakeSecureDialer.DialCallCount()).To(Equal(1))
   156  				})
   157  			})
   158  		})
   159  	})
   160  
   161  	Describe("InteractiveSession", func() {
   162  		var (
   163  			stdin          *fake_io.FakeReadCloser
   164  			stdout, stderr *fake_io.FakeWriter
   165  
   166  			sessionErr                error
   167  			interactiveSessionInvoker func(secureShell *SecureShell)
   168  		)
   169  
   170  		BeforeEach(func() {
   171  			stdin = new(fake_io.FakeReadCloser)
   172  			stdout = new(fake_io.FakeWriter)
   173  			stderr = new(fake_io.FakeWriter)
   174  
   175  			fakeTerminalHelper.StdStreamsReturns(stdin, stdout, stderr)
   176  			interactiveSessionInvoker = func(secureShell *SecureShell) {
   177  				sessionErr = secureShell.InteractiveSession(commands, terminalRequest)
   178  			}
   179  		})
   180  
   181  		JustBeforeEach(func() {
   182  			connectErr := secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation)
   183  			Expect(connectErr).NotTo(HaveOccurred())
   184  			interactiveSessionInvoker(secureShell)
   185  		})
   186  
   187  		Context("when host key validation is enabled", func() {
   188  			var (
   189  				callback func(hostname string, remote net.Addr, key ssh.PublicKey) error
   190  				addr     net.Addr
   191  			)
   192  
   193  			BeforeEach(func() {
   194  				skipHostValidation = false
   195  			})
   196  
   197  			JustBeforeEach(func() {
   198  				Expect(fakeSecureDialer.DialCallCount()).To(Equal(1))
   199  				_, _, config := fakeSecureDialer.DialArgsForCall(0)
   200  				callback = config.HostKeyCallback
   201  
   202  				listener, err := net.Listen("tcp", "localhost:0")
   203  				Expect(err).NotTo(HaveOccurred())
   204  
   205  				addr = listener.Addr()
   206  				listener.Close()
   207  			})
   208  
   209  			Context("when the md5 fingerprint matches", func() {
   210  				BeforeEach(func() {
   211  					sshEndpointFingerprint = "41:ce:56:e6:9c:42:a9:c6:9e:68:ac:e3:4d:f6:38:79"
   212  				})
   213  
   214  				It("does not return an error", func() {
   215  					Expect(callback("", addr, TestHostKey.PublicKey())).ToNot(HaveOccurred())
   216  				})
   217  			})
   218  
   219  			Context("when the hex sha1 fingerprint matches", func() {
   220  				BeforeEach(func() {
   221  					sshEndpointFingerprint = "a8:e2:67:cb:ea:2a:6e:23:a1:72:ce:8f:07:92:15:ee:1f:82:f8:ca"
   222  				})
   223  
   224  				It("does not return an error", func() {
   225  					Expect(callback("", addr, TestHostKey.PublicKey())).ToNot(HaveOccurred())
   226  				})
   227  			})
   228  
   229  			Context("when the base64 sha256 fingerprint matches", func() {
   230  				BeforeEach(func() {
   231  					sshEndpointFingerprint = "sp/jrLuj66r+yrLDUKZdJU5tdzt4mq/UaSiNBjpgr+8"
   232  				})
   233  
   234  				It("does not return an error", func() {
   235  					Expect(callback("", addr, TestHostKey.PublicKey())).ToNot(HaveOccurred())
   236  				})
   237  			})
   238  
   239  			Context("when the base64 SHA256 fingerprint does not match", func() {
   240  				BeforeEach(func() {
   241  					sshEndpointFingerprint = "0000000000000000000000000000000000000000000"
   242  				})
   243  
   244  				It("returns an error'", func() {
   245  					err := callback("", addr, TestHostKey.PublicKey())
   246  					Expect(err).To(MatchError(MatchRegexp("Host key verification failed\\.")))
   247  					Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\"")))
   248  				})
   249  			})
   250  
   251  			Context("when the hex SHA1 fingerprint does not match", func() {
   252  				BeforeEach(func() {
   253  					sshEndpointFingerprint = "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00"
   254  				})
   255  
   256  				It("returns an error'", func() {
   257  					err := callback("", addr, TestHostKey.PublicKey())
   258  					Expect(err).To(MatchError(MatchRegexp("Host key verification failed\\.")))
   259  					Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\"")))
   260  				})
   261  			})
   262  
   263  			Context("when the MD5 fingerprint does not match", func() {
   264  				BeforeEach(func() {
   265  					sshEndpointFingerprint = "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00"
   266  				})
   267  
   268  				It("returns an error'", func() {
   269  					err := callback("", addr, TestHostKey.PublicKey())
   270  					Expect(err).To(MatchError(MatchRegexp("Host key verification failed\\.")))
   271  					Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\"")))
   272  				})
   273  			})
   274  
   275  			Context("when no fingerprint is present in endpoint info", func() {
   276  				BeforeEach(func() {
   277  					sshEndpointFingerprint = ""
   278  					sshEndpoint = ""
   279  				})
   280  
   281  				It("returns an error'", func() {
   282  					err := callback("", addr, TestHostKey.PublicKey())
   283  					Expect(err).To(MatchError(MatchRegexp("Unable to verify identity of host\\.")))
   284  					Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\"")))
   285  				})
   286  			})
   287  
   288  			Context("when the fingerprint length doesn't make sense", func() {
   289  				BeforeEach(func() {
   290  					sshEndpointFingerprint = "garbage"
   291  				})
   292  
   293  				It("returns an error", func() {
   294  					err := callback("", addr, TestHostKey.PublicKey())
   295  					Eventually(err).Should(MatchError(MatchRegexp("Unsupported host key fingerprint format")))
   296  				})
   297  			})
   298  		})
   299  
   300  		Context("when the skip host validation flag is set", func() {
   301  			BeforeEach(func() {
   302  				skipHostValidation = true
   303  			})
   304  
   305  			It("the HostKeyCallback on the Config to always return nil", func() {
   306  				Expect(fakeSecureDialer.DialCallCount()).To(Equal(1))
   307  
   308  				_, _, config := fakeSecureDialer.DialArgsForCall(0)
   309  				Expect(config.HostKeyCallback("some-hostname", nil, nil)).To(BeNil())
   310  			})
   311  		})
   312  
   313  		// TODO: see if it's possible to test the piping between the ss client input and outputs and the UI object we pass in
   314  		Context("when dialing is successful", func() {
   315  			It("creates a new secure shell session", func() {
   316  				Expect(fakeSecureClient.NewSessionCallCount()).To(Equal(1))
   317  			})
   318  
   319  			It("closes the session", func() {
   320  				Expect(fakeSecureSession.CloseCallCount()).To(Equal(1))
   321  			})
   322  
   323  			It("gets a stdin pipe for the session", func() {
   324  				Expect(fakeSecureSession.StdinPipeCallCount()).To(Equal(1))
   325  			})
   326  
   327  			Context("when getting the stdin pipe fails", func() {
   328  				BeforeEach(func() {
   329  					fakeSecureSession.StdinPipeReturns(nil, errors.New("woops"))
   330  				})
   331  
   332  				It("returns the error", func() {
   333  					Expect(sessionErr).Should(MatchError("woops"))
   334  				})
   335  			})
   336  
   337  			It("gets a stdout pipe for the session", func() {
   338  				Expect(fakeSecureSession.StdoutPipeCallCount()).To(Equal(1))
   339  			})
   340  
   341  			Context("when getting the stdout pipe fails", func() {
   342  				BeforeEach(func() {
   343  					fakeSecureSession.StdoutPipeReturns(nil, errors.New("woops"))
   344  				})
   345  
   346  				It("returns the error", func() {
   347  					Expect(sessionErr).Should(MatchError("woops"))
   348  				})
   349  			})
   350  
   351  			It("gets a stderr pipe for the session", func() {
   352  				Expect(fakeSecureSession.StderrPipeCallCount()).To(Equal(1))
   353  			})
   354  
   355  			Context("when getting the stderr pipe fails", func() {
   356  				BeforeEach(func() {
   357  					fakeSecureSession.StderrPipeReturns(nil, errors.New("woops"))
   358  				})
   359  
   360  				It("returns the error", func() {
   361  					Expect(sessionErr).Should(MatchError("woops"))
   362  				})
   363  			})
   364  		})
   365  
   366  		Context("when stdin is a terminal", func() {
   367  			var master, slave *os.File
   368  
   369  			BeforeEach(func() {
   370  				var err error
   371  				master, slave, err = pty.Open()
   372  				Expect(err).NotTo(HaveOccurred())
   373  
   374  				terminalRequest = RequestTTYForce
   375  
   376  				terminalHelper := DefaultTerminalHelper()
   377  				fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo
   378  				fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize
   379  			})
   380  
   381  			AfterEach(func() {
   382  				master.Close()
   383  				// slave.Close() // race
   384  			})
   385  
   386  			Context("when a command is not specified", func() {
   387  				var terminalType string
   388  
   389  				BeforeEach(func() {
   390  					terminalType = os.Getenv("TERM")
   391  					os.Setenv("TERM", "test-terminal-type")
   392  
   393  					winsize := &term.Winsize{Width: 1024, Height: 256}
   394  					fakeTerminalHelper.GetWinsizeReturns(winsize, nil)
   395  
   396  					fakeSecureSession.ShellStub = func() error {
   397  						Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1))
   398  						Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(0))
   399  						return nil
   400  					}
   401  				})
   402  
   403  				AfterEach(func() {
   404  					os.Setenv("TERM", terminalType)
   405  				})
   406  
   407  				It("requests a pty with the correct terminal type, window size, and modes", func() {
   408  					Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1))
   409  					Expect(fakeTerminalHelper.GetWinsizeCallCount()).To(Equal(1))
   410  
   411  					termType, height, width, modes := fakeSecureSession.RequestPtyArgsForCall(0)
   412  					Expect(termType).To(Equal("test-terminal-type"))
   413  					Expect(height).To(Equal(256))
   414  					Expect(width).To(Equal(1024))
   415  
   416  					expectedModes := ssh.TerminalModes{
   417  						ssh.ECHO:          1,
   418  						ssh.TTY_OP_ISPEED: 115200,
   419  						ssh.TTY_OP_OSPEED: 115200,
   420  					}
   421  					Expect(modes).To(Equal(expectedModes))
   422  				})
   423  
   424  				Context("when the TERM environment variable is not set", func() {
   425  					BeforeEach(func() {
   426  						os.Unsetenv("TERM")
   427  					})
   428  
   429  					It("requests a pty with the default terminal type", func() {
   430  						Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1))
   431  
   432  						termType, _, _, _ := fakeSecureSession.RequestPtyArgsForCall(0)
   433  						Expect(termType).To(Equal("xterm"))
   434  					})
   435  				})
   436  
   437  				It("puts the terminal into raw mode and restores it after running the shell", func() {
   438  					Expect(fakeSecureSession.ShellCallCount()).To(Equal(1))
   439  					Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1))
   440  					Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(1))
   441  				})
   442  
   443  				Context("when the pty allocation fails", func() {
   444  					var ptyError error
   445  
   446  					BeforeEach(func() {
   447  						ptyError = errors.New("pty allocation error")
   448  						fakeSecureSession.RequestPtyReturns(ptyError)
   449  					})
   450  
   451  					It("returns the error", func() {
   452  						Expect(sessionErr).To(Equal(ptyError))
   453  					})
   454  				})
   455  
   456  				Context("when placing the terminal into raw mode fails", func() {
   457  					BeforeEach(func() {
   458  						fakeTerminalHelper.SetRawTerminalReturns(nil, errors.New("woops"))
   459  					})
   460  
   461  					It("keeps calm and carries on", func() {
   462  						Expect(fakeSecureSession.ShellCallCount()).To(Equal(1))
   463  					})
   464  
   465  					It("does not not restore the terminal", func() {
   466  						Expect(fakeSecureSession.ShellCallCount()).To(Equal(1))
   467  						Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1))
   468  						Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(0))
   469  					})
   470  				})
   471  			})
   472  
   473  			Context("when a command is specified", func() {
   474  				BeforeEach(func() {
   475  					commands = []string{"echo", "-n", "hello"}
   476  				})
   477  
   478  				Context("when a terminal is requested", func() {
   479  					BeforeEach(func() {
   480  						terminalRequest = RequestTTYYes
   481  						fakeTerminalHelper.GetFdInfoReturns(0, true)
   482  					})
   483  
   484  					It("requests a pty", func() {
   485  						Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1))
   486  					})
   487  				})
   488  
   489  				Context("when a terminal is not explicitly requested", func() {
   490  					BeforeEach(func() {
   491  						terminalRequest = RequestTTYAuto
   492  					})
   493  
   494  					It("does not request a pty", func() {
   495  						Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0))
   496  					})
   497  				})
   498  			})
   499  		})
   500  
   501  		Context("when stdin is not a terminal", func() {
   502  			BeforeEach(func() {
   503  				stdin.ReadStub = func(p []byte) (int, error) {
   504  					return 0, io.EOF
   505  				}
   506  
   507  				terminalHelper := DefaultTerminalHelper()
   508  				fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo
   509  				fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize
   510  			})
   511  
   512  			Context("when a terminal is not requested", func() {
   513  				It("does not request a pty", func() {
   514  					Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0))
   515  				})
   516  			})
   517  
   518  			Context("when a terminal is requested", func() {
   519  				BeforeEach(func() {
   520  					terminalRequest = RequestTTYYes
   521  				})
   522  
   523  				It("does not request a pty", func() {
   524  					Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0))
   525  				})
   526  			})
   527  		})
   528  
   529  		PContext("when a terminal is forced", func() {
   530  			BeforeEach(func() {
   531  				terminalRequest = RequestTTYForce
   532  			})
   533  
   534  			It("requests a pty", func() {
   535  				Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1))
   536  			})
   537  		})
   538  
   539  		Context("when a terminal is disabled", func() {
   540  			BeforeEach(func() {
   541  				terminalRequest = RequestTTYNo
   542  			})
   543  
   544  			It("does not request a pty", func() {
   545  				Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0))
   546  			})
   547  		})
   548  
   549  		Context("when a command is not specified", func() {
   550  			It("requests an interactive shell", func() {
   551  				Expect(fakeSecureSession.ShellCallCount()).To(Equal(1))
   552  			})
   553  
   554  			Context("when the shell request returns an error", func() {
   555  				BeforeEach(func() {
   556  					fakeSecureSession.ShellReturns(errors.New("oh bother"))
   557  				})
   558  
   559  				It("returns the error", func() {
   560  					Expect(sessionErr).To(MatchError("oh bother"))
   561  				})
   562  			})
   563  		})
   564  
   565  		Context("when a command is specifed", func() {
   566  			BeforeEach(func() {
   567  				commands = []string{"echo", "-n", "hello"}
   568  			})
   569  
   570  			It("starts the command", func() {
   571  				Expect(fakeSecureSession.StartCallCount()).To(Equal(1))
   572  				Expect(fakeSecureSession.StartArgsForCall(0)).To(Equal("echo -n hello"))
   573  			})
   574  
   575  			Context("when the command fails to start", func() {
   576  				BeforeEach(func() {
   577  					fakeSecureSession.StartReturns(errors.New("oh well"))
   578  				})
   579  
   580  				It("returns the error", func() {
   581  					Expect(sessionErr).To(MatchError("oh well"))
   582  				})
   583  			})
   584  		})
   585  
   586  		Context("when the shell or command has started", func() {
   587  			BeforeEach(func() {
   588  				stdin.ReadStub = func(p []byte) (int, error) {
   589  					p[0] = 0
   590  					return 1, io.EOF
   591  				}
   592  				stdinPipe.WriteStub = func(p []byte) (int, error) {
   593  					defer GinkgoRecover()
   594  					Expect(p[0]).To(Equal(byte(0)))
   595  					return 1, nil
   596  				}
   597  
   598  				stdoutPipe.ReadStub = func(p []byte) (int, error) {
   599  					p[0] = 1
   600  					return 1, io.EOF
   601  				}
   602  				stdout.WriteStub = func(p []byte) (int, error) {
   603  					defer GinkgoRecover()
   604  					Expect(p[0]).To(Equal(byte(1)))
   605  					return 1, nil
   606  				}
   607  
   608  				stderrPipe.ReadStub = func(p []byte) (int, error) {
   609  					p[0] = 2
   610  					return 1, io.EOF
   611  				}
   612  				stderr.WriteStub = func(p []byte) (int, error) {
   613  					defer GinkgoRecover()
   614  					Expect(p[0]).To(Equal(byte(2)))
   615  					return 1, nil
   616  				}
   617  
   618  				fakeSecureSession.StdinPipeReturns(stdinPipe, nil)
   619  				fakeSecureSession.StdoutPipeReturns(stdoutPipe, nil)
   620  				fakeSecureSession.StderrPipeReturns(stderrPipe, nil)
   621  
   622  				fakeSecureSession.WaitReturns(errors.New("error result"))
   623  			})
   624  
   625  			It("copies data from the stdin stream to the session stdin pipe", func() {
   626  				Eventually(stdin.ReadCallCount).Should(Equal(1))
   627  				Eventually(stdinPipe.WriteCallCount).Should(Equal(1))
   628  			})
   629  
   630  			It("copies data from the session stdout pipe to the stdout stream", func() {
   631  				Eventually(stdoutPipe.ReadCallCount).Should(Equal(1))
   632  				Eventually(stdout.WriteCallCount).Should(Equal(1))
   633  			})
   634  
   635  			It("copies data from the session stderr pipe to the stderr stream", func() {
   636  				Eventually(stderrPipe.ReadCallCount).Should(Equal(1))
   637  				Eventually(stderr.WriteCallCount).Should(Equal(1))
   638  			})
   639  
   640  			It("waits for the session to end", func() {
   641  				Expect(fakeSecureSession.WaitCallCount()).To(Equal(1))
   642  			})
   643  
   644  			It("returns the result from wait", func() {
   645  				Expect(sessionErr).To(MatchError("error result"))
   646  			})
   647  
   648  			Context("when the session terminates before stream copies complete", func() {
   649  				var sessionErrorCh chan error
   650  
   651  				BeforeEach(func() {
   652  					sessionErrorCh = make(chan error, 1)
   653  
   654  					interactiveSessionInvoker = func(secureShell *SecureShell) {
   655  						go func() {
   656  							sessionErrorCh <- secureShell.InteractiveSession(commands, terminalRequest)
   657  						}()
   658  					}
   659  
   660  					stdoutPipe.ReadStub = func(p []byte) (int, error) {
   661  						defer GinkgoRecover()
   662  						Eventually(fakeSecureSession.WaitCallCount).Should(Equal(1))
   663  						Consistently(sessionErrorCh).ShouldNot(Receive())
   664  
   665  						p[0] = 1
   666  						return 1, io.EOF
   667  					}
   668  
   669  					stderrPipe.ReadStub = func(p []byte) (int, error) {
   670  						defer GinkgoRecover()
   671  						Eventually(fakeSecureSession.WaitCallCount).Should(Equal(1))
   672  						Consistently(sessionErrorCh).ShouldNot(Receive())
   673  
   674  						p[0] = 2
   675  						return 1, io.EOF
   676  					}
   677  				})
   678  
   679  				It("waits for the copies to complete", func() {
   680  					Eventually(sessionErrorCh).Should(Receive())
   681  					Expect(stdoutPipe.ReadCallCount()).To(Equal(1))
   682  					Expect(stderrPipe.ReadCallCount()).To(Equal(1))
   683  				})
   684  			})
   685  
   686  			Context("when stdin is closed", func() {
   687  				BeforeEach(func() {
   688  					stdin.ReadStub = func(p []byte) (int, error) {
   689  						defer GinkgoRecover()
   690  						Consistently(stdinPipe.CloseCallCount).Should(Equal(0))
   691  						p[0] = 0
   692  						return 1, io.EOF
   693  					}
   694  				})
   695  
   696  				It("closes the stdinPipe", func() {
   697  					Eventually(stdinPipe.CloseCallCount).Should(Equal(1))
   698  				})
   699  			})
   700  		})
   701  
   702  		Context("when stdout is a terminal and a window size change occurs", func() {
   703  			var master, slave *os.File
   704  
   705  			BeforeEach(func() {
   706  				var err error
   707  				master, slave, err = pty.Open()
   708  				Expect(err).NotTo(HaveOccurred())
   709  
   710  				terminalHelper := DefaultTerminalHelper()
   711  				fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo
   712  				fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize
   713  				fakeTerminalHelper.StdStreamsReturns(stdin, slave, stderr)
   714  
   715  				winsize := &term.Winsize{Height: 100, Width: 100}
   716  				err = term.SetWinsize(slave.Fd(), winsize)
   717  				Expect(err).NotTo(HaveOccurred())
   718  
   719  				fakeSecureSession.WaitStub = func() error {
   720  					fakeSecureSession.SendRequestCallCount()
   721  					Expect(fakeSecureSession.SendRequestCallCount()).To(Equal(0))
   722  
   723  					// No dimension change
   724  					for i := 0; i < 3; i++ {
   725  						winsize := &term.Winsize{Height: 100, Width: 100}
   726  						err = term.SetWinsize(slave.Fd(), winsize)
   727  						Expect(err).NotTo(HaveOccurred())
   728  					}
   729  
   730  					winsize := &term.Winsize{Height: 100, Width: 200}
   731  					err = term.SetWinsize(slave.Fd(), winsize)
   732  					Expect(err).NotTo(HaveOccurred())
   733  
   734  					err = syscall.Kill(syscall.Getpid(), syscall.SIGWINCH)
   735  					Expect(err).NotTo(HaveOccurred())
   736  
   737  					Eventually(fakeSecureSession.SendRequestCallCount).Should(Equal(1))
   738  					return nil
   739  				}
   740  			})
   741  
   742  			AfterEach(func() {
   743  				master.Close()
   744  				slave.Close()
   745  			})
   746  
   747  			It("sends window change events when the window dimensions change", func() {
   748  				Expect(fakeSecureSession.SendRequestCallCount()).To(Equal(1))
   749  
   750  				requestType, wantReply, message := fakeSecureSession.SendRequestArgsForCall(0)
   751  				Expect(requestType).To(Equal("window-change"))
   752  				Expect(wantReply).To(BeFalse())
   753  
   754  				type resizeMessage struct {
   755  					Width       uint32
   756  					Height      uint32
   757  					PixelWidth  uint32
   758  					PixelHeight uint32
   759  				}
   760  				var resizeMsg resizeMessage
   761  
   762  				err := ssh.Unmarshal(message, &resizeMsg)
   763  				Expect(err).NotTo(HaveOccurred())
   764  
   765  				Expect(resizeMsg).To(Equal(resizeMessage{Height: 100, Width: 200}))
   766  			})
   767  		})
   768  
   769  		Describe("keep alive messages", func() {
   770  			var times []time.Time
   771  			var timesCh chan []time.Time
   772  			var done chan struct{}
   773  
   774  			BeforeEach(func() {
   775  				keepAliveDuration = 100 * time.Millisecond
   776  
   777  				times = []time.Time{}
   778  				timesCh = make(chan []time.Time, 1)
   779  				done = make(chan struct{}, 1)
   780  
   781  				fakeConnection.SendRequestStub = func(reqName string, wantReply bool, message []byte) (bool, []byte, error) {
   782  					Expect(reqName).To(Equal("keepalive@cloudfoundry.org"))
   783  					Expect(wantReply).To(BeTrue())
   784  					Expect(message).To(BeNil())
   785  
   786  					times = append(times, time.Now())
   787  					if len(times) == 3 {
   788  						timesCh <- times
   789  						close(done)
   790  					}
   791  					return true, nil, nil
   792  				}
   793  
   794  				fakeSecureSession.WaitStub = func() error {
   795  					Eventually(done).Should(BeClosed())
   796  					return nil
   797  				}
   798  			})
   799  
   800  			PIt("sends keep alive messages at the expected interval", func() {
   801  				times := <-timesCh
   802  				Expect(times[2]).To(BeTemporally("~", times[0].Add(200*time.Millisecond), 160*time.Millisecond))
   803  			})
   804  		})
   805  	})
   806  
   807  	Describe("LocalPortForward", func() {
   808  		var (
   809  			forwardErr error
   810  
   811  			echoAddress  string
   812  			echoListener *fake_net.FakeListener
   813  			echoHandler  *fake_server.FakeConnectionHandler
   814  			echoServer   *server.Server
   815  
   816  			localAddress string
   817  
   818  			realLocalListener net.Listener
   819  			fakeLocalListener *fake_net.FakeListener
   820  
   821  			forwardSpecs []LocalPortForward
   822  		)
   823  
   824  		BeforeEach(func() {
   825  			logger := lagertest.NewTestLogger("test")
   826  
   827  			var err error
   828  			realLocalListener, err = net.Listen("tcp", "127.0.0.1:0")
   829  			Expect(err).NotTo(HaveOccurred())
   830  
   831  			localAddress = realLocalListener.Addr().String()
   832  			fakeListenerFactory.ListenReturns(realLocalListener, nil)
   833  
   834  			echoHandler = &fake_server.FakeConnectionHandler{}
   835  			echoHandler.HandleConnectionStub = func(conn net.Conn) {
   836  				io.Copy(conn, conn)
   837  				conn.Close()
   838  			}
   839  
   840  			realListener, err := net.Listen("tcp", "127.0.0.1:0")
   841  			Expect(err).NotTo(HaveOccurred())
   842  			echoAddress = realListener.Addr().String()
   843  
   844  			echoListener = &fake_net.FakeListener{}
   845  			echoListener.AcceptStub = realListener.Accept
   846  			echoListener.CloseStub = realListener.Close
   847  			echoListener.AddrStub = realListener.Addr
   848  
   849  			fakeLocalListener = &fake_net.FakeListener{}
   850  			fakeLocalListener.AcceptReturns(nil, errors.New("Not Accepting Connections"))
   851  
   852  			echoServer = server.NewServer(logger.Session("echo"), "", echoHandler)
   853  			echoServer.SetListener(echoListener)
   854  			go echoServer.Serve()
   855  
   856  			forwardSpecs = []LocalPortForward{{
   857  				RemoteAddress: echoAddress,
   858  				LocalAddress:  localAddress,
   859  			}}
   860  
   861  			fakeSecureClient.DialStub = net.Dial
   862  		})
   863  
   864  		JustBeforeEach(func() {
   865  			connectErr := secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation)
   866  			Expect(connectErr).NotTo(HaveOccurred())
   867  
   868  			forwardErr = secureShell.LocalPortForward(forwardSpecs)
   869  		})
   870  
   871  		AfterEach(func() {
   872  			err := secureShell.Close()
   873  			Expect(err).NotTo(HaveOccurred())
   874  			echoServer.Shutdown()
   875  
   876  			realLocalListener.Close()
   877  		})
   878  
   879  		validateConnectivity := func(addr string) {
   880  			conn, err := net.Dial("tcp", addr)
   881  			Expect(err).NotTo(HaveOccurred())
   882  
   883  			msg := fmt.Sprintf("Hello from %s\n", addr)
   884  			n, err := conn.Write([]byte(msg))
   885  			Expect(err).NotTo(HaveOccurred())
   886  			Expect(n).To(Equal(len(msg)))
   887  
   888  			response := make([]byte, len(msg))
   889  			n, err = conn.Read(response)
   890  			Expect(err).NotTo(HaveOccurred())
   891  			Expect(n).To(Equal(len(msg)))
   892  
   893  			err = conn.Close()
   894  			Expect(err).NotTo(HaveOccurred())
   895  
   896  			Expect(response).To(Equal([]byte(msg)))
   897  		}
   898  
   899  		It("dials the connect address when a local connection is made", func() {
   900  			Expect(forwardErr).NotTo(HaveOccurred())
   901  
   902  			conn, err := net.Dial("tcp", localAddress)
   903  			Expect(err).NotTo(HaveOccurred())
   904  
   905  			Eventually(echoListener.AcceptCallCount).Should(BeNumerically(">=", 1))
   906  			Eventually(fakeSecureClient.DialCallCount).Should(Equal(1))
   907  
   908  			network, addr := fakeSecureClient.DialArgsForCall(0)
   909  			Expect(network).To(Equal("tcp"))
   910  			Expect(addr).To(Equal(echoAddress))
   911  
   912  			Expect(conn.Close()).NotTo(HaveOccurred())
   913  		})
   914  
   915  		It("copies data between the local and remote connections", func() {
   916  			validateConnectivity(localAddress)
   917  		})
   918  
   919  		Context("when a local connection is already open", func() {
   920  			var conn net.Conn
   921  
   922  			JustBeforeEach(func() {
   923  				var err error
   924  				conn, err = net.Dial("tcp", localAddress)
   925  				Expect(err).NotTo(HaveOccurred())
   926  			})
   927  
   928  			AfterEach(func() {
   929  				err := conn.Close()
   930  				Expect(err).NotTo(HaveOccurred())
   931  			})
   932  
   933  			It("allows for new incoming connections as well", func() {
   934  				validateConnectivity(localAddress)
   935  			})
   936  		})
   937  
   938  		Context("when there are multiple port forward specs", func() {
   939  			var (
   940  				realLocalListener2 net.Listener
   941  				localAddress2      string
   942  			)
   943  
   944  			BeforeEach(func() {
   945  				var err error
   946  				realLocalListener2, err = net.Listen("tcp", "127.0.0.1:0")
   947  				Expect(err).NotTo(HaveOccurred())
   948  
   949  				localAddress2 = realLocalListener2.Addr().String()
   950  
   951  				fakeListenerFactory.ListenStub = func(network, addr string) (net.Listener, error) {
   952  					if addr == localAddress {
   953  						return realLocalListener, nil
   954  					}
   955  
   956  					if addr == localAddress2 {
   957  						return realLocalListener2, nil
   958  					}
   959  
   960  					return nil, errors.New("unexpected address")
   961  				}
   962  
   963  				forwardSpecs = []LocalPortForward{
   964  					{
   965  						RemoteAddress: echoAddress,
   966  						LocalAddress:  localAddress,
   967  					},
   968  					{
   969  						RemoteAddress: echoAddress,
   970  						LocalAddress:  localAddress2,
   971  					},
   972  				}
   973  			})
   974  
   975  			AfterEach(func() {
   976  				realLocalListener2.Close()
   977  			})
   978  
   979  			It("listens to all the things", func() {
   980  				Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(2))
   981  
   982  				network, addr := fakeListenerFactory.ListenArgsForCall(0)
   983  				Expect(network).To(Equal("tcp"))
   984  				Expect(addr).To(Equal(localAddress))
   985  
   986  				network, addr = fakeListenerFactory.ListenArgsForCall(1)
   987  				Expect(network).To(Equal("tcp"))
   988  				Expect(addr).To(Equal(localAddress2))
   989  			})
   990  
   991  			It("forwards to the correct target", func() {
   992  				validateConnectivity(localAddress)
   993  				validateConnectivity(localAddress2)
   994  			})
   995  
   996  			Context("when the secure client is closed", func() {
   997  				BeforeEach(func() {
   998  					fakeListenerFactory.ListenReturns(fakeLocalListener, nil)
   999  					fakeLocalListener.AcceptReturns(nil, errors.New("not accepting connections"))
  1000  				})
  1001  
  1002  				It("closes the listeners ", func() {
  1003  					Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(2))
  1004  					Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(2))
  1005  
  1006  					originalCloseCount := fakeLocalListener.CloseCallCount()
  1007  					err := secureShell.Close()
  1008  					Expect(err).NotTo(HaveOccurred())
  1009  					Expect(fakeLocalListener.CloseCallCount()).Should(Equal(originalCloseCount + 2))
  1010  				})
  1011  			})
  1012  		})
  1013  
  1014  		Context("when listen fails", func() {
  1015  			BeforeEach(func() {
  1016  				fakeListenerFactory.ListenReturns(nil, errors.New("failure is an option"))
  1017  			})
  1018  
  1019  			It("returns the error", func() {
  1020  				Expect(forwardErr).To(MatchError("failure is an option"))
  1021  			})
  1022  		})
  1023  
  1024  		Context("when the client it closed", func() {
  1025  			BeforeEach(func() {
  1026  				fakeListenerFactory.ListenReturns(fakeLocalListener, nil)
  1027  				fakeLocalListener.AcceptReturns(nil, errors.New("not accepting and connections"))
  1028  			})
  1029  
  1030  			It("closes the listener when the client is closed", func() {
  1031  				Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(1))
  1032  				Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(1))
  1033  
  1034  				originalCloseCount := fakeLocalListener.CloseCallCount()
  1035  				err := secureShell.Close()
  1036  				Expect(err).NotTo(HaveOccurred())
  1037  				Expect(fakeLocalListener.CloseCallCount()).Should(Equal(originalCloseCount + 1))
  1038  			})
  1039  		})
  1040  
  1041  		Context("when accept fails", func() {
  1042  			var fakeConn *fake_net.FakeConn
  1043  
  1044  			BeforeEach(func() {
  1045  				fakeConn = &fake_net.FakeConn{}
  1046  				fakeConn.ReadReturns(0, io.EOF)
  1047  
  1048  				fakeListenerFactory.ListenReturns(fakeLocalListener, nil)
  1049  			})
  1050  
  1051  			Context("with a permanent error", func() {
  1052  				BeforeEach(func() {
  1053  					fakeLocalListener.AcceptReturns(nil, errors.New("boom"))
  1054  				})
  1055  
  1056  				It("stops trying to accept connections", func() {
  1057  					Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(1))
  1058  					Consistently(fakeLocalListener.AcceptCallCount).Should(Equal(1))
  1059  					Expect(fakeLocalListener.CloseCallCount()).To(Equal(1))
  1060  				})
  1061  			})
  1062  
  1063  			Context("with a temporary error", func() {
  1064  				var timeCh chan time.Time
  1065  
  1066  				BeforeEach(func() {
  1067  					timeCh = make(chan time.Time, 3)
  1068  
  1069  					fakeLocalListener.AcceptStub = func() (net.Conn, error) {
  1070  						timeCh := timeCh
  1071  						if fakeLocalListener.AcceptCallCount() > 3 {
  1072  							close(timeCh)
  1073  							return nil, test_helpers.NewTestNetError(false, false)
  1074  						} else {
  1075  							timeCh <- time.Now()
  1076  							return nil, test_helpers.NewTestNetError(false, true)
  1077  						}
  1078  					}
  1079  				})
  1080  
  1081  				PIt("retries connecting after a short delay", func() {
  1082  					Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(3))
  1083  					Expect(timeCh).To(HaveLen(3))
  1084  
  1085  					times := make([]time.Time, 0)
  1086  					for t := range timeCh {
  1087  						times = append(times, t)
  1088  					}
  1089  
  1090  					Expect(times[1]).To(BeTemporally("~", times[0].Add(115*time.Millisecond), 80*time.Millisecond))
  1091  					Expect(times[2]).To(BeTemporally("~", times[1].Add(115*time.Millisecond), 100*time.Millisecond))
  1092  				})
  1093  			})
  1094  		})
  1095  
  1096  		Context("when dialing the connect address fails", func() {
  1097  			var fakeTarget *fake_net.FakeConn
  1098  
  1099  			BeforeEach(func() {
  1100  				fakeTarget = &fake_net.FakeConn{}
  1101  				fakeSecureClient.DialReturns(fakeTarget, errors.New("boom"))
  1102  			})
  1103  
  1104  			It("does not call close on the target connection", func() {
  1105  				Consistently(fakeTarget.CloseCallCount).Should(Equal(0))
  1106  			})
  1107  		})
  1108  	})
  1109  
  1110  	Describe("Wait", func() {
  1111  		var waitErr error
  1112  
  1113  		JustBeforeEach(func() {
  1114  			connectErr := secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation)
  1115  			Expect(connectErr).NotTo(HaveOccurred())
  1116  
  1117  			waitErr = secureShell.Wait()
  1118  		})
  1119  
  1120  		It("calls wait on the secureClient", func() {
  1121  			Expect(waitErr).NotTo(HaveOccurred())
  1122  			Expect(fakeSecureClient.WaitCallCount()).To(Equal(1))
  1123  		})
  1124  
  1125  		Describe("keep alive messages", func() {
  1126  			var times []time.Time
  1127  			var timesCh chan []time.Time
  1128  			var done chan struct{}
  1129  
  1130  			BeforeEach(func() {
  1131  				keepAliveDuration = 100 * time.Millisecond
  1132  
  1133  				times = []time.Time{}
  1134  				timesCh = make(chan []time.Time, 1)
  1135  				done = make(chan struct{}, 1)
  1136  
  1137  				fakeConnection.SendRequestStub = func(reqName string, wantReply bool, message []byte) (bool, []byte, error) {
  1138  					Expect(reqName).To(Equal("keepalive@cloudfoundry.org"))
  1139  					Expect(wantReply).To(BeTrue())
  1140  					Expect(message).To(BeNil())
  1141  
  1142  					times = append(times, time.Now())
  1143  					if len(times) == 3 {
  1144  						timesCh <- times
  1145  						close(done)
  1146  					}
  1147  					return true, nil, nil
  1148  				}
  1149  
  1150  				fakeSecureClient.WaitStub = func() error {
  1151  					Eventually(done).Should(BeClosed())
  1152  					return nil
  1153  				}
  1154  			})
  1155  
  1156  			PIt("sends keep alive messages at the expected interval", func() {
  1157  				Expect(waitErr).NotTo(HaveOccurred())
  1158  				times := <-timesCh
  1159  				Expect(times[2]).To(BeTemporally("~", times[0].Add(200*time.Millisecond), 100*time.Millisecond))
  1160  			})
  1161  		})
  1162  	})
  1163  
  1164  	Describe("Close", func() {
  1165  		JustBeforeEach(func() {
  1166  			connectErr := secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation)
  1167  			Expect(connectErr).NotTo(HaveOccurred())
  1168  		})
  1169  
  1170  		It("calls close on the secureClient", func() {
  1171  			err := secureShell.Close()
  1172  			Expect(err).NotTo(HaveOccurred())
  1173  
  1174  			Expect(fakeSecureClient.CloseCallCount()).To(Equal(1))
  1175  		})
  1176  	})
  1177  })