github.com/graybobo/golang.org-package-offline-cache@v0.0.0-20200626051047-6608995c132f/x/crypto/ssh/handshake_test.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/rand"
    10  	"errors"
    11  	"fmt"
    12  	"net"
    13  	"runtime"
    14  	"strings"
    15  	"sync"
    16  	"testing"
    17  )
    18  
    19  type testChecker struct {
    20  	calls []string
    21  }
    22  
    23  func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
    24  	if dialAddr == "bad" {
    25  		return fmt.Errorf("dialAddr is bad")
    26  	}
    27  
    28  	if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil {
    29  		return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr)
    30  	}
    31  
    32  	t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal()))
    33  
    34  	return nil
    35  }
    36  
    37  // netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
    38  // therefore is buffered (net.Pipe deadlocks if both sides start with
    39  // a write.)
    40  func netPipe() (net.Conn, net.Conn, error) {
    41  	listener, err := net.Listen("tcp", "127.0.0.1:0")
    42  	if err != nil {
    43  		return nil, nil, err
    44  	}
    45  	defer listener.Close()
    46  	c1, err := net.Dial("tcp", listener.Addr().String())
    47  	if err != nil {
    48  		return nil, nil, err
    49  	}
    50  
    51  	c2, err := listener.Accept()
    52  	if err != nil {
    53  		c1.Close()
    54  		return nil, nil, err
    55  	}
    56  
    57  	return c1, c2, nil
    58  }
    59  
    60  func handshakePair(clientConf *ClientConfig, addr string) (client *handshakeTransport, server *handshakeTransport, err error) {
    61  	a, b, err := netPipe()
    62  	if err != nil {
    63  		return nil, nil, err
    64  	}
    65  
    66  	trC := newTransport(a, rand.Reader, true)
    67  	trS := newTransport(b, rand.Reader, false)
    68  	clientConf.SetDefaults()
    69  
    70  	v := []byte("version")
    71  	client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr())
    72  
    73  	serverConf := &ServerConfig{}
    74  	serverConf.AddHostKey(testSigners["ecdsa"])
    75  	serverConf.AddHostKey(testSigners["rsa"])
    76  	serverConf.SetDefaults()
    77  	server = newServerTransport(trS, v, v, serverConf)
    78  
    79  	return client, server, nil
    80  }
    81  
    82  func TestHandshakeBasic(t *testing.T) {
    83  	if runtime.GOOS == "plan9" {
    84  		t.Skip("see golang.org/issue/7237")
    85  	}
    86  	checker := &testChecker{}
    87  	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
    88  	if err != nil {
    89  		t.Fatalf("handshakePair: %v", err)
    90  	}
    91  
    92  	defer trC.Close()
    93  	defer trS.Close()
    94  
    95  	go func() {
    96  		// Client writes a bunch of stuff, and does a key
    97  		// change in the middle. This should not confuse the
    98  		// handshake in progress
    99  		for i := 0; i < 10; i++ {
   100  			p := []byte{msgRequestSuccess, byte(i)}
   101  			if err := trC.writePacket(p); err != nil {
   102  				t.Fatalf("sendPacket: %v", err)
   103  			}
   104  			if i == 5 {
   105  				// halfway through, we request a key change.
   106  				_, _, err := trC.sendKexInit()
   107  				if err != nil {
   108  					t.Fatalf("sendKexInit: %v", err)
   109  				}
   110  			}
   111  		}
   112  		trC.Close()
   113  	}()
   114  
   115  	// Server checks that client messages come in cleanly
   116  	i := 0
   117  	for {
   118  		p, err := trS.readPacket()
   119  		if err != nil {
   120  			break
   121  		}
   122  		if p[0] == msgNewKeys {
   123  			continue
   124  		}
   125  		want := []byte{msgRequestSuccess, byte(i)}
   126  		if bytes.Compare(p, want) != 0 {
   127  			t.Errorf("message %d: got %q, want %q", i, p, want)
   128  		}
   129  		i++
   130  	}
   131  	if i != 10 {
   132  		t.Errorf("received %d messages, want 10.", i)
   133  	}
   134  
   135  	// If all went well, we registered exactly 1 key change.
   136  	if len(checker.calls) != 1 {
   137  		t.Fatalf("got %d host key checks, want 1", len(checker.calls))
   138  	}
   139  
   140  	pub := testSigners["ecdsa"].PublicKey()
   141  	want := fmt.Sprintf("%s %v %s %x", "addr", trC.remoteAddr, pub.Type(), pub.Marshal())
   142  	if want != checker.calls[0] {
   143  		t.Errorf("got %q want %q for host key check", checker.calls[0], want)
   144  	}
   145  }
   146  
   147  func TestHandshakeError(t *testing.T) {
   148  	checker := &testChecker{}
   149  	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "bad")
   150  	if err != nil {
   151  		t.Fatalf("handshakePair: %v", err)
   152  	}
   153  	defer trC.Close()
   154  	defer trS.Close()
   155  
   156  	// send a packet
   157  	packet := []byte{msgRequestSuccess, 42}
   158  	if err := trC.writePacket(packet); err != nil {
   159  		t.Errorf("writePacket: %v", err)
   160  	}
   161  
   162  	// Now request a key change.
   163  	_, _, err = trC.sendKexInit()
   164  	if err != nil {
   165  		t.Errorf("sendKexInit: %v", err)
   166  	}
   167  
   168  	// the key change will fail, and afterwards we can't write.
   169  	if err := trC.writePacket([]byte{msgRequestSuccess, 43}); err == nil {
   170  		t.Errorf("writePacket after botched rekey succeeded.")
   171  	}
   172  
   173  	readback, err := trS.readPacket()
   174  	if err != nil {
   175  		t.Fatalf("server closed too soon: %v", err)
   176  	}
   177  	if bytes.Compare(readback, packet) != 0 {
   178  		t.Errorf("got %q want %q", readback, packet)
   179  	}
   180  	readback, err = trS.readPacket()
   181  	if err == nil {
   182  		t.Errorf("got a message %q after failed key change", readback)
   183  	}
   184  }
   185  
   186  func TestHandshakeTwice(t *testing.T) {
   187  	checker := &testChecker{}
   188  	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
   189  	if err != nil {
   190  		t.Fatalf("handshakePair: %v", err)
   191  	}
   192  
   193  	defer trC.Close()
   194  	defer trS.Close()
   195  
   196  	// send a packet
   197  	packet := make([]byte, 5)
   198  	packet[0] = msgRequestSuccess
   199  	if err := trC.writePacket(packet); err != nil {
   200  		t.Errorf("writePacket: %v", err)
   201  	}
   202  
   203  	// Now request a key change.
   204  	_, _, err = trC.sendKexInit()
   205  	if err != nil {
   206  		t.Errorf("sendKexInit: %v", err)
   207  	}
   208  
   209  	// Send another packet. Use a fresh one, since writePacket destroys.
   210  	packet = make([]byte, 5)
   211  	packet[0] = msgRequestSuccess
   212  	if err := trC.writePacket(packet); err != nil {
   213  		t.Errorf("writePacket: %v", err)
   214  	}
   215  
   216  	// 2nd key change.
   217  	_, _, err = trC.sendKexInit()
   218  	if err != nil {
   219  		t.Errorf("sendKexInit: %v", err)
   220  	}
   221  
   222  	packet = make([]byte, 5)
   223  	packet[0] = msgRequestSuccess
   224  	if err := trC.writePacket(packet); err != nil {
   225  		t.Errorf("writePacket: %v", err)
   226  	}
   227  
   228  	packet = make([]byte, 5)
   229  	packet[0] = msgRequestSuccess
   230  	for i := 0; i < 5; i++ {
   231  		msg, err := trS.readPacket()
   232  		if err != nil {
   233  			t.Fatalf("server closed too soon: %v", err)
   234  		}
   235  		if msg[0] == msgNewKeys {
   236  			continue
   237  		}
   238  
   239  		if bytes.Compare(msg, packet) != 0 {
   240  			t.Errorf("packet %d: got %q want %q", i, msg, packet)
   241  		}
   242  	}
   243  	if len(checker.calls) != 2 {
   244  		t.Errorf("got %d key changes, want 2", len(checker.calls))
   245  	}
   246  }
   247  
   248  func TestHandshakeAutoRekeyWrite(t *testing.T) {
   249  	checker := &testChecker{}
   250  	clientConf := &ClientConfig{HostKeyCallback: checker.Check}
   251  	clientConf.RekeyThreshold = 500
   252  	trC, trS, err := handshakePair(clientConf, "addr")
   253  	if err != nil {
   254  		t.Fatalf("handshakePair: %v", err)
   255  	}
   256  	defer trC.Close()
   257  	defer trS.Close()
   258  
   259  	for i := 0; i < 5; i++ {
   260  		packet := make([]byte, 251)
   261  		packet[0] = msgRequestSuccess
   262  		if err := trC.writePacket(packet); err != nil {
   263  			t.Errorf("writePacket: %v", err)
   264  		}
   265  	}
   266  
   267  	j := 0
   268  	for ; j < 5; j++ {
   269  		_, err := trS.readPacket()
   270  		if err != nil {
   271  			break
   272  		}
   273  	}
   274  
   275  	if j != 5 {
   276  		t.Errorf("got %d, want 5 messages", j)
   277  	}
   278  
   279  	if len(checker.calls) != 2 {
   280  		t.Errorf("got %d key changes, wanted 2", len(checker.calls))
   281  	}
   282  }
   283  
   284  type syncChecker struct {
   285  	called chan int
   286  }
   287  
   288  func (t *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
   289  	t.called <- 1
   290  	return nil
   291  }
   292  
   293  func TestHandshakeAutoRekeyRead(t *testing.T) {
   294  	sync := &syncChecker{make(chan int, 2)}
   295  	clientConf := &ClientConfig{
   296  		HostKeyCallback: sync.Check,
   297  	}
   298  	clientConf.RekeyThreshold = 500
   299  
   300  	trC, trS, err := handshakePair(clientConf, "addr")
   301  	if err != nil {
   302  		t.Fatalf("handshakePair: %v", err)
   303  	}
   304  	defer trC.Close()
   305  	defer trS.Close()
   306  
   307  	packet := make([]byte, 501)
   308  	packet[0] = msgRequestSuccess
   309  	if err := trS.writePacket(packet); err != nil {
   310  		t.Fatalf("writePacket: %v", err)
   311  	}
   312  	// While we read out the packet, a key change will be
   313  	// initiated.
   314  	if _, err := trC.readPacket(); err != nil {
   315  		t.Fatalf("readPacket(client): %v", err)
   316  	}
   317  
   318  	<-sync.called
   319  }
   320  
   321  // errorKeyingTransport generates errors after a given number of
   322  // read/write operations.
   323  type errorKeyingTransport struct {
   324  	packetConn
   325  	readLeft, writeLeft int
   326  }
   327  
   328  func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error {
   329  	return nil
   330  }
   331  func (n *errorKeyingTransport) getSessionID() []byte {
   332  	return nil
   333  }
   334  
   335  func (n *errorKeyingTransport) writePacket(packet []byte) error {
   336  	if n.writeLeft == 0 {
   337  		n.Close()
   338  		return errors.New("barf")
   339  	}
   340  
   341  	n.writeLeft--
   342  	return n.packetConn.writePacket(packet)
   343  }
   344  
   345  func (n *errorKeyingTransport) readPacket() ([]byte, error) {
   346  	if n.readLeft == 0 {
   347  		n.Close()
   348  		return nil, errors.New("barf")
   349  	}
   350  
   351  	n.readLeft--
   352  	return n.packetConn.readPacket()
   353  }
   354  
   355  func TestHandshakeErrorHandlingRead(t *testing.T) {
   356  	for i := 0; i < 20; i++ {
   357  		testHandshakeErrorHandlingN(t, i, -1)
   358  	}
   359  }
   360  
   361  func TestHandshakeErrorHandlingWrite(t *testing.T) {
   362  	for i := 0; i < 20; i++ {
   363  		testHandshakeErrorHandlingN(t, -1, i)
   364  	}
   365  }
   366  
   367  // testHandshakeErrorHandlingN runs handshakes, injecting errors. If
   368  // handshakeTransport deadlocks, the go runtime will detect it and
   369  // panic.
   370  func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int) {
   371  	msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)})
   372  
   373  	a, b := memPipe()
   374  	defer a.Close()
   375  	defer b.Close()
   376  
   377  	key := testSigners["ecdsa"]
   378  	serverConf := Config{RekeyThreshold: minRekeyThreshold}
   379  	serverConf.SetDefaults()
   380  	serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'})
   381  	serverConn.hostKeys = []Signer{key}
   382  	go serverConn.readLoop()
   383  
   384  	clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold}
   385  	clientConf.SetDefaults()
   386  	clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'})
   387  	clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()}
   388  	go clientConn.readLoop()
   389  
   390  	var wg sync.WaitGroup
   391  	wg.Add(4)
   392  
   393  	for _, hs := range []packetConn{serverConn, clientConn} {
   394  		go func(c packetConn) {
   395  			for {
   396  				err := c.writePacket(msg)
   397  				if err != nil {
   398  					break
   399  				}
   400  			}
   401  			wg.Done()
   402  		}(hs)
   403  		go func(c packetConn) {
   404  			for {
   405  				_, err := c.readPacket()
   406  				if err != nil {
   407  					break
   408  				}
   409  			}
   410  			wg.Done()
   411  		}(hs)
   412  	}
   413  
   414  	wg.Wait()
   415  }