github.com/glycerine/xcryptossh@v7.0.4+incompatible/handshake_test.go (about) 1 // Copyright 2013 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package ssh 6 7 import ( 8 "bytes" 9 "context" 10 "crypto/rand" 11 "errors" 12 "fmt" 13 "io" 14 "net" 15 "reflect" 16 "runtime" 17 "strings" 18 "sync" 19 "testing" 20 ) 21 22 type testChecker struct { 23 calls []string 24 } 25 26 func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { 27 if dialAddr == "bad" { 28 return fmt.Errorf("dialAddr is bad") 29 } 30 31 if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil { 32 return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr) 33 } 34 35 t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal())) 36 37 return nil 38 } 39 40 // netPipe is analogous to net.Pipe, but it uses a real net.Conn, and 41 // therefore is buffered (net.Pipe deadlocks if both sides start with 42 // a write.) 43 func netPipe() (net.Conn, net.Conn, error) { 44 listener, err := net.Listen("tcp", "127.0.0.1:0") 45 if err != nil { 46 listener, err = net.Listen("tcp", "[::1]:0") 47 if err != nil { 48 return nil, nil, err 49 } 50 } 51 defer listener.Close() 52 c1, err := net.Dial("tcp", listener.Addr().String()) 53 if err != nil { 54 return nil, nil, err 55 } 56 57 c2, err := listener.Accept() 58 if err != nil { 59 c1.Close() 60 return nil, nil, err 61 } 62 63 return c1, c2, nil 64 } 65 66 // noiseTransport inserts ignore messages to check that the read loop 67 // and the key exchange filters out these messages. 68 type noiseTransport struct { 69 keyingTransport 70 } 71 72 func (t *noiseTransport) writePacket(p []byte) error { 73 ignore := []byte{msgIgnore} 74 if err := t.keyingTransport.writePacket(ignore); err != nil { 75 return err 76 } 77 debug := []byte{msgDebug, 1, 2, 3} 78 if err := t.keyingTransport.writePacket(debug); err != nil { 79 return err 80 } 81 82 return t.keyingTransport.writePacket(p) 83 } 84 85 func addNoiseTransport(t keyingTransport) keyingTransport { 86 return &noiseTransport{t} 87 } 88 89 // handshakePair creates two handshakeTransports connected with each 90 // other. If the noise argument is true, both transports will try to 91 // confuse the other side by sending ignore and debug messages. 92 func handshakePair(clientConf *ClientConfig, addr string, noise bool) (client *handshakeTransport, server *handshakeTransport, err error) { 93 a, b, err := netPipe() 94 if err != nil { 95 return nil, nil, err 96 } 97 98 var trC, trS keyingTransport 99 100 trC = newTransport(a, rand.Reader, true, &clientConf.Config) 101 trS = newTransport(b, rand.Reader, false, &clientConf.Config) 102 if noise { 103 trC = addNoiseTransport(trC) 104 trS = addNoiseTransport(trS) 105 } 106 clientConf.SetDefaults() 107 108 v := []byte("version") 109 ctx := context.Background() 110 111 client = newClientTransport(ctx, trC, v, v, clientConf, addr, a.RemoteAddr()) 112 113 serverConf := &ServerConfig{Config: Config{Halt: clientConf.Halt}} 114 serverConf.AddHostKey(testSigners["ecdsa"]) 115 serverConf.AddHostKey(testSigners["rsa"]) 116 serverConf.SetDefaults() 117 server = newServerTransport(ctx, trS, v, v, serverConf) 118 if server == nil { 119 return nil, nil, fmt.Errorf("ssh: shutting down.") 120 } 121 if err := server.waitSession(ctx); err != nil { 122 return nil, nil, fmt.Errorf("server.waitSession: %v", err) 123 } 124 if err := client.waitSession(ctx); err != nil { 125 return nil, nil, fmt.Errorf("client.waitSession: %v", err) 126 } 127 128 return client, server, nil 129 } 130 131 func TestHandshakeBasic(t *testing.T) { 132 defer xtestend(xtestbegin(t)) 133 134 if runtime.GOOS == "plan9" { 135 t.Skip("see golang.org/issue/7237") 136 } 137 138 checker := &syncChecker{ 139 waitCall: make(chan int, 10), 140 called: make(chan int, 10), 141 } 142 143 halt := NewHalter() 144 defer halt.RequestStop() 145 checker.waitCall <- 1 146 trC, trS, err := handshakePair( 147 &ClientConfig{ 148 HostKeyCallback: checker.Check, 149 Config: Config{ 150 Halt: halt}, 151 }, "addr", false) 152 153 if err != nil { 154 t.Fatalf("handshakePair: %v", err) 155 } 156 157 defer trC.Close() 158 defer trS.Close() 159 160 // Let first kex complete normally. 161 <-checker.called 162 163 clientDone := make(chan int, 0) 164 gotHalf := make(chan int, 0) 165 const N = 20 166 167 go func() { 168 defer close(clientDone) 169 // Client writes a bunch of stuff, and does a key 170 // change in the middle. This should not confuse the 171 // handshake in progress. We do this twice, so we test 172 // that the packet buffer is reset correctly. 173 for i := 0; i < N; i++ { 174 p := []byte{msgRequestSuccess, byte(i)} 175 if err := trC.writePacket(p); err != nil { 176 t.Fatalf("sendPacket: %v", err) 177 } 178 if (i % 10) == 5 { 179 <-gotHalf 180 // halfway through, we request a key change. 181 trC.requestKeyExchange() 182 183 // Wait until we can be sure the key 184 // change has really started before we 185 // write more. 186 <-checker.called 187 } 188 if (i % 10) == 7 { 189 // write some packets until the kex 190 // completes, to test buffering of 191 // packets. 192 checker.waitCall <- 1 193 } 194 } 195 }() 196 197 ctx := context.Background() 198 199 // Server checks that client messages come in cleanly 200 i := 0 201 err = nil 202 for ; i < N; i++ { 203 var p []byte 204 p, err = trS.readPacket(ctx) 205 if err != nil { 206 break 207 } 208 if (i % 10) == 5 { 209 gotHalf <- 1 210 } 211 212 want := []byte{msgRequestSuccess, byte(i)} 213 if bytes.Compare(p, want) != 0 { 214 t.Errorf("message %d: got %v, want %v", i, p, want) 215 } 216 } 217 <-clientDone 218 if err != nil && err != io.EOF { 219 t.Fatalf("server error: %v", err) 220 } 221 if i != N { 222 t.Errorf("received %d messages, want 10.", i) 223 } 224 225 close(checker.called) 226 if _, ok := <-checker.called; ok { 227 // If all went well, we registered exactly 2 key changes: one 228 // that establishes the session, and one that we requested 229 // additionally. 230 t.Fatalf("got another host key checks after 2 handshakes") 231 } 232 } 233 234 func TestForceFirstKex(t *testing.T) { 235 defer xtestend(xtestbegin(t)) 236 237 halt := NewHalter() 238 defer halt.RequestStop() 239 240 // like handshakePair, but must access the keyingTransport. 241 checker := &testChecker{} 242 clientConf := &ClientConfig{ 243 HostKeyCallback: checker.Check, 244 Config: Config{ 245 Halt: halt, 246 }, 247 } 248 a, b, err := netPipe() 249 if err != nil { 250 t.Fatalf("netPipe: %v", err) 251 } 252 253 var trC, trS keyingTransport 254 255 trC = newTransport(a, rand.Reader, true, nil) 256 257 // This is the disallowed packet: 258 trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})) 259 260 // Rest of the setup. 261 trS = newTransport(b, rand.Reader, false, nil) 262 clientConf.SetDefaults() 263 264 v := []byte("version") 265 ctx := context.Background() 266 client := newClientTransport(ctx, trC, v, v, clientConf, "addr", a.RemoteAddr()) 267 268 serverConf := &ServerConfig{Config: Config{Halt: halt}} 269 serverConf.AddHostKey(testSigners["ecdsa"]) 270 serverConf.AddHostKey(testSigners["rsa"]) 271 serverConf.SetDefaults() 272 server := newServerTransport(ctx, trS, v, v, serverConf) 273 if server == nil { 274 // shut down 275 return 276 } 277 278 defer client.Close() 279 defer server.Close() 280 281 // We setup the initial key exchange, but the remote side 282 // tries to send serviceRequestMsg in cleartext, which is 283 // disallowed. 284 285 if err := server.waitSession(ctx); err == nil { 286 t.Errorf("server first kex init should reject unexpected packet") 287 } 288 } 289 290 func TestHandshakeAutoRekeyWrite(t *testing.T) { 291 defer xtestend(xtestbegin(t)) 292 293 halt := NewHalter() 294 defer halt.RequestStop() 295 checker := &syncChecker{ 296 called: make(chan int, 10), 297 waitCall: nil, 298 } 299 clientConf := &ClientConfig{ 300 HostKeyCallback: checker.Check, 301 Config: Config{Halt: halt}, 302 } 303 clientConf.RekeyThreshold = 500 304 trC, trS, err := handshakePair(clientConf, "addr", false) 305 if err != nil { 306 t.Fatalf("handshakePair: %v", err) 307 } 308 defer trC.Close() 309 defer trS.Close() 310 311 input := make([]byte, 251) 312 input[0] = msgRequestSuccess 313 ctx := context.Background() 314 315 done := make(chan int, 1) 316 const numPacket = 5 317 go func() { 318 defer close(done) 319 j := 0 320 for ; j < numPacket; j++ { 321 if p, err := trS.readPacket(ctx); err != nil { 322 break 323 } else if !bytes.Equal(input, p) { 324 t.Errorf("got packet type %d, want %d", p[0], input[0]) 325 } 326 } 327 328 if j != numPacket { 329 t.Errorf("got %d, want 5 messages", j) 330 } 331 }() 332 333 <-checker.called 334 335 for i := 0; i < numPacket; i++ { 336 p := make([]byte, len(input)) 337 copy(p, input) 338 if err := trC.writePacket(p); err != nil { 339 t.Errorf("writePacket: %v", err) 340 } 341 if i == 2 { 342 // Make sure the kex is in progress. 343 <-checker.called 344 } 345 346 } 347 <-done 348 } 349 350 type syncChecker struct { 351 waitCall chan int 352 called chan int 353 } 354 355 func (c *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { 356 c.called <- 1 357 if c.waitCall != nil { 358 <-c.waitCall 359 } 360 return nil 361 } 362 363 func TestHandshakeAutoRekeyRead(t *testing.T) { 364 defer xtestend(xtestbegin(t)) 365 366 halt := NewHalter() 367 defer halt.RequestStop() 368 369 sync := &syncChecker{ 370 called: make(chan int, 2), 371 waitCall: nil, 372 } 373 clientConf := &ClientConfig{ 374 HostKeyCallback: sync.Check, 375 Config: Config{Halt: halt}, 376 } 377 clientConf.RekeyThreshold = 500 378 379 trC, trS, err := handshakePair(clientConf, "addr", false) 380 if err != nil { 381 t.Fatalf("handshakePair: %v", err) 382 } 383 defer trC.Close() 384 defer trS.Close() 385 386 packet := make([]byte, 501) 387 packet[0] = msgRequestSuccess 388 if err := trS.writePacket(packet); err != nil { 389 t.Fatalf("writePacket: %v", err) 390 } 391 392 // While we read out the packet, a key change will be 393 // initiated. 394 done := make(chan int, 1) 395 ctx := context.Background() 396 397 go func() { 398 defer close(done) 399 if _, err := trC.readPacket(ctx); err != nil { 400 t.Fatalf("readPacket(client): %v", err) 401 } 402 403 }() 404 405 <-done 406 <-sync.called 407 } 408 409 // errorKeyingTransport generates errors after a given number of 410 // read/write operations. 411 type errorKeyingTransport struct { 412 packetConn 413 readLeft, writeLeft int 414 } 415 416 func (n *errorKeyingTransport) prepareKeyChange(context.Context, *algorithms, *kexResult, *Config) error { 417 return nil 418 } 419 420 func (n *errorKeyingTransport) getSessionID() []byte { 421 return nil 422 } 423 424 func (n *errorKeyingTransport) writePacket(packet []byte) error { 425 if n.writeLeft == 0 { 426 n.Close() 427 return errors.New("barf") 428 } 429 430 n.writeLeft-- 431 return n.packetConn.writePacket(packet) 432 } 433 434 func (n *errorKeyingTransport) readPacket(ctx context.Context) ([]byte, error) { 435 if n.readLeft == 0 { 436 n.Close() 437 return nil, errors.New("barf") 438 } 439 440 n.readLeft-- 441 442 return n.packetConn.readPacket(ctx) 443 } 444 445 func TestHandshakeErrorHandlingRead(t *testing.T) { 446 defer xtestend(xtestbegin(t)) 447 448 for i := 0; i < 20; i++ { 449 testHandshakeErrorHandlingN(t, i, -1, false) 450 } 451 } 452 453 func TestHandshakeErrorHandlingWrite(t *testing.T) { 454 defer xtestend(xtestbegin(t)) 455 456 for i := 0; i < 20; i++ { 457 testHandshakeErrorHandlingN(t, -1, i, false) 458 } 459 } 460 461 func TestHandshakeErrorHandlingReadCoupled(t *testing.T) { 462 defer xtestend(xtestbegin(t)) 463 464 for i := 0; i < 20; i++ { 465 testHandshakeErrorHandlingN(t, i, -1, true) 466 } 467 } 468 469 func TestHandshakeErrorHandlingWriteCoupled(t *testing.T) { 470 defer xtestend(xtestbegin(t)) 471 472 for i := 0; i < 20; i++ { 473 testHandshakeErrorHandlingN(t, -1, i, true) 474 } 475 } 476 477 // testHandshakeErrorHandlingN runs handshakes, injecting errors. If 478 // handshakeTransport deadlocks, the go runtime will detect it and 479 // panic. 480 func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int, coupled bool) { 481 msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)}) 482 483 a, b := memPipe() 484 defer a.Close() 485 defer b.Close() 486 487 key := testSigners["ecdsa"] 488 serverConf := Config{RekeyThreshold: minRekeyThreshold} 489 serverConf.SetDefaults() 490 defer serverConf.Halt.RequestStop() 491 ctx := context.Background() 492 serverConn := newHandshakeTransport(ctx, &errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'}) 493 serverConn.hostKeys = []Signer{key} 494 go serverConn.readLoop(ctx) 495 go serverConn.kexLoop(ctx) 496 497 clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold} 498 clientConf.SetDefaults() 499 defer clientConf.Halt.RequestStop() 500 clientConn := newHandshakeTransport(ctx, &errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'}) 501 clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()} 502 clientConn.hostKeyCallback = InsecureIgnoreHostKey() 503 go clientConn.readLoop(ctx) 504 go clientConn.kexLoop(ctx) 505 506 var wg sync.WaitGroup 507 508 for _, hs := range []packetConn{serverConn, clientConn} { 509 if !coupled { 510 wg.Add(2) 511 go func(c packetConn) { 512 for i := 0; ; i++ { 513 str := fmt.Sprintf("%08x", i) + strings.Repeat("x", int(minRekeyThreshold)/4-8) 514 err := c.writePacket(Marshal(&serviceRequestMsg{str})) 515 if err != nil { 516 break 517 } 518 } 519 wg.Done() 520 c.Close() 521 }(hs) 522 go func(c packetConn) { 523 for { 524 _, err := c.readPacket(ctx) 525 if err != nil { 526 break 527 } 528 } 529 wg.Done() 530 }(hs) 531 } else { 532 wg.Add(1) 533 go func(c packetConn) { 534 for { 535 _, err := c.readPacket(ctx) 536 if err != nil { 537 break 538 } 539 if err := c.writePacket(msg); err != nil { 540 break 541 } 542 543 } 544 wg.Done() 545 }(hs) 546 } 547 } 548 wg.Wait() 549 } 550 551 func TestDisconnect(t *testing.T) { 552 defer xtestend(xtestbegin(t)) 553 554 halt := NewHalter() 555 defer halt.RequestStop() 556 557 if runtime.GOOS == "plan9" { 558 t.Skip("see golang.org/issue/7237") 559 } 560 checker := &testChecker{} 561 trC, trS, err := handshakePair( 562 &ClientConfig{ 563 HostKeyCallback: checker.Check, 564 Config: Config{Halt: halt}, 565 }, "addr", false) 566 if err != nil { 567 t.Fatalf("handshakePair: %v", err) 568 } 569 570 defer trC.Close() 571 defer trS.Close() 572 573 trC.writePacket([]byte{msgRequestSuccess, 0, 0}) 574 errMsg := &disconnectMsg{ 575 Reason: 42, 576 Message: "such is life", 577 } 578 trC.writePacket(Marshal(errMsg)) 579 trC.writePacket([]byte{msgRequestSuccess, 0, 0}) 580 ctx := context.Background() 581 packet, err := trS.readPacket(ctx) 582 if err != nil { 583 t.Fatalf("readPacket 1: %v", err) 584 } 585 if packet[0] != msgRequestSuccess { 586 t.Errorf("got packet %v, want packet type %d", packet, msgRequestSuccess) 587 } 588 589 _, err = trS.readPacket(ctx) 590 if err == nil { 591 t.Errorf("readPacket 2 succeeded") 592 } else if !reflect.DeepEqual(err, errMsg) { 593 t.Errorf("got error %#v, want %#v", err, errMsg) 594 } 595 596 _, err = trS.readPacket(ctx) 597 if err == nil { 598 t.Errorf("readPacket 3 succeeded") 599 } 600 } 601 602 func TestHandshakeRekeyDefault(t *testing.T) { 603 defer xtestend(xtestbegin(t)) 604 605 halt := NewHalter() 606 defer halt.RequestStop() 607 608 clientConf := &ClientConfig{ 609 Config: Config{ 610 Ciphers: []string{"aes128-ctr"}, 611 Halt: halt, 612 }, 613 HostKeyCallback: InsecureIgnoreHostKey(), 614 } 615 trC, trS, err := handshakePair(clientConf, "addr", false) 616 if err != nil { 617 t.Fatalf("handshakePair: %v", err) 618 } 619 defer trC.Close() 620 defer trS.Close() 621 622 trC.writePacket([]byte{msgRequestSuccess, 0, 0}) 623 trC.Close() 624 625 rgb := (1024 + trC.readBytesLeft) >> 30 626 wgb := (1024 + trC.writeBytesLeft) >> 30 627 628 if rgb != 64 { 629 t.Errorf("got rekey after %dG read, want 64G", rgb) 630 } 631 if wgb != 64 { 632 t.Errorf("got rekey after %dG write, want 64G", wgb) 633 } 634 }