github.com/deis/deis@v1.13.5-0.20170519182049-1d9e59fbdbfc/Godeps/_workspace/src/golang.org/x/crypto/ssh/handshake_test.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssh
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/rand"
    10  	"fmt"
    11  	"net"
    12  	"testing"
    13  )
    14  
    15  type testChecker struct {
    16  	calls []string
    17  }
    18  
    19  func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
    20  	if dialAddr == "bad" {
    21  		return fmt.Errorf("dialAddr is bad")
    22  	}
    23  
    24  	if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil {
    25  		return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr)
    26  	}
    27  
    28  	t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal()))
    29  
    30  	return nil
    31  }
    32  
    33  // netPipe is analogous to net.Pipe, but it uses a real net.Conn, and
    34  // therefore is buffered (net.Pipe deadlocks if both sides start with
    35  // a write.)
    36  func netPipe() (net.Conn, net.Conn, error) {
    37  	listener, err := net.Listen("tcp", "127.0.0.1:0")
    38  	if err != nil {
    39  		return nil, nil, err
    40  	}
    41  	defer listener.Close()
    42  	c1, err := net.Dial("tcp", listener.Addr().String())
    43  	if err != nil {
    44  		return nil, nil, err
    45  	}
    46  
    47  	c2, err := listener.Accept()
    48  	if err != nil {
    49  		c1.Close()
    50  		return nil, nil, err
    51  	}
    52  
    53  	return c1, c2, nil
    54  }
    55  
    56  func handshakePair(clientConf *ClientConfig, addr string) (client *handshakeTransport, server *handshakeTransport, err error) {
    57  	a, b, err := netPipe()
    58  	if err != nil {
    59  		return nil, nil, err
    60  	}
    61  
    62  	trC := newTransport(a, rand.Reader, true)
    63  	trS := newTransport(b, rand.Reader, false)
    64  	clientConf.SetDefaults()
    65  
    66  	v := []byte("version")
    67  	client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr())
    68  
    69  	serverConf := &ServerConfig{}
    70  	serverConf.AddHostKey(testSigners["ecdsa"])
    71  	serverConf.SetDefaults()
    72  	server = newServerTransport(trS, v, v, serverConf)
    73  
    74  	return client, server, nil
    75  }
    76  
    77  func TestHandshakeBasic(t *testing.T) {
    78  	checker := &testChecker{}
    79  	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
    80  	if err != nil {
    81  		t.Fatalf("handshakePair: %v", err)
    82  	}
    83  
    84  	defer trC.Close()
    85  	defer trS.Close()
    86  
    87  	go func() {
    88  		// Client writes a bunch of stuff, and does a key
    89  		// change in the middle. This should not confuse the
    90  		// handshake in progress
    91  		for i := 0; i < 10; i++ {
    92  			p := []byte{msgRequestSuccess, byte(i)}
    93  			if err := trC.writePacket(p); err != nil {
    94  				t.Fatalf("sendPacket: %v", err)
    95  			}
    96  			if i == 5 {
    97  				// halfway through, we request a key change.
    98  				_, _, err := trC.sendKexInit()
    99  				if err != nil {
   100  					t.Fatalf("sendKexInit: %v", err)
   101  				}
   102  			}
   103  		}
   104  		trC.Close()
   105  	}()
   106  
   107  	// Server checks that client messages come in cleanly
   108  	i := 0
   109  	for {
   110  		p, err := trS.readPacket()
   111  		if err != nil {
   112  			break
   113  		}
   114  		if p[0] == msgNewKeys {
   115  			continue
   116  		}
   117  		want := []byte{msgRequestSuccess, byte(i)}
   118  		if bytes.Compare(p, want) != 0 {
   119  			t.Errorf("message %d: got %q, want %q", i, p, want)
   120  		}
   121  		i++
   122  	}
   123  	if i != 10 {
   124  		t.Errorf("received %d messages, want 10.", i)
   125  	}
   126  
   127  	// If all went well, we registered exactly 1 key change.
   128  	if len(checker.calls) != 1 {
   129  		t.Fatalf("got %d host key checks, want 1", len(checker.calls))
   130  	}
   131  
   132  	pub := testSigners["ecdsa"].PublicKey()
   133  	want := fmt.Sprintf("%s %v %s %x", "addr", trC.remoteAddr, pub.Type(), pub.Marshal())
   134  	if want != checker.calls[0] {
   135  		t.Errorf("got %q want %q for host key check", checker.calls[0], want)
   136  	}
   137  }
   138  
   139  func TestHandshakeError(t *testing.T) {
   140  	checker := &testChecker{}
   141  	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "bad")
   142  	if err != nil {
   143  		t.Fatalf("handshakePair: %v", err)
   144  	}
   145  	defer trC.Close()
   146  	defer trS.Close()
   147  
   148  	// send a packet
   149  	packet := []byte{msgRequestSuccess, 42}
   150  	if err := trC.writePacket(packet); err != nil {
   151  		t.Errorf("writePacket: %v", err)
   152  	}
   153  
   154  	// Now request a key change.
   155  	_, _, err = trC.sendKexInit()
   156  	if err != nil {
   157  		t.Errorf("sendKexInit: %v", err)
   158  	}
   159  
   160  	// the key change will fail, and afterwards we can't write.
   161  	if err := trC.writePacket([]byte{msgRequestSuccess, 43}); err == nil {
   162  		t.Errorf("writePacket after botched rekey succeeded.")
   163  	}
   164  
   165  	readback, err := trS.readPacket()
   166  	if err != nil {
   167  		t.Fatalf("server closed too soon: %v", err)
   168  	}
   169  	if bytes.Compare(readback, packet) != 0 {
   170  		t.Errorf("got %q want %q", readback, packet)
   171  	}
   172  	readback, err = trS.readPacket()
   173  	if err == nil {
   174  		t.Errorf("got a message %q after failed key change", readback)
   175  	}
   176  }
   177  
   178  func TestHandshakeTwice(t *testing.T) {
   179  	checker := &testChecker{}
   180  	trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr")
   181  	if err != nil {
   182  		t.Fatalf("handshakePair: %v", err)
   183  	}
   184  
   185  	defer trC.Close()
   186  	defer trS.Close()
   187  
   188  	// send a packet
   189  	packet := make([]byte, 5)
   190  	packet[0] = msgRequestSuccess
   191  	if err := trC.writePacket(packet); err != nil {
   192  		t.Errorf("writePacket: %v", err)
   193  	}
   194  
   195  	// Now request a key change.
   196  	_, _, err = trC.sendKexInit()
   197  	if err != nil {
   198  		t.Errorf("sendKexInit: %v", err)
   199  	}
   200  
   201  	// Send another packet. Use a fresh one, since writePacket destroys.
   202  	packet = make([]byte, 5)
   203  	packet[0] = msgRequestSuccess
   204  	if err := trC.writePacket(packet); err != nil {
   205  		t.Errorf("writePacket: %v", err)
   206  	}
   207  
   208  	// 2nd key change.
   209  	_, _, err = trC.sendKexInit()
   210  	if err != nil {
   211  		t.Errorf("sendKexInit: %v", err)
   212  	}
   213  
   214  	packet = make([]byte, 5)
   215  	packet[0] = msgRequestSuccess
   216  	if err := trC.writePacket(packet); err != nil {
   217  		t.Errorf("writePacket: %v", err)
   218  	}
   219  
   220  	packet = make([]byte, 5)
   221  	packet[0] = msgRequestSuccess
   222  	for i := 0; i < 5; i++ {
   223  		msg, err := trS.readPacket()
   224  		if err != nil {
   225  			t.Fatalf("server closed too soon: %v", err)
   226  		}
   227  		if msg[0] == msgNewKeys {
   228  			continue
   229  		}
   230  
   231  		if bytes.Compare(msg, packet) != 0 {
   232  			t.Errorf("packet %d: got %q want %q", i, msg, packet)
   233  		}
   234  	}
   235  	if len(checker.calls) != 2 {
   236  		t.Errorf("got %d key changes, want 2", len(checker.calls))
   237  	}
   238  }
   239  
   240  func TestHandshakeAutoRekeyWrite(t *testing.T) {
   241  	checker := &testChecker{}
   242  	clientConf := &ClientConfig{HostKeyCallback: checker.Check}
   243  	clientConf.RekeyThreshold = 500
   244  	trC, trS, err := handshakePair(clientConf, "addr")
   245  	if err != nil {
   246  		t.Fatalf("handshakePair: %v", err)
   247  	}
   248  	defer trC.Close()
   249  	defer trS.Close()
   250  
   251  	for i := 0; i < 5; i++ {
   252  		packet := make([]byte, 251)
   253  		packet[0] = msgRequestSuccess
   254  		if err := trC.writePacket(packet); err != nil {
   255  			t.Errorf("writePacket: %v", err)
   256  		}
   257  	}
   258  
   259  	j := 0
   260  	for ; j < 5; j++ {
   261  		_, err := trS.readPacket()
   262  		if err != nil {
   263  			break
   264  		}
   265  	}
   266  
   267  	if j != 5 {
   268  		t.Errorf("got %d, want 5 messages", j)
   269  	}
   270  
   271  	if len(checker.calls) != 2 {
   272  		t.Errorf("got %d key changes, wanted 2", len(checker.calls))
   273  	}
   274  }
   275  
   276  type syncChecker struct {
   277  	called chan int
   278  }
   279  
   280  func (t *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error {
   281  	t.called <- 1
   282  	return nil
   283  }
   284  
   285  func TestHandshakeAutoRekeyRead(t *testing.T) {
   286  	sync := &syncChecker{make(chan int, 2)}
   287  	clientConf := &ClientConfig{
   288  		HostKeyCallback: sync.Check,
   289  	}
   290  	clientConf.RekeyThreshold = 500
   291  
   292  	trC, trS, err := handshakePair(clientConf, "addr")
   293  	if err != nil {
   294  		t.Fatalf("handshakePair: %v", err)
   295  	}
   296  	defer trC.Close()
   297  	defer trS.Close()
   298  
   299  	packet := make([]byte, 501)
   300  	packet[0] = msgRequestSuccess
   301  	if err := trS.writePacket(packet); err != nil {
   302  		t.Fatalf("writePacket: %v", err)
   303  	}
   304  	// While we read out the packet, a key change will be
   305  	// initiated.
   306  	if _, err := trC.readPacket(); err != nil {
   307  		t.Fatalf("readPacket(client): %v", err)
   308  	}
   309  
   310  	<-sync.called
   311  }