github.com/DaAlbrecht/cf-cli@v0.0.0-20231128151943-1fe19bb400b9/cf/ssh/ssh_test.go (about) 1 //go:build !windows && !386 2 // +build !windows,!386 3 4 // skipping 386 because lager uses UInt64 in Session() 5 // skipping windows because Unix/Linux only syscall in test. 6 // should refactor out the conflicts so we could test this package in multi platforms. 7 8 package sshCmd_test 9 10 import ( 11 "errors" 12 "fmt" 13 "io" 14 "net" 15 "os" 16 "syscall" 17 "time" 18 19 "code.cloudfoundry.org/cli/cf/models" 20 sshCmd "code.cloudfoundry.org/cli/cf/ssh" 21 "code.cloudfoundry.org/cli/cf/ssh/options" 22 "code.cloudfoundry.org/cli/cf/ssh/sshfakes" 23 "code.cloudfoundry.org/cli/cf/ssh/terminal" 24 "code.cloudfoundry.org/cli/cf/ssh/terminal/terminalfakes" 25 "code.cloudfoundry.org/diego-ssh/server" 26 fake_server "code.cloudfoundry.org/diego-ssh/server/fakes" 27 "code.cloudfoundry.org/diego-ssh/test_helpers" 28 "code.cloudfoundry.org/diego-ssh/test_helpers/fake_io" 29 "code.cloudfoundry.org/diego-ssh/test_helpers/fake_net" 30 "code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh" 31 "code.cloudfoundry.org/lager/lagertest" 32 "github.com/kr/pty" 33 "github.com/moby/term" 34 "golang.org/x/crypto/ssh" 35 36 . "github.com/onsi/ginkgo" 37 . "github.com/onsi/gomega" 38 ) 39 40 var _ = Describe("SSH", func() { 41 var ( 42 fakeTerminalHelper *terminalfakes.FakeTerminalHelper 43 fakeListenerFactory *sshfakes.FakeListenerFactory 44 45 fakeConnection *fake_ssh.FakeConn 46 fakeSecureClient *sshfakes.FakeSecureClient 47 fakeSecureDialer *sshfakes.FakeSecureDialer 48 fakeSecureSession *sshfakes.FakeSecureSession 49 50 terminalHelper terminal.TerminalHelper 51 keepAliveDuration time.Duration 52 secureShell sshCmd.SecureShell 53 54 stdinPipe *fake_io.FakeWriteCloser 55 56 currentApp models.Application 57 sshEndpointFingerprint string 58 sshEndpoint string 59 token string 60 ) 61 62 BeforeEach(func() { 63 fakeTerminalHelper = new(terminalfakes.FakeTerminalHelper) 64 terminalHelper = terminal.DefaultHelper() 65 66 fakeListenerFactory = new(sshfakes.FakeListenerFactory) 67 fakeListenerFactory.ListenStub = net.Listen 68 69 keepAliveDuration = 30 * time.Second 70 71 currentApp = models.Application{} 72 sshEndpoint = "" 73 sshEndpointFingerprint = "" 74 token = "" 75 76 fakeConnection = new(fake_ssh.FakeConn) 77 fakeSecureClient = new(sshfakes.FakeSecureClient) 78 fakeSecureDialer = new(sshfakes.FakeSecureDialer) 79 fakeSecureSession = new(sshfakes.FakeSecureSession) 80 81 fakeSecureDialer.DialReturns(fakeSecureClient, nil) 82 fakeSecureClient.NewSessionReturns(fakeSecureSession, nil) 83 fakeSecureClient.ConnReturns(fakeConnection) 84 85 stdinPipe = &fake_io.FakeWriteCloser{} 86 stdinPipe.WriteStub = func(p []byte) (int, error) { 87 return len(p), nil 88 } 89 90 stdoutPipe := &fake_io.FakeReader{} 91 stdoutPipe.ReadStub = func(p []byte) (int, error) { 92 return 0, io.EOF 93 } 94 95 stderrPipe := &fake_io.FakeReader{} 96 stderrPipe.ReadStub = func(p []byte) (int, error) { 97 return 0, io.EOF 98 } 99 100 fakeSecureSession.StdinPipeReturns(stdinPipe, nil) 101 fakeSecureSession.StdoutPipeReturns(stdoutPipe, nil) 102 fakeSecureSession.StderrPipeReturns(stderrPipe, nil) 103 }) 104 105 JustBeforeEach(func() { 106 secureShell = sshCmd.NewSecureShell( 107 fakeSecureDialer, 108 terminalHelper, 109 fakeListenerFactory, 110 keepAliveDuration, 111 currentApp, 112 sshEndpointFingerprint, 113 sshEndpoint, 114 token, 115 ) 116 }) 117 118 Describe("Validation", func() { 119 var connectErr error 120 var opts *options.SSHOptions 121 122 BeforeEach(func() { 123 opts = &options.SSHOptions{ 124 AppName: "app-1", 125 } 126 }) 127 128 JustBeforeEach(func() { 129 connectErr = secureShell.Connect(opts) 130 }) 131 132 Context("when the app model and endpoint info are successfully acquired", func() { 133 BeforeEach(func() { 134 token = "" 135 currentApp.State = "STARTED" 136 }) 137 138 Context("when the app is not in the 'STARTED' state", func() { 139 BeforeEach(func() { 140 currentApp.State = "STOPPED" 141 }) 142 143 It("returns an error", func() { 144 Expect(connectErr).To(MatchError(MatchRegexp("Application.*not in the STARTED state"))) 145 }) 146 }) 147 148 Context("when dialing fails", func() { 149 var dialError = errors.New("woops") 150 151 BeforeEach(func() { 152 fakeSecureDialer.DialReturns(nil, dialError) 153 }) 154 155 It("returns the dial error", func() { 156 Expect(connectErr).To(Equal(dialError)) 157 Expect(fakeSecureDialer.DialCallCount()).To(Equal(1)) 158 }) 159 }) 160 }) 161 }) 162 163 Describe("InteractiveSession", func() { 164 var opts *options.SSHOptions 165 var sessionError error 166 var interactiveSessionInvoker func(secureShell sshCmd.SecureShell) 167 168 BeforeEach(func() { 169 sshEndpoint = "ssh.example.com:22" 170 171 opts = &options.SSHOptions{ 172 AppName: "app-name", 173 Index: 2, 174 } 175 176 currentApp.State = "STARTED" 177 currentApp.GUID = "app-guid" 178 token = "bearer token" 179 180 interactiveSessionInvoker = func(secureShell sshCmd.SecureShell) { 181 sessionError = secureShell.InteractiveSession() 182 } 183 }) 184 185 JustBeforeEach(func() { 186 connectErr := secureShell.Connect(opts) 187 Expect(connectErr).NotTo(HaveOccurred()) 188 interactiveSessionInvoker(secureShell) 189 }) 190 191 It("dials the correct endpoint as the correct user", func() { 192 Expect(fakeSecureDialer.DialCallCount()).To(Equal(1)) 193 194 network, address, config := fakeSecureDialer.DialArgsForCall(0) 195 Expect(network).To(Equal("tcp")) 196 Expect(address).To(Equal("ssh.example.com:22")) 197 Expect(config.Auth).NotTo(BeEmpty()) 198 Expect(config.User).To(Equal("cf:app-guid/2")) 199 Expect(config.HostKeyCallback).NotTo(BeNil()) 200 }) 201 202 Context("when host key validation is enabled", func() { 203 var callback func(hostname string, remote net.Addr, key ssh.PublicKey) error 204 var addr net.Addr 205 206 JustBeforeEach(func() { 207 Expect(fakeSecureDialer.DialCallCount()).To(Equal(1)) 208 _, _, config := fakeSecureDialer.DialArgsForCall(0) 209 callback = config.HostKeyCallback 210 211 listener, err := net.Listen("tcp", "localhost:0") 212 Expect(err).NotTo(HaveOccurred()) 213 214 addr = listener.Addr() 215 listener.Close() 216 }) 217 218 Context("when the md5 fingerprint matches", func() { 219 BeforeEach(func() { 220 sshEndpointFingerprint = "41:ce:56:e6:9c:42:a9:c6:9e:68:ac:e3:4d:f6:38:79" 221 }) 222 223 It("does not return an error", func() { 224 Expect(callback("", addr, TestHostKey.PublicKey())).ToNot(HaveOccurred()) 225 }) 226 }) 227 228 Context("when the hex sha1 fingerprint matches", func() { 229 BeforeEach(func() { 230 sshEndpointFingerprint = "a8:e2:67:cb:ea:2a:6e:23:a1:72:ce:8f:07:92:15:ee:1f:82:f8:ca" 231 }) 232 233 It("does not return an error", func() { 234 Expect(callback("", addr, TestHostKey.PublicKey())).ToNot(HaveOccurred()) 235 }) 236 }) 237 238 Context("when the base64 sha256 fingerprint matches", func() { 239 BeforeEach(func() { 240 sshEndpointFingerprint = "sp/jrLuj66r+yrLDUKZdJU5tdzt4mq/UaSiNBjpgr+8" 241 }) 242 243 It("does not return an error", func() { 244 Expect(callback("", addr, TestHostKey.PublicKey())).ToNot(HaveOccurred()) 245 }) 246 }) 247 248 Context("when the base64 SHA256 fingerprint does not match", func() { 249 BeforeEach(func() { 250 sshEndpointFingerprint = "0000000000000000000000000000000000000000000" 251 }) 252 253 It("returns an error'", func() { 254 err := callback("", addr, TestHostKey.PublicKey()) 255 Expect(err).To(MatchError(MatchRegexp("Host key verification failed\\."))) 256 Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\""))) 257 }) 258 }) 259 260 Context("when the hex SHA1 fingerprint does not match", func() { 261 BeforeEach(func() { 262 sshEndpointFingerprint = "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00" 263 }) 264 265 It("returns an error'", func() { 266 err := callback("", addr, TestHostKey.PublicKey()) 267 Expect(err).To(MatchError(MatchRegexp("Host key verification failed\\."))) 268 Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\""))) 269 }) 270 }) 271 272 Context("when the MD5 fingerprint does not match", func() { 273 BeforeEach(func() { 274 sshEndpointFingerprint = "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00" 275 }) 276 277 It("returns an error'", func() { 278 err := callback("", addr, TestHostKey.PublicKey()) 279 Expect(err).To(MatchError(MatchRegexp("Host key verification failed\\."))) 280 Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\""))) 281 }) 282 }) 283 284 Context("when no fingerprint is present in endpoint info", func() { 285 BeforeEach(func() { 286 sshEndpointFingerprint = "" 287 sshEndpoint = "" 288 }) 289 290 It("returns an error'", func() { 291 err := callback("", addr, TestHostKey.PublicKey()) 292 Expect(err).To(MatchError(MatchRegexp("Unable to verify identity of host\\."))) 293 Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\""))) 294 }) 295 }) 296 297 Context("when the fingerprint length doesn't make sense", func() { 298 BeforeEach(func() { 299 sshEndpointFingerprint = "garbage" 300 }) 301 302 It("returns an error", func() { 303 err := callback("", addr, TestHostKey.PublicKey()) 304 Eventually(err).Should(MatchError(MatchRegexp("Unsupported host key fingerprint format"))) 305 }) 306 }) 307 }) 308 309 Context("when the skip host validation flag is set", func() { 310 BeforeEach(func() { 311 opts.SkipHostValidation = true 312 }) 313 314 It("removes the HostKeyCallback from the client config", func() { 315 Expect(fakeSecureDialer.DialCallCount()).To(Equal(1)) 316 317 _, _, config := fakeSecureDialer.DialArgsForCall(0) 318 Expect(config.HostKeyCallback("some-addr", nil, nil)).To(BeNil()) 319 }) 320 }) 321 322 Context("when dialing is successful", func() { 323 BeforeEach(func() { 324 fakeTerminalHelper.StdStreamsStub = terminalHelper.StdStreams 325 terminalHelper = fakeTerminalHelper 326 }) 327 328 It("creates a new secure shell session", func() { 329 Expect(fakeSecureClient.NewSessionCallCount()).To(Equal(1)) 330 }) 331 332 It("closes the session", func() { 333 Expect(fakeSecureSession.CloseCallCount()).To(Equal(1)) 334 }) 335 336 It("allocates standard streams", func() { 337 Expect(fakeTerminalHelper.StdStreamsCallCount()).To(Equal(1)) 338 }) 339 340 It("gets a stdin pipe for the session", func() { 341 Expect(fakeSecureSession.StdinPipeCallCount()).To(Equal(1)) 342 }) 343 344 Context("when getting the stdin pipe fails", func() { 345 BeforeEach(func() { 346 fakeSecureSession.StdinPipeReturns(nil, errors.New("woops")) 347 }) 348 349 It("returns the error", func() { 350 Expect(sessionError).Should(MatchError("woops")) 351 }) 352 }) 353 354 It("gets a stdout pipe for the session", func() { 355 Expect(fakeSecureSession.StdoutPipeCallCount()).To(Equal(1)) 356 }) 357 358 Context("when getting the stdout pipe fails", func() { 359 BeforeEach(func() { 360 fakeSecureSession.StdoutPipeReturns(nil, errors.New("woops")) 361 }) 362 363 It("returns the error", func() { 364 Expect(sessionError).Should(MatchError("woops")) 365 }) 366 }) 367 368 It("gets a stderr pipe for the session", func() { 369 Expect(fakeSecureSession.StderrPipeCallCount()).To(Equal(1)) 370 }) 371 372 Context("when getting the stderr pipe fails", func() { 373 BeforeEach(func() { 374 fakeSecureSession.StderrPipeReturns(nil, errors.New("woops")) 375 }) 376 377 It("returns the error", func() { 378 Expect(sessionError).Should(MatchError("woops")) 379 }) 380 }) 381 }) 382 383 Context("when stdin is a terminal", func() { 384 var pseudoterminal, terminal *os.File 385 386 BeforeEach(func() { 387 _, stdout, stderr := terminalHelper.StdStreams() 388 389 var err error 390 pseudoterminal, terminal, err = pty.Open() 391 Expect(err).NotTo(HaveOccurred()) 392 393 fakeTerminalHelper.IsTerminalStub = terminalHelper.IsTerminal 394 fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo 395 fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize 396 fakeTerminalHelper.StdStreamsReturns(terminal, stdout, stderr) 397 terminalHelper = fakeTerminalHelper 398 }) 399 400 AfterEach(func() { 401 pseudoterminal.Close() 402 // terminal.Close() // race 403 }) 404 405 Context("when a command is not specified", func() { 406 var terminalType string 407 408 BeforeEach(func() { 409 terminalType = os.Getenv("TERM") 410 os.Setenv("TERM", "test-terminal-type") 411 412 winsize := &term.Winsize{Width: 1024, Height: 256} 413 fakeTerminalHelper.GetWinsizeReturns(winsize, nil) 414 415 fakeSecureSession.ShellStub = func() error { 416 Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1)) 417 Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(0)) 418 return nil 419 } 420 }) 421 422 AfterEach(func() { 423 os.Setenv("TERM", terminalType) 424 }) 425 426 It("requests a pty with the correct terminal type, window size, and modes", func() { 427 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1)) 428 Expect(fakeTerminalHelper.GetWinsizeCallCount()).To(Equal(1)) 429 430 termType, height, width, modes := fakeSecureSession.RequestPtyArgsForCall(0) 431 Expect(termType).To(Equal("test-terminal-type")) 432 Expect(height).To(Equal(256)) 433 Expect(width).To(Equal(1024)) 434 435 expectedModes := ssh.TerminalModes{ 436 ssh.ECHO: 1, 437 ssh.TTY_OP_ISPEED: 115200, 438 ssh.TTY_OP_OSPEED: 115200, 439 } 440 Expect(modes).To(Equal(expectedModes)) 441 }) 442 443 Context("when the TERM environment variable is not set", func() { 444 BeforeEach(func() { 445 os.Unsetenv("TERM") 446 }) 447 448 It("requests a pty with the default terminal type", func() { 449 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1)) 450 451 termType, _, _, _ := fakeSecureSession.RequestPtyArgsForCall(0) 452 Expect(termType).To(Equal("xterm")) 453 }) 454 }) 455 456 It("puts the terminal into raw mode and restores it after running the shell", func() { 457 Expect(fakeSecureSession.ShellCallCount()).To(Equal(1)) 458 Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1)) 459 Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(1)) 460 }) 461 462 Context("when the pty allocation fails", func() { 463 var ptyError error 464 465 BeforeEach(func() { 466 ptyError = errors.New("pty allocation error") 467 fakeSecureSession.RequestPtyReturns(ptyError) 468 }) 469 470 It("returns the error", func() { 471 Expect(sessionError).To(Equal(ptyError)) 472 }) 473 }) 474 475 Context("when placing the terminal into raw mode fails", func() { 476 BeforeEach(func() { 477 fakeTerminalHelper.SetRawTerminalReturns(nil, errors.New("woops")) 478 }) 479 480 It("keeps calm and carries on", func() { 481 Expect(fakeSecureSession.ShellCallCount()).To(Equal(1)) 482 }) 483 484 It("does not restore the terminal", func() { 485 Expect(fakeSecureSession.ShellCallCount()).To(Equal(1)) 486 Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1)) 487 Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(0)) 488 }) 489 }) 490 }) 491 492 Context("when a command is specified", func() { 493 BeforeEach(func() { 494 opts.Command = []string{"echo", "-n", "hello"} 495 }) 496 497 Context("when a terminal is requested", func() { 498 BeforeEach(func() { 499 opts.TerminalRequest = options.RequestTTYYes 500 }) 501 502 It("requests a pty", func() { 503 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1)) 504 }) 505 }) 506 507 Context("when a terminal is not explicitly requested", func() { 508 It("does not request a pty", func() { 509 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0)) 510 }) 511 }) 512 }) 513 }) 514 515 Context("when stdin is not a terminal", func() { 516 BeforeEach(func() { 517 _, stdout, stderr := terminalHelper.StdStreams() 518 519 stdin := &fake_io.FakeReadCloser{} 520 stdin.ReadStub = func(p []byte) (int, error) { 521 return 0, io.EOF 522 } 523 524 fakeTerminalHelper.IsTerminalStub = terminalHelper.IsTerminal 525 fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo 526 fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize 527 fakeTerminalHelper.StdStreamsReturns(stdin, stdout, stderr) 528 terminalHelper = fakeTerminalHelper 529 }) 530 531 Context("when a terminal is not requested", func() { 532 It("does not request a pty", func() { 533 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0)) 534 }) 535 }) 536 537 Context("when a terminal is requested", func() { 538 BeforeEach(func() { 539 opts.TerminalRequest = options.RequestTTYYes 540 }) 541 542 It("does not request a pty", func() { 543 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0)) 544 }) 545 }) 546 }) 547 548 Context("when a terminal is forced", func() { 549 BeforeEach(func() { 550 opts.TerminalRequest = options.RequestTTYForce 551 }) 552 553 It("requests a pty", func() { 554 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1)) 555 }) 556 }) 557 558 Context("when a terminal is disabled", func() { 559 BeforeEach(func() { 560 opts.TerminalRequest = options.RequestTTYNo 561 }) 562 563 It("does not request a pty", func() { 564 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0)) 565 }) 566 }) 567 568 Context("when a command is not specified", func() { 569 It("requests an interactive shell", func() { 570 Expect(fakeSecureSession.ShellCallCount()).To(Equal(1)) 571 }) 572 573 Context("when the shell request returns an error", func() { 574 BeforeEach(func() { 575 fakeSecureSession.ShellReturns(errors.New("oh bother")) 576 }) 577 578 It("returns the error", func() { 579 Expect(sessionError).To(MatchError("oh bother")) 580 }) 581 }) 582 }) 583 584 Context("when a command is specified", func() { 585 BeforeEach(func() { 586 opts.Command = []string{"echo", "-n", "hello"} 587 }) 588 589 It("starts the command", func() { 590 Expect(fakeSecureSession.StartCallCount()).To(Equal(1)) 591 Expect(fakeSecureSession.StartArgsForCall(0)).To(Equal("echo -n hello")) 592 }) 593 594 Context("when the command fails to start", func() { 595 BeforeEach(func() { 596 fakeSecureSession.StartReturns(errors.New("oh well")) 597 }) 598 599 It("returns the error", func() { 600 Expect(sessionError).To(MatchError("oh well")) 601 }) 602 }) 603 }) 604 605 Context("when the shell or command has started", func() { 606 var ( 607 stdin *fake_io.FakeReadCloser 608 stdout, stderr *fake_io.FakeWriter 609 stdinPipe *fake_io.FakeWriteCloser 610 stdoutPipe, stderrPipe *fake_io.FakeReader 611 ) 612 613 BeforeEach(func() { 614 stdin = &fake_io.FakeReadCloser{} 615 stdin.ReadStub = func(p []byte) (int, error) { 616 p[0] = 0 617 return 1, io.EOF 618 } 619 stdinPipe = &fake_io.FakeWriteCloser{} 620 stdinPipe.WriteStub = func(p []byte) (int, error) { 621 defer GinkgoRecover() 622 Expect(p[0]).To(Equal(byte(0))) 623 return 1, nil 624 } 625 626 stdoutPipe = &fake_io.FakeReader{} 627 stdoutPipe.ReadStub = func(p []byte) (int, error) { 628 p[0] = 1 629 return 1, io.EOF 630 } 631 stdout = &fake_io.FakeWriter{} 632 stdout.WriteStub = func(p []byte) (int, error) { 633 defer GinkgoRecover() 634 Expect(p[0]).To(Equal(byte(1))) 635 return 1, nil 636 } 637 638 stderrPipe = &fake_io.FakeReader{} 639 stderrPipe.ReadStub = func(p []byte) (int, error) { 640 p[0] = 2 641 return 1, io.EOF 642 } 643 stderr = &fake_io.FakeWriter{} 644 stderr.WriteStub = func(p []byte) (int, error) { 645 defer GinkgoRecover() 646 Expect(p[0]).To(Equal(byte(2))) 647 return 1, nil 648 } 649 650 fakeTerminalHelper.StdStreamsReturns(stdin, stdout, stderr) 651 terminalHelper = fakeTerminalHelper 652 653 fakeSecureSession.StdinPipeReturns(stdinPipe, nil) 654 fakeSecureSession.StdoutPipeReturns(stdoutPipe, nil) 655 fakeSecureSession.StderrPipeReturns(stderrPipe, nil) 656 657 fakeSecureSession.WaitReturns(errors.New("error result")) 658 }) 659 660 It("copies data from the stdin stream to the session stdin pipe", func() { 661 Eventually(stdin.ReadCallCount).Should(Equal(1)) 662 Eventually(stdinPipe.WriteCallCount).Should(Equal(1)) 663 }) 664 665 It("copies data from the session stdout pipe to the stdout stream", func() { 666 Eventually(stdoutPipe.ReadCallCount).Should(Equal(1)) 667 Eventually(stdout.WriteCallCount).Should(Equal(1)) 668 }) 669 670 It("copies data from the session stderr pipe to the stderr stream", func() { 671 Eventually(stderrPipe.ReadCallCount).Should(Equal(1)) 672 Eventually(stderr.WriteCallCount).Should(Equal(1)) 673 }) 674 675 It("waits for the session to end", func() { 676 Expect(fakeSecureSession.WaitCallCount()).To(Equal(1)) 677 }) 678 679 It("returns the result from wait", func() { 680 Expect(sessionError).To(MatchError("error result")) 681 }) 682 683 Context("when the session terminates before stream copies complete", func() { 684 var sessionErrorCh chan error 685 686 BeforeEach(func() { 687 sessionErrorCh = make(chan error, 1) 688 689 interactiveSessionInvoker = func(secureShell sshCmd.SecureShell) { 690 go func() { sessionErrorCh <- secureShell.InteractiveSession() }() 691 } 692 693 stdoutPipe.ReadStub = func(p []byte) (int, error) { 694 defer GinkgoRecover() 695 Eventually(fakeSecureSession.WaitCallCount).Should(Equal(1)) 696 Consistently(sessionErrorCh).ShouldNot(Receive()) 697 698 p[0] = 1 699 return 1, io.EOF 700 } 701 702 stderrPipe.ReadStub = func(p []byte) (int, error) { 703 defer GinkgoRecover() 704 Eventually(fakeSecureSession.WaitCallCount).Should(Equal(1)) 705 Consistently(sessionErrorCh).ShouldNot(Receive()) 706 707 p[0] = 2 708 return 1, io.EOF 709 } 710 }) 711 712 It("waits for the copies to complete", func() { 713 Eventually(sessionErrorCh).Should(Receive()) 714 Expect(stdoutPipe.ReadCallCount()).To(Equal(1)) 715 Expect(stderrPipe.ReadCallCount()).To(Equal(1)) 716 }) 717 }) 718 719 Context("when stdin is closed", func() { 720 BeforeEach(func() { 721 stdin.ReadStub = func(p []byte) (int, error) { 722 defer GinkgoRecover() 723 Consistently(stdinPipe.CloseCallCount).Should(Equal(0)) 724 p[0] = 0 725 return 1, io.EOF 726 } 727 }) 728 729 It("closes the stdinPipe", func() { 730 Eventually(stdinPipe.CloseCallCount).Should(Equal(1)) 731 }) 732 }) 733 }) 734 735 Context("when stdout is a terminal and a window size change occurs", func() { 736 var pseudoterminal, terminal *os.File 737 738 BeforeEach(func() { 739 stdin, _, stderr := terminalHelper.StdStreams() 740 741 var err error 742 pseudoterminal, terminal, err = pty.Open() 743 Expect(err).NotTo(HaveOccurred()) 744 745 fakeTerminalHelper.IsTerminalStub = terminalHelper.IsTerminal 746 fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo 747 fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize 748 fakeTerminalHelper.StdStreamsReturns(stdin, terminal, stderr) 749 terminalHelper = fakeTerminalHelper 750 751 winsize := &term.Winsize{Height: 100, Width: 100} 752 err = term.SetWinsize(terminal.Fd(), winsize) 753 Expect(err).NotTo(HaveOccurred()) 754 755 fakeSecureSession.WaitStub = func() error { 756 fakeSecureSession.SendRequestCallCount() 757 Expect(fakeSecureSession.SendRequestCallCount()).To(Equal(0)) 758 759 // No dimension change 760 for i := 0; i < 3; i++ { 761 winsize := &term.Winsize{Height: 100, Width: 100} 762 err = term.SetWinsize(terminal.Fd(), winsize) 763 Expect(err).NotTo(HaveOccurred()) 764 } 765 766 winsize := &term.Winsize{Height: 100, Width: 200} 767 err = term.SetWinsize(terminal.Fd(), winsize) 768 Expect(err).NotTo(HaveOccurred()) 769 770 err = syscall.Kill(syscall.Getpid(), syscall.SIGWINCH) 771 Expect(err).NotTo(HaveOccurred()) 772 773 Eventually(fakeSecureSession.SendRequestCallCount).Should(Equal(1)) 774 return nil 775 } 776 }) 777 778 AfterEach(func() { 779 pseudoterminal.Close() 780 terminal.Close() 781 }) 782 783 It("sends window change events when the window dimensions change", func() { 784 Expect(fakeSecureSession.SendRequestCallCount()).To(Equal(1)) 785 786 requestType, wantReply, message := fakeSecureSession.SendRequestArgsForCall(0) 787 Expect(requestType).To(Equal("window-change")) 788 Expect(wantReply).To(BeFalse()) 789 790 type resizeMessage struct { 791 Width uint32 792 Height uint32 793 PixelWidth uint32 794 PixelHeight uint32 795 } 796 var resizeMsg resizeMessage 797 798 err := ssh.Unmarshal(message, &resizeMsg) 799 Expect(err).NotTo(HaveOccurred()) 800 801 Expect(resizeMsg).To(Equal(resizeMessage{Height: 100, Width: 200})) 802 }) 803 }) 804 805 Describe("keep alive messages", func() { 806 var times []time.Time 807 var timesCh chan []time.Time 808 var done chan struct{} 809 810 BeforeEach(func() { 811 keepAliveDuration = 100 * time.Millisecond 812 813 times = []time.Time{} 814 timesCh = make(chan []time.Time, 1) 815 done = make(chan struct{}, 1) 816 817 fakeConnection.SendRequestStub = func(reqName string, wantReply bool, message []byte) (bool, []byte, error) { 818 Expect(reqName).To(Equal("keepalive@cloudfoundry.org")) 819 Expect(wantReply).To(BeTrue()) 820 Expect(message).To(BeNil()) 821 822 times = append(times, time.Now()) 823 if len(times) == 3 { 824 timesCh <- times 825 close(done) 826 } 827 return true, nil, nil 828 } 829 830 fakeSecureSession.WaitStub = func() error { 831 Eventually(done).Should(BeClosed()) 832 return nil 833 } 834 }) 835 836 It("sends keep alive messages at the expected interval", func() { 837 times := <-timesCh 838 839 // Expected interval time = 100 msec 840 Expect(times[1]).To(BeTemporally(">=", times[0])) 841 Expect(times[1]).To(BeTemporally("<=", times[0].Add(500*time.Millisecond))) 842 Expect(times[2]).To(BeTemporally(">=", times[1])) 843 Expect(times[2]).To(BeTemporally("<=", times[1].Add(500*time.Millisecond))) 844 }) 845 }) 846 }) 847 848 Describe("LocalPortForward", func() { 849 var ( 850 opts *options.SSHOptions 851 localForwardError error 852 853 echoAddress string 854 echoListener *fake_net.FakeListener 855 echoHandler *fake_server.FakeConnectionHandler 856 echoServer *server.Server 857 858 localAddress string 859 860 realLocalListener net.Listener 861 fakeLocalListener *fake_net.FakeListener 862 ) 863 864 BeforeEach(func() { 865 logger := lagertest.NewTestLogger("test") 866 867 var err error 868 realLocalListener, err = net.Listen("tcp", "127.0.0.1:0") 869 Expect(err).NotTo(HaveOccurred()) 870 871 localAddress = realLocalListener.Addr().String() 872 fakeListenerFactory.ListenReturns(realLocalListener, nil) 873 874 echoHandler = &fake_server.FakeConnectionHandler{} 875 echoHandler.HandleConnectionStub = func(conn net.Conn) { 876 io.Copy(conn, conn) 877 conn.Close() 878 } 879 880 realListener, err := net.Listen("tcp", "127.0.0.1:0") 881 Expect(err).NotTo(HaveOccurred()) 882 echoAddress = realListener.Addr().String() 883 884 echoListener = &fake_net.FakeListener{} 885 echoListener.AcceptStub = realListener.Accept 886 echoListener.CloseStub = realListener.Close 887 echoListener.AddrStub = realListener.Addr 888 889 fakeLocalListener = &fake_net.FakeListener{} 890 fakeLocalListener.AcceptReturns(nil, errors.New("Not Accepting Connections")) 891 892 echoServer = server.NewServer(logger.Session("echo"), "", echoHandler) 893 echoServer.SetListener(echoListener) 894 go echoServer.Serve() 895 896 opts = &options.SSHOptions{ 897 AppName: "app-1", 898 ForwardSpecs: []options.ForwardSpec{{ 899 ListenAddress: localAddress, 900 ConnectAddress: echoAddress, 901 }}, 902 } 903 904 currentApp.State = "STARTED" 905 906 sshEndpointFingerprint = "" 907 sshEndpoint = "" 908 909 token = "" 910 911 fakeSecureClient.DialStub = net.Dial 912 }) 913 914 JustBeforeEach(func() { 915 connectErr := secureShell.Connect(opts) 916 Expect(connectErr).NotTo(HaveOccurred()) 917 918 localForwardError = secureShell.LocalPortForward() 919 }) 920 921 AfterEach(func() { 922 err := secureShell.Close() 923 Expect(err).NotTo(HaveOccurred()) 924 echoServer.Shutdown() 925 926 realLocalListener.Close() 927 }) 928 929 validateConnectivity := func(addr string) { 930 conn, err := net.Dial("tcp", addr) 931 Expect(err).NotTo(HaveOccurred()) 932 933 msg := fmt.Sprintf("Hello from %s\n", addr) 934 n, err := conn.Write([]byte(msg)) 935 Expect(err).NotTo(HaveOccurred()) 936 Expect(n).To(Equal(len(msg))) 937 938 response := make([]byte, len(msg)) 939 n, err = conn.Read(response) 940 Expect(err).NotTo(HaveOccurred()) 941 Expect(n).To(Equal(len(msg))) 942 943 err = conn.Close() 944 Expect(err).NotTo(HaveOccurred()) 945 946 Expect(response).To(Equal([]byte(msg))) 947 } 948 949 It("dials the connect address when a local connection is made", func() { 950 Expect(localForwardError).NotTo(HaveOccurred()) 951 952 conn, err := net.Dial("tcp", localAddress) 953 Expect(err).NotTo(HaveOccurred()) 954 955 Eventually(echoListener.AcceptCallCount).Should(BeNumerically(">=", 1)) 956 Eventually(fakeSecureClient.DialCallCount).Should(Equal(1)) 957 958 network, addr := fakeSecureClient.DialArgsForCall(0) 959 Expect(network).To(Equal("tcp")) 960 Expect(addr).To(Equal(echoAddress)) 961 962 Expect(conn.Close()).NotTo(HaveOccurred()) 963 }) 964 965 It("copies data between the local and remote connections", func() { 966 validateConnectivity(localAddress) 967 }) 968 969 Context("when a local connection is already open", func() { 970 var ( 971 conn net.Conn 972 err error 973 ) 974 975 JustBeforeEach(func() { 976 conn, err = net.Dial("tcp", localAddress) 977 Expect(err).NotTo(HaveOccurred()) 978 }) 979 980 AfterEach(func() { 981 err = conn.Close() 982 Expect(err).NotTo(HaveOccurred()) 983 }) 984 985 It("allows for new incoming connections as well", func() { 986 validateConnectivity(localAddress) 987 }) 988 }) 989 990 Context("when there are multiple port forward specs", func() { 991 var realLocalListener2 net.Listener 992 var localAddress2 string 993 994 BeforeEach(func() { 995 var err error 996 realLocalListener2, err = net.Listen("tcp", "127.0.0.1:0") 997 Expect(err).NotTo(HaveOccurred()) 998 999 localAddress2 = realLocalListener2.Addr().String() 1000 1001 fakeListenerFactory.ListenStub = func(network, addr string) (net.Listener, error) { 1002 if addr == localAddress { 1003 return realLocalListener, nil 1004 } 1005 1006 if addr == localAddress2 { 1007 return realLocalListener2, nil 1008 } 1009 1010 return nil, errors.New("unexpected address") 1011 } 1012 1013 opts = &options.SSHOptions{ 1014 AppName: "app-1", 1015 ForwardSpecs: []options.ForwardSpec{{ 1016 ListenAddress: localAddress, 1017 ConnectAddress: echoAddress, 1018 }, { 1019 ListenAddress: localAddress2, 1020 ConnectAddress: echoAddress, 1021 }}, 1022 } 1023 }) 1024 1025 AfterEach(func() { 1026 realLocalListener2.Close() 1027 }) 1028 1029 It("listens to all the things", func() { 1030 Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(2)) 1031 1032 network, addr := fakeListenerFactory.ListenArgsForCall(0) 1033 Expect(network).To(Equal("tcp")) 1034 Expect(addr).To(Equal(localAddress)) 1035 1036 network, addr = fakeListenerFactory.ListenArgsForCall(1) 1037 Expect(network).To(Equal("tcp")) 1038 Expect(addr).To(Equal(localAddress2)) 1039 }) 1040 1041 It("forwards to the correct target", func() { 1042 validateConnectivity(localAddress) 1043 validateConnectivity(localAddress2) 1044 }) 1045 1046 Context("when the secure client is closed", func() { 1047 BeforeEach(func() { 1048 fakeListenerFactory.ListenReturns(fakeLocalListener, nil) 1049 fakeLocalListener.AcceptReturns(nil, errors.New("not accepting connections")) 1050 }) 1051 1052 It("closes the listeners ", func() { 1053 Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(2)) 1054 Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(2)) 1055 1056 originalCloseCount := fakeLocalListener.CloseCallCount() 1057 err := secureShell.Close() 1058 Expect(err).NotTo(HaveOccurred()) 1059 Expect(fakeLocalListener.CloseCallCount()).Should(Equal(originalCloseCount + 2)) 1060 }) 1061 }) 1062 }) 1063 1064 Context("when listen fails", func() { 1065 BeforeEach(func() { 1066 fakeListenerFactory.ListenReturns(nil, errors.New("failure is an option")) 1067 }) 1068 1069 It("returns the error", func() { 1070 Expect(localForwardError).To(MatchError("failure is an option")) 1071 }) 1072 }) 1073 1074 Context("when the client it closed", func() { 1075 BeforeEach(func() { 1076 fakeListenerFactory.ListenReturns(fakeLocalListener, nil) 1077 fakeLocalListener.AcceptReturns(nil, errors.New("not accepting and connections")) 1078 }) 1079 1080 It("closes the listener when the client is closed", func() { 1081 Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(1)) 1082 Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(1)) 1083 1084 originalCloseCount := fakeLocalListener.CloseCallCount() 1085 err := secureShell.Close() 1086 Expect(err).NotTo(HaveOccurred()) 1087 Expect(fakeLocalListener.CloseCallCount()).Should(Equal(originalCloseCount + 1)) 1088 }) 1089 }) 1090 1091 Context("when accept fails", func() { 1092 var fakeConn *fake_net.FakeConn 1093 BeforeEach(func() { 1094 fakeConn = &fake_net.FakeConn{} 1095 fakeConn.ReadReturns(0, io.EOF) 1096 1097 fakeListenerFactory.ListenReturns(fakeLocalListener, nil) 1098 }) 1099 1100 Context("with a permanent error", func() { 1101 BeforeEach(func() { 1102 fakeLocalListener.AcceptReturns(nil, errors.New("boom")) 1103 }) 1104 1105 It("stops trying to accept connections", func() { 1106 Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(1)) 1107 Consistently(fakeLocalListener.AcceptCallCount).Should(Equal(1)) 1108 Expect(fakeLocalListener.CloseCallCount()).To(Equal(1)) 1109 }) 1110 }) 1111 1112 Context("with a temporary error", func() { 1113 var timeCh chan time.Time 1114 1115 BeforeEach(func() { 1116 timeCh = make(chan time.Time, 3) 1117 1118 fakeLocalListener.AcceptStub = func() (net.Conn, error) { 1119 timeCh := timeCh 1120 if fakeLocalListener.AcceptCallCount() > 3 { 1121 close(timeCh) 1122 return nil, test_helpers.NewTestNetError(false, false) 1123 } else { 1124 timeCh <- time.Now() 1125 return nil, test_helpers.NewTestNetError(false, true) 1126 } 1127 } 1128 }) 1129 1130 It("retries connecting after a short delay", func() { 1131 Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(3)) 1132 Expect(timeCh).To(HaveLen(3)) 1133 1134 times := make([]time.Time, 0) 1135 for t := range timeCh { 1136 times = append(times, t) 1137 } 1138 1139 // Expected interval time = 100 msec 1140 Expect(times[1]).To(BeTemporally(">=", times[0])) 1141 Expect(times[1]).To(BeTemporally("<=", times[0].Add(500*time.Millisecond))) 1142 Expect(times[2]).To(BeTemporally(">=", times[1])) 1143 Expect(times[2]).To(BeTemporally("<=", times[1].Add(500*time.Millisecond))) 1144 }) 1145 }) 1146 }) 1147 1148 Context("when dialing the connect address fails", func() { 1149 var fakeTarget *fake_net.FakeConn 1150 1151 BeforeEach(func() { 1152 fakeTarget = &fake_net.FakeConn{} 1153 fakeSecureClient.DialReturns(fakeTarget, errors.New("boom")) 1154 }) 1155 1156 It("does not call close on the target connection", func() { 1157 Consistently(fakeTarget.CloseCallCount).Should(Equal(0)) 1158 }) 1159 }) 1160 }) 1161 1162 Describe("Wait", func() { 1163 var opts *options.SSHOptions 1164 var waitErr error 1165 1166 BeforeEach(func() { 1167 opts = &options.SSHOptions{ 1168 AppName: "app-1", 1169 } 1170 1171 currentApp.State = "STARTED" 1172 1173 sshEndpointFingerprint = "" 1174 sshEndpoint = "" 1175 1176 token = "" 1177 }) 1178 1179 JustBeforeEach(func() { 1180 connectErr := secureShell.Connect(opts) 1181 Expect(connectErr).NotTo(HaveOccurred()) 1182 1183 waitErr = secureShell.Wait() 1184 }) 1185 1186 It("calls wait on the secureClient", func() { 1187 Expect(waitErr).NotTo(HaveOccurred()) 1188 Expect(fakeSecureClient.WaitCallCount()).To(Equal(1)) 1189 }) 1190 1191 Describe("keep alive messages", func() { 1192 var times []time.Time 1193 var timesCh chan []time.Time 1194 var done chan struct{} 1195 1196 BeforeEach(func() { 1197 keepAliveDuration = 100 * time.Millisecond 1198 1199 times = []time.Time{} 1200 timesCh = make(chan []time.Time, 1) 1201 done = make(chan struct{}, 1) 1202 1203 fakeConnection.SendRequestStub = func(reqName string, wantReply bool, message []byte) (bool, []byte, error) { 1204 Expect(reqName).To(Equal("keepalive@cloudfoundry.org")) 1205 Expect(wantReply).To(BeTrue()) 1206 Expect(message).To(BeNil()) 1207 1208 times = append(times, time.Now()) 1209 if len(times) == 3 { 1210 timesCh <- times 1211 close(done) 1212 } 1213 return true, nil, nil 1214 } 1215 1216 fakeSecureClient.WaitStub = func() error { 1217 Eventually(done).Should(BeClosed()) 1218 return nil 1219 } 1220 }) 1221 1222 It("sends keep alive messages at the expected interval", func() { 1223 Expect(waitErr).NotTo(HaveOccurred()) 1224 times := <-timesCh 1225 1226 // Expected interval time = 100 msec 1227 Expect(times[1]).To(BeTemporally(">=", times[0])) 1228 Expect(times[1]).To(BeTemporally("<=", times[0].Add(500*time.Millisecond))) 1229 Expect(times[2]).To(BeTemporally(">=", times[1])) 1230 Expect(times[2]).To(BeTemporally("<=", times[1].Add(500*time.Millisecond))) 1231 }) 1232 }) 1233 }) 1234 1235 Describe("Close", func() { 1236 var opts *options.SSHOptions 1237 1238 BeforeEach(func() { 1239 opts = &options.SSHOptions{ 1240 AppName: "app-1", 1241 } 1242 1243 currentApp.State = "STARTED" 1244 1245 sshEndpointFingerprint = "" 1246 sshEndpoint = "" 1247 1248 token = "" 1249 }) 1250 1251 JustBeforeEach(func() { 1252 connectErr := secureShell.Connect(opts) 1253 Expect(connectErr).NotTo(HaveOccurred()) 1254 }) 1255 1256 It("calls close on the secureClient", func() { 1257 err := secureShell.Close() 1258 Expect(err).NotTo(HaveOccurred()) 1259 1260 Expect(fakeSecureClient.CloseCallCount()).To(Equal(1)) 1261 }) 1262 }) 1263 })