github.com/maenmax/kairep@v0.0.0-20210218001208-55bf3df36788/src/golang.org/x/crypto/ssh/handshake_test.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/rand"
    10  	"errors"
    11  	"fmt"
    12  	"net"
    13  	"reflect"
    14  	"runtime"
    15  	"strings"
    16  	"sync"
    17  	"testing"
    18  )
    19  
    20  type testChecker struct {
    21  	calls []string
    22  }
    23  
    24  func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
    25  	if dialAddr == "bad" {
    26  		return fmt.Errorf("dialAddr is bad")
    27  	}
    28  
    29  	if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil {
    30  		return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr)
    31  	}
    32  
    33  	t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal()))
    34  
    35  	return nil
    36  }
    37  
    38  // netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
    39  // therefore is buffered (net.Pipe deadlocks if both sides start with
    40  // a write.)
    41  func netPipe() (net.Conn, net.Conn, error) {
    42  	listener, err := net.Listen("tcp", "127.0.0.1:0")
    43  	if err != nil {
    44  		return nil, nil, err
    45  	}
    46  	defer listener.Close()
    47  	c1, err := net.Dial("tcp", listener.Addr().String())
    48  	if err != nil {
    49  		return nil, nil, err
    50  	}
    51  
    52  	c2, err := listener.Accept()
    53  	if err != nil {
    54  		c1.Close()
    55  		return nil, nil, err
    56  	}
    57  
    58  	return c1, c2, nil
    59  }
    60  
    61  func handshakePair(clientConf *ClientConfig, addr string) (client *handshakeTransport, server *handshakeTransport, err error) {
    62  	a, b, err := netPipe()
    63  	if err != nil {
    64  		return nil, nil, err
    65  	}
    66  
    67  	trC := newTransport(a, rand.Reader, true)
    68  	trS := newTransport(b, rand.Reader, false)
    69  	clientConf.SetDefaults()
    70  
    71  	v := []byte("version")
    72  	client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr())
    73  
    74  	serverConf := &ServerConfig{}
    75  	serverConf.AddHostKey(testSigners["ecdsa"])
    76  	serverConf.AddHostKey(testSigners["rsa"])
    77  	serverConf.SetDefaults()
    78  	server = newServerTransport(trS, v, v, serverConf)
    79  
    80  	return client, server, nil
    81  }
    82  
    83  func TestHandshakeBasic(t *testing.T) {
    84  	if runtime.GOOS == "plan9" {
    85  		t.Skip("see golang.org/issue/7237")
    86  	}
    87  	checker := &testChecker{}
    88  	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
    89  	if err != nil {
    90  		t.Fatalf("handshakePair: %v", err)
    91  	}
    92  
    93  	defer trC.Close()
    94  	defer trS.Close()
    95  
    96  	go func() {
    97  		// Client writes a bunch of stuff, and does a key
    98  		// change in the middle. This should not confuse the
    99  		// handshake in progress
   100  		for i := 0; i < 10; i++ {
   101  			p := []byte{msgRequestSuccess, byte(i)}
   102  			if err := trC.writePacket(p); err != nil {
   103  				t.Fatalf("sendPacket: %v", err)
   104  			}
   105  			if i == 5 {
   106  				// halfway through, we request a key change.
   107  				err := trC.sendKexInit(subsequentKeyExchange)
   108  				if err != nil {
   109  					t.Fatalf("sendKexInit: %v", err)
   110  				}
   111  			}
   112  		}
   113  		trC.Close()
   114  	}()
   115  
   116  	// Server checks that client messages come in cleanly
   117  	i := 0
   118  	for {
   119  		p, err := trS.readPacket()
   120  		if err != nil {
   121  			break
   122  		}
   123  		if p[0] == msgNewKeys {
   124  			continue
   125  		}
   126  		want := []byte{msgRequestSuccess, byte(i)}
   127  		if bytes.Compare(p, want) != 0 {
   128  			t.Errorf("message %d: got %q, want %q", i, p, want)
   129  		}
   130  		i++
   131  	}
   132  	if i != 10 {
   133  		t.Errorf("received %d messages, want 10.", i)
   134  	}
   135  
   136  	// If all went well, we registered exactly 1 key change.
   137  	if len(checker.calls) != 1 {
   138  		t.Fatalf("got %d host key checks, want 1", len(checker.calls))
   139  	}
   140  
   141  	pub := testSigners["ecdsa"].PublicKey()
   142  	want := fmt.Sprintf("%s %v %s %x", "addr", trC.remoteAddr, pub.Type(), pub.Marshal())
   143  	if want != checker.calls[0] {
   144  		t.Errorf("got %q want %q for host key check", checker.calls[0], want)
   145  	}
   146  }
   147  
   148  func TestHandshakeError(t *testing.T) {
   149  	checker := &testChecker{}
   150  	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "bad")
   151  	if err != nil {
   152  		t.Fatalf("handshakePair: %v", err)
   153  	}
   154  	defer trC.Close()
   155  	defer trS.Close()
   156  
   157  	// send a packet
   158  	packet := []byte{msgRequestSuccess, 42}
   159  	if err := trC.writePacket(packet); err != nil {
   160  		t.Errorf("writePacket: %v", err)
   161  	}
   162  
   163  	// Now request a key change.
   164  	err = trC.sendKexInit(subsequentKeyExchange)
   165  	if err != nil {
   166  		t.Errorf("sendKexInit: %v", err)
   167  	}
   168  
   169  	// the key change will fail, and afterwards we can't write.
   170  	if err := trC.writePacket([]byte{msgRequestSuccess, 43}); err == nil {
   171  		t.Errorf("writePacket after botched rekey succeeded.")
   172  	}
   173  
   174  	readback, err := trS.readPacket()
   175  	if err != nil {
   176  		t.Fatalf("server closed too soon: %v", err)
   177  	}
   178  	if bytes.Compare(readback, packet) != 0 {
   179  		t.Errorf("got %q want %q", readback, packet)
   180  	}
   181  	readback, err = trS.readPacket()
   182  	if err == nil {
   183  		t.Errorf("got a message %q after failed key change", readback)
   184  	}
   185  }
   186  
   187  func TestForceFirstKex(t *testing.T) {
   188  	checker := &testChecker{}
   189  	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
   190  	if err != nil {
   191  		t.Fatalf("handshakePair: %v", err)
   192  	}
   193  
   194  	defer trC.Close()
   195  	defer trS.Close()
   196  
   197  	trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth}))
   198  
   199  	// We setup the initial key exchange, but the remote side
   200  	// tries to send serviceRequestMsg in cleartext, which is
   201  	// disallowed.
   202  
   203  	err = trS.sendKexInit(firstKeyExchange)
   204  	if err == nil {
   205  		t.Errorf("server first kex init should reject unexpected packet")
   206  	}
   207  }
   208  
   209  func TestHandshakeTwice(t *testing.T) {
   210  	checker := &testChecker{}
   211  	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
   212  	if err != nil {
   213  		t.Fatalf("handshakePair: %v", err)
   214  	}
   215  
   216  	defer trC.Close()
   217  	defer trS.Close()
   218  
   219  	// Both sides should ask for the first key exchange first.
   220  	err = trS.sendKexInit(firstKeyExchange)
   221  	if err != nil {
   222  		t.Errorf("server sendKexInit: %v", err)
   223  	}
   224  
   225  	err = trC.sendKexInit(firstKeyExchange)
   226  	if err != nil {
   227  		t.Errorf("client sendKexInit: %v", err)
   228  	}
   229  
   230  	sent := 0
   231  	// send a packet
   232  	packet := make([]byte, 5)
   233  	packet[0] = msgRequestSuccess
   234  	if err := trC.writePacket(packet); err != nil {
   235  		t.Errorf("writePacket: %v", err)
   236  	}
   237  	sent++
   238  
   239  	// Send another packet. Use a fresh one, since writePacket destroys.
   240  	packet = make([]byte, 5)
   241  	packet[0] = msgRequestSuccess
   242  	if err := trC.writePacket(packet); err != nil {
   243  		t.Errorf("writePacket: %v", err)
   244  	}
   245  	sent++
   246  
   247  	// 2nd key change.
   248  	err = trC.sendKexInit(subsequentKeyExchange)
   249  	if err != nil {
   250  		t.Errorf("sendKexInit: %v", err)
   251  	}
   252  
   253  	packet = make([]byte, 5)
   254  	packet[0] = msgRequestSuccess
   255  	if err := trC.writePacket(packet); err != nil {
   256  		t.Errorf("writePacket: %v", err)
   257  	}
   258  	sent++
   259  
   260  	packet = make([]byte, 5)
   261  	packet[0] = msgRequestSuccess
   262  	for i := 0; i < sent; i++ {
   263  		msg, err := trS.readPacket()
   264  		if err != nil {
   265  			t.Fatalf("server closed too soon: %v", err)
   266  		}
   267  
   268  		if bytes.Compare(msg, packet) != 0 {
   269  			t.Errorf("packet %d: got %q want %q", i, msg, packet)
   270  		}
   271  	}
   272  	if len(checker.calls) != 2 {
   273  		t.Errorf("got %d key changes, want 2", len(checker.calls))
   274  	}
   275  }
   276  
   277  func TestHandshakeAutoRekeyWrite(t *testing.T) {
   278  	checker := &testChecker{}
   279  	clientConf := &ClientConfig{HostKeyCallback: checker.Check}
   280  	clientConf.RekeyThreshold = 500
   281  	trC, trS, err := handshakePair(clientConf, "addr")
   282  	if err != nil {
   283  		t.Fatalf("handshakePair: %v", err)
   284  	}
   285  	defer trC.Close()
   286  	defer trS.Close()
   287  
   288  	for i := 0; i < 5; i++ {
   289  		packet := make([]byte, 251)
   290  		packet[0] = msgRequestSuccess
   291  		if err := trC.writePacket(packet); err != nil {
   292  			t.Errorf("writePacket: %v", err)
   293  		}
   294  	}
   295  
   296  	j := 0
   297  	for ; j < 5; j++ {
   298  		_, err := trS.readPacket()
   299  		if err != nil {
   300  			break
   301  		}
   302  	}
   303  
   304  	if j != 5 {
   305  		t.Errorf("got %d, want 5 messages", j)
   306  	}
   307  
   308  	if len(checker.calls) != 2 {
   309  		t.Errorf("got %d key changes, wanted 2", len(checker.calls))
   310  	}
   311  }
   312  
   313  type syncChecker struct {
   314  	called chan int
   315  }
   316  
   317  func (t *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
   318  	t.called <- 1
   319  	return nil
   320  }
   321  
   322  func TestHandshakeAutoRekeyRead(t *testing.T) {
   323  	sync := &syncChecker{make(chan int, 2)}
   324  	clientConf := &ClientConfig{
   325  		HostKeyCallback: sync.Check,
   326  	}
   327  	clientConf.RekeyThreshold = 500
   328  
   329  	trC, trS, err := handshakePair(clientConf, "addr")
   330  	if err != nil {
   331  		t.Fatalf("handshakePair: %v", err)
   332  	}
   333  	defer trC.Close()
   334  	defer trS.Close()
   335  
   336  	packet := make([]byte, 501)
   337  	packet[0] = msgRequestSuccess
   338  	if err := trS.writePacket(packet); err != nil {
   339  		t.Fatalf("writePacket: %v", err)
   340  	}
   341  	// While we read out the packet, a key change will be
   342  	// initiated.
   343  	if _, err := trC.readPacket(); err != nil {
   344  		t.Fatalf("readPacket(client): %v", err)
   345  	}
   346  
   347  	<-sync.called
   348  }
   349  
   350  // errorKeyingTransport generates errors after a given number of
   351  // read/write operations.
   352  type errorKeyingTransport struct {
   353  	packetConn
   354  	readLeft, writeLeft int
   355  }
   356  
   357  func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error {
   358  	return nil
   359  }
   360  func (n *errorKeyingTransport) getSessionID() []byte {
   361  	return nil
   362  }
   363  
   364  func (n *errorKeyingTransport) writePacket(packet []byte) error {
   365  	if n.writeLeft == 0 {
   366  		n.Close()
   367  		return errors.New("barf")
   368  	}
   369  
   370  	n.writeLeft--
   371  	return n.packetConn.writePacket(packet)
   372  }
   373  
   374  func (n *errorKeyingTransport) readPacket() ([]byte, error) {
   375  	if n.readLeft == 0 {
   376  		n.Close()
   377  		return nil, errors.New("barf")
   378  	}
   379  
   380  	n.readLeft--
   381  	return n.packetConn.readPacket()
   382  }
   383  
   384  func TestHandshakeErrorHandlingRead(t *testing.T) {
   385  	for i := 0; i < 20; i++ {
   386  		testHandshakeErrorHandlingN(t, i, -1)
   387  	}
   388  }
   389  
   390  func TestHandshakeErrorHandlingWrite(t *testing.T) {
   391  	for i := 0; i < 20; i++ {
   392  		testHandshakeErrorHandlingN(t, -1, i)
   393  	}
   394  }
   395  
   396  // testHandshakeErrorHandlingN runs handshakes, injecting errors. If
   397  // handshakeTransport deadlocks, the go runtime will detect it and
   398  // panic.
   399  func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int) {
   400  	msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)})
   401  
   402  	a, b := memPipe()
   403  	defer a.Close()
   404  	defer b.Close()
   405  
   406  	key := testSigners["ecdsa"]
   407  	serverConf := Config{RekeyThreshold: minRekeyThreshold}
   408  	serverConf.SetDefaults()
   409  	serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'})
   410  	serverConn.hostKeys = []Signer{key}
   411  	go serverConn.readLoop()
   412  
   413  	clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold}
   414  	clientConf.SetDefaults()
   415  	clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'})
   416  	clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()}
   417  	go clientConn.readLoop()
   418  
   419  	var wg sync.WaitGroup
   420  	wg.Add(4)
   421  
   422  	for _, hs := range []packetConn{serverConn, clientConn} {
   423  		go func(c packetConn) {
   424  			for {
   425  				err := c.writePacket(msg)
   426  				if err != nil {
   427  					break
   428  				}
   429  			}
   430  			wg.Done()
   431  		}(hs)
   432  		go func(c packetConn) {
   433  			for {
   434  				_, err := c.readPacket()
   435  				if err != nil {
   436  					break
   437  				}
   438  			}
   439  			wg.Done()
   440  		}(hs)
   441  	}
   442  
   443  	wg.Wait()
   444  }
   445  
   446  func TestDisconnect(t *testing.T) {
   447  	if runtime.GOOS == "plan9" {
   448  		t.Skip("see golang.org/issue/7237")
   449  	}
   450  	checker := &testChecker{}
   451  	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
   452  	if err != nil {
   453  		t.Fatalf("handshakePair: %v", err)
   454  	}
   455  
   456  	defer trC.Close()
   457  	defer trS.Close()
   458  
   459  	trC.writePacket([]byte{msgRequestSuccess, 0, 0})
   460  	errMsg := &disconnectMsg{
   461  		Reason:  42,
   462  		Message: "such is life",
   463  	}
   464  	trC.writePacket(Marshal(errMsg))
   465  	trC.writePacket([]byte{msgRequestSuccess, 0, 0})
   466  
   467  	packet, err := trS.readPacket()
   468  	if err != nil {
   469  		t.Fatalf("readPacket 1: %v", err)
   470  	}
   471  	if packet[0] != msgRequestSuccess {
   472  		t.Errorf("got packet %v, want packet type %d", packet, msgRequestSuccess)
   473  	}
   474  
   475  	_, err = trS.readPacket()
   476  	if err == nil {
   477  		t.Errorf("readPacket 2 succeeded")
   478  	} else if !reflect.DeepEqual(err, errMsg) {
   479  		t.Errorf("got error %#v, want %#v", err, errMsg)
   480  	}
   481  
   482  	_, err = trS.readPacket()
   483  	if err == nil {
   484  		t.Errorf("readPacket 3 succeeded")
   485  	}
   486  }