github.com/Thanhphan1147/cloudfoundry-cli@v7.1.0+incompatible/util/clissh/ssh_test.go (about)

     1  // +build !windows,!386
     2  
     3  // skipping 386 because lager uses UInt64 in Session()
     4  // skipping windows because Unix/Linux only syscall in test.
     5  // should refactor out the conflicts so we could test this package in multi platforms.
     6  
     7  package clissh_test
     8  
     9  import (
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"net"
    14  	"os"
    15  	"sync"
    16  	"syscall"
    17  	"time"
    18  
    19  	. "code.cloudfoundry.org/cli/util/clissh"
    20  	"code.cloudfoundry.org/cli/util/clissh/clisshfakes"
    21  	"code.cloudfoundry.org/cli/util/clissh/ssherror"
    22  	"code.cloudfoundry.org/diego-ssh/server"
    23  	fake_server "code.cloudfoundry.org/diego-ssh/server/fakes"
    24  	"code.cloudfoundry.org/diego-ssh/test_helpers"
    25  	"code.cloudfoundry.org/diego-ssh/test_helpers/fake_io"
    26  	"code.cloudfoundry.org/diego-ssh/test_helpers/fake_net"
    27  	"code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh"
    28  	"code.cloudfoundry.org/lager/lagertest"
    29  	"github.com/kr/pty"
    30  	"github.com/moby/moby/pkg/term"
    31  	. "github.com/onsi/ginkgo"
    32  	. "github.com/onsi/gomega"
    33  	"golang.org/x/crypto/ssh"
    34  )
    35  
    36  func BlockAcceptOnClose(fake *fake_net.FakeListener) {
    37  	waitUntilClosed := make(chan bool)
    38  	fake.AcceptStub = func() (net.Conn, error) {
    39  		<-waitUntilClosed
    40  		return nil, errors.New("banana")
    41  	}
    42  
    43  	var once sync.Once
    44  	fake.CloseStub = func() error {
    45  		once.Do(func() {
    46  			close(waitUntilClosed)
    47  		})
    48  		return nil
    49  	}
    50  }
    51  
    52  var _ = Describe("CLI SSH", func() {
    53  	var (
    54  		fakeSecureDialer    *clisshfakes.FakeSecureDialer
    55  		fakeSecureClient    *clisshfakes.FakeSecureClient
    56  		fakeTerminalHelper  *clisshfakes.FakeTerminalHelper
    57  		fakeListenerFactory *clisshfakes.FakeListenerFactory
    58  		fakeSecureSession   *clisshfakes.FakeSecureSession
    59  
    60  		fakeConnection *fake_ssh.FakeConn
    61  		stdinPipe      *fake_io.FakeWriteCloser
    62  		stdoutPipe     *fake_io.FakeReader
    63  		stderrPipe     *fake_io.FakeReader
    64  		secureShell    *SecureShell
    65  
    66  		username               string
    67  		passcode               string
    68  		sshEndpoint            string
    69  		sshEndpointFingerprint string
    70  		skipHostValidation     bool
    71  		commands               []string
    72  		terminalRequest        TTYRequest
    73  		keepAliveDuration      time.Duration
    74  	)
    75  
    76  	BeforeEach(func() {
    77  		fakeSecureDialer = new(clisshfakes.FakeSecureDialer)
    78  		fakeSecureClient = new(clisshfakes.FakeSecureClient)
    79  		fakeTerminalHelper = new(clisshfakes.FakeTerminalHelper)
    80  		fakeListenerFactory = new(clisshfakes.FakeListenerFactory)
    81  		fakeSecureSession = new(clisshfakes.FakeSecureSession)
    82  
    83  		fakeConnection = new(fake_ssh.FakeConn)
    84  		stdinPipe = new(fake_io.FakeWriteCloser)
    85  		stdoutPipe = new(fake_io.FakeReader)
    86  		stderrPipe = new(fake_io.FakeReader)
    87  
    88  		fakeListenerFactory.ListenStub = net.Listen
    89  		fakeSecureClient.NewSessionReturns(fakeSecureSession, nil)
    90  		fakeSecureClient.ConnReturns(fakeConnection)
    91  		fakeSecureDialer.DialReturns(fakeSecureClient, nil)
    92  
    93  		stdinPipe.WriteStub = func(p []byte) (int, error) {
    94  			return len(p), nil
    95  		}
    96  		fakeSecureSession.StdinPipeReturns(stdinPipe, nil)
    97  
    98  		stdoutPipe.ReadStub = func(p []byte) (int, error) {
    99  			return 0, io.EOF
   100  		}
   101  		fakeSecureSession.StdoutPipeReturns(stdoutPipe, nil)
   102  
   103  		stderrPipe.ReadStub = func(p []byte) (int, error) {
   104  			return 0, io.EOF
   105  		}
   106  		fakeSecureSession.StderrPipeReturns(stderrPipe, nil)
   107  
   108  		username = "some-user"
   109  		passcode = "some-passcode"
   110  		sshEndpoint = "some-endpoint"
   111  		sshEndpointFingerprint = "some-fingerprint"
   112  		skipHostValidation = false
   113  		commands = []string{}
   114  		terminalRequest = RequestTTYAuto
   115  		keepAliveDuration = DefaultKeepAliveInterval
   116  	})
   117  
   118  	JustBeforeEach(func() {
   119  		secureShell = NewSecureShell(
   120  			fakeSecureDialer,
   121  			fakeTerminalHelper,
   122  			fakeListenerFactory,
   123  			keepAliveDuration,
   124  		)
   125  	})
   126  
   127  	Describe("Connect", func() {
   128  		var connectErr error
   129  
   130  		JustBeforeEach(func() {
   131  			connectErr = secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation)
   132  		})
   133  
   134  		When("dialing succeeds", func() {
   135  			It("creates the ssh client", func() {
   136  				Expect(connectErr).ToNot(HaveOccurred())
   137  
   138  				Expect(fakeSecureDialer.DialCallCount()).To(Equal(1))
   139  				protocolArg, sshEndpointArg, sshConfigArg := fakeSecureDialer.DialArgsForCall(0)
   140  				Expect(protocolArg).To(Equal("tcp"))
   141  				Expect(sshEndpointArg).To(Equal(sshEndpoint))
   142  				Expect(sshConfigArg.User).To(Equal(username))
   143  				Expect(sshConfigArg.Auth).To(HaveLen(1))
   144  				Expect(sshConfigArg.HostKeyCallback).ToNot(BeNil())
   145  			})
   146  		})
   147  
   148  		When("dialing fails", func() {
   149  			var dialError error
   150  
   151  			When("the error is a generic Dial error", func() {
   152  				BeforeEach(func() {
   153  					dialError = errors.New("woops")
   154  					fakeSecureDialer.DialReturns(nil, dialError)
   155  				})
   156  
   157  				It("returns the dial error", func() {
   158  					Expect(connectErr).To(Equal(dialError))
   159  					Expect(fakeSecureDialer.DialCallCount()).To(Equal(1))
   160  				})
   161  			})
   162  
   163  			When("the dialing error is a golang 'unable to authenticate' error", func() {
   164  				BeforeEach(func() {
   165  					dialError = fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", []string{"none", "password"})
   166  					fakeSecureDialer.DialReturns(nil, dialError)
   167  				})
   168  
   169  				It("returns an UnableToAuthenticateError", func() {
   170  					Expect(connectErr).To(MatchError(ssherror.UnableToAuthenticateError{Err: dialError}))
   171  					Expect(fakeSecureDialer.DialCallCount()).To(Equal(1))
   172  				})
   173  			})
   174  		})
   175  	})
   176  
   177  	Describe("InteractiveSession", func() {
   178  		var (
   179  			stdin          *fake_io.FakeReadCloser
   180  			stdout, stderr *fake_io.FakeWriter
   181  
   182  			sessionErr                error
   183  			interactiveSessionInvoker func(secureShell *SecureShell)
   184  		)
   185  
   186  		BeforeEach(func() {
   187  			stdin = new(fake_io.FakeReadCloser)
   188  			stdout = new(fake_io.FakeWriter)
   189  			stderr = new(fake_io.FakeWriter)
   190  
   191  			fakeTerminalHelper.StdStreamsReturns(stdin, stdout, stderr)
   192  			interactiveSessionInvoker = func(secureShell *SecureShell) {
   193  				sessionErr = secureShell.InteractiveSession(commands, terminalRequest)
   194  			}
   195  		})
   196  
   197  		JustBeforeEach(func() {
   198  			connectErr := secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation)
   199  			Expect(connectErr).NotTo(HaveOccurred())
   200  			interactiveSessionInvoker(secureShell)
   201  		})
   202  
   203  		When("host key validation is enabled", func() {
   204  			var (
   205  				callback func(hostname string, remote net.Addr, key ssh.PublicKey) error
   206  				addr     net.Addr
   207  			)
   208  
   209  			BeforeEach(func() {
   210  				skipHostValidation = false
   211  			})
   212  
   213  			JustBeforeEach(func() {
   214  				Expect(fakeSecureDialer.DialCallCount()).To(Equal(1))
   215  				_, _, config := fakeSecureDialer.DialArgsForCall(0)
   216  				callback = config.HostKeyCallback
   217  
   218  				listener, err := net.Listen("tcp", "localhost:0")
   219  				Expect(err).NotTo(HaveOccurred())
   220  
   221  				addr = listener.Addr()
   222  				listener.Close()
   223  			})
   224  
   225  			When("the md5 fingerprint matches", func() {
   226  				BeforeEach(func() {
   227  					sshEndpointFingerprint = "41:ce:56:e6:9c:42:a9:c6:9e:68:ac:e3:4d:f6:38:79"
   228  				})
   229  
   230  				It("does not return an error", func() {
   231  					Expect(callback("", addr, TestHostKey.PublicKey())).ToNot(HaveOccurred())
   232  				})
   233  			})
   234  
   235  			When("the hex sha1 fingerprint matches", func() {
   236  				BeforeEach(func() {
   237  					sshEndpointFingerprint = "a8:e2:67:cb:ea:2a:6e:23:a1:72:ce:8f:07:92:15:ee:1f:82:f8:ca"
   238  				})
   239  
   240  				It("does not return an error", func() {
   241  					Expect(callback("", addr, TestHostKey.PublicKey())).ToNot(HaveOccurred())
   242  				})
   243  			})
   244  
   245  			When("the base64 sha256 fingerprint matches", func() {
   246  				BeforeEach(func() {
   247  					sshEndpointFingerprint = "sp/jrLuj66r+yrLDUKZdJU5tdzt4mq/UaSiNBjpgr+8"
   248  				})
   249  
   250  				It("does not return an error", func() {
   251  					Expect(callback("", addr, TestHostKey.PublicKey())).ToNot(HaveOccurred())
   252  				})
   253  			})
   254  
   255  			When("the base64 SHA256 fingerprint does not match", func() {
   256  				BeforeEach(func() {
   257  					sshEndpointFingerprint = "0000000000000000000000000000000000000000000"
   258  				})
   259  
   260  				It("returns an error'", func() {
   261  					err := callback("", addr, TestHostKey.PublicKey())
   262  					Expect(err).To(MatchError(MatchRegexp(`Host key verification failed\.`)))
   263  					Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\"")))
   264  				})
   265  			})
   266  
   267  			When("the hex SHA1 fingerprint does not match", func() {
   268  				BeforeEach(func() {
   269  					sshEndpointFingerprint = "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00"
   270  				})
   271  
   272  				It("returns an error'", func() {
   273  					err := callback("", addr, TestHostKey.PublicKey())
   274  					Expect(err).To(MatchError(MatchRegexp(`Host key verification failed\.`)))
   275  					Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\"")))
   276  				})
   277  			})
   278  
   279  			When("the MD5 fingerprint does not match", func() {
   280  				BeforeEach(func() {
   281  					sshEndpointFingerprint = "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00"
   282  				})
   283  
   284  				It("returns an error'", func() {
   285  					err := callback("", addr, TestHostKey.PublicKey())
   286  					Expect(err).To(MatchError(MatchRegexp(`Host key verification failed\.`)))
   287  					Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\"")))
   288  				})
   289  			})
   290  
   291  			When("no fingerprint is present in endpoint info", func() {
   292  				BeforeEach(func() {
   293  					sshEndpointFingerprint = ""
   294  					sshEndpoint = ""
   295  				})
   296  
   297  				It("returns an error'", func() {
   298  					err := callback("", addr, TestHostKey.PublicKey())
   299  					Expect(err).To(MatchError(MatchRegexp(`Unable to verify identity of host\.`)))
   300  					Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\"")))
   301  				})
   302  			})
   303  
   304  			When("the fingerprint length doesn't make sense", func() {
   305  				BeforeEach(func() {
   306  					sshEndpointFingerprint = "garbage"
   307  				})
   308  
   309  				It("returns an error", func() {
   310  					err := callback("", addr, TestHostKey.PublicKey())
   311  					Eventually(err).Should(MatchError(MatchRegexp("Unsupported host key fingerprint format")))
   312  				})
   313  			})
   314  		})
   315  
   316  		When("the skip host validation flag is set", func() {
   317  			BeforeEach(func() {
   318  				skipHostValidation = true
   319  			})
   320  
   321  			It("the HostKeyCallback on the Config to always return nil", func() {
   322  				Expect(fakeSecureDialer.DialCallCount()).To(Equal(1))
   323  
   324  				_, _, config := fakeSecureDialer.DialArgsForCall(0)
   325  				Expect(config.HostKeyCallback("some-hostname", nil, nil)).To(BeNil())
   326  			})
   327  		})
   328  
   329  		// TODO: see if it's possible to test the piping between the ss client input and outputs and the UI object we pass in
   330  		When("dialing is successful", func() {
   331  			It("creates a new secure shell session", func() {
   332  				Expect(fakeSecureClient.NewSessionCallCount()).To(Equal(1))
   333  			})
   334  
   335  			It("closes the session", func() {
   336  				Expect(fakeSecureSession.CloseCallCount()).To(Equal(1))
   337  			})
   338  
   339  			It("gets a stdin pipe for the session", func() {
   340  				Expect(fakeSecureSession.StdinPipeCallCount()).To(Equal(1))
   341  			})
   342  
   343  			When("getting the stdin pipe fails", func() {
   344  				BeforeEach(func() {
   345  					fakeSecureSession.StdinPipeReturns(nil, errors.New("woops"))
   346  				})
   347  
   348  				It("returns the error", func() {
   349  					Expect(sessionErr).Should(MatchError("woops"))
   350  				})
   351  			})
   352  
   353  			It("gets a stdout pipe for the session", func() {
   354  				Expect(fakeSecureSession.StdoutPipeCallCount()).To(Equal(1))
   355  			})
   356  
   357  			When("getting the stdout pipe fails", func() {
   358  				BeforeEach(func() {
   359  					fakeSecureSession.StdoutPipeReturns(nil, errors.New("woops"))
   360  				})
   361  
   362  				It("returns the error", func() {
   363  					Expect(sessionErr).Should(MatchError("woops"))
   364  				})
   365  			})
   366  
   367  			It("gets a stderr pipe for the session", func() {
   368  				Expect(fakeSecureSession.StderrPipeCallCount()).To(Equal(1))
   369  			})
   370  
   371  			When("getting the stderr pipe fails", func() {
   372  				BeforeEach(func() {
   373  					fakeSecureSession.StderrPipeReturns(nil, errors.New("woops"))
   374  				})
   375  
   376  				It("returns the error", func() {
   377  					Expect(sessionErr).Should(MatchError("woops"))
   378  				})
   379  			})
   380  		})
   381  
   382  		When("stdin is a terminal", func() {
   383  			var master *os.File
   384  
   385  			BeforeEach(func() {
   386  				var err error
   387  				master, _, err = pty.Open()
   388  				Expect(err).NotTo(HaveOccurred())
   389  
   390  				terminalRequest = RequestTTYForce
   391  
   392  				terminalHelper := DefaultTerminalHelper()
   393  				fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo
   394  				fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize
   395  			})
   396  
   397  			AfterEach(func() {
   398  				master.Close()
   399  			})
   400  
   401  			When("a command is not specified", func() {
   402  				var terminalType string
   403  
   404  				BeforeEach(func() {
   405  					terminalType = os.Getenv("TERM")
   406  					os.Setenv("TERM", "test-terminal-type")
   407  
   408  					winsize := &term.Winsize{Width: 1024, Height: 256}
   409  					fakeTerminalHelper.GetWinsizeReturns(winsize, nil)
   410  
   411  					fakeSecureSession.ShellStub = func() error {
   412  						Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1))
   413  						Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(0))
   414  						return nil
   415  					}
   416  				})
   417  
   418  				AfterEach(func() {
   419  					os.Setenv("TERM", terminalType)
   420  				})
   421  
   422  				It("requests a pty with the correct terminal type, window size, and modes", func() {
   423  					Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1))
   424  					Expect(fakeTerminalHelper.GetWinsizeCallCount()).To(Equal(1))
   425  
   426  					termType, height, width, modes := fakeSecureSession.RequestPtyArgsForCall(0)
   427  					Expect(termType).To(Equal("test-terminal-type"))
   428  					Expect(height).To(Equal(256))
   429  					Expect(width).To(Equal(1024))
   430  
   431  					expectedModes := ssh.TerminalModes{
   432  						ssh.ECHO:          1,
   433  						ssh.TTY_OP_ISPEED: 115200,
   434  						ssh.TTY_OP_OSPEED: 115200,
   435  					}
   436  					Expect(modes).To(Equal(expectedModes))
   437  				})
   438  
   439  				When("the TERM environment variable is not set", func() {
   440  					BeforeEach(func() {
   441  						os.Unsetenv("TERM")
   442  					})
   443  
   444  					It("requests a pty with the default terminal type", func() {
   445  						Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1))
   446  
   447  						termType, _, _, _ := fakeSecureSession.RequestPtyArgsForCall(0)
   448  						Expect(termType).To(Equal("xterm"))
   449  					})
   450  				})
   451  
   452  				It("puts the terminal into raw mode and restores it after running the shell", func() {
   453  					Expect(fakeSecureSession.ShellCallCount()).To(Equal(1))
   454  					Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1))
   455  					Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(1))
   456  				})
   457  
   458  				When("the pty allocation fails", func() {
   459  					var ptyError error
   460  
   461  					BeforeEach(func() {
   462  						ptyError = errors.New("pty allocation error")
   463  						fakeSecureSession.RequestPtyReturns(ptyError)
   464  					})
   465  
   466  					It("returns the error", func() {
   467  						Expect(sessionErr).To(Equal(ptyError))
   468  					})
   469  				})
   470  
   471  				When("placing the terminal into raw mode fails", func() {
   472  					BeforeEach(func() {
   473  						fakeTerminalHelper.SetRawTerminalReturns(nil, errors.New("woops"))
   474  					})
   475  
   476  					It("keeps calm and carries on", func() {
   477  						Expect(fakeSecureSession.ShellCallCount()).To(Equal(1))
   478  					})
   479  
   480  					It("does not not restore the terminal", func() {
   481  						Expect(fakeSecureSession.ShellCallCount()).To(Equal(1))
   482  						Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1))
   483  						Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(0))
   484  					})
   485  				})
   486  			})
   487  
   488  			When("a command is specified", func() {
   489  				BeforeEach(func() {
   490  					commands = []string{"echo", "-n", "hello"}
   491  				})
   492  
   493  				When("a terminal is requested", func() {
   494  					BeforeEach(func() {
   495  						terminalRequest = RequestTTYYes
   496  						fakeTerminalHelper.GetFdInfoReturns(0, true)
   497  					})
   498  
   499  					It("requests a pty", func() {
   500  						Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1))
   501  					})
   502  				})
   503  
   504  				When("a terminal is not explicitly requested", func() {
   505  					BeforeEach(func() {
   506  						terminalRequest = RequestTTYAuto
   507  					})
   508  
   509  					It("does not request a pty", func() {
   510  						Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0))
   511  					})
   512  				})
   513  			})
   514  		})
   515  
   516  		When("stdin is not a terminal", func() {
   517  			BeforeEach(func() {
   518  				stdin.ReadStub = func(p []byte) (int, error) {
   519  					return 0, io.EOF
   520  				}
   521  
   522  				terminalHelper := DefaultTerminalHelper()
   523  				fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo
   524  				fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize
   525  			})
   526  
   527  			When("a terminal is not requested", func() {
   528  				It("does not request a pty", func() {
   529  					Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0))
   530  				})
   531  			})
   532  
   533  			When("a terminal is requested", func() {
   534  				BeforeEach(func() {
   535  					terminalRequest = RequestTTYYes
   536  				})
   537  
   538  				It("does not request a pty", func() {
   539  					Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0))
   540  				})
   541  			})
   542  		})
   543  
   544  		PWhen("a terminal is forced", func() {
   545  			BeforeEach(func() {
   546  				terminalRequest = RequestTTYForce
   547  			})
   548  
   549  			It("requests a pty", func() {
   550  				Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1))
   551  			})
   552  		})
   553  
   554  		When("a terminal is disabled", func() {
   555  			BeforeEach(func() {
   556  				terminalRequest = RequestTTYNo
   557  			})
   558  
   559  			It("does not request a pty", func() {
   560  				Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0))
   561  			})
   562  		})
   563  
   564  		When("a command is not specified", func() {
   565  			It("requests an interactive shell", func() {
   566  				Expect(fakeSecureSession.ShellCallCount()).To(Equal(1))
   567  			})
   568  
   569  			When("the shell request returns an error", func() {
   570  				BeforeEach(func() {
   571  					fakeSecureSession.ShellReturns(errors.New("oh bother"))
   572  				})
   573  
   574  				It("returns the error", func() {
   575  					Expect(sessionErr).To(MatchError("oh bother"))
   576  				})
   577  			})
   578  		})
   579  
   580  		When("a command is specifed", func() {
   581  			BeforeEach(func() {
   582  				commands = []string{"echo", "-n", "hello"}
   583  			})
   584  
   585  			It("starts the command", func() {
   586  				Expect(fakeSecureSession.StartCallCount()).To(Equal(1))
   587  				Expect(fakeSecureSession.StartArgsForCall(0)).To(Equal("echo -n hello"))
   588  			})
   589  
   590  			When("the command fails to start", func() {
   591  				BeforeEach(func() {
   592  					fakeSecureSession.StartReturns(errors.New("oh well"))
   593  				})
   594  
   595  				It("returns the error", func() {
   596  					Expect(sessionErr).To(MatchError("oh well"))
   597  				})
   598  			})
   599  		})
   600  
   601  		When("the shell or command has started", func() {
   602  			BeforeEach(func() {
   603  				stdin.ReadStub = func(p []byte) (int, error) {
   604  					p[0] = 0
   605  					return 1, io.EOF
   606  				}
   607  				stdinPipe.WriteStub = func(p []byte) (int, error) {
   608  					defer GinkgoRecover()
   609  					Expect(p[0]).To(Equal(byte(0)))
   610  					return 1, nil
   611  				}
   612  
   613  				stdoutPipe.ReadStub = func(p []byte) (int, error) {
   614  					p[0] = 1
   615  					return 1, io.EOF
   616  				}
   617  				stdout.WriteStub = func(p []byte) (int, error) {
   618  					defer GinkgoRecover()
   619  					Expect(p[0]).To(Equal(byte(1)))
   620  					return 1, nil
   621  				}
   622  
   623  				stderrPipe.ReadStub = func(p []byte) (int, error) {
   624  					p[0] = 2
   625  					return 1, io.EOF
   626  				}
   627  				stderr.WriteStub = func(p []byte) (int, error) {
   628  					defer GinkgoRecover()
   629  					Expect(p[0]).To(Equal(byte(2)))
   630  					return 1, nil
   631  				}
   632  
   633  				fakeSecureSession.StdinPipeReturns(stdinPipe, nil)
   634  				fakeSecureSession.StdoutPipeReturns(stdoutPipe, nil)
   635  				fakeSecureSession.StderrPipeReturns(stderrPipe, nil)
   636  
   637  				fakeSecureSession.WaitReturns(errors.New("error result"))
   638  			})
   639  
   640  			It("copies data from the stdin stream to the session stdin pipe", func() {
   641  				Eventually(stdin.ReadCallCount).Should(Equal(1))
   642  				Eventually(stdinPipe.WriteCallCount).Should(Equal(1))
   643  			})
   644  
   645  			It("copies data from the session stdout pipe to the stdout stream", func() {
   646  				Eventually(stdoutPipe.ReadCallCount).Should(Equal(1))
   647  				Eventually(stdout.WriteCallCount).Should(Equal(1))
   648  			})
   649  
   650  			It("copies data from the session stderr pipe to the stderr stream", func() {
   651  				Eventually(stderrPipe.ReadCallCount).Should(Equal(1))
   652  				Eventually(stderr.WriteCallCount).Should(Equal(1))
   653  			})
   654  
   655  			It("waits for the session to end", func() {
   656  				Expect(fakeSecureSession.WaitCallCount()).To(Equal(1))
   657  			})
   658  
   659  			It("returns the result from wait", func() {
   660  				Expect(sessionErr).To(MatchError("error result"))
   661  			})
   662  
   663  			When("the session terminates before stream copies complete", func() {
   664  				var sessionErrorCh chan error
   665  
   666  				BeforeEach(func() {
   667  					sessionErrorCh = make(chan error, 1)
   668  
   669  					interactiveSessionInvoker = func(secureShell *SecureShell) {
   670  						go func() {
   671  							sessionErrorCh <- secureShell.InteractiveSession(commands, terminalRequest)
   672  						}()
   673  					}
   674  
   675  					stdoutPipe.ReadStub = func(p []byte) (int, error) {
   676  						defer GinkgoRecover()
   677  						Eventually(fakeSecureSession.WaitCallCount).Should(Equal(1))
   678  						Consistently(sessionErrorCh).ShouldNot(Receive())
   679  
   680  						p[0] = 1
   681  						return 1, io.EOF
   682  					}
   683  
   684  					stderrPipe.ReadStub = func(p []byte) (int, error) {
   685  						defer GinkgoRecover()
   686  						Eventually(fakeSecureSession.WaitCallCount).Should(Equal(1))
   687  						Consistently(sessionErrorCh).ShouldNot(Receive())
   688  
   689  						p[0] = 2
   690  						return 1, io.EOF
   691  					}
   692  				})
   693  
   694  				It("waits for the copies to complete", func() {
   695  					Eventually(sessionErrorCh).Should(Receive())
   696  					Expect(stdoutPipe.ReadCallCount()).To(Equal(1))
   697  					Expect(stderrPipe.ReadCallCount()).To(Equal(1))
   698  				})
   699  			})
   700  
   701  			When("stdin is closed", func() {
   702  				BeforeEach(func() {
   703  					stdin.ReadStub = func(p []byte) (int, error) {
   704  						defer GinkgoRecover()
   705  						Consistently(stdinPipe.CloseCallCount).Should(Equal(0))
   706  						p[0] = 0
   707  						return 1, io.EOF
   708  					}
   709  				})
   710  
   711  				It("closes the stdinPipe", func() {
   712  					Eventually(stdinPipe.CloseCallCount).Should(Equal(1))
   713  				})
   714  			})
   715  		})
   716  
   717  		When("stdout is a terminal and a window size change occurs", func() {
   718  			var master, slave *os.File
   719  
   720  			BeforeEach(func() {
   721  				var err error
   722  				master, slave, err = pty.Open()
   723  				Expect(err).NotTo(HaveOccurred())
   724  
   725  				terminalHelper := DefaultTerminalHelper()
   726  				fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo
   727  				fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize
   728  				fakeTerminalHelper.StdStreamsReturns(stdin, slave, stderr)
   729  
   730  				winsize := &term.Winsize{Height: 100, Width: 100}
   731  				err = term.SetWinsize(slave.Fd(), winsize)
   732  				Expect(err).NotTo(HaveOccurred())
   733  
   734  				fakeSecureSession.WaitStub = func() error {
   735  					fakeSecureSession.SendRequestCallCount()
   736  					Expect(fakeSecureSession.SendRequestCallCount()).To(Equal(0))
   737  
   738  					// No dimension change
   739  					for i := 0; i < 3; i++ {
   740  						winsize := &term.Winsize{Height: 100, Width: 100}
   741  						err = term.SetWinsize(slave.Fd(), winsize)
   742  						Expect(err).NotTo(HaveOccurred())
   743  					}
   744  
   745  					winsize := &term.Winsize{Height: 100, Width: 200}
   746  					err = term.SetWinsize(slave.Fd(), winsize)
   747  					Expect(err).NotTo(HaveOccurred())
   748  
   749  					err = syscall.Kill(syscall.Getpid(), syscall.SIGWINCH)
   750  					Expect(err).NotTo(HaveOccurred())
   751  
   752  					Eventually(fakeSecureSession.SendRequestCallCount).Should(Equal(1))
   753  					return nil
   754  				}
   755  			})
   756  
   757  			AfterEach(func() {
   758  				master.Close()
   759  				slave.Close()
   760  			})
   761  
   762  			It("sends window change events when the window dimensions change", func() {
   763  				Expect(fakeSecureSession.SendRequestCallCount()).To(Equal(1))
   764  
   765  				requestType, wantReply, message := fakeSecureSession.SendRequestArgsForCall(0)
   766  				Expect(requestType).To(Equal("window-change"))
   767  				Expect(wantReply).To(BeFalse())
   768  
   769  				type resizeMessage struct {
   770  					Width       uint32
   771  					Height      uint32
   772  					PixelWidth  uint32
   773  					PixelHeight uint32
   774  				}
   775  				var resizeMsg resizeMessage
   776  
   777  				err := ssh.Unmarshal(message, &resizeMsg)
   778  				Expect(err).NotTo(HaveOccurred())
   779  
   780  				Expect(resizeMsg).To(Equal(resizeMessage{Height: 100, Width: 200}))
   781  			})
   782  		})
   783  
   784  		Describe("keep alive messages", func() {
   785  			var times []time.Time
   786  			var timesCh chan []time.Time
   787  			var done chan struct{}
   788  
   789  			BeforeEach(func() {
   790  				keepAliveDuration = 100 * time.Millisecond
   791  
   792  				times = []time.Time{}
   793  				timesCh = make(chan []time.Time, 1)
   794  				done = make(chan struct{}, 1)
   795  
   796  				fakeConnection.SendRequestStub = func(reqName string, wantReply bool, message []byte) (bool, []byte, error) {
   797  					Expect(reqName).To(Equal("keepalive@cloudfoundry.org"))
   798  					Expect(wantReply).To(BeTrue())
   799  					Expect(message).To(BeNil())
   800  
   801  					times = append(times, time.Now())
   802  					if len(times) == 3 {
   803  						timesCh <- times
   804  						close(done)
   805  					}
   806  					return true, nil, nil
   807  				}
   808  
   809  				fakeSecureSession.WaitStub = func() error {
   810  					Eventually(done).Should(BeClosed())
   811  					return nil
   812  				}
   813  			})
   814  
   815  			PIt("sends keep alive messages at the expected interval", func() {
   816  				times := <-timesCh
   817  				Expect(times[2]).To(BeTemporally("~", times[0].Add(200*time.Millisecond), 160*time.Millisecond))
   818  			})
   819  		})
   820  	})
   821  
   822  	Describe("LocalPortForward", func() {
   823  		var (
   824  			forwardErr error
   825  
   826  			echoAddress  string
   827  			echoListener *fake_net.FakeListener
   828  			echoHandler  *fake_server.FakeConnectionHandler
   829  			echoServer   *server.Server
   830  
   831  			localAddress string
   832  
   833  			realLocalListener net.Listener
   834  			fakeLocalListener *fake_net.FakeListener
   835  
   836  			forwardSpecs []LocalPortForward
   837  		)
   838  
   839  		BeforeEach(func() {
   840  			logger := lagertest.NewTestLogger("test")
   841  
   842  			var err error
   843  			realLocalListener, err = net.Listen("tcp", "127.0.0.1:0")
   844  			Expect(err).NotTo(HaveOccurred())
   845  
   846  			localAddress = realLocalListener.Addr().String()
   847  			fakeListenerFactory.ListenReturns(realLocalListener, nil)
   848  
   849  			echoHandler = new(fake_server.FakeConnectionHandler)
   850  			echoHandler.HandleConnectionStub = func(conn net.Conn) {
   851  				io.Copy(conn, conn) //nolint:errcheck
   852  				conn.Close()
   853  			}
   854  
   855  			realListener, err := net.Listen("tcp", "127.0.0.1:0")
   856  			Expect(err).NotTo(HaveOccurred())
   857  			echoAddress = realListener.Addr().String()
   858  
   859  			echoListener = new(fake_net.FakeListener)
   860  			echoListener.AcceptStub = realListener.Accept
   861  			echoListener.CloseStub = realListener.Close
   862  			echoListener.AddrStub = realListener.Addr
   863  
   864  			fakeLocalListener = new(fake_net.FakeListener)
   865  			fakeLocalListener.AcceptReturns(nil, errors.New("Not Accepting Connections"))
   866  
   867  			echoServer = server.NewServer(logger.Session("echo"), "", echoHandler)
   868  			err = echoServer.SetListener(echoListener)
   869  			Expect(err).NotTo(HaveOccurred())
   870  			go echoServer.Serve()
   871  
   872  			forwardSpecs = []LocalPortForward{{
   873  				RemoteAddress: echoAddress,
   874  				LocalAddress:  localAddress,
   875  			}}
   876  
   877  			fakeSecureClient.DialStub = net.Dial
   878  		})
   879  
   880  		JustBeforeEach(func() {
   881  			connectErr := secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation)
   882  			Expect(connectErr).NotTo(HaveOccurred())
   883  
   884  			forwardErr = secureShell.LocalPortForward(forwardSpecs)
   885  		})
   886  
   887  		AfterEach(func() {
   888  			err := secureShell.Close()
   889  			Expect(err).NotTo(HaveOccurred())
   890  			echoServer.Shutdown()
   891  
   892  			realLocalListener.Close()
   893  		})
   894  
   895  		validateConnectivity := func(addr string) {
   896  			conn, err := net.Dial("tcp", addr)
   897  			Expect(err).NotTo(HaveOccurred())
   898  
   899  			msg := fmt.Sprintf("Hello from %s\n", addr)
   900  			n, err := conn.Write([]byte(msg))
   901  			Expect(err).NotTo(HaveOccurred())
   902  			Expect(n).To(Equal(len(msg)))
   903  
   904  			response := make([]byte, len(msg))
   905  			n, err = conn.Read(response)
   906  			Expect(err).NotTo(HaveOccurred())
   907  			Expect(n).To(Equal(len(msg)))
   908  
   909  			err = conn.Close()
   910  			Expect(err).NotTo(HaveOccurred())
   911  
   912  			Expect(response).To(Equal([]byte(msg)))
   913  		}
   914  
   915  		It("dials the connect address when a local connection is made", func() {
   916  			Expect(forwardErr).NotTo(HaveOccurred())
   917  
   918  			conn, err := net.Dial("tcp", localAddress)
   919  			Expect(err).NotTo(HaveOccurred())
   920  
   921  			Eventually(echoListener.AcceptCallCount).Should(BeNumerically(">=", 1))
   922  			Eventually(fakeSecureClient.DialCallCount).Should(Equal(1))
   923  
   924  			network, addr := fakeSecureClient.DialArgsForCall(0)
   925  			Expect(network).To(Equal("tcp"))
   926  			Expect(addr).To(Equal(echoAddress))
   927  
   928  			Expect(conn.Close()).NotTo(HaveOccurred())
   929  		})
   930  
   931  		It("copies data between the local and remote connections", func() {
   932  			validateConnectivity(localAddress)
   933  		})
   934  
   935  		When("a local connection is already open", func() {
   936  			var conn net.Conn
   937  
   938  			JustBeforeEach(func() {
   939  				var err error
   940  				conn, err = net.Dial("tcp", localAddress)
   941  				Expect(err).NotTo(HaveOccurred())
   942  			})
   943  
   944  			AfterEach(func() {
   945  				err := conn.Close()
   946  				Expect(err).NotTo(HaveOccurred())
   947  			})
   948  
   949  			It("allows for new incoming connections as well", func() {
   950  				validateConnectivity(localAddress)
   951  			})
   952  		})
   953  
   954  		When("there are multiple port forward specs", func() {
   955  			When("provided a real listener", func() {
   956  				var (
   957  					realLocalListener2 net.Listener
   958  					localAddress2      string
   959  				)
   960  
   961  				BeforeEach(func() {
   962  					var err error
   963  					realLocalListener2, err = net.Listen("tcp", "127.0.0.1:0")
   964  					Expect(err).NotTo(HaveOccurred())
   965  
   966  					localAddress2 = realLocalListener2.Addr().String()
   967  
   968  					fakeListenerFactory.ListenStub = func(network, addr string) (net.Listener, error) {
   969  						if addr == localAddress {
   970  							return realLocalListener, nil
   971  						}
   972  
   973  						if addr == localAddress2 {
   974  							return realLocalListener2, nil
   975  						}
   976  
   977  						return nil, errors.New("unexpected address")
   978  					}
   979  
   980  					forwardSpecs = []LocalPortForward{
   981  						{
   982  							RemoteAddress: echoAddress,
   983  							LocalAddress:  localAddress,
   984  						},
   985  						{
   986  							RemoteAddress: echoAddress,
   987  							LocalAddress:  localAddress2,
   988  						},
   989  					}
   990  				})
   991  
   992  				AfterEach(func() {
   993  					realLocalListener2.Close()
   994  				})
   995  
   996  				It("listens to all the things", func() {
   997  					Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(2))
   998  
   999  					network, addr := fakeListenerFactory.ListenArgsForCall(0)
  1000  					Expect(network).To(Equal("tcp"))
  1001  					Expect(addr).To(Equal(localAddress))
  1002  
  1003  					network, addr = fakeListenerFactory.ListenArgsForCall(1)
  1004  					Expect(network).To(Equal("tcp"))
  1005  					Expect(addr).To(Equal(localAddress2))
  1006  				})
  1007  
  1008  				It("forwards to the correct target", func() {
  1009  					validateConnectivity(localAddress)
  1010  					validateConnectivity(localAddress2)
  1011  				})
  1012  			})
  1013  
  1014  			When("the secure client is closed", func() {
  1015  				var (
  1016  					fakeLocalListener1 *fake_net.FakeListener
  1017  					fakeLocalListener2 *fake_net.FakeListener
  1018  				)
  1019  
  1020  				BeforeEach(func() {
  1021  					forwardSpecs = []LocalPortForward{
  1022  						{
  1023  							RemoteAddress: echoAddress,
  1024  							LocalAddress:  localAddress,
  1025  						},
  1026  						{
  1027  							RemoteAddress: echoAddress,
  1028  							LocalAddress:  localAddress,
  1029  						},
  1030  					}
  1031  
  1032  					fakeLocalListener1 = new(fake_net.FakeListener)
  1033  					BlockAcceptOnClose(fakeLocalListener1)
  1034  					fakeListenerFactory.ListenReturnsOnCall(0, fakeLocalListener1, nil)
  1035  
  1036  					fakeLocalListener2 = new(fake_net.FakeListener)
  1037  					BlockAcceptOnClose(fakeLocalListener2)
  1038  					fakeListenerFactory.ListenReturnsOnCall(1, fakeLocalListener2, nil)
  1039  				})
  1040  
  1041  				It("closes the listeners", func() {
  1042  					Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(2))
  1043  					Eventually(fakeLocalListener1.AcceptCallCount).Should(Equal(1))
  1044  					Eventually(fakeLocalListener2.AcceptCallCount).Should(Equal(1))
  1045  
  1046  					err := secureShell.Close()
  1047  					Expect(err).NotTo(HaveOccurred())
  1048  					Eventually(fakeLocalListener1.CloseCallCount).Should(Equal(2))
  1049  					Eventually(fakeLocalListener2.CloseCallCount).Should(Equal(2))
  1050  				})
  1051  			})
  1052  		})
  1053  
  1054  		When("listen fails", func() {
  1055  			BeforeEach(func() {
  1056  				fakeListenerFactory.ListenReturns(nil, errors.New("failure is an option"))
  1057  			})
  1058  
  1059  			It("returns the error", func() {
  1060  				Expect(forwardErr).To(MatchError("failure is an option"))
  1061  			})
  1062  		})
  1063  
  1064  		When("the client it closed", func() {
  1065  			BeforeEach(func() {
  1066  				fakeLocalListener = new(fake_net.FakeListener)
  1067  				BlockAcceptOnClose(fakeLocalListener)
  1068  
  1069  				fakeListenerFactory.ListenReturns(fakeLocalListener, nil)
  1070  			})
  1071  
  1072  			It("closes the listener when the client is closed", func() {
  1073  				Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(1))
  1074  				Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(1))
  1075  
  1076  				err := secureShell.Close()
  1077  				Expect(err).NotTo(HaveOccurred())
  1078  				Eventually(fakeLocalListener.CloseCallCount).Should(Equal(2))
  1079  			})
  1080  		})
  1081  
  1082  		When("accept fails", func() {
  1083  			var fakeConn *fake_net.FakeConn
  1084  
  1085  			BeforeEach(func() {
  1086  				fakeConn = &fake_net.FakeConn{}
  1087  				fakeConn.ReadReturns(0, io.EOF)
  1088  
  1089  				fakeListenerFactory.ListenReturns(fakeLocalListener, nil)
  1090  			})
  1091  
  1092  			Context("with a permanent error", func() {
  1093  				BeforeEach(func() {
  1094  					fakeLocalListener.AcceptReturns(nil, errors.New("boom"))
  1095  				})
  1096  
  1097  				It("stops trying to accept connections", func() {
  1098  					Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(1))
  1099  					Consistently(fakeLocalListener.AcceptCallCount).Should(Equal(1))
  1100  					Expect(fakeLocalListener.CloseCallCount()).To(Equal(1))
  1101  				})
  1102  			})
  1103  
  1104  			Context("with a temporary error", func() {
  1105  				var timeCh chan time.Time
  1106  
  1107  				BeforeEach(func() {
  1108  					timeCh = make(chan time.Time, 3)
  1109  
  1110  					fakeLocalListener.AcceptStub = func() (net.Conn, error) {
  1111  						timeCh := timeCh
  1112  						if fakeLocalListener.AcceptCallCount() > 3 {
  1113  							close(timeCh)
  1114  							return nil, test_helpers.NewTestNetError(false, false)
  1115  						} else {
  1116  							timeCh <- time.Now()
  1117  							return nil, test_helpers.NewTestNetError(false, true)
  1118  						}
  1119  					}
  1120  				})
  1121  
  1122  				PIt("retries connecting after a short delay", func() {
  1123  					Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(3))
  1124  					Expect(timeCh).To(HaveLen(3))
  1125  
  1126  					times := make([]time.Time, 0)
  1127  					for t := range timeCh {
  1128  						times = append(times, t)
  1129  					}
  1130  
  1131  					Expect(times[1]).To(BeTemporally("~", times[0].Add(115*time.Millisecond), 80*time.Millisecond))
  1132  					Expect(times[2]).To(BeTemporally("~", times[1].Add(115*time.Millisecond), 100*time.Millisecond))
  1133  				})
  1134  			})
  1135  		})
  1136  
  1137  		When("dialing the connect address fails", func() {
  1138  			var fakeTarget *fake_net.FakeConn
  1139  
  1140  			BeforeEach(func() {
  1141  				fakeTarget = &fake_net.FakeConn{}
  1142  				fakeSecureClient.DialReturns(fakeTarget, errors.New("boom"))
  1143  			})
  1144  
  1145  			It("does not call close on the target connection", func() {
  1146  				Consistently(fakeTarget.CloseCallCount).Should(Equal(0))
  1147  			})
  1148  		})
  1149  	})
  1150  
  1151  	Describe("Wait", func() {
  1152  		var waitErr error
  1153  
  1154  		JustBeforeEach(func() {
  1155  			connectErr := secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation)
  1156  			Expect(connectErr).NotTo(HaveOccurred())
  1157  
  1158  			waitErr = secureShell.Wait()
  1159  		})
  1160  
  1161  		It("calls wait on the secureClient", func() {
  1162  			Expect(waitErr).NotTo(HaveOccurred())
  1163  			Expect(fakeSecureClient.WaitCallCount()).To(Equal(1))
  1164  		})
  1165  
  1166  		Describe("keep alive messages", func() {
  1167  			var times []time.Time
  1168  			var timesCh chan []time.Time
  1169  			var done chan struct{}
  1170  
  1171  			BeforeEach(func() {
  1172  				keepAliveDuration = 100 * time.Millisecond
  1173  
  1174  				times = []time.Time{}
  1175  				timesCh = make(chan []time.Time, 1)
  1176  				done = make(chan struct{}, 1)
  1177  
  1178  				fakeConnection.SendRequestStub = func(reqName string, wantReply bool, message []byte) (bool, []byte, error) {
  1179  					Expect(reqName).To(Equal("keepalive@cloudfoundry.org"))
  1180  					Expect(wantReply).To(BeTrue())
  1181  					Expect(message).To(BeNil())
  1182  
  1183  					times = append(times, time.Now())
  1184  					if len(times) == 3 {
  1185  						timesCh <- times
  1186  						close(done)
  1187  					}
  1188  					return true, nil, nil
  1189  				}
  1190  
  1191  				fakeSecureClient.WaitStub = func() error {
  1192  					Eventually(done).Should(BeClosed())
  1193  					return nil
  1194  				}
  1195  			})
  1196  
  1197  			PIt("sends keep alive messages at the expected interval", func() {
  1198  				Expect(waitErr).NotTo(HaveOccurred())
  1199  				times := <-timesCh
  1200  				Expect(times[2]).To(BeTemporally("~", times[0].Add(200*time.Millisecond), 100*time.Millisecond))
  1201  			})
  1202  		})
  1203  	})
  1204  
  1205  	Describe("Close", func() {
  1206  		JustBeforeEach(func() {
  1207  			connectErr := secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation)
  1208  			Expect(connectErr).NotTo(HaveOccurred())
  1209  		})
  1210  
  1211  		It("calls close on the secureClient", func() {
  1212  			err := secureShell.Close()
  1213  			Expect(err).NotTo(HaveOccurred())
  1214  
  1215  			Expect(fakeSecureClient.CloseCallCount()).To(Equal(1))
  1216  		})
  1217  	})
  1218  })