github.com/DaAlbrecht/cf-cli@v0.0.0-20231128151943-1fe19bb400b9/util/clissh/ssh_test.go (about)

     1  //go:build !windows && !386
     2  // +build !windows,!386
     3  
     4  // skipping 386 because lager uses UInt64 in Session()
     5  // skipping windows because Unix/Linux only syscall in test.
     6  // should refactor out the conflicts so we could test this package in multi platforms.
     7  
     8  package clissh_test
     9  
    10  import (
    11  	"errors"
    12  	"fmt"
    13  	"io"
    14  	"net"
    15  	"os"
    16  	"sync"
    17  	"syscall"
    18  	"time"
    19  
    20  	. "code.cloudfoundry.org/cli/util/clissh"
    21  	"code.cloudfoundry.org/cli/util/clissh/clisshfakes"
    22  	"code.cloudfoundry.org/cli/util/clissh/ssherror"
    23  	"code.cloudfoundry.org/diego-ssh/server"
    24  	fake_server "code.cloudfoundry.org/diego-ssh/server/fakes"
    25  	"code.cloudfoundry.org/diego-ssh/test_helpers"
    26  	"code.cloudfoundry.org/diego-ssh/test_helpers/fake_io"
    27  	"code.cloudfoundry.org/diego-ssh/test_helpers/fake_net"
    28  	"code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh"
    29  	"code.cloudfoundry.org/lager/lagertest"
    30  	"github.com/kr/pty"
    31  	"github.com/moby/term"
    32  	. "github.com/onsi/ginkgo"
    33  	. "github.com/onsi/gomega"
    34  	"golang.org/x/crypto/ssh"
    35  )
    36  
    37  func BlockAcceptOnClose(fake *fake_net.FakeListener) {
    38  	waitUntilClosed := make(chan bool)
    39  	fake.AcceptStub = func() (net.Conn, error) {
    40  		<-waitUntilClosed
    41  		return nil, errors.New("banana")
    42  	}
    43  
    44  	var once sync.Once
    45  	fake.CloseStub = func() error {
    46  		once.Do(func() {
    47  			close(waitUntilClosed)
    48  		})
    49  		return nil
    50  	}
    51  }
    52  
    53  var _ = Describe("CLI SSH", func() {
    54  	var (
    55  		fakeSecureDialer    *clisshfakes.FakeSecureDialer
    56  		fakeSecureClient    *clisshfakes.FakeSecureClient
    57  		fakeTerminalHelper  *clisshfakes.FakeTerminalHelper
    58  		fakeListenerFactory *clisshfakes.FakeListenerFactory
    59  		fakeSecureSession   *clisshfakes.FakeSecureSession
    60  
    61  		fakeConnection *fake_ssh.FakeConn
    62  		stdinPipe      *fake_io.FakeWriteCloser
    63  		stdoutPipe     *fake_io.FakeReader
    64  		stderrPipe     *fake_io.FakeReader
    65  		secureShell    *SecureShell
    66  
    67  		username               string
    68  		passcode               string
    69  		sshEndpoint            string
    70  		sshEndpointFingerprint string
    71  		skipHostValidation     bool
    72  		commands               []string
    73  		terminalRequest        TTYRequest
    74  		keepAliveDuration      time.Duration
    75  	)
    76  
    77  	BeforeEach(func() {
    78  		fakeSecureDialer = new(clisshfakes.FakeSecureDialer)
    79  		fakeSecureClient = new(clisshfakes.FakeSecureClient)
    80  		fakeTerminalHelper = new(clisshfakes.FakeTerminalHelper)
    81  		fakeListenerFactory = new(clisshfakes.FakeListenerFactory)
    82  		fakeSecureSession = new(clisshfakes.FakeSecureSession)
    83  
    84  		fakeConnection = new(fake_ssh.FakeConn)
    85  		stdinPipe = new(fake_io.FakeWriteCloser)
    86  		stdoutPipe = new(fake_io.FakeReader)
    87  		stderrPipe = new(fake_io.FakeReader)
    88  
    89  		fakeListenerFactory.ListenStub = net.Listen
    90  		fakeSecureClient.NewSessionReturns(fakeSecureSession, nil)
    91  		fakeSecureClient.ConnReturns(fakeConnection)
    92  		fakeSecureDialer.DialReturns(fakeSecureClient, nil)
    93  
    94  		stdinPipe.WriteStub = func(p []byte) (int, error) {
    95  			return len(p), nil
    96  		}
    97  		fakeSecureSession.StdinPipeReturns(stdinPipe, nil)
    98  
    99  		stdoutPipe.ReadStub = func(p []byte) (int, error) {
   100  			return 0, io.EOF
   101  		}
   102  		fakeSecureSession.StdoutPipeReturns(stdoutPipe, nil)
   103  
   104  		stderrPipe.ReadStub = func(p []byte) (int, error) {
   105  			return 0, io.EOF
   106  		}
   107  		fakeSecureSession.StderrPipeReturns(stderrPipe, nil)
   108  
   109  		username = "some-user"
   110  		passcode = "some-passcode"
   111  		sshEndpoint = "some-endpoint"
   112  		sshEndpointFingerprint = "some-fingerprint"
   113  		skipHostValidation = false
   114  		commands = []string{}
   115  		terminalRequest = RequestTTYAuto
   116  		keepAliveDuration = DefaultKeepAliveInterval
   117  	})
   118  
   119  	JustBeforeEach(func() {
   120  		secureShell = NewSecureShell(
   121  			fakeSecureDialer,
   122  			fakeTerminalHelper,
   123  			fakeListenerFactory,
   124  			keepAliveDuration,
   125  		)
   126  	})
   127  
   128  	Describe("Connect", func() {
   129  		var connectErr error
   130  
   131  		JustBeforeEach(func() {
   132  			connectErr = secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation)
   133  		})
   134  
   135  		When("dialing succeeds", func() {
   136  			It("creates the ssh client", func() {
   137  				Expect(connectErr).ToNot(HaveOccurred())
   138  
   139  				Expect(fakeSecureDialer.DialCallCount()).To(Equal(1))
   140  				protocolArg, sshEndpointArg, sshConfigArg := fakeSecureDialer.DialArgsForCall(0)
   141  				Expect(protocolArg).To(Equal("tcp"))
   142  				Expect(sshEndpointArg).To(Equal(sshEndpoint))
   143  				Expect(sshConfigArg.User).To(Equal(username))
   144  				Expect(sshConfigArg.Auth).To(HaveLen(1))
   145  				Expect(sshConfigArg.HostKeyCallback).ToNot(BeNil())
   146  			})
   147  		})
   148  
   149  		When("dialing fails", func() {
   150  			var dialError error
   151  
   152  			When("the error is a generic Dial error", func() {
   153  				BeforeEach(func() {
   154  					dialError = errors.New("woops")
   155  					fakeSecureDialer.DialReturns(nil, dialError)
   156  				})
   157  
   158  				It("returns the dial error", func() {
   159  					Expect(connectErr).To(Equal(dialError))
   160  					Expect(fakeSecureDialer.DialCallCount()).To(Equal(1))
   161  				})
   162  			})
   163  
   164  			When("the dialing error is a golang 'unable to authenticate' error", func() {
   165  				BeforeEach(func() {
   166  					dialError = fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", []string{"none", "password"})
   167  					fakeSecureDialer.DialReturns(nil, dialError)
   168  				})
   169  
   170  				It("returns an UnableToAuthenticateError", func() {
   171  					Expect(connectErr).To(MatchError(ssherror.UnableToAuthenticateError{Err: dialError}))
   172  					Expect(fakeSecureDialer.DialCallCount()).To(Equal(1))
   173  				})
   174  			})
   175  		})
   176  	})
   177  
   178  	Describe("InteractiveSession", func() {
   179  		var (
   180  			stdin          *fake_io.FakeReadCloser
   181  			stdout, stderr *fake_io.FakeWriter
   182  
   183  			sessionErr                error
   184  			interactiveSessionInvoker func(secureShell *SecureShell)
   185  		)
   186  
   187  		BeforeEach(func() {
   188  			stdin = new(fake_io.FakeReadCloser)
   189  			stdout = new(fake_io.FakeWriter)
   190  			stderr = new(fake_io.FakeWriter)
   191  
   192  			fakeTerminalHelper.StdStreamsReturns(stdin, stdout, stderr)
   193  			interactiveSessionInvoker = func(secureShell *SecureShell) {
   194  				sessionErr = secureShell.InteractiveSession(commands, terminalRequest)
   195  			}
   196  		})
   197  
   198  		JustBeforeEach(func() {
   199  			connectErr := secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation)
   200  			Expect(connectErr).NotTo(HaveOccurred())
   201  			interactiveSessionInvoker(secureShell)
   202  		})
   203  
   204  		When("host key validation is enabled", func() {
   205  			var (
   206  				callback func(hostname string, remote net.Addr, key ssh.PublicKey) error
   207  				addr     net.Addr
   208  			)
   209  
   210  			BeforeEach(func() {
   211  				skipHostValidation = false
   212  			})
   213  
   214  			JustBeforeEach(func() {
   215  				Expect(fakeSecureDialer.DialCallCount()).To(Equal(1))
   216  				_, _, config := fakeSecureDialer.DialArgsForCall(0)
   217  				callback = config.HostKeyCallback
   218  
   219  				listener, err := net.Listen("tcp", "localhost:0")
   220  				Expect(err).NotTo(HaveOccurred())
   221  
   222  				addr = listener.Addr()
   223  				listener.Close()
   224  			})
   225  
   226  			When("the md5 fingerprint matches", func() {
   227  				BeforeEach(func() {
   228  					sshEndpointFingerprint = "41:ce:56:e6:9c:42:a9:c6:9e:68:ac:e3:4d:f6:38:79"
   229  				})
   230  
   231  				It("does not return an error", func() {
   232  					Expect(callback("", addr, TestHostKey.PublicKey())).ToNot(HaveOccurred())
   233  				})
   234  			})
   235  
   236  			When("the hex sha1 fingerprint matches", func() {
   237  				BeforeEach(func() {
   238  					sshEndpointFingerprint = "a8:e2:67:cb:ea:2a:6e:23:a1:72:ce:8f:07:92:15:ee:1f:82:f8:ca"
   239  				})
   240  
   241  				It("does not return an error", func() {
   242  					Expect(callback("", addr, TestHostKey.PublicKey())).ToNot(HaveOccurred())
   243  				})
   244  			})
   245  
   246  			When("the base64 sha256 fingerprint matches", func() {
   247  				BeforeEach(func() {
   248  					sshEndpointFingerprint = "sp/jrLuj66r+yrLDUKZdJU5tdzt4mq/UaSiNBjpgr+8"
   249  				})
   250  
   251  				It("does not return an error", func() {
   252  					Expect(callback("", addr, TestHostKey.PublicKey())).ToNot(HaveOccurred())
   253  				})
   254  			})
   255  
   256  			When("the base64 SHA256 fingerprint does not match", func() {
   257  				BeforeEach(func() {
   258  					sshEndpointFingerprint = "0000000000000000000000000000000000000000000"
   259  				})
   260  
   261  				It("returns an error'", func() {
   262  					err := callback("", addr, TestHostKey.PublicKey())
   263  					Expect(err).To(MatchError(MatchRegexp(`Host key verification failed\.`)))
   264  					Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\"")))
   265  				})
   266  			})
   267  
   268  			When("the hex SHA1 fingerprint does not match", func() {
   269  				BeforeEach(func() {
   270  					sshEndpointFingerprint = "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00"
   271  				})
   272  
   273  				It("returns an error'", func() {
   274  					err := callback("", addr, TestHostKey.PublicKey())
   275  					Expect(err).To(MatchError(MatchRegexp(`Host key verification failed\.`)))
   276  					Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\"")))
   277  				})
   278  			})
   279  
   280  			When("the MD5 fingerprint does not match", func() {
   281  				BeforeEach(func() {
   282  					sshEndpointFingerprint = "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00"
   283  				})
   284  
   285  				It("returns an error'", func() {
   286  					err := callback("", addr, TestHostKey.PublicKey())
   287  					Expect(err).To(MatchError(MatchRegexp(`Host key verification failed\.`)))
   288  					Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\"")))
   289  				})
   290  			})
   291  
   292  			When("no fingerprint is present in endpoint info", func() {
   293  				BeforeEach(func() {
   294  					sshEndpointFingerprint = ""
   295  					sshEndpoint = ""
   296  				})
   297  
   298  				It("returns an error'", func() {
   299  					err := callback("", addr, TestHostKey.PublicKey())
   300  					Expect(err).To(MatchError(MatchRegexp(`Unable to verify identity of host\.`)))
   301  					Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\"")))
   302  				})
   303  			})
   304  
   305  			When("the fingerprint length doesn't make sense", func() {
   306  				BeforeEach(func() {
   307  					sshEndpointFingerprint = "garbage"
   308  				})
   309  
   310  				It("returns an error", func() {
   311  					err := callback("", addr, TestHostKey.PublicKey())
   312  					Eventually(err).Should(MatchError(MatchRegexp("Unsupported host key fingerprint format")))
   313  				})
   314  			})
   315  		})
   316  
   317  		When("the skip host validation flag is set", func() {
   318  			BeforeEach(func() {
   319  				skipHostValidation = true
   320  			})
   321  
   322  			It("the HostKeyCallback on the Config to always return nil", func() {
   323  				Expect(fakeSecureDialer.DialCallCount()).To(Equal(1))
   324  
   325  				_, _, config := fakeSecureDialer.DialArgsForCall(0)
   326  				Expect(config.HostKeyCallback("some-hostname", nil, nil)).To(BeNil())
   327  			})
   328  		})
   329  
   330  		// TODO: see if it's possible to test the piping between the ss client input and outputs and the UI object we pass in
   331  		When("dialing is successful", func() {
   332  			It("creates a new secure shell session", func() {
   333  				Expect(fakeSecureClient.NewSessionCallCount()).To(Equal(1))
   334  			})
   335  
   336  			It("closes the session", func() {
   337  				Expect(fakeSecureSession.CloseCallCount()).To(Equal(1))
   338  			})
   339  
   340  			It("gets a stdin pipe for the session", func() {
   341  				Expect(fakeSecureSession.StdinPipeCallCount()).To(Equal(1))
   342  			})
   343  
   344  			When("getting the stdin pipe fails", func() {
   345  				BeforeEach(func() {
   346  					fakeSecureSession.StdinPipeReturns(nil, errors.New("woops"))
   347  				})
   348  
   349  				It("returns the error", func() {
   350  					Expect(sessionErr).Should(MatchError("woops"))
   351  				})
   352  			})
   353  
   354  			It("gets a stdout pipe for the session", func() {
   355  				Expect(fakeSecureSession.StdoutPipeCallCount()).To(Equal(1))
   356  			})
   357  
   358  			When("getting the stdout pipe fails", func() {
   359  				BeforeEach(func() {
   360  					fakeSecureSession.StdoutPipeReturns(nil, errors.New("woops"))
   361  				})
   362  
   363  				It("returns the error", func() {
   364  					Expect(sessionErr).Should(MatchError("woops"))
   365  				})
   366  			})
   367  
   368  			It("gets a stderr pipe for the session", func() {
   369  				Expect(fakeSecureSession.StderrPipeCallCount()).To(Equal(1))
   370  			})
   371  
   372  			When("getting the stderr pipe fails", func() {
   373  				BeforeEach(func() {
   374  					fakeSecureSession.StderrPipeReturns(nil, errors.New("woops"))
   375  				})
   376  
   377  				It("returns the error", func() {
   378  					Expect(sessionErr).Should(MatchError("woops"))
   379  				})
   380  			})
   381  		})
   382  
   383  		When("stdin is a terminal", func() {
   384  			var terminal *os.File
   385  
   386  			BeforeEach(func() {
   387  				var err error
   388  				terminal, _, err = pty.Open()
   389  				Expect(err).NotTo(HaveOccurred())
   390  
   391  				terminalRequest = RequestTTYForce
   392  
   393  				terminalHelper := DefaultTerminalHelper()
   394  				fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo
   395  				fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize
   396  			})
   397  
   398  			AfterEach(func() {
   399  				terminal.Close()
   400  			})
   401  
   402  			When("a command is not specified", func() {
   403  				var terminalType string
   404  
   405  				BeforeEach(func() {
   406  					terminalType = os.Getenv("TERM")
   407  					os.Setenv("TERM", "test-terminal-type")
   408  
   409  					winsize := &term.Winsize{Width: 1024, Height: 256}
   410  					fakeTerminalHelper.GetWinsizeReturns(winsize, nil)
   411  
   412  					fakeSecureSession.ShellStub = func() error {
   413  						Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1))
   414  						Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(0))
   415  						return nil
   416  					}
   417  				})
   418  
   419  				AfterEach(func() {
   420  					os.Setenv("TERM", terminalType)
   421  				})
   422  
   423  				It("requests a pty with the correct terminal type, window size, and modes", func() {
   424  					Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1))
   425  					Expect(fakeTerminalHelper.GetWinsizeCallCount()).To(Equal(1))
   426  
   427  					termType, height, width, modes := fakeSecureSession.RequestPtyArgsForCall(0)
   428  					Expect(termType).To(Equal("test-terminal-type"))
   429  					Expect(height).To(Equal(256))
   430  					Expect(width).To(Equal(1024))
   431  
   432  					expectedModes := ssh.TerminalModes{
   433  						ssh.ECHO:          1,
   434  						ssh.TTY_OP_ISPEED: 115200,
   435  						ssh.TTY_OP_OSPEED: 115200,
   436  					}
   437  					Expect(modes).To(Equal(expectedModes))
   438  				})
   439  
   440  				When("the TERM environment variable is not set", func() {
   441  					BeforeEach(func() {
   442  						os.Unsetenv("TERM")
   443  					})
   444  
   445  					It("requests a pty with the default terminal type", func() {
   446  						Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1))
   447  
   448  						termType, _, _, _ := fakeSecureSession.RequestPtyArgsForCall(0)
   449  						Expect(termType).To(Equal("xterm"))
   450  					})
   451  				})
   452  
   453  				It("puts the terminal into raw mode and restores it after running the shell", func() {
   454  					Expect(fakeSecureSession.ShellCallCount()).To(Equal(1))
   455  					Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1))
   456  					Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(1))
   457  				})
   458  
   459  				When("the pty allocation fails", func() {
   460  					var ptyError error
   461  
   462  					BeforeEach(func() {
   463  						ptyError = errors.New("pty allocation error")
   464  						fakeSecureSession.RequestPtyReturns(ptyError)
   465  					})
   466  
   467  					It("returns the error", func() {
   468  						Expect(sessionErr).To(Equal(ptyError))
   469  					})
   470  				})
   471  
   472  				When("placing the terminal into raw mode fails", func() {
   473  					BeforeEach(func() {
   474  						fakeTerminalHelper.SetRawTerminalReturns(nil, errors.New("woops"))
   475  					})
   476  
   477  					It("keeps calm and carries on", func() {
   478  						Expect(fakeSecureSession.ShellCallCount()).To(Equal(1))
   479  					})
   480  
   481  					It("does not restore the terminal", func() {
   482  						Expect(fakeSecureSession.ShellCallCount()).To(Equal(1))
   483  						Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1))
   484  						Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(0))
   485  					})
   486  				})
   487  			})
   488  
   489  			When("a command is specified", func() {
   490  				BeforeEach(func() {
   491  					commands = []string{"echo", "-n", "hello"}
   492  				})
   493  
   494  				When("a terminal is requested", func() {
   495  					BeforeEach(func() {
   496  						terminalRequest = RequestTTYYes
   497  						fakeTerminalHelper.GetFdInfoReturns(0, true)
   498  					})
   499  
   500  					It("requests a pty", func() {
   501  						Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1))
   502  					})
   503  				})
   504  
   505  				When("a terminal is not explicitly requested", func() {
   506  					BeforeEach(func() {
   507  						terminalRequest = RequestTTYAuto
   508  					})
   509  
   510  					It("does not request a pty", func() {
   511  						Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0))
   512  					})
   513  				})
   514  			})
   515  		})
   516  
   517  		When("stdin is not a terminal", func() {
   518  			BeforeEach(func() {
   519  				stdin.ReadStub = func(p []byte) (int, error) {
   520  					return 0, io.EOF
   521  				}
   522  
   523  				terminalHelper := DefaultTerminalHelper()
   524  				fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo
   525  				fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize
   526  			})
   527  
   528  			When("a terminal is not requested", func() {
   529  				It("does not request a pty", func() {
   530  					Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0))
   531  				})
   532  			})
   533  
   534  			When("a terminal is requested", func() {
   535  				BeforeEach(func() {
   536  					terminalRequest = RequestTTYYes
   537  				})
   538  
   539  				It("does not request a pty", func() {
   540  					Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0))
   541  				})
   542  			})
   543  		})
   544  
   545  		PWhen("a terminal is forced", func() {
   546  			BeforeEach(func() {
   547  				terminalRequest = RequestTTYForce
   548  			})
   549  
   550  			It("requests a pty", func() {
   551  				Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1))
   552  			})
   553  		})
   554  
   555  		When("a terminal is disabled", func() {
   556  			BeforeEach(func() {
   557  				terminalRequest = RequestTTYNo
   558  			})
   559  
   560  			It("does not request a pty", func() {
   561  				Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0))
   562  			})
   563  		})
   564  
   565  		When("a command is not specified", func() {
   566  			It("requests an interactive shell", func() {
   567  				Expect(fakeSecureSession.ShellCallCount()).To(Equal(1))
   568  			})
   569  
   570  			When("the shell request returns an error", func() {
   571  				BeforeEach(func() {
   572  					fakeSecureSession.ShellReturns(errors.New("oh bother"))
   573  				})
   574  
   575  				It("returns the error", func() {
   576  					Expect(sessionErr).To(MatchError("oh bother"))
   577  				})
   578  			})
   579  		})
   580  
   581  		When("a command is specified", func() {
   582  			BeforeEach(func() {
   583  				commands = []string{"echo", "-n", "hello"}
   584  			})
   585  
   586  			It("starts the command", func() {
   587  				Expect(fakeSecureSession.StartCallCount()).To(Equal(1))
   588  				Expect(fakeSecureSession.StartArgsForCall(0)).To(Equal("echo -n hello"))
   589  			})
   590  
   591  			When("the command fails to start", func() {
   592  				BeforeEach(func() {
   593  					fakeSecureSession.StartReturns(errors.New("oh well"))
   594  				})
   595  
   596  				It("returns the error", func() {
   597  					Expect(sessionErr).To(MatchError("oh well"))
   598  				})
   599  			})
   600  		})
   601  
   602  		When("the shell or command has started", func() {
   603  			BeforeEach(func() {
   604  				stdin.ReadStub = func(p []byte) (int, error) {
   605  					p[0] = 0
   606  					return 1, io.EOF
   607  				}
   608  				stdinPipe.WriteStub = func(p []byte) (int, error) {
   609  					defer GinkgoRecover()
   610  					Expect(p[0]).To(Equal(byte(0)))
   611  					return 1, nil
   612  				}
   613  
   614  				stdoutPipe.ReadStub = func(p []byte) (int, error) {
   615  					p[0] = 1
   616  					return 1, io.EOF
   617  				}
   618  				stdout.WriteStub = func(p []byte) (int, error) {
   619  					defer GinkgoRecover()
   620  					Expect(p[0]).To(Equal(byte(1)))
   621  					return 1, nil
   622  				}
   623  
   624  				stderrPipe.ReadStub = func(p []byte) (int, error) {
   625  					p[0] = 2
   626  					return 1, io.EOF
   627  				}
   628  				stderr.WriteStub = func(p []byte) (int, error) {
   629  					defer GinkgoRecover()
   630  					Expect(p[0]).To(Equal(byte(2)))
   631  					return 1, nil
   632  				}
   633  
   634  				fakeSecureSession.StdinPipeReturns(stdinPipe, nil)
   635  				fakeSecureSession.StdoutPipeReturns(stdoutPipe, nil)
   636  				fakeSecureSession.StderrPipeReturns(stderrPipe, nil)
   637  
   638  				fakeSecureSession.WaitReturns(errors.New("error result"))
   639  			})
   640  
   641  			It("copies data from the stdin stream to the session stdin pipe", func() {
   642  				Eventually(stdin.ReadCallCount).Should(Equal(1))
   643  				Eventually(stdinPipe.WriteCallCount).Should(Equal(1))
   644  			})
   645  
   646  			It("copies data from the session stdout pipe to the stdout stream", func() {
   647  				Eventually(stdoutPipe.ReadCallCount).Should(Equal(1))
   648  				Eventually(stdout.WriteCallCount).Should(Equal(1))
   649  			})
   650  
   651  			It("copies data from the session stderr pipe to the stderr stream", func() {
   652  				Eventually(stderrPipe.ReadCallCount).Should(Equal(1))
   653  				Eventually(stderr.WriteCallCount).Should(Equal(1))
   654  			})
   655  
   656  			It("waits for the session to end", func() {
   657  				Expect(fakeSecureSession.WaitCallCount()).To(Equal(1))
   658  			})
   659  
   660  			It("returns the result from wait", func() {
   661  				Expect(sessionErr).To(MatchError("error result"))
   662  			})
   663  
   664  			When("the session terminates before stream copies complete", func() {
   665  				var sessionErrorCh chan error
   666  
   667  				BeforeEach(func() {
   668  					sessionErrorCh = make(chan error, 1)
   669  
   670  					interactiveSessionInvoker = func(secureShell *SecureShell) {
   671  						go func() {
   672  							sessionErrorCh <- secureShell.InteractiveSession(commands, terminalRequest)
   673  						}()
   674  					}
   675  
   676  					stdoutPipe.ReadStub = func(p []byte) (int, error) {
   677  						defer GinkgoRecover()
   678  						Eventually(fakeSecureSession.WaitCallCount).Should(Equal(1))
   679  						Consistently(sessionErrorCh).ShouldNot(Receive())
   680  
   681  						p[0] = 1
   682  						return 1, io.EOF
   683  					}
   684  
   685  					stderrPipe.ReadStub = func(p []byte) (int, error) {
   686  						defer GinkgoRecover()
   687  						Eventually(fakeSecureSession.WaitCallCount).Should(Equal(1))
   688  						Consistently(sessionErrorCh).ShouldNot(Receive())
   689  
   690  						p[0] = 2
   691  						return 1, io.EOF
   692  					}
   693  				})
   694  
   695  				It("waits for the copies to complete", func() {
   696  					Eventually(sessionErrorCh).Should(Receive())
   697  					Expect(stdoutPipe.ReadCallCount()).To(Equal(1))
   698  					Expect(stderrPipe.ReadCallCount()).To(Equal(1))
   699  				})
   700  			})
   701  
   702  			When("stdin is closed", func() {
   703  				BeforeEach(func() {
   704  					stdin.ReadStub = func(p []byte) (int, error) {
   705  						defer GinkgoRecover()
   706  						Consistently(stdinPipe.CloseCallCount).Should(Equal(0))
   707  						p[0] = 0
   708  						return 1, io.EOF
   709  					}
   710  				})
   711  
   712  				It("closes the stdinPipe", func() {
   713  					Eventually(stdinPipe.CloseCallCount).Should(Equal(1))
   714  				})
   715  			})
   716  		})
   717  
   718  		When("stdout is a terminal and a window size change occurs", func() {
   719  			var pseudoterminal, terminal *os.File
   720  
   721  			BeforeEach(func() {
   722  				var err error
   723  				pseudoterminal, terminal, err = pty.Open()
   724  				Expect(err).NotTo(HaveOccurred())
   725  
   726  				terminalHelper := DefaultTerminalHelper()
   727  				fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo
   728  				fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize
   729  				fakeTerminalHelper.StdStreamsReturns(stdin, terminal, stderr)
   730  
   731  				winsize := &term.Winsize{Height: 100, Width: 100}
   732  				err = term.SetWinsize(terminal.Fd(), winsize)
   733  				Expect(err).NotTo(HaveOccurred())
   734  
   735  				fakeSecureSession.WaitStub = func() error {
   736  					fakeSecureSession.SendRequestCallCount()
   737  					Expect(fakeSecureSession.SendRequestCallCount()).To(Equal(0))
   738  
   739  					// No dimension change
   740  					for i := 0; i < 3; i++ {
   741  						winsize := &term.Winsize{Height: 100, Width: 100}
   742  						err = term.SetWinsize(terminal.Fd(), winsize)
   743  						Expect(err).NotTo(HaveOccurred())
   744  					}
   745  
   746  					winsize := &term.Winsize{Height: 100, Width: 200}
   747  					err = term.SetWinsize(terminal.Fd(), winsize)
   748  					Expect(err).NotTo(HaveOccurred())
   749  
   750  					err = syscall.Kill(syscall.Getpid(), syscall.SIGWINCH)
   751  					Expect(err).NotTo(HaveOccurred())
   752  
   753  					Eventually(fakeSecureSession.SendRequestCallCount).Should(Equal(1))
   754  					return nil
   755  				}
   756  			})
   757  
   758  			AfterEach(func() {
   759  				pseudoterminal.Close()
   760  				terminal.Close()
   761  			})
   762  
   763  			It("sends window change events when the window dimensions change", func() {
   764  				Expect(fakeSecureSession.SendRequestCallCount()).To(Equal(1))
   765  
   766  				requestType, wantReply, message := fakeSecureSession.SendRequestArgsForCall(0)
   767  				Expect(requestType).To(Equal("window-change"))
   768  				Expect(wantReply).To(BeFalse())
   769  
   770  				type resizeMessage struct {
   771  					Width       uint32
   772  					Height      uint32
   773  					PixelWidth  uint32
   774  					PixelHeight uint32
   775  				}
   776  				var resizeMsg resizeMessage
   777  
   778  				err := ssh.Unmarshal(message, &resizeMsg)
   779  				Expect(err).NotTo(HaveOccurred())
   780  
   781  				Expect(resizeMsg).To(Equal(resizeMessage{Height: 100, Width: 200}))
   782  			})
   783  		})
   784  
   785  		Describe("keep alive messages", func() {
   786  			var times []time.Time
   787  			var timesCh chan []time.Time
   788  			var done chan struct{}
   789  
   790  			BeforeEach(func() {
   791  				keepAliveDuration = 100 * time.Millisecond
   792  
   793  				times = []time.Time{}
   794  				timesCh = make(chan []time.Time, 1)
   795  				done = make(chan struct{}, 1)
   796  
   797  				fakeConnection.SendRequestStub = func(reqName string, wantReply bool, message []byte) (bool, []byte, error) {
   798  					Expect(reqName).To(Equal("keepalive@cloudfoundry.org"))
   799  					Expect(wantReply).To(BeTrue())
   800  					Expect(message).To(BeNil())
   801  
   802  					times = append(times, time.Now())
   803  					if len(times) == 3 {
   804  						timesCh <- times
   805  						close(done)
   806  					}
   807  					return true, nil, nil
   808  				}
   809  
   810  				fakeSecureSession.WaitStub = func() error {
   811  					Eventually(done).Should(BeClosed())
   812  					return nil
   813  				}
   814  			})
   815  
   816  			PIt("sends keep alive messages at the expected interval", func() {
   817  				times := <-timesCh
   818  				Expect(times[2]).To(BeTemporally("~", times[0].Add(200*time.Millisecond), 160*time.Millisecond))
   819  			})
   820  		})
   821  	})
   822  
   823  	Describe("LocalPortForward", func() {
   824  		var (
   825  			forwardErr error
   826  
   827  			echoAddress  string
   828  			echoListener *fake_net.FakeListener
   829  			echoHandler  *fake_server.FakeConnectionHandler
   830  			echoServer   *server.Server
   831  
   832  			localAddress string
   833  
   834  			realLocalListener net.Listener
   835  			fakeLocalListener *fake_net.FakeListener
   836  
   837  			forwardSpecs []LocalPortForward
   838  		)
   839  
   840  		BeforeEach(func() {
   841  			logger := lagertest.NewTestLogger("test")
   842  
   843  			var err error
   844  			realLocalListener, err = net.Listen("tcp", "127.0.0.1:0")
   845  			Expect(err).NotTo(HaveOccurred())
   846  
   847  			localAddress = realLocalListener.Addr().String()
   848  			fakeListenerFactory.ListenReturns(realLocalListener, nil)
   849  
   850  			echoHandler = new(fake_server.FakeConnectionHandler)
   851  			echoHandler.HandleConnectionStub = func(conn net.Conn) {
   852  				io.Copy(conn, conn) //nolint:errcheck
   853  				conn.Close()
   854  			}
   855  
   856  			realListener, err := net.Listen("tcp", "127.0.0.1:0")
   857  			Expect(err).NotTo(HaveOccurred())
   858  			echoAddress = realListener.Addr().String()
   859  
   860  			echoListener = new(fake_net.FakeListener)
   861  			echoListener.AcceptStub = realListener.Accept
   862  			echoListener.CloseStub = realListener.Close
   863  			echoListener.AddrStub = realListener.Addr
   864  
   865  			fakeLocalListener = new(fake_net.FakeListener)
   866  			fakeLocalListener.AcceptReturns(nil, errors.New("Not Accepting Connections"))
   867  
   868  			echoServer = server.NewServer(logger.Session("echo"), "", echoHandler)
   869  			err = echoServer.SetListener(echoListener)
   870  			Expect(err).NotTo(HaveOccurred())
   871  			go echoServer.Serve()
   872  
   873  			forwardSpecs = []LocalPortForward{{
   874  				RemoteAddress: echoAddress,
   875  				LocalAddress:  localAddress,
   876  			}}
   877  
   878  			fakeSecureClient.DialStub = net.Dial
   879  		})
   880  
   881  		JustBeforeEach(func() {
   882  			connectErr := secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation)
   883  			Expect(connectErr).NotTo(HaveOccurred())
   884  
   885  			forwardErr = secureShell.LocalPortForward(forwardSpecs)
   886  		})
   887  
   888  		AfterEach(func() {
   889  			err := secureShell.Close()
   890  			Expect(err).NotTo(HaveOccurred())
   891  			echoServer.Shutdown()
   892  
   893  			realLocalListener.Close()
   894  		})
   895  
   896  		validateConnectivity := func(addr string) {
   897  			conn, err := net.Dial("tcp", addr)
   898  			Expect(err).NotTo(HaveOccurred())
   899  
   900  			msg := fmt.Sprintf("Hello from %s\n", addr)
   901  			n, err := conn.Write([]byte(msg))
   902  			Expect(err).NotTo(HaveOccurred())
   903  			Expect(n).To(Equal(len(msg)))
   904  
   905  			response := make([]byte, len(msg))
   906  			n, err = conn.Read(response)
   907  			Expect(err).NotTo(HaveOccurred())
   908  			Expect(n).To(Equal(len(msg)))
   909  
   910  			err = conn.Close()
   911  			Expect(err).NotTo(HaveOccurred())
   912  
   913  			Expect(response).To(Equal([]byte(msg)))
   914  		}
   915  
   916  		It("dials the connect address when a local connection is made", func() {
   917  			Expect(forwardErr).NotTo(HaveOccurred())
   918  
   919  			conn, err := net.Dial("tcp", localAddress)
   920  			Expect(err).NotTo(HaveOccurred())
   921  
   922  			Eventually(echoListener.AcceptCallCount).Should(BeNumerically(">=", 1))
   923  			Eventually(fakeSecureClient.DialCallCount).Should(Equal(1))
   924  
   925  			network, addr := fakeSecureClient.DialArgsForCall(0)
   926  			Expect(network).To(Equal("tcp"))
   927  			Expect(addr).To(Equal(echoAddress))
   928  
   929  			Expect(conn.Close()).NotTo(HaveOccurred())
   930  		})
   931  
   932  		It("copies data between the local and remote connections", func() {
   933  			validateConnectivity(localAddress)
   934  		})
   935  
   936  		When("a local connection is already open", func() {
   937  			var conn net.Conn
   938  
   939  			JustBeforeEach(func() {
   940  				var err error
   941  				conn, err = net.Dial("tcp", localAddress)
   942  				Expect(err).NotTo(HaveOccurred())
   943  			})
   944  
   945  			AfterEach(func() {
   946  				err := conn.Close()
   947  				Expect(err).NotTo(HaveOccurred())
   948  			})
   949  
   950  			It("allows for new incoming connections as well", func() {
   951  				validateConnectivity(localAddress)
   952  			})
   953  		})
   954  
   955  		When("there are multiple port forward specs", func() {
   956  			When("provided a real listener", func() {
   957  				var (
   958  					realLocalListener2 net.Listener
   959  					localAddress2      string
   960  				)
   961  
   962  				BeforeEach(func() {
   963  					var err error
   964  					realLocalListener2, err = net.Listen("tcp", "127.0.0.1:0")
   965  					Expect(err).NotTo(HaveOccurred())
   966  
   967  					localAddress2 = realLocalListener2.Addr().String()
   968  
   969  					fakeListenerFactory.ListenStub = func(network, addr string) (net.Listener, error) {
   970  						if addr == localAddress {
   971  							return realLocalListener, nil
   972  						}
   973  
   974  						if addr == localAddress2 {
   975  							return realLocalListener2, nil
   976  						}
   977  
   978  						return nil, errors.New("unexpected address")
   979  					}
   980  
   981  					forwardSpecs = []LocalPortForward{
   982  						{
   983  							RemoteAddress: echoAddress,
   984  							LocalAddress:  localAddress,
   985  						},
   986  						{
   987  							RemoteAddress: echoAddress,
   988  							LocalAddress:  localAddress2,
   989  						},
   990  					}
   991  				})
   992  
   993  				AfterEach(func() {
   994  					realLocalListener2.Close()
   995  				})
   996  
   997  				It("listens to all the things", func() {
   998  					Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(2))
   999  
  1000  					network, addr := fakeListenerFactory.ListenArgsForCall(0)
  1001  					Expect(network).To(Equal("tcp"))
  1002  					Expect(addr).To(Equal(localAddress))
  1003  
  1004  					network, addr = fakeListenerFactory.ListenArgsForCall(1)
  1005  					Expect(network).To(Equal("tcp"))
  1006  					Expect(addr).To(Equal(localAddress2))
  1007  				})
  1008  
  1009  				It("forwards to the correct target", func() {
  1010  					validateConnectivity(localAddress)
  1011  					validateConnectivity(localAddress2)
  1012  				})
  1013  			})
  1014  
  1015  			When("the secure client is closed", func() {
  1016  				var (
  1017  					fakeLocalListener1 *fake_net.FakeListener
  1018  					fakeLocalListener2 *fake_net.FakeListener
  1019  				)
  1020  
  1021  				BeforeEach(func() {
  1022  					forwardSpecs = []LocalPortForward{
  1023  						{
  1024  							RemoteAddress: echoAddress,
  1025  							LocalAddress:  localAddress,
  1026  						},
  1027  						{
  1028  							RemoteAddress: echoAddress,
  1029  							LocalAddress:  localAddress,
  1030  						},
  1031  					}
  1032  
  1033  					fakeLocalListener1 = new(fake_net.FakeListener)
  1034  					BlockAcceptOnClose(fakeLocalListener1)
  1035  					fakeListenerFactory.ListenReturnsOnCall(0, fakeLocalListener1, nil)
  1036  
  1037  					fakeLocalListener2 = new(fake_net.FakeListener)
  1038  					BlockAcceptOnClose(fakeLocalListener2)
  1039  					fakeListenerFactory.ListenReturnsOnCall(1, fakeLocalListener2, nil)
  1040  				})
  1041  
  1042  				It("closes the listeners", func() {
  1043  					Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(2))
  1044  					Eventually(fakeLocalListener1.AcceptCallCount).Should(Equal(1))
  1045  					Eventually(fakeLocalListener2.AcceptCallCount).Should(Equal(1))
  1046  
  1047  					err := secureShell.Close()
  1048  					Expect(err).NotTo(HaveOccurred())
  1049  					Eventually(fakeLocalListener1.CloseCallCount).Should(Equal(2))
  1050  					Eventually(fakeLocalListener2.CloseCallCount).Should(Equal(2))
  1051  				})
  1052  			})
  1053  		})
  1054  
  1055  		When("listen fails", func() {
  1056  			BeforeEach(func() {
  1057  				fakeListenerFactory.ListenReturns(nil, errors.New("failure is an option"))
  1058  			})
  1059  
  1060  			It("returns the error", func() {
  1061  				Expect(forwardErr).To(MatchError("failure is an option"))
  1062  			})
  1063  		})
  1064  
  1065  		When("the client it closed", func() {
  1066  			BeforeEach(func() {
  1067  				fakeLocalListener = new(fake_net.FakeListener)
  1068  				BlockAcceptOnClose(fakeLocalListener)
  1069  
  1070  				fakeListenerFactory.ListenReturns(fakeLocalListener, nil)
  1071  			})
  1072  
  1073  			It("closes the listener when the client is closed", func() {
  1074  				Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(1))
  1075  				Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(1))
  1076  
  1077  				err := secureShell.Close()
  1078  				Expect(err).NotTo(HaveOccurred())
  1079  				Eventually(fakeLocalListener.CloseCallCount).Should(Equal(2))
  1080  			})
  1081  		})
  1082  
  1083  		When("accept fails", func() {
  1084  			var fakeConn *fake_net.FakeConn
  1085  
  1086  			BeforeEach(func() {
  1087  				fakeConn = &fake_net.FakeConn{}
  1088  				fakeConn.ReadReturns(0, io.EOF)
  1089  
  1090  				fakeListenerFactory.ListenReturns(fakeLocalListener, nil)
  1091  			})
  1092  
  1093  			Context("with a permanent error", func() {
  1094  				BeforeEach(func() {
  1095  					fakeLocalListener.AcceptReturns(nil, errors.New("boom"))
  1096  				})
  1097  
  1098  				It("stops trying to accept connections", func() {
  1099  					Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(1))
  1100  					Consistently(fakeLocalListener.AcceptCallCount).Should(Equal(1))
  1101  					Expect(fakeLocalListener.CloseCallCount()).To(Equal(1))
  1102  				})
  1103  			})
  1104  
  1105  			Context("with a temporary error", func() {
  1106  				var timeCh chan time.Time
  1107  
  1108  				BeforeEach(func() {
  1109  					timeCh = make(chan time.Time, 3)
  1110  
  1111  					fakeLocalListener.AcceptStub = func() (net.Conn, error) {
  1112  						timeCh := timeCh
  1113  						if fakeLocalListener.AcceptCallCount() > 3 {
  1114  							close(timeCh)
  1115  							return nil, test_helpers.NewTestNetError(false, false)
  1116  						} else {
  1117  							timeCh <- time.Now()
  1118  							return nil, test_helpers.NewTestNetError(false, true)
  1119  						}
  1120  					}
  1121  				})
  1122  
  1123  				PIt("retries connecting after a short delay", func() {
  1124  					Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(3))
  1125  					Expect(timeCh).To(HaveLen(3))
  1126  
  1127  					times := make([]time.Time, 0)
  1128  					for t := range timeCh {
  1129  						times = append(times, t)
  1130  					}
  1131  
  1132  					Expect(times[1]).To(BeTemporally("~", times[0].Add(115*time.Millisecond), 80*time.Millisecond))
  1133  					Expect(times[2]).To(BeTemporally("~", times[1].Add(115*time.Millisecond), 100*time.Millisecond))
  1134  				})
  1135  			})
  1136  		})
  1137  
  1138  		When("dialing the connect address fails", func() {
  1139  			var fakeTarget *fake_net.FakeConn
  1140  
  1141  			BeforeEach(func() {
  1142  				fakeTarget = &fake_net.FakeConn{}
  1143  				fakeSecureClient.DialReturns(fakeTarget, errors.New("boom"))
  1144  			})
  1145  
  1146  			It("does not call close on the target connection", func() {
  1147  				Consistently(fakeTarget.CloseCallCount).Should(Equal(0))
  1148  			})
  1149  		})
  1150  	})
  1151  
  1152  	Describe("Wait", func() {
  1153  		var waitErr error
  1154  
  1155  		JustBeforeEach(func() {
  1156  			connectErr := secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation)
  1157  			Expect(connectErr).NotTo(HaveOccurred())
  1158  
  1159  			waitErr = secureShell.Wait()
  1160  		})
  1161  
  1162  		It("calls wait on the secureClient", func() {
  1163  			Expect(waitErr).NotTo(HaveOccurred())
  1164  			Expect(fakeSecureClient.WaitCallCount()).To(Equal(1))
  1165  		})
  1166  
  1167  		Describe("keep alive messages", func() {
  1168  			var times []time.Time
  1169  			var timesCh chan []time.Time
  1170  			var done chan struct{}
  1171  
  1172  			BeforeEach(func() {
  1173  				keepAliveDuration = 100 * time.Millisecond
  1174  
  1175  				times = []time.Time{}
  1176  				timesCh = make(chan []time.Time, 1)
  1177  				done = make(chan struct{}, 1)
  1178  
  1179  				fakeConnection.SendRequestStub = func(reqName string, wantReply bool, message []byte) (bool, []byte, error) {
  1180  					Expect(reqName).To(Equal("keepalive@cloudfoundry.org"))
  1181  					Expect(wantReply).To(BeTrue())
  1182  					Expect(message).To(BeNil())
  1183  
  1184  					times = append(times, time.Now())
  1185  					if len(times) == 3 {
  1186  						timesCh <- times
  1187  						close(done)
  1188  					}
  1189  					return true, nil, nil
  1190  				}
  1191  
  1192  				fakeSecureClient.WaitStub = func() error {
  1193  					Eventually(done).Should(BeClosed())
  1194  					return nil
  1195  				}
  1196  			})
  1197  
  1198  			PIt("sends keep alive messages at the expected interval", func() {
  1199  				Expect(waitErr).NotTo(HaveOccurred())
  1200  				times := <-timesCh
  1201  				Expect(times[2]).To(BeTemporally("~", times[0].Add(200*time.Millisecond), 100*time.Millisecond))
  1202  			})
  1203  		})
  1204  	})
  1205  
  1206  	Describe("Close", func() {
  1207  		JustBeforeEach(func() {
  1208  			connectErr := secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation)
  1209  			Expect(connectErr).NotTo(HaveOccurred())
  1210  		})
  1211  
  1212  		It("calls close on the secureClient", func() {
  1213  			err := secureShell.Close()
  1214  			Expect(err).NotTo(HaveOccurred())
  1215  
  1216  			Expect(fakeSecureClient.CloseCallCount()).To(Equal(1))
  1217  		})
  1218  	})
  1219  })