github.com/cloudfoundry-attic/ltc@v0.0.0-20151123212628-098adc7919fc/ssh/sshapi/client_test.go (about)

     1  package sshapi_test
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"io"
     7  	"net"
     8  	"os"
     9  	"reflect"
    10  
    11  	. "github.com/onsi/ginkgo"
    12  	. "github.com/onsi/gomega"
    13  
    14  	"github.com/cloudfoundry-incubator/ltc/ssh/sshapi"
    15  	"github.com/cloudfoundry-incubator/ltc/ssh/sshapi/mocks"
    16  	"golang.org/x/crypto/ssh"
    17  )
    18  
    19  var _ = Describe(".New", func() {
    20  	It("should create a client using the package DialFunc", func() {
    21  		origDial := sshapi.DialFunc
    22  		defer func() { sshapi.DialFunc = origDial }()
    23  
    24  		dialCalled := false
    25  		sshClient := &ssh.Client{}
    26  
    27  		sshapi.DialFunc = func(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
    28  			Expect(network).To(Equal("tcp"))
    29  			Expect(addr).To(Equal("some-host"))
    30  			Expect(config.User).To(Equal("some-ssh-user"))
    31  
    32  			Expect(config.Auth).To(HaveLen(1))
    33  
    34  			actualSecret := reflect.ValueOf(config.Auth[0]).Call([]reflect.Value{})[0].Interface()
    35  			Expect(actualSecret).To(Equal("some-user:some-password"))
    36  
    37  			dialCalled = true
    38  
    39  			return sshClient, nil
    40  		}
    41  
    42  		client, err := sshapi.New("some-ssh-user", "some-user", "some-password", "some-host")
    43  		Expect(err).NotTo(HaveOccurred())
    44  		Expect(client.Dialer == sshClient).To(BeTrue())
    45  		Expect(client.SSHSessionFactory.(*sshapi.CryptoSSHSessionFactory).Client == sshClient).To(BeTrue())
    46  		Expect(client.Stdin).To(Equal(os.Stdin))
    47  		Expect(client.Stdout).To(Equal(os.Stdout))
    48  		Expect(client.Stderr).To(Equal(os.Stderr))
    49  
    50  		Expect(dialCalled).To(BeTrue())
    51  	})
    52  
    53  	Context("when dialing fails", func() {
    54  		It("should return an error", func() {
    55  			origDial := sshapi.DialFunc
    56  			defer func() { sshapi.DialFunc = origDial }()
    57  
    58  			sshapi.DialFunc = func(network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) {
    59  				return &ssh.Client{}, errors.New("some error")
    60  			}
    61  
    62  			_, err := sshapi.New("some-ssh-user", "some-user", "some-password", "some-host")
    63  			Expect(err).To(MatchError("some error"))
    64  		})
    65  	})
    66  })
    67  
    68  type mockConn struct {
    69  	io.Reader
    70  	io.Writer
    71  	nilNetConn
    72  	closed bool
    73  }
    74  
    75  type nilNetConn struct {
    76  	net.Conn
    77  }
    78  
    79  func (m *mockConn) Close() error {
    80  	m.closed = true
    81  	return nil
    82  }
    83  
    84  var _ = Describe("Client", func() {
    85  	Describe("#Forward", func() {
    86  		var (
    87  			client     *sshapi.Client
    88  			fakeDialer *mocks.FakeDialer
    89  		)
    90  
    91  		BeforeEach(func() {
    92  			fakeDialer = &mocks.FakeDialer{}
    93  			client = &sshapi.Client{Dialer: fakeDialer}
    94  		})
    95  
    96  		It("should dial a remote connection", func() {
    97  			localConn := &mockConn{Reader: &bytes.Buffer{}, Writer: &bytes.Buffer{}}
    98  			remoteConn := &mockConn{Reader: &bytes.Buffer{}, Writer: &bytes.Buffer{}}
    99  			fakeDialer.DialReturns(remoteConn, nil)
   100  
   101  			Expect(client.Forward(localConn, "some remote address")).To(Succeed())
   102  
   103  			Expect(fakeDialer.DialCallCount()).To(Equal(1))
   104  			protocol, address := fakeDialer.DialArgsForCall(0)
   105  			Expect(protocol).To(Equal("tcp"))
   106  			Expect(address).To(Equal("some remote address"))
   107  		})
   108  
   109  		It("should copy data in both directions", func() {
   110  			localConnBuffer := &bytes.Buffer{}
   111  			remoteConnBuffer := &bytes.Buffer{}
   112  			localConn := &mockConn{Reader: bytes.NewBufferString("some local data"), Writer: localConnBuffer}
   113  			remoteConn := &mockConn{Reader: bytes.NewBufferString("some remote data"), Writer: remoteConnBuffer}
   114  			fakeDialer.DialReturns(remoteConn, nil)
   115  
   116  			Expect(client.Forward(localConn, "some remote address")).To(Succeed())
   117  
   118  			Expect(localConn.closed).To(BeTrue())
   119  			Expect(remoteConn.closed).To(BeTrue())
   120  			Expect(localConnBuffer.String()).To(Equal("some remote data"))
   121  			Expect(remoteConnBuffer.String()).To(Equal("some local data"))
   122  		})
   123  
   124  		Context("when dialing a remote connection fails", func() {
   125  			It("should return an error", func() {
   126  				fakeDialer.DialReturns(nil, errors.New("some error"))
   127  				err := client.Forward(nil, "some remote address")
   128  				Expect(err).To(MatchError("some error"))
   129  			})
   130  		})
   131  	})
   132  
   133  	Describe("#Open", func() {
   134  		var (
   135  			client             *sshapi.Client
   136  			mockSession        *mocks.FakeSSHSession
   137  			mockSessionFactory *mocks.FakeSSHSessionFactory
   138  			originalTerm       string
   139  		)
   140  
   141  		BeforeEach(func() {
   142  			originalTerm = os.Getenv("TERM")
   143  			mockSessionFactory = &mocks.FakeSSHSessionFactory{}
   144  			client = &sshapi.Client{
   145  				SSHSessionFactory: mockSessionFactory,
   146  			}
   147  			mockSession = &mocks.FakeSSHSession{}
   148  			mockSessionFactory.NewReturns(mockSession, nil)
   149  		})
   150  
   151  		AfterEach(func() {
   152  			os.Setenv("TERM", originalTerm)
   153  		})
   154  
   155  		It("should open a new session", func() {
   156  			os.Setenv("TERM", "some term")
   157  
   158  			client.Stdin = bytes.NewBufferString("some client in data")
   159  			client.Stdout = &bytes.Buffer{}
   160  			client.Stderr = &bytes.Buffer{}
   161  			mockSessionStdinBuffer := &bytes.Buffer{}
   162  			mockSessionStdin := &mockConn{Writer: mockSessionStdinBuffer}
   163  			mockSessionStdout := bytes.NewBufferString("some session out data")
   164  			mockSessionStderr := bytes.NewBufferString("some session err data")
   165  			mockSession.StdinPipeReturns(mockSessionStdin, nil)
   166  			mockSession.StdoutPipeReturns(mockSessionStdout, nil)
   167  			mockSession.StderrPipeReturns(mockSessionStderr, nil)
   168  
   169  			_, err := client.Open(100, 200, true)
   170  			Expect(err).NotTo(HaveOccurred())
   171  
   172  			Expect(mockSession.RequestPtyCallCount()).To(Equal(1))
   173  			termType, height, width, modes := mockSession.RequestPtyArgsForCall(0)
   174  			Expect(termType).To(Equal("some term"))
   175  			Expect(height).To(Equal(200))
   176  			Expect(width).To(Equal(100))
   177  			Expect(modes[ssh.ECHO]).To(Equal(uint32(1)))
   178  			Expect(modes[ssh.TTY_OP_ISPEED]).To(Equal(uint32(115200)))
   179  			Expect(modes[ssh.TTY_OP_OSPEED]).To(Equal(uint32(115200)))
   180  
   181  			Eventually(mockSessionStdinBuffer.String).Should(Equal("some client in data"))
   182  			Eventually(client.Stdout.(*bytes.Buffer).String).Should(Equal("some session out data"))
   183  			Eventually(client.Stderr.(*bytes.Buffer).String).Should(Equal("some session err data"))
   184  			Eventually(func() bool { return mockSessionStdin.closed }).Should(BeTrue())
   185  		})
   186  
   187  		It("should not request a pty when desirePTY is false", func() {
   188  			client.Stdin = bytes.NewBufferString("some client in data")
   189  			client.Stdout = &bytes.Buffer{}
   190  			client.Stderr = &bytes.Buffer{}
   191  			mockSessionStdinBuffer := &bytes.Buffer{}
   192  			mockSessionStdin := &mockConn{Writer: mockSessionStdinBuffer}
   193  			mockSessionStdout := bytes.NewBufferString("some session out data")
   194  			mockSessionStderr := bytes.NewBufferString("some session err data")
   195  			mockSession.StdinPipeReturns(mockSessionStdin, nil)
   196  			mockSession.StdoutPipeReturns(mockSessionStdout, nil)
   197  			mockSession.StderrPipeReturns(mockSessionStderr, nil)
   198  
   199  			_, err := client.Open(100, 200, false)
   200  			Expect(err).NotTo(HaveOccurred())
   201  
   202  			Expect(mockSession.RequestPtyCallCount()).To(Equal(0))
   203  		})
   204  
   205  		It("should request a pty when desirePTY is true", func() {
   206  			client.Stdin = bytes.NewBufferString("some client in data")
   207  			client.Stdout = &bytes.Buffer{}
   208  			client.Stderr = &bytes.Buffer{}
   209  			mockSessionStdinBuffer := &bytes.Buffer{}
   210  			mockSessionStdin := &mockConn{Writer: mockSessionStdinBuffer}
   211  			mockSessionStdout := bytes.NewBufferString("some session out data")
   212  			mockSessionStderr := bytes.NewBufferString("some session err data")
   213  			mockSession.StdinPipeReturns(mockSessionStdin, nil)
   214  			mockSession.StdoutPipeReturns(mockSessionStdout, nil)
   215  			mockSession.StderrPipeReturns(mockSessionStderr, nil)
   216  
   217  			_, err := client.Open(100, 200, true)
   218  			Expect(err).NotTo(HaveOccurred())
   219  
   220  			Expect(mockSession.RequestPtyCallCount()).To(Equal(1))
   221  		})
   222  
   223  		Context("when we fail to open a new session", func() {
   224  			It("should return an error", func() {
   225  				mockSessionFactory.NewReturns(nil, errors.New("some error"))
   226  				_, err := client.Open(100, 200, true)
   227  				Expect(err).To(MatchError("some error"))
   228  			})
   229  		})
   230  
   231  		Context("when we fail to open any of the session pipes", func() {
   232  			It("should return an error", func() {
   233  				mockSession.StderrPipeReturns(nil, errors.New("some stderr error"))
   234  				_, err := client.Open(100, 200, true)
   235  				Expect(err).To(MatchError("some stderr error"))
   236  
   237  				mockSession.StdoutPipeReturns(nil, errors.New("some stdout error"))
   238  				_, err = client.Open(100, 200, true)
   239  				Expect(err).To(MatchError("some stdout error"))
   240  
   241  				mockSession.StdinPipeReturns(nil, errors.New("some stdin error"))
   242  				_, err = client.Open(100, 200, true)
   243  				Expect(err).To(MatchError("some stdin error"))
   244  			})
   245  		})
   246  
   247  		Context("when requesting a PTY fails", func() {
   248  			It("should return an error", func() {
   249  				mockSession.RequestPtyReturns(errors.New("some error"))
   250  				_, err := client.Open(100, 200, true)
   251  				Expect(err).To(MatchError("some error"))
   252  
   253  			})
   254  		})
   255  	})
   256  })