github.com/ablease/cli@v6.37.1-0.20180613014814-3adbb7d7fb19+incompatible/util/clissh/ssh_test.go (about) 1 // +build !windows,!386 2 3 // skipping 386 because lager uses UInt64 in Session() 4 // skipping windows because Unix/Linux only syscall in test. 5 // should refactor out the conflicts so we could test this package in multi platforms. 6 7 package clissh_test 8 9 import ( 10 "errors" 11 "fmt" 12 "io" 13 "net" 14 "os" 15 "sync" 16 "syscall" 17 "time" 18 19 "code.cloudfoundry.org/cli/util/clissh/clisshfakes" 20 "code.cloudfoundry.org/cli/util/clissh/ssherror" 21 "code.cloudfoundry.org/diego-ssh/server" 22 fake_server "code.cloudfoundry.org/diego-ssh/server/fakes" 23 "code.cloudfoundry.org/diego-ssh/test_helpers" 24 "code.cloudfoundry.org/diego-ssh/test_helpers/fake_io" 25 "code.cloudfoundry.org/diego-ssh/test_helpers/fake_net" 26 "code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh" 27 "code.cloudfoundry.org/lager/lagertest" 28 "github.com/kr/pty" 29 "github.com/moby/moby/pkg/term" 30 "golang.org/x/crypto/ssh" 31 32 . "code.cloudfoundry.org/cli/util/clissh" 33 . "github.com/onsi/ginkgo" 34 . "github.com/onsi/gomega" 35 ) 36 37 func BlockAcceptOnClose(fake *fake_net.FakeListener) { 38 waitUntilClosed := make(chan bool) 39 fake.AcceptStub = func() (net.Conn, error) { 40 <-waitUntilClosed 41 return nil, errors.New("banana") 42 } 43 44 var once sync.Once 45 fake.CloseStub = func() error { 46 once.Do(func() { 47 close(waitUntilClosed) 48 }) 49 return nil 50 } 51 } 52 53 var _ = Describe("CLI SSH", func() { 54 var ( 55 fakeSecureDialer *clisshfakes.FakeSecureDialer 56 fakeSecureClient *clisshfakes.FakeSecureClient 57 fakeTerminalHelper *clisshfakes.FakeTerminalHelper 58 fakeListenerFactory *clisshfakes.FakeListenerFactory 59 fakeSecureSession *clisshfakes.FakeSecureSession 60 61 fakeConnection *fake_ssh.FakeConn 62 stdinPipe *fake_io.FakeWriteCloser 63 stdoutPipe *fake_io.FakeReader 64 stderrPipe *fake_io.FakeReader 65 secureShell *SecureShell 66 67 username string 68 passcode string 69 sshEndpoint string 70 sshEndpointFingerprint string 71 skipHostValidation bool 72 commands []string 73 terminalRequest TTYRequest 74 keepAliveDuration time.Duration 75 ) 76 77 BeforeEach(func() { 78 fakeSecureDialer = new(clisshfakes.FakeSecureDialer) 79 fakeSecureClient = new(clisshfakes.FakeSecureClient) 80 fakeTerminalHelper = new(clisshfakes.FakeTerminalHelper) 81 fakeListenerFactory = new(clisshfakes.FakeListenerFactory) 82 fakeSecureSession = new(clisshfakes.FakeSecureSession) 83 84 fakeConnection = new(fake_ssh.FakeConn) 85 stdinPipe = new(fake_io.FakeWriteCloser) 86 stdoutPipe = new(fake_io.FakeReader) 87 stderrPipe = new(fake_io.FakeReader) 88 89 fakeListenerFactory.ListenStub = net.Listen 90 fakeSecureClient.NewSessionReturns(fakeSecureSession, nil) 91 fakeSecureClient.ConnReturns(fakeConnection) 92 fakeSecureDialer.DialReturns(fakeSecureClient, nil) 93 94 stdinPipe.WriteStub = func(p []byte) (int, error) { 95 return len(p), nil 96 } 97 fakeSecureSession.StdinPipeReturns(stdinPipe, nil) 98 99 stdoutPipe.ReadStub = func(p []byte) (int, error) { 100 return 0, io.EOF 101 } 102 fakeSecureSession.StdoutPipeReturns(stdoutPipe, nil) 103 104 stderrPipe.ReadStub = func(p []byte) (int, error) { 105 return 0, io.EOF 106 } 107 fakeSecureSession.StderrPipeReturns(stderrPipe, nil) 108 109 username = "some-user" 110 passcode = "some-passcode" 111 sshEndpoint = "some-endpoint" 112 sshEndpointFingerprint = "some-fingerprint" 113 skipHostValidation = false 114 commands = []string{} 115 terminalRequest = RequestTTYAuto 116 keepAliveDuration = DefaultKeepAliveInterval 117 }) 118 119 JustBeforeEach(func() { 120 secureShell = NewSecureShell( 121 fakeSecureDialer, 122 fakeTerminalHelper, 123 fakeListenerFactory, 124 keepAliveDuration, 125 ) 126 }) 127 128 Describe("Connect", func() { 129 var connectErr error 130 131 JustBeforeEach(func() { 132 connectErr = secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation) 133 }) 134 135 Context("when dialing succeeds", func() { 136 It("creates the ssh client", func() { 137 Expect(connectErr).ToNot(HaveOccurred()) 138 139 Expect(fakeSecureDialer.DialCallCount()).To(Equal(1)) 140 protocolArg, sshEndpointArg, sshConfigArg := fakeSecureDialer.DialArgsForCall(0) 141 Expect(protocolArg).To(Equal("tcp")) 142 Expect(sshEndpointArg).To(Equal(sshEndpoint)) 143 Expect(sshConfigArg.User).To(Equal(username)) 144 Expect(sshConfigArg.Auth).To(HaveLen(1)) 145 Expect(sshConfigArg.HostKeyCallback).ToNot(BeNil()) 146 }) 147 }) 148 149 Context("when dialing fails", func() { 150 var dialError error 151 152 Context("when the error is a generic Dial error", func() { 153 BeforeEach(func() { 154 dialError = errors.New("woops") 155 fakeSecureDialer.DialReturns(nil, dialError) 156 }) 157 158 It("returns the dial error", func() { 159 Expect(connectErr).To(Equal(dialError)) 160 Expect(fakeSecureDialer.DialCallCount()).To(Equal(1)) 161 }) 162 }) 163 164 Context("when the dialing error is a golang 'unable to authenticate' error", func() { 165 BeforeEach(func() { 166 dialError = fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", []string{"none", "password"}) 167 fakeSecureDialer.DialReturns(nil, dialError) 168 }) 169 170 It("returns an UnableToAuthenticateError", func() { 171 Expect(connectErr).To(MatchError(ssherror.UnableToAuthenticateError{Err: dialError})) 172 Expect(fakeSecureDialer.DialCallCount()).To(Equal(1)) 173 }) 174 }) 175 }) 176 }) 177 178 Describe("InteractiveSession", func() { 179 var ( 180 stdin *fake_io.FakeReadCloser 181 stdout, stderr *fake_io.FakeWriter 182 183 sessionErr error 184 interactiveSessionInvoker func(secureShell *SecureShell) 185 ) 186 187 BeforeEach(func() { 188 stdin = new(fake_io.FakeReadCloser) 189 stdout = new(fake_io.FakeWriter) 190 stderr = new(fake_io.FakeWriter) 191 192 fakeTerminalHelper.StdStreamsReturns(stdin, stdout, stderr) 193 interactiveSessionInvoker = func(secureShell *SecureShell) { 194 sessionErr = secureShell.InteractiveSession(commands, terminalRequest) 195 } 196 }) 197 198 JustBeforeEach(func() { 199 connectErr := secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation) 200 Expect(connectErr).NotTo(HaveOccurred()) 201 interactiveSessionInvoker(secureShell) 202 }) 203 204 Context("when host key validation is enabled", func() { 205 var ( 206 callback func(hostname string, remote net.Addr, key ssh.PublicKey) error 207 addr net.Addr 208 ) 209 210 BeforeEach(func() { 211 skipHostValidation = false 212 }) 213 214 JustBeforeEach(func() { 215 Expect(fakeSecureDialer.DialCallCount()).To(Equal(1)) 216 _, _, config := fakeSecureDialer.DialArgsForCall(0) 217 callback = config.HostKeyCallback 218 219 listener, err := net.Listen("tcp", "localhost:0") 220 Expect(err).NotTo(HaveOccurred()) 221 222 addr = listener.Addr() 223 listener.Close() 224 }) 225 226 Context("when the md5 fingerprint matches", func() { 227 BeforeEach(func() { 228 sshEndpointFingerprint = "41:ce:56:e6:9c:42:a9:c6:9e:68:ac:e3:4d:f6:38:79" 229 }) 230 231 It("does not return an error", func() { 232 Expect(callback("", addr, TestHostKey.PublicKey())).ToNot(HaveOccurred()) 233 }) 234 }) 235 236 Context("when the hex sha1 fingerprint matches", func() { 237 BeforeEach(func() { 238 sshEndpointFingerprint = "a8:e2:67:cb:ea:2a:6e:23:a1:72:ce:8f:07:92:15:ee:1f:82:f8:ca" 239 }) 240 241 It("does not return an error", func() { 242 Expect(callback("", addr, TestHostKey.PublicKey())).ToNot(HaveOccurred()) 243 }) 244 }) 245 246 Context("when the base64 sha256 fingerprint matches", func() { 247 BeforeEach(func() { 248 sshEndpointFingerprint = "sp/jrLuj66r+yrLDUKZdJU5tdzt4mq/UaSiNBjpgr+8" 249 }) 250 251 It("does not return an error", func() { 252 Expect(callback("", addr, TestHostKey.PublicKey())).ToNot(HaveOccurred()) 253 }) 254 }) 255 256 Context("when the base64 SHA256 fingerprint does not match", func() { 257 BeforeEach(func() { 258 sshEndpointFingerprint = "0000000000000000000000000000000000000000000" 259 }) 260 261 It("returns an error'", func() { 262 err := callback("", addr, TestHostKey.PublicKey()) 263 Expect(err).To(MatchError(MatchRegexp("Host key verification failed\\."))) 264 Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\""))) 265 }) 266 }) 267 268 Context("when the hex SHA1 fingerprint does not match", func() { 269 BeforeEach(func() { 270 sshEndpointFingerprint = "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00" 271 }) 272 273 It("returns an error'", func() { 274 err := callback("", addr, TestHostKey.PublicKey()) 275 Expect(err).To(MatchError(MatchRegexp("Host key verification failed\\."))) 276 Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\""))) 277 }) 278 }) 279 280 Context("when the MD5 fingerprint does not match", func() { 281 BeforeEach(func() { 282 sshEndpointFingerprint = "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00" 283 }) 284 285 It("returns an error'", func() { 286 err := callback("", addr, TestHostKey.PublicKey()) 287 Expect(err).To(MatchError(MatchRegexp("Host key verification failed\\."))) 288 Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\""))) 289 }) 290 }) 291 292 Context("when no fingerprint is present in endpoint info", func() { 293 BeforeEach(func() { 294 sshEndpointFingerprint = "" 295 sshEndpoint = "" 296 }) 297 298 It("returns an error'", func() { 299 err := callback("", addr, TestHostKey.PublicKey()) 300 Expect(err).To(MatchError(MatchRegexp("Unable to verify identity of host\\."))) 301 Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\""))) 302 }) 303 }) 304 305 Context("when the fingerprint length doesn't make sense", func() { 306 BeforeEach(func() { 307 sshEndpointFingerprint = "garbage" 308 }) 309 310 It("returns an error", func() { 311 err := callback("", addr, TestHostKey.PublicKey()) 312 Eventually(err).Should(MatchError(MatchRegexp("Unsupported host key fingerprint format"))) 313 }) 314 }) 315 }) 316 317 Context("when the skip host validation flag is set", func() { 318 BeforeEach(func() { 319 skipHostValidation = true 320 }) 321 322 It("the HostKeyCallback on the Config to always return nil", func() { 323 Expect(fakeSecureDialer.DialCallCount()).To(Equal(1)) 324 325 _, _, config := fakeSecureDialer.DialArgsForCall(0) 326 Expect(config.HostKeyCallback("some-hostname", nil, nil)).To(BeNil()) 327 }) 328 }) 329 330 // TODO: see if it's possible to test the piping between the ss client input and outputs and the UI object we pass in 331 Context("when dialing is successful", func() { 332 It("creates a new secure shell session", func() { 333 Expect(fakeSecureClient.NewSessionCallCount()).To(Equal(1)) 334 }) 335 336 It("closes the session", func() { 337 Expect(fakeSecureSession.CloseCallCount()).To(Equal(1)) 338 }) 339 340 It("gets a stdin pipe for the session", func() { 341 Expect(fakeSecureSession.StdinPipeCallCount()).To(Equal(1)) 342 }) 343 344 Context("when getting the stdin pipe fails", func() { 345 BeforeEach(func() { 346 fakeSecureSession.StdinPipeReturns(nil, errors.New("woops")) 347 }) 348 349 It("returns the error", func() { 350 Expect(sessionErr).Should(MatchError("woops")) 351 }) 352 }) 353 354 It("gets a stdout pipe for the session", func() { 355 Expect(fakeSecureSession.StdoutPipeCallCount()).To(Equal(1)) 356 }) 357 358 Context("when getting the stdout pipe fails", func() { 359 BeforeEach(func() { 360 fakeSecureSession.StdoutPipeReturns(nil, errors.New("woops")) 361 }) 362 363 It("returns the error", func() { 364 Expect(sessionErr).Should(MatchError("woops")) 365 }) 366 }) 367 368 It("gets a stderr pipe for the session", func() { 369 Expect(fakeSecureSession.StderrPipeCallCount()).To(Equal(1)) 370 }) 371 372 Context("when getting the stderr pipe fails", func() { 373 BeforeEach(func() { 374 fakeSecureSession.StderrPipeReturns(nil, errors.New("woops")) 375 }) 376 377 It("returns the error", func() { 378 Expect(sessionErr).Should(MatchError("woops")) 379 }) 380 }) 381 }) 382 383 Context("when stdin is a terminal", func() { 384 var master, slave *os.File 385 386 BeforeEach(func() { 387 var err error 388 master, slave, err = pty.Open() 389 Expect(err).NotTo(HaveOccurred()) 390 391 terminalRequest = RequestTTYForce 392 393 terminalHelper := DefaultTerminalHelper() 394 fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo 395 fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize 396 }) 397 398 AfterEach(func() { 399 master.Close() 400 // slave.Close() // race 401 }) 402 403 Context("when a command is not specified", func() { 404 var terminalType string 405 406 BeforeEach(func() { 407 terminalType = os.Getenv("TERM") 408 os.Setenv("TERM", "test-terminal-type") 409 410 winsize := &term.Winsize{Width: 1024, Height: 256} 411 fakeTerminalHelper.GetWinsizeReturns(winsize, nil) 412 413 fakeSecureSession.ShellStub = func() error { 414 Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1)) 415 Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(0)) 416 return nil 417 } 418 }) 419 420 AfterEach(func() { 421 os.Setenv("TERM", terminalType) 422 }) 423 424 It("requests a pty with the correct terminal type, window size, and modes", func() { 425 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1)) 426 Expect(fakeTerminalHelper.GetWinsizeCallCount()).To(Equal(1)) 427 428 termType, height, width, modes := fakeSecureSession.RequestPtyArgsForCall(0) 429 Expect(termType).To(Equal("test-terminal-type")) 430 Expect(height).To(Equal(256)) 431 Expect(width).To(Equal(1024)) 432 433 expectedModes := ssh.TerminalModes{ 434 ssh.ECHO: 1, 435 ssh.TTY_OP_ISPEED: 115200, 436 ssh.TTY_OP_OSPEED: 115200, 437 } 438 Expect(modes).To(Equal(expectedModes)) 439 }) 440 441 Context("when the TERM environment variable is not set", func() { 442 BeforeEach(func() { 443 os.Unsetenv("TERM") 444 }) 445 446 It("requests a pty with the default terminal type", func() { 447 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1)) 448 449 termType, _, _, _ := fakeSecureSession.RequestPtyArgsForCall(0) 450 Expect(termType).To(Equal("xterm")) 451 }) 452 }) 453 454 It("puts the terminal into raw mode and restores it after running the shell", func() { 455 Expect(fakeSecureSession.ShellCallCount()).To(Equal(1)) 456 Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1)) 457 Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(1)) 458 }) 459 460 Context("when the pty allocation fails", func() { 461 var ptyError error 462 463 BeforeEach(func() { 464 ptyError = errors.New("pty allocation error") 465 fakeSecureSession.RequestPtyReturns(ptyError) 466 }) 467 468 It("returns the error", func() { 469 Expect(sessionErr).To(Equal(ptyError)) 470 }) 471 }) 472 473 Context("when placing the terminal into raw mode fails", func() { 474 BeforeEach(func() { 475 fakeTerminalHelper.SetRawTerminalReturns(nil, errors.New("woops")) 476 }) 477 478 It("keeps calm and carries on", func() { 479 Expect(fakeSecureSession.ShellCallCount()).To(Equal(1)) 480 }) 481 482 It("does not not restore the terminal", func() { 483 Expect(fakeSecureSession.ShellCallCount()).To(Equal(1)) 484 Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1)) 485 Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(0)) 486 }) 487 }) 488 }) 489 490 Context("when a command is specified", func() { 491 BeforeEach(func() { 492 commands = []string{"echo", "-n", "hello"} 493 }) 494 495 Context("when a terminal is requested", func() { 496 BeforeEach(func() { 497 terminalRequest = RequestTTYYes 498 fakeTerminalHelper.GetFdInfoReturns(0, true) 499 }) 500 501 It("requests a pty", func() { 502 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1)) 503 }) 504 }) 505 506 Context("when a terminal is not explicitly requested", func() { 507 BeforeEach(func() { 508 terminalRequest = RequestTTYAuto 509 }) 510 511 It("does not request a pty", func() { 512 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0)) 513 }) 514 }) 515 }) 516 }) 517 518 Context("when stdin is not a terminal", func() { 519 BeforeEach(func() { 520 stdin.ReadStub = func(p []byte) (int, error) { 521 return 0, io.EOF 522 } 523 524 terminalHelper := DefaultTerminalHelper() 525 fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo 526 fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize 527 }) 528 529 Context("when a terminal is not requested", func() { 530 It("does not request a pty", func() { 531 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0)) 532 }) 533 }) 534 535 Context("when a terminal is requested", func() { 536 BeforeEach(func() { 537 terminalRequest = RequestTTYYes 538 }) 539 540 It("does not request a pty", func() { 541 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0)) 542 }) 543 }) 544 }) 545 546 PContext("when a terminal is forced", func() { 547 BeforeEach(func() { 548 terminalRequest = RequestTTYForce 549 }) 550 551 It("requests a pty", func() { 552 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1)) 553 }) 554 }) 555 556 Context("when a terminal is disabled", func() { 557 BeforeEach(func() { 558 terminalRequest = RequestTTYNo 559 }) 560 561 It("does not request a pty", func() { 562 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0)) 563 }) 564 }) 565 566 Context("when a command is not specified", func() { 567 It("requests an interactive shell", func() { 568 Expect(fakeSecureSession.ShellCallCount()).To(Equal(1)) 569 }) 570 571 Context("when the shell request returns an error", func() { 572 BeforeEach(func() { 573 fakeSecureSession.ShellReturns(errors.New("oh bother")) 574 }) 575 576 It("returns the error", func() { 577 Expect(sessionErr).To(MatchError("oh bother")) 578 }) 579 }) 580 }) 581 582 Context("when a command is specifed", func() { 583 BeforeEach(func() { 584 commands = []string{"echo", "-n", "hello"} 585 }) 586 587 It("starts the command", func() { 588 Expect(fakeSecureSession.StartCallCount()).To(Equal(1)) 589 Expect(fakeSecureSession.StartArgsForCall(0)).To(Equal("echo -n hello")) 590 }) 591 592 Context("when the command fails to start", func() { 593 BeforeEach(func() { 594 fakeSecureSession.StartReturns(errors.New("oh well")) 595 }) 596 597 It("returns the error", func() { 598 Expect(sessionErr).To(MatchError("oh well")) 599 }) 600 }) 601 }) 602 603 Context("when the shell or command has started", func() { 604 BeforeEach(func() { 605 stdin.ReadStub = func(p []byte) (int, error) { 606 p[0] = 0 607 return 1, io.EOF 608 } 609 stdinPipe.WriteStub = func(p []byte) (int, error) { 610 defer GinkgoRecover() 611 Expect(p[0]).To(Equal(byte(0))) 612 return 1, nil 613 } 614 615 stdoutPipe.ReadStub = func(p []byte) (int, error) { 616 p[0] = 1 617 return 1, io.EOF 618 } 619 stdout.WriteStub = func(p []byte) (int, error) { 620 defer GinkgoRecover() 621 Expect(p[0]).To(Equal(byte(1))) 622 return 1, nil 623 } 624 625 stderrPipe.ReadStub = func(p []byte) (int, error) { 626 p[0] = 2 627 return 1, io.EOF 628 } 629 stderr.WriteStub = func(p []byte) (int, error) { 630 defer GinkgoRecover() 631 Expect(p[0]).To(Equal(byte(2))) 632 return 1, nil 633 } 634 635 fakeSecureSession.StdinPipeReturns(stdinPipe, nil) 636 fakeSecureSession.StdoutPipeReturns(stdoutPipe, nil) 637 fakeSecureSession.StderrPipeReturns(stderrPipe, nil) 638 639 fakeSecureSession.WaitReturns(errors.New("error result")) 640 }) 641 642 It("copies data from the stdin stream to the session stdin pipe", func() { 643 Eventually(stdin.ReadCallCount).Should(Equal(1)) 644 Eventually(stdinPipe.WriteCallCount).Should(Equal(1)) 645 }) 646 647 It("copies data from the session stdout pipe to the stdout stream", func() { 648 Eventually(stdoutPipe.ReadCallCount).Should(Equal(1)) 649 Eventually(stdout.WriteCallCount).Should(Equal(1)) 650 }) 651 652 It("copies data from the session stderr pipe to the stderr stream", func() { 653 Eventually(stderrPipe.ReadCallCount).Should(Equal(1)) 654 Eventually(stderr.WriteCallCount).Should(Equal(1)) 655 }) 656 657 It("waits for the session to end", func() { 658 Expect(fakeSecureSession.WaitCallCount()).To(Equal(1)) 659 }) 660 661 It("returns the result from wait", func() { 662 Expect(sessionErr).To(MatchError("error result")) 663 }) 664 665 Context("when the session terminates before stream copies complete", func() { 666 var sessionErrorCh chan error 667 668 BeforeEach(func() { 669 sessionErrorCh = make(chan error, 1) 670 671 interactiveSessionInvoker = func(secureShell *SecureShell) { 672 go func() { 673 sessionErrorCh <- secureShell.InteractiveSession(commands, terminalRequest) 674 }() 675 } 676 677 stdoutPipe.ReadStub = func(p []byte) (int, error) { 678 defer GinkgoRecover() 679 Eventually(fakeSecureSession.WaitCallCount).Should(Equal(1)) 680 Consistently(sessionErrorCh).ShouldNot(Receive()) 681 682 p[0] = 1 683 return 1, io.EOF 684 } 685 686 stderrPipe.ReadStub = func(p []byte) (int, error) { 687 defer GinkgoRecover() 688 Eventually(fakeSecureSession.WaitCallCount).Should(Equal(1)) 689 Consistently(sessionErrorCh).ShouldNot(Receive()) 690 691 p[0] = 2 692 return 1, io.EOF 693 } 694 }) 695 696 It("waits for the copies to complete", func() { 697 Eventually(sessionErrorCh).Should(Receive()) 698 Expect(stdoutPipe.ReadCallCount()).To(Equal(1)) 699 Expect(stderrPipe.ReadCallCount()).To(Equal(1)) 700 }) 701 }) 702 703 Context("when stdin is closed", func() { 704 BeforeEach(func() { 705 stdin.ReadStub = func(p []byte) (int, error) { 706 defer GinkgoRecover() 707 Consistently(stdinPipe.CloseCallCount).Should(Equal(0)) 708 p[0] = 0 709 return 1, io.EOF 710 } 711 }) 712 713 It("closes the stdinPipe", func() { 714 Eventually(stdinPipe.CloseCallCount).Should(Equal(1)) 715 }) 716 }) 717 }) 718 719 Context("when stdout is a terminal and a window size change occurs", func() { 720 var master, slave *os.File 721 722 BeforeEach(func() { 723 var err error 724 master, slave, err = pty.Open() 725 Expect(err).NotTo(HaveOccurred()) 726 727 terminalHelper := DefaultTerminalHelper() 728 fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo 729 fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize 730 fakeTerminalHelper.StdStreamsReturns(stdin, slave, stderr) 731 732 winsize := &term.Winsize{Height: 100, Width: 100} 733 err = term.SetWinsize(slave.Fd(), winsize) 734 Expect(err).NotTo(HaveOccurred()) 735 736 fakeSecureSession.WaitStub = func() error { 737 fakeSecureSession.SendRequestCallCount() 738 Expect(fakeSecureSession.SendRequestCallCount()).To(Equal(0)) 739 740 // No dimension change 741 for i := 0; i < 3; i++ { 742 winsize := &term.Winsize{Height: 100, Width: 100} 743 err = term.SetWinsize(slave.Fd(), winsize) 744 Expect(err).NotTo(HaveOccurred()) 745 } 746 747 winsize := &term.Winsize{Height: 100, Width: 200} 748 err = term.SetWinsize(slave.Fd(), winsize) 749 Expect(err).NotTo(HaveOccurred()) 750 751 err = syscall.Kill(syscall.Getpid(), syscall.SIGWINCH) 752 Expect(err).NotTo(HaveOccurred()) 753 754 Eventually(fakeSecureSession.SendRequestCallCount).Should(Equal(1)) 755 return nil 756 } 757 }) 758 759 AfterEach(func() { 760 master.Close() 761 slave.Close() 762 }) 763 764 It("sends window change events when the window dimensions change", func() { 765 Expect(fakeSecureSession.SendRequestCallCount()).To(Equal(1)) 766 767 requestType, wantReply, message := fakeSecureSession.SendRequestArgsForCall(0) 768 Expect(requestType).To(Equal("window-change")) 769 Expect(wantReply).To(BeFalse()) 770 771 type resizeMessage struct { 772 Width uint32 773 Height uint32 774 PixelWidth uint32 775 PixelHeight uint32 776 } 777 var resizeMsg resizeMessage 778 779 err := ssh.Unmarshal(message, &resizeMsg) 780 Expect(err).NotTo(HaveOccurred()) 781 782 Expect(resizeMsg).To(Equal(resizeMessage{Height: 100, Width: 200})) 783 }) 784 }) 785 786 Describe("keep alive messages", func() { 787 var times []time.Time 788 var timesCh chan []time.Time 789 var done chan struct{} 790 791 BeforeEach(func() { 792 keepAliveDuration = 100 * time.Millisecond 793 794 times = []time.Time{} 795 timesCh = make(chan []time.Time, 1) 796 done = make(chan struct{}, 1) 797 798 fakeConnection.SendRequestStub = func(reqName string, wantReply bool, message []byte) (bool, []byte, error) { 799 Expect(reqName).To(Equal("keepalive@cloudfoundry.org")) 800 Expect(wantReply).To(BeTrue()) 801 Expect(message).To(BeNil()) 802 803 times = append(times, time.Now()) 804 if len(times) == 3 { 805 timesCh <- times 806 close(done) 807 } 808 return true, nil, nil 809 } 810 811 fakeSecureSession.WaitStub = func() error { 812 Eventually(done).Should(BeClosed()) 813 return nil 814 } 815 }) 816 817 PIt("sends keep alive messages at the expected interval", func() { 818 times := <-timesCh 819 Expect(times[2]).To(BeTemporally("~", times[0].Add(200*time.Millisecond), 160*time.Millisecond)) 820 }) 821 }) 822 }) 823 824 Describe("LocalPortForward", func() { 825 var ( 826 forwardErr error 827 828 echoAddress string 829 echoListener *fake_net.FakeListener 830 echoHandler *fake_server.FakeConnectionHandler 831 echoServer *server.Server 832 833 localAddress string 834 835 realLocalListener net.Listener 836 fakeLocalListener *fake_net.FakeListener 837 838 forwardSpecs []LocalPortForward 839 ) 840 841 BeforeEach(func() { 842 logger := lagertest.NewTestLogger("test") 843 844 var err error 845 realLocalListener, err = net.Listen("tcp", "127.0.0.1:0") 846 Expect(err).NotTo(HaveOccurred()) 847 848 localAddress = realLocalListener.Addr().String() 849 fakeListenerFactory.ListenReturns(realLocalListener, nil) 850 851 echoHandler = new(fake_server.FakeConnectionHandler) 852 echoHandler.HandleConnectionStub = func(conn net.Conn) { 853 io.Copy(conn, conn) 854 conn.Close() 855 } 856 857 realListener, err := net.Listen("tcp", "127.0.0.1:0") 858 Expect(err).NotTo(HaveOccurred()) 859 echoAddress = realListener.Addr().String() 860 861 echoListener = new(fake_net.FakeListener) 862 echoListener.AcceptStub = realListener.Accept 863 echoListener.CloseStub = realListener.Close 864 echoListener.AddrStub = realListener.Addr 865 866 fakeLocalListener = new(fake_net.FakeListener) 867 fakeLocalListener.AcceptReturns(nil, errors.New("Not Accepting Connections")) 868 869 echoServer = server.NewServer(logger.Session("echo"), "", echoHandler) 870 echoServer.SetListener(echoListener) 871 go echoServer.Serve() 872 873 forwardSpecs = []LocalPortForward{{ 874 RemoteAddress: echoAddress, 875 LocalAddress: localAddress, 876 }} 877 878 fakeSecureClient.DialStub = net.Dial 879 }) 880 881 JustBeforeEach(func() { 882 connectErr := secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation) 883 Expect(connectErr).NotTo(HaveOccurred()) 884 885 forwardErr = secureShell.LocalPortForward(forwardSpecs) 886 }) 887 888 AfterEach(func() { 889 err := secureShell.Close() 890 Expect(err).NotTo(HaveOccurred()) 891 echoServer.Shutdown() 892 893 realLocalListener.Close() 894 }) 895 896 validateConnectivity := func(addr string) { 897 conn, err := net.Dial("tcp", addr) 898 Expect(err).NotTo(HaveOccurred()) 899 900 msg := fmt.Sprintf("Hello from %s\n", addr) 901 n, err := conn.Write([]byte(msg)) 902 Expect(err).NotTo(HaveOccurred()) 903 Expect(n).To(Equal(len(msg))) 904 905 response := make([]byte, len(msg)) 906 n, err = conn.Read(response) 907 Expect(err).NotTo(HaveOccurred()) 908 Expect(n).To(Equal(len(msg))) 909 910 err = conn.Close() 911 Expect(err).NotTo(HaveOccurred()) 912 913 Expect(response).To(Equal([]byte(msg))) 914 } 915 916 It("dials the connect address when a local connection is made", func() { 917 Expect(forwardErr).NotTo(HaveOccurred()) 918 919 conn, err := net.Dial("tcp", localAddress) 920 Expect(err).NotTo(HaveOccurred()) 921 922 Eventually(echoListener.AcceptCallCount).Should(BeNumerically(">=", 1)) 923 Eventually(fakeSecureClient.DialCallCount).Should(Equal(1)) 924 925 network, addr := fakeSecureClient.DialArgsForCall(0) 926 Expect(network).To(Equal("tcp")) 927 Expect(addr).To(Equal(echoAddress)) 928 929 Expect(conn.Close()).NotTo(HaveOccurred()) 930 }) 931 932 It("copies data between the local and remote connections", func() { 933 validateConnectivity(localAddress) 934 }) 935 936 Context("when a local connection is already open", func() { 937 var conn net.Conn 938 939 JustBeforeEach(func() { 940 var err error 941 conn, err = net.Dial("tcp", localAddress) 942 Expect(err).NotTo(HaveOccurred()) 943 }) 944 945 AfterEach(func() { 946 err := conn.Close() 947 Expect(err).NotTo(HaveOccurred()) 948 }) 949 950 It("allows for new incoming connections as well", func() { 951 validateConnectivity(localAddress) 952 }) 953 }) 954 955 Context("when there are multiple port forward specs", func() { 956 Context("when provided a real listener", func() { 957 var ( 958 realLocalListener2 net.Listener 959 localAddress2 string 960 ) 961 962 BeforeEach(func() { 963 var err error 964 realLocalListener2, err = net.Listen("tcp", "127.0.0.1:0") 965 Expect(err).NotTo(HaveOccurred()) 966 967 localAddress2 = realLocalListener2.Addr().String() 968 969 fakeListenerFactory.ListenStub = func(network, addr string) (net.Listener, error) { 970 if addr == localAddress { 971 return realLocalListener, nil 972 } 973 974 if addr == localAddress2 { 975 return realLocalListener2, nil 976 } 977 978 return nil, errors.New("unexpected address") 979 } 980 981 forwardSpecs = []LocalPortForward{ 982 { 983 RemoteAddress: echoAddress, 984 LocalAddress: localAddress, 985 }, 986 { 987 RemoteAddress: echoAddress, 988 LocalAddress: localAddress2, 989 }, 990 } 991 }) 992 993 AfterEach(func() { 994 realLocalListener2.Close() 995 }) 996 997 It("listens to all the things", func() { 998 Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(2)) 999 1000 network, addr := fakeListenerFactory.ListenArgsForCall(0) 1001 Expect(network).To(Equal("tcp")) 1002 Expect(addr).To(Equal(localAddress)) 1003 1004 network, addr = fakeListenerFactory.ListenArgsForCall(1) 1005 Expect(network).To(Equal("tcp")) 1006 Expect(addr).To(Equal(localAddress2)) 1007 }) 1008 1009 It("forwards to the correct target", func() { 1010 validateConnectivity(localAddress) 1011 validateConnectivity(localAddress2) 1012 }) 1013 }) 1014 1015 Context("when the secure client is closed", func() { 1016 var ( 1017 fakeLocalListener1 *fake_net.FakeListener 1018 fakeLocalListener2 *fake_net.FakeListener 1019 ) 1020 1021 BeforeEach(func() { 1022 forwardSpecs = []LocalPortForward{ 1023 { 1024 RemoteAddress: echoAddress, 1025 LocalAddress: localAddress, 1026 }, 1027 { 1028 RemoteAddress: echoAddress, 1029 LocalAddress: localAddress, 1030 }, 1031 } 1032 1033 fakeLocalListener1 = new(fake_net.FakeListener) 1034 BlockAcceptOnClose(fakeLocalListener1) 1035 fakeListenerFactory.ListenReturnsOnCall(0, fakeLocalListener1, nil) 1036 1037 fakeLocalListener2 = new(fake_net.FakeListener) 1038 BlockAcceptOnClose(fakeLocalListener2) 1039 fakeListenerFactory.ListenReturnsOnCall(1, fakeLocalListener2, nil) 1040 }) 1041 1042 It("closes the listeners", func() { 1043 Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(2)) 1044 Eventually(fakeLocalListener1.AcceptCallCount).Should(Equal(1)) 1045 Eventually(fakeLocalListener2.AcceptCallCount).Should(Equal(1)) 1046 1047 err := secureShell.Close() 1048 Expect(err).NotTo(HaveOccurred()) 1049 Eventually(fakeLocalListener1.CloseCallCount).Should(Equal(2)) 1050 Eventually(fakeLocalListener2.CloseCallCount).Should(Equal(2)) 1051 }) 1052 }) 1053 }) 1054 1055 Context("when listen fails", func() { 1056 BeforeEach(func() { 1057 fakeListenerFactory.ListenReturns(nil, errors.New("failure is an option")) 1058 }) 1059 1060 It("returns the error", func() { 1061 Expect(forwardErr).To(MatchError("failure is an option")) 1062 }) 1063 }) 1064 1065 Context("when the client it closed", func() { 1066 BeforeEach(func() { 1067 fakeLocalListener = new(fake_net.FakeListener) 1068 BlockAcceptOnClose(fakeLocalListener) 1069 1070 fakeListenerFactory.ListenReturns(fakeLocalListener, nil) 1071 }) 1072 1073 It("closes the listener when the client is closed", func() { 1074 Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(1)) 1075 Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(1)) 1076 1077 err := secureShell.Close() 1078 Expect(err).NotTo(HaveOccurred()) 1079 Eventually(fakeLocalListener.CloseCallCount).Should(Equal(2)) 1080 }) 1081 }) 1082 1083 Context("when accept fails", func() { 1084 var fakeConn *fake_net.FakeConn 1085 1086 BeforeEach(func() { 1087 fakeConn = &fake_net.FakeConn{} 1088 fakeConn.ReadReturns(0, io.EOF) 1089 1090 fakeListenerFactory.ListenReturns(fakeLocalListener, nil) 1091 }) 1092 1093 Context("with a permanent error", func() { 1094 BeforeEach(func() { 1095 fakeLocalListener.AcceptReturns(nil, errors.New("boom")) 1096 }) 1097 1098 It("stops trying to accept connections", func() { 1099 Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(1)) 1100 Consistently(fakeLocalListener.AcceptCallCount).Should(Equal(1)) 1101 Expect(fakeLocalListener.CloseCallCount()).To(Equal(1)) 1102 }) 1103 }) 1104 1105 Context("with a temporary error", func() { 1106 var timeCh chan time.Time 1107 1108 BeforeEach(func() { 1109 timeCh = make(chan time.Time, 3) 1110 1111 fakeLocalListener.AcceptStub = func() (net.Conn, error) { 1112 timeCh := timeCh 1113 if fakeLocalListener.AcceptCallCount() > 3 { 1114 close(timeCh) 1115 return nil, test_helpers.NewTestNetError(false, false) 1116 } else { 1117 timeCh <- time.Now() 1118 return nil, test_helpers.NewTestNetError(false, true) 1119 } 1120 } 1121 }) 1122 1123 PIt("retries connecting after a short delay", func() { 1124 Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(3)) 1125 Expect(timeCh).To(HaveLen(3)) 1126 1127 times := make([]time.Time, 0) 1128 for t := range timeCh { 1129 times = append(times, t) 1130 } 1131 1132 Expect(times[1]).To(BeTemporally("~", times[0].Add(115*time.Millisecond), 80*time.Millisecond)) 1133 Expect(times[2]).To(BeTemporally("~", times[1].Add(115*time.Millisecond), 100*time.Millisecond)) 1134 }) 1135 }) 1136 }) 1137 1138 Context("when dialing the connect address fails", func() { 1139 var fakeTarget *fake_net.FakeConn 1140 1141 BeforeEach(func() { 1142 fakeTarget = &fake_net.FakeConn{} 1143 fakeSecureClient.DialReturns(fakeTarget, errors.New("boom")) 1144 }) 1145 1146 It("does not call close on the target connection", func() { 1147 Consistently(fakeTarget.CloseCallCount).Should(Equal(0)) 1148 }) 1149 }) 1150 }) 1151 1152 Describe("Wait", func() { 1153 var waitErr error 1154 1155 JustBeforeEach(func() { 1156 connectErr := secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation) 1157 Expect(connectErr).NotTo(HaveOccurred()) 1158 1159 waitErr = secureShell.Wait() 1160 }) 1161 1162 It("calls wait on the secureClient", func() { 1163 Expect(waitErr).NotTo(HaveOccurred()) 1164 Expect(fakeSecureClient.WaitCallCount()).To(Equal(1)) 1165 }) 1166 1167 Describe("keep alive messages", func() { 1168 var times []time.Time 1169 var timesCh chan []time.Time 1170 var done chan struct{} 1171 1172 BeforeEach(func() { 1173 keepAliveDuration = 100 * time.Millisecond 1174 1175 times = []time.Time{} 1176 timesCh = make(chan []time.Time, 1) 1177 done = make(chan struct{}, 1) 1178 1179 fakeConnection.SendRequestStub = func(reqName string, wantReply bool, message []byte) (bool, []byte, error) { 1180 Expect(reqName).To(Equal("keepalive@cloudfoundry.org")) 1181 Expect(wantReply).To(BeTrue()) 1182 Expect(message).To(BeNil()) 1183 1184 times = append(times, time.Now()) 1185 if len(times) == 3 { 1186 timesCh <- times 1187 close(done) 1188 } 1189 return true, nil, nil 1190 } 1191 1192 fakeSecureClient.WaitStub = func() error { 1193 Eventually(done).Should(BeClosed()) 1194 return nil 1195 } 1196 }) 1197 1198 PIt("sends keep alive messages at the expected interval", func() { 1199 Expect(waitErr).NotTo(HaveOccurred()) 1200 times := <-timesCh 1201 Expect(times[2]).To(BeTemporally("~", times[0].Add(200*time.Millisecond), 100*time.Millisecond)) 1202 }) 1203 }) 1204 }) 1205 1206 Describe("Close", func() { 1207 JustBeforeEach(func() { 1208 connectErr := secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation) 1209 Expect(connectErr).NotTo(HaveOccurred()) 1210 }) 1211 1212 It("calls close on the secureClient", func() { 1213 err := secureShell.Close() 1214 Expect(err).NotTo(HaveOccurred()) 1215 1216 Expect(fakeSecureClient.CloseCallCount()).To(Equal(1)) 1217 }) 1218 }) 1219 })