github.com/devops-filetransfer/sshego@v7.0.4+incompatible/_vendor/golang.org/x/crypto/ssh/handshake_test.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/rand"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"net"
    14  	"reflect"
    15  	"runtime"
    16  	"strings"
    17  	"sync"
    18  	"testing"
    19  )
    20  
    21  type testChecker struct {
    22  	calls []string
    23  }
    24  
    25  func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
    26  	if dialAddr == "bad" {
    27  		return fmt.Errorf("dialAddr is bad")
    28  	}
    29  
    30  	if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil {
    31  		return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr)
    32  	}
    33  
    34  	t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal()))
    35  
    36  	return nil
    37  }
    38  
    39  // netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
    40  // therefore is buffered (net.Pipe deadlocks if both sides start with
    41  // a write.)
    42  func netPipe() (net.Conn, net.Conn, error) {
    43  	listener, err := net.Listen("tcp", ":0")
    44  	if err != nil {
    45  		return nil, nil, err
    46  	}
    47  	defer listener.Close()
    48  	c1, err := net.Dial("tcp", listener.Addr().String())
    49  	if err != nil {
    50  		return nil, nil, err
    51  	}
    52  
    53  	c2, err := listener.Accept()
    54  	if err != nil {
    55  		c1.Close()
    56  		return nil, nil, err
    57  	}
    58  
    59  	return c1, c2, nil
    60  }
    61  
    62  // noiseTransport inserts ignore messages to check that the read loop
    63  // and the key exchange filters out these messages.
    64  type noiseTransport struct {
    65  	keyingTransport
    66  }
    67  
    68  func (t *noiseTransport) writePacket(p []byte) error {
    69  	ignore := []byte{msgIgnore}
    70  	if err := t.keyingTransport.writePacket(ignore); err != nil {
    71  		return err
    72  	}
    73  	debug := []byte{msgDebug, 1, 2, 3}
    74  	if err := t.keyingTransport.writePacket(debug); err != nil {
    75  		return err
    76  	}
    77  
    78  	return t.keyingTransport.writePacket(p)
    79  }
    80  
    81  func addNoiseTransport(t keyingTransport) keyingTransport {
    82  	return &noiseTransport{t}
    83  }
    84  
    85  // handshakePair creates two handshakeTransports connected with each
    86  // other. If the noise argument is true, both transports will try to
    87  // confuse the other side by sending ignore and debug messages.
    88  func handshakePair(clientConf *ClientConfig, addr string, noise bool) (client *handshakeTransport, server *handshakeTransport, err error) {
    89  	a, b, err := netPipe()
    90  	if err != nil {
    91  		return nil, nil, err
    92  	}
    93  
    94  	var trC, trS keyingTransport
    95  
    96  	trC = newTransport(a, rand.Reader, true)
    97  	trS = newTransport(b, rand.Reader, false)
    98  	if noise {
    99  		trC = addNoiseTransport(trC)
   100  		trS = addNoiseTransport(trS)
   101  	}
   102  	clientConf.SetDefaults()
   103  
   104  	v := []byte("version")
   105  	client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr())
   106  
   107  	serverConf := &ServerConfig{}
   108  	serverConf.AddHostKey(testSigners["ecdsa"])
   109  	serverConf.AddHostKey(testSigners["rsa"])
   110  	serverConf.SetDefaults()
   111  	server = newServerTransport(trS, v, v, serverConf)
   112  
   113  	if err := server.waitSession(); err != nil {
   114  		return nil, nil, fmt.Errorf("server.waitSession: %v", err)
   115  	}
   116  	if err := client.waitSession(); err != nil {
   117  		return nil, nil, fmt.Errorf("client.waitSession: %v", err)
   118  	}
   119  
   120  	return client, server, nil
   121  }
   122  
   123  func TestHandshakeBasic(t *testing.T) {
   124  	if runtime.GOOS == "plan9" {
   125  		t.Skip("see golang.org/issue/7237")
   126  	}
   127  
   128  	checker := &syncChecker{
   129  		waitCall: make(chan int, 10),
   130  		called:   make(chan int, 10),
   131  	}
   132  
   133  	checker.waitCall <- 1
   134  	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
   135  	if err != nil {
   136  		t.Fatalf("handshakePair: %v", err)
   137  	}
   138  
   139  	defer trC.Close()
   140  	defer trS.Close()
   141  
   142  	// Let first kex complete normally.
   143  	<-checker.called
   144  
   145  	clientDone := make(chan int, 0)
   146  	gotHalf := make(chan int, 0)
   147  	const N = 20
   148  
   149  	go func() {
   150  		defer close(clientDone)
   151  		// Client writes a bunch of stuff, and does a key
   152  		// change in the middle. This should not confuse the
   153  		// handshake in progress. We do this twice, so we test
   154  		// that the packet buffer is reset correctly.
   155  		for i := 0; i < N; i++ {
   156  			p := []byte{msgRequestSuccess, byte(i)}
   157  			if err := trC.writePacket(p); err != nil {
   158  				t.Fatalf("sendPacket: %v", err)
   159  			}
   160  			if (i % 10) == 5 {
   161  				<-gotHalf
   162  				// halfway through, we request a key change.
   163  				trC.requestKeyExchange()
   164  
   165  				// Wait until we can be sure the key
   166  				// change has really started before we
   167  				// write more.
   168  				<-checker.called
   169  			}
   170  			if (i % 10) == 7 {
   171  				// write some packets until the kex
   172  				// completes, to test buffering of
   173  				// packets.
   174  				checker.waitCall <- 1
   175  			}
   176  		}
   177  	}()
   178  
   179  	// Server checks that client messages come in cleanly
   180  	i := 0
   181  	err = nil
   182  	for ; i < N; i++ {
   183  		var p []byte
   184  		p, err = trS.readPacket()
   185  		if err != nil {
   186  			break
   187  		}
   188  		if (i % 10) == 5 {
   189  			gotHalf <- 1
   190  		}
   191  
   192  		want := []byte{msgRequestSuccess, byte(i)}
   193  		if bytes.Compare(p, want) != 0 {
   194  			t.Errorf("message %d: got %v, want %v", i, p, want)
   195  		}
   196  	}
   197  	<-clientDone
   198  	if err != nil && err != io.EOF {
   199  		t.Fatalf("server error: %v", err)
   200  	}
   201  	if i != N {
   202  		t.Errorf("received %d messages, want 10.", i)
   203  	}
   204  
   205  	close(checker.called)
   206  	if _, ok := <-checker.called; ok {
   207  		// If all went well, we registered exactly 2 key changes: one
   208  		// that establishes the session, and one that we requested
   209  		// additionally.
   210  		t.Fatalf("got another host key checks after 2 handshakes")
   211  	}
   212  }
   213  
   214  func TestForceFirstKex(t *testing.T) {
   215  	// like handshakePair, but must access the keyingTransport.
   216  	checker := &testChecker{}
   217  	clientConf := &ClientConfig{HostKeyCallback: checker.Check}
   218  	a, b, err := netPipe()
   219  	if err != nil {
   220  		t.Fatalf("netPipe: %v", err)
   221  	}
   222  
   223  	var trC, trS keyingTransport
   224  
   225  	trC = newTransport(a, rand.Reader, true)
   226  
   227  	// This is the disallowed packet:
   228  	trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth}))
   229  
   230  	// Rest of the setup.
   231  	trS = newTransport(b, rand.Reader, false)
   232  	clientConf.SetDefaults()
   233  
   234  	v := []byte("version")
   235  	client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr())
   236  
   237  	serverConf := &ServerConfig{}
   238  	serverConf.AddHostKey(testSigners["ecdsa"])
   239  	serverConf.AddHostKey(testSigners["rsa"])
   240  	serverConf.SetDefaults()
   241  	server := newServerTransport(trS, v, v, serverConf)
   242  
   243  	defer client.Close()
   244  	defer server.Close()
   245  
   246  	// We setup the initial key exchange, but the remote side
   247  	// tries to send serviceRequestMsg in cleartext, which is
   248  	// disallowed.
   249  
   250  	if err := server.waitSession(); err == nil {
   251  		t.Errorf("server first kex init should reject unexpected packet")
   252  	}
   253  }
   254  
   255  func TestHandshakeAutoRekeyWrite(t *testing.T) {
   256  	checker := &syncChecker{
   257  		called:   make(chan int, 10),
   258  		waitCall: nil,
   259  	}
   260  	clientConf := &ClientConfig{HostKeyCallback: checker.Check}
   261  	clientConf.RekeyThreshold = 500
   262  	trC, trS, err := handshakePair(clientConf, "addr", false)
   263  	if err != nil {
   264  		t.Fatalf("handshakePair: %v", err)
   265  	}
   266  	defer trC.Close()
   267  	defer trS.Close()
   268  
   269  	input := make([]byte, 251)
   270  	input[0] = msgRequestSuccess
   271  
   272  	done := make(chan int, 1)
   273  	const numPacket = 5
   274  	go func() {
   275  		defer close(done)
   276  		j := 0
   277  		for ; j < numPacket; j++ {
   278  			if p, err := trS.readPacket(); err != nil {
   279  				break
   280  			} else if !bytes.Equal(input, p) {
   281  				t.Errorf("got packet type %d, want %d", p[0], input[0])
   282  			}
   283  		}
   284  
   285  		if j != numPacket {
   286  			t.Errorf("got %d, want 5 messages", j)
   287  		}
   288  	}()
   289  
   290  	<-checker.called
   291  
   292  	for i := 0; i < numPacket; i++ {
   293  		p := make([]byte, len(input))
   294  		copy(p, input)
   295  		if err := trC.writePacket(p); err != nil {
   296  			t.Errorf("writePacket: %v", err)
   297  		}
   298  		if i == 2 {
   299  			// Make sure the kex is in progress.
   300  			<-checker.called
   301  		}
   302  
   303  	}
   304  	<-done
   305  }
   306  
   307  type syncChecker struct {
   308  	waitCall chan int
   309  	called   chan int
   310  }
   311  
   312  func (c *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
   313  	c.called <- 1
   314  	if c.waitCall != nil {
   315  		<-c.waitCall
   316  	}
   317  	return nil
   318  }
   319  
   320  func TestHandshakeAutoRekeyRead(t *testing.T) {
   321  	sync := &syncChecker{
   322  		called:   make(chan int, 2),
   323  		waitCall: nil,
   324  	}
   325  	clientConf := &ClientConfig{
   326  		HostKeyCallback: sync.Check,
   327  	}
   328  	clientConf.RekeyThreshold = 500
   329  
   330  	trC, trS, err := handshakePair(clientConf, "addr", false)
   331  	if err != nil {
   332  		t.Fatalf("handshakePair: %v", err)
   333  	}
   334  	defer trC.Close()
   335  	defer trS.Close()
   336  
   337  	packet := make([]byte, 501)
   338  	packet[0] = msgRequestSuccess
   339  	if err := trS.writePacket(packet); err != nil {
   340  		t.Fatalf("writePacket: %v", err)
   341  	}
   342  
   343  	// While we read out the packet, a key change will be
   344  	// initiated.
   345  	done := make(chan int, 1)
   346  	go func() {
   347  		defer close(done)
   348  		if _, err := trC.readPacket(); err != nil {
   349  			t.Fatalf("readPacket(client): %v", err)
   350  		}
   351  
   352  	}()
   353  
   354  	<-done
   355  	<-sync.called
   356  }
   357  
   358  // errorKeyingTransport generates errors after a given number of
   359  // read/write operations.
   360  type errorKeyingTransport struct {
   361  	packetConn
   362  	readLeft, writeLeft int
   363  }
   364  
   365  func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error {
   366  	return nil
   367  }
   368  
   369  func (n *errorKeyingTransport) getSessionID() []byte {
   370  	return nil
   371  }
   372  
   373  func (n *errorKeyingTransport) writePacket(packet []byte) error {
   374  	if n.writeLeft == 0 {
   375  		n.Close()
   376  		return errors.New("barf")
   377  	}
   378  
   379  	n.writeLeft--
   380  	return n.packetConn.writePacket(packet)
   381  }
   382  
   383  func (n *errorKeyingTransport) readPacket() ([]byte, error) {
   384  	if n.readLeft == 0 {
   385  		n.Close()
   386  		return nil, errors.New("barf")
   387  	}
   388  
   389  	n.readLeft--
   390  	return n.packetConn.readPacket()
   391  }
   392  
   393  func TestHandshakeErrorHandlingRead(t *testing.T) {
   394  	for i := 0; i < 20; i++ {
   395  		testHandshakeErrorHandlingN(t, i, -1, false)
   396  	}
   397  }
   398  
   399  func TestHandshakeErrorHandlingWrite(t *testing.T) {
   400  	for i := 0; i < 20; i++ {
   401  		testHandshakeErrorHandlingN(t, -1, i, false)
   402  	}
   403  }
   404  
   405  func TestHandshakeErrorHandlingReadCoupled(t *testing.T) {
   406  	for i := 0; i < 20; i++ {
   407  		testHandshakeErrorHandlingN(t, i, -1, true)
   408  	}
   409  }
   410  
   411  func TestHandshakeErrorHandlingWriteCoupled(t *testing.T) {
   412  	for i := 0; i < 20; i++ {
   413  		testHandshakeErrorHandlingN(t, -1, i, true)
   414  	}
   415  }
   416  
   417  // testHandshakeErrorHandlingN runs handshakes, injecting errors. If
   418  // handshakeTransport deadlocks, the go runtime will detect it and
   419  // panic.
   420  func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int, coupled bool) {
   421  	msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)})
   422  
   423  	a, b := memPipe()
   424  	defer a.Close()
   425  	defer b.Close()
   426  
   427  	key := testSigners["ecdsa"]
   428  	serverConf := Config{RekeyThreshold: minRekeyThreshold}
   429  	serverConf.SetDefaults()
   430  	serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'})
   431  	serverConn.hostKeys = []Signer{key}
   432  	go serverConn.readLoop()
   433  	go serverConn.kexLoop()
   434  
   435  	clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold}
   436  	clientConf.SetDefaults()
   437  	clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'})
   438  	clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()}
   439  	clientConn.hostKeyCallback = InsecureIgnoreHostKey()
   440  	go clientConn.readLoop()
   441  	go clientConn.kexLoop()
   442  
   443  	var wg sync.WaitGroup
   444  
   445  	for _, hs := range []packetConn{serverConn, clientConn} {
   446  		if !coupled {
   447  			wg.Add(2)
   448  			go func(c packetConn) {
   449  				for i := 0; ; i++ {
   450  					str := fmt.Sprintf("%08x", i) + strings.Repeat("x", int(minRekeyThreshold)/4-8)
   451  					err := c.writePacket(Marshal(&serviceRequestMsg{str}))
   452  					if err != nil {
   453  						break
   454  					}
   455  				}
   456  				wg.Done()
   457  				c.Close()
   458  			}(hs)
   459  			go func(c packetConn) {
   460  				for {
   461  					_, err := c.readPacket()
   462  					if err != nil {
   463  						break
   464  					}
   465  				}
   466  				wg.Done()
   467  			}(hs)
   468  		} else {
   469  			wg.Add(1)
   470  			go func(c packetConn) {
   471  				for {
   472  					_, err := c.readPacket()
   473  					if err != nil {
   474  						break
   475  					}
   476  					if err := c.writePacket(msg); err != nil {
   477  						break
   478  					}
   479  
   480  				}
   481  				wg.Done()
   482  			}(hs)
   483  		}
   484  	}
   485  	wg.Wait()
   486  }
   487  
   488  func TestDisconnect(t *testing.T) {
   489  	if runtime.GOOS == "plan9" {
   490  		t.Skip("see golang.org/issue/7237")
   491  	}
   492  	checker := &testChecker{}
   493  	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
   494  	if err != nil {
   495  		t.Fatalf("handshakePair: %v", err)
   496  	}
   497  
   498  	defer trC.Close()
   499  	defer trS.Close()
   500  
   501  	trC.writePacket([]byte{msgRequestSuccess, 0, 0})
   502  	errMsg := &disconnectMsg{
   503  		Reason:  42,
   504  		Message: "such is life",
   505  	}
   506  	trC.writePacket(Marshal(errMsg))
   507  	trC.writePacket([]byte{msgRequestSuccess, 0, 0})
   508  
   509  	packet, err := trS.readPacket()
   510  	if err != nil {
   511  		t.Fatalf("readPacket 1: %v", err)
   512  	}
   513  	if packet[0] != msgRequestSuccess {
   514  		t.Errorf("got packet %v, want packet type %d", packet, msgRequestSuccess)
   515  	}
   516  
   517  	_, err = trS.readPacket()
   518  	if err == nil {
   519  		t.Errorf("readPacket 2 succeeded")
   520  	} else if !reflect.DeepEqual(err, errMsg) {
   521  		t.Errorf("got error %#v, want %#v", err, errMsg)
   522  	}
   523  
   524  	_, err = trS.readPacket()
   525  	if err == nil {
   526  		t.Errorf("readPacket 3 succeeded")
   527  	}
   528  }
   529  
   530  func TestHandshakeRekeyDefault(t *testing.T) {
   531  	clientConf := &ClientConfig{
   532  		Config: Config{
   533  			Ciphers: []string{"aes128-ctr"},
   534  		},
   535  		HostKeyCallback: InsecureIgnoreHostKey(),
   536  	}
   537  	trC, trS, err := handshakePair(clientConf, "addr", false)
   538  	if err != nil {
   539  		t.Fatalf("handshakePair: %v", err)
   540  	}
   541  	defer trC.Close()
   542  	defer trS.Close()
   543  
   544  	trC.writePacket([]byte{msgRequestSuccess, 0, 0})
   545  	trC.Close()
   546  
   547  	rgb := (1024 + trC.readBytesLeft) >> 30
   548  	wgb := (1024 + trC.writeBytesLeft) >> 30
   549  
   550  	if rgb != 64 {
   551  		t.Errorf("got rekey after %dG read, want 64G", rgb)
   552  	}
   553  	if wgb != 64 {
   554  		t.Errorf("got rekey after %dG write, want 64G", wgb)
   555  	}
   556  }