github.com/sleungcy-sap/cli@v7.1.0+incompatible/cf/ssh/ssh_test.go (about) 1 // +build !windows,!386 2 3 // skipping 386 because lager uses UInt64 in Session() 4 // skipping windows because Unix/Linux only syscall in test. 5 // should refactor out the conflicts so we could test this package in multi platforms. 6 7 package sshCmd_test 8 9 import ( 10 "errors" 11 "fmt" 12 "io" 13 "net" 14 "os" 15 "syscall" 16 "time" 17 18 "code.cloudfoundry.org/cli/cf/models" 19 sshCmd "code.cloudfoundry.org/cli/cf/ssh" 20 "code.cloudfoundry.org/cli/cf/ssh/options" 21 "code.cloudfoundry.org/cli/cf/ssh/sshfakes" 22 "code.cloudfoundry.org/cli/cf/ssh/terminal" 23 "code.cloudfoundry.org/cli/cf/ssh/terminal/terminalfakes" 24 "code.cloudfoundry.org/diego-ssh/server" 25 fake_server "code.cloudfoundry.org/diego-ssh/server/fakes" 26 "code.cloudfoundry.org/diego-ssh/test_helpers" 27 "code.cloudfoundry.org/diego-ssh/test_helpers/fake_io" 28 "code.cloudfoundry.org/diego-ssh/test_helpers/fake_net" 29 "code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh" 30 "code.cloudfoundry.org/lager/lagertest" 31 "github.com/kr/pty" 32 "github.com/moby/moby/pkg/term" 33 "golang.org/x/crypto/ssh" 34 35 . "github.com/onsi/ginkgo" 36 . "github.com/onsi/gomega" 37 ) 38 39 var _ = Describe("SSH", func() { 40 var ( 41 fakeTerminalHelper *terminalfakes.FakeTerminalHelper 42 fakeListenerFactory *sshfakes.FakeListenerFactory 43 44 fakeConnection *fake_ssh.FakeConn 45 fakeSecureClient *sshfakes.FakeSecureClient 46 fakeSecureDialer *sshfakes.FakeSecureDialer 47 fakeSecureSession *sshfakes.FakeSecureSession 48 49 terminalHelper terminal.TerminalHelper 50 keepAliveDuration time.Duration 51 secureShell sshCmd.SecureShell 52 53 stdinPipe *fake_io.FakeWriteCloser 54 55 currentApp models.Application 56 sshEndpointFingerprint string 57 sshEndpoint string 58 token string 59 ) 60 61 BeforeEach(func() { 62 fakeTerminalHelper = new(terminalfakes.FakeTerminalHelper) 63 terminalHelper = terminal.DefaultHelper() 64 65 fakeListenerFactory = new(sshfakes.FakeListenerFactory) 66 fakeListenerFactory.ListenStub = net.Listen 67 68 keepAliveDuration = 30 * time.Second 69 70 currentApp = models.Application{} 71 sshEndpoint = "" 72 sshEndpointFingerprint = "" 73 token = "" 74 75 fakeConnection = new(fake_ssh.FakeConn) 76 fakeSecureClient = new(sshfakes.FakeSecureClient) 77 fakeSecureDialer = new(sshfakes.FakeSecureDialer) 78 fakeSecureSession = new(sshfakes.FakeSecureSession) 79 80 fakeSecureDialer.DialReturns(fakeSecureClient, nil) 81 fakeSecureClient.NewSessionReturns(fakeSecureSession, nil) 82 fakeSecureClient.ConnReturns(fakeConnection) 83 84 stdinPipe = &fake_io.FakeWriteCloser{} 85 stdinPipe.WriteStub = func(p []byte) (int, error) { 86 return len(p), nil 87 } 88 89 stdoutPipe := &fake_io.FakeReader{} 90 stdoutPipe.ReadStub = func(p []byte) (int, error) { 91 return 0, io.EOF 92 } 93 94 stderrPipe := &fake_io.FakeReader{} 95 stderrPipe.ReadStub = func(p []byte) (int, error) { 96 return 0, io.EOF 97 } 98 99 fakeSecureSession.StdinPipeReturns(stdinPipe, nil) 100 fakeSecureSession.StdoutPipeReturns(stdoutPipe, nil) 101 fakeSecureSession.StderrPipeReturns(stderrPipe, nil) 102 }) 103 104 JustBeforeEach(func() { 105 secureShell = sshCmd.NewSecureShell( 106 fakeSecureDialer, 107 terminalHelper, 108 fakeListenerFactory, 109 keepAliveDuration, 110 currentApp, 111 sshEndpointFingerprint, 112 sshEndpoint, 113 token, 114 ) 115 }) 116 117 Describe("Validation", func() { 118 var connectErr error 119 var opts *options.SSHOptions 120 121 BeforeEach(func() { 122 opts = &options.SSHOptions{ 123 AppName: "app-1", 124 } 125 }) 126 127 JustBeforeEach(func() { 128 connectErr = secureShell.Connect(opts) 129 }) 130 131 Context("when the app model and endpoint info are successfully acquired", func() { 132 BeforeEach(func() { 133 token = "" 134 currentApp.State = "STARTED" 135 }) 136 137 Context("when the app is not in the 'STARTED' state", func() { 138 BeforeEach(func() { 139 currentApp.State = "STOPPED" 140 }) 141 142 It("returns an error", func() { 143 Expect(connectErr).To(MatchError(MatchRegexp("Application.*not in the STARTED state"))) 144 }) 145 }) 146 147 Context("when dialing fails", func() { 148 var dialError = errors.New("woops") 149 150 BeforeEach(func() { 151 fakeSecureDialer.DialReturns(nil, dialError) 152 }) 153 154 It("returns the dial error", func() { 155 Expect(connectErr).To(Equal(dialError)) 156 Expect(fakeSecureDialer.DialCallCount()).To(Equal(1)) 157 }) 158 }) 159 }) 160 }) 161 162 Describe("InteractiveSession", func() { 163 var opts *options.SSHOptions 164 var sessionError error 165 var interactiveSessionInvoker func(secureShell sshCmd.SecureShell) 166 167 BeforeEach(func() { 168 sshEndpoint = "ssh.example.com:22" 169 170 opts = &options.SSHOptions{ 171 AppName: "app-name", 172 Index: 2, 173 } 174 175 currentApp.State = "STARTED" 176 currentApp.GUID = "app-guid" 177 token = "bearer token" 178 179 interactiveSessionInvoker = func(secureShell sshCmd.SecureShell) { 180 sessionError = secureShell.InteractiveSession() 181 } 182 }) 183 184 JustBeforeEach(func() { 185 connectErr := secureShell.Connect(opts) 186 Expect(connectErr).NotTo(HaveOccurred()) 187 interactiveSessionInvoker(secureShell) 188 }) 189 190 It("dials the correct endpoint as the correct user", func() { 191 Expect(fakeSecureDialer.DialCallCount()).To(Equal(1)) 192 193 network, address, config := fakeSecureDialer.DialArgsForCall(0) 194 Expect(network).To(Equal("tcp")) 195 Expect(address).To(Equal("ssh.example.com:22")) 196 Expect(config.Auth).NotTo(BeEmpty()) 197 Expect(config.User).To(Equal("cf:app-guid/2")) 198 Expect(config.HostKeyCallback).NotTo(BeNil()) 199 }) 200 201 Context("when host key validation is enabled", func() { 202 var callback func(hostname string, remote net.Addr, key ssh.PublicKey) error 203 var addr net.Addr 204 205 JustBeforeEach(func() { 206 Expect(fakeSecureDialer.DialCallCount()).To(Equal(1)) 207 _, _, config := fakeSecureDialer.DialArgsForCall(0) 208 callback = config.HostKeyCallback 209 210 listener, err := net.Listen("tcp", "localhost:0") 211 Expect(err).NotTo(HaveOccurred()) 212 213 addr = listener.Addr() 214 listener.Close() 215 }) 216 217 Context("when the md5 fingerprint matches", func() { 218 BeforeEach(func() { 219 sshEndpointFingerprint = "41:ce:56:e6:9c:42:a9:c6:9e:68:ac:e3:4d:f6:38:79" 220 }) 221 222 It("does not return an error", func() { 223 Expect(callback("", addr, TestHostKey.PublicKey())).ToNot(HaveOccurred()) 224 }) 225 }) 226 227 Context("when the hex sha1 fingerprint matches", func() { 228 BeforeEach(func() { 229 sshEndpointFingerprint = "a8:e2:67:cb:ea:2a:6e:23:a1:72:ce:8f:07:92:15:ee:1f:82:f8:ca" 230 }) 231 232 It("does not return an error", func() { 233 Expect(callback("", addr, TestHostKey.PublicKey())).ToNot(HaveOccurred()) 234 }) 235 }) 236 237 Context("when the base64 sha256 fingerprint matches", func() { 238 BeforeEach(func() { 239 sshEndpointFingerprint = "sp/jrLuj66r+yrLDUKZdJU5tdzt4mq/UaSiNBjpgr+8" 240 }) 241 242 It("does not return an error", func() { 243 Expect(callback("", addr, TestHostKey.PublicKey())).ToNot(HaveOccurred()) 244 }) 245 }) 246 247 Context("when the base64 SHA256 fingerprint does not match", func() { 248 BeforeEach(func() { 249 sshEndpointFingerprint = "0000000000000000000000000000000000000000000" 250 }) 251 252 It("returns an error'", func() { 253 err := callback("", addr, TestHostKey.PublicKey()) 254 Expect(err).To(MatchError(MatchRegexp("Host key verification failed\\."))) 255 Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\""))) 256 }) 257 }) 258 259 Context("when the hex SHA1 fingerprint does not match", func() { 260 BeforeEach(func() { 261 sshEndpointFingerprint = "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00" 262 }) 263 264 It("returns an error'", func() { 265 err := callback("", addr, TestHostKey.PublicKey()) 266 Expect(err).To(MatchError(MatchRegexp("Host key verification failed\\."))) 267 Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\""))) 268 }) 269 }) 270 271 Context("when the MD5 fingerprint does not match", func() { 272 BeforeEach(func() { 273 sshEndpointFingerprint = "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00" 274 }) 275 276 It("returns an error'", func() { 277 err := callback("", addr, TestHostKey.PublicKey()) 278 Expect(err).To(MatchError(MatchRegexp("Host key verification failed\\."))) 279 Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\""))) 280 }) 281 }) 282 283 Context("when no fingerprint is present in endpoint info", func() { 284 BeforeEach(func() { 285 sshEndpointFingerprint = "" 286 sshEndpoint = "" 287 }) 288 289 It("returns an error'", func() { 290 err := callback("", addr, TestHostKey.PublicKey()) 291 Expect(err).To(MatchError(MatchRegexp("Unable to verify identity of host\\."))) 292 Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\""))) 293 }) 294 }) 295 296 Context("when the fingerprint length doesn't make sense", func() { 297 BeforeEach(func() { 298 sshEndpointFingerprint = "garbage" 299 }) 300 301 It("returns an error", func() { 302 err := callback("", addr, TestHostKey.PublicKey()) 303 Eventually(err).Should(MatchError(MatchRegexp("Unsupported host key fingerprint format"))) 304 }) 305 }) 306 }) 307 308 Context("when the skip host validation flag is set", func() { 309 BeforeEach(func() { 310 opts.SkipHostValidation = true 311 }) 312 313 It("removes the HostKeyCallback from the client config", func() { 314 Expect(fakeSecureDialer.DialCallCount()).To(Equal(1)) 315 316 _, _, config := fakeSecureDialer.DialArgsForCall(0) 317 Expect(config.HostKeyCallback("some-addr", nil, nil)).To(BeNil()) 318 }) 319 }) 320 321 Context("when dialing is successful", func() { 322 BeforeEach(func() { 323 fakeTerminalHelper.StdStreamsStub = terminalHelper.StdStreams 324 terminalHelper = fakeTerminalHelper 325 }) 326 327 It("creates a new secure shell session", func() { 328 Expect(fakeSecureClient.NewSessionCallCount()).To(Equal(1)) 329 }) 330 331 It("closes the session", func() { 332 Expect(fakeSecureSession.CloseCallCount()).To(Equal(1)) 333 }) 334 335 It("allocates standard streams", func() { 336 Expect(fakeTerminalHelper.StdStreamsCallCount()).To(Equal(1)) 337 }) 338 339 It("gets a stdin pipe for the session", func() { 340 Expect(fakeSecureSession.StdinPipeCallCount()).To(Equal(1)) 341 }) 342 343 Context("when getting the stdin pipe fails", func() { 344 BeforeEach(func() { 345 fakeSecureSession.StdinPipeReturns(nil, errors.New("woops")) 346 }) 347 348 It("returns the error", func() { 349 Expect(sessionError).Should(MatchError("woops")) 350 }) 351 }) 352 353 It("gets a stdout pipe for the session", func() { 354 Expect(fakeSecureSession.StdoutPipeCallCount()).To(Equal(1)) 355 }) 356 357 Context("when getting the stdout pipe fails", func() { 358 BeforeEach(func() { 359 fakeSecureSession.StdoutPipeReturns(nil, errors.New("woops")) 360 }) 361 362 It("returns the error", func() { 363 Expect(sessionError).Should(MatchError("woops")) 364 }) 365 }) 366 367 It("gets a stderr pipe for the session", func() { 368 Expect(fakeSecureSession.StderrPipeCallCount()).To(Equal(1)) 369 }) 370 371 Context("when getting the stderr pipe fails", func() { 372 BeforeEach(func() { 373 fakeSecureSession.StderrPipeReturns(nil, errors.New("woops")) 374 }) 375 376 It("returns the error", func() { 377 Expect(sessionError).Should(MatchError("woops")) 378 }) 379 }) 380 }) 381 382 Context("when stdin is a terminal", func() { 383 var master, slave *os.File 384 385 BeforeEach(func() { 386 _, stdout, stderr := terminalHelper.StdStreams() 387 388 var err error 389 master, slave, err = pty.Open() 390 Expect(err).NotTo(HaveOccurred()) 391 392 fakeTerminalHelper.IsTerminalStub = terminalHelper.IsTerminal 393 fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo 394 fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize 395 fakeTerminalHelper.StdStreamsReturns(slave, stdout, stderr) 396 terminalHelper = fakeTerminalHelper 397 }) 398 399 AfterEach(func() { 400 master.Close() 401 // slave.Close() // race 402 }) 403 404 Context("when a command is not specified", func() { 405 var terminalType string 406 407 BeforeEach(func() { 408 terminalType = os.Getenv("TERM") 409 os.Setenv("TERM", "test-terminal-type") 410 411 winsize := &term.Winsize{Width: 1024, Height: 256} 412 fakeTerminalHelper.GetWinsizeReturns(winsize, nil) 413 414 fakeSecureSession.ShellStub = func() error { 415 Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1)) 416 Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(0)) 417 return nil 418 } 419 }) 420 421 AfterEach(func() { 422 os.Setenv("TERM", terminalType) 423 }) 424 425 It("requests a pty with the correct terminal type, window size, and modes", func() { 426 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1)) 427 Expect(fakeTerminalHelper.GetWinsizeCallCount()).To(Equal(1)) 428 429 termType, height, width, modes := fakeSecureSession.RequestPtyArgsForCall(0) 430 Expect(termType).To(Equal("test-terminal-type")) 431 Expect(height).To(Equal(256)) 432 Expect(width).To(Equal(1024)) 433 434 expectedModes := ssh.TerminalModes{ 435 ssh.ECHO: 1, 436 ssh.TTY_OP_ISPEED: 115200, 437 ssh.TTY_OP_OSPEED: 115200, 438 } 439 Expect(modes).To(Equal(expectedModes)) 440 }) 441 442 Context("when the TERM environment variable is not set", func() { 443 BeforeEach(func() { 444 os.Unsetenv("TERM") 445 }) 446 447 It("requests a pty with the default terminal type", func() { 448 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1)) 449 450 termType, _, _, _ := fakeSecureSession.RequestPtyArgsForCall(0) 451 Expect(termType).To(Equal("xterm")) 452 }) 453 }) 454 455 It("puts the terminal into raw mode and restores it after running the shell", func() { 456 Expect(fakeSecureSession.ShellCallCount()).To(Equal(1)) 457 Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1)) 458 Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(1)) 459 }) 460 461 Context("when the pty allocation fails", func() { 462 var ptyError error 463 464 BeforeEach(func() { 465 ptyError = errors.New("pty allocation error") 466 fakeSecureSession.RequestPtyReturns(ptyError) 467 }) 468 469 It("returns the error", func() { 470 Expect(sessionError).To(Equal(ptyError)) 471 }) 472 }) 473 474 Context("when placing the terminal into raw mode fails", func() { 475 BeforeEach(func() { 476 fakeTerminalHelper.SetRawTerminalReturns(nil, errors.New("woops")) 477 }) 478 479 It("keeps calm and carries on", func() { 480 Expect(fakeSecureSession.ShellCallCount()).To(Equal(1)) 481 }) 482 483 It("does not not restore the terminal", func() { 484 Expect(fakeSecureSession.ShellCallCount()).To(Equal(1)) 485 Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1)) 486 Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(0)) 487 }) 488 }) 489 }) 490 491 Context("when a command is specified", func() { 492 BeforeEach(func() { 493 opts.Command = []string{"echo", "-n", "hello"} 494 }) 495 496 Context("when a terminal is requested", func() { 497 BeforeEach(func() { 498 opts.TerminalRequest = options.RequestTTYYes 499 }) 500 501 It("requests a pty", func() { 502 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1)) 503 }) 504 }) 505 506 Context("when a terminal is not explicitly requested", func() { 507 It("does not request a pty", func() { 508 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0)) 509 }) 510 }) 511 }) 512 }) 513 514 Context("when stdin is not a terminal", func() { 515 BeforeEach(func() { 516 _, stdout, stderr := terminalHelper.StdStreams() 517 518 stdin := &fake_io.FakeReadCloser{} 519 stdin.ReadStub = func(p []byte) (int, error) { 520 return 0, io.EOF 521 } 522 523 fakeTerminalHelper.IsTerminalStub = terminalHelper.IsTerminal 524 fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo 525 fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize 526 fakeTerminalHelper.StdStreamsReturns(stdin, stdout, stderr) 527 terminalHelper = fakeTerminalHelper 528 }) 529 530 Context("when a terminal is not requested", func() { 531 It("does not request a pty", func() { 532 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0)) 533 }) 534 }) 535 536 Context("when a terminal is requested", func() { 537 BeforeEach(func() { 538 opts.TerminalRequest = options.RequestTTYYes 539 }) 540 541 It("does not request a pty", func() { 542 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0)) 543 }) 544 }) 545 }) 546 547 Context("when a terminal is forced", func() { 548 BeforeEach(func() { 549 opts.TerminalRequest = options.RequestTTYForce 550 }) 551 552 It("requests a pty", func() { 553 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1)) 554 }) 555 }) 556 557 Context("when a terminal is disabled", func() { 558 BeforeEach(func() { 559 opts.TerminalRequest = options.RequestTTYNo 560 }) 561 562 It("does not request a pty", func() { 563 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0)) 564 }) 565 }) 566 567 Context("when a command is not specified", func() { 568 It("requests an interactive shell", func() { 569 Expect(fakeSecureSession.ShellCallCount()).To(Equal(1)) 570 }) 571 572 Context("when the shell request returns an error", func() { 573 BeforeEach(func() { 574 fakeSecureSession.ShellReturns(errors.New("oh bother")) 575 }) 576 577 It("returns the error", func() { 578 Expect(sessionError).To(MatchError("oh bother")) 579 }) 580 }) 581 }) 582 583 Context("when a command is specifed", func() { 584 BeforeEach(func() { 585 opts.Command = []string{"echo", "-n", "hello"} 586 }) 587 588 It("starts the command", func() { 589 Expect(fakeSecureSession.StartCallCount()).To(Equal(1)) 590 Expect(fakeSecureSession.StartArgsForCall(0)).To(Equal("echo -n hello")) 591 }) 592 593 Context("when the command fails to start", func() { 594 BeforeEach(func() { 595 fakeSecureSession.StartReturns(errors.New("oh well")) 596 }) 597 598 It("returns the error", func() { 599 Expect(sessionError).To(MatchError("oh well")) 600 }) 601 }) 602 }) 603 604 Context("when the shell or command has started", func() { 605 var ( 606 stdin *fake_io.FakeReadCloser 607 stdout, stderr *fake_io.FakeWriter 608 stdinPipe *fake_io.FakeWriteCloser 609 stdoutPipe, stderrPipe *fake_io.FakeReader 610 ) 611 612 BeforeEach(func() { 613 stdin = &fake_io.FakeReadCloser{} 614 stdin.ReadStub = func(p []byte) (int, error) { 615 p[0] = 0 616 return 1, io.EOF 617 } 618 stdinPipe = &fake_io.FakeWriteCloser{} 619 stdinPipe.WriteStub = func(p []byte) (int, error) { 620 defer GinkgoRecover() 621 Expect(p[0]).To(Equal(byte(0))) 622 return 1, nil 623 } 624 625 stdoutPipe = &fake_io.FakeReader{} 626 stdoutPipe.ReadStub = func(p []byte) (int, error) { 627 p[0] = 1 628 return 1, io.EOF 629 } 630 stdout = &fake_io.FakeWriter{} 631 stdout.WriteStub = func(p []byte) (int, error) { 632 defer GinkgoRecover() 633 Expect(p[0]).To(Equal(byte(1))) 634 return 1, nil 635 } 636 637 stderrPipe = &fake_io.FakeReader{} 638 stderrPipe.ReadStub = func(p []byte) (int, error) { 639 p[0] = 2 640 return 1, io.EOF 641 } 642 stderr = &fake_io.FakeWriter{} 643 stderr.WriteStub = func(p []byte) (int, error) { 644 defer GinkgoRecover() 645 Expect(p[0]).To(Equal(byte(2))) 646 return 1, nil 647 } 648 649 fakeTerminalHelper.StdStreamsReturns(stdin, stdout, stderr) 650 terminalHelper = fakeTerminalHelper 651 652 fakeSecureSession.StdinPipeReturns(stdinPipe, nil) 653 fakeSecureSession.StdoutPipeReturns(stdoutPipe, nil) 654 fakeSecureSession.StderrPipeReturns(stderrPipe, nil) 655 656 fakeSecureSession.WaitReturns(errors.New("error result")) 657 }) 658 659 It("copies data from the stdin stream to the session stdin pipe", func() { 660 Eventually(stdin.ReadCallCount).Should(Equal(1)) 661 Eventually(stdinPipe.WriteCallCount).Should(Equal(1)) 662 }) 663 664 It("copies data from the session stdout pipe to the stdout stream", func() { 665 Eventually(stdoutPipe.ReadCallCount).Should(Equal(1)) 666 Eventually(stdout.WriteCallCount).Should(Equal(1)) 667 }) 668 669 It("copies data from the session stderr pipe to the stderr stream", func() { 670 Eventually(stderrPipe.ReadCallCount).Should(Equal(1)) 671 Eventually(stderr.WriteCallCount).Should(Equal(1)) 672 }) 673 674 It("waits for the session to end", func() { 675 Expect(fakeSecureSession.WaitCallCount()).To(Equal(1)) 676 }) 677 678 It("returns the result from wait", func() { 679 Expect(sessionError).To(MatchError("error result")) 680 }) 681 682 Context("when the session terminates before stream copies complete", func() { 683 var sessionErrorCh chan error 684 685 BeforeEach(func() { 686 sessionErrorCh = make(chan error, 1) 687 688 interactiveSessionInvoker = func(secureShell sshCmd.SecureShell) { 689 go func() { sessionErrorCh <- secureShell.InteractiveSession() }() 690 } 691 692 stdoutPipe.ReadStub = func(p []byte) (int, error) { 693 defer GinkgoRecover() 694 Eventually(fakeSecureSession.WaitCallCount).Should(Equal(1)) 695 Consistently(sessionErrorCh).ShouldNot(Receive()) 696 697 p[0] = 1 698 return 1, io.EOF 699 } 700 701 stderrPipe.ReadStub = func(p []byte) (int, error) { 702 defer GinkgoRecover() 703 Eventually(fakeSecureSession.WaitCallCount).Should(Equal(1)) 704 Consistently(sessionErrorCh).ShouldNot(Receive()) 705 706 p[0] = 2 707 return 1, io.EOF 708 } 709 }) 710 711 It("waits for the copies to complete", func() { 712 Eventually(sessionErrorCh).Should(Receive()) 713 Expect(stdoutPipe.ReadCallCount()).To(Equal(1)) 714 Expect(stderrPipe.ReadCallCount()).To(Equal(1)) 715 }) 716 }) 717 718 Context("when stdin is closed", func() { 719 BeforeEach(func() { 720 stdin.ReadStub = func(p []byte) (int, error) { 721 defer GinkgoRecover() 722 Consistently(stdinPipe.CloseCallCount).Should(Equal(0)) 723 p[0] = 0 724 return 1, io.EOF 725 } 726 }) 727 728 It("closes the stdinPipe", func() { 729 Eventually(stdinPipe.CloseCallCount).Should(Equal(1)) 730 }) 731 }) 732 }) 733 734 Context("when stdout is a terminal and a window size change occurs", func() { 735 var master, slave *os.File 736 737 BeforeEach(func() { 738 stdin, _, stderr := terminalHelper.StdStreams() 739 740 var err error 741 master, slave, err = pty.Open() 742 Expect(err).NotTo(HaveOccurred()) 743 744 fakeTerminalHelper.IsTerminalStub = terminalHelper.IsTerminal 745 fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo 746 fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize 747 fakeTerminalHelper.StdStreamsReturns(stdin, slave, stderr) 748 terminalHelper = fakeTerminalHelper 749 750 winsize := &term.Winsize{Height: 100, Width: 100} 751 err = term.SetWinsize(slave.Fd(), winsize) 752 Expect(err).NotTo(HaveOccurred()) 753 754 fakeSecureSession.WaitStub = func() error { 755 fakeSecureSession.SendRequestCallCount() 756 Expect(fakeSecureSession.SendRequestCallCount()).To(Equal(0)) 757 758 // No dimension change 759 for i := 0; i < 3; i++ { 760 winsize := &term.Winsize{Height: 100, Width: 100} 761 err = term.SetWinsize(slave.Fd(), winsize) 762 Expect(err).NotTo(HaveOccurred()) 763 } 764 765 winsize := &term.Winsize{Height: 100, Width: 200} 766 err = term.SetWinsize(slave.Fd(), winsize) 767 Expect(err).NotTo(HaveOccurred()) 768 769 err = syscall.Kill(syscall.Getpid(), syscall.SIGWINCH) 770 Expect(err).NotTo(HaveOccurred()) 771 772 Eventually(fakeSecureSession.SendRequestCallCount).Should(Equal(1)) 773 return nil 774 } 775 }) 776 777 AfterEach(func() { 778 master.Close() 779 slave.Close() 780 }) 781 782 It("sends window change events when the window dimensions change", func() { 783 Expect(fakeSecureSession.SendRequestCallCount()).To(Equal(1)) 784 785 requestType, wantReply, message := fakeSecureSession.SendRequestArgsForCall(0) 786 Expect(requestType).To(Equal("window-change")) 787 Expect(wantReply).To(BeFalse()) 788 789 type resizeMessage struct { 790 Width uint32 791 Height uint32 792 PixelWidth uint32 793 PixelHeight uint32 794 } 795 var resizeMsg resizeMessage 796 797 err := ssh.Unmarshal(message, &resizeMsg) 798 Expect(err).NotTo(HaveOccurred()) 799 800 Expect(resizeMsg).To(Equal(resizeMessage{Height: 100, Width: 200})) 801 }) 802 }) 803 804 Describe("keep alive messages", func() { 805 var times []time.Time 806 var timesCh chan []time.Time 807 var done chan struct{} 808 809 BeforeEach(func() { 810 keepAliveDuration = 100 * time.Millisecond 811 812 times = []time.Time{} 813 timesCh = make(chan []time.Time, 1) 814 done = make(chan struct{}, 1) 815 816 fakeConnection.SendRequestStub = func(reqName string, wantReply bool, message []byte) (bool, []byte, error) { 817 Expect(reqName).To(Equal("keepalive@cloudfoundry.org")) 818 Expect(wantReply).To(BeTrue()) 819 Expect(message).To(BeNil()) 820 821 times = append(times, time.Now()) 822 if len(times) == 3 { 823 timesCh <- times 824 close(done) 825 } 826 return true, nil, nil 827 } 828 829 fakeSecureSession.WaitStub = func() error { 830 Eventually(done).Should(BeClosed()) 831 return nil 832 } 833 }) 834 835 It("sends keep alive messages at the expected interval", func() { 836 times := <-timesCh 837 838 // Expected interval time = 100 msec 839 Expect(times[1]).To(BeTemporally(">=", times[0])) 840 Expect(times[1]).To(BeTemporally("<=", times[0].Add(500*time.Millisecond))) 841 Expect(times[2]).To(BeTemporally(">=", times[1])) 842 Expect(times[2]).To(BeTemporally("<=", times[1].Add(500*time.Millisecond))) 843 }) 844 }) 845 }) 846 847 Describe("LocalPortForward", func() { 848 var ( 849 opts *options.SSHOptions 850 localForwardError error 851 852 echoAddress string 853 echoListener *fake_net.FakeListener 854 echoHandler *fake_server.FakeConnectionHandler 855 echoServer *server.Server 856 857 localAddress string 858 859 realLocalListener net.Listener 860 fakeLocalListener *fake_net.FakeListener 861 ) 862 863 BeforeEach(func() { 864 logger := lagertest.NewTestLogger("test") 865 866 var err error 867 realLocalListener, err = net.Listen("tcp", "127.0.0.1:0") 868 Expect(err).NotTo(HaveOccurred()) 869 870 localAddress = realLocalListener.Addr().String() 871 fakeListenerFactory.ListenReturns(realLocalListener, nil) 872 873 echoHandler = &fake_server.FakeConnectionHandler{} 874 echoHandler.HandleConnectionStub = func(conn net.Conn) { 875 io.Copy(conn, conn) 876 conn.Close() 877 } 878 879 realListener, err := net.Listen("tcp", "127.0.0.1:0") 880 Expect(err).NotTo(HaveOccurred()) 881 echoAddress = realListener.Addr().String() 882 883 echoListener = &fake_net.FakeListener{} 884 echoListener.AcceptStub = realListener.Accept 885 echoListener.CloseStub = realListener.Close 886 echoListener.AddrStub = realListener.Addr 887 888 fakeLocalListener = &fake_net.FakeListener{} 889 fakeLocalListener.AcceptReturns(nil, errors.New("Not Accepting Connections")) 890 891 echoServer = server.NewServer(logger.Session("echo"), "", echoHandler) 892 echoServer.SetListener(echoListener) 893 go echoServer.Serve() 894 895 opts = &options.SSHOptions{ 896 AppName: "app-1", 897 ForwardSpecs: []options.ForwardSpec{{ 898 ListenAddress: localAddress, 899 ConnectAddress: echoAddress, 900 }}, 901 } 902 903 currentApp.State = "STARTED" 904 905 sshEndpointFingerprint = "" 906 sshEndpoint = "" 907 908 token = "" 909 910 fakeSecureClient.DialStub = net.Dial 911 }) 912 913 JustBeforeEach(func() { 914 connectErr := secureShell.Connect(opts) 915 Expect(connectErr).NotTo(HaveOccurred()) 916 917 localForwardError = secureShell.LocalPortForward() 918 }) 919 920 AfterEach(func() { 921 err := secureShell.Close() 922 Expect(err).NotTo(HaveOccurred()) 923 echoServer.Shutdown() 924 925 realLocalListener.Close() 926 }) 927 928 validateConnectivity := func(addr string) { 929 conn, err := net.Dial("tcp", addr) 930 Expect(err).NotTo(HaveOccurred()) 931 932 msg := fmt.Sprintf("Hello from %s\n", addr) 933 n, err := conn.Write([]byte(msg)) 934 Expect(err).NotTo(HaveOccurred()) 935 Expect(n).To(Equal(len(msg))) 936 937 response := make([]byte, len(msg)) 938 n, err = conn.Read(response) 939 Expect(err).NotTo(HaveOccurred()) 940 Expect(n).To(Equal(len(msg))) 941 942 err = conn.Close() 943 Expect(err).NotTo(HaveOccurred()) 944 945 Expect(response).To(Equal([]byte(msg))) 946 } 947 948 It("dials the connect address when a local connection is made", func() { 949 Expect(localForwardError).NotTo(HaveOccurred()) 950 951 conn, err := net.Dial("tcp", localAddress) 952 Expect(err).NotTo(HaveOccurred()) 953 954 Eventually(echoListener.AcceptCallCount).Should(BeNumerically(">=", 1)) 955 Eventually(fakeSecureClient.DialCallCount).Should(Equal(1)) 956 957 network, addr := fakeSecureClient.DialArgsForCall(0) 958 Expect(network).To(Equal("tcp")) 959 Expect(addr).To(Equal(echoAddress)) 960 961 Expect(conn.Close()).NotTo(HaveOccurred()) 962 }) 963 964 It("copies data between the local and remote connections", func() { 965 validateConnectivity(localAddress) 966 }) 967 968 Context("when a local connection is already open", func() { 969 var ( 970 conn net.Conn 971 err error 972 ) 973 974 JustBeforeEach(func() { 975 conn, err = net.Dial("tcp", localAddress) 976 Expect(err).NotTo(HaveOccurred()) 977 }) 978 979 AfterEach(func() { 980 err = conn.Close() 981 Expect(err).NotTo(HaveOccurred()) 982 }) 983 984 It("allows for new incoming connections as well", func() { 985 validateConnectivity(localAddress) 986 }) 987 }) 988 989 Context("when there are multiple port forward specs", func() { 990 var realLocalListener2 net.Listener 991 var localAddress2 string 992 993 BeforeEach(func() { 994 var err error 995 realLocalListener2, err = net.Listen("tcp", "127.0.0.1:0") 996 Expect(err).NotTo(HaveOccurred()) 997 998 localAddress2 = realLocalListener2.Addr().String() 999 1000 fakeListenerFactory.ListenStub = func(network, addr string) (net.Listener, error) { 1001 if addr == localAddress { 1002 return realLocalListener, nil 1003 } 1004 1005 if addr == localAddress2 { 1006 return realLocalListener2, nil 1007 } 1008 1009 return nil, errors.New("unexpected address") 1010 } 1011 1012 opts = &options.SSHOptions{ 1013 AppName: "app-1", 1014 ForwardSpecs: []options.ForwardSpec{{ 1015 ListenAddress: localAddress, 1016 ConnectAddress: echoAddress, 1017 }, { 1018 ListenAddress: localAddress2, 1019 ConnectAddress: echoAddress, 1020 }}, 1021 } 1022 }) 1023 1024 AfterEach(func() { 1025 realLocalListener2.Close() 1026 }) 1027 1028 It("listens to all the things", func() { 1029 Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(2)) 1030 1031 network, addr := fakeListenerFactory.ListenArgsForCall(0) 1032 Expect(network).To(Equal("tcp")) 1033 Expect(addr).To(Equal(localAddress)) 1034 1035 network, addr = fakeListenerFactory.ListenArgsForCall(1) 1036 Expect(network).To(Equal("tcp")) 1037 Expect(addr).To(Equal(localAddress2)) 1038 }) 1039 1040 It("forwards to the correct target", func() { 1041 validateConnectivity(localAddress) 1042 validateConnectivity(localAddress2) 1043 }) 1044 1045 Context("when the secure client is closed", func() { 1046 BeforeEach(func() { 1047 fakeListenerFactory.ListenReturns(fakeLocalListener, nil) 1048 fakeLocalListener.AcceptReturns(nil, errors.New("not accepting connections")) 1049 }) 1050 1051 It("closes the listeners ", func() { 1052 Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(2)) 1053 Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(2)) 1054 1055 originalCloseCount := fakeLocalListener.CloseCallCount() 1056 err := secureShell.Close() 1057 Expect(err).NotTo(HaveOccurred()) 1058 Expect(fakeLocalListener.CloseCallCount()).Should(Equal(originalCloseCount + 2)) 1059 }) 1060 }) 1061 }) 1062 1063 Context("when listen fails", func() { 1064 BeforeEach(func() { 1065 fakeListenerFactory.ListenReturns(nil, errors.New("failure is an option")) 1066 }) 1067 1068 It("returns the error", func() { 1069 Expect(localForwardError).To(MatchError("failure is an option")) 1070 }) 1071 }) 1072 1073 Context("when the client it closed", func() { 1074 BeforeEach(func() { 1075 fakeListenerFactory.ListenReturns(fakeLocalListener, nil) 1076 fakeLocalListener.AcceptReturns(nil, errors.New("not accepting and connections")) 1077 }) 1078 1079 It("closes the listener when the client is closed", func() { 1080 Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(1)) 1081 Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(1)) 1082 1083 originalCloseCount := fakeLocalListener.CloseCallCount() 1084 err := secureShell.Close() 1085 Expect(err).NotTo(HaveOccurred()) 1086 Expect(fakeLocalListener.CloseCallCount()).Should(Equal(originalCloseCount + 1)) 1087 }) 1088 }) 1089 1090 Context("when accept fails", func() { 1091 var fakeConn *fake_net.FakeConn 1092 BeforeEach(func() { 1093 fakeConn = &fake_net.FakeConn{} 1094 fakeConn.ReadReturns(0, io.EOF) 1095 1096 fakeListenerFactory.ListenReturns(fakeLocalListener, nil) 1097 }) 1098 1099 Context("with a permanent error", func() { 1100 BeforeEach(func() { 1101 fakeLocalListener.AcceptReturns(nil, errors.New("boom")) 1102 }) 1103 1104 It("stops trying to accept connections", func() { 1105 Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(1)) 1106 Consistently(fakeLocalListener.AcceptCallCount).Should(Equal(1)) 1107 Expect(fakeLocalListener.CloseCallCount()).To(Equal(1)) 1108 }) 1109 }) 1110 1111 Context("with a temporary error", func() { 1112 var timeCh chan time.Time 1113 1114 BeforeEach(func() { 1115 timeCh = make(chan time.Time, 3) 1116 1117 fakeLocalListener.AcceptStub = func() (net.Conn, error) { 1118 timeCh := timeCh 1119 if fakeLocalListener.AcceptCallCount() > 3 { 1120 close(timeCh) 1121 return nil, test_helpers.NewTestNetError(false, false) 1122 } else { 1123 timeCh <- time.Now() 1124 return nil, test_helpers.NewTestNetError(false, true) 1125 } 1126 } 1127 }) 1128 1129 It("retries connecting after a short delay", func() { 1130 Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(3)) 1131 Expect(timeCh).To(HaveLen(3)) 1132 1133 times := make([]time.Time, 0) 1134 for t := range timeCh { 1135 times = append(times, t) 1136 } 1137 1138 // Expected interval time = 100 msec 1139 Expect(times[1]).To(BeTemporally(">=", times[0])) 1140 Expect(times[1]).To(BeTemporally("<=", times[0].Add(500*time.Millisecond))) 1141 Expect(times[2]).To(BeTemporally(">=", times[1])) 1142 Expect(times[2]).To(BeTemporally("<=", times[1].Add(500*time.Millisecond))) 1143 }) 1144 }) 1145 }) 1146 1147 Context("when dialing the connect address fails", func() { 1148 var fakeTarget *fake_net.FakeConn 1149 1150 BeforeEach(func() { 1151 fakeTarget = &fake_net.FakeConn{} 1152 fakeSecureClient.DialReturns(fakeTarget, errors.New("boom")) 1153 }) 1154 1155 It("does not call close on the target connection", func() { 1156 Consistently(fakeTarget.CloseCallCount).Should(Equal(0)) 1157 }) 1158 }) 1159 }) 1160 1161 Describe("Wait", func() { 1162 var opts *options.SSHOptions 1163 var waitErr error 1164 1165 BeforeEach(func() { 1166 opts = &options.SSHOptions{ 1167 AppName: "app-1", 1168 } 1169 1170 currentApp.State = "STARTED" 1171 1172 sshEndpointFingerprint = "" 1173 sshEndpoint = "" 1174 1175 token = "" 1176 }) 1177 1178 JustBeforeEach(func() { 1179 connectErr := secureShell.Connect(opts) 1180 Expect(connectErr).NotTo(HaveOccurred()) 1181 1182 waitErr = secureShell.Wait() 1183 }) 1184 1185 It("calls wait on the secureClient", func() { 1186 Expect(waitErr).NotTo(HaveOccurred()) 1187 Expect(fakeSecureClient.WaitCallCount()).To(Equal(1)) 1188 }) 1189 1190 Describe("keep alive messages", func() { 1191 var times []time.Time 1192 var timesCh chan []time.Time 1193 var done chan struct{} 1194 1195 BeforeEach(func() { 1196 keepAliveDuration = 100 * time.Millisecond 1197 1198 times = []time.Time{} 1199 timesCh = make(chan []time.Time, 1) 1200 done = make(chan struct{}, 1) 1201 1202 fakeConnection.SendRequestStub = func(reqName string, wantReply bool, message []byte) (bool, []byte, error) { 1203 Expect(reqName).To(Equal("keepalive@cloudfoundry.org")) 1204 Expect(wantReply).To(BeTrue()) 1205 Expect(message).To(BeNil()) 1206 1207 times = append(times, time.Now()) 1208 if len(times) == 3 { 1209 timesCh <- times 1210 close(done) 1211 } 1212 return true, nil, nil 1213 } 1214 1215 fakeSecureClient.WaitStub = func() error { 1216 Eventually(done).Should(BeClosed()) 1217 return nil 1218 } 1219 }) 1220 1221 It("sends keep alive messages at the expected interval", func() { 1222 Expect(waitErr).NotTo(HaveOccurred()) 1223 times := <-timesCh 1224 1225 // Expected interval time = 100 msec 1226 Expect(times[1]).To(BeTemporally(">=", times[0])) 1227 Expect(times[1]).To(BeTemporally("<=", times[0].Add(500*time.Millisecond))) 1228 Expect(times[2]).To(BeTemporally(">=", times[1])) 1229 Expect(times[2]).To(BeTemporally("<=", times[1].Add(500*time.Millisecond))) 1230 }) 1231 }) 1232 }) 1233 1234 Describe("Close", func() { 1235 var opts *options.SSHOptions 1236 1237 BeforeEach(func() { 1238 opts = &options.SSHOptions{ 1239 AppName: "app-1", 1240 } 1241 1242 currentApp.State = "STARTED" 1243 1244 sshEndpointFingerprint = "" 1245 sshEndpoint = "" 1246 1247 token = "" 1248 }) 1249 1250 JustBeforeEach(func() { 1251 connectErr := secureShell.Connect(opts) 1252 Expect(connectErr).NotTo(HaveOccurred()) 1253 }) 1254 1255 It("calls close on the secureClient", func() { 1256 err := secureShell.Close() 1257 Expect(err).NotTo(HaveOccurred()) 1258 1259 Expect(fakeSecureClient.CloseCallCount()).To(Equal(1)) 1260 }) 1261 }) 1262 })