github.com/Psiphon-Labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/crypto/ssh/handshake_test.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/rand"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"net"
    14  	"reflect"
    15  	"runtime"
    16  	"strings"
    17  	"sync"
    18  	"testing"
    19  )
    20  
    21  type testChecker struct {
    22  	calls []string
    23  }
    24  
    25  func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
    26  	if dialAddr == "bad" {
    27  		return fmt.Errorf("dialAddr is bad")
    28  	}
    29  
    30  	if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil {
    31  		return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr)
    32  	}
    33  
    34  	t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal()))
    35  
    36  	return nil
    37  }
    38  
    39  // netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
    40  // therefore is buffered (net.Pipe deadlocks if both sides start with
    41  // a write.)
    42  func netPipe() (net.Conn, net.Conn, error) {
    43  	listener, err := net.Listen("tcp", "127.0.0.1:0")
    44  	if err != nil {
    45  		listener, err = net.Listen("tcp", "[::1]:0")
    46  		if err != nil {
    47  			return nil, nil, err
    48  		}
    49  	}
    50  	defer listener.Close()
    51  	c1, err := net.Dial("tcp", listener.Addr().String())
    52  	if err != nil {
    53  		return nil, nil, err
    54  	}
    55  
    56  	c2, err := listener.Accept()
    57  	if err != nil {
    58  		c1.Close()
    59  		return nil, nil, err
    60  	}
    61  
    62  	return c1, c2, nil
    63  }
    64  
    65  // noiseTransport inserts ignore messages to check that the read loop
    66  // and the key exchange filters out these messages.
    67  type noiseTransport struct {
    68  	keyingTransport
    69  }
    70  
    71  func (t *noiseTransport) writePacket(p []byte) error {
    72  	ignore := []byte{msgIgnore}
    73  	if err := t.keyingTransport.writePacket(ignore); err != nil {
    74  		return err
    75  	}
    76  	debug := []byte{msgDebug, 1, 2, 3}
    77  	if err := t.keyingTransport.writePacket(debug); err != nil {
    78  		return err
    79  	}
    80  
    81  	return t.keyingTransport.writePacket(p)
    82  }
    83  
    84  func addNoiseTransport(t keyingTransport) keyingTransport {
    85  	return &noiseTransport{t}
    86  }
    87  
    88  // handshakePair creates two handshakeTransports connected with each
    89  // other. If the noise argument is true, both transports will try to
    90  // confuse the other side by sending ignore and debug messages.
    91  func handshakePair(clientConf *ClientConfig, addr string, noise bool) (client *handshakeTransport, server *handshakeTransport, err error) {
    92  	a, b, err := netPipe()
    93  	if err != nil {
    94  		return nil, nil, err
    95  	}
    96  
    97  	var trC, trS keyingTransport
    98  
    99  	trC = newTransport(a, rand.Reader, true)
   100  	trS = newTransport(b, rand.Reader, false)
   101  	if noise {
   102  		trC = addNoiseTransport(trC)
   103  		trS = addNoiseTransport(trS)
   104  	}
   105  	clientConf.SetDefaults()
   106  
   107  	v := []byte("version")
   108  	client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr())
   109  
   110  	serverConf := &ServerConfig{}
   111  	serverConf.AddHostKey(testSigners["ecdsa"])
   112  	serverConf.AddHostKey(testSigners["rsa"])
   113  	serverConf.SetDefaults()
   114  	server = newServerTransport(trS, v, v, serverConf)
   115  
   116  	if err := server.waitSession(); err != nil {
   117  		return nil, nil, fmt.Errorf("server.waitSession: %v", err)
   118  	}
   119  	if err := client.waitSession(); err != nil {
   120  		return nil, nil, fmt.Errorf("client.waitSession: %v", err)
   121  	}
   122  
   123  	return client, server, nil
   124  }
   125  
   126  func TestHandshakeBasic(t *testing.T) {
   127  	if runtime.GOOS == "plan9" {
   128  		t.Skip("see golang.org/issue/7237")
   129  	}
   130  
   131  	checker := &syncChecker{
   132  		waitCall: make(chan int, 10),
   133  		called:   make(chan int, 10),
   134  	}
   135  
   136  	checker.waitCall <- 1
   137  	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
   138  	if err != nil {
   139  		t.Fatalf("handshakePair: %v", err)
   140  	}
   141  
   142  	defer trC.Close()
   143  	defer trS.Close()
   144  
   145  	// Let first kex complete normally.
   146  	<-checker.called
   147  
   148  	clientDone := make(chan int, 0)
   149  	gotHalf := make(chan int, 0)
   150  	const N = 20
   151  
   152  	go func() {
   153  		defer close(clientDone)
   154  		// Client writes a bunch of stuff, and does a key
   155  		// change in the middle. This should not confuse the
   156  		// handshake in progress. We do this twice, so we test
   157  		// that the packet buffer is reset correctly.
   158  		for i := 0; i < N; i++ {
   159  			p := []byte{msgRequestSuccess, byte(i)}
   160  			if err := trC.writePacket(p); err != nil {
   161  				t.Fatalf("sendPacket: %v", err)
   162  			}
   163  			if (i % 10) == 5 {
   164  				<-gotHalf
   165  				// halfway through, we request a key change.
   166  				trC.requestKeyExchange()
   167  
   168  				// Wait until we can be sure the key
   169  				// change has really started before we
   170  				// write more.
   171  				<-checker.called
   172  			}
   173  			if (i % 10) == 7 {
   174  				// write some packets until the kex
   175  				// completes, to test buffering of
   176  				// packets.
   177  				checker.waitCall <- 1
   178  			}
   179  		}
   180  	}()
   181  
   182  	// Server checks that client messages come in cleanly
   183  	i := 0
   184  	err = nil
   185  	for ; i < N; i++ {
   186  		var p []byte
   187  		p, err = trS.readPacket()
   188  		if err != nil {
   189  			break
   190  		}
   191  		if (i % 10) == 5 {
   192  			gotHalf <- 1
   193  		}
   194  
   195  		want := []byte{msgRequestSuccess, byte(i)}
   196  		if bytes.Compare(p, want) != 0 {
   197  			t.Errorf("message %d: got %v, want %v", i, p, want)
   198  		}
   199  	}
   200  	<-clientDone
   201  	if err != nil && err != io.EOF {
   202  		t.Fatalf("server error: %v", err)
   203  	}
   204  	if i != N {
   205  		t.Errorf("received %d messages, want 10.", i)
   206  	}
   207  
   208  	close(checker.called)
   209  	if _, ok := <-checker.called; ok {
   210  		// If all went well, we registered exactly 2 key changes: one
   211  		// that establishes the session, and one that we requested
   212  		// additionally.
   213  		t.Fatalf("got another host key checks after 2 handshakes")
   214  	}
   215  }
   216  
   217  func TestForceFirstKex(t *testing.T) {
   218  	// like handshakePair, but must access the keyingTransport.
   219  	checker := &testChecker{}
   220  	clientConf := &ClientConfig{HostKeyCallback: checker.Check}
   221  	a, b, err := netPipe()
   222  	if err != nil {
   223  		t.Fatalf("netPipe: %v", err)
   224  	}
   225  
   226  	var trC, trS keyingTransport
   227  
   228  	trC = newTransport(a, rand.Reader, true)
   229  
   230  	// This is the disallowed packet:
   231  	trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth}))
   232  
   233  	// Rest of the setup.
   234  	trS = newTransport(b, rand.Reader, false)
   235  	clientConf.SetDefaults()
   236  
   237  	v := []byte("version")
   238  	client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr())
   239  
   240  	serverConf := &ServerConfig{}
   241  	serverConf.AddHostKey(testSigners["ecdsa"])
   242  	serverConf.AddHostKey(testSigners["rsa"])
   243  	serverConf.SetDefaults()
   244  	server := newServerTransport(trS, v, v, serverConf)
   245  
   246  	defer client.Close()
   247  	defer server.Close()
   248  
   249  	// We setup the initial key exchange, but the remote side
   250  	// tries to send serviceRequestMsg in cleartext, which is
   251  	// disallowed.
   252  
   253  	if err := server.waitSession(); err == nil {
   254  		t.Errorf("server first kex init should reject unexpected packet")
   255  	}
   256  }
   257  
   258  func TestHandshakeAutoRekeyWrite(t *testing.T) {
   259  	checker := &syncChecker{
   260  		called:   make(chan int, 10),
   261  		waitCall: nil,
   262  	}
   263  	clientConf := &ClientConfig{HostKeyCallback: checker.Check}
   264  	clientConf.RekeyThreshold = 500
   265  	trC, trS, err := handshakePair(clientConf, "addr", false)
   266  	if err != nil {
   267  		t.Fatalf("handshakePair: %v", err)
   268  	}
   269  	defer trC.Close()
   270  	defer trS.Close()
   271  
   272  	input := make([]byte, 251)
   273  	input[0] = msgRequestSuccess
   274  
   275  	done := make(chan int, 1)
   276  	const numPacket = 5
   277  	go func() {
   278  		defer close(done)
   279  		j := 0
   280  		for ; j < numPacket; j++ {
   281  			if p, err := trS.readPacket(); err != nil {
   282  				break
   283  			} else if !bytes.Equal(input, p) {
   284  				t.Errorf("got packet type %d, want %d", p[0], input[0])
   285  			}
   286  		}
   287  
   288  		if j != numPacket {
   289  			t.Errorf("got %d, want 5 messages", j)
   290  		}
   291  	}()
   292  
   293  	<-checker.called
   294  
   295  	for i := 0; i < numPacket; i++ {
   296  		p := make([]byte, len(input))
   297  		copy(p, input)
   298  		if err := trC.writePacket(p); err != nil {
   299  			t.Errorf("writePacket: %v", err)
   300  		}
   301  		if i == 2 {
   302  			// Make sure the kex is in progress.
   303  			<-checker.called
   304  		}
   305  
   306  	}
   307  	<-done
   308  }
   309  
   310  type syncChecker struct {
   311  	waitCall chan int
   312  	called   chan int
   313  }
   314  
   315  func (c *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
   316  	c.called <- 1
   317  	if c.waitCall != nil {
   318  		<-c.waitCall
   319  	}
   320  	return nil
   321  }
   322  
   323  func TestHandshakeAutoRekeyRead(t *testing.T) {
   324  	sync := &syncChecker{
   325  		called:   make(chan int, 2),
   326  		waitCall: nil,
   327  	}
   328  	clientConf := &ClientConfig{
   329  		HostKeyCallback: sync.Check,
   330  	}
   331  	clientConf.RekeyThreshold = 500
   332  
   333  	trC, trS, err := handshakePair(clientConf, "addr", false)
   334  	if err != nil {
   335  		t.Fatalf("handshakePair: %v", err)
   336  	}
   337  	defer trC.Close()
   338  	defer trS.Close()
   339  
   340  	packet := make([]byte, 501)
   341  	packet[0] = msgRequestSuccess
   342  	if err := trS.writePacket(packet); err != nil {
   343  		t.Fatalf("writePacket: %v", err)
   344  	}
   345  
   346  	// While we read out the packet, a key change will be
   347  	// initiated.
   348  	done := make(chan int, 1)
   349  	go func() {
   350  		defer close(done)
   351  		if _, err := trC.readPacket(); err != nil {
   352  			t.Fatalf("readPacket(client): %v", err)
   353  		}
   354  
   355  	}()
   356  
   357  	<-done
   358  	<-sync.called
   359  }
   360  
   361  // errorKeyingTransport generates errors after a given number of
   362  // read/write operations.
   363  type errorKeyingTransport struct {
   364  	packetConn
   365  	readLeft, writeLeft int
   366  }
   367  
   368  func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error {
   369  	return nil
   370  }
   371  
   372  func (n *errorKeyingTransport) getSessionID() []byte {
   373  	return nil
   374  }
   375  
   376  func (n *errorKeyingTransport) writePacket(packet []byte) error {
   377  	if n.writeLeft == 0 {
   378  		n.Close()
   379  		return errors.New("barf")
   380  	}
   381  
   382  	n.writeLeft--
   383  	return n.packetConn.writePacket(packet)
   384  }
   385  
   386  func (n *errorKeyingTransport) readPacket() ([]byte, error) {
   387  	if n.readLeft == 0 {
   388  		n.Close()
   389  		return nil, errors.New("barf")
   390  	}
   391  
   392  	n.readLeft--
   393  	return n.packetConn.readPacket()
   394  }
   395  
   396  func TestHandshakeErrorHandlingRead(t *testing.T) {
   397  	for i := 0; i < 20; i++ {
   398  		testHandshakeErrorHandlingN(t, i, -1, false)
   399  	}
   400  }
   401  
   402  func TestHandshakeErrorHandlingWrite(t *testing.T) {
   403  	for i := 0; i < 20; i++ {
   404  		testHandshakeErrorHandlingN(t, -1, i, false)
   405  	}
   406  }
   407  
   408  func TestHandshakeErrorHandlingReadCoupled(t *testing.T) {
   409  	for i := 0; i < 20; i++ {
   410  		testHandshakeErrorHandlingN(t, i, -1, true)
   411  	}
   412  }
   413  
   414  func TestHandshakeErrorHandlingWriteCoupled(t *testing.T) {
   415  	for i := 0; i < 20; i++ {
   416  		testHandshakeErrorHandlingN(t, -1, i, true)
   417  	}
   418  }
   419  
   420  // testHandshakeErrorHandlingN runs handshakes, injecting errors. If
   421  // handshakeTransport deadlocks, the go runtime will detect it and
   422  // panic.
   423  func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int, coupled bool) {
   424  	if runtime.GOOS == "js" && runtime.GOARCH == "wasm" {
   425  		t.Skip("skipping on js/wasm; see golang.org/issue/32840")
   426  	}
   427  	msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)})
   428  
   429  	a, b := memPipe()
   430  	defer a.Close()
   431  	defer b.Close()
   432  
   433  	key := testSigners["ecdsa"]
   434  	serverConf := Config{RekeyThreshold: minRekeyThreshold}
   435  	serverConf.SetDefaults()
   436  	serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'})
   437  	serverConn.hostKeys = []Signer{key}
   438  	go serverConn.readLoop()
   439  	go serverConn.kexLoop()
   440  
   441  	clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold}
   442  	clientConf.SetDefaults()
   443  	clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'})
   444  	clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()}
   445  	clientConn.hostKeyCallback = InsecureIgnoreHostKey()
   446  	go clientConn.readLoop()
   447  	go clientConn.kexLoop()
   448  
   449  	var wg sync.WaitGroup
   450  
   451  	for _, hs := range []packetConn{serverConn, clientConn} {
   452  		if !coupled {
   453  			wg.Add(2)
   454  			go func(c packetConn) {
   455  				for i := 0; ; i++ {
   456  					str := fmt.Sprintf("%08x", i) + strings.Repeat("x", int(minRekeyThreshold)/4-8)
   457  					err := c.writePacket(Marshal(&serviceRequestMsg{str}))
   458  					if err != nil {
   459  						break
   460  					}
   461  				}
   462  				wg.Done()
   463  				c.Close()
   464  			}(hs)
   465  			go func(c packetConn) {
   466  				for {
   467  					_, err := c.readPacket()
   468  					if err != nil {
   469  						break
   470  					}
   471  				}
   472  				wg.Done()
   473  			}(hs)
   474  		} else {
   475  			wg.Add(1)
   476  			go func(c packetConn) {
   477  				for {
   478  					_, err := c.readPacket()
   479  					if err != nil {
   480  						break
   481  					}
   482  					if err := c.writePacket(msg); err != nil {
   483  						break
   484  					}
   485  
   486  				}
   487  				wg.Done()
   488  			}(hs)
   489  		}
   490  	}
   491  	wg.Wait()
   492  }
   493  
   494  func TestDisconnect(t *testing.T) {
   495  	if runtime.GOOS == "plan9" {
   496  		t.Skip("see golang.org/issue/7237")
   497  	}
   498  	checker := &testChecker{}
   499  	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false)
   500  	if err != nil {
   501  		t.Fatalf("handshakePair: %v", err)
   502  	}
   503  
   504  	defer trC.Close()
   505  	defer trS.Close()
   506  
   507  	trC.writePacket([]byte{msgRequestSuccess, 0, 0})
   508  	errMsg := &disconnectMsg{
   509  		Reason:  42,
   510  		Message: "such is life",
   511  	}
   512  	trC.writePacket(Marshal(errMsg))
   513  	trC.writePacket([]byte{msgRequestSuccess, 0, 0})
   514  
   515  	packet, err := trS.readPacket()
   516  	if err != nil {
   517  		t.Fatalf("readPacket 1: %v", err)
   518  	}
   519  	if packet[0] != msgRequestSuccess {
   520  		t.Errorf("got packet %v, want packet type %d", packet, msgRequestSuccess)
   521  	}
   522  
   523  	_, err = trS.readPacket()
   524  	if err == nil {
   525  		t.Errorf("readPacket 2 succeeded")
   526  	} else if !reflect.DeepEqual(err, errMsg) {
   527  		t.Errorf("got error %#v, want %#v", err, errMsg)
   528  	}
   529  
   530  	_, err = trS.readPacket()
   531  	if err == nil {
   532  		t.Errorf("readPacket 3 succeeded")
   533  	}
   534  }
   535  
   536  func TestHandshakeRekeyDefault(t *testing.T) {
   537  	clientConf := &ClientConfig{
   538  		Config: Config{
   539  			Ciphers: []string{"aes128-ctr"},
   540  		},
   541  		HostKeyCallback: InsecureIgnoreHostKey(),
   542  	}
   543  	trC, trS, err := handshakePair(clientConf, "addr", false)
   544  	if err != nil {
   545  		t.Fatalf("handshakePair: %v", err)
   546  	}
   547  	defer trC.Close()
   548  	defer trS.Close()
   549  
   550  	trC.writePacket([]byte{msgRequestSuccess, 0, 0})
   551  	trC.Close()
   552  
   553  	rgb := (1024 + trC.readBytesLeft) >> 30
   554  	wgb := (1024 + trC.writeBytesLeft) >> 30
   555  
   556  	if rgb != 64 {
   557  		t.Errorf("got rekey after %dG read, want 64G", rgb)
   558  	}
   559  	if wgb != 64 {
   560  		t.Errorf("got rekey after %dG write, want 64G", wgb)
   561  	}
   562  }