github.com/dcarley/cf-cli@v6.24.1-0.20170220111324-4225ff346898+incompatible/cf/ssh/ssh_test.go (about) 1 // +build !windows,!386 2 3 // skipping 386 because lager uses UInt64 in Session() 4 // skipping windows because Unix/Linux only syscall in test. 5 // should refactor out the conflicts so we could test this package in multi platforms. 6 7 package sshCmd_test 8 9 import ( 10 "errors" 11 "fmt" 12 "io" 13 "net" 14 "os" 15 "syscall" 16 "time" 17 18 "code.cloudfoundry.org/cli/cf/models" 19 "code.cloudfoundry.org/cli/cf/ssh" 20 "code.cloudfoundry.org/cli/cf/ssh/options" 21 "code.cloudfoundry.org/cli/cf/ssh/sshfakes" 22 "code.cloudfoundry.org/cli/cf/ssh/terminal" 23 "code.cloudfoundry.org/cli/cf/ssh/terminal/terminalfakes" 24 "code.cloudfoundry.org/diego-ssh/server" 25 fake_server "code.cloudfoundry.org/diego-ssh/server/fakes" 26 "code.cloudfoundry.org/diego-ssh/test_helpers" 27 "code.cloudfoundry.org/diego-ssh/test_helpers/fake_io" 28 "code.cloudfoundry.org/diego-ssh/test_helpers/fake_net" 29 "code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh" 30 "code.cloudfoundry.org/lager/lagertest" 31 "github.com/docker/docker/pkg/term" 32 "github.com/kr/pty" 33 "golang.org/x/crypto/ssh" 34 35 . "github.com/onsi/ginkgo" 36 . "github.com/onsi/gomega" 37 ) 38 39 var _ = Describe("SSH", func() { 40 var ( 41 fakeTerminalHelper *terminalfakes.FakeTerminalHelper 42 fakeListenerFactory *sshfakes.FakeListenerFactory 43 44 fakeConnection *fake_ssh.FakeConn 45 fakeSecureClient *sshfakes.FakeSecureClient 46 fakeSecureDialer *sshfakes.FakeSecureDialer 47 fakeSecureSession *sshfakes.FakeSecureSession 48 49 terminalHelper terminal.TerminalHelper 50 keepAliveDuration time.Duration 51 secureShell sshCmd.SecureShell 52 53 stdinPipe *fake_io.FakeWriteCloser 54 55 currentApp models.Application 56 sshEndpointFingerprint string 57 sshEndpoint string 58 token string 59 ) 60 61 BeforeEach(func() { 62 fakeTerminalHelper = new(terminalfakes.FakeTerminalHelper) 63 terminalHelper = terminal.DefaultHelper() 64 65 fakeListenerFactory = new(sshfakes.FakeListenerFactory) 66 fakeListenerFactory.ListenStub = net.Listen 67 68 keepAliveDuration = 30 * time.Second 69 70 currentApp = models.Application{} 71 sshEndpoint = "" 72 sshEndpointFingerprint = "" 73 token = "" 74 75 fakeConnection = new(fake_ssh.FakeConn) 76 fakeSecureClient = new(sshfakes.FakeSecureClient) 77 fakeSecureDialer = new(sshfakes.FakeSecureDialer) 78 fakeSecureSession = new(sshfakes.FakeSecureSession) 79 80 fakeSecureDialer.DialReturns(fakeSecureClient, nil) 81 fakeSecureClient.NewSessionReturns(fakeSecureSession, nil) 82 fakeSecureClient.ConnReturns(fakeConnection) 83 84 stdinPipe = &fake_io.FakeWriteCloser{} 85 stdinPipe.WriteStub = func(p []byte) (int, error) { 86 return len(p), nil 87 } 88 89 stdoutPipe := &fake_io.FakeReader{} 90 stdoutPipe.ReadStub = func(p []byte) (int, error) { 91 return 0, io.EOF 92 } 93 94 stderrPipe := &fake_io.FakeReader{} 95 stderrPipe.ReadStub = func(p []byte) (int, error) { 96 return 0, io.EOF 97 } 98 99 fakeSecureSession.StdinPipeReturns(stdinPipe, nil) 100 fakeSecureSession.StdoutPipeReturns(stdoutPipe, nil) 101 fakeSecureSession.StderrPipeReturns(stderrPipe, nil) 102 }) 103 104 JustBeforeEach(func() { 105 secureShell = sshCmd.NewSecureShell( 106 fakeSecureDialer, 107 terminalHelper, 108 fakeListenerFactory, 109 keepAliveDuration, 110 currentApp, 111 sshEndpointFingerprint, 112 sshEndpoint, 113 token, 114 ) 115 }) 116 117 Describe("Validation", func() { 118 var connectErr error 119 var opts *options.SSHOptions 120 121 BeforeEach(func() { 122 opts = &options.SSHOptions{ 123 AppName: "app-1", 124 } 125 }) 126 127 JustBeforeEach(func() { 128 connectErr = secureShell.Connect(opts) 129 }) 130 131 Context("when the app model and endpoint info are successfully acquired", func() { 132 BeforeEach(func() { 133 token = "" 134 currentApp.State = "STARTED" 135 currentApp.Diego = true 136 }) 137 138 Context("when the app is not in the 'STARTED' state", func() { 139 BeforeEach(func() { 140 currentApp.State = "STOPPED" 141 currentApp.Diego = true 142 }) 143 144 It("returns an error", func() { 145 Expect(connectErr).To(MatchError(MatchRegexp("Application.*not in the STARTED state"))) 146 }) 147 }) 148 149 Context("when the app is not a Diego app", func() { 150 BeforeEach(func() { 151 currentApp.State = "STARTED" 152 currentApp.Diego = false 153 }) 154 155 It("returns an error", func() { 156 Expect(connectErr).To(MatchError(MatchRegexp("Application.*not running on Diego"))) 157 }) 158 }) 159 160 Context("when dialing fails", func() { 161 var dialError = errors.New("woops") 162 163 BeforeEach(func() { 164 fakeSecureDialer.DialReturns(nil, dialError) 165 }) 166 167 It("returns the dial error", func() { 168 Expect(connectErr).To(Equal(dialError)) 169 Expect(fakeSecureDialer.DialCallCount()).To(Equal(1)) 170 }) 171 }) 172 }) 173 }) 174 175 Describe("InteractiveSession", func() { 176 var opts *options.SSHOptions 177 var sessionError error 178 var interactiveSessionInvoker func(secureShell sshCmd.SecureShell) 179 180 BeforeEach(func() { 181 sshEndpoint = "ssh.example.com:22" 182 183 opts = &options.SSHOptions{ 184 AppName: "app-name", 185 Index: 2, 186 } 187 188 currentApp.State = "STARTED" 189 currentApp.Diego = true 190 currentApp.GUID = "app-guid" 191 token = "bearer token" 192 193 interactiveSessionInvoker = func(secureShell sshCmd.SecureShell) { 194 sessionError = secureShell.InteractiveSession() 195 } 196 }) 197 198 JustBeforeEach(func() { 199 connectErr := secureShell.Connect(opts) 200 Expect(connectErr).NotTo(HaveOccurred()) 201 interactiveSessionInvoker(secureShell) 202 }) 203 204 It("dials the correct endpoint as the correct user", func() { 205 Expect(fakeSecureDialer.DialCallCount()).To(Equal(1)) 206 207 network, address, config := fakeSecureDialer.DialArgsForCall(0) 208 Expect(network).To(Equal("tcp")) 209 Expect(address).To(Equal("ssh.example.com:22")) 210 Expect(config.Auth).NotTo(BeEmpty()) 211 Expect(config.User).To(Equal("cf:app-guid/2")) 212 Expect(config.HostKeyCallback).NotTo(BeNil()) 213 }) 214 215 Context("when host key validation is enabled", func() { 216 var callback func(hostname string, remote net.Addr, key ssh.PublicKey) error 217 var addr net.Addr 218 219 JustBeforeEach(func() { 220 Expect(fakeSecureDialer.DialCallCount()).To(Equal(1)) 221 _, _, config := fakeSecureDialer.DialArgsForCall(0) 222 callback = config.HostKeyCallback 223 224 listener, err := net.Listen("tcp", "localhost:0") 225 Expect(err).NotTo(HaveOccurred()) 226 227 addr = listener.Addr() 228 listener.Close() 229 }) 230 231 Context("when the SHA1 fingerprint does not match", func() { 232 BeforeEach(func() { 233 sshEndpointFingerprint = "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00" 234 }) 235 236 It("returns an error'", func() { 237 err := callback("", addr, TestHostKey.PublicKey()) 238 Expect(err).To(MatchError(MatchRegexp("Host key verification failed\\."))) 239 Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\""))) 240 }) 241 }) 242 243 Context("when the MD5 fingerprint does not match", func() { 244 BeforeEach(func() { 245 sshEndpointFingerprint = "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00" 246 }) 247 248 It("returns an error'", func() { 249 err := callback("", addr, TestHostKey.PublicKey()) 250 Expect(err).To(MatchError(MatchRegexp("Host key verification failed\\."))) 251 Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\""))) 252 }) 253 }) 254 255 Context("when no fingerprint is present in endpoint info", func() { 256 BeforeEach(func() { 257 sshEndpointFingerprint = "" 258 sshEndpoint = "" 259 }) 260 261 It("returns an error'", func() { 262 err := callback("", addr, TestHostKey.PublicKey()) 263 Expect(err).To(MatchError(MatchRegexp("Unable to verify identity of host\\."))) 264 Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\""))) 265 }) 266 }) 267 268 Context("when the fingerprint length doesn't make sense", func() { 269 BeforeEach(func() { 270 sshEndpointFingerprint = "garbage" 271 }) 272 273 It("returns an error", func() { 274 err := callback("", addr, TestHostKey.PublicKey()) 275 Eventually(err).Should(MatchError(MatchRegexp("Unsupported host key fingerprint format"))) 276 }) 277 }) 278 }) 279 280 Context("when the skip host validation flag is set", func() { 281 BeforeEach(func() { 282 opts.SkipHostValidation = true 283 }) 284 285 It("removes the HostKeyCallback from the client config", func() { 286 Expect(fakeSecureDialer.DialCallCount()).To(Equal(1)) 287 288 _, _, config := fakeSecureDialer.DialArgsForCall(0) 289 Expect(config.HostKeyCallback).To(BeNil()) 290 }) 291 }) 292 293 Context("when dialing is successful", func() { 294 BeforeEach(func() { 295 fakeTerminalHelper.StdStreamsStub = terminalHelper.StdStreams 296 terminalHelper = fakeTerminalHelper 297 }) 298 299 It("creates a new secure shell session", func() { 300 Expect(fakeSecureClient.NewSessionCallCount()).To(Equal(1)) 301 }) 302 303 It("closes the session", func() { 304 Expect(fakeSecureSession.CloseCallCount()).To(Equal(1)) 305 }) 306 307 It("allocates standard streams", func() { 308 Expect(fakeTerminalHelper.StdStreamsCallCount()).To(Equal(1)) 309 }) 310 311 It("gets a stdin pipe for the session", func() { 312 Expect(fakeSecureSession.StdinPipeCallCount()).To(Equal(1)) 313 }) 314 315 Context("when getting the stdin pipe fails", func() { 316 BeforeEach(func() { 317 fakeSecureSession.StdinPipeReturns(nil, errors.New("woops")) 318 }) 319 320 It("returns the error", func() { 321 Expect(sessionError).Should(MatchError("woops")) 322 }) 323 }) 324 325 It("gets a stdout pipe for the session", func() { 326 Expect(fakeSecureSession.StdoutPipeCallCount()).To(Equal(1)) 327 }) 328 329 Context("when getting the stdout pipe fails", func() { 330 BeforeEach(func() { 331 fakeSecureSession.StdoutPipeReturns(nil, errors.New("woops")) 332 }) 333 334 It("returns the error", func() { 335 Expect(sessionError).Should(MatchError("woops")) 336 }) 337 }) 338 339 It("gets a stderr pipe for the session", func() { 340 Expect(fakeSecureSession.StderrPipeCallCount()).To(Equal(1)) 341 }) 342 343 Context("when getting the stderr pipe fails", func() { 344 BeforeEach(func() { 345 fakeSecureSession.StderrPipeReturns(nil, errors.New("woops")) 346 }) 347 348 It("returns the error", func() { 349 Expect(sessionError).Should(MatchError("woops")) 350 }) 351 }) 352 }) 353 354 Context("when stdin is a terminal", func() { 355 var master, slave *os.File 356 357 BeforeEach(func() { 358 _, stdout, stderr := terminalHelper.StdStreams() 359 360 var err error 361 master, slave, err = pty.Open() 362 Expect(err).NotTo(HaveOccurred()) 363 364 fakeTerminalHelper.IsTerminalStub = terminalHelper.IsTerminal 365 fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo 366 fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize 367 fakeTerminalHelper.StdStreamsReturns(slave, stdout, stderr) 368 terminalHelper = fakeTerminalHelper 369 }) 370 371 AfterEach(func() { 372 master.Close() 373 // slave.Close() // race 374 }) 375 376 Context("when a command is not specified", func() { 377 var terminalType string 378 379 BeforeEach(func() { 380 terminalType = os.Getenv("TERM") 381 os.Setenv("TERM", "test-terminal-type") 382 383 winsize := &term.Winsize{Width: 1024, Height: 256} 384 fakeTerminalHelper.GetWinsizeReturns(winsize, nil) 385 386 fakeSecureSession.ShellStub = func() error { 387 Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1)) 388 Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(0)) 389 return nil 390 } 391 }) 392 393 AfterEach(func() { 394 os.Setenv("TERM", terminalType) 395 }) 396 397 It("requests a pty with the correct terminal type, window size, and modes", func() { 398 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1)) 399 Expect(fakeTerminalHelper.GetWinsizeCallCount()).To(Equal(1)) 400 401 termType, height, width, modes := fakeSecureSession.RequestPtyArgsForCall(0) 402 Expect(termType).To(Equal("test-terminal-type")) 403 Expect(height).To(Equal(256)) 404 Expect(width).To(Equal(1024)) 405 406 expectedModes := ssh.TerminalModes{ 407 ssh.ECHO: 1, 408 ssh.TTY_OP_ISPEED: 115200, 409 ssh.TTY_OP_OSPEED: 115200, 410 } 411 Expect(modes).To(Equal(expectedModes)) 412 }) 413 414 Context("when the TERM environment variable is not set", func() { 415 BeforeEach(func() { 416 os.Unsetenv("TERM") 417 }) 418 419 It("requests a pty with the default terminal type", func() { 420 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1)) 421 422 termType, _, _, _ := fakeSecureSession.RequestPtyArgsForCall(0) 423 Expect(termType).To(Equal("xterm")) 424 }) 425 }) 426 427 It("puts the terminal into raw mode and restores it after running the shell", func() { 428 Expect(fakeSecureSession.ShellCallCount()).To(Equal(1)) 429 Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1)) 430 Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(1)) 431 }) 432 433 Context("when the pty allocation fails", func() { 434 var ptyError error 435 436 BeforeEach(func() { 437 ptyError = errors.New("pty allocation error") 438 fakeSecureSession.RequestPtyReturns(ptyError) 439 }) 440 441 It("returns the error", func() { 442 Expect(sessionError).To(Equal(ptyError)) 443 }) 444 }) 445 446 Context("when placing the terminal into raw mode fails", func() { 447 BeforeEach(func() { 448 fakeTerminalHelper.SetRawTerminalReturns(nil, errors.New("woops")) 449 }) 450 451 It("keeps calm and carries on", func() { 452 Expect(fakeSecureSession.ShellCallCount()).To(Equal(1)) 453 }) 454 455 It("does not not restore the terminal", func() { 456 Expect(fakeSecureSession.ShellCallCount()).To(Equal(1)) 457 Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1)) 458 Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(0)) 459 }) 460 }) 461 }) 462 463 Context("when a command is specified", func() { 464 BeforeEach(func() { 465 opts.Command = []string{"echo", "-n", "hello"} 466 }) 467 468 Context("when a terminal is requested", func() { 469 BeforeEach(func() { 470 opts.TerminalRequest = options.RequestTTYYes 471 }) 472 473 It("requests a pty", func() { 474 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1)) 475 }) 476 }) 477 478 Context("when a terminal is not explicitly requested", func() { 479 It("does not request a pty", func() { 480 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0)) 481 }) 482 }) 483 }) 484 }) 485 486 Context("when stdin is not a terminal", func() { 487 BeforeEach(func() { 488 _, stdout, stderr := terminalHelper.StdStreams() 489 490 stdin := &fake_io.FakeReadCloser{} 491 stdin.ReadStub = func(p []byte) (int, error) { 492 return 0, io.EOF 493 } 494 495 fakeTerminalHelper.IsTerminalStub = terminalHelper.IsTerminal 496 fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo 497 fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize 498 fakeTerminalHelper.StdStreamsReturns(stdin, stdout, stderr) 499 terminalHelper = fakeTerminalHelper 500 }) 501 502 Context("when a terminal is not requested", func() { 503 It("does not request a pty", func() { 504 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0)) 505 }) 506 }) 507 508 Context("when a terminal is requested", func() { 509 BeforeEach(func() { 510 opts.TerminalRequest = options.RequestTTYYes 511 }) 512 513 It("does not request a pty", func() { 514 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0)) 515 }) 516 }) 517 }) 518 519 Context("when a terminal is forced", func() { 520 BeforeEach(func() { 521 opts.TerminalRequest = options.RequestTTYForce 522 }) 523 524 It("requests a pty", func() { 525 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1)) 526 }) 527 }) 528 529 Context("when a terminal is disabled", func() { 530 BeforeEach(func() { 531 opts.TerminalRequest = options.RequestTTYNo 532 }) 533 534 It("does not request a pty", func() { 535 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0)) 536 }) 537 }) 538 539 Context("when a command is not specified", func() { 540 It("requests an interactive shell", func() { 541 Expect(fakeSecureSession.ShellCallCount()).To(Equal(1)) 542 }) 543 544 Context("when the shell request returns an error", func() { 545 BeforeEach(func() { 546 fakeSecureSession.ShellReturns(errors.New("oh bother")) 547 }) 548 549 It("returns the error", func() { 550 Expect(sessionError).To(MatchError("oh bother")) 551 }) 552 }) 553 }) 554 555 Context("when a command is specifed", func() { 556 BeforeEach(func() { 557 opts.Command = []string{"echo", "-n", "hello"} 558 }) 559 560 It("starts the command", func() { 561 Expect(fakeSecureSession.StartCallCount()).To(Equal(1)) 562 Expect(fakeSecureSession.StartArgsForCall(0)).To(Equal("echo -n hello")) 563 }) 564 565 Context("when the command fails to start", func() { 566 BeforeEach(func() { 567 fakeSecureSession.StartReturns(errors.New("oh well")) 568 }) 569 570 It("returns the error", func() { 571 Expect(sessionError).To(MatchError("oh well")) 572 }) 573 }) 574 }) 575 576 Context("when the shell or command has started", func() { 577 var ( 578 stdin *fake_io.FakeReadCloser 579 stdout, stderr *fake_io.FakeWriter 580 stdinPipe *fake_io.FakeWriteCloser 581 stdoutPipe, stderrPipe *fake_io.FakeReader 582 ) 583 584 BeforeEach(func() { 585 stdin = &fake_io.FakeReadCloser{} 586 stdin.ReadStub = func(p []byte) (int, error) { 587 p[0] = 0 588 return 1, io.EOF 589 } 590 stdinPipe = &fake_io.FakeWriteCloser{} 591 stdinPipe.WriteStub = func(p []byte) (int, error) { 592 defer GinkgoRecover() 593 Expect(p[0]).To(Equal(byte(0))) 594 return 1, nil 595 } 596 597 stdoutPipe = &fake_io.FakeReader{} 598 stdoutPipe.ReadStub = func(p []byte) (int, error) { 599 p[0] = 1 600 return 1, io.EOF 601 } 602 stdout = &fake_io.FakeWriter{} 603 stdout.WriteStub = func(p []byte) (int, error) { 604 defer GinkgoRecover() 605 Expect(p[0]).To(Equal(byte(1))) 606 return 1, nil 607 } 608 609 stderrPipe = &fake_io.FakeReader{} 610 stderrPipe.ReadStub = func(p []byte) (int, error) { 611 p[0] = 2 612 return 1, io.EOF 613 } 614 stderr = &fake_io.FakeWriter{} 615 stderr.WriteStub = func(p []byte) (int, error) { 616 defer GinkgoRecover() 617 Expect(p[0]).To(Equal(byte(2))) 618 return 1, nil 619 } 620 621 fakeTerminalHelper.StdStreamsReturns(stdin, stdout, stderr) 622 terminalHelper = fakeTerminalHelper 623 624 fakeSecureSession.StdinPipeReturns(stdinPipe, nil) 625 fakeSecureSession.StdoutPipeReturns(stdoutPipe, nil) 626 fakeSecureSession.StderrPipeReturns(stderrPipe, nil) 627 628 fakeSecureSession.WaitReturns(errors.New("error result")) 629 }) 630 631 It("copies data from the stdin stream to the session stdin pipe", func() { 632 Eventually(stdin.ReadCallCount).Should(Equal(1)) 633 Eventually(stdinPipe.WriteCallCount).Should(Equal(1)) 634 }) 635 636 It("copies data from the session stdout pipe to the stdout stream", func() { 637 Eventually(stdoutPipe.ReadCallCount).Should(Equal(1)) 638 Eventually(stdout.WriteCallCount).Should(Equal(1)) 639 }) 640 641 It("copies data from the session stderr pipe to the stderr stream", func() { 642 Eventually(stderrPipe.ReadCallCount).Should(Equal(1)) 643 Eventually(stderr.WriteCallCount).Should(Equal(1)) 644 }) 645 646 It("waits for the session to end", func() { 647 Expect(fakeSecureSession.WaitCallCount()).To(Equal(1)) 648 }) 649 650 It("returns the result from wait", func() { 651 Expect(sessionError).To(MatchError("error result")) 652 }) 653 654 Context("when the session terminates before stream copies complete", func() { 655 var sessionErrorCh chan error 656 657 BeforeEach(func() { 658 sessionErrorCh = make(chan error, 1) 659 660 interactiveSessionInvoker = func(secureShell sshCmd.SecureShell) { 661 go func() { sessionErrorCh <- secureShell.InteractiveSession() }() 662 } 663 664 stdoutPipe.ReadStub = func(p []byte) (int, error) { 665 defer GinkgoRecover() 666 Eventually(fakeSecureSession.WaitCallCount).Should(Equal(1)) 667 Consistently(sessionErrorCh).ShouldNot(Receive()) 668 669 p[0] = 1 670 return 1, io.EOF 671 } 672 673 stderrPipe.ReadStub = func(p []byte) (int, error) { 674 defer GinkgoRecover() 675 Eventually(fakeSecureSession.WaitCallCount).Should(Equal(1)) 676 Consistently(sessionErrorCh).ShouldNot(Receive()) 677 678 p[0] = 2 679 return 1, io.EOF 680 } 681 }) 682 683 It("waits for the copies to complete", func() { 684 Eventually(sessionErrorCh).Should(Receive()) 685 Expect(stdoutPipe.ReadCallCount()).To(Equal(1)) 686 Expect(stderrPipe.ReadCallCount()).To(Equal(1)) 687 }) 688 }) 689 690 Context("when stdin is closed", func() { 691 BeforeEach(func() { 692 stdin.ReadStub = func(p []byte) (int, error) { 693 defer GinkgoRecover() 694 Consistently(stdinPipe.CloseCallCount).Should(Equal(0)) 695 p[0] = 0 696 return 1, io.EOF 697 } 698 }) 699 700 It("closes the stdinPipe", func() { 701 Eventually(stdinPipe.CloseCallCount).Should(Equal(1)) 702 }) 703 }) 704 }) 705 706 Context("when stdout is a terminal and a window size change occurs", func() { 707 var master, slave *os.File 708 709 BeforeEach(func() { 710 stdin, _, stderr := terminalHelper.StdStreams() 711 712 var err error 713 master, slave, err = pty.Open() 714 Expect(err).NotTo(HaveOccurred()) 715 716 fakeTerminalHelper.IsTerminalStub = terminalHelper.IsTerminal 717 fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo 718 fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize 719 fakeTerminalHelper.StdStreamsReturns(stdin, slave, stderr) 720 terminalHelper = fakeTerminalHelper 721 722 winsize := &term.Winsize{Height: 100, Width: 100} 723 err = term.SetWinsize(slave.Fd(), winsize) 724 Expect(err).NotTo(HaveOccurred()) 725 726 fakeSecureSession.WaitStub = func() error { 727 fakeSecureSession.SendRequestCallCount() 728 Expect(fakeSecureSession.SendRequestCallCount()).To(Equal(0)) 729 730 // No dimension change 731 for i := 0; i < 3; i++ { 732 winsize := &term.Winsize{Height: 100, Width: 100} 733 err = term.SetWinsize(slave.Fd(), winsize) 734 Expect(err).NotTo(HaveOccurred()) 735 } 736 737 winsize := &term.Winsize{Height: 100, Width: 200} 738 err = term.SetWinsize(slave.Fd(), winsize) 739 Expect(err).NotTo(HaveOccurred()) 740 741 err = syscall.Kill(syscall.Getpid(), syscall.SIGWINCH) 742 Expect(err).NotTo(HaveOccurred()) 743 744 Eventually(fakeSecureSession.SendRequestCallCount).Should(Equal(1)) 745 return nil 746 } 747 }) 748 749 AfterEach(func() { 750 master.Close() 751 slave.Close() 752 }) 753 754 It("sends window change events when the window dimensions change", func() { 755 Expect(fakeSecureSession.SendRequestCallCount()).To(Equal(1)) 756 757 requestType, wantReply, message := fakeSecureSession.SendRequestArgsForCall(0) 758 Expect(requestType).To(Equal("window-change")) 759 Expect(wantReply).To(BeFalse()) 760 761 type resizeMessage struct { 762 Width uint32 763 Height uint32 764 PixelWidth uint32 765 PixelHeight uint32 766 } 767 var resizeMsg resizeMessage 768 769 err := ssh.Unmarshal(message, &resizeMsg) 770 Expect(err).NotTo(HaveOccurred()) 771 772 Expect(resizeMsg).To(Equal(resizeMessage{Height: 100, Width: 200})) 773 }) 774 }) 775 776 Describe("keep alive messages", func() { 777 var times []time.Time 778 var timesCh chan []time.Time 779 var done chan struct{} 780 781 BeforeEach(func() { 782 keepAliveDuration = 100 * time.Millisecond 783 784 times = []time.Time{} 785 timesCh = make(chan []time.Time, 1) 786 done = make(chan struct{}, 1) 787 788 fakeConnection.SendRequestStub = func(reqName string, wantReply bool, message []byte) (bool, []byte, error) { 789 Expect(reqName).To(Equal("keepalive@cloudfoundry.org")) 790 Expect(wantReply).To(BeTrue()) 791 Expect(message).To(BeNil()) 792 793 times = append(times, time.Now()) 794 if len(times) == 3 { 795 timesCh <- times 796 close(done) 797 } 798 return true, nil, nil 799 } 800 801 fakeSecureSession.WaitStub = func() error { 802 Eventually(done).Should(BeClosed()) 803 return nil 804 } 805 }) 806 807 It("sends keep alive messages at the expected interval", func() { 808 times := <-timesCh 809 Expect(times[2]).To(BeTemporally("~", times[0].Add(200*time.Millisecond), 100*time.Millisecond)) 810 }) 811 }) 812 }) 813 814 Describe("LocalPortForward", func() { 815 var ( 816 opts *options.SSHOptions 817 localForwardError error 818 819 echoAddress string 820 echoListener *fake_net.FakeListener 821 echoHandler *fake_server.FakeConnectionHandler 822 echoServer *server.Server 823 824 localAddress string 825 826 realLocalListener net.Listener 827 fakeLocalListener *fake_net.FakeListener 828 ) 829 830 BeforeEach(func() { 831 logger := lagertest.NewTestLogger("test") 832 833 var err error 834 realLocalListener, err = net.Listen("tcp", "127.0.0.1:0") 835 Expect(err).NotTo(HaveOccurred()) 836 837 localAddress = realLocalListener.Addr().String() 838 fakeListenerFactory.ListenReturns(realLocalListener, nil) 839 840 echoHandler = &fake_server.FakeConnectionHandler{} 841 echoHandler.HandleConnectionStub = func(conn net.Conn) { 842 io.Copy(conn, conn) 843 conn.Close() 844 } 845 846 realListener, err := net.Listen("tcp", "127.0.0.1:0") 847 Expect(err).NotTo(HaveOccurred()) 848 echoAddress = realListener.Addr().String() 849 850 echoListener = &fake_net.FakeListener{} 851 echoListener.AcceptStub = realListener.Accept 852 echoListener.CloseStub = realListener.Close 853 echoListener.AddrStub = realListener.Addr 854 855 fakeLocalListener = &fake_net.FakeListener{} 856 fakeLocalListener.AcceptReturns(nil, errors.New("Not Accepting Connections")) 857 858 echoServer = server.NewServer(logger.Session("echo"), "", echoHandler) 859 echoServer.SetListener(echoListener) 860 go echoServer.Serve() 861 862 opts = &options.SSHOptions{ 863 AppName: "app-1", 864 ForwardSpecs: []options.ForwardSpec{{ 865 ListenAddress: localAddress, 866 ConnectAddress: echoAddress, 867 }}, 868 } 869 870 currentApp.State = "STARTED" 871 currentApp.Diego = true 872 873 sshEndpointFingerprint = "" 874 sshEndpoint = "" 875 876 token = "" 877 878 fakeSecureClient.DialStub = net.Dial 879 }) 880 881 JustBeforeEach(func() { 882 connectErr := secureShell.Connect(opts) 883 Expect(connectErr).NotTo(HaveOccurred()) 884 885 localForwardError = secureShell.LocalPortForward() 886 }) 887 888 AfterEach(func() { 889 err := secureShell.Close() 890 Expect(err).NotTo(HaveOccurred()) 891 echoServer.Shutdown() 892 893 realLocalListener.Close() 894 }) 895 896 validateConnectivity := func(addr string) { 897 conn, err := net.Dial("tcp", addr) 898 Expect(err).NotTo(HaveOccurred()) 899 900 msg := fmt.Sprintf("Hello from %s\n", addr) 901 n, err := conn.Write([]byte(msg)) 902 Expect(err).NotTo(HaveOccurred()) 903 Expect(n).To(Equal(len(msg))) 904 905 response := make([]byte, len(msg)) 906 n, err = conn.Read(response) 907 Expect(err).NotTo(HaveOccurred()) 908 Expect(n).To(Equal(len(msg))) 909 910 err = conn.Close() 911 Expect(err).NotTo(HaveOccurred()) 912 913 Expect(response).To(Equal([]byte(msg))) 914 } 915 916 It("dials the connect address when a local connection is made", func() { 917 Expect(localForwardError).NotTo(HaveOccurred()) 918 919 conn, err := net.Dial("tcp", localAddress) 920 Expect(err).NotTo(HaveOccurred()) 921 922 Eventually(echoListener.AcceptCallCount).Should(BeNumerically(">=", 1)) 923 Eventually(fakeSecureClient.DialCallCount).Should(Equal(1)) 924 925 network, addr := fakeSecureClient.DialArgsForCall(0) 926 Expect(network).To(Equal("tcp")) 927 Expect(addr).To(Equal(echoAddress)) 928 929 Expect(conn.Close()).NotTo(HaveOccurred()) 930 }) 931 932 It("copies data between the local and remote connections", func() { 933 validateConnectivity(localAddress) 934 }) 935 936 Context("when a local connection is already open", func() { 937 var ( 938 conn net.Conn 939 err error 940 ) 941 942 JustBeforeEach(func() { 943 conn, err = net.Dial("tcp", localAddress) 944 Expect(err).NotTo(HaveOccurred()) 945 }) 946 947 AfterEach(func() { 948 err = conn.Close() 949 Expect(err).NotTo(HaveOccurred()) 950 }) 951 952 It("allows for new incoming connections as well", func() { 953 validateConnectivity(localAddress) 954 }) 955 }) 956 957 Context("when there are multiple port forward specs", func() { 958 var realLocalListener2 net.Listener 959 var localAddress2 string 960 961 BeforeEach(func() { 962 var err error 963 realLocalListener2, err = net.Listen("tcp", "127.0.0.1:0") 964 Expect(err).NotTo(HaveOccurred()) 965 966 localAddress2 = realLocalListener2.Addr().String() 967 968 fakeListenerFactory.ListenStub = func(network, addr string) (net.Listener, error) { 969 if addr == localAddress { 970 return realLocalListener, nil 971 } 972 973 if addr == localAddress2 { 974 return realLocalListener2, nil 975 } 976 977 return nil, errors.New("unexpected address") 978 } 979 980 opts = &options.SSHOptions{ 981 AppName: "app-1", 982 ForwardSpecs: []options.ForwardSpec{{ 983 ListenAddress: localAddress, 984 ConnectAddress: echoAddress, 985 }, { 986 ListenAddress: localAddress2, 987 ConnectAddress: echoAddress, 988 }}, 989 } 990 }) 991 992 AfterEach(func() { 993 realLocalListener2.Close() 994 }) 995 996 It("listens to all the things", func() { 997 Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(2)) 998 999 network, addr := fakeListenerFactory.ListenArgsForCall(0) 1000 Expect(network).To(Equal("tcp")) 1001 Expect(addr).To(Equal(localAddress)) 1002 1003 network, addr = fakeListenerFactory.ListenArgsForCall(1) 1004 Expect(network).To(Equal("tcp")) 1005 Expect(addr).To(Equal(localAddress2)) 1006 }) 1007 1008 It("forwards to the correct target", func() { 1009 validateConnectivity(localAddress) 1010 validateConnectivity(localAddress2) 1011 }) 1012 1013 Context("when the secure client is closed", func() { 1014 BeforeEach(func() { 1015 fakeListenerFactory.ListenReturns(fakeLocalListener, nil) 1016 fakeLocalListener.AcceptReturns(nil, errors.New("not accepting connections")) 1017 }) 1018 1019 It("closes the listeners ", func() { 1020 Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(2)) 1021 Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(2)) 1022 1023 originalCloseCount := fakeLocalListener.CloseCallCount() 1024 err := secureShell.Close() 1025 Expect(err).NotTo(HaveOccurred()) 1026 Expect(fakeLocalListener.CloseCallCount()).Should(Equal(originalCloseCount + 2)) 1027 }) 1028 }) 1029 }) 1030 1031 Context("when listen fails", func() { 1032 BeforeEach(func() { 1033 fakeListenerFactory.ListenReturns(nil, errors.New("failure is an option")) 1034 }) 1035 1036 It("returns the error", func() { 1037 Expect(localForwardError).To(MatchError("failure is an option")) 1038 }) 1039 }) 1040 1041 Context("when the client it closed", func() { 1042 BeforeEach(func() { 1043 fakeListenerFactory.ListenReturns(fakeLocalListener, nil) 1044 fakeLocalListener.AcceptReturns(nil, errors.New("not accepting and connections")) 1045 }) 1046 1047 It("closes the listener when the client is closed", func() { 1048 Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(1)) 1049 Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(1)) 1050 1051 originalCloseCount := fakeLocalListener.CloseCallCount() 1052 err := secureShell.Close() 1053 Expect(err).NotTo(HaveOccurred()) 1054 Expect(fakeLocalListener.CloseCallCount()).Should(Equal(originalCloseCount + 1)) 1055 }) 1056 }) 1057 1058 Context("when accept fails", func() { 1059 var fakeConn *fake_net.FakeConn 1060 BeforeEach(func() { 1061 fakeConn = &fake_net.FakeConn{} 1062 fakeConn.ReadReturns(0, io.EOF) 1063 1064 fakeListenerFactory.ListenReturns(fakeLocalListener, nil) 1065 }) 1066 1067 Context("with a permanent error", func() { 1068 BeforeEach(func() { 1069 fakeLocalListener.AcceptReturns(nil, errors.New("boom")) 1070 }) 1071 1072 It("stops trying to accept connections", func() { 1073 Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(1)) 1074 Consistently(fakeLocalListener.AcceptCallCount).Should(Equal(1)) 1075 Expect(fakeLocalListener.CloseCallCount()).To(Equal(1)) 1076 }) 1077 }) 1078 1079 Context("with a temporary error", func() { 1080 var timeCh chan time.Time 1081 1082 BeforeEach(func() { 1083 timeCh = make(chan time.Time, 3) 1084 1085 fakeLocalListener.AcceptStub = func() (net.Conn, error) { 1086 timeCh := timeCh 1087 if fakeLocalListener.AcceptCallCount() > 3 { 1088 close(timeCh) 1089 return nil, test_helpers.NewTestNetError(false, false) 1090 } else { 1091 timeCh <- time.Now() 1092 return nil, test_helpers.NewTestNetError(false, true) 1093 } 1094 } 1095 }) 1096 1097 It("retries connecting after a short delay", func() { 1098 Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(3)) 1099 Expect(timeCh).To(HaveLen(3)) 1100 1101 times := make([]time.Time, 0) 1102 for t := range timeCh { 1103 times = append(times, t) 1104 } 1105 1106 Expect(times[1]).To(BeTemporally("~", times[0].Add(115*time.Millisecond), 30*time.Millisecond)) 1107 Expect(times[2]).To(BeTemporally("~", times[1].Add(115*time.Millisecond), 30*time.Millisecond)) 1108 }) 1109 }) 1110 }) 1111 1112 Context("when dialing the connect address fails", func() { 1113 var fakeTarget *fake_net.FakeConn 1114 1115 BeforeEach(func() { 1116 fakeTarget = &fake_net.FakeConn{} 1117 fakeSecureClient.DialReturns(fakeTarget, errors.New("boom")) 1118 }) 1119 1120 It("does not call close on the target connection", func() { 1121 Consistently(fakeTarget.CloseCallCount).Should(Equal(0)) 1122 }) 1123 }) 1124 }) 1125 1126 Describe("Wait", func() { 1127 var opts *options.SSHOptions 1128 var waitErr error 1129 1130 BeforeEach(func() { 1131 opts = &options.SSHOptions{ 1132 AppName: "app-1", 1133 } 1134 1135 currentApp.State = "STARTED" 1136 currentApp.Diego = true 1137 1138 sshEndpointFingerprint = "" 1139 sshEndpoint = "" 1140 1141 token = "" 1142 }) 1143 1144 JustBeforeEach(func() { 1145 connectErr := secureShell.Connect(opts) 1146 Expect(connectErr).NotTo(HaveOccurred()) 1147 1148 waitErr = secureShell.Wait() 1149 }) 1150 1151 It("calls wait on the secureClient", func() { 1152 Expect(waitErr).NotTo(HaveOccurred()) 1153 Expect(fakeSecureClient.WaitCallCount()).To(Equal(1)) 1154 }) 1155 1156 Describe("keep alive messages", func() { 1157 var times []time.Time 1158 var timesCh chan []time.Time 1159 var done chan struct{} 1160 1161 BeforeEach(func() { 1162 keepAliveDuration = 100 * time.Millisecond 1163 1164 times = []time.Time{} 1165 timesCh = make(chan []time.Time, 1) 1166 done = make(chan struct{}, 1) 1167 1168 fakeConnection.SendRequestStub = func(reqName string, wantReply bool, message []byte) (bool, []byte, error) { 1169 Expect(reqName).To(Equal("keepalive@cloudfoundry.org")) 1170 Expect(wantReply).To(BeTrue()) 1171 Expect(message).To(BeNil()) 1172 1173 times = append(times, time.Now()) 1174 if len(times) == 3 { 1175 timesCh <- times 1176 close(done) 1177 } 1178 return true, nil, nil 1179 } 1180 1181 fakeSecureClient.WaitStub = func() error { 1182 Eventually(done).Should(BeClosed()) 1183 return nil 1184 } 1185 }) 1186 1187 It("sends keep alive messages at the expected interval", func() { 1188 Expect(waitErr).NotTo(HaveOccurred()) 1189 times := <-timesCh 1190 Expect(times[2]).To(BeTemporally("~", times[0].Add(200*time.Millisecond), 100*time.Millisecond)) 1191 }) 1192 }) 1193 }) 1194 1195 Describe("Close", func() { 1196 var opts *options.SSHOptions 1197 1198 BeforeEach(func() { 1199 opts = &options.SSHOptions{ 1200 AppName: "app-1", 1201 } 1202 1203 currentApp.State = "STARTED" 1204 currentApp.Diego = true 1205 1206 sshEndpointFingerprint = "" 1207 sshEndpoint = "" 1208 1209 token = "" 1210 }) 1211 1212 JustBeforeEach(func() { 1213 connectErr := secureShell.Connect(opts) 1214 Expect(connectErr).NotTo(HaveOccurred()) 1215 }) 1216 1217 It("calls close on the secureClient", func() { 1218 err := secureShell.Close() 1219 Expect(err).NotTo(HaveOccurred()) 1220 1221 Expect(fakeSecureClient.CloseCallCount()).To(Equal(1)) 1222 }) 1223 }) 1224 })