github.com/cloudfoundry-attic/ltc@v0.0.0-20151123212628-098adc7919fc/ssh/sshapi/client_test.go (about) 1 package sshapi_test 2 3 import ( 4 "bytes" 5 "errors" 6 "io" 7 "net" 8 "os" 9 "reflect" 10 11 . "github.com/onsi/ginkgo" 12 . "github.com/onsi/gomega" 13 14 "github.com/cloudfoundry-incubator/ltc/ssh/sshapi" 15 "github.com/cloudfoundry-incubator/ltc/ssh/sshapi/mocks" 16 "golang.org/x/crypto/ssh" 17 ) 18 19 var _ = Describe(".New", func() { 20 It("should create a client using the package DialFunc", func() { 21 origDial := sshapi.DialFunc 22 defer func() { sshapi.DialFunc = origDial }() 23 24 dialCalled := false 25 sshClient := &ssh.Client{} 26 27 sshapi.DialFunc = func(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { 28 Expect(network).To(Equal("tcp")) 29 Expect(addr).To(Equal("some-host")) 30 Expect(config.User).To(Equal("some-ssh-user")) 31 32 Expect(config.Auth).To(HaveLen(1)) 33 34 actualSecret := reflect.ValueOf(config.Auth[0]).Call([]reflect.Value{})[0].Interface() 35 Expect(actualSecret).To(Equal("some-user:some-password")) 36 37 dialCalled = true 38 39 return sshClient, nil 40 } 41 42 client, err := sshapi.New("some-ssh-user", "some-user", "some-password", "some-host") 43 Expect(err).NotTo(HaveOccurred()) 44 Expect(client.Dialer == sshClient).To(BeTrue()) 45 Expect(client.SSHSessionFactory.(*sshapi.CryptoSSHSessionFactory).Client == sshClient).To(BeTrue()) 46 Expect(client.Stdin).To(Equal(os.Stdin)) 47 Expect(client.Stdout).To(Equal(os.Stdout)) 48 Expect(client.Stderr).To(Equal(os.Stderr)) 49 50 Expect(dialCalled).To(BeTrue()) 51 }) 52 53 Context("when dialing fails", func() { 54 It("should return an error", func() { 55 origDial := sshapi.DialFunc 56 defer func() { sshapi.DialFunc = origDial }() 57 58 sshapi.DialFunc = func(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { 59 return &ssh.Client{}, errors.New("some error") 60 } 61 62 _, err := sshapi.New("some-ssh-user", "some-user", "some-password", "some-host") 63 Expect(err).To(MatchError("some error")) 64 }) 65 }) 66 }) 67 68 type mockConn struct { 69 io.Reader 70 io.Writer 71 nilNetConn 72 closed bool 73 } 74 75 type nilNetConn struct { 76 net.Conn 77 } 78 79 func (m *mockConn) Close() error { 80 m.closed = true 81 return nil 82 } 83 84 var _ = Describe("Client", func() { 85 Describe("#Forward", func() { 86 var ( 87 client *sshapi.Client 88 fakeDialer *mocks.FakeDialer 89 ) 90 91 BeforeEach(func() { 92 fakeDialer = &mocks.FakeDialer{} 93 client = &sshapi.Client{Dialer: fakeDialer} 94 }) 95 96 It("should dial a remote connection", func() { 97 localConn := &mockConn{Reader: &bytes.Buffer{}, Writer: &bytes.Buffer{}} 98 remoteConn := &mockConn{Reader: &bytes.Buffer{}, Writer: &bytes.Buffer{}} 99 fakeDialer.DialReturns(remoteConn, nil) 100 101 Expect(client.Forward(localConn, "some remote address")).To(Succeed()) 102 103 Expect(fakeDialer.DialCallCount()).To(Equal(1)) 104 protocol, address := fakeDialer.DialArgsForCall(0) 105 Expect(protocol).To(Equal("tcp")) 106 Expect(address).To(Equal("some remote address")) 107 }) 108 109 It("should copy data in both directions", func() { 110 localConnBuffer := &bytes.Buffer{} 111 remoteConnBuffer := &bytes.Buffer{} 112 localConn := &mockConn{Reader: bytes.NewBufferString("some local data"), Writer: localConnBuffer} 113 remoteConn := &mockConn{Reader: bytes.NewBufferString("some remote data"), Writer: remoteConnBuffer} 114 fakeDialer.DialReturns(remoteConn, nil) 115 116 Expect(client.Forward(localConn, "some remote address")).To(Succeed()) 117 118 Expect(localConn.closed).To(BeTrue()) 119 Expect(remoteConn.closed).To(BeTrue()) 120 Expect(localConnBuffer.String()).To(Equal("some remote data")) 121 Expect(remoteConnBuffer.String()).To(Equal("some local data")) 122 }) 123 124 Context("when dialing a remote connection fails", func() { 125 It("should return an error", func() { 126 fakeDialer.DialReturns(nil, errors.New("some error")) 127 err := client.Forward(nil, "some remote address") 128 Expect(err).To(MatchError("some error")) 129 }) 130 }) 131 }) 132 133 Describe("#Open", func() { 134 var ( 135 client *sshapi.Client 136 mockSession *mocks.FakeSSHSession 137 mockSessionFactory *mocks.FakeSSHSessionFactory 138 originalTerm string 139 ) 140 141 BeforeEach(func() { 142 originalTerm = os.Getenv("TERM") 143 mockSessionFactory = &mocks.FakeSSHSessionFactory{} 144 client = &sshapi.Client{ 145 SSHSessionFactory: mockSessionFactory, 146 } 147 mockSession = &mocks.FakeSSHSession{} 148 mockSessionFactory.NewReturns(mockSession, nil) 149 }) 150 151 AfterEach(func() { 152 os.Setenv("TERM", originalTerm) 153 }) 154 155 It("should open a new session", func() { 156 os.Setenv("TERM", "some term") 157 158 client.Stdin = bytes.NewBufferString("some client in data") 159 client.Stdout = &bytes.Buffer{} 160 client.Stderr = &bytes.Buffer{} 161 mockSessionStdinBuffer := &bytes.Buffer{} 162 mockSessionStdin := &mockConn{Writer: mockSessionStdinBuffer} 163 mockSessionStdout := bytes.NewBufferString("some session out data") 164 mockSessionStderr := bytes.NewBufferString("some session err data") 165 mockSession.StdinPipeReturns(mockSessionStdin, nil) 166 mockSession.StdoutPipeReturns(mockSessionStdout, nil) 167 mockSession.StderrPipeReturns(mockSessionStderr, nil) 168 169 _, err := client.Open(100, 200, true) 170 Expect(err).NotTo(HaveOccurred()) 171 172 Expect(mockSession.RequestPtyCallCount()).To(Equal(1)) 173 termType, height, width, modes := mockSession.RequestPtyArgsForCall(0) 174 Expect(termType).To(Equal("some term")) 175 Expect(height).To(Equal(200)) 176 Expect(width).To(Equal(100)) 177 Expect(modes[ssh.ECHO]).To(Equal(uint32(1))) 178 Expect(modes[ssh.TTY_OP_ISPEED]).To(Equal(uint32(115200))) 179 Expect(modes[ssh.TTY_OP_OSPEED]).To(Equal(uint32(115200))) 180 181 Eventually(mockSessionStdinBuffer.String).Should(Equal("some client in data")) 182 Eventually(client.Stdout.(*bytes.Buffer).String).Should(Equal("some session out data")) 183 Eventually(client.Stderr.(*bytes.Buffer).String).Should(Equal("some session err data")) 184 Eventually(func() bool { return mockSessionStdin.closed }).Should(BeTrue()) 185 }) 186 187 It("should not request a pty when desirePTY is false", func() { 188 client.Stdin = bytes.NewBufferString("some client in data") 189 client.Stdout = &bytes.Buffer{} 190 client.Stderr = &bytes.Buffer{} 191 mockSessionStdinBuffer := &bytes.Buffer{} 192 mockSessionStdin := &mockConn{Writer: mockSessionStdinBuffer} 193 mockSessionStdout := bytes.NewBufferString("some session out data") 194 mockSessionStderr := bytes.NewBufferString("some session err data") 195 mockSession.StdinPipeReturns(mockSessionStdin, nil) 196 mockSession.StdoutPipeReturns(mockSessionStdout, nil) 197 mockSession.StderrPipeReturns(mockSessionStderr, nil) 198 199 _, err := client.Open(100, 200, false) 200 Expect(err).NotTo(HaveOccurred()) 201 202 Expect(mockSession.RequestPtyCallCount()).To(Equal(0)) 203 }) 204 205 It("should request a pty when desirePTY is true", func() { 206 client.Stdin = bytes.NewBufferString("some client in data") 207 client.Stdout = &bytes.Buffer{} 208 client.Stderr = &bytes.Buffer{} 209 mockSessionStdinBuffer := &bytes.Buffer{} 210 mockSessionStdin := &mockConn{Writer: mockSessionStdinBuffer} 211 mockSessionStdout := bytes.NewBufferString("some session out data") 212 mockSessionStderr := bytes.NewBufferString("some session err data") 213 mockSession.StdinPipeReturns(mockSessionStdin, nil) 214 mockSession.StdoutPipeReturns(mockSessionStdout, nil) 215 mockSession.StderrPipeReturns(mockSessionStderr, nil) 216 217 _, err := client.Open(100, 200, true) 218 Expect(err).NotTo(HaveOccurred()) 219 220 Expect(mockSession.RequestPtyCallCount()).To(Equal(1)) 221 }) 222 223 Context("when we fail to open a new session", func() { 224 It("should return an error", func() { 225 mockSessionFactory.NewReturns(nil, errors.New("some error")) 226 _, err := client.Open(100, 200, true) 227 Expect(err).To(MatchError("some error")) 228 }) 229 }) 230 231 Context("when we fail to open any of the session pipes", func() { 232 It("should return an error", func() { 233 mockSession.StderrPipeReturns(nil, errors.New("some stderr error")) 234 _, err := client.Open(100, 200, true) 235 Expect(err).To(MatchError("some stderr error")) 236 237 mockSession.StdoutPipeReturns(nil, errors.New("some stdout error")) 238 _, err = client.Open(100, 200, true) 239 Expect(err).To(MatchError("some stdout error")) 240 241 mockSession.StdinPipeReturns(nil, errors.New("some stdin error")) 242 _, err = client.Open(100, 200, true) 243 Expect(err).To(MatchError("some stdin error")) 244 }) 245 }) 246 247 Context("when requesting a PTY fails", func() { 248 It("should return an error", func() { 249 mockSession.RequestPtyReturns(errors.New("some error")) 250 _, err := client.Open(100, 200, true) 251 Expect(err).To(MatchError("some error")) 252 253 }) 254 }) 255 }) 256 })