github.com/psiphon-Labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/crypto/ssh/handshake_test.go (about) 1 // Copyright 2013 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package ssh 6 7 import ( 8 "bytes" 9 "crypto/rand" 10 "errors" 11 "fmt" 12 "io" 13 "net" 14 "reflect" 15 "runtime" 16 "strings" 17 "sync" 18 "testing" 19 ) 20 21 type testChecker struct { 22 calls []string 23 } 24 25 func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { 26 if dialAddr == "bad" { 27 return fmt.Errorf("dialAddr is bad") 28 } 29 30 if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil { 31 return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr) 32 } 33 34 t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal())) 35 36 return nil 37 } 38 39 // netPipe is analogous to net.Pipe, but it uses a real net.Conn, and 40 // therefore is buffered (net.Pipe deadlocks if both sides start with 41 // a write.) 42 func netPipe() (net.Conn, net.Conn, error) { 43 listener, err := net.Listen("tcp", "127.0.0.1:0") 44 if err != nil { 45 listener, err = net.Listen("tcp", "[::1]:0") 46 if err != nil { 47 return nil, nil, err 48 } 49 } 50 defer listener.Close() 51 c1, err := net.Dial("tcp", listener.Addr().String()) 52 if err != nil { 53 return nil, nil, err 54 } 55 56 c2, err := listener.Accept() 57 if err != nil { 58 c1.Close() 59 return nil, nil, err 60 } 61 62 return c1, c2, nil 63 } 64 65 // noiseTransport inserts ignore messages to check that the read loop 66 // and the key exchange filters out these messages. 67 type noiseTransport struct { 68 keyingTransport 69 } 70 71 func (t *noiseTransport) writePacket(p []byte) error { 72 ignore := []byte{msgIgnore} 73 if err := t.keyingTransport.writePacket(ignore); err != nil { 74 return err 75 } 76 debug := []byte{msgDebug, 1, 2, 3} 77 if err := t.keyingTransport.writePacket(debug); err != nil { 78 return err 79 } 80 81 return t.keyingTransport.writePacket(p) 82 } 83 84 func addNoiseTransport(t keyingTransport) keyingTransport { 85 return &noiseTransport{t} 86 } 87 88 // handshakePair creates two handshakeTransports connected with each 89 // other. If the noise argument is true, both transports will try to 90 // confuse the other side by sending ignore and debug messages. 91 func handshakePair(clientConf *ClientConfig, addr string, noise bool) (client *handshakeTransport, server *handshakeTransport, err error) { 92 a, b, err := netPipe() 93 if err != nil { 94 return nil, nil, err 95 } 96 97 var trC, trS keyingTransport 98 99 trC = newTransport(a, rand.Reader, true) 100 trS = newTransport(b, rand.Reader, false) 101 if noise { 102 trC = addNoiseTransport(trC) 103 trS = addNoiseTransport(trS) 104 } 105 clientConf.SetDefaults() 106 107 v := []byte("version") 108 client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr()) 109 110 serverConf := &ServerConfig{} 111 serverConf.AddHostKey(testSigners["ecdsa"]) 112 serverConf.AddHostKey(testSigners["rsa"]) 113 serverConf.SetDefaults() 114 server = newServerTransport(trS, v, v, serverConf) 115 116 if err := server.waitSession(); err != nil { 117 return nil, nil, fmt.Errorf("server.waitSession: %v", err) 118 } 119 if err := client.waitSession(); err != nil { 120 return nil, nil, fmt.Errorf("client.waitSession: %v", err) 121 } 122 123 return client, server, nil 124 } 125 126 func TestHandshakeBasic(t *testing.T) { 127 if runtime.GOOS == "plan9" { 128 t.Skip("see golang.org/issue/7237") 129 } 130 131 checker := &syncChecker{ 132 waitCall: make(chan int, 10), 133 called: make(chan int, 10), 134 } 135 136 checker.waitCall <- 1 137 trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false) 138 if err != nil { 139 t.Fatalf("handshakePair: %v", err) 140 } 141 142 defer trC.Close() 143 defer trS.Close() 144 145 // Let first kex complete normally. 146 <-checker.called 147 148 clientDone := make(chan int, 0) 149 gotHalf := make(chan int, 0) 150 const N = 20 151 152 go func() { 153 defer close(clientDone) 154 // Client writes a bunch of stuff, and does a key 155 // change in the middle. This should not confuse the 156 // handshake in progress. We do this twice, so we test 157 // that the packet buffer is reset correctly. 158 for i := 0; i < N; i++ { 159 p := []byte{msgRequestSuccess, byte(i)} 160 if err := trC.writePacket(p); err != nil { 161 t.Fatalf("sendPacket: %v", err) 162 } 163 if (i % 10) == 5 { 164 <-gotHalf 165 // halfway through, we request a key change. 166 trC.requestKeyExchange() 167 168 // Wait until we can be sure the key 169 // change has really started before we 170 // write more. 171 <-checker.called 172 } 173 if (i % 10) == 7 { 174 // write some packets until the kex 175 // completes, to test buffering of 176 // packets. 177 checker.waitCall <- 1 178 } 179 } 180 }() 181 182 // Server checks that client messages come in cleanly 183 i := 0 184 err = nil 185 for ; i < N; i++ { 186 var p []byte 187 p, err = trS.readPacket() 188 if err != nil { 189 break 190 } 191 if (i % 10) == 5 { 192 gotHalf <- 1 193 } 194 195 want := []byte{msgRequestSuccess, byte(i)} 196 if bytes.Compare(p, want) != 0 { 197 t.Errorf("message %d: got %v, want %v", i, p, want) 198 } 199 } 200 <-clientDone 201 if err != nil && err != io.EOF { 202 t.Fatalf("server error: %v", err) 203 } 204 if i != N { 205 t.Errorf("received %d messages, want 10.", i) 206 } 207 208 close(checker.called) 209 if _, ok := <-checker.called; ok { 210 // If all went well, we registered exactly 2 key changes: one 211 // that establishes the session, and one that we requested 212 // additionally. 213 t.Fatalf("got another host key checks after 2 handshakes") 214 } 215 } 216 217 func TestForceFirstKex(t *testing.T) { 218 // like handshakePair, but must access the keyingTransport. 219 checker := &testChecker{} 220 clientConf := &ClientConfig{HostKeyCallback: checker.Check} 221 a, b, err := netPipe() 222 if err != nil { 223 t.Fatalf("netPipe: %v", err) 224 } 225 226 var trC, trS keyingTransport 227 228 trC = newTransport(a, rand.Reader, true) 229 230 // This is the disallowed packet: 231 trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})) 232 233 // Rest of the setup. 234 trS = newTransport(b, rand.Reader, false) 235 clientConf.SetDefaults() 236 237 v := []byte("version") 238 client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr()) 239 240 serverConf := &ServerConfig{} 241 serverConf.AddHostKey(testSigners["ecdsa"]) 242 serverConf.AddHostKey(testSigners["rsa"]) 243 serverConf.SetDefaults() 244 server := newServerTransport(trS, v, v, serverConf) 245 246 defer client.Close() 247 defer server.Close() 248 249 // We setup the initial key exchange, but the remote side 250 // tries to send serviceRequestMsg in cleartext, which is 251 // disallowed. 252 253 if err := server.waitSession(); err == nil { 254 t.Errorf("server first kex init should reject unexpected packet") 255 } 256 } 257 258 func TestHandshakeAutoRekeyWrite(t *testing.T) { 259 checker := &syncChecker{ 260 called: make(chan int, 10), 261 waitCall: nil, 262 } 263 clientConf := &ClientConfig{HostKeyCallback: checker.Check} 264 clientConf.RekeyThreshold = 500 265 trC, trS, err := handshakePair(clientConf, "addr", false) 266 if err != nil { 267 t.Fatalf("handshakePair: %v", err) 268 } 269 defer trC.Close() 270 defer trS.Close() 271 272 input := make([]byte, 251) 273 input[0] = msgRequestSuccess 274 275 done := make(chan int, 1) 276 const numPacket = 5 277 go func() { 278 defer close(done) 279 j := 0 280 for ; j < numPacket; j++ { 281 if p, err := trS.readPacket(); err != nil { 282 break 283 } else if !bytes.Equal(input, p) { 284 t.Errorf("got packet type %d, want %d", p[0], input[0]) 285 } 286 } 287 288 if j != numPacket { 289 t.Errorf("got %d, want 5 messages", j) 290 } 291 }() 292 293 <-checker.called 294 295 for i := 0; i < numPacket; i++ { 296 p := make([]byte, len(input)) 297 copy(p, input) 298 if err := trC.writePacket(p); err != nil { 299 t.Errorf("writePacket: %v", err) 300 } 301 if i == 2 { 302 // Make sure the kex is in progress. 303 <-checker.called 304 } 305 306 } 307 <-done 308 } 309 310 type syncChecker struct { 311 waitCall chan int 312 called chan int 313 } 314 315 func (c *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { 316 c.called <- 1 317 if c.waitCall != nil { 318 <-c.waitCall 319 } 320 return nil 321 } 322 323 func TestHandshakeAutoRekeyRead(t *testing.T) { 324 sync := &syncChecker{ 325 called: make(chan int, 2), 326 waitCall: nil, 327 } 328 clientConf := &ClientConfig{ 329 HostKeyCallback: sync.Check, 330 } 331 clientConf.RekeyThreshold = 500 332 333 trC, trS, err := handshakePair(clientConf, "addr", false) 334 if err != nil { 335 t.Fatalf("handshakePair: %v", err) 336 } 337 defer trC.Close() 338 defer trS.Close() 339 340 packet := make([]byte, 501) 341 packet[0] = msgRequestSuccess 342 if err := trS.writePacket(packet); err != nil { 343 t.Fatalf("writePacket: %v", err) 344 } 345 346 // While we read out the packet, a key change will be 347 // initiated. 348 done := make(chan int, 1) 349 go func() { 350 defer close(done) 351 if _, err := trC.readPacket(); err != nil { 352 t.Fatalf("readPacket(client): %v", err) 353 } 354 355 }() 356 357 <-done 358 <-sync.called 359 } 360 361 // errorKeyingTransport generates errors after a given number of 362 // read/write operations. 363 type errorKeyingTransport struct { 364 packetConn 365 readLeft, writeLeft int 366 } 367 368 func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error { 369 return nil 370 } 371 372 func (n *errorKeyingTransport) getSessionID() []byte { 373 return nil 374 } 375 376 func (n *errorKeyingTransport) writePacket(packet []byte) error { 377 if n.writeLeft == 0 { 378 n.Close() 379 return errors.New("barf") 380 } 381 382 n.writeLeft-- 383 return n.packetConn.writePacket(packet) 384 } 385 386 func (n *errorKeyingTransport) readPacket() ([]byte, error) { 387 if n.readLeft == 0 { 388 n.Close() 389 return nil, errors.New("barf") 390 } 391 392 n.readLeft-- 393 return n.packetConn.readPacket() 394 } 395 396 func TestHandshakeErrorHandlingRead(t *testing.T) { 397 for i := 0; i < 20; i++ { 398 testHandshakeErrorHandlingN(t, i, -1, false) 399 } 400 } 401 402 func TestHandshakeErrorHandlingWrite(t *testing.T) { 403 for i := 0; i < 20; i++ { 404 testHandshakeErrorHandlingN(t, -1, i, false) 405 } 406 } 407 408 func TestHandshakeErrorHandlingReadCoupled(t *testing.T) { 409 for i := 0; i < 20; i++ { 410 testHandshakeErrorHandlingN(t, i, -1, true) 411 } 412 } 413 414 func TestHandshakeErrorHandlingWriteCoupled(t *testing.T) { 415 for i := 0; i < 20; i++ { 416 testHandshakeErrorHandlingN(t, -1, i, true) 417 } 418 } 419 420 // testHandshakeErrorHandlingN runs handshakes, injecting errors. If 421 // handshakeTransport deadlocks, the go runtime will detect it and 422 // panic. 423 func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int, coupled bool) { 424 if runtime.GOOS == "js" && runtime.GOARCH == "wasm" { 425 t.Skip("skipping on js/wasm; see golang.org/issue/32840") 426 } 427 msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)}) 428 429 a, b := memPipe() 430 defer a.Close() 431 defer b.Close() 432 433 key := testSigners["ecdsa"] 434 serverConf := Config{RekeyThreshold: minRekeyThreshold} 435 serverConf.SetDefaults() 436 serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'}) 437 serverConn.hostKeys = []Signer{key} 438 go serverConn.readLoop() 439 go serverConn.kexLoop() 440 441 clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold} 442 clientConf.SetDefaults() 443 clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'}) 444 clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()} 445 clientConn.hostKeyCallback = InsecureIgnoreHostKey() 446 go clientConn.readLoop() 447 go clientConn.kexLoop() 448 449 var wg sync.WaitGroup 450 451 for _, hs := range []packetConn{serverConn, clientConn} { 452 if !coupled { 453 wg.Add(2) 454 go func(c packetConn) { 455 for i := 0; ; i++ { 456 str := fmt.Sprintf("%08x", i) + strings.Repeat("x", int(minRekeyThreshold)/4-8) 457 err := c.writePacket(Marshal(&serviceRequestMsg{str})) 458 if err != nil { 459 break 460 } 461 } 462 wg.Done() 463 c.Close() 464 }(hs) 465 go func(c packetConn) { 466 for { 467 _, err := c.readPacket() 468 if err != nil { 469 break 470 } 471 } 472 wg.Done() 473 }(hs) 474 } else { 475 wg.Add(1) 476 go func(c packetConn) { 477 for { 478 _, err := c.readPacket() 479 if err != nil { 480 break 481 } 482 if err := c.writePacket(msg); err != nil { 483 break 484 } 485 486 } 487 wg.Done() 488 }(hs) 489 } 490 } 491 wg.Wait() 492 } 493 494 func TestDisconnect(t *testing.T) { 495 if runtime.GOOS == "plan9" { 496 t.Skip("see golang.org/issue/7237") 497 } 498 checker := &testChecker{} 499 trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false) 500 if err != nil { 501 t.Fatalf("handshakePair: %v", err) 502 } 503 504 defer trC.Close() 505 defer trS.Close() 506 507 trC.writePacket([]byte{msgRequestSuccess, 0, 0}) 508 errMsg := &disconnectMsg{ 509 Reason: 42, 510 Message: "such is life", 511 } 512 trC.writePacket(Marshal(errMsg)) 513 trC.writePacket([]byte{msgRequestSuccess, 0, 0}) 514 515 packet, err := trS.readPacket() 516 if err != nil { 517 t.Fatalf("readPacket 1: %v", err) 518 } 519 if packet[0] != msgRequestSuccess { 520 t.Errorf("got packet %v, want packet type %d", packet, msgRequestSuccess) 521 } 522 523 _, err = trS.readPacket() 524 if err == nil { 525 t.Errorf("readPacket 2 succeeded") 526 } else if !reflect.DeepEqual(err, errMsg) { 527 t.Errorf("got error %#v, want %#v", err, errMsg) 528 } 529 530 _, err = trS.readPacket() 531 if err == nil { 532 t.Errorf("readPacket 3 succeeded") 533 } 534 } 535 536 func TestHandshakeRekeyDefault(t *testing.T) { 537 clientConf := &ClientConfig{ 538 Config: Config{ 539 Ciphers: []string{"aes128-ctr"}, 540 }, 541 HostKeyCallback: InsecureIgnoreHostKey(), 542 } 543 trC, trS, err := handshakePair(clientConf, "addr", false) 544 if err != nil { 545 t.Fatalf("handshakePair: %v", err) 546 } 547 defer trC.Close() 548 defer trS.Close() 549 550 trC.writePacket([]byte{msgRequestSuccess, 0, 0}) 551 trC.Close() 552 553 rgb := (1024 + trC.readBytesLeft) >> 30 554 wgb := (1024 + trC.writeBytesLeft) >> 30 555 556 if rgb != 64 { 557 t.Errorf("got rekey after %dG read, want 64G", rgb) 558 } 559 if wgb != 64 { 560 t.Errorf("got rekey after %dG write, want 64G", wgb) 561 } 562 }