github.com/loggregator/cli@v6.33.1-0.20180224010324-82334f081791+incompatible/util/clissh/ssh_test.go (about) 1 // +build !windows,!386 2 3 // skipping 386 because lager uses UInt64 in Session() 4 // skipping windows because Unix/Linux only syscall in test. 5 // should refactor out the conflicts so we could test this package in multi platforms. 6 7 package clissh_test 8 9 import ( 10 "errors" 11 "fmt" 12 "io" 13 "net" 14 "os" 15 "syscall" 16 "time" 17 18 "code.cloudfoundry.org/cli/util/clissh/clisshfakes" 19 "code.cloudfoundry.org/cli/util/clissh/ssherror" 20 "code.cloudfoundry.org/diego-ssh/server" 21 fake_server "code.cloudfoundry.org/diego-ssh/server/fakes" 22 "code.cloudfoundry.org/diego-ssh/test_helpers" 23 "code.cloudfoundry.org/diego-ssh/test_helpers/fake_io" 24 "code.cloudfoundry.org/diego-ssh/test_helpers/fake_net" 25 "code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh" 26 "code.cloudfoundry.org/lager/lagertest" 27 "github.com/kr/pty" 28 "github.com/moby/moby/pkg/term" 29 "golang.org/x/crypto/ssh" 30 31 . "code.cloudfoundry.org/cli/util/clissh" 32 . "github.com/onsi/ginkgo" 33 . "github.com/onsi/gomega" 34 ) 35 36 var _ = Describe("CLI SSH", func() { 37 var ( 38 fakeSecureDialer *clisshfakes.FakeSecureDialer 39 fakeSecureClient *clisshfakes.FakeSecureClient 40 fakeTerminalHelper *clisshfakes.FakeTerminalHelper 41 fakeListenerFactory *clisshfakes.FakeListenerFactory 42 fakeSecureSession *clisshfakes.FakeSecureSession 43 44 fakeConnection *fake_ssh.FakeConn 45 stdinPipe *fake_io.FakeWriteCloser 46 stdoutPipe *fake_io.FakeReader 47 stderrPipe *fake_io.FakeReader 48 secureShell *SecureShell 49 50 username string 51 passcode string 52 sshEndpoint string 53 sshEndpointFingerprint string 54 skipHostValidation bool 55 commands []string 56 terminalRequest TTYRequest 57 keepAliveDuration time.Duration 58 ) 59 60 BeforeEach(func() { 61 fakeSecureDialer = new(clisshfakes.FakeSecureDialer) 62 fakeSecureClient = new(clisshfakes.FakeSecureClient) 63 fakeTerminalHelper = new(clisshfakes.FakeTerminalHelper) 64 fakeListenerFactory = new(clisshfakes.FakeListenerFactory) 65 fakeSecureSession = new(clisshfakes.FakeSecureSession) 66 67 fakeConnection = new(fake_ssh.FakeConn) 68 stdinPipe = new(fake_io.FakeWriteCloser) 69 stdoutPipe = new(fake_io.FakeReader) 70 stderrPipe = new(fake_io.FakeReader) 71 72 fakeListenerFactory.ListenStub = net.Listen 73 fakeSecureClient.NewSessionReturns(fakeSecureSession, nil) 74 fakeSecureClient.ConnReturns(fakeConnection) 75 fakeSecureDialer.DialReturns(fakeSecureClient, nil) 76 77 stdinPipe.WriteStub = func(p []byte) (int, error) { 78 return len(p), nil 79 } 80 fakeSecureSession.StdinPipeReturns(stdinPipe, nil) 81 82 stdoutPipe.ReadStub = func(p []byte) (int, error) { 83 return 0, io.EOF 84 } 85 fakeSecureSession.StdoutPipeReturns(stdoutPipe, nil) 86 87 stderrPipe.ReadStub = func(p []byte) (int, error) { 88 return 0, io.EOF 89 } 90 fakeSecureSession.StderrPipeReturns(stderrPipe, nil) 91 92 username = "some-user" 93 passcode = "some-passcode" 94 sshEndpoint = "some-endpoint" 95 sshEndpointFingerprint = "some-fingerprint" 96 skipHostValidation = false 97 commands = []string{} 98 terminalRequest = RequestTTYAuto 99 keepAliveDuration = DefaultKeepAliveInterval 100 }) 101 102 JustBeforeEach(func() { 103 secureShell = NewSecureShell( 104 fakeSecureDialer, 105 fakeTerminalHelper, 106 fakeListenerFactory, 107 keepAliveDuration, 108 ) 109 }) 110 111 Describe("Connect", func() { 112 var connectErr error 113 114 JustBeforeEach(func() { 115 connectErr = secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation) 116 }) 117 118 Context("when dialing succeeds", func() { 119 It("creates the ssh client", func() { 120 Expect(connectErr).ToNot(HaveOccurred()) 121 122 Expect(fakeSecureDialer.DialCallCount()).To(Equal(1)) 123 protocolArg, sshEndpointArg, sshConfigArg := fakeSecureDialer.DialArgsForCall(0) 124 Expect(protocolArg).To(Equal("tcp")) 125 Expect(sshEndpointArg).To(Equal(sshEndpoint)) 126 Expect(sshConfigArg.User).To(Equal(username)) 127 Expect(sshConfigArg.Auth).To(HaveLen(1)) 128 Expect(sshConfigArg.HostKeyCallback).ToNot(BeNil()) 129 }) 130 }) 131 132 Context("when dialing fails", func() { 133 var dialError error 134 135 Context("when the error is a generic Dial error", func() { 136 BeforeEach(func() { 137 dialError = errors.New("woops") 138 fakeSecureDialer.DialReturns(nil, dialError) 139 }) 140 141 It("returns the dial error", func() { 142 Expect(connectErr).To(Equal(dialError)) 143 Expect(fakeSecureDialer.DialCallCount()).To(Equal(1)) 144 }) 145 }) 146 147 Context("when the dialing error is a golang 'unable to authenticate' error", func() { 148 BeforeEach(func() { 149 dialError = fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", []string{"none", "password"}) 150 fakeSecureDialer.DialReturns(nil, dialError) 151 }) 152 153 It("returns an UnableToAuthenticateError", func() { 154 Expect(connectErr).To(MatchError(ssherror.UnableToAuthenticateError{Err: dialError})) 155 Expect(fakeSecureDialer.DialCallCount()).To(Equal(1)) 156 }) 157 }) 158 }) 159 }) 160 161 Describe("InteractiveSession", func() { 162 var ( 163 stdin *fake_io.FakeReadCloser 164 stdout, stderr *fake_io.FakeWriter 165 166 sessionErr error 167 interactiveSessionInvoker func(secureShell *SecureShell) 168 ) 169 170 BeforeEach(func() { 171 stdin = new(fake_io.FakeReadCloser) 172 stdout = new(fake_io.FakeWriter) 173 stderr = new(fake_io.FakeWriter) 174 175 fakeTerminalHelper.StdStreamsReturns(stdin, stdout, stderr) 176 interactiveSessionInvoker = func(secureShell *SecureShell) { 177 sessionErr = secureShell.InteractiveSession(commands, terminalRequest) 178 } 179 }) 180 181 JustBeforeEach(func() { 182 connectErr := secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation) 183 Expect(connectErr).NotTo(HaveOccurred()) 184 interactiveSessionInvoker(secureShell) 185 }) 186 187 Context("when host key validation is enabled", func() { 188 var ( 189 callback func(hostname string, remote net.Addr, key ssh.PublicKey) error 190 addr net.Addr 191 ) 192 193 BeforeEach(func() { 194 skipHostValidation = false 195 }) 196 197 JustBeforeEach(func() { 198 Expect(fakeSecureDialer.DialCallCount()).To(Equal(1)) 199 _, _, config := fakeSecureDialer.DialArgsForCall(0) 200 callback = config.HostKeyCallback 201 202 listener, err := net.Listen("tcp", "localhost:0") 203 Expect(err).NotTo(HaveOccurred()) 204 205 addr = listener.Addr() 206 listener.Close() 207 }) 208 209 Context("when the md5 fingerprint matches", func() { 210 BeforeEach(func() { 211 sshEndpointFingerprint = "41:ce:56:e6:9c:42:a9:c6:9e:68:ac:e3:4d:f6:38:79" 212 }) 213 214 It("does not return an error", func() { 215 Expect(callback("", addr, TestHostKey.PublicKey())).ToNot(HaveOccurred()) 216 }) 217 }) 218 219 Context("when the hex sha1 fingerprint matches", func() { 220 BeforeEach(func() { 221 sshEndpointFingerprint = "a8:e2:67:cb:ea:2a:6e:23:a1:72:ce:8f:07:92:15:ee:1f:82:f8:ca" 222 }) 223 224 It("does not return an error", func() { 225 Expect(callback("", addr, TestHostKey.PublicKey())).ToNot(HaveOccurred()) 226 }) 227 }) 228 229 Context("when the base64 sha256 fingerprint matches", func() { 230 BeforeEach(func() { 231 sshEndpointFingerprint = "sp/jrLuj66r+yrLDUKZdJU5tdzt4mq/UaSiNBjpgr+8" 232 }) 233 234 It("does not return an error", func() { 235 Expect(callback("", addr, TestHostKey.PublicKey())).ToNot(HaveOccurred()) 236 }) 237 }) 238 239 Context("when the base64 SHA256 fingerprint does not match", func() { 240 BeforeEach(func() { 241 sshEndpointFingerprint = "0000000000000000000000000000000000000000000" 242 }) 243 244 It("returns an error'", func() { 245 err := callback("", addr, TestHostKey.PublicKey()) 246 Expect(err).To(MatchError(MatchRegexp("Host key verification failed\\."))) 247 Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\""))) 248 }) 249 }) 250 251 Context("when the hex SHA1 fingerprint does not match", func() { 252 BeforeEach(func() { 253 sshEndpointFingerprint = "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00" 254 }) 255 256 It("returns an error'", func() { 257 err := callback("", addr, TestHostKey.PublicKey()) 258 Expect(err).To(MatchError(MatchRegexp("Host key verification failed\\."))) 259 Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\""))) 260 }) 261 }) 262 263 Context("when the MD5 fingerprint does not match", func() { 264 BeforeEach(func() { 265 sshEndpointFingerprint = "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00" 266 }) 267 268 It("returns an error'", func() { 269 err := callback("", addr, TestHostKey.PublicKey()) 270 Expect(err).To(MatchError(MatchRegexp("Host key verification failed\\."))) 271 Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\""))) 272 }) 273 }) 274 275 Context("when no fingerprint is present in endpoint info", func() { 276 BeforeEach(func() { 277 sshEndpointFingerprint = "" 278 sshEndpoint = "" 279 }) 280 281 It("returns an error'", func() { 282 err := callback("", addr, TestHostKey.PublicKey()) 283 Expect(err).To(MatchError(MatchRegexp("Unable to verify identity of host\\."))) 284 Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\""))) 285 }) 286 }) 287 288 Context("when the fingerprint length doesn't make sense", func() { 289 BeforeEach(func() { 290 sshEndpointFingerprint = "garbage" 291 }) 292 293 It("returns an error", func() { 294 err := callback("", addr, TestHostKey.PublicKey()) 295 Eventually(err).Should(MatchError(MatchRegexp("Unsupported host key fingerprint format"))) 296 }) 297 }) 298 }) 299 300 Context("when the skip host validation flag is set", func() { 301 BeforeEach(func() { 302 skipHostValidation = true 303 }) 304 305 It("the HostKeyCallback on the Config to always return nil", func() { 306 Expect(fakeSecureDialer.DialCallCount()).To(Equal(1)) 307 308 _, _, config := fakeSecureDialer.DialArgsForCall(0) 309 Expect(config.HostKeyCallback("some-hostname", nil, nil)).To(BeNil()) 310 }) 311 }) 312 313 // TODO: see if it's possible to test the piping between the ss client input and outputs and the UI object we pass in 314 Context("when dialing is successful", func() { 315 It("creates a new secure shell session", func() { 316 Expect(fakeSecureClient.NewSessionCallCount()).To(Equal(1)) 317 }) 318 319 It("closes the session", func() { 320 Expect(fakeSecureSession.CloseCallCount()).To(Equal(1)) 321 }) 322 323 It("gets a stdin pipe for the session", func() { 324 Expect(fakeSecureSession.StdinPipeCallCount()).To(Equal(1)) 325 }) 326 327 Context("when getting the stdin pipe fails", func() { 328 BeforeEach(func() { 329 fakeSecureSession.StdinPipeReturns(nil, errors.New("woops")) 330 }) 331 332 It("returns the error", func() { 333 Expect(sessionErr).Should(MatchError("woops")) 334 }) 335 }) 336 337 It("gets a stdout pipe for the session", func() { 338 Expect(fakeSecureSession.StdoutPipeCallCount()).To(Equal(1)) 339 }) 340 341 Context("when getting the stdout pipe fails", func() { 342 BeforeEach(func() { 343 fakeSecureSession.StdoutPipeReturns(nil, errors.New("woops")) 344 }) 345 346 It("returns the error", func() { 347 Expect(sessionErr).Should(MatchError("woops")) 348 }) 349 }) 350 351 It("gets a stderr pipe for the session", func() { 352 Expect(fakeSecureSession.StderrPipeCallCount()).To(Equal(1)) 353 }) 354 355 Context("when getting the stderr pipe fails", func() { 356 BeforeEach(func() { 357 fakeSecureSession.StderrPipeReturns(nil, errors.New("woops")) 358 }) 359 360 It("returns the error", func() { 361 Expect(sessionErr).Should(MatchError("woops")) 362 }) 363 }) 364 }) 365 366 Context("when stdin is a terminal", func() { 367 var master, slave *os.File 368 369 BeforeEach(func() { 370 var err error 371 master, slave, err = pty.Open() 372 Expect(err).NotTo(HaveOccurred()) 373 374 terminalRequest = RequestTTYForce 375 376 terminalHelper := DefaultTerminalHelper() 377 fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo 378 fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize 379 }) 380 381 AfterEach(func() { 382 master.Close() 383 // slave.Close() // race 384 }) 385 386 Context("when a command is not specified", func() { 387 var terminalType string 388 389 BeforeEach(func() { 390 terminalType = os.Getenv("TERM") 391 os.Setenv("TERM", "test-terminal-type") 392 393 winsize := &term.Winsize{Width: 1024, Height: 256} 394 fakeTerminalHelper.GetWinsizeReturns(winsize, nil) 395 396 fakeSecureSession.ShellStub = func() error { 397 Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1)) 398 Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(0)) 399 return nil 400 } 401 }) 402 403 AfterEach(func() { 404 os.Setenv("TERM", terminalType) 405 }) 406 407 It("requests a pty with the correct terminal type, window size, and modes", func() { 408 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1)) 409 Expect(fakeTerminalHelper.GetWinsizeCallCount()).To(Equal(1)) 410 411 termType, height, width, modes := fakeSecureSession.RequestPtyArgsForCall(0) 412 Expect(termType).To(Equal("test-terminal-type")) 413 Expect(height).To(Equal(256)) 414 Expect(width).To(Equal(1024)) 415 416 expectedModes := ssh.TerminalModes{ 417 ssh.ECHO: 1, 418 ssh.TTY_OP_ISPEED: 115200, 419 ssh.TTY_OP_OSPEED: 115200, 420 } 421 Expect(modes).To(Equal(expectedModes)) 422 }) 423 424 Context("when the TERM environment variable is not set", func() { 425 BeforeEach(func() { 426 os.Unsetenv("TERM") 427 }) 428 429 It("requests a pty with the default terminal type", func() { 430 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1)) 431 432 termType, _, _, _ := fakeSecureSession.RequestPtyArgsForCall(0) 433 Expect(termType).To(Equal("xterm")) 434 }) 435 }) 436 437 It("puts the terminal into raw mode and restores it after running the shell", func() { 438 Expect(fakeSecureSession.ShellCallCount()).To(Equal(1)) 439 Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1)) 440 Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(1)) 441 }) 442 443 Context("when the pty allocation fails", func() { 444 var ptyError error 445 446 BeforeEach(func() { 447 ptyError = errors.New("pty allocation error") 448 fakeSecureSession.RequestPtyReturns(ptyError) 449 }) 450 451 It("returns the error", func() { 452 Expect(sessionErr).To(Equal(ptyError)) 453 }) 454 }) 455 456 Context("when placing the terminal into raw mode fails", func() { 457 BeforeEach(func() { 458 fakeTerminalHelper.SetRawTerminalReturns(nil, errors.New("woops")) 459 }) 460 461 It("keeps calm and carries on", func() { 462 Expect(fakeSecureSession.ShellCallCount()).To(Equal(1)) 463 }) 464 465 It("does not not restore the terminal", func() { 466 Expect(fakeSecureSession.ShellCallCount()).To(Equal(1)) 467 Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1)) 468 Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(0)) 469 }) 470 }) 471 }) 472 473 Context("when a command is specified", func() { 474 BeforeEach(func() { 475 commands = []string{"echo", "-n", "hello"} 476 }) 477 478 Context("when a terminal is requested", func() { 479 BeforeEach(func() { 480 terminalRequest = RequestTTYYes 481 fakeTerminalHelper.GetFdInfoReturns(0, true) 482 }) 483 484 It("requests a pty", func() { 485 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1)) 486 }) 487 }) 488 489 Context("when a terminal is not explicitly requested", func() { 490 BeforeEach(func() { 491 terminalRequest = RequestTTYAuto 492 }) 493 494 It("does not request a pty", func() { 495 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0)) 496 }) 497 }) 498 }) 499 }) 500 501 Context("when stdin is not a terminal", func() { 502 BeforeEach(func() { 503 stdin.ReadStub = func(p []byte) (int, error) { 504 return 0, io.EOF 505 } 506 507 terminalHelper := DefaultTerminalHelper() 508 fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo 509 fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize 510 }) 511 512 Context("when a terminal is not requested", func() { 513 It("does not request a pty", func() { 514 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0)) 515 }) 516 }) 517 518 Context("when a terminal is requested", func() { 519 BeforeEach(func() { 520 terminalRequest = RequestTTYYes 521 }) 522 523 It("does not request a pty", func() { 524 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0)) 525 }) 526 }) 527 }) 528 529 PContext("when a terminal is forced", func() { 530 BeforeEach(func() { 531 terminalRequest = RequestTTYForce 532 }) 533 534 It("requests a pty", func() { 535 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1)) 536 }) 537 }) 538 539 Context("when a terminal is disabled", func() { 540 BeforeEach(func() { 541 terminalRequest = RequestTTYNo 542 }) 543 544 It("does not request a pty", func() { 545 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0)) 546 }) 547 }) 548 549 Context("when a command is not specified", func() { 550 It("requests an interactive shell", func() { 551 Expect(fakeSecureSession.ShellCallCount()).To(Equal(1)) 552 }) 553 554 Context("when the shell request returns an error", func() { 555 BeforeEach(func() { 556 fakeSecureSession.ShellReturns(errors.New("oh bother")) 557 }) 558 559 It("returns the error", func() { 560 Expect(sessionErr).To(MatchError("oh bother")) 561 }) 562 }) 563 }) 564 565 Context("when a command is specifed", func() { 566 BeforeEach(func() { 567 commands = []string{"echo", "-n", "hello"} 568 }) 569 570 It("starts the command", func() { 571 Expect(fakeSecureSession.StartCallCount()).To(Equal(1)) 572 Expect(fakeSecureSession.StartArgsForCall(0)).To(Equal("echo -n hello")) 573 }) 574 575 Context("when the command fails to start", func() { 576 BeforeEach(func() { 577 fakeSecureSession.StartReturns(errors.New("oh well")) 578 }) 579 580 It("returns the error", func() { 581 Expect(sessionErr).To(MatchError("oh well")) 582 }) 583 }) 584 }) 585 586 Context("when the shell or command has started", func() { 587 BeforeEach(func() { 588 stdin.ReadStub = func(p []byte) (int, error) { 589 p[0] = 0 590 return 1, io.EOF 591 } 592 stdinPipe.WriteStub = func(p []byte) (int, error) { 593 defer GinkgoRecover() 594 Expect(p[0]).To(Equal(byte(0))) 595 return 1, nil 596 } 597 598 stdoutPipe.ReadStub = func(p []byte) (int, error) { 599 p[0] = 1 600 return 1, io.EOF 601 } 602 stdout.WriteStub = func(p []byte) (int, error) { 603 defer GinkgoRecover() 604 Expect(p[0]).To(Equal(byte(1))) 605 return 1, nil 606 } 607 608 stderrPipe.ReadStub = func(p []byte) (int, error) { 609 p[0] = 2 610 return 1, io.EOF 611 } 612 stderr.WriteStub = func(p []byte) (int, error) { 613 defer GinkgoRecover() 614 Expect(p[0]).To(Equal(byte(2))) 615 return 1, nil 616 } 617 618 fakeSecureSession.StdinPipeReturns(stdinPipe, nil) 619 fakeSecureSession.StdoutPipeReturns(stdoutPipe, nil) 620 fakeSecureSession.StderrPipeReturns(stderrPipe, nil) 621 622 fakeSecureSession.WaitReturns(errors.New("error result")) 623 }) 624 625 It("copies data from the stdin stream to the session stdin pipe", func() { 626 Eventually(stdin.ReadCallCount).Should(Equal(1)) 627 Eventually(stdinPipe.WriteCallCount).Should(Equal(1)) 628 }) 629 630 It("copies data from the session stdout pipe to the stdout stream", func() { 631 Eventually(stdoutPipe.ReadCallCount).Should(Equal(1)) 632 Eventually(stdout.WriteCallCount).Should(Equal(1)) 633 }) 634 635 It("copies data from the session stderr pipe to the stderr stream", func() { 636 Eventually(stderrPipe.ReadCallCount).Should(Equal(1)) 637 Eventually(stderr.WriteCallCount).Should(Equal(1)) 638 }) 639 640 It("waits for the session to end", func() { 641 Expect(fakeSecureSession.WaitCallCount()).To(Equal(1)) 642 }) 643 644 It("returns the result from wait", func() { 645 Expect(sessionErr).To(MatchError("error result")) 646 }) 647 648 Context("when the session terminates before stream copies complete", func() { 649 var sessionErrorCh chan error 650 651 BeforeEach(func() { 652 sessionErrorCh = make(chan error, 1) 653 654 interactiveSessionInvoker = func(secureShell *SecureShell) { 655 go func() { 656 sessionErrorCh <- secureShell.InteractiveSession(commands, terminalRequest) 657 }() 658 } 659 660 stdoutPipe.ReadStub = func(p []byte) (int, error) { 661 defer GinkgoRecover() 662 Eventually(fakeSecureSession.WaitCallCount).Should(Equal(1)) 663 Consistently(sessionErrorCh).ShouldNot(Receive()) 664 665 p[0] = 1 666 return 1, io.EOF 667 } 668 669 stderrPipe.ReadStub = func(p []byte) (int, error) { 670 defer GinkgoRecover() 671 Eventually(fakeSecureSession.WaitCallCount).Should(Equal(1)) 672 Consistently(sessionErrorCh).ShouldNot(Receive()) 673 674 p[0] = 2 675 return 1, io.EOF 676 } 677 }) 678 679 It("waits for the copies to complete", func() { 680 Eventually(sessionErrorCh).Should(Receive()) 681 Expect(stdoutPipe.ReadCallCount()).To(Equal(1)) 682 Expect(stderrPipe.ReadCallCount()).To(Equal(1)) 683 }) 684 }) 685 686 Context("when stdin is closed", func() { 687 BeforeEach(func() { 688 stdin.ReadStub = func(p []byte) (int, error) { 689 defer GinkgoRecover() 690 Consistently(stdinPipe.CloseCallCount).Should(Equal(0)) 691 p[0] = 0 692 return 1, io.EOF 693 } 694 }) 695 696 It("closes the stdinPipe", func() { 697 Eventually(stdinPipe.CloseCallCount).Should(Equal(1)) 698 }) 699 }) 700 }) 701 702 Context("when stdout is a terminal and a window size change occurs", func() { 703 var master, slave *os.File 704 705 BeforeEach(func() { 706 var err error 707 master, slave, err = pty.Open() 708 Expect(err).NotTo(HaveOccurred()) 709 710 terminalHelper := DefaultTerminalHelper() 711 fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo 712 fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize 713 fakeTerminalHelper.StdStreamsReturns(stdin, slave, stderr) 714 715 winsize := &term.Winsize{Height: 100, Width: 100} 716 err = term.SetWinsize(slave.Fd(), winsize) 717 Expect(err).NotTo(HaveOccurred()) 718 719 fakeSecureSession.WaitStub = func() error { 720 fakeSecureSession.SendRequestCallCount() 721 Expect(fakeSecureSession.SendRequestCallCount()).To(Equal(0)) 722 723 // No dimension change 724 for i := 0; i < 3; i++ { 725 winsize := &term.Winsize{Height: 100, Width: 100} 726 err = term.SetWinsize(slave.Fd(), winsize) 727 Expect(err).NotTo(HaveOccurred()) 728 } 729 730 winsize := &term.Winsize{Height: 100, Width: 200} 731 err = term.SetWinsize(slave.Fd(), winsize) 732 Expect(err).NotTo(HaveOccurred()) 733 734 err = syscall.Kill(syscall.Getpid(), syscall.SIGWINCH) 735 Expect(err).NotTo(HaveOccurred()) 736 737 Eventually(fakeSecureSession.SendRequestCallCount).Should(Equal(1)) 738 return nil 739 } 740 }) 741 742 AfterEach(func() { 743 master.Close() 744 slave.Close() 745 }) 746 747 It("sends window change events when the window dimensions change", func() { 748 Expect(fakeSecureSession.SendRequestCallCount()).To(Equal(1)) 749 750 requestType, wantReply, message := fakeSecureSession.SendRequestArgsForCall(0) 751 Expect(requestType).To(Equal("window-change")) 752 Expect(wantReply).To(BeFalse()) 753 754 type resizeMessage struct { 755 Width uint32 756 Height uint32 757 PixelWidth uint32 758 PixelHeight uint32 759 } 760 var resizeMsg resizeMessage 761 762 err := ssh.Unmarshal(message, &resizeMsg) 763 Expect(err).NotTo(HaveOccurred()) 764 765 Expect(resizeMsg).To(Equal(resizeMessage{Height: 100, Width: 200})) 766 }) 767 }) 768 769 Describe("keep alive messages", func() { 770 var times []time.Time 771 var timesCh chan []time.Time 772 var done chan struct{} 773 774 BeforeEach(func() { 775 keepAliveDuration = 100 * time.Millisecond 776 777 times = []time.Time{} 778 timesCh = make(chan []time.Time, 1) 779 done = make(chan struct{}, 1) 780 781 fakeConnection.SendRequestStub = func(reqName string, wantReply bool, message []byte) (bool, []byte, error) { 782 Expect(reqName).To(Equal("keepalive@cloudfoundry.org")) 783 Expect(wantReply).To(BeTrue()) 784 Expect(message).To(BeNil()) 785 786 times = append(times, time.Now()) 787 if len(times) == 3 { 788 timesCh <- times 789 close(done) 790 } 791 return true, nil, nil 792 } 793 794 fakeSecureSession.WaitStub = func() error { 795 Eventually(done).Should(BeClosed()) 796 return nil 797 } 798 }) 799 800 PIt("sends keep alive messages at the expected interval", func() { 801 times := <-timesCh 802 Expect(times[2]).To(BeTemporally("~", times[0].Add(200*time.Millisecond), 160*time.Millisecond)) 803 }) 804 }) 805 }) 806 807 Describe("LocalPortForward", func() { 808 var ( 809 forwardErr error 810 811 echoAddress string 812 echoListener *fake_net.FakeListener 813 echoHandler *fake_server.FakeConnectionHandler 814 echoServer *server.Server 815 816 localAddress string 817 818 realLocalListener net.Listener 819 fakeLocalListener *fake_net.FakeListener 820 821 forwardSpecs []LocalPortForward 822 ) 823 824 BeforeEach(func() { 825 logger := lagertest.NewTestLogger("test") 826 827 var err error 828 realLocalListener, err = net.Listen("tcp", "127.0.0.1:0") 829 Expect(err).NotTo(HaveOccurred()) 830 831 localAddress = realLocalListener.Addr().String() 832 fakeListenerFactory.ListenReturns(realLocalListener, nil) 833 834 echoHandler = &fake_server.FakeConnectionHandler{} 835 echoHandler.HandleConnectionStub = func(conn net.Conn) { 836 io.Copy(conn, conn) 837 conn.Close() 838 } 839 840 realListener, err := net.Listen("tcp", "127.0.0.1:0") 841 Expect(err).NotTo(HaveOccurred()) 842 echoAddress = realListener.Addr().String() 843 844 echoListener = &fake_net.FakeListener{} 845 echoListener.AcceptStub = realListener.Accept 846 echoListener.CloseStub = realListener.Close 847 echoListener.AddrStub = realListener.Addr 848 849 fakeLocalListener = &fake_net.FakeListener{} 850 fakeLocalListener.AcceptReturns(nil, errors.New("Not Accepting Connections")) 851 852 echoServer = server.NewServer(logger.Session("echo"), "", echoHandler) 853 echoServer.SetListener(echoListener) 854 go echoServer.Serve() 855 856 forwardSpecs = []LocalPortForward{{ 857 RemoteAddress: echoAddress, 858 LocalAddress: localAddress, 859 }} 860 861 fakeSecureClient.DialStub = net.Dial 862 }) 863 864 JustBeforeEach(func() { 865 connectErr := secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation) 866 Expect(connectErr).NotTo(HaveOccurred()) 867 868 forwardErr = secureShell.LocalPortForward(forwardSpecs) 869 }) 870 871 AfterEach(func() { 872 err := secureShell.Close() 873 Expect(err).NotTo(HaveOccurred()) 874 echoServer.Shutdown() 875 876 realLocalListener.Close() 877 }) 878 879 validateConnectivity := func(addr string) { 880 conn, err := net.Dial("tcp", addr) 881 Expect(err).NotTo(HaveOccurred()) 882 883 msg := fmt.Sprintf("Hello from %s\n", addr) 884 n, err := conn.Write([]byte(msg)) 885 Expect(err).NotTo(HaveOccurred()) 886 Expect(n).To(Equal(len(msg))) 887 888 response := make([]byte, len(msg)) 889 n, err = conn.Read(response) 890 Expect(err).NotTo(HaveOccurred()) 891 Expect(n).To(Equal(len(msg))) 892 893 err = conn.Close() 894 Expect(err).NotTo(HaveOccurred()) 895 896 Expect(response).To(Equal([]byte(msg))) 897 } 898 899 It("dials the connect address when a local connection is made", func() { 900 Expect(forwardErr).NotTo(HaveOccurred()) 901 902 conn, err := net.Dial("tcp", localAddress) 903 Expect(err).NotTo(HaveOccurred()) 904 905 Eventually(echoListener.AcceptCallCount).Should(BeNumerically(">=", 1)) 906 Eventually(fakeSecureClient.DialCallCount).Should(Equal(1)) 907 908 network, addr := fakeSecureClient.DialArgsForCall(0) 909 Expect(network).To(Equal("tcp")) 910 Expect(addr).To(Equal(echoAddress)) 911 912 Expect(conn.Close()).NotTo(HaveOccurred()) 913 }) 914 915 It("copies data between the local and remote connections", func() { 916 validateConnectivity(localAddress) 917 }) 918 919 Context("when a local connection is already open", func() { 920 var conn net.Conn 921 922 JustBeforeEach(func() { 923 var err error 924 conn, err = net.Dial("tcp", localAddress) 925 Expect(err).NotTo(HaveOccurred()) 926 }) 927 928 AfterEach(func() { 929 err := conn.Close() 930 Expect(err).NotTo(HaveOccurred()) 931 }) 932 933 It("allows for new incoming connections as well", func() { 934 validateConnectivity(localAddress) 935 }) 936 }) 937 938 Context("when there are multiple port forward specs", func() { 939 var ( 940 realLocalListener2 net.Listener 941 localAddress2 string 942 ) 943 944 BeforeEach(func() { 945 var err error 946 realLocalListener2, err = net.Listen("tcp", "127.0.0.1:0") 947 Expect(err).NotTo(HaveOccurred()) 948 949 localAddress2 = realLocalListener2.Addr().String() 950 951 fakeListenerFactory.ListenStub = func(network, addr string) (net.Listener, error) { 952 if addr == localAddress { 953 return realLocalListener, nil 954 } 955 956 if addr == localAddress2 { 957 return realLocalListener2, nil 958 } 959 960 return nil, errors.New("unexpected address") 961 } 962 963 forwardSpecs = []LocalPortForward{ 964 { 965 RemoteAddress: echoAddress, 966 LocalAddress: localAddress, 967 }, 968 { 969 RemoteAddress: echoAddress, 970 LocalAddress: localAddress2, 971 }, 972 } 973 }) 974 975 AfterEach(func() { 976 realLocalListener2.Close() 977 }) 978 979 It("listens to all the things", func() { 980 Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(2)) 981 982 network, addr := fakeListenerFactory.ListenArgsForCall(0) 983 Expect(network).To(Equal("tcp")) 984 Expect(addr).To(Equal(localAddress)) 985 986 network, addr = fakeListenerFactory.ListenArgsForCall(1) 987 Expect(network).To(Equal("tcp")) 988 Expect(addr).To(Equal(localAddress2)) 989 }) 990 991 It("forwards to the correct target", func() { 992 validateConnectivity(localAddress) 993 validateConnectivity(localAddress2) 994 }) 995 996 Context("when the secure client is closed", func() { 997 BeforeEach(func() { 998 fakeListenerFactory.ListenReturns(fakeLocalListener, nil) 999 fakeLocalListener.AcceptReturns(nil, errors.New("not accepting connections")) 1000 }) 1001 1002 It("closes the listeners ", func() { 1003 Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(2)) 1004 Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(2)) 1005 1006 originalCloseCount := fakeLocalListener.CloseCallCount() 1007 err := secureShell.Close() 1008 Expect(err).NotTo(HaveOccurred()) 1009 Expect(fakeLocalListener.CloseCallCount()).Should(Equal(originalCloseCount + 2)) 1010 }) 1011 }) 1012 }) 1013 1014 Context("when listen fails", func() { 1015 BeforeEach(func() { 1016 fakeListenerFactory.ListenReturns(nil, errors.New("failure is an option")) 1017 }) 1018 1019 It("returns the error", func() { 1020 Expect(forwardErr).To(MatchError("failure is an option")) 1021 }) 1022 }) 1023 1024 Context("when the client it closed", func() { 1025 BeforeEach(func() { 1026 fakeListenerFactory.ListenReturns(fakeLocalListener, nil) 1027 fakeLocalListener.AcceptReturns(nil, errors.New("not accepting and connections")) 1028 }) 1029 1030 It("closes the listener when the client is closed", func() { 1031 Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(1)) 1032 Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(1)) 1033 1034 originalCloseCount := fakeLocalListener.CloseCallCount() 1035 err := secureShell.Close() 1036 Expect(err).NotTo(HaveOccurred()) 1037 Expect(fakeLocalListener.CloseCallCount()).Should(Equal(originalCloseCount + 1)) 1038 }) 1039 }) 1040 1041 Context("when accept fails", func() { 1042 var fakeConn *fake_net.FakeConn 1043 1044 BeforeEach(func() { 1045 fakeConn = &fake_net.FakeConn{} 1046 fakeConn.ReadReturns(0, io.EOF) 1047 1048 fakeListenerFactory.ListenReturns(fakeLocalListener, nil) 1049 }) 1050 1051 Context("with a permanent error", func() { 1052 BeforeEach(func() { 1053 fakeLocalListener.AcceptReturns(nil, errors.New("boom")) 1054 }) 1055 1056 It("stops trying to accept connections", func() { 1057 Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(1)) 1058 Consistently(fakeLocalListener.AcceptCallCount).Should(Equal(1)) 1059 Expect(fakeLocalListener.CloseCallCount()).To(Equal(1)) 1060 }) 1061 }) 1062 1063 Context("with a temporary error", func() { 1064 var timeCh chan time.Time 1065 1066 BeforeEach(func() { 1067 timeCh = make(chan time.Time, 3) 1068 1069 fakeLocalListener.AcceptStub = func() (net.Conn, error) { 1070 timeCh := timeCh 1071 if fakeLocalListener.AcceptCallCount() > 3 { 1072 close(timeCh) 1073 return nil, test_helpers.NewTestNetError(false, false) 1074 } else { 1075 timeCh <- time.Now() 1076 return nil, test_helpers.NewTestNetError(false, true) 1077 } 1078 } 1079 }) 1080 1081 PIt("retries connecting after a short delay", func() { 1082 Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(3)) 1083 Expect(timeCh).To(HaveLen(3)) 1084 1085 times := make([]time.Time, 0) 1086 for t := range timeCh { 1087 times = append(times, t) 1088 } 1089 1090 Expect(times[1]).To(BeTemporally("~", times[0].Add(115*time.Millisecond), 80*time.Millisecond)) 1091 Expect(times[2]).To(BeTemporally("~", times[1].Add(115*time.Millisecond), 100*time.Millisecond)) 1092 }) 1093 }) 1094 }) 1095 1096 Context("when dialing the connect address fails", func() { 1097 var fakeTarget *fake_net.FakeConn 1098 1099 BeforeEach(func() { 1100 fakeTarget = &fake_net.FakeConn{} 1101 fakeSecureClient.DialReturns(fakeTarget, errors.New("boom")) 1102 }) 1103 1104 It("does not call close on the target connection", func() { 1105 Consistently(fakeTarget.CloseCallCount).Should(Equal(0)) 1106 }) 1107 }) 1108 }) 1109 1110 Describe("Wait", func() { 1111 var waitErr error 1112 1113 JustBeforeEach(func() { 1114 connectErr := secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation) 1115 Expect(connectErr).NotTo(HaveOccurred()) 1116 1117 waitErr = secureShell.Wait() 1118 }) 1119 1120 It("calls wait on the secureClient", func() { 1121 Expect(waitErr).NotTo(HaveOccurred()) 1122 Expect(fakeSecureClient.WaitCallCount()).To(Equal(1)) 1123 }) 1124 1125 Describe("keep alive messages", func() { 1126 var times []time.Time 1127 var timesCh chan []time.Time 1128 var done chan struct{} 1129 1130 BeforeEach(func() { 1131 keepAliveDuration = 100 * time.Millisecond 1132 1133 times = []time.Time{} 1134 timesCh = make(chan []time.Time, 1) 1135 done = make(chan struct{}, 1) 1136 1137 fakeConnection.SendRequestStub = func(reqName string, wantReply bool, message []byte) (bool, []byte, error) { 1138 Expect(reqName).To(Equal("keepalive@cloudfoundry.org")) 1139 Expect(wantReply).To(BeTrue()) 1140 Expect(message).To(BeNil()) 1141 1142 times = append(times, time.Now()) 1143 if len(times) == 3 { 1144 timesCh <- times 1145 close(done) 1146 } 1147 return true, nil, nil 1148 } 1149 1150 fakeSecureClient.WaitStub = func() error { 1151 Eventually(done).Should(BeClosed()) 1152 return nil 1153 } 1154 }) 1155 1156 PIt("sends keep alive messages at the expected interval", func() { 1157 Expect(waitErr).NotTo(HaveOccurred()) 1158 times := <-timesCh 1159 Expect(times[2]).To(BeTemporally("~", times[0].Add(200*time.Millisecond), 100*time.Millisecond)) 1160 }) 1161 }) 1162 }) 1163 1164 Describe("Close", func() { 1165 JustBeforeEach(func() { 1166 connectErr := secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation) 1167 Expect(connectErr).NotTo(HaveOccurred()) 1168 }) 1169 1170 It("calls close on the secureClient", func() { 1171 err := secureShell.Close() 1172 Expect(err).NotTo(HaveOccurred()) 1173 1174 Expect(fakeSecureClient.CloseCallCount()).To(Equal(1)) 1175 }) 1176 }) 1177 })