github.com/loggregator/cli@v6.33.1-0.20180224010324-82334f081791+incompatible/cf/ssh/ssh_test.go (about) 1 // +build !windows,!386 2 3 // skipping 386 because lager uses UInt64 in Session() 4 // skipping windows because Unix/Linux only syscall in test. 5 // should refactor out the conflicts so we could test this package in multi platforms. 6 7 package sshCmd_test 8 9 import ( 10 "errors" 11 "fmt" 12 "io" 13 "net" 14 "os" 15 "syscall" 16 "time" 17 18 "code.cloudfoundry.org/cli/cf/models" 19 "code.cloudfoundry.org/cli/cf/ssh" 20 "code.cloudfoundry.org/cli/cf/ssh/options" 21 "code.cloudfoundry.org/cli/cf/ssh/sshfakes" 22 "code.cloudfoundry.org/cli/cf/ssh/terminal" 23 "code.cloudfoundry.org/cli/cf/ssh/terminal/terminalfakes" 24 "code.cloudfoundry.org/diego-ssh/server" 25 fake_server "code.cloudfoundry.org/diego-ssh/server/fakes" 26 "code.cloudfoundry.org/diego-ssh/test_helpers" 27 "code.cloudfoundry.org/diego-ssh/test_helpers/fake_io" 28 "code.cloudfoundry.org/diego-ssh/test_helpers/fake_net" 29 "code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh" 30 "code.cloudfoundry.org/lager/lagertest" 31 "github.com/kr/pty" 32 "github.com/moby/moby/pkg/term" 33 "golang.org/x/crypto/ssh" 34 35 . "github.com/onsi/ginkgo" 36 . "github.com/onsi/gomega" 37 ) 38 39 var _ = Describe("SSH", func() { 40 var ( 41 fakeTerminalHelper *terminalfakes.FakeTerminalHelper 42 fakeListenerFactory *sshfakes.FakeListenerFactory 43 44 fakeConnection *fake_ssh.FakeConn 45 fakeSecureClient *sshfakes.FakeSecureClient 46 fakeSecureDialer *sshfakes.FakeSecureDialer 47 fakeSecureSession *sshfakes.FakeSecureSession 48 49 terminalHelper terminal.TerminalHelper 50 keepAliveDuration time.Duration 51 secureShell sshCmd.SecureShell 52 53 stdinPipe *fake_io.FakeWriteCloser 54 55 currentApp models.Application 56 sshEndpointFingerprint string 57 sshEndpoint string 58 token string 59 ) 60 61 BeforeEach(func() { 62 fakeTerminalHelper = new(terminalfakes.FakeTerminalHelper) 63 terminalHelper = terminal.DefaultHelper() 64 65 fakeListenerFactory = new(sshfakes.FakeListenerFactory) 66 fakeListenerFactory.ListenStub = net.Listen 67 68 keepAliveDuration = 30 * time.Second 69 70 currentApp = models.Application{} 71 sshEndpoint = "" 72 sshEndpointFingerprint = "" 73 token = "" 74 75 fakeConnection = new(fake_ssh.FakeConn) 76 fakeSecureClient = new(sshfakes.FakeSecureClient) 77 fakeSecureDialer = new(sshfakes.FakeSecureDialer) 78 fakeSecureSession = new(sshfakes.FakeSecureSession) 79 80 fakeSecureDialer.DialReturns(fakeSecureClient, nil) 81 fakeSecureClient.NewSessionReturns(fakeSecureSession, nil) 82 fakeSecureClient.ConnReturns(fakeConnection) 83 84 stdinPipe = &fake_io.FakeWriteCloser{} 85 stdinPipe.WriteStub = func(p []byte) (int, error) { 86 return len(p), nil 87 } 88 89 stdoutPipe := &fake_io.FakeReader{} 90 stdoutPipe.ReadStub = func(p []byte) (int, error) { 91 return 0, io.EOF 92 } 93 94 stderrPipe := &fake_io.FakeReader{} 95 stderrPipe.ReadStub = func(p []byte) (int, error) { 96 return 0, io.EOF 97 } 98 99 fakeSecureSession.StdinPipeReturns(stdinPipe, nil) 100 fakeSecureSession.StdoutPipeReturns(stdoutPipe, nil) 101 fakeSecureSession.StderrPipeReturns(stderrPipe, nil) 102 }) 103 104 JustBeforeEach(func() { 105 secureShell = sshCmd.NewSecureShell( 106 fakeSecureDialer, 107 terminalHelper, 108 fakeListenerFactory, 109 keepAliveDuration, 110 currentApp, 111 sshEndpointFingerprint, 112 sshEndpoint, 113 token, 114 ) 115 }) 116 117 Describe("Validation", func() { 118 var connectErr error 119 var opts *options.SSHOptions 120 121 BeforeEach(func() { 122 opts = &options.SSHOptions{ 123 AppName: "app-1", 124 } 125 }) 126 127 JustBeforeEach(func() { 128 connectErr = secureShell.Connect(opts) 129 }) 130 131 Context("when the app model and endpoint info are successfully acquired", func() { 132 BeforeEach(func() { 133 token = "" 134 currentApp.State = "STARTED" 135 currentApp.Diego = true 136 }) 137 138 Context("when the app is not in the 'STARTED' state", func() { 139 BeforeEach(func() { 140 currentApp.State = "STOPPED" 141 currentApp.Diego = true 142 }) 143 144 It("returns an error", func() { 145 Expect(connectErr).To(MatchError(MatchRegexp("Application.*not in the STARTED state"))) 146 }) 147 }) 148 149 Context("when the app is not a Diego app", func() { 150 BeforeEach(func() { 151 currentApp.State = "STARTED" 152 currentApp.Diego = false 153 }) 154 155 It("returns an error", func() { 156 Expect(connectErr).To(MatchError(MatchRegexp("Application.*not running on Diego"))) 157 }) 158 }) 159 160 Context("when dialing fails", func() { 161 var dialError = errors.New("woops") 162 163 BeforeEach(func() { 164 fakeSecureDialer.DialReturns(nil, dialError) 165 }) 166 167 It("returns the dial error", func() { 168 Expect(connectErr).To(Equal(dialError)) 169 Expect(fakeSecureDialer.DialCallCount()).To(Equal(1)) 170 }) 171 }) 172 }) 173 }) 174 175 Describe("InteractiveSession", func() { 176 var opts *options.SSHOptions 177 var sessionError error 178 var interactiveSessionInvoker func(secureShell sshCmd.SecureShell) 179 180 BeforeEach(func() { 181 sshEndpoint = "ssh.example.com:22" 182 183 opts = &options.SSHOptions{ 184 AppName: "app-name", 185 Index: 2, 186 } 187 188 currentApp.State = "STARTED" 189 currentApp.Diego = true 190 currentApp.GUID = "app-guid" 191 token = "bearer token" 192 193 interactiveSessionInvoker = func(secureShell sshCmd.SecureShell) { 194 sessionError = secureShell.InteractiveSession() 195 } 196 }) 197 198 JustBeforeEach(func() { 199 connectErr := secureShell.Connect(opts) 200 Expect(connectErr).NotTo(HaveOccurred()) 201 interactiveSessionInvoker(secureShell) 202 }) 203 204 It("dials the correct endpoint as the correct user", func() { 205 Expect(fakeSecureDialer.DialCallCount()).To(Equal(1)) 206 207 network, address, config := fakeSecureDialer.DialArgsForCall(0) 208 Expect(network).To(Equal("tcp")) 209 Expect(address).To(Equal("ssh.example.com:22")) 210 Expect(config.Auth).NotTo(BeEmpty()) 211 Expect(config.User).To(Equal("cf:app-guid/2")) 212 Expect(config.HostKeyCallback).NotTo(BeNil()) 213 }) 214 215 Context("when host key validation is enabled", func() { 216 var callback func(hostname string, remote net.Addr, key ssh.PublicKey) error 217 var addr net.Addr 218 219 JustBeforeEach(func() { 220 Expect(fakeSecureDialer.DialCallCount()).To(Equal(1)) 221 _, _, config := fakeSecureDialer.DialArgsForCall(0) 222 callback = config.HostKeyCallback 223 224 listener, err := net.Listen("tcp", "localhost:0") 225 Expect(err).NotTo(HaveOccurred()) 226 227 addr = listener.Addr() 228 listener.Close() 229 }) 230 231 Context("when the md5 fingerprint matches", func() { 232 BeforeEach(func() { 233 sshEndpointFingerprint = "41:ce:56:e6:9c:42:a9:c6:9e:68:ac:e3:4d:f6:38:79" 234 }) 235 236 It("does not return an error", func() { 237 Expect(callback("", addr, TestHostKey.PublicKey())).ToNot(HaveOccurred()) 238 }) 239 }) 240 241 Context("when the hex sha1 fingerprint matches", func() { 242 BeforeEach(func() { 243 sshEndpointFingerprint = "a8:e2:67:cb:ea:2a:6e:23:a1:72:ce:8f:07:92:15:ee:1f:82:f8:ca" 244 }) 245 246 It("does not return an error", func() { 247 Expect(callback("", addr, TestHostKey.PublicKey())).ToNot(HaveOccurred()) 248 }) 249 }) 250 251 Context("when the base64 sha256 fingerprint matches", func() { 252 BeforeEach(func() { 253 sshEndpointFingerprint = "sp/jrLuj66r+yrLDUKZdJU5tdzt4mq/UaSiNBjpgr+8" 254 }) 255 256 It("does not return an error", func() { 257 Expect(callback("", addr, TestHostKey.PublicKey())).ToNot(HaveOccurred()) 258 }) 259 }) 260 261 Context("when the base64 SHA256 fingerprint does not match", func() { 262 BeforeEach(func() { 263 sshEndpointFingerprint = "0000000000000000000000000000000000000000000" 264 }) 265 266 It("returns an error'", func() { 267 err := callback("", addr, TestHostKey.PublicKey()) 268 Expect(err).To(MatchError(MatchRegexp("Host key verification failed\\."))) 269 Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\""))) 270 }) 271 }) 272 273 Context("when the hex SHA1 fingerprint does not match", func() { 274 BeforeEach(func() { 275 sshEndpointFingerprint = "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00" 276 }) 277 278 It("returns an error'", func() { 279 err := callback("", addr, TestHostKey.PublicKey()) 280 Expect(err).To(MatchError(MatchRegexp("Host key verification failed\\."))) 281 Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\""))) 282 }) 283 }) 284 285 Context("when the MD5 fingerprint does not match", func() { 286 BeforeEach(func() { 287 sshEndpointFingerprint = "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00" 288 }) 289 290 It("returns an error'", func() { 291 err := callback("", addr, TestHostKey.PublicKey()) 292 Expect(err).To(MatchError(MatchRegexp("Host key verification failed\\."))) 293 Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\""))) 294 }) 295 }) 296 297 Context("when no fingerprint is present in endpoint info", func() { 298 BeforeEach(func() { 299 sshEndpointFingerprint = "" 300 sshEndpoint = "" 301 }) 302 303 It("returns an error'", func() { 304 err := callback("", addr, TestHostKey.PublicKey()) 305 Expect(err).To(MatchError(MatchRegexp("Unable to verify identity of host\\."))) 306 Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\""))) 307 }) 308 }) 309 310 Context("when the fingerprint length doesn't make sense", func() { 311 BeforeEach(func() { 312 sshEndpointFingerprint = "garbage" 313 }) 314 315 It("returns an error", func() { 316 err := callback("", addr, TestHostKey.PublicKey()) 317 Eventually(err).Should(MatchError(MatchRegexp("Unsupported host key fingerprint format"))) 318 }) 319 }) 320 }) 321 322 Context("when the skip host validation flag is set", func() { 323 BeforeEach(func() { 324 opts.SkipHostValidation = true 325 }) 326 327 It("removes the HostKeyCallback from the client config", func() { 328 Expect(fakeSecureDialer.DialCallCount()).To(Equal(1)) 329 330 _, _, config := fakeSecureDialer.DialArgsForCall(0) 331 Expect(config.HostKeyCallback("some-addr", nil, nil)).To(BeNil()) 332 }) 333 }) 334 335 Context("when dialing is successful", func() { 336 BeforeEach(func() { 337 fakeTerminalHelper.StdStreamsStub = terminalHelper.StdStreams 338 terminalHelper = fakeTerminalHelper 339 }) 340 341 It("creates a new secure shell session", func() { 342 Expect(fakeSecureClient.NewSessionCallCount()).To(Equal(1)) 343 }) 344 345 It("closes the session", func() { 346 Expect(fakeSecureSession.CloseCallCount()).To(Equal(1)) 347 }) 348 349 It("allocates standard streams", func() { 350 Expect(fakeTerminalHelper.StdStreamsCallCount()).To(Equal(1)) 351 }) 352 353 It("gets a stdin pipe for the session", func() { 354 Expect(fakeSecureSession.StdinPipeCallCount()).To(Equal(1)) 355 }) 356 357 Context("when getting the stdin pipe fails", func() { 358 BeforeEach(func() { 359 fakeSecureSession.StdinPipeReturns(nil, errors.New("woops")) 360 }) 361 362 It("returns the error", func() { 363 Expect(sessionError).Should(MatchError("woops")) 364 }) 365 }) 366 367 It("gets a stdout pipe for the session", func() { 368 Expect(fakeSecureSession.StdoutPipeCallCount()).To(Equal(1)) 369 }) 370 371 Context("when getting the stdout pipe fails", func() { 372 BeforeEach(func() { 373 fakeSecureSession.StdoutPipeReturns(nil, errors.New("woops")) 374 }) 375 376 It("returns the error", func() { 377 Expect(sessionError).Should(MatchError("woops")) 378 }) 379 }) 380 381 It("gets a stderr pipe for the session", func() { 382 Expect(fakeSecureSession.StderrPipeCallCount()).To(Equal(1)) 383 }) 384 385 Context("when getting the stderr pipe fails", func() { 386 BeforeEach(func() { 387 fakeSecureSession.StderrPipeReturns(nil, errors.New("woops")) 388 }) 389 390 It("returns the error", func() { 391 Expect(sessionError).Should(MatchError("woops")) 392 }) 393 }) 394 }) 395 396 Context("when stdin is a terminal", func() { 397 var master, slave *os.File 398 399 BeforeEach(func() { 400 _, stdout, stderr := terminalHelper.StdStreams() 401 402 var err error 403 master, slave, err = pty.Open() 404 Expect(err).NotTo(HaveOccurred()) 405 406 fakeTerminalHelper.IsTerminalStub = terminalHelper.IsTerminal 407 fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo 408 fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize 409 fakeTerminalHelper.StdStreamsReturns(slave, stdout, stderr) 410 terminalHelper = fakeTerminalHelper 411 }) 412 413 AfterEach(func() { 414 master.Close() 415 // slave.Close() // race 416 }) 417 418 Context("when a command is not specified", func() { 419 var terminalType string 420 421 BeforeEach(func() { 422 terminalType = os.Getenv("TERM") 423 os.Setenv("TERM", "test-terminal-type") 424 425 winsize := &term.Winsize{Width: 1024, Height: 256} 426 fakeTerminalHelper.GetWinsizeReturns(winsize, nil) 427 428 fakeSecureSession.ShellStub = func() error { 429 Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1)) 430 Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(0)) 431 return nil 432 } 433 }) 434 435 AfterEach(func() { 436 os.Setenv("TERM", terminalType) 437 }) 438 439 It("requests a pty with the correct terminal type, window size, and modes", func() { 440 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1)) 441 Expect(fakeTerminalHelper.GetWinsizeCallCount()).To(Equal(1)) 442 443 termType, height, width, modes := fakeSecureSession.RequestPtyArgsForCall(0) 444 Expect(termType).To(Equal("test-terminal-type")) 445 Expect(height).To(Equal(256)) 446 Expect(width).To(Equal(1024)) 447 448 expectedModes := ssh.TerminalModes{ 449 ssh.ECHO: 1, 450 ssh.TTY_OP_ISPEED: 115200, 451 ssh.TTY_OP_OSPEED: 115200, 452 } 453 Expect(modes).To(Equal(expectedModes)) 454 }) 455 456 Context("when the TERM environment variable is not set", func() { 457 BeforeEach(func() { 458 os.Unsetenv("TERM") 459 }) 460 461 It("requests a pty with the default terminal type", func() { 462 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1)) 463 464 termType, _, _, _ := fakeSecureSession.RequestPtyArgsForCall(0) 465 Expect(termType).To(Equal("xterm")) 466 }) 467 }) 468 469 It("puts the terminal into raw mode and restores it after running the shell", func() { 470 Expect(fakeSecureSession.ShellCallCount()).To(Equal(1)) 471 Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1)) 472 Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(1)) 473 }) 474 475 Context("when the pty allocation fails", func() { 476 var ptyError error 477 478 BeforeEach(func() { 479 ptyError = errors.New("pty allocation error") 480 fakeSecureSession.RequestPtyReturns(ptyError) 481 }) 482 483 It("returns the error", func() { 484 Expect(sessionError).To(Equal(ptyError)) 485 }) 486 }) 487 488 Context("when placing the terminal into raw mode fails", func() { 489 BeforeEach(func() { 490 fakeTerminalHelper.SetRawTerminalReturns(nil, errors.New("woops")) 491 }) 492 493 It("keeps calm and carries on", func() { 494 Expect(fakeSecureSession.ShellCallCount()).To(Equal(1)) 495 }) 496 497 It("does not not restore the terminal", func() { 498 Expect(fakeSecureSession.ShellCallCount()).To(Equal(1)) 499 Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1)) 500 Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(0)) 501 }) 502 }) 503 }) 504 505 Context("when a command is specified", func() { 506 BeforeEach(func() { 507 opts.Command = []string{"echo", "-n", "hello"} 508 }) 509 510 Context("when a terminal is requested", func() { 511 BeforeEach(func() { 512 opts.TerminalRequest = options.RequestTTYYes 513 }) 514 515 It("requests a pty", func() { 516 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1)) 517 }) 518 }) 519 520 Context("when a terminal is not explicitly requested", func() { 521 It("does not request a pty", func() { 522 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0)) 523 }) 524 }) 525 }) 526 }) 527 528 Context("when stdin is not a terminal", func() { 529 BeforeEach(func() { 530 _, stdout, stderr := terminalHelper.StdStreams() 531 532 stdin := &fake_io.FakeReadCloser{} 533 stdin.ReadStub = func(p []byte) (int, error) { 534 return 0, io.EOF 535 } 536 537 fakeTerminalHelper.IsTerminalStub = terminalHelper.IsTerminal 538 fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo 539 fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize 540 fakeTerminalHelper.StdStreamsReturns(stdin, stdout, stderr) 541 terminalHelper = fakeTerminalHelper 542 }) 543 544 Context("when a terminal is not requested", func() { 545 It("does not request a pty", func() { 546 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0)) 547 }) 548 }) 549 550 Context("when a terminal is requested", func() { 551 BeforeEach(func() { 552 opts.TerminalRequest = options.RequestTTYYes 553 }) 554 555 It("does not request a pty", func() { 556 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0)) 557 }) 558 }) 559 }) 560 561 Context("when a terminal is forced", func() { 562 BeforeEach(func() { 563 opts.TerminalRequest = options.RequestTTYForce 564 }) 565 566 It("requests a pty", func() { 567 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1)) 568 }) 569 }) 570 571 Context("when a terminal is disabled", func() { 572 BeforeEach(func() { 573 opts.TerminalRequest = options.RequestTTYNo 574 }) 575 576 It("does not request a pty", func() { 577 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0)) 578 }) 579 }) 580 581 Context("when a command is not specified", func() { 582 It("requests an interactive shell", func() { 583 Expect(fakeSecureSession.ShellCallCount()).To(Equal(1)) 584 }) 585 586 Context("when the shell request returns an error", func() { 587 BeforeEach(func() { 588 fakeSecureSession.ShellReturns(errors.New("oh bother")) 589 }) 590 591 It("returns the error", func() { 592 Expect(sessionError).To(MatchError("oh bother")) 593 }) 594 }) 595 }) 596 597 Context("when a command is specifed", func() { 598 BeforeEach(func() { 599 opts.Command = []string{"echo", "-n", "hello"} 600 }) 601 602 It("starts the command", func() { 603 Expect(fakeSecureSession.StartCallCount()).To(Equal(1)) 604 Expect(fakeSecureSession.StartArgsForCall(0)).To(Equal("echo -n hello")) 605 }) 606 607 Context("when the command fails to start", func() { 608 BeforeEach(func() { 609 fakeSecureSession.StartReturns(errors.New("oh well")) 610 }) 611 612 It("returns the error", func() { 613 Expect(sessionError).To(MatchError("oh well")) 614 }) 615 }) 616 }) 617 618 Context("when the shell or command has started", func() { 619 var ( 620 stdin *fake_io.FakeReadCloser 621 stdout, stderr *fake_io.FakeWriter 622 stdinPipe *fake_io.FakeWriteCloser 623 stdoutPipe, stderrPipe *fake_io.FakeReader 624 ) 625 626 BeforeEach(func() { 627 stdin = &fake_io.FakeReadCloser{} 628 stdin.ReadStub = func(p []byte) (int, error) { 629 p[0] = 0 630 return 1, io.EOF 631 } 632 stdinPipe = &fake_io.FakeWriteCloser{} 633 stdinPipe.WriteStub = func(p []byte) (int, error) { 634 defer GinkgoRecover() 635 Expect(p[0]).To(Equal(byte(0))) 636 return 1, nil 637 } 638 639 stdoutPipe = &fake_io.FakeReader{} 640 stdoutPipe.ReadStub = func(p []byte) (int, error) { 641 p[0] = 1 642 return 1, io.EOF 643 } 644 stdout = &fake_io.FakeWriter{} 645 stdout.WriteStub = func(p []byte) (int, error) { 646 defer GinkgoRecover() 647 Expect(p[0]).To(Equal(byte(1))) 648 return 1, nil 649 } 650 651 stderrPipe = &fake_io.FakeReader{} 652 stderrPipe.ReadStub = func(p []byte) (int, error) { 653 p[0] = 2 654 return 1, io.EOF 655 } 656 stderr = &fake_io.FakeWriter{} 657 stderr.WriteStub = func(p []byte) (int, error) { 658 defer GinkgoRecover() 659 Expect(p[0]).To(Equal(byte(2))) 660 return 1, nil 661 } 662 663 fakeTerminalHelper.StdStreamsReturns(stdin, stdout, stderr) 664 terminalHelper = fakeTerminalHelper 665 666 fakeSecureSession.StdinPipeReturns(stdinPipe, nil) 667 fakeSecureSession.StdoutPipeReturns(stdoutPipe, nil) 668 fakeSecureSession.StderrPipeReturns(stderrPipe, nil) 669 670 fakeSecureSession.WaitReturns(errors.New("error result")) 671 }) 672 673 It("copies data from the stdin stream to the session stdin pipe", func() { 674 Eventually(stdin.ReadCallCount).Should(Equal(1)) 675 Eventually(stdinPipe.WriteCallCount).Should(Equal(1)) 676 }) 677 678 It("copies data from the session stdout pipe to the stdout stream", func() { 679 Eventually(stdoutPipe.ReadCallCount).Should(Equal(1)) 680 Eventually(stdout.WriteCallCount).Should(Equal(1)) 681 }) 682 683 It("copies data from the session stderr pipe to the stderr stream", func() { 684 Eventually(stderrPipe.ReadCallCount).Should(Equal(1)) 685 Eventually(stderr.WriteCallCount).Should(Equal(1)) 686 }) 687 688 It("waits for the session to end", func() { 689 Expect(fakeSecureSession.WaitCallCount()).To(Equal(1)) 690 }) 691 692 It("returns the result from wait", func() { 693 Expect(sessionError).To(MatchError("error result")) 694 }) 695 696 Context("when the session terminates before stream copies complete", func() { 697 var sessionErrorCh chan error 698 699 BeforeEach(func() { 700 sessionErrorCh = make(chan error, 1) 701 702 interactiveSessionInvoker = func(secureShell sshCmd.SecureShell) { 703 go func() { sessionErrorCh <- secureShell.InteractiveSession() }() 704 } 705 706 stdoutPipe.ReadStub = func(p []byte) (int, error) { 707 defer GinkgoRecover() 708 Eventually(fakeSecureSession.WaitCallCount).Should(Equal(1)) 709 Consistently(sessionErrorCh).ShouldNot(Receive()) 710 711 p[0] = 1 712 return 1, io.EOF 713 } 714 715 stderrPipe.ReadStub = func(p []byte) (int, error) { 716 defer GinkgoRecover() 717 Eventually(fakeSecureSession.WaitCallCount).Should(Equal(1)) 718 Consistently(sessionErrorCh).ShouldNot(Receive()) 719 720 p[0] = 2 721 return 1, io.EOF 722 } 723 }) 724 725 It("waits for the copies to complete", func() { 726 Eventually(sessionErrorCh).Should(Receive()) 727 Expect(stdoutPipe.ReadCallCount()).To(Equal(1)) 728 Expect(stderrPipe.ReadCallCount()).To(Equal(1)) 729 }) 730 }) 731 732 Context("when stdin is closed", func() { 733 BeforeEach(func() { 734 stdin.ReadStub = func(p []byte) (int, error) { 735 defer GinkgoRecover() 736 Consistently(stdinPipe.CloseCallCount).Should(Equal(0)) 737 p[0] = 0 738 return 1, io.EOF 739 } 740 }) 741 742 It("closes the stdinPipe", func() { 743 Eventually(stdinPipe.CloseCallCount).Should(Equal(1)) 744 }) 745 }) 746 }) 747 748 Context("when stdout is a terminal and a window size change occurs", func() { 749 var master, slave *os.File 750 751 BeforeEach(func() { 752 stdin, _, stderr := terminalHelper.StdStreams() 753 754 var err error 755 master, slave, err = pty.Open() 756 Expect(err).NotTo(HaveOccurred()) 757 758 fakeTerminalHelper.IsTerminalStub = terminalHelper.IsTerminal 759 fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo 760 fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize 761 fakeTerminalHelper.StdStreamsReturns(stdin, slave, stderr) 762 terminalHelper = fakeTerminalHelper 763 764 winsize := &term.Winsize{Height: 100, Width: 100} 765 err = term.SetWinsize(slave.Fd(), winsize) 766 Expect(err).NotTo(HaveOccurred()) 767 768 fakeSecureSession.WaitStub = func() error { 769 fakeSecureSession.SendRequestCallCount() 770 Expect(fakeSecureSession.SendRequestCallCount()).To(Equal(0)) 771 772 // No dimension change 773 for i := 0; i < 3; i++ { 774 winsize := &term.Winsize{Height: 100, Width: 100} 775 err = term.SetWinsize(slave.Fd(), winsize) 776 Expect(err).NotTo(HaveOccurred()) 777 } 778 779 winsize := &term.Winsize{Height: 100, Width: 200} 780 err = term.SetWinsize(slave.Fd(), winsize) 781 Expect(err).NotTo(HaveOccurred()) 782 783 err = syscall.Kill(syscall.Getpid(), syscall.SIGWINCH) 784 Expect(err).NotTo(HaveOccurred()) 785 786 Eventually(fakeSecureSession.SendRequestCallCount).Should(Equal(1)) 787 return nil 788 } 789 }) 790 791 AfterEach(func() { 792 master.Close() 793 slave.Close() 794 }) 795 796 It("sends window change events when the window dimensions change", func() { 797 Expect(fakeSecureSession.SendRequestCallCount()).To(Equal(1)) 798 799 requestType, wantReply, message := fakeSecureSession.SendRequestArgsForCall(0) 800 Expect(requestType).To(Equal("window-change")) 801 Expect(wantReply).To(BeFalse()) 802 803 type resizeMessage struct { 804 Width uint32 805 Height uint32 806 PixelWidth uint32 807 PixelHeight uint32 808 } 809 var resizeMsg resizeMessage 810 811 err := ssh.Unmarshal(message, &resizeMsg) 812 Expect(err).NotTo(HaveOccurred()) 813 814 Expect(resizeMsg).To(Equal(resizeMessage{Height: 100, Width: 200})) 815 }) 816 }) 817 818 Describe("keep alive messages", func() { 819 var times []time.Time 820 var timesCh chan []time.Time 821 var done chan struct{} 822 823 BeforeEach(func() { 824 keepAliveDuration = 100 * time.Millisecond 825 826 times = []time.Time{} 827 timesCh = make(chan []time.Time, 1) 828 done = make(chan struct{}, 1) 829 830 fakeConnection.SendRequestStub = func(reqName string, wantReply bool, message []byte) (bool, []byte, error) { 831 Expect(reqName).To(Equal("keepalive@cloudfoundry.org")) 832 Expect(wantReply).To(BeTrue()) 833 Expect(message).To(BeNil()) 834 835 times = append(times, time.Now()) 836 if len(times) == 3 { 837 timesCh <- times 838 close(done) 839 } 840 return true, nil, nil 841 } 842 843 fakeSecureSession.WaitStub = func() error { 844 Eventually(done).Should(BeClosed()) 845 return nil 846 } 847 }) 848 849 It("sends keep alive messages at the expected interval", func() { 850 times := <-timesCh 851 Expect(times[2]).To(BeTemporally("~", times[0].Add(200*time.Millisecond), 100*time.Millisecond)) 852 }) 853 }) 854 }) 855 856 Describe("LocalPortForward", func() { 857 var ( 858 opts *options.SSHOptions 859 localForwardError error 860 861 echoAddress string 862 echoListener *fake_net.FakeListener 863 echoHandler *fake_server.FakeConnectionHandler 864 echoServer *server.Server 865 866 localAddress string 867 868 realLocalListener net.Listener 869 fakeLocalListener *fake_net.FakeListener 870 ) 871 872 BeforeEach(func() { 873 logger := lagertest.NewTestLogger("test") 874 875 var err error 876 realLocalListener, err = net.Listen("tcp", "127.0.0.1:0") 877 Expect(err).NotTo(HaveOccurred()) 878 879 localAddress = realLocalListener.Addr().String() 880 fakeListenerFactory.ListenReturns(realLocalListener, nil) 881 882 echoHandler = &fake_server.FakeConnectionHandler{} 883 echoHandler.HandleConnectionStub = func(conn net.Conn) { 884 io.Copy(conn, conn) 885 conn.Close() 886 } 887 888 realListener, err := net.Listen("tcp", "127.0.0.1:0") 889 Expect(err).NotTo(HaveOccurred()) 890 echoAddress = realListener.Addr().String() 891 892 echoListener = &fake_net.FakeListener{} 893 echoListener.AcceptStub = realListener.Accept 894 echoListener.CloseStub = realListener.Close 895 echoListener.AddrStub = realListener.Addr 896 897 fakeLocalListener = &fake_net.FakeListener{} 898 fakeLocalListener.AcceptReturns(nil, errors.New("Not Accepting Connections")) 899 900 echoServer = server.NewServer(logger.Session("echo"), "", echoHandler) 901 echoServer.SetListener(echoListener) 902 go echoServer.Serve() 903 904 opts = &options.SSHOptions{ 905 AppName: "app-1", 906 ForwardSpecs: []options.ForwardSpec{{ 907 ListenAddress: localAddress, 908 ConnectAddress: echoAddress, 909 }}, 910 } 911 912 currentApp.State = "STARTED" 913 currentApp.Diego = true 914 915 sshEndpointFingerprint = "" 916 sshEndpoint = "" 917 918 token = "" 919 920 fakeSecureClient.DialStub = net.Dial 921 }) 922 923 JustBeforeEach(func() { 924 connectErr := secureShell.Connect(opts) 925 Expect(connectErr).NotTo(HaveOccurred()) 926 927 localForwardError = secureShell.LocalPortForward() 928 }) 929 930 AfterEach(func() { 931 err := secureShell.Close() 932 Expect(err).NotTo(HaveOccurred()) 933 echoServer.Shutdown() 934 935 realLocalListener.Close() 936 }) 937 938 validateConnectivity := func(addr string) { 939 conn, err := net.Dial("tcp", addr) 940 Expect(err).NotTo(HaveOccurred()) 941 942 msg := fmt.Sprintf("Hello from %s\n", addr) 943 n, err := conn.Write([]byte(msg)) 944 Expect(err).NotTo(HaveOccurred()) 945 Expect(n).To(Equal(len(msg))) 946 947 response := make([]byte, len(msg)) 948 n, err = conn.Read(response) 949 Expect(err).NotTo(HaveOccurred()) 950 Expect(n).To(Equal(len(msg))) 951 952 err = conn.Close() 953 Expect(err).NotTo(HaveOccurred()) 954 955 Expect(response).To(Equal([]byte(msg))) 956 } 957 958 It("dials the connect address when a local connection is made", func() { 959 Expect(localForwardError).NotTo(HaveOccurred()) 960 961 conn, err := net.Dial("tcp", localAddress) 962 Expect(err).NotTo(HaveOccurred()) 963 964 Eventually(echoListener.AcceptCallCount).Should(BeNumerically(">=", 1)) 965 Eventually(fakeSecureClient.DialCallCount).Should(Equal(1)) 966 967 network, addr := fakeSecureClient.DialArgsForCall(0) 968 Expect(network).To(Equal("tcp")) 969 Expect(addr).To(Equal(echoAddress)) 970 971 Expect(conn.Close()).NotTo(HaveOccurred()) 972 }) 973 974 It("copies data between the local and remote connections", func() { 975 validateConnectivity(localAddress) 976 }) 977 978 Context("when a local connection is already open", func() { 979 var ( 980 conn net.Conn 981 err error 982 ) 983 984 JustBeforeEach(func() { 985 conn, err = net.Dial("tcp", localAddress) 986 Expect(err).NotTo(HaveOccurred()) 987 }) 988 989 AfterEach(func() { 990 err = conn.Close() 991 Expect(err).NotTo(HaveOccurred()) 992 }) 993 994 It("allows for new incoming connections as well", func() { 995 validateConnectivity(localAddress) 996 }) 997 }) 998 999 Context("when there are multiple port forward specs", func() { 1000 var realLocalListener2 net.Listener 1001 var localAddress2 string 1002 1003 BeforeEach(func() { 1004 var err error 1005 realLocalListener2, err = net.Listen("tcp", "127.0.0.1:0") 1006 Expect(err).NotTo(HaveOccurred()) 1007 1008 localAddress2 = realLocalListener2.Addr().String() 1009 1010 fakeListenerFactory.ListenStub = func(network, addr string) (net.Listener, error) { 1011 if addr == localAddress { 1012 return realLocalListener, nil 1013 } 1014 1015 if addr == localAddress2 { 1016 return realLocalListener2, nil 1017 } 1018 1019 return nil, errors.New("unexpected address") 1020 } 1021 1022 opts = &options.SSHOptions{ 1023 AppName: "app-1", 1024 ForwardSpecs: []options.ForwardSpec{{ 1025 ListenAddress: localAddress, 1026 ConnectAddress: echoAddress, 1027 }, { 1028 ListenAddress: localAddress2, 1029 ConnectAddress: echoAddress, 1030 }}, 1031 } 1032 }) 1033 1034 AfterEach(func() { 1035 realLocalListener2.Close() 1036 }) 1037 1038 It("listens to all the things", func() { 1039 Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(2)) 1040 1041 network, addr := fakeListenerFactory.ListenArgsForCall(0) 1042 Expect(network).To(Equal("tcp")) 1043 Expect(addr).To(Equal(localAddress)) 1044 1045 network, addr = fakeListenerFactory.ListenArgsForCall(1) 1046 Expect(network).To(Equal("tcp")) 1047 Expect(addr).To(Equal(localAddress2)) 1048 }) 1049 1050 It("forwards to the correct target", func() { 1051 validateConnectivity(localAddress) 1052 validateConnectivity(localAddress2) 1053 }) 1054 1055 Context("when the secure client is closed", func() { 1056 BeforeEach(func() { 1057 fakeListenerFactory.ListenReturns(fakeLocalListener, nil) 1058 fakeLocalListener.AcceptReturns(nil, errors.New("not accepting connections")) 1059 }) 1060 1061 It("closes the listeners ", func() { 1062 Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(2)) 1063 Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(2)) 1064 1065 originalCloseCount := fakeLocalListener.CloseCallCount() 1066 err := secureShell.Close() 1067 Expect(err).NotTo(HaveOccurred()) 1068 Expect(fakeLocalListener.CloseCallCount()).Should(Equal(originalCloseCount + 2)) 1069 }) 1070 }) 1071 }) 1072 1073 Context("when listen fails", func() { 1074 BeforeEach(func() { 1075 fakeListenerFactory.ListenReturns(nil, errors.New("failure is an option")) 1076 }) 1077 1078 It("returns the error", func() { 1079 Expect(localForwardError).To(MatchError("failure is an option")) 1080 }) 1081 }) 1082 1083 Context("when the client it closed", func() { 1084 BeforeEach(func() { 1085 fakeListenerFactory.ListenReturns(fakeLocalListener, nil) 1086 fakeLocalListener.AcceptReturns(nil, errors.New("not accepting and connections")) 1087 }) 1088 1089 It("closes the listener when the client is closed", func() { 1090 Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(1)) 1091 Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(1)) 1092 1093 originalCloseCount := fakeLocalListener.CloseCallCount() 1094 err := secureShell.Close() 1095 Expect(err).NotTo(HaveOccurred()) 1096 Expect(fakeLocalListener.CloseCallCount()).Should(Equal(originalCloseCount + 1)) 1097 }) 1098 }) 1099 1100 Context("when accept fails", func() { 1101 var fakeConn *fake_net.FakeConn 1102 BeforeEach(func() { 1103 fakeConn = &fake_net.FakeConn{} 1104 fakeConn.ReadReturns(0, io.EOF) 1105 1106 fakeListenerFactory.ListenReturns(fakeLocalListener, nil) 1107 }) 1108 1109 Context("with a permanent error", func() { 1110 BeforeEach(func() { 1111 fakeLocalListener.AcceptReturns(nil, errors.New("boom")) 1112 }) 1113 1114 It("stops trying to accept connections", func() { 1115 Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(1)) 1116 Consistently(fakeLocalListener.AcceptCallCount).Should(Equal(1)) 1117 Expect(fakeLocalListener.CloseCallCount()).To(Equal(1)) 1118 }) 1119 }) 1120 1121 Context("with a temporary error", func() { 1122 var timeCh chan time.Time 1123 1124 BeforeEach(func() { 1125 timeCh = make(chan time.Time, 3) 1126 1127 fakeLocalListener.AcceptStub = func() (net.Conn, error) { 1128 timeCh := timeCh 1129 if fakeLocalListener.AcceptCallCount() > 3 { 1130 close(timeCh) 1131 return nil, test_helpers.NewTestNetError(false, false) 1132 } else { 1133 timeCh <- time.Now() 1134 return nil, test_helpers.NewTestNetError(false, true) 1135 } 1136 } 1137 }) 1138 1139 It("retries connecting after a short delay", func() { 1140 Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(3)) 1141 Expect(timeCh).To(HaveLen(3)) 1142 1143 times := make([]time.Time, 0) 1144 for t := range timeCh { 1145 times = append(times, t) 1146 } 1147 1148 Expect(times[1]).To(BeTemporally("~", times[0].Add(115*time.Millisecond), 30*time.Millisecond)) 1149 Expect(times[2]).To(BeTemporally("~", times[1].Add(115*time.Millisecond), 30*time.Millisecond)) 1150 }) 1151 }) 1152 }) 1153 1154 Context("when dialing the connect address fails", func() { 1155 var fakeTarget *fake_net.FakeConn 1156 1157 BeforeEach(func() { 1158 fakeTarget = &fake_net.FakeConn{} 1159 fakeSecureClient.DialReturns(fakeTarget, errors.New("boom")) 1160 }) 1161 1162 It("does not call close on the target connection", func() { 1163 Consistently(fakeTarget.CloseCallCount).Should(Equal(0)) 1164 }) 1165 }) 1166 }) 1167 1168 Describe("Wait", func() { 1169 var opts *options.SSHOptions 1170 var waitErr error 1171 1172 BeforeEach(func() { 1173 opts = &options.SSHOptions{ 1174 AppName: "app-1", 1175 } 1176 1177 currentApp.State = "STARTED" 1178 currentApp.Diego = true 1179 1180 sshEndpointFingerprint = "" 1181 sshEndpoint = "" 1182 1183 token = "" 1184 }) 1185 1186 JustBeforeEach(func() { 1187 connectErr := secureShell.Connect(opts) 1188 Expect(connectErr).NotTo(HaveOccurred()) 1189 1190 waitErr = secureShell.Wait() 1191 }) 1192 1193 It("calls wait on the secureClient", func() { 1194 Expect(waitErr).NotTo(HaveOccurred()) 1195 Expect(fakeSecureClient.WaitCallCount()).To(Equal(1)) 1196 }) 1197 1198 Describe("keep alive messages", func() { 1199 var times []time.Time 1200 var timesCh chan []time.Time 1201 var done chan struct{} 1202 1203 BeforeEach(func() { 1204 keepAliveDuration = 100 * time.Millisecond 1205 1206 times = []time.Time{} 1207 timesCh = make(chan []time.Time, 1) 1208 done = make(chan struct{}, 1) 1209 1210 fakeConnection.SendRequestStub = func(reqName string, wantReply bool, message []byte) (bool, []byte, error) { 1211 Expect(reqName).To(Equal("keepalive@cloudfoundry.org")) 1212 Expect(wantReply).To(BeTrue()) 1213 Expect(message).To(BeNil()) 1214 1215 times = append(times, time.Now()) 1216 if len(times) == 3 { 1217 timesCh <- times 1218 close(done) 1219 } 1220 return true, nil, nil 1221 } 1222 1223 fakeSecureClient.WaitStub = func() error { 1224 Eventually(done).Should(BeClosed()) 1225 return nil 1226 } 1227 }) 1228 1229 It("sends keep alive messages at the expected interval", func() { 1230 Expect(waitErr).NotTo(HaveOccurred()) 1231 times := <-timesCh 1232 Expect(times[2]).To(BeTemporally("~", times[0].Add(200*time.Millisecond), 100*time.Millisecond)) 1233 }) 1234 }) 1235 }) 1236 1237 Describe("Close", func() { 1238 var opts *options.SSHOptions 1239 1240 BeforeEach(func() { 1241 opts = &options.SSHOptions{ 1242 AppName: "app-1", 1243 } 1244 1245 currentApp.State = "STARTED" 1246 currentApp.Diego = true 1247 1248 sshEndpointFingerprint = "" 1249 sshEndpoint = "" 1250 1251 token = "" 1252 }) 1253 1254 JustBeforeEach(func() { 1255 connectErr := secureShell.Connect(opts) 1256 Expect(connectErr).NotTo(HaveOccurred()) 1257 }) 1258 1259 It("calls close on the secureClient", func() { 1260 err := secureShell.Close() 1261 Expect(err).NotTo(HaveOccurred()) 1262 1263 Expect(fakeSecureClient.CloseCallCount()).To(Equal(1)) 1264 }) 1265 }) 1266 })