github.com/glycerine/xcryptossh@v7.0.4+incompatible/handshake_test.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"crypto/rand"
    11  	"errors"
    12  	"fmt"
    13  	"io"
    14  	"net"
    15  	"reflect"
    16  	"runtime"
    17  	"strings"
    18  	"sync"
    19  	"testing"
    20  )
    21  
    22  type testChecker struct {
    23  	calls []string
    24  }
    25  
    26  func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
    27  	if dialAddr == "bad" {
    28  		return fmt.Errorf("dialAddr is bad")
    29  	}
    30  
    31  	if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil {
    32  		return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr)
    33  	}
    34  
    35  	t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal()))
    36  
    37  	return nil
    38  }
    39  
    40  // netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
    41  // therefore is buffered (net.Pipe deadlocks if both sides start with
    42  // a write.)
    43  func netPipe() (net.Conn, net.Conn, error) {
    44  	listener, err := net.Listen("tcp", "127.0.0.1:0")
    45  	if err != nil {
    46  		listener, err = net.Listen("tcp", "[::1]:0")
    47  		if err != nil {
    48  			return nil, nil, err
    49  		}
    50  	}
    51  	defer listener.Close()
    52  	c1, err := net.Dial("tcp", listener.Addr().String())
    53  	if err != nil {
    54  		return nil, nil, err
    55  	}
    56  
    57  	c2, err := listener.Accept()
    58  	if err != nil {
    59  		c1.Close()
    60  		return nil, nil, err
    61  	}
    62  
    63  	return c1, c2, nil
    64  }
    65  
    66  // noiseTransport inserts ignore messages to check that the read loop
    67  // and the key exchange filters out these messages.
    68  type noiseTransport struct {
    69  	keyingTransport
    70  }
    71  
    72  func (t *noiseTransport) writePacket(p []byte) error {
    73  	ignore := []byte{msgIgnore}
    74  	if err := t.keyingTransport.writePacket(ignore); err != nil {
    75  		return err
    76  	}
    77  	debug := []byte{msgDebug, 1, 2, 3}
    78  	if err := t.keyingTransport.writePacket(debug); err != nil {
    79  		return err
    80  	}
    81  
    82  	return t.keyingTransport.writePacket(p)
    83  }
    84  
    85  func addNoiseTransport(t keyingTransport) keyingTransport {
    86  	return &noiseTransport{t}
    87  }
    88  
    89  // handshakePair creates two handshakeTransports connected with each
    90  // other. If the noise argument is true, both transports will try to
    91  // confuse the other side by sending ignore and debug messages.
    92  func handshakePair(clientConf *ClientConfig, addr string, noise bool) (client *handshakeTransport, server *handshakeTransport, err error) {
    93  	a, b, err := netPipe()
    94  	if err != nil {
    95  		return nil, nil, err
    96  	}
    97  
    98  	var trC, trS keyingTransport
    99  
   100  	trC = newTransport(a, rand.Reader, true, &clientConf.Config)
   101  	trS = newTransport(b, rand.Reader, false, &clientConf.Config)
   102  	if noise {
   103  		trC = addNoiseTransport(trC)
   104  		trS = addNoiseTransport(trS)
   105  	}
   106  	clientConf.SetDefaults()
   107  
   108  	v := []byte("version")
   109  	ctx := context.Background()
   110  
   111  	client = newClientTransport(ctx, trC, v, v, clientConf, addr, a.RemoteAddr())
   112  
   113  	serverConf := &ServerConfig{Config: Config{Halt: clientConf.Halt}}
   114  	serverConf.AddHostKey(testSigners["ecdsa"])
   115  	serverConf.AddHostKey(testSigners["rsa"])
   116  	serverConf.SetDefaults()
   117  	server = newServerTransport(ctx, trS, v, v, serverConf)
   118  	if server == nil {
   119  		return nil, nil, fmt.Errorf("ssh: shutting down.")
   120  	}
   121  	if err := server.waitSession(ctx); err != nil {
   122  		return nil, nil, fmt.Errorf("server.waitSession: %v", err)
   123  	}
   124  	if err := client.waitSession(ctx); err != nil {
   125  		return nil, nil, fmt.Errorf("client.waitSession: %v", err)
   126  	}
   127  
   128  	return client, server, nil
   129  }
   130  
   131  func TestHandshakeBasic(t *testing.T) {
   132  	defer xtestend(xtestbegin(t))
   133  
   134  	if runtime.GOOS == "plan9" {
   135  		t.Skip("see golang.org/issue/7237")
   136  	}
   137  
   138  	checker := &syncChecker{
   139  		waitCall: make(chan int, 10),
   140  		called:   make(chan int, 10),
   141  	}
   142  
   143  	halt := NewHalter()
   144  	defer halt.RequestStop()
   145  	checker.waitCall <- 1
   146  	trC, trS, err := handshakePair(
   147  		&ClientConfig{
   148  			HostKeyCallback: checker.Check,
   149  			Config: Config{
   150  				Halt: halt},
   151  		}, "addr", false)
   152  
   153  	if err != nil {
   154  		t.Fatalf("handshakePair: %v", err)
   155  	}
   156  
   157  	defer trC.Close()
   158  	defer trS.Close()
   159  
   160  	// Let first kex complete normally.
   161  	<-checker.called
   162  
   163  	clientDone := make(chan int, 0)
   164  	gotHalf := make(chan int, 0)
   165  	const N = 20
   166  
   167  	go func() {
   168  		defer close(clientDone)
   169  		// Client writes a bunch of stuff, and does a key
   170  		// change in the middle. This should not confuse the
   171  		// handshake in progress. We do this twice, so we test
   172  		// that the packet buffer is reset correctly.
   173  		for i := 0; i < N; i++ {
   174  			p := []byte{msgRequestSuccess, byte(i)}
   175  			if err := trC.writePacket(p); err != nil {
   176  				t.Fatalf("sendPacket: %v", err)
   177  			}
   178  			if (i % 10) == 5 {
   179  				<-gotHalf
   180  				// halfway through, we request a key change.
   181  				trC.requestKeyExchange()
   182  
   183  				// Wait until we can be sure the key
   184  				// change has really started before we
   185  				// write more.
   186  				<-checker.called
   187  			}
   188  			if (i % 10) == 7 {
   189  				// write some packets until the kex
   190  				// completes, to test buffering of
   191  				// packets.
   192  				checker.waitCall <- 1
   193  			}
   194  		}
   195  	}()
   196  
   197  	ctx := context.Background()
   198  
   199  	// Server checks that client messages come in cleanly
   200  	i := 0
   201  	err = nil
   202  	for ; i < N; i++ {
   203  		var p []byte
   204  		p, err = trS.readPacket(ctx)
   205  		if err != nil {
   206  			break
   207  		}
   208  		if (i % 10) == 5 {
   209  			gotHalf <- 1
   210  		}
   211  
   212  		want := []byte{msgRequestSuccess, byte(i)}
   213  		if bytes.Compare(p, want) != 0 {
   214  			t.Errorf("message %d: got %v, want %v", i, p, want)
   215  		}
   216  	}
   217  	<-clientDone
   218  	if err != nil && err != io.EOF {
   219  		t.Fatalf("server error: %v", err)
   220  	}
   221  	if i != N {
   222  		t.Errorf("received %d messages, want 10.", i)
   223  	}
   224  
   225  	close(checker.called)
   226  	if _, ok := <-checker.called; ok {
   227  		// If all went well, we registered exactly 2 key changes: one
   228  		// that establishes the session, and one that we requested
   229  		// additionally.
   230  		t.Fatalf("got another host key checks after 2 handshakes")
   231  	}
   232  }
   233  
   234  func TestForceFirstKex(t *testing.T) {
   235  	defer xtestend(xtestbegin(t))
   236  
   237  	halt := NewHalter()
   238  	defer halt.RequestStop()
   239  
   240  	// like handshakePair, but must access the keyingTransport.
   241  	checker := &testChecker{}
   242  	clientConf := &ClientConfig{
   243  		HostKeyCallback: checker.Check,
   244  		Config: Config{
   245  			Halt: halt,
   246  		},
   247  	}
   248  	a, b, err := netPipe()
   249  	if err != nil {
   250  		t.Fatalf("netPipe: %v", err)
   251  	}
   252  
   253  	var trC, trS keyingTransport
   254  
   255  	trC = newTransport(a, rand.Reader, true, nil)
   256  
   257  	// This is the disallowed packet:
   258  	trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth}))
   259  
   260  	// Rest of the setup.
   261  	trS = newTransport(b, rand.Reader, false, nil)
   262  	clientConf.SetDefaults()
   263  
   264  	v := []byte("version")
   265  	ctx := context.Background()
   266  	client := newClientTransport(ctx, trC, v, v, clientConf, "addr", a.RemoteAddr())
   267  
   268  	serverConf := &ServerConfig{Config: Config{Halt: halt}}
   269  	serverConf.AddHostKey(testSigners["ecdsa"])
   270  	serverConf.AddHostKey(testSigners["rsa"])
   271  	serverConf.SetDefaults()
   272  	server := newServerTransport(ctx, trS, v, v, serverConf)
   273  	if server == nil {
   274  		// shut down
   275  		return
   276  	}
   277  
   278  	defer client.Close()
   279  	defer server.Close()
   280  
   281  	// We setup the initial key exchange, but the remote side
   282  	// tries to send serviceRequestMsg in cleartext, which is
   283  	// disallowed.
   284  
   285  	if err := server.waitSession(ctx); err == nil {
   286  		t.Errorf("server first kex init should reject unexpected packet")
   287  	}
   288  }
   289  
   290  func TestHandshakeAutoRekeyWrite(t *testing.T) {
   291  	defer xtestend(xtestbegin(t))
   292  
   293  	halt := NewHalter()
   294  	defer halt.RequestStop()
   295  	checker := &syncChecker{
   296  		called:   make(chan int, 10),
   297  		waitCall: nil,
   298  	}
   299  	clientConf := &ClientConfig{
   300  		HostKeyCallback: checker.Check,
   301  		Config:          Config{Halt: halt},
   302  	}
   303  	clientConf.RekeyThreshold = 500
   304  	trC, trS, err := handshakePair(clientConf, "addr", false)
   305  	if err != nil {
   306  		t.Fatalf("handshakePair: %v", err)
   307  	}
   308  	defer trC.Close()
   309  	defer trS.Close()
   310  
   311  	input := make([]byte, 251)
   312  	input[0] = msgRequestSuccess
   313  	ctx := context.Background()
   314  
   315  	done := make(chan int, 1)
   316  	const numPacket = 5
   317  	go func() {
   318  		defer close(done)
   319  		j := 0
   320  		for ; j < numPacket; j++ {
   321  			if p, err := trS.readPacket(ctx); err != nil {
   322  				break
   323  			} else if !bytes.Equal(input, p) {
   324  				t.Errorf("got packet type %d, want %d", p[0], input[0])
   325  			}
   326  		}
   327  
   328  		if j != numPacket {
   329  			t.Errorf("got %d, want 5 messages", j)
   330  		}
   331  	}()
   332  
   333  	<-checker.called
   334  
   335  	for i := 0; i < numPacket; i++ {
   336  		p := make([]byte, len(input))
   337  		copy(p, input)
   338  		if err := trC.writePacket(p); err != nil {
   339  			t.Errorf("writePacket: %v", err)
   340  		}
   341  		if i == 2 {
   342  			// Make sure the kex is in progress.
   343  			<-checker.called
   344  		}
   345  
   346  	}
   347  	<-done
   348  }
   349  
   350  type syncChecker struct {
   351  	waitCall chan int
   352  	called   chan int
   353  }
   354  
   355  func (c *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
   356  	c.called <- 1
   357  	if c.waitCall != nil {
   358  		<-c.waitCall
   359  	}
   360  	return nil
   361  }
   362  
   363  func TestHandshakeAutoRekeyRead(t *testing.T) {
   364  	defer xtestend(xtestbegin(t))
   365  
   366  	halt := NewHalter()
   367  	defer halt.RequestStop()
   368  
   369  	sync := &syncChecker{
   370  		called:   make(chan int, 2),
   371  		waitCall: nil,
   372  	}
   373  	clientConf := &ClientConfig{
   374  		HostKeyCallback: sync.Check,
   375  		Config:          Config{Halt: halt},
   376  	}
   377  	clientConf.RekeyThreshold = 500
   378  
   379  	trC, trS, err := handshakePair(clientConf, "addr", false)
   380  	if err != nil {
   381  		t.Fatalf("handshakePair: %v", err)
   382  	}
   383  	defer trC.Close()
   384  	defer trS.Close()
   385  
   386  	packet := make([]byte, 501)
   387  	packet[0] = msgRequestSuccess
   388  	if err := trS.writePacket(packet); err != nil {
   389  		t.Fatalf("writePacket: %v", err)
   390  	}
   391  
   392  	// While we read out the packet, a key change will be
   393  	// initiated.
   394  	done := make(chan int, 1)
   395  	ctx := context.Background()
   396  
   397  	go func() {
   398  		defer close(done)
   399  		if _, err := trC.readPacket(ctx); err != nil {
   400  			t.Fatalf("readPacket(client): %v", err)
   401  		}
   402  
   403  	}()
   404  
   405  	<-done
   406  	<-sync.called
   407  }
   408  
   409  // errorKeyingTransport generates errors after a given number of
   410  // read/write operations.
   411  type errorKeyingTransport struct {
   412  	packetConn
   413  	readLeft, writeLeft int
   414  }
   415  
   416  func (n *errorKeyingTransport) prepareKeyChange(context.Context, *algorithms, *kexResult, *Config) error {
   417  	return nil
   418  }
   419  
   420  func (n *errorKeyingTransport) getSessionID() []byte {
   421  	return nil
   422  }
   423  
   424  func (n *errorKeyingTransport) writePacket(packet []byte) error {
   425  	if n.writeLeft == 0 {
   426  		n.Close()
   427  		return errors.New("barf")
   428  	}
   429  
   430  	n.writeLeft--
   431  	return n.packetConn.writePacket(packet)
   432  }
   433  
   434  func (n *errorKeyingTransport) readPacket(ctx context.Context) ([]byte, error) {
   435  	if n.readLeft == 0 {
   436  		n.Close()
   437  		return nil, errors.New("barf")
   438  	}
   439  
   440  	n.readLeft--
   441  
   442  	return n.packetConn.readPacket(ctx)
   443  }
   444  
   445  func TestHandshakeErrorHandlingRead(t *testing.T) {
   446  	defer xtestend(xtestbegin(t))
   447  
   448  	for i := 0; i < 20; i++ {
   449  		testHandshakeErrorHandlingN(t, i, -1, false)
   450  	}
   451  }
   452  
   453  func TestHandshakeErrorHandlingWrite(t *testing.T) {
   454  	defer xtestend(xtestbegin(t))
   455  
   456  	for i := 0; i < 20; i++ {
   457  		testHandshakeErrorHandlingN(t, -1, i, false)
   458  	}
   459  }
   460  
   461  func TestHandshakeErrorHandlingReadCoupled(t *testing.T) {
   462  	defer xtestend(xtestbegin(t))
   463  
   464  	for i := 0; i < 20; i++ {
   465  		testHandshakeErrorHandlingN(t, i, -1, true)
   466  	}
   467  }
   468  
   469  func TestHandshakeErrorHandlingWriteCoupled(t *testing.T) {
   470  	defer xtestend(xtestbegin(t))
   471  
   472  	for i := 0; i < 20; i++ {
   473  		testHandshakeErrorHandlingN(t, -1, i, true)
   474  	}
   475  }
   476  
   477  // testHandshakeErrorHandlingN runs handshakes, injecting errors. If
   478  // handshakeTransport deadlocks, the go runtime will detect it and
   479  // panic.
   480  func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int, coupled bool) {
   481  	msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)})
   482  
   483  	a, b := memPipe()
   484  	defer a.Close()
   485  	defer b.Close()
   486  
   487  	key := testSigners["ecdsa"]
   488  	serverConf := Config{RekeyThreshold: minRekeyThreshold}
   489  	serverConf.SetDefaults()
   490  	defer serverConf.Halt.RequestStop()
   491  	ctx := context.Background()
   492  	serverConn := newHandshakeTransport(ctx, &errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'})
   493  	serverConn.hostKeys = []Signer{key}
   494  	go serverConn.readLoop(ctx)
   495  	go serverConn.kexLoop(ctx)
   496  
   497  	clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold}
   498  	clientConf.SetDefaults()
   499  	defer clientConf.Halt.RequestStop()
   500  	clientConn := newHandshakeTransport(ctx, &errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'})
   501  	clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()}
   502  	clientConn.hostKeyCallback = InsecureIgnoreHostKey()
   503  	go clientConn.readLoop(ctx)
   504  	go clientConn.kexLoop(ctx)
   505  
   506  	var wg sync.WaitGroup
   507  
   508  	for _, hs := range []packetConn{serverConn, clientConn} {
   509  		if !coupled {
   510  			wg.Add(2)
   511  			go func(c packetConn) {
   512  				for i := 0; ; i++ {
   513  					str := fmt.Sprintf("%08x", i) + strings.Repeat("x", int(minRekeyThreshold)/4-8)
   514  					err := c.writePacket(Marshal(&serviceRequestMsg{str}))
   515  					if err != nil {
   516  						break
   517  					}
   518  				}
   519  				wg.Done()
   520  				c.Close()
   521  			}(hs)
   522  			go func(c packetConn) {
   523  				for {
   524  					_, err := c.readPacket(ctx)
   525  					if err != nil {
   526  						break
   527  					}
   528  				}
   529  				wg.Done()
   530  			}(hs)
   531  		} else {
   532  			wg.Add(1)
   533  			go func(c packetConn) {
   534  				for {
   535  					_, err := c.readPacket(ctx)
   536  					if err != nil {
   537  						break
   538  					}
   539  					if err := c.writePacket(msg); err != nil {
   540  						break
   541  					}
   542  
   543  				}
   544  				wg.Done()
   545  			}(hs)
   546  		}
   547  	}
   548  	wg.Wait()
   549  }
   550  
   551  func TestDisconnect(t *testing.T) {
   552  	defer xtestend(xtestbegin(t))
   553  
   554  	halt := NewHalter()
   555  	defer halt.RequestStop()
   556  
   557  	if runtime.GOOS == "plan9" {
   558  		t.Skip("see golang.org/issue/7237")
   559  	}
   560  	checker := &testChecker{}
   561  	trC, trS, err := handshakePair(
   562  		&ClientConfig{
   563  			HostKeyCallback: checker.Check,
   564  			Config:          Config{Halt: halt},
   565  		}, "addr", false)
   566  	if err != nil {
   567  		t.Fatalf("handshakePair: %v", err)
   568  	}
   569  
   570  	defer trC.Close()
   571  	defer trS.Close()
   572  
   573  	trC.writePacket([]byte{msgRequestSuccess, 0, 0})
   574  	errMsg := &disconnectMsg{
   575  		Reason:  42,
   576  		Message: "such is life",
   577  	}
   578  	trC.writePacket(Marshal(errMsg))
   579  	trC.writePacket([]byte{msgRequestSuccess, 0, 0})
   580  	ctx := context.Background()
   581  	packet, err := trS.readPacket(ctx)
   582  	if err != nil {
   583  		t.Fatalf("readPacket 1: %v", err)
   584  	}
   585  	if packet[0] != msgRequestSuccess {
   586  		t.Errorf("got packet %v, want packet type %d", packet, msgRequestSuccess)
   587  	}
   588  
   589  	_, err = trS.readPacket(ctx)
   590  	if err == nil {
   591  		t.Errorf("readPacket 2 succeeded")
   592  	} else if !reflect.DeepEqual(err, errMsg) {
   593  		t.Errorf("got error %#v, want %#v", err, errMsg)
   594  	}
   595  
   596  	_, err = trS.readPacket(ctx)
   597  	if err == nil {
   598  		t.Errorf("readPacket 3 succeeded")
   599  	}
   600  }
   601  
   602  func TestHandshakeRekeyDefault(t *testing.T) {
   603  	defer xtestend(xtestbegin(t))
   604  
   605  	halt := NewHalter()
   606  	defer halt.RequestStop()
   607  
   608  	clientConf := &ClientConfig{
   609  		Config: Config{
   610  			Ciphers: []string{"aes128-ctr"},
   611  			Halt:    halt,
   612  		},
   613  		HostKeyCallback: InsecureIgnoreHostKey(),
   614  	}
   615  	trC, trS, err := handshakePair(clientConf, "addr", false)
   616  	if err != nil {
   617  		t.Fatalf("handshakePair: %v", err)
   618  	}
   619  	defer trC.Close()
   620  	defer trS.Close()
   621  
   622  	trC.writePacket([]byte{msgRequestSuccess, 0, 0})
   623  	trC.Close()
   624  
   625  	rgb := (1024 + trC.readBytesLeft) >> 30
   626  	wgb := (1024 + trC.writeBytesLeft) >> 30
   627  
   628  	if rgb != 64 {
   629  		t.Errorf("got rekey after %dG read, want 64G", rgb)
   630  	}
   631  	if wgb != 64 {
   632  		t.Errorf("got rekey after %dG write, want 64G", wgb)
   633  	}
   634  }