github.com/loafoe/cli@v7.1.0+incompatible/util/clissh/ssh_test.go (about) 1 // +build !windows,!386 2 3 // skipping 386 because lager uses UInt64 in Session() 4 // skipping windows because Unix/Linux only syscall in test. 5 // should refactor out the conflicts so we could test this package in multi platforms. 6 7 package clissh_test 8 9 import ( 10 "errors" 11 "fmt" 12 "io" 13 "net" 14 "os" 15 "sync" 16 "syscall" 17 "time" 18 19 . "code.cloudfoundry.org/cli/util/clissh" 20 "code.cloudfoundry.org/cli/util/clissh/clisshfakes" 21 "code.cloudfoundry.org/cli/util/clissh/ssherror" 22 "code.cloudfoundry.org/diego-ssh/server" 23 fake_server "code.cloudfoundry.org/diego-ssh/server/fakes" 24 "code.cloudfoundry.org/diego-ssh/test_helpers" 25 "code.cloudfoundry.org/diego-ssh/test_helpers/fake_io" 26 "code.cloudfoundry.org/diego-ssh/test_helpers/fake_net" 27 "code.cloudfoundry.org/diego-ssh/test_helpers/fake_ssh" 28 "code.cloudfoundry.org/lager/lagertest" 29 "github.com/kr/pty" 30 "github.com/moby/moby/pkg/term" 31 . "github.com/onsi/ginkgo" 32 . "github.com/onsi/gomega" 33 "golang.org/x/crypto/ssh" 34 ) 35 36 func BlockAcceptOnClose(fake *fake_net.FakeListener) { 37 waitUntilClosed := make(chan bool) 38 fake.AcceptStub = func() (net.Conn, error) { 39 <-waitUntilClosed 40 return nil, errors.New("banana") 41 } 42 43 var once sync.Once 44 fake.CloseStub = func() error { 45 once.Do(func() { 46 close(waitUntilClosed) 47 }) 48 return nil 49 } 50 } 51 52 var _ = Describe("CLI SSH", func() { 53 var ( 54 fakeSecureDialer *clisshfakes.FakeSecureDialer 55 fakeSecureClient *clisshfakes.FakeSecureClient 56 fakeTerminalHelper *clisshfakes.FakeTerminalHelper 57 fakeListenerFactory *clisshfakes.FakeListenerFactory 58 fakeSecureSession *clisshfakes.FakeSecureSession 59 60 fakeConnection *fake_ssh.FakeConn 61 stdinPipe *fake_io.FakeWriteCloser 62 stdoutPipe *fake_io.FakeReader 63 stderrPipe *fake_io.FakeReader 64 secureShell *SecureShell 65 66 username string 67 passcode string 68 sshEndpoint string 69 sshEndpointFingerprint string 70 skipHostValidation bool 71 commands []string 72 terminalRequest TTYRequest 73 keepAliveDuration time.Duration 74 ) 75 76 BeforeEach(func() { 77 fakeSecureDialer = new(clisshfakes.FakeSecureDialer) 78 fakeSecureClient = new(clisshfakes.FakeSecureClient) 79 fakeTerminalHelper = new(clisshfakes.FakeTerminalHelper) 80 fakeListenerFactory = new(clisshfakes.FakeListenerFactory) 81 fakeSecureSession = new(clisshfakes.FakeSecureSession) 82 83 fakeConnection = new(fake_ssh.FakeConn) 84 stdinPipe = new(fake_io.FakeWriteCloser) 85 stdoutPipe = new(fake_io.FakeReader) 86 stderrPipe = new(fake_io.FakeReader) 87 88 fakeListenerFactory.ListenStub = net.Listen 89 fakeSecureClient.NewSessionReturns(fakeSecureSession, nil) 90 fakeSecureClient.ConnReturns(fakeConnection) 91 fakeSecureDialer.DialReturns(fakeSecureClient, nil) 92 93 stdinPipe.WriteStub = func(p []byte) (int, error) { 94 return len(p), nil 95 } 96 fakeSecureSession.StdinPipeReturns(stdinPipe, nil) 97 98 stdoutPipe.ReadStub = func(p []byte) (int, error) { 99 return 0, io.EOF 100 } 101 fakeSecureSession.StdoutPipeReturns(stdoutPipe, nil) 102 103 stderrPipe.ReadStub = func(p []byte) (int, error) { 104 return 0, io.EOF 105 } 106 fakeSecureSession.StderrPipeReturns(stderrPipe, nil) 107 108 username = "some-user" 109 passcode = "some-passcode" 110 sshEndpoint = "some-endpoint" 111 sshEndpointFingerprint = "some-fingerprint" 112 skipHostValidation = false 113 commands = []string{} 114 terminalRequest = RequestTTYAuto 115 keepAliveDuration = DefaultKeepAliveInterval 116 }) 117 118 JustBeforeEach(func() { 119 secureShell = NewSecureShell( 120 fakeSecureDialer, 121 fakeTerminalHelper, 122 fakeListenerFactory, 123 keepAliveDuration, 124 ) 125 }) 126 127 Describe("Connect", func() { 128 var connectErr error 129 130 JustBeforeEach(func() { 131 connectErr = secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation) 132 }) 133 134 When("dialing succeeds", func() { 135 It("creates the ssh client", func() { 136 Expect(connectErr).ToNot(HaveOccurred()) 137 138 Expect(fakeSecureDialer.DialCallCount()).To(Equal(1)) 139 protocolArg, sshEndpointArg, sshConfigArg := fakeSecureDialer.DialArgsForCall(0) 140 Expect(protocolArg).To(Equal("tcp")) 141 Expect(sshEndpointArg).To(Equal(sshEndpoint)) 142 Expect(sshConfigArg.User).To(Equal(username)) 143 Expect(sshConfigArg.Auth).To(HaveLen(1)) 144 Expect(sshConfigArg.HostKeyCallback).ToNot(BeNil()) 145 }) 146 }) 147 148 When("dialing fails", func() { 149 var dialError error 150 151 When("the error is a generic Dial error", func() { 152 BeforeEach(func() { 153 dialError = errors.New("woops") 154 fakeSecureDialer.DialReturns(nil, dialError) 155 }) 156 157 It("returns the dial error", func() { 158 Expect(connectErr).To(Equal(dialError)) 159 Expect(fakeSecureDialer.DialCallCount()).To(Equal(1)) 160 }) 161 }) 162 163 When("the dialing error is a golang 'unable to authenticate' error", func() { 164 BeforeEach(func() { 165 dialError = fmt.Errorf("ssh: unable to authenticate, attempted methods %v, no supported methods remain", []string{"none", "password"}) 166 fakeSecureDialer.DialReturns(nil, dialError) 167 }) 168 169 It("returns an UnableToAuthenticateError", func() { 170 Expect(connectErr).To(MatchError(ssherror.UnableToAuthenticateError{Err: dialError})) 171 Expect(fakeSecureDialer.DialCallCount()).To(Equal(1)) 172 }) 173 }) 174 }) 175 }) 176 177 Describe("InteractiveSession", func() { 178 var ( 179 stdin *fake_io.FakeReadCloser 180 stdout, stderr *fake_io.FakeWriter 181 182 sessionErr error 183 interactiveSessionInvoker func(secureShell *SecureShell) 184 ) 185 186 BeforeEach(func() { 187 stdin = new(fake_io.FakeReadCloser) 188 stdout = new(fake_io.FakeWriter) 189 stderr = new(fake_io.FakeWriter) 190 191 fakeTerminalHelper.StdStreamsReturns(stdin, stdout, stderr) 192 interactiveSessionInvoker = func(secureShell *SecureShell) { 193 sessionErr = secureShell.InteractiveSession(commands, terminalRequest) 194 } 195 }) 196 197 JustBeforeEach(func() { 198 connectErr := secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation) 199 Expect(connectErr).NotTo(HaveOccurred()) 200 interactiveSessionInvoker(secureShell) 201 }) 202 203 When("host key validation is enabled", func() { 204 var ( 205 callback func(hostname string, remote net.Addr, key ssh.PublicKey) error 206 addr net.Addr 207 ) 208 209 BeforeEach(func() { 210 skipHostValidation = false 211 }) 212 213 JustBeforeEach(func() { 214 Expect(fakeSecureDialer.DialCallCount()).To(Equal(1)) 215 _, _, config := fakeSecureDialer.DialArgsForCall(0) 216 callback = config.HostKeyCallback 217 218 listener, err := net.Listen("tcp", "localhost:0") 219 Expect(err).NotTo(HaveOccurred()) 220 221 addr = listener.Addr() 222 listener.Close() 223 }) 224 225 When("the md5 fingerprint matches", func() { 226 BeforeEach(func() { 227 sshEndpointFingerprint = "41:ce:56:e6:9c:42:a9:c6:9e:68:ac:e3:4d:f6:38:79" 228 }) 229 230 It("does not return an error", func() { 231 Expect(callback("", addr, TestHostKey.PublicKey())).ToNot(HaveOccurred()) 232 }) 233 }) 234 235 When("the hex sha1 fingerprint matches", func() { 236 BeforeEach(func() { 237 sshEndpointFingerprint = "a8:e2:67:cb:ea:2a:6e:23:a1:72:ce:8f:07:92:15:ee:1f:82:f8:ca" 238 }) 239 240 It("does not return an error", func() { 241 Expect(callback("", addr, TestHostKey.PublicKey())).ToNot(HaveOccurred()) 242 }) 243 }) 244 245 When("the base64 sha256 fingerprint matches", func() { 246 BeforeEach(func() { 247 sshEndpointFingerprint = "sp/jrLuj66r+yrLDUKZdJU5tdzt4mq/UaSiNBjpgr+8" 248 }) 249 250 It("does not return an error", func() { 251 Expect(callback("", addr, TestHostKey.PublicKey())).ToNot(HaveOccurred()) 252 }) 253 }) 254 255 When("the base64 SHA256 fingerprint does not match", func() { 256 BeforeEach(func() { 257 sshEndpointFingerprint = "0000000000000000000000000000000000000000000" 258 }) 259 260 It("returns an error'", func() { 261 err := callback("", addr, TestHostKey.PublicKey()) 262 Expect(err).To(MatchError(MatchRegexp(`Host key verification failed\.`))) 263 Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\""))) 264 }) 265 }) 266 267 When("the hex SHA1 fingerprint does not match", func() { 268 BeforeEach(func() { 269 sshEndpointFingerprint = "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00" 270 }) 271 272 It("returns an error'", func() { 273 err := callback("", addr, TestHostKey.PublicKey()) 274 Expect(err).To(MatchError(MatchRegexp(`Host key verification failed\.`))) 275 Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\""))) 276 }) 277 }) 278 279 When("the MD5 fingerprint does not match", func() { 280 BeforeEach(func() { 281 sshEndpointFingerprint = "00:00:00:00:00:00:00:00:00:00:00:00:00:00:00:00" 282 }) 283 284 It("returns an error'", func() { 285 err := callback("", addr, TestHostKey.PublicKey()) 286 Expect(err).To(MatchError(MatchRegexp(`Host key verification failed\.`))) 287 Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\""))) 288 }) 289 }) 290 291 When("no fingerprint is present in endpoint info", func() { 292 BeforeEach(func() { 293 sshEndpointFingerprint = "" 294 sshEndpoint = "" 295 }) 296 297 It("returns an error'", func() { 298 err := callback("", addr, TestHostKey.PublicKey()) 299 Expect(err).To(MatchError(MatchRegexp(`Unable to verify identity of host\.`))) 300 Expect(err).To(MatchError(MatchRegexp("The fingerprint of the received key was \".*\""))) 301 }) 302 }) 303 304 When("the fingerprint length doesn't make sense", func() { 305 BeforeEach(func() { 306 sshEndpointFingerprint = "garbage" 307 }) 308 309 It("returns an error", func() { 310 err := callback("", addr, TestHostKey.PublicKey()) 311 Eventually(err).Should(MatchError(MatchRegexp("Unsupported host key fingerprint format"))) 312 }) 313 }) 314 }) 315 316 When("the skip host validation flag is set", func() { 317 BeforeEach(func() { 318 skipHostValidation = true 319 }) 320 321 It("the HostKeyCallback on the Config to always return nil", func() { 322 Expect(fakeSecureDialer.DialCallCount()).To(Equal(1)) 323 324 _, _, config := fakeSecureDialer.DialArgsForCall(0) 325 Expect(config.HostKeyCallback("some-hostname", nil, nil)).To(BeNil()) 326 }) 327 }) 328 329 // TODO: see if it's possible to test the piping between the ss client input and outputs and the UI object we pass in 330 When("dialing is successful", func() { 331 It("creates a new secure shell session", func() { 332 Expect(fakeSecureClient.NewSessionCallCount()).To(Equal(1)) 333 }) 334 335 It("closes the session", func() { 336 Expect(fakeSecureSession.CloseCallCount()).To(Equal(1)) 337 }) 338 339 It("gets a stdin pipe for the session", func() { 340 Expect(fakeSecureSession.StdinPipeCallCount()).To(Equal(1)) 341 }) 342 343 When("getting the stdin pipe fails", func() { 344 BeforeEach(func() { 345 fakeSecureSession.StdinPipeReturns(nil, errors.New("woops")) 346 }) 347 348 It("returns the error", func() { 349 Expect(sessionErr).Should(MatchError("woops")) 350 }) 351 }) 352 353 It("gets a stdout pipe for the session", func() { 354 Expect(fakeSecureSession.StdoutPipeCallCount()).To(Equal(1)) 355 }) 356 357 When("getting the stdout pipe fails", func() { 358 BeforeEach(func() { 359 fakeSecureSession.StdoutPipeReturns(nil, errors.New("woops")) 360 }) 361 362 It("returns the error", func() { 363 Expect(sessionErr).Should(MatchError("woops")) 364 }) 365 }) 366 367 It("gets a stderr pipe for the session", func() { 368 Expect(fakeSecureSession.StderrPipeCallCount()).To(Equal(1)) 369 }) 370 371 When("getting the stderr pipe fails", func() { 372 BeforeEach(func() { 373 fakeSecureSession.StderrPipeReturns(nil, errors.New("woops")) 374 }) 375 376 It("returns the error", func() { 377 Expect(sessionErr).Should(MatchError("woops")) 378 }) 379 }) 380 }) 381 382 When("stdin is a terminal", func() { 383 var master *os.File 384 385 BeforeEach(func() { 386 var err error 387 master, _, err = pty.Open() 388 Expect(err).NotTo(HaveOccurred()) 389 390 terminalRequest = RequestTTYForce 391 392 terminalHelper := DefaultTerminalHelper() 393 fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo 394 fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize 395 }) 396 397 AfterEach(func() { 398 master.Close() 399 }) 400 401 When("a command is not specified", func() { 402 var terminalType string 403 404 BeforeEach(func() { 405 terminalType = os.Getenv("TERM") 406 os.Setenv("TERM", "test-terminal-type") 407 408 winsize := &term.Winsize{Width: 1024, Height: 256} 409 fakeTerminalHelper.GetWinsizeReturns(winsize, nil) 410 411 fakeSecureSession.ShellStub = func() error { 412 Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1)) 413 Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(0)) 414 return nil 415 } 416 }) 417 418 AfterEach(func() { 419 os.Setenv("TERM", terminalType) 420 }) 421 422 It("requests a pty with the correct terminal type, window size, and modes", func() { 423 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1)) 424 Expect(fakeTerminalHelper.GetWinsizeCallCount()).To(Equal(1)) 425 426 termType, height, width, modes := fakeSecureSession.RequestPtyArgsForCall(0) 427 Expect(termType).To(Equal("test-terminal-type")) 428 Expect(height).To(Equal(256)) 429 Expect(width).To(Equal(1024)) 430 431 expectedModes := ssh.TerminalModes{ 432 ssh.ECHO: 1, 433 ssh.TTY_OP_ISPEED: 115200, 434 ssh.TTY_OP_OSPEED: 115200, 435 } 436 Expect(modes).To(Equal(expectedModes)) 437 }) 438 439 When("the TERM environment variable is not set", func() { 440 BeforeEach(func() { 441 os.Unsetenv("TERM") 442 }) 443 444 It("requests a pty with the default terminal type", func() { 445 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1)) 446 447 termType, _, _, _ := fakeSecureSession.RequestPtyArgsForCall(0) 448 Expect(termType).To(Equal("xterm")) 449 }) 450 }) 451 452 It("puts the terminal into raw mode and restores it after running the shell", func() { 453 Expect(fakeSecureSession.ShellCallCount()).To(Equal(1)) 454 Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1)) 455 Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(1)) 456 }) 457 458 When("the pty allocation fails", func() { 459 var ptyError error 460 461 BeforeEach(func() { 462 ptyError = errors.New("pty allocation error") 463 fakeSecureSession.RequestPtyReturns(ptyError) 464 }) 465 466 It("returns the error", func() { 467 Expect(sessionErr).To(Equal(ptyError)) 468 }) 469 }) 470 471 When("placing the terminal into raw mode fails", func() { 472 BeforeEach(func() { 473 fakeTerminalHelper.SetRawTerminalReturns(nil, errors.New("woops")) 474 }) 475 476 It("keeps calm and carries on", func() { 477 Expect(fakeSecureSession.ShellCallCount()).To(Equal(1)) 478 }) 479 480 It("does not not restore the terminal", func() { 481 Expect(fakeSecureSession.ShellCallCount()).To(Equal(1)) 482 Expect(fakeTerminalHelper.SetRawTerminalCallCount()).To(Equal(1)) 483 Expect(fakeTerminalHelper.RestoreTerminalCallCount()).To(Equal(0)) 484 }) 485 }) 486 }) 487 488 When("a command is specified", func() { 489 BeforeEach(func() { 490 commands = []string{"echo", "-n", "hello"} 491 }) 492 493 When("a terminal is requested", func() { 494 BeforeEach(func() { 495 terminalRequest = RequestTTYYes 496 fakeTerminalHelper.GetFdInfoReturns(0, true) 497 }) 498 499 It("requests a pty", func() { 500 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1)) 501 }) 502 }) 503 504 When("a terminal is not explicitly requested", func() { 505 BeforeEach(func() { 506 terminalRequest = RequestTTYAuto 507 }) 508 509 It("does not request a pty", func() { 510 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0)) 511 }) 512 }) 513 }) 514 }) 515 516 When("stdin is not a terminal", func() { 517 BeforeEach(func() { 518 stdin.ReadStub = func(p []byte) (int, error) { 519 return 0, io.EOF 520 } 521 522 terminalHelper := DefaultTerminalHelper() 523 fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo 524 fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize 525 }) 526 527 When("a terminal is not requested", func() { 528 It("does not request a pty", func() { 529 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0)) 530 }) 531 }) 532 533 When("a terminal is requested", func() { 534 BeforeEach(func() { 535 terminalRequest = RequestTTYYes 536 }) 537 538 It("does not request a pty", func() { 539 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0)) 540 }) 541 }) 542 }) 543 544 PWhen("a terminal is forced", func() { 545 BeforeEach(func() { 546 terminalRequest = RequestTTYForce 547 }) 548 549 It("requests a pty", func() { 550 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(1)) 551 }) 552 }) 553 554 When("a terminal is disabled", func() { 555 BeforeEach(func() { 556 terminalRequest = RequestTTYNo 557 }) 558 559 It("does not request a pty", func() { 560 Expect(fakeSecureSession.RequestPtyCallCount()).To(Equal(0)) 561 }) 562 }) 563 564 When("a command is not specified", func() { 565 It("requests an interactive shell", func() { 566 Expect(fakeSecureSession.ShellCallCount()).To(Equal(1)) 567 }) 568 569 When("the shell request returns an error", func() { 570 BeforeEach(func() { 571 fakeSecureSession.ShellReturns(errors.New("oh bother")) 572 }) 573 574 It("returns the error", func() { 575 Expect(sessionErr).To(MatchError("oh bother")) 576 }) 577 }) 578 }) 579 580 When("a command is specifed", func() { 581 BeforeEach(func() { 582 commands = []string{"echo", "-n", "hello"} 583 }) 584 585 It("starts the command", func() { 586 Expect(fakeSecureSession.StartCallCount()).To(Equal(1)) 587 Expect(fakeSecureSession.StartArgsForCall(0)).To(Equal("echo -n hello")) 588 }) 589 590 When("the command fails to start", func() { 591 BeforeEach(func() { 592 fakeSecureSession.StartReturns(errors.New("oh well")) 593 }) 594 595 It("returns the error", func() { 596 Expect(sessionErr).To(MatchError("oh well")) 597 }) 598 }) 599 }) 600 601 When("the shell or command has started", func() { 602 BeforeEach(func() { 603 stdin.ReadStub = func(p []byte) (int, error) { 604 p[0] = 0 605 return 1, io.EOF 606 } 607 stdinPipe.WriteStub = func(p []byte) (int, error) { 608 defer GinkgoRecover() 609 Expect(p[0]).To(Equal(byte(0))) 610 return 1, nil 611 } 612 613 stdoutPipe.ReadStub = func(p []byte) (int, error) { 614 p[0] = 1 615 return 1, io.EOF 616 } 617 stdout.WriteStub = func(p []byte) (int, error) { 618 defer GinkgoRecover() 619 Expect(p[0]).To(Equal(byte(1))) 620 return 1, nil 621 } 622 623 stderrPipe.ReadStub = func(p []byte) (int, error) { 624 p[0] = 2 625 return 1, io.EOF 626 } 627 stderr.WriteStub = func(p []byte) (int, error) { 628 defer GinkgoRecover() 629 Expect(p[0]).To(Equal(byte(2))) 630 return 1, nil 631 } 632 633 fakeSecureSession.StdinPipeReturns(stdinPipe, nil) 634 fakeSecureSession.StdoutPipeReturns(stdoutPipe, nil) 635 fakeSecureSession.StderrPipeReturns(stderrPipe, nil) 636 637 fakeSecureSession.WaitReturns(errors.New("error result")) 638 }) 639 640 It("copies data from the stdin stream to the session stdin pipe", func() { 641 Eventually(stdin.ReadCallCount).Should(Equal(1)) 642 Eventually(stdinPipe.WriteCallCount).Should(Equal(1)) 643 }) 644 645 It("copies data from the session stdout pipe to the stdout stream", func() { 646 Eventually(stdoutPipe.ReadCallCount).Should(Equal(1)) 647 Eventually(stdout.WriteCallCount).Should(Equal(1)) 648 }) 649 650 It("copies data from the session stderr pipe to the stderr stream", func() { 651 Eventually(stderrPipe.ReadCallCount).Should(Equal(1)) 652 Eventually(stderr.WriteCallCount).Should(Equal(1)) 653 }) 654 655 It("waits for the session to end", func() { 656 Expect(fakeSecureSession.WaitCallCount()).To(Equal(1)) 657 }) 658 659 It("returns the result from wait", func() { 660 Expect(sessionErr).To(MatchError("error result")) 661 }) 662 663 When("the session terminates before stream copies complete", func() { 664 var sessionErrorCh chan error 665 666 BeforeEach(func() { 667 sessionErrorCh = make(chan error, 1) 668 669 interactiveSessionInvoker = func(secureShell *SecureShell) { 670 go func() { 671 sessionErrorCh <- secureShell.InteractiveSession(commands, terminalRequest) 672 }() 673 } 674 675 stdoutPipe.ReadStub = func(p []byte) (int, error) { 676 defer GinkgoRecover() 677 Eventually(fakeSecureSession.WaitCallCount).Should(Equal(1)) 678 Consistently(sessionErrorCh).ShouldNot(Receive()) 679 680 p[0] = 1 681 return 1, io.EOF 682 } 683 684 stderrPipe.ReadStub = func(p []byte) (int, error) { 685 defer GinkgoRecover() 686 Eventually(fakeSecureSession.WaitCallCount).Should(Equal(1)) 687 Consistently(sessionErrorCh).ShouldNot(Receive()) 688 689 p[0] = 2 690 return 1, io.EOF 691 } 692 }) 693 694 It("waits for the copies to complete", func() { 695 Eventually(sessionErrorCh).Should(Receive()) 696 Expect(stdoutPipe.ReadCallCount()).To(Equal(1)) 697 Expect(stderrPipe.ReadCallCount()).To(Equal(1)) 698 }) 699 }) 700 701 When("stdin is closed", func() { 702 BeforeEach(func() { 703 stdin.ReadStub = func(p []byte) (int, error) { 704 defer GinkgoRecover() 705 Consistently(stdinPipe.CloseCallCount).Should(Equal(0)) 706 p[0] = 0 707 return 1, io.EOF 708 } 709 }) 710 711 It("closes the stdinPipe", func() { 712 Eventually(stdinPipe.CloseCallCount).Should(Equal(1)) 713 }) 714 }) 715 }) 716 717 When("stdout is a terminal and a window size change occurs", func() { 718 var master, slave *os.File 719 720 BeforeEach(func() { 721 var err error 722 master, slave, err = pty.Open() 723 Expect(err).NotTo(HaveOccurred()) 724 725 terminalHelper := DefaultTerminalHelper() 726 fakeTerminalHelper.GetFdInfoStub = terminalHelper.GetFdInfo 727 fakeTerminalHelper.GetWinsizeStub = terminalHelper.GetWinsize 728 fakeTerminalHelper.StdStreamsReturns(stdin, slave, stderr) 729 730 winsize := &term.Winsize{Height: 100, Width: 100} 731 err = term.SetWinsize(slave.Fd(), winsize) 732 Expect(err).NotTo(HaveOccurred()) 733 734 fakeSecureSession.WaitStub = func() error { 735 fakeSecureSession.SendRequestCallCount() 736 Expect(fakeSecureSession.SendRequestCallCount()).To(Equal(0)) 737 738 // No dimension change 739 for i := 0; i < 3; i++ { 740 winsize := &term.Winsize{Height: 100, Width: 100} 741 err = term.SetWinsize(slave.Fd(), winsize) 742 Expect(err).NotTo(HaveOccurred()) 743 } 744 745 winsize := &term.Winsize{Height: 100, Width: 200} 746 err = term.SetWinsize(slave.Fd(), winsize) 747 Expect(err).NotTo(HaveOccurred()) 748 749 err = syscall.Kill(syscall.Getpid(), syscall.SIGWINCH) 750 Expect(err).NotTo(HaveOccurred()) 751 752 Eventually(fakeSecureSession.SendRequestCallCount).Should(Equal(1)) 753 return nil 754 } 755 }) 756 757 AfterEach(func() { 758 master.Close() 759 slave.Close() 760 }) 761 762 It("sends window change events when the window dimensions change", func() { 763 Expect(fakeSecureSession.SendRequestCallCount()).To(Equal(1)) 764 765 requestType, wantReply, message := fakeSecureSession.SendRequestArgsForCall(0) 766 Expect(requestType).To(Equal("window-change")) 767 Expect(wantReply).To(BeFalse()) 768 769 type resizeMessage struct { 770 Width uint32 771 Height uint32 772 PixelWidth uint32 773 PixelHeight uint32 774 } 775 var resizeMsg resizeMessage 776 777 err := ssh.Unmarshal(message, &resizeMsg) 778 Expect(err).NotTo(HaveOccurred()) 779 780 Expect(resizeMsg).To(Equal(resizeMessage{Height: 100, Width: 200})) 781 }) 782 }) 783 784 Describe("keep alive messages", func() { 785 var times []time.Time 786 var timesCh chan []time.Time 787 var done chan struct{} 788 789 BeforeEach(func() { 790 keepAliveDuration = 100 * time.Millisecond 791 792 times = []time.Time{} 793 timesCh = make(chan []time.Time, 1) 794 done = make(chan struct{}, 1) 795 796 fakeConnection.SendRequestStub = func(reqName string, wantReply bool, message []byte) (bool, []byte, error) { 797 Expect(reqName).To(Equal("keepalive@cloudfoundry.org")) 798 Expect(wantReply).To(BeTrue()) 799 Expect(message).To(BeNil()) 800 801 times = append(times, time.Now()) 802 if len(times) == 3 { 803 timesCh <- times 804 close(done) 805 } 806 return true, nil, nil 807 } 808 809 fakeSecureSession.WaitStub = func() error { 810 Eventually(done).Should(BeClosed()) 811 return nil 812 } 813 }) 814 815 PIt("sends keep alive messages at the expected interval", func() { 816 times := <-timesCh 817 Expect(times[2]).To(BeTemporally("~", times[0].Add(200*time.Millisecond), 160*time.Millisecond)) 818 }) 819 }) 820 }) 821 822 Describe("LocalPortForward", func() { 823 var ( 824 forwardErr error 825 826 echoAddress string 827 echoListener *fake_net.FakeListener 828 echoHandler *fake_server.FakeConnectionHandler 829 echoServer *server.Server 830 831 localAddress string 832 833 realLocalListener net.Listener 834 fakeLocalListener *fake_net.FakeListener 835 836 forwardSpecs []LocalPortForward 837 ) 838 839 BeforeEach(func() { 840 logger := lagertest.NewTestLogger("test") 841 842 var err error 843 realLocalListener, err = net.Listen("tcp", "127.0.0.1:0") 844 Expect(err).NotTo(HaveOccurred()) 845 846 localAddress = realLocalListener.Addr().String() 847 fakeListenerFactory.ListenReturns(realLocalListener, nil) 848 849 echoHandler = new(fake_server.FakeConnectionHandler) 850 echoHandler.HandleConnectionStub = func(conn net.Conn) { 851 io.Copy(conn, conn) //nolint:errcheck 852 conn.Close() 853 } 854 855 realListener, err := net.Listen("tcp", "127.0.0.1:0") 856 Expect(err).NotTo(HaveOccurred()) 857 echoAddress = realListener.Addr().String() 858 859 echoListener = new(fake_net.FakeListener) 860 echoListener.AcceptStub = realListener.Accept 861 echoListener.CloseStub = realListener.Close 862 echoListener.AddrStub = realListener.Addr 863 864 fakeLocalListener = new(fake_net.FakeListener) 865 fakeLocalListener.AcceptReturns(nil, errors.New("Not Accepting Connections")) 866 867 echoServer = server.NewServer(logger.Session("echo"), "", echoHandler) 868 err = echoServer.SetListener(echoListener) 869 Expect(err).NotTo(HaveOccurred()) 870 go echoServer.Serve() 871 872 forwardSpecs = []LocalPortForward{{ 873 RemoteAddress: echoAddress, 874 LocalAddress: localAddress, 875 }} 876 877 fakeSecureClient.DialStub = net.Dial 878 }) 879 880 JustBeforeEach(func() { 881 connectErr := secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation) 882 Expect(connectErr).NotTo(HaveOccurred()) 883 884 forwardErr = secureShell.LocalPortForward(forwardSpecs) 885 }) 886 887 AfterEach(func() { 888 err := secureShell.Close() 889 Expect(err).NotTo(HaveOccurred()) 890 echoServer.Shutdown() 891 892 realLocalListener.Close() 893 }) 894 895 validateConnectivity := func(addr string) { 896 conn, err := net.Dial("tcp", addr) 897 Expect(err).NotTo(HaveOccurred()) 898 899 msg := fmt.Sprintf("Hello from %s\n", addr) 900 n, err := conn.Write([]byte(msg)) 901 Expect(err).NotTo(HaveOccurred()) 902 Expect(n).To(Equal(len(msg))) 903 904 response := make([]byte, len(msg)) 905 n, err = conn.Read(response) 906 Expect(err).NotTo(HaveOccurred()) 907 Expect(n).To(Equal(len(msg))) 908 909 err = conn.Close() 910 Expect(err).NotTo(HaveOccurred()) 911 912 Expect(response).To(Equal([]byte(msg))) 913 } 914 915 It("dials the connect address when a local connection is made", func() { 916 Expect(forwardErr).NotTo(HaveOccurred()) 917 918 conn, err := net.Dial("tcp", localAddress) 919 Expect(err).NotTo(HaveOccurred()) 920 921 Eventually(echoListener.AcceptCallCount).Should(BeNumerically(">=", 1)) 922 Eventually(fakeSecureClient.DialCallCount).Should(Equal(1)) 923 924 network, addr := fakeSecureClient.DialArgsForCall(0) 925 Expect(network).To(Equal("tcp")) 926 Expect(addr).To(Equal(echoAddress)) 927 928 Expect(conn.Close()).NotTo(HaveOccurred()) 929 }) 930 931 It("copies data between the local and remote connections", func() { 932 validateConnectivity(localAddress) 933 }) 934 935 When("a local connection is already open", func() { 936 var conn net.Conn 937 938 JustBeforeEach(func() { 939 var err error 940 conn, err = net.Dial("tcp", localAddress) 941 Expect(err).NotTo(HaveOccurred()) 942 }) 943 944 AfterEach(func() { 945 err := conn.Close() 946 Expect(err).NotTo(HaveOccurred()) 947 }) 948 949 It("allows for new incoming connections as well", func() { 950 validateConnectivity(localAddress) 951 }) 952 }) 953 954 When("there are multiple port forward specs", func() { 955 When("provided a real listener", func() { 956 var ( 957 realLocalListener2 net.Listener 958 localAddress2 string 959 ) 960 961 BeforeEach(func() { 962 var err error 963 realLocalListener2, err = net.Listen("tcp", "127.0.0.1:0") 964 Expect(err).NotTo(HaveOccurred()) 965 966 localAddress2 = realLocalListener2.Addr().String() 967 968 fakeListenerFactory.ListenStub = func(network, addr string) (net.Listener, error) { 969 if addr == localAddress { 970 return realLocalListener, nil 971 } 972 973 if addr == localAddress2 { 974 return realLocalListener2, nil 975 } 976 977 return nil, errors.New("unexpected address") 978 } 979 980 forwardSpecs = []LocalPortForward{ 981 { 982 RemoteAddress: echoAddress, 983 LocalAddress: localAddress, 984 }, 985 { 986 RemoteAddress: echoAddress, 987 LocalAddress: localAddress2, 988 }, 989 } 990 }) 991 992 AfterEach(func() { 993 realLocalListener2.Close() 994 }) 995 996 It("listens to all the things", func() { 997 Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(2)) 998 999 network, addr := fakeListenerFactory.ListenArgsForCall(0) 1000 Expect(network).To(Equal("tcp")) 1001 Expect(addr).To(Equal(localAddress)) 1002 1003 network, addr = fakeListenerFactory.ListenArgsForCall(1) 1004 Expect(network).To(Equal("tcp")) 1005 Expect(addr).To(Equal(localAddress2)) 1006 }) 1007 1008 It("forwards to the correct target", func() { 1009 validateConnectivity(localAddress) 1010 validateConnectivity(localAddress2) 1011 }) 1012 }) 1013 1014 When("the secure client is closed", func() { 1015 var ( 1016 fakeLocalListener1 *fake_net.FakeListener 1017 fakeLocalListener2 *fake_net.FakeListener 1018 ) 1019 1020 BeforeEach(func() { 1021 forwardSpecs = []LocalPortForward{ 1022 { 1023 RemoteAddress: echoAddress, 1024 LocalAddress: localAddress, 1025 }, 1026 { 1027 RemoteAddress: echoAddress, 1028 LocalAddress: localAddress, 1029 }, 1030 } 1031 1032 fakeLocalListener1 = new(fake_net.FakeListener) 1033 BlockAcceptOnClose(fakeLocalListener1) 1034 fakeListenerFactory.ListenReturnsOnCall(0, fakeLocalListener1, nil) 1035 1036 fakeLocalListener2 = new(fake_net.FakeListener) 1037 BlockAcceptOnClose(fakeLocalListener2) 1038 fakeListenerFactory.ListenReturnsOnCall(1, fakeLocalListener2, nil) 1039 }) 1040 1041 It("closes the listeners", func() { 1042 Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(2)) 1043 Eventually(fakeLocalListener1.AcceptCallCount).Should(Equal(1)) 1044 Eventually(fakeLocalListener2.AcceptCallCount).Should(Equal(1)) 1045 1046 err := secureShell.Close() 1047 Expect(err).NotTo(HaveOccurred()) 1048 Eventually(fakeLocalListener1.CloseCallCount).Should(Equal(2)) 1049 Eventually(fakeLocalListener2.CloseCallCount).Should(Equal(2)) 1050 }) 1051 }) 1052 }) 1053 1054 When("listen fails", func() { 1055 BeforeEach(func() { 1056 fakeListenerFactory.ListenReturns(nil, errors.New("failure is an option")) 1057 }) 1058 1059 It("returns the error", func() { 1060 Expect(forwardErr).To(MatchError("failure is an option")) 1061 }) 1062 }) 1063 1064 When("the client it closed", func() { 1065 BeforeEach(func() { 1066 fakeLocalListener = new(fake_net.FakeListener) 1067 BlockAcceptOnClose(fakeLocalListener) 1068 1069 fakeListenerFactory.ListenReturns(fakeLocalListener, nil) 1070 }) 1071 1072 It("closes the listener when the client is closed", func() { 1073 Eventually(fakeListenerFactory.ListenCallCount).Should(Equal(1)) 1074 Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(1)) 1075 1076 err := secureShell.Close() 1077 Expect(err).NotTo(HaveOccurred()) 1078 Eventually(fakeLocalListener.CloseCallCount).Should(Equal(2)) 1079 }) 1080 }) 1081 1082 When("accept fails", func() { 1083 var fakeConn *fake_net.FakeConn 1084 1085 BeforeEach(func() { 1086 fakeConn = &fake_net.FakeConn{} 1087 fakeConn.ReadReturns(0, io.EOF) 1088 1089 fakeListenerFactory.ListenReturns(fakeLocalListener, nil) 1090 }) 1091 1092 Context("with a permanent error", func() { 1093 BeforeEach(func() { 1094 fakeLocalListener.AcceptReturns(nil, errors.New("boom")) 1095 }) 1096 1097 It("stops trying to accept connections", func() { 1098 Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(1)) 1099 Consistently(fakeLocalListener.AcceptCallCount).Should(Equal(1)) 1100 Expect(fakeLocalListener.CloseCallCount()).To(Equal(1)) 1101 }) 1102 }) 1103 1104 Context("with a temporary error", func() { 1105 var timeCh chan time.Time 1106 1107 BeforeEach(func() { 1108 timeCh = make(chan time.Time, 3) 1109 1110 fakeLocalListener.AcceptStub = func() (net.Conn, error) { 1111 timeCh := timeCh 1112 if fakeLocalListener.AcceptCallCount() > 3 { 1113 close(timeCh) 1114 return nil, test_helpers.NewTestNetError(false, false) 1115 } else { 1116 timeCh <- time.Now() 1117 return nil, test_helpers.NewTestNetError(false, true) 1118 } 1119 } 1120 }) 1121 1122 PIt("retries connecting after a short delay", func() { 1123 Eventually(fakeLocalListener.AcceptCallCount).Should(Equal(3)) 1124 Expect(timeCh).To(HaveLen(3)) 1125 1126 times := make([]time.Time, 0) 1127 for t := range timeCh { 1128 times = append(times, t) 1129 } 1130 1131 Expect(times[1]).To(BeTemporally("~", times[0].Add(115*time.Millisecond), 80*time.Millisecond)) 1132 Expect(times[2]).To(BeTemporally("~", times[1].Add(115*time.Millisecond), 100*time.Millisecond)) 1133 }) 1134 }) 1135 }) 1136 1137 When("dialing the connect address fails", func() { 1138 var fakeTarget *fake_net.FakeConn 1139 1140 BeforeEach(func() { 1141 fakeTarget = &fake_net.FakeConn{} 1142 fakeSecureClient.DialReturns(fakeTarget, errors.New("boom")) 1143 }) 1144 1145 It("does not call close on the target connection", func() { 1146 Consistently(fakeTarget.CloseCallCount).Should(Equal(0)) 1147 }) 1148 }) 1149 }) 1150 1151 Describe("Wait", func() { 1152 var waitErr error 1153 1154 JustBeforeEach(func() { 1155 connectErr := secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation) 1156 Expect(connectErr).NotTo(HaveOccurred()) 1157 1158 waitErr = secureShell.Wait() 1159 }) 1160 1161 It("calls wait on the secureClient", func() { 1162 Expect(waitErr).NotTo(HaveOccurred()) 1163 Expect(fakeSecureClient.WaitCallCount()).To(Equal(1)) 1164 }) 1165 1166 Describe("keep alive messages", func() { 1167 var times []time.Time 1168 var timesCh chan []time.Time 1169 var done chan struct{} 1170 1171 BeforeEach(func() { 1172 keepAliveDuration = 100 * time.Millisecond 1173 1174 times = []time.Time{} 1175 timesCh = make(chan []time.Time, 1) 1176 done = make(chan struct{}, 1) 1177 1178 fakeConnection.SendRequestStub = func(reqName string, wantReply bool, message []byte) (bool, []byte, error) { 1179 Expect(reqName).To(Equal("keepalive@cloudfoundry.org")) 1180 Expect(wantReply).To(BeTrue()) 1181 Expect(message).To(BeNil()) 1182 1183 times = append(times, time.Now()) 1184 if len(times) == 3 { 1185 timesCh <- times 1186 close(done) 1187 } 1188 return true, nil, nil 1189 } 1190 1191 fakeSecureClient.WaitStub = func() error { 1192 Eventually(done).Should(BeClosed()) 1193 return nil 1194 } 1195 }) 1196 1197 PIt("sends keep alive messages at the expected interval", func() { 1198 Expect(waitErr).NotTo(HaveOccurred()) 1199 times := <-timesCh 1200 Expect(times[2]).To(BeTemporally("~", times[0].Add(200*time.Millisecond), 100*time.Millisecond)) 1201 }) 1202 }) 1203 }) 1204 1205 Describe("Close", func() { 1206 JustBeforeEach(func() { 1207 connectErr := secureShell.Connect(username, passcode, sshEndpoint, sshEndpointFingerprint, skipHostValidation) 1208 Expect(connectErr).NotTo(HaveOccurred()) 1209 }) 1210 1211 It("calls close on the secureClient", func() { 1212 err := secureShell.Close() 1213 Expect(err).NotTo(HaveOccurred()) 1214 1215 Expect(fakeSecureClient.CloseCallCount()).To(Equal(1)) 1216 }) 1217 }) 1218 })