github.com/dcarley/cf-cli@v6.24.1-0.20170220111324-4225ff346898+incompatible/cf/ssh/ssh_test.go (about)

     1  // +build !windows,!386
     2  
     3  // skipping 386 because lager uses UInt64 in Session()
     4  // skipping windows because Unix/Linux only syscall in test.
     5  // should refactor out the conflicts so we could test this package in multi platforms.
     6  
     7  package sshCmd_test
     8  
     9  import (
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"net"
    14  	"os"
    15  	"syscall"
    16  	"time"
    17  
    18  	"code.cloudfoundry.org/cli/cf/models"
    19  	"code.cloudfoundry.org/cli/cf/ssh"
    20  	"code.cloudfoundry.org/cli/cf/ssh/options"
    21  	"code.cloudfoundry.org/cli/cf/ssh/sshfakes"
    22  	"code.cloudfoundry.org/cli/cf/ssh/terminal"
    23  	"code.cloudfoundry.org/cli/cf/ssh/terminal/terminalfakes"
    24  	"code.cloudfoundry.org/diego-ssh/server"
    25  	fake_server "code.cloudfoundry.org/diego-ssh/server/fakes"
    26  	"code.cloudfoundry.org/diego-ssh/test_helpers"
    27  	"code.cloudfoundry.org/diego-ssh/test_helpers/fake_io"
    28  	"code.cloudfoundry.org/diego-ssh/test_helpers/fake_net"
    29  	"code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh"
    30  	"code.cloudfoundry.org/lager/lagertest"
    31  	"github.com/docker/docker/pkg/term"
    32  	"github.com/kr/pty"
    33  	"golang.org/x/crypto/ssh"
    34  
    35  	. "github.com/onsi/ginkgo"
    36  	. "github.com/onsi/gomega"
    37  )
    38  
    39  var _ = Describe("SSH", func() {
    40  	var (
    41  		fakeTerminalHelper  *terminalfakes.FakeTerminalHelper
    42  		fakeListenerFactory *sshfakes.FakeListenerFactory
    43  
    44  		fakeConnection    *fake_ssh.FakeConn
    45  		fakeSecureClient  *sshfakes.FakeSecureClient
    46  		fakeSecureDialer  *sshfakes.FakeSecureDialer
    47  		fakeSecureSession *sshfakes.FakeSecureSession
    48  
    49  		terminalHelper    terminal.TerminalHelper
    50  		keepAliveDuration time.Duration
    51  		secureShell       sshCmd.SecureShell
    52  
    53  		stdinPipe *fake_io.FakeWriteCloser
    54  
    55  		currentApp             models.Application
    56  		sshEndpointFingerprint string
    57  		sshEndpoint            string
    58  		token                  string
    59  	)
    60  
    61  	BeforeEach(func() {
    62  		fakeTerminalHelper = new(terminalfakes.FakeTerminalHelper)
    63  		terminalHelper = terminal.DefaultHelper()
    64  
    65  		fakeListenerFactory = new(sshfakes.FakeListenerFactory)
    66  		fakeListenerFactory.ListenStub = net.Listen
    67  
    68  		keepAliveDuration = 30 * time.Second
    69  
    70  		currentApp = models.Application{}
    71  		sshEndpoint = ""
    72  		sshEndpointFingerprint = ""
    73  		token = ""
    74  
    75  		fakeConnection = new(fake_ssh.FakeConn)
    76  		fakeSecureClient = new(sshfakes.FakeSecureClient)
    77  		fakeSecureDialer = new(sshfakes.FakeSecureDialer)
    78  		fakeSecureSession = new(sshfakes.FakeSecureSession)
    79  
    80  		fakeSecureDialer.DialReturns(fakeSecureClient, nil)
    81  		fakeSecureClient.NewSessionReturns(fakeSecureSession, nil)
    82  		fakeSecureClient.ConnReturns(fakeConnection)
    83  
    84  		stdinPipe = &fake_io.FakeWriteCloser{}
    85  		stdinPipe.WriteStub = func(p []byte) (int, error) {
    86  			return len(p), nil
    87  		}
    88  
    89  		stdoutPipe := &fake_io.FakeReader{}
    90  		stdoutPipe.ReadStub = func(p []byte) (int, error) {
    91  			return 0, io.EOF
    92  		}
    93  
    94  		stderrPipe := &fake_io.FakeReader{}
    95  		stderrPipe.ReadStub = func(p []byte) (int, error) {
    96  			return 0, io.EOF
    97  		}
    98  
    99  		fakeSecureSession.StdinPipeReturns(stdinPipe, nil)
   100  		fakeSecureSession.StdoutPipeReturns(stdoutPipe, nil)
   101  		fakeSecureSession.StderrPipeReturns(stderrPipe, nil)
   102  	})
   103  
   104  	JustBeforeEach(func() {
   105  		secureShell = sshCmd.NewSecureShell(
   106  			fakeSecureDialer,
   107  			terminalHelper,
   108  			fakeListenerFactory,
   109  			keepAliveDuration,
   110  			currentApp,
   111  			sshEndpointFingerprint,
   112  			sshEndpoint,
   113  			token,
   114  		)
   115  	})
   116  
   117  	Describe("Validation", func() {
   118  		var connectErr error
   119  		var opts *options.SSHOptions
   120  
   121  		BeforeEach(func() {
   122  			opts = &options.SSHOptions{
   123  				AppName: "app-1",
   124  			}
   125  		})
   126  
   127  		JustBeforeEach(func() {
   128  			connectErr = secureShell.Connect(opts)
   129  		})
   130  
   131  		Context("when the app model and endpoint info are successfully acquired", func() {
   132  			BeforeEach(func() {
   133  				token = ""
   134  				currentApp.State = "STARTED"
   135  				currentApp.Diego = true
   136  			})
   137  
   138  			Context("when the app is not in the 'STARTED' state", func() {
   139  				BeforeEach(func() {
   140  					currentApp.State = "STOPPED"
   141  					currentApp.Diego = true
   142  				})
   143  
   144  				It("returns an error", func() {
   145  					Expect(connectErr).To(MatchError(MatchRegexp("Application.*not in the STARTED state")))
   146  				})
   147  			})
   148  
   149  			Context("when the app is not a Diego app", func() {
   150  				BeforeEach(func() {
   151  					currentApp.State = "STARTED"
   152  					currentApp.Diego = false
   153  				})
   154  
   155  				It("returns an error", func() {
   156  					Expect(connectErr).To(MatchError(MatchRegexp("Application.*not running on Diego")))
   157  				})
   158  			})
   159  
   160  			Context("when dialing fails", func() {
   161  				var dialError = errors.New("woops")
   162  
   163  				BeforeEach(func() {
   164  					fakeSecureDialer.DialReturns(nil, dialError)
   165  				})
   166  
   167  				It("returns the dial error", func() {
   168  					Expect(connectErr).To(Equal(dialError))
   169  					Expect(fakeSecureDialer.DialCallCount()).To(Equal(1))
   170  				})
   171  			})
   172  		})
   173  	})
   174  
   175  	Describe("InteractiveSession", func() {
   176  		var opts *options.SSHOptions
   177  		var sessionError error
   178  		var interactiveSessionInvoker func(secureShell sshCmd.SecureShell)
   179  
   180  		BeforeEach(func() {
   181  			sshEndpoint = "ssh.example.com:22"
   182  
   183  			opts = &options.SSHOptions{
   184  				AppName: "app-name",
   185  				Index:   2,
   186  			}
   187  
   188  			currentApp.State = "STARTED"
   189  			currentApp.Diego = true
   190  			currentApp.GUID = "app-guid"
   191  			token = "bearer token"
   192  
   193  			interactiveSessionInvoker = func(secureShell sshCmd.SecureShell) {
   194  				sessionError = secureShell.InteractiveSession()
   195  			}
   196  		})
   197  
   198  		JustBeforeEach(func() {
   199  			connectErr := secureShell.Connect(opts)
   200  			Expect(connectErr).NotTo(HaveOccurred())
   201  			interactiveSessionInvoker(secureShell)
   202  		})
   203  
   204  		It("dials the correct endpoint as the correct user", func() {
   205  			Expect(fakeSecureDialer.DialCallCount()).To(Equal(1))
   206  
   207  			network, address, config := fakeSecureDialer.DialArgsForCall(0)
   208  			Expect(network).To(Equal("tcp"))
   209  			Expect(address).To(Equal("ssh.example.com:22"))
   210  			Expect(config.Auth).NotTo(BeEmpty())
   211  			Expect(config.User).To(Equal("cf:app-guid/2"))
   212  			Expect(config.HostKeyCallback).NotTo(BeNil())
   213  		})
   214  
   215  		Context("when host key validation is enabled", func() {
   216  			var callback func(hostname string, remote net.Addr, key ssh.PublicKey) error
   217  			var addr net.Addr
   218  
   219  			JustBeforeEach(func() {
   220  				Expect(fakeSecureDialer.DialCallCount()).To(Equal(1))
   221  				_, _, config := fakeSecureDialer.DialArgsForCall(0)
   222  				callback = config.HostKeyCallback
   223  
   224  				listener, err := net.Listen("tcp", "localhost:0")
   225  				Expect(err).NotTo(HaveOccurred())
   226  
   227  				addr = listener.Addr()
   228  				listener.Close()
   229  			})
   230  
   231  			Context("when the SHA1 fingerprint does not match", func() {
   232  				BeforeEach(func() {
   233  					sshEndpointFingerprint = "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00"
   234  				})
   235  
   236  				It("returns an error'", func() {
   237  					err := callback("", addr, TestHostKey.PublicKey())
   238  					Expect(err).To(MatchError(MatchRegexp("Host key verification failed\\.")))
   239  					Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\"")))
   240  				})
   241  			})
   242  
   243  			Context("when the MD5 fingerprint does not match", func() {
   244  				BeforeEach(func() {
   245  					sshEndpointFingerprint = "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00"
   246  				})
   247  
   248  				It("returns an error'", func() {
   249  					err := callback("", addr, TestHostKey.PublicKey())
   250  					Expect(err).To(MatchError(MatchRegexp("Host key verification failed\\.")))
   251  					Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\"")))
   252  				})
   253  			})
   254  
   255  			Context("when no fingerprint is present in endpoint info", func() {
   256  				BeforeEach(func() {
   257  					sshEndpointFingerprint = ""
   258  					sshEndpoint = ""
   259  				})
   260  
   261  				It("returns an error'", func() {
   262  					err := callback("", addr, TestHostKey.PublicKey())
   263  					Expect(err).To(MatchError(MatchRegexp("Unable to verify identity of host\\.")))
   264  					Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\"")))
   265  				})
   266  			})
   267  
   268  			Context("when the fingerprint length doesn't make sense", func() {
   269  				BeforeEach(func() {
   270  					sshEndpointFingerprint = "garbage"
   271  				})
   272  
   273  				It("returns an error", func() {
   274  					err := callback("", addr, TestHostKey.PublicKey())
   275  					Eventually(err).Should(MatchError(MatchRegexp("Unsupported host key fingerprint format")))
   276  				})
   277  			})
   278  		})
   279  
   280  		Context("when the skip host validation flag is set", func() {
   281  			BeforeEach(func() {
   282  				opts.SkipHostValidation = true
   283  			})
   284  
   285  			It("removes the HostKeyCallback from the client config", func() {
   286  				Expect(fakeSecureDialer.DialCallCount()).To(Equal(1))
   287  
   288  				_, _, config := fakeSecureDialer.DialArgsForCall(0)
   289  				Expect(config.HostKeyCallback).To(BeNil())
   290  			})
   291  		})
   292  
   293  		Context("when dialing is successful", func() {
   294  			BeforeEach(func() {
   295  				fakeTerminalHelper.StdStreamsStub = terminalHelper.StdStreams
   296  				terminalHelper = fakeTerminalHelper
   297  			})
   298  
   299  			It("creates a new secure shell session", func() {
   300  				Expect(fakeSecureClient.NewSessionCallCount()).To(Equal(1))
   301  			})
   302  
   303  			It("closes the session", func() {
   304  				Expect(fakeSecureSession.CloseCallCount()).To(Equal(1))
   305  			})
   306  
   307  			It("allocates standard streams", func() {
   308  				Expect(fakeTerminalHelper.StdStreamsCallCount()).To(Equal(1))
   309  			})
   310  
   311  			It("gets a stdin pipe for the session", func() {
   312  				Expect(fakeSecureSession.StdinPipeCallCount()).To(Equal(1))
   313  			})
   314  
   315  			Context("when getting the stdin pipe fails", func() {
   316  				BeforeEach(func() {
   317  					fakeSecureSession.StdinPipeReturns(nil, errors.New("woops"))
   318  				})
   319  
   320  				It("returns the error", func() {
   321  					Expect(sessionError).Should(MatchError("woops"))
   322  				})
   323  			})
   324  
   325  			It("gets a stdout pipe for the session", func() {
   326  				Expect(fakeSecureSession.StdoutPipeCallCount()).To(Equal(1))
   327  			})
   328  
   329  			Context("when getting the stdout pipe fails", func() {
   330  				BeforeEach(func() {
   331  					fakeSecureSession.StdoutPipeReturns(nil, errors.New("woops"))
   332  				})
   333  
   334  				It("returns the error", func() {
   335  					Expect(sessionError).Should(MatchError("woops"))
   336  				})
   337  			})
   338  
   339  			It("gets a stderr pipe for the session", func() {
   340  				Expect(fakeSecureSession.StderrPipeCallCount()).To(Equal(1))
   341  			})
   342  
   343  			Context("when getting the stderr pipe fails", func() {
   344  				BeforeEach(func() {
   345  					fakeSecureSession.StderrPipeReturns(nil, errors.New("woops"))
   346  				})
   347  
   348  				It("returns the error", func() {
   349  					Expect(sessionError).Should(MatchError("woops"))
   350  				})
   351  			})
   352  		})
   353  
   354  		Context("when stdin is a terminal", func() {
   355  			var master, slave *os.File
   356  
   357  			BeforeEach(func() {
   358  				_, stdout, stderr := terminalHelper.StdStreams()
   359  
   360  				var err error
   361  				master, slave, err = pty.Open()
   362  				Expect(err).NotTo(HaveOccurred())
   363  
   364  				fakeTerminalHelper.IsTerminalStub = terminalHelper.IsTerminal
   365  				fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo
   366  				fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize
   367  				fakeTerminalHelper.StdStreamsReturns(slave, stdout, stderr)
   368  				terminalHelper = fakeTerminalHelper
   369  			})
   370  
   371  			AfterEach(func() {
   372  				master.Close()
   373  				// slave.Close() // race
   374  			})
   375  
   376  			Context("when a command is not specified", func() {
   377  				var terminalType string
   378  
   379  				BeforeEach(func() {
   380  					terminalType = os.Getenv("TERM")
   381  					os.Setenv("TERM", "test-terminal-type")
   382  
   383  					winsize := &term.Winsize{Width: 1024, Height: 256}
   384  					fakeTerminalHelper.GetWinsizeReturns(winsize, nil)
   385  
   386  					fakeSecureSession.ShellStub = func() error {
   387  						Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1))
   388  						Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(0))
   389  						return nil
   390  					}
   391  				})
   392  
   393  				AfterEach(func() {
   394  					os.Setenv("TERM", terminalType)
   395  				})
   396  
   397  				It("requests a pty with the correct terminal type, window size, and modes", func() {
   398  					Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1))
   399  					Expect(fakeTerminalHelper.GetWinsizeCallCount()).To(Equal(1))
   400  
   401  					termType, height, width, modes := fakeSecureSession.RequestPtyArgsForCall(0)
   402  					Expect(termType).To(Equal("test-terminal-type"))
   403  					Expect(height).To(Equal(256))
   404  					Expect(width).To(Equal(1024))
   405  
   406  					expectedModes := ssh.TerminalModes{
   407  						ssh.ECHO:          1,
   408  						ssh.TTY_OP_ISPEED: 115200,
   409  						ssh.TTY_OP_OSPEED: 115200,
   410  					}
   411  					Expect(modes).To(Equal(expectedModes))
   412  				})
   413  
   414  				Context("when the TERM environment variable is not set", func() {
   415  					BeforeEach(func() {
   416  						os.Unsetenv("TERM")
   417  					})
   418  
   419  					It("requests a pty with the default terminal type", func() {
   420  						Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1))
   421  
   422  						termType, _, _, _ := fakeSecureSession.RequestPtyArgsForCall(0)
   423  						Expect(termType).To(Equal("xterm"))
   424  					})
   425  				})
   426  
   427  				It("puts the terminal into raw mode and restores it after running the shell", func() {
   428  					Expect(fakeSecureSession.ShellCallCount()).To(Equal(1))
   429  					Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1))
   430  					Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(1))
   431  				})
   432  
   433  				Context("when the pty allocation fails", func() {
   434  					var ptyError error
   435  
   436  					BeforeEach(func() {
   437  						ptyError = errors.New("pty allocation error")
   438  						fakeSecureSession.RequestPtyReturns(ptyError)
   439  					})
   440  
   441  					It("returns the error", func() {
   442  						Expect(sessionError).To(Equal(ptyError))
   443  					})
   444  				})
   445  
   446  				Context("when placing the terminal into raw mode fails", func() {
   447  					BeforeEach(func() {
   448  						fakeTerminalHelper.SetRawTerminalReturns(nil, errors.New("woops"))
   449  					})
   450  
   451  					It("keeps calm and carries on", func() {
   452  						Expect(fakeSecureSession.ShellCallCount()).To(Equal(1))
   453  					})
   454  
   455  					It("does not not restore the terminal", func() {
   456  						Expect(fakeSecureSession.ShellCallCount()).To(Equal(1))
   457  						Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1))
   458  						Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(0))
   459  					})
   460  				})
   461  			})
   462  
   463  			Context("when a command is specified", func() {
   464  				BeforeEach(func() {
   465  					opts.Command = []string{"echo", "-n", "hello"}
   466  				})
   467  
   468  				Context("when a terminal is requested", func() {
   469  					BeforeEach(func() {
   470  						opts.TerminalRequest = options.RequestTTYYes
   471  					})
   472  
   473  					It("requests a pty", func() {
   474  						Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1))
   475  					})
   476  				})
   477  
   478  				Context("when a terminal is not explicitly requested", func() {
   479  					It("does not request a pty", func() {
   480  						Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0))
   481  					})
   482  				})
   483  			})
   484  		})
   485  
   486  		Context("when stdin is not a terminal", func() {
   487  			BeforeEach(func() {
   488  				_, stdout, stderr := terminalHelper.StdStreams()
   489  
   490  				stdin := &fake_io.FakeReadCloser{}
   491  				stdin.ReadStub = func(p []byte) (int, error) {
   492  					return 0, io.EOF
   493  				}
   494  
   495  				fakeTerminalHelper.IsTerminalStub = terminalHelper.IsTerminal
   496  				fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo
   497  				fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize
   498  				fakeTerminalHelper.StdStreamsReturns(stdin, stdout, stderr)
   499  				terminalHelper = fakeTerminalHelper
   500  			})
   501  
   502  			Context("when a terminal is not requested", func() {
   503  				It("does not request a pty", func() {
   504  					Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0))
   505  				})
   506  			})
   507  
   508  			Context("when a terminal is requested", func() {
   509  				BeforeEach(func() {
   510  					opts.TerminalRequest = options.RequestTTYYes
   511  				})
   512  
   513  				It("does not request a pty", func() {
   514  					Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0))
   515  				})
   516  			})
   517  		})
   518  
   519  		Context("when a terminal is forced", func() {
   520  			BeforeEach(func() {
   521  				opts.TerminalRequest = options.RequestTTYForce
   522  			})
   523  
   524  			It("requests a pty", func() {
   525  				Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1))
   526  			})
   527  		})
   528  
   529  		Context("when a terminal is disabled", func() {
   530  			BeforeEach(func() {
   531  				opts.TerminalRequest = options.RequestTTYNo
   532  			})
   533  
   534  			It("does not request a pty", func() {
   535  				Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0))
   536  			})
   537  		})
   538  
   539  		Context("when a command is not specified", func() {
   540  			It("requests an interactive shell", func() {
   541  				Expect(fakeSecureSession.ShellCallCount()).To(Equal(1))
   542  			})
   543  
   544  			Context("when the shell request returns an error", func() {
   545  				BeforeEach(func() {
   546  					fakeSecureSession.ShellReturns(errors.New("oh bother"))
   547  				})
   548  
   549  				It("returns the error", func() {
   550  					Expect(sessionError).To(MatchError("oh bother"))
   551  				})
   552  			})
   553  		})
   554  
   555  		Context("when a command is specifed", func() {
   556  			BeforeEach(func() {
   557  				opts.Command = []string{"echo", "-n", "hello"}
   558  			})
   559  
   560  			It("starts the command", func() {
   561  				Expect(fakeSecureSession.StartCallCount()).To(Equal(1))
   562  				Expect(fakeSecureSession.StartArgsForCall(0)).To(Equal("echo -n hello"))
   563  			})
   564  
   565  			Context("when the command fails to start", func() {
   566  				BeforeEach(func() {
   567  					fakeSecureSession.StartReturns(errors.New("oh well"))
   568  				})
   569  
   570  				It("returns the error", func() {
   571  					Expect(sessionError).To(MatchError("oh well"))
   572  				})
   573  			})
   574  		})
   575  
   576  		Context("when the shell or command has started", func() {
   577  			var (
   578  				stdin                  *fake_io.FakeReadCloser
   579  				stdout, stderr         *fake_io.FakeWriter
   580  				stdinPipe              *fake_io.FakeWriteCloser
   581  				stdoutPipe, stderrPipe *fake_io.FakeReader
   582  			)
   583  
   584  			BeforeEach(func() {
   585  				stdin = &fake_io.FakeReadCloser{}
   586  				stdin.ReadStub = func(p []byte) (int, error) {
   587  					p[0] = 0
   588  					return 1, io.EOF
   589  				}
   590  				stdinPipe = &fake_io.FakeWriteCloser{}
   591  				stdinPipe.WriteStub = func(p []byte) (int, error) {
   592  					defer GinkgoRecover()
   593  					Expect(p[0]).To(Equal(byte(0)))
   594  					return 1, nil
   595  				}
   596  
   597  				stdoutPipe = &fake_io.FakeReader{}
   598  				stdoutPipe.ReadStub = func(p []byte) (int, error) {
   599  					p[0] = 1
   600  					return 1, io.EOF
   601  				}
   602  				stdout = &fake_io.FakeWriter{}
   603  				stdout.WriteStub = func(p []byte) (int, error) {
   604  					defer GinkgoRecover()
   605  					Expect(p[0]).To(Equal(byte(1)))
   606  					return 1, nil
   607  				}
   608  
   609  				stderrPipe = &fake_io.FakeReader{}
   610  				stderrPipe.ReadStub = func(p []byte) (int, error) {
   611  					p[0] = 2
   612  					return 1, io.EOF
   613  				}
   614  				stderr = &fake_io.FakeWriter{}
   615  				stderr.WriteStub = func(p []byte) (int, error) {
   616  					defer GinkgoRecover()
   617  					Expect(p[0]).To(Equal(byte(2)))
   618  					return 1, nil
   619  				}
   620  
   621  				fakeTerminalHelper.StdStreamsReturns(stdin, stdout, stderr)
   622  				terminalHelper = fakeTerminalHelper
   623  
   624  				fakeSecureSession.StdinPipeReturns(stdinPipe, nil)
   625  				fakeSecureSession.StdoutPipeReturns(stdoutPipe, nil)
   626  				fakeSecureSession.StderrPipeReturns(stderrPipe, nil)
   627  
   628  				fakeSecureSession.WaitReturns(errors.New("error result"))
   629  			})
   630  
   631  			It("copies data from the stdin stream to the session stdin pipe", func() {
   632  				Eventually(stdin.ReadCallCount).Should(Equal(1))
   633  				Eventually(stdinPipe.WriteCallCount).Should(Equal(1))
   634  			})
   635  
   636  			It("copies data from the session stdout pipe to the stdout stream", func() {
   637  				Eventually(stdoutPipe.ReadCallCount).Should(Equal(1))
   638  				Eventually(stdout.WriteCallCount).Should(Equal(1))
   639  			})
   640  
   641  			It("copies data from the session stderr pipe to the stderr stream", func() {
   642  				Eventually(stderrPipe.ReadCallCount).Should(Equal(1))
   643  				Eventually(stderr.WriteCallCount).Should(Equal(1))
   644  			})
   645  
   646  			It("waits for the session to end", func() {
   647  				Expect(fakeSecureSession.WaitCallCount()).To(Equal(1))
   648  			})
   649  
   650  			It("returns the result from wait", func() {
   651  				Expect(sessionError).To(MatchError("error result"))
   652  			})
   653  
   654  			Context("when the session terminates before stream copies complete", func() {
   655  				var sessionErrorCh chan error
   656  
   657  				BeforeEach(func() {
   658  					sessionErrorCh = make(chan error, 1)
   659  
   660  					interactiveSessionInvoker = func(secureShell sshCmd.SecureShell) {
   661  						go func() { sessionErrorCh <- secureShell.InteractiveSession() }()
   662  					}
   663  
   664  					stdoutPipe.ReadStub = func(p []byte) (int, error) {
   665  						defer GinkgoRecover()
   666  						Eventually(fakeSecureSession.WaitCallCount).Should(Equal(1))
   667  						Consistently(sessionErrorCh).ShouldNot(Receive())
   668  
   669  						p[0] = 1
   670  						return 1, io.EOF
   671  					}
   672  
   673  					stderrPipe.ReadStub = func(p []byte) (int, error) {
   674  						defer GinkgoRecover()
   675  						Eventually(fakeSecureSession.WaitCallCount).Should(Equal(1))
   676  						Consistently(sessionErrorCh).ShouldNot(Receive())
   677  
   678  						p[0] = 2
   679  						return 1, io.EOF
   680  					}
   681  				})
   682  
   683  				It("waits for the copies to complete", func() {
   684  					Eventually(sessionErrorCh).Should(Receive())
   685  					Expect(stdoutPipe.ReadCallCount()).To(Equal(1))
   686  					Expect(stderrPipe.ReadCallCount()).To(Equal(1))
   687  				})
   688  			})
   689  
   690  			Context("when stdin is closed", func() {
   691  				BeforeEach(func() {
   692  					stdin.ReadStub = func(p []byte) (int, error) {
   693  						defer GinkgoRecover()
   694  						Consistently(stdinPipe.CloseCallCount).Should(Equal(0))
   695  						p[0] = 0
   696  						return 1, io.EOF
   697  					}
   698  				})
   699  
   700  				It("closes the stdinPipe", func() {
   701  					Eventually(stdinPipe.CloseCallCount).Should(Equal(1))
   702  				})
   703  			})
   704  		})
   705  
   706  		Context("when stdout is a terminal and a window size change occurs", func() {
   707  			var master, slave *os.File
   708  
   709  			BeforeEach(func() {
   710  				stdin, _, stderr := terminalHelper.StdStreams()
   711  
   712  				var err error
   713  				master, slave, err = pty.Open()
   714  				Expect(err).NotTo(HaveOccurred())
   715  
   716  				fakeTerminalHelper.IsTerminalStub = terminalHelper.IsTerminal
   717  				fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo
   718  				fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize
   719  				fakeTerminalHelper.StdStreamsReturns(stdin, slave, stderr)
   720  				terminalHelper = fakeTerminalHelper
   721  
   722  				winsize := &term.Winsize{Height: 100, Width: 100}
   723  				err = term.SetWinsize(slave.Fd(), winsize)
   724  				Expect(err).NotTo(HaveOccurred())
   725  
   726  				fakeSecureSession.WaitStub = func() error {
   727  					fakeSecureSession.SendRequestCallCount()
   728  					Expect(fakeSecureSession.SendRequestCallCount()).To(Equal(0))
   729  
   730  					// No dimension change
   731  					for i := 0; i < 3; i++ {
   732  						winsize := &term.Winsize{Height: 100, Width: 100}
   733  						err = term.SetWinsize(slave.Fd(), winsize)
   734  						Expect(err).NotTo(HaveOccurred())
   735  					}
   736  
   737  					winsize := &term.Winsize{Height: 100, Width: 200}
   738  					err = term.SetWinsize(slave.Fd(), winsize)
   739  					Expect(err).NotTo(HaveOccurred())
   740  
   741  					err = syscall.Kill(syscall.Getpid(), syscall.SIGWINCH)
   742  					Expect(err).NotTo(HaveOccurred())
   743  
   744  					Eventually(fakeSecureSession.SendRequestCallCount).Should(Equal(1))
   745  					return nil
   746  				}
   747  			})
   748  
   749  			AfterEach(func() {
   750  				master.Close()
   751  				slave.Close()
   752  			})
   753  
   754  			It("sends window change events when the window dimensions change", func() {
   755  				Expect(fakeSecureSession.SendRequestCallCount()).To(Equal(1))
   756  
   757  				requestType, wantReply, message := fakeSecureSession.SendRequestArgsForCall(0)
   758  				Expect(requestType).To(Equal("window-change"))
   759  				Expect(wantReply).To(BeFalse())
   760  
   761  				type resizeMessage struct {
   762  					Width       uint32
   763  					Height      uint32
   764  					PixelWidth  uint32
   765  					PixelHeight uint32
   766  				}
   767  				var resizeMsg resizeMessage
   768  
   769  				err := ssh.Unmarshal(message, &resizeMsg)
   770  				Expect(err).NotTo(HaveOccurred())
   771  
   772  				Expect(resizeMsg).To(Equal(resizeMessage{Height: 100, Width: 200}))
   773  			})
   774  		})
   775  
   776  		Describe("keep alive messages", func() {
   777  			var times []time.Time
   778  			var timesCh chan []time.Time
   779  			var done chan struct{}
   780  
   781  			BeforeEach(func() {
   782  				keepAliveDuration = 100 * time.Millisecond
   783  
   784  				times = []time.Time{}
   785  				timesCh = make(chan []time.Time, 1)
   786  				done = make(chan struct{}, 1)
   787  
   788  				fakeConnection.SendRequestStub = func(reqName string, wantReply bool, message []byte) (bool, []byte, error) {
   789  					Expect(reqName).To(Equal("keepalive@cloudfoundry.org"))
   790  					Expect(wantReply).To(BeTrue())
   791  					Expect(message).To(BeNil())
   792  
   793  					times = append(times, time.Now())
   794  					if len(times) == 3 {
   795  						timesCh <- times
   796  						close(done)
   797  					}
   798  					return true, nil, nil
   799  				}
   800  
   801  				fakeSecureSession.WaitStub = func() error {
   802  					Eventually(done).Should(BeClosed())
   803  					return nil
   804  				}
   805  			})
   806  
   807  			It("sends keep alive messages at the expected interval", func() {
   808  				times := <-timesCh
   809  				Expect(times[2]).To(BeTemporally("~", times[0].Add(200*time.Millisecond), 100*time.Millisecond))
   810  			})
   811  		})
   812  	})
   813  
   814  	Describe("LocalPortForward", func() {
   815  		var (
   816  			opts              *options.SSHOptions
   817  			localForwardError error
   818  
   819  			echoAddress  string
   820  			echoListener *fake_net.FakeListener
   821  			echoHandler  *fake_server.FakeConnectionHandler
   822  			echoServer   *server.Server
   823  
   824  			localAddress string
   825  
   826  			realLocalListener net.Listener
   827  			fakeLocalListener *fake_net.FakeListener
   828  		)
   829  
   830  		BeforeEach(func() {
   831  			logger := lagertest.NewTestLogger("test")
   832  
   833  			var err error
   834  			realLocalListener, err = net.Listen("tcp", "127.0.0.1:0")
   835  			Expect(err).NotTo(HaveOccurred())
   836  
   837  			localAddress = realLocalListener.Addr().String()
   838  			fakeListenerFactory.ListenReturns(realLocalListener, nil)
   839  
   840  			echoHandler = &fake_server.FakeConnectionHandler{}
   841  			echoHandler.HandleConnectionStub = func(conn net.Conn) {
   842  				io.Copy(conn, conn)
   843  				conn.Close()
   844  			}
   845  
   846  			realListener, err := net.Listen("tcp", "127.0.0.1:0")
   847  			Expect(err).NotTo(HaveOccurred())
   848  			echoAddress = realListener.Addr().String()
   849  
   850  			echoListener = &fake_net.FakeListener{}
   851  			echoListener.AcceptStub = realListener.Accept
   852  			echoListener.CloseStub = realListener.Close
   853  			echoListener.AddrStub = realListener.Addr
   854  
   855  			fakeLocalListener = &fake_net.FakeListener{}
   856  			fakeLocalListener.AcceptReturns(nil, errors.New("Not Accepting Connections"))
   857  
   858  			echoServer = server.NewServer(logger.Session("echo"), "", echoHandler)
   859  			echoServer.SetListener(echoListener)
   860  			go echoServer.Serve()
   861  
   862  			opts = &options.SSHOptions{
   863  				AppName: "app-1",
   864  				ForwardSpecs: []options.ForwardSpec{{
   865  					ListenAddress:  localAddress,
   866  					ConnectAddress: echoAddress,
   867  				}},
   868  			}
   869  
   870  			currentApp.State = "STARTED"
   871  			currentApp.Diego = true
   872  
   873  			sshEndpointFingerprint = ""
   874  			sshEndpoint = ""
   875  
   876  			token = ""
   877  
   878  			fakeSecureClient.DialStub = net.Dial
   879  		})
   880  
   881  		JustBeforeEach(func() {
   882  			connectErr := secureShell.Connect(opts)
   883  			Expect(connectErr).NotTo(HaveOccurred())
   884  
   885  			localForwardError = secureShell.LocalPortForward()
   886  		})
   887  
   888  		AfterEach(func() {
   889  			err := secureShell.Close()
   890  			Expect(err).NotTo(HaveOccurred())
   891  			echoServer.Shutdown()
   892  
   893  			realLocalListener.Close()
   894  		})
   895  
   896  		validateConnectivity := func(addr string) {
   897  			conn, err := net.Dial("tcp", addr)
   898  			Expect(err).NotTo(HaveOccurred())
   899  
   900  			msg := fmt.Sprintf("Hello from %s\n", addr)
   901  			n, err := conn.Write([]byte(msg))
   902  			Expect(err).NotTo(HaveOccurred())
   903  			Expect(n).To(Equal(len(msg)))
   904  
   905  			response := make([]byte, len(msg))
   906  			n, err = conn.Read(response)
   907  			Expect(err).NotTo(HaveOccurred())
   908  			Expect(n).To(Equal(len(msg)))
   909  
   910  			err = conn.Close()
   911  			Expect(err).NotTo(HaveOccurred())
   912  
   913  			Expect(response).To(Equal([]byte(msg)))
   914  		}
   915  
   916  		It("dials the connect address when a local connection is made", func() {
   917  			Expect(localForwardError).NotTo(HaveOccurred())
   918  
   919  			conn, err := net.Dial("tcp", localAddress)
   920  			Expect(err).NotTo(HaveOccurred())
   921  
   922  			Eventually(echoListener.AcceptCallCount).Should(BeNumerically(">=", 1))
   923  			Eventually(fakeSecureClient.DialCallCount).Should(Equal(1))
   924  
   925  			network, addr := fakeSecureClient.DialArgsForCall(0)
   926  			Expect(network).To(Equal("tcp"))
   927  			Expect(addr).To(Equal(echoAddress))
   928  
   929  			Expect(conn.Close()).NotTo(HaveOccurred())
   930  		})
   931  
   932  		It("copies data between the local and remote connections", func() {
   933  			validateConnectivity(localAddress)
   934  		})
   935  
   936  		Context("when a local connection is already open", func() {
   937  			var (
   938  				conn net.Conn
   939  				err  error
   940  			)
   941  
   942  			JustBeforeEach(func() {
   943  				conn, err = net.Dial("tcp", localAddress)
   944  				Expect(err).NotTo(HaveOccurred())
   945  			})
   946  
   947  			AfterEach(func() {
   948  				err = conn.Close()
   949  				Expect(err).NotTo(HaveOccurred())
   950  			})
   951  
   952  			It("allows for new incoming connections as well", func() {
   953  				validateConnectivity(localAddress)
   954  			})
   955  		})
   956  
   957  		Context("when there are multiple port forward specs", func() {
   958  			var realLocalListener2 net.Listener
   959  			var localAddress2 string
   960  
   961  			BeforeEach(func() {
   962  				var err error
   963  				realLocalListener2, err = net.Listen("tcp", "127.0.0.1:0")
   964  				Expect(err).NotTo(HaveOccurred())
   965  
   966  				localAddress2 = realLocalListener2.Addr().String()
   967  
   968  				fakeListenerFactory.ListenStub = func(network, addr string) (net.Listener, error) {
   969  					if addr == localAddress {
   970  						return realLocalListener, nil
   971  					}
   972  
   973  					if addr == localAddress2 {
   974  						return realLocalListener2, nil
   975  					}
   976  
   977  					return nil, errors.New("unexpected address")
   978  				}
   979  
   980  				opts = &options.SSHOptions{
   981  					AppName: "app-1",
   982  					ForwardSpecs: []options.ForwardSpec{{
   983  						ListenAddress:  localAddress,
   984  						ConnectAddress: echoAddress,
   985  					}, {
   986  						ListenAddress:  localAddress2,
   987  						ConnectAddress: echoAddress,
   988  					}},
   989  				}
   990  			})
   991  
   992  			AfterEach(func() {
   993  				realLocalListener2.Close()
   994  			})
   995  
   996  			It("listens to all the things", func() {
   997  				Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(2))
   998  
   999  				network, addr := fakeListenerFactory.ListenArgsForCall(0)
  1000  				Expect(network).To(Equal("tcp"))
  1001  				Expect(addr).To(Equal(localAddress))
  1002  
  1003  				network, addr = fakeListenerFactory.ListenArgsForCall(1)
  1004  				Expect(network).To(Equal("tcp"))
  1005  				Expect(addr).To(Equal(localAddress2))
  1006  			})
  1007  
  1008  			It("forwards to the correct target", func() {
  1009  				validateConnectivity(localAddress)
  1010  				validateConnectivity(localAddress2)
  1011  			})
  1012  
  1013  			Context("when the secure client is closed", func() {
  1014  				BeforeEach(func() {
  1015  					fakeListenerFactory.ListenReturns(fakeLocalListener, nil)
  1016  					fakeLocalListener.AcceptReturns(nil, errors.New("not accepting connections"))
  1017  				})
  1018  
  1019  				It("closes the listeners ", func() {
  1020  					Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(2))
  1021  					Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(2))
  1022  
  1023  					originalCloseCount := fakeLocalListener.CloseCallCount()
  1024  					err := secureShell.Close()
  1025  					Expect(err).NotTo(HaveOccurred())
  1026  					Expect(fakeLocalListener.CloseCallCount()).Should(Equal(originalCloseCount + 2))
  1027  				})
  1028  			})
  1029  		})
  1030  
  1031  		Context("when listen fails", func() {
  1032  			BeforeEach(func() {
  1033  				fakeListenerFactory.ListenReturns(nil, errors.New("failure is an option"))
  1034  			})
  1035  
  1036  			It("returns the error", func() {
  1037  				Expect(localForwardError).To(MatchError("failure is an option"))
  1038  			})
  1039  		})
  1040  
  1041  		Context("when the client it closed", func() {
  1042  			BeforeEach(func() {
  1043  				fakeListenerFactory.ListenReturns(fakeLocalListener, nil)
  1044  				fakeLocalListener.AcceptReturns(nil, errors.New("not accepting and connections"))
  1045  			})
  1046  
  1047  			It("closes the listener when the client is closed", func() {
  1048  				Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(1))
  1049  				Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(1))
  1050  
  1051  				originalCloseCount := fakeLocalListener.CloseCallCount()
  1052  				err := secureShell.Close()
  1053  				Expect(err).NotTo(HaveOccurred())
  1054  				Expect(fakeLocalListener.CloseCallCount()).Should(Equal(originalCloseCount + 1))
  1055  			})
  1056  		})
  1057  
  1058  		Context("when accept fails", func() {
  1059  			var fakeConn *fake_net.FakeConn
  1060  			BeforeEach(func() {
  1061  				fakeConn = &fake_net.FakeConn{}
  1062  				fakeConn.ReadReturns(0, io.EOF)
  1063  
  1064  				fakeListenerFactory.ListenReturns(fakeLocalListener, nil)
  1065  			})
  1066  
  1067  			Context("with a permanent error", func() {
  1068  				BeforeEach(func() {
  1069  					fakeLocalListener.AcceptReturns(nil, errors.New("boom"))
  1070  				})
  1071  
  1072  				It("stops trying to accept connections", func() {
  1073  					Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(1))
  1074  					Consistently(fakeLocalListener.AcceptCallCount).Should(Equal(1))
  1075  					Expect(fakeLocalListener.CloseCallCount()).To(Equal(1))
  1076  				})
  1077  			})
  1078  
  1079  			Context("with a temporary error", func() {
  1080  				var timeCh chan time.Time
  1081  
  1082  				BeforeEach(func() {
  1083  					timeCh = make(chan time.Time, 3)
  1084  
  1085  					fakeLocalListener.AcceptStub = func() (net.Conn, error) {
  1086  						timeCh := timeCh
  1087  						if fakeLocalListener.AcceptCallCount() > 3 {
  1088  							close(timeCh)
  1089  							return nil, test_helpers.NewTestNetError(false, false)
  1090  						} else {
  1091  							timeCh <- time.Now()
  1092  							return nil, test_helpers.NewTestNetError(false, true)
  1093  						}
  1094  					}
  1095  				})
  1096  
  1097  				It("retries connecting after a short delay", func() {
  1098  					Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(3))
  1099  					Expect(timeCh).To(HaveLen(3))
  1100  
  1101  					times := make([]time.Time, 0)
  1102  					for t := range timeCh {
  1103  						times = append(times, t)
  1104  					}
  1105  
  1106  					Expect(times[1]).To(BeTemporally("~", times[0].Add(115*time.Millisecond), 30*time.Millisecond))
  1107  					Expect(times[2]).To(BeTemporally("~", times[1].Add(115*time.Millisecond), 30*time.Millisecond))
  1108  				})
  1109  			})
  1110  		})
  1111  
  1112  		Context("when dialing the connect address fails", func() {
  1113  			var fakeTarget *fake_net.FakeConn
  1114  
  1115  			BeforeEach(func() {
  1116  				fakeTarget = &fake_net.FakeConn{}
  1117  				fakeSecureClient.DialReturns(fakeTarget, errors.New("boom"))
  1118  			})
  1119  
  1120  			It("does not call close on the target connection", func() {
  1121  				Consistently(fakeTarget.CloseCallCount).Should(Equal(0))
  1122  			})
  1123  		})
  1124  	})
  1125  
  1126  	Describe("Wait", func() {
  1127  		var opts *options.SSHOptions
  1128  		var waitErr error
  1129  
  1130  		BeforeEach(func() {
  1131  			opts = &options.SSHOptions{
  1132  				AppName: "app-1",
  1133  			}
  1134  
  1135  			currentApp.State = "STARTED"
  1136  			currentApp.Diego = true
  1137  
  1138  			sshEndpointFingerprint = ""
  1139  			sshEndpoint = ""
  1140  
  1141  			token = ""
  1142  		})
  1143  
  1144  		JustBeforeEach(func() {
  1145  			connectErr := secureShell.Connect(opts)
  1146  			Expect(connectErr).NotTo(HaveOccurred())
  1147  
  1148  			waitErr = secureShell.Wait()
  1149  		})
  1150  
  1151  		It("calls wait on the secureClient", func() {
  1152  			Expect(waitErr).NotTo(HaveOccurred())
  1153  			Expect(fakeSecureClient.WaitCallCount()).To(Equal(1))
  1154  		})
  1155  
  1156  		Describe("keep alive messages", func() {
  1157  			var times []time.Time
  1158  			var timesCh chan []time.Time
  1159  			var done chan struct{}
  1160  
  1161  			BeforeEach(func() {
  1162  				keepAliveDuration = 100 * time.Millisecond
  1163  
  1164  				times = []time.Time{}
  1165  				timesCh = make(chan []time.Time, 1)
  1166  				done = make(chan struct{}, 1)
  1167  
  1168  				fakeConnection.SendRequestStub = func(reqName string, wantReply bool, message []byte) (bool, []byte, error) {
  1169  					Expect(reqName).To(Equal("keepalive@cloudfoundry.org"))
  1170  					Expect(wantReply).To(BeTrue())
  1171  					Expect(message).To(BeNil())
  1172  
  1173  					times = append(times, time.Now())
  1174  					if len(times) == 3 {
  1175  						timesCh <- times
  1176  						close(done)
  1177  					}
  1178  					return true, nil, nil
  1179  				}
  1180  
  1181  				fakeSecureClient.WaitStub = func() error {
  1182  					Eventually(done).Should(BeClosed())
  1183  					return nil
  1184  				}
  1185  			})
  1186  
  1187  			It("sends keep alive messages at the expected interval", func() {
  1188  				Expect(waitErr).NotTo(HaveOccurred())
  1189  				times := <-timesCh
  1190  				Expect(times[2]).To(BeTemporally("~", times[0].Add(200*time.Millisecond), 100*time.Millisecond))
  1191  			})
  1192  		})
  1193  	})
  1194  
  1195  	Describe("Close", func() {
  1196  		var opts *options.SSHOptions
  1197  
  1198  		BeforeEach(func() {
  1199  			opts = &options.SSHOptions{
  1200  				AppName: "app-1",
  1201  			}
  1202  
  1203  			currentApp.State = "STARTED"
  1204  			currentApp.Diego = true
  1205  
  1206  			sshEndpointFingerprint = ""
  1207  			sshEndpoint = ""
  1208  
  1209  			token = ""
  1210  		})
  1211  
  1212  		JustBeforeEach(func() {
  1213  			connectErr := secureShell.Connect(opts)
  1214  			Expect(connectErr).NotTo(HaveOccurred())
  1215  		})
  1216  
  1217  		It("calls close on the secureClient", func() {
  1218  			err := secureShell.Close()
  1219  			Expect(err).NotTo(HaveOccurred())
  1220  
  1221  			Expect(fakeSecureClient.CloseCallCount()).To(Equal(1))
  1222  		})
  1223  	})
  1224  })