github.com/devops-filetransfer/sshego@v7.0.4+incompatible/_vendor/golang.org/x/crypto/ssh/handshake_test.go (about) 1 // Copyright 2013 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package ssh 6 7 import ( 8 "bytes" 9 "crypto/rand" 10 "errors" 11 "fmt" 12 "io" 13 "net" 14 "reflect" 15 "runtime" 16 "strings" 17 "sync" 18 "testing" 19 ) 20 21 type testChecker struct { 22 calls []string 23 } 24 25 func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { 26 if dialAddr == "bad" { 27 return fmt.Errorf("dialAddr is bad") 28 } 29 30 if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil { 31 return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr) 32 } 33 34 t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal())) 35 36 return nil 37 } 38 39 // netPipe is analogous to net.Pipe, but it uses a real net.Conn, and 40 // therefore is buffered (net.Pipe deadlocks if both sides start with 41 // a write.) 42 func netPipe() (net.Conn, net.Conn, error) { 43 listener, err := net.Listen("tcp", ":0") 44 if err != nil { 45 return nil, nil, err 46 } 47 defer listener.Close() 48 c1, err := net.Dial("tcp", listener.Addr().String()) 49 if err != nil { 50 return nil, nil, err 51 } 52 53 c2, err := listener.Accept() 54 if err != nil { 55 c1.Close() 56 return nil, nil, err 57 } 58 59 return c1, c2, nil 60 } 61 62 // noiseTransport inserts ignore messages to check that the read loop 63 // and the key exchange filters out these messages. 64 type noiseTransport struct { 65 keyingTransport 66 } 67 68 func (t *noiseTransport) writePacket(p []byte) error { 69 ignore := []byte{msgIgnore} 70 if err := t.keyingTransport.writePacket(ignore); err != nil { 71 return err 72 } 73 debug := []byte{msgDebug, 1, 2, 3} 74 if err := t.keyingTransport.writePacket(debug); err != nil { 75 return err 76 } 77 78 return t.keyingTransport.writePacket(p) 79 } 80 81 func addNoiseTransport(t keyingTransport) keyingTransport { 82 return &noiseTransport{t} 83 } 84 85 // handshakePair creates two handshakeTransports connected with each 86 // other. If the noise argument is true, both transports will try to 87 // confuse the other side by sending ignore and debug messages. 88 func handshakePair(clientConf *ClientConfig, addr string, noise bool) (client *handshakeTransport, server *handshakeTransport, err error) { 89 a, b, err := netPipe() 90 if err != nil { 91 return nil, nil, err 92 } 93 94 var trC, trS keyingTransport 95 96 trC = newTransport(a, rand.Reader, true) 97 trS = newTransport(b, rand.Reader, false) 98 if noise { 99 trC = addNoiseTransport(trC) 100 trS = addNoiseTransport(trS) 101 } 102 clientConf.SetDefaults() 103 104 v := []byte("version") 105 client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr()) 106 107 serverConf := &ServerConfig{} 108 serverConf.AddHostKey(testSigners["ecdsa"]) 109 serverConf.AddHostKey(testSigners["rsa"]) 110 serverConf.SetDefaults() 111 server = newServerTransport(trS, v, v, serverConf) 112 113 if err := server.waitSession(); err != nil { 114 return nil, nil, fmt.Errorf("server.waitSession: %v", err) 115 } 116 if err := client.waitSession(); err != nil { 117 return nil, nil, fmt.Errorf("client.waitSession: %v", err) 118 } 119 120 return client, server, nil 121 } 122 123 func TestHandshakeBasic(t *testing.T) { 124 if runtime.GOOS == "plan9" { 125 t.Skip("see golang.org/issue/7237") 126 } 127 128 checker := &syncChecker{ 129 waitCall: make(chan int, 10), 130 called: make(chan int, 10), 131 } 132 133 checker.waitCall <- 1 134 trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false) 135 if err != nil { 136 t.Fatalf("handshakePair: %v", err) 137 } 138 139 defer trC.Close() 140 defer trS.Close() 141 142 // Let first kex complete normally. 143 <-checker.called 144 145 clientDone := make(chan int, 0) 146 gotHalf := make(chan int, 0) 147 const N = 20 148 149 go func() { 150 defer close(clientDone) 151 // Client writes a bunch of stuff, and does a key 152 // change in the middle. This should not confuse the 153 // handshake in progress. We do this twice, so we test 154 // that the packet buffer is reset correctly. 155 for i := 0; i < N; i++ { 156 p := []byte{msgRequestSuccess, byte(i)} 157 if err := trC.writePacket(p); err != nil { 158 t.Fatalf("sendPacket: %v", err) 159 } 160 if (i % 10) == 5 { 161 <-gotHalf 162 // halfway through, we request a key change. 163 trC.requestKeyExchange() 164 165 // Wait until we can be sure the key 166 // change has really started before we 167 // write more. 168 <-checker.called 169 } 170 if (i % 10) == 7 { 171 // write some packets until the kex 172 // completes, to test buffering of 173 // packets. 174 checker.waitCall <- 1 175 } 176 } 177 }() 178 179 // Server checks that client messages come in cleanly 180 i := 0 181 err = nil 182 for ; i < N; i++ { 183 var p []byte 184 p, err = trS.readPacket() 185 if err != nil { 186 break 187 } 188 if (i % 10) == 5 { 189 gotHalf <- 1 190 } 191 192 want := []byte{msgRequestSuccess, byte(i)} 193 if bytes.Compare(p, want) != 0 { 194 t.Errorf("message %d: got %v, want %v", i, p, want) 195 } 196 } 197 <-clientDone 198 if err != nil && err != io.EOF { 199 t.Fatalf("server error: %v", err) 200 } 201 if i != N { 202 t.Errorf("received %d messages, want 10.", i) 203 } 204 205 close(checker.called) 206 if _, ok := <-checker.called; ok { 207 // If all went well, we registered exactly 2 key changes: one 208 // that establishes the session, and one that we requested 209 // additionally. 210 t.Fatalf("got another host key checks after 2 handshakes") 211 } 212 } 213 214 func TestForceFirstKex(t *testing.T) { 215 // like handshakePair, but must access the keyingTransport. 216 checker := &testChecker{} 217 clientConf := &ClientConfig{HostKeyCallback: checker.Check} 218 a, b, err := netPipe() 219 if err != nil { 220 t.Fatalf("netPipe: %v", err) 221 } 222 223 var trC, trS keyingTransport 224 225 trC = newTransport(a, rand.Reader, true) 226 227 // This is the disallowed packet: 228 trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})) 229 230 // Rest of the setup. 231 trS = newTransport(b, rand.Reader, false) 232 clientConf.SetDefaults() 233 234 v := []byte("version") 235 client := newClientTransport(trC, v, v, clientConf, "addr", a.RemoteAddr()) 236 237 serverConf := &ServerConfig{} 238 serverConf.AddHostKey(testSigners["ecdsa"]) 239 serverConf.AddHostKey(testSigners["rsa"]) 240 serverConf.SetDefaults() 241 server := newServerTransport(trS, v, v, serverConf) 242 243 defer client.Close() 244 defer server.Close() 245 246 // We setup the initial key exchange, but the remote side 247 // tries to send serviceRequestMsg in cleartext, which is 248 // disallowed. 249 250 if err := server.waitSession(); err == nil { 251 t.Errorf("server first kex init should reject unexpected packet") 252 } 253 } 254 255 func TestHandshakeAutoRekeyWrite(t *testing.T) { 256 checker := &syncChecker{ 257 called: make(chan int, 10), 258 waitCall: nil, 259 } 260 clientConf := &ClientConfig{HostKeyCallback: checker.Check} 261 clientConf.RekeyThreshold = 500 262 trC, trS, err := handshakePair(clientConf, "addr", false) 263 if err != nil { 264 t.Fatalf("handshakePair: %v", err) 265 } 266 defer trC.Close() 267 defer trS.Close() 268 269 input := make([]byte, 251) 270 input[0] = msgRequestSuccess 271 272 done := make(chan int, 1) 273 const numPacket = 5 274 go func() { 275 defer close(done) 276 j := 0 277 for ; j < numPacket; j++ { 278 if p, err := trS.readPacket(); err != nil { 279 break 280 } else if !bytes.Equal(input, p) { 281 t.Errorf("got packet type %d, want %d", p[0], input[0]) 282 } 283 } 284 285 if j != numPacket { 286 t.Errorf("got %d, want 5 messages", j) 287 } 288 }() 289 290 <-checker.called 291 292 for i := 0; i < numPacket; i++ { 293 p := make([]byte, len(input)) 294 copy(p, input) 295 if err := trC.writePacket(p); err != nil { 296 t.Errorf("writePacket: %v", err) 297 } 298 if i == 2 { 299 // Make sure the kex is in progress. 300 <-checker.called 301 } 302 303 } 304 <-done 305 } 306 307 type syncChecker struct { 308 waitCall chan int 309 called chan int 310 } 311 312 func (c *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { 313 c.called <- 1 314 if c.waitCall != nil { 315 <-c.waitCall 316 } 317 return nil 318 } 319 320 func TestHandshakeAutoRekeyRead(t *testing.T) { 321 sync := &syncChecker{ 322 called: make(chan int, 2), 323 waitCall: nil, 324 } 325 clientConf := &ClientConfig{ 326 HostKeyCallback: sync.Check, 327 } 328 clientConf.RekeyThreshold = 500 329 330 trC, trS, err := handshakePair(clientConf, "addr", false) 331 if err != nil { 332 t.Fatalf("handshakePair: %v", err) 333 } 334 defer trC.Close() 335 defer trS.Close() 336 337 packet := make([]byte, 501) 338 packet[0] = msgRequestSuccess 339 if err := trS.writePacket(packet); err != nil { 340 t.Fatalf("writePacket: %v", err) 341 } 342 343 // While we read out the packet, a key change will be 344 // initiated. 345 done := make(chan int, 1) 346 go func() { 347 defer close(done) 348 if _, err := trC.readPacket(); err != nil { 349 t.Fatalf("readPacket(client): %v", err) 350 } 351 352 }() 353 354 <-done 355 <-sync.called 356 } 357 358 // errorKeyingTransport generates errors after a given number of 359 // read/write operations. 360 type errorKeyingTransport struct { 361 packetConn 362 readLeft, writeLeft int 363 } 364 365 func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error { 366 return nil 367 } 368 369 func (n *errorKeyingTransport) getSessionID() []byte { 370 return nil 371 } 372 373 func (n *errorKeyingTransport) writePacket(packet []byte) error { 374 if n.writeLeft == 0 { 375 n.Close() 376 return errors.New("barf") 377 } 378 379 n.writeLeft-- 380 return n.packetConn.writePacket(packet) 381 } 382 383 func (n *errorKeyingTransport) readPacket() ([]byte, error) { 384 if n.readLeft == 0 { 385 n.Close() 386 return nil, errors.New("barf") 387 } 388 389 n.readLeft-- 390 return n.packetConn.readPacket() 391 } 392 393 func TestHandshakeErrorHandlingRead(t *testing.T) { 394 for i := 0; i < 20; i++ { 395 testHandshakeErrorHandlingN(t, i, -1, false) 396 } 397 } 398 399 func TestHandshakeErrorHandlingWrite(t *testing.T) { 400 for i := 0; i < 20; i++ { 401 testHandshakeErrorHandlingN(t, -1, i, false) 402 } 403 } 404 405 func TestHandshakeErrorHandlingReadCoupled(t *testing.T) { 406 for i := 0; i < 20; i++ { 407 testHandshakeErrorHandlingN(t, i, -1, true) 408 } 409 } 410 411 func TestHandshakeErrorHandlingWriteCoupled(t *testing.T) { 412 for i := 0; i < 20; i++ { 413 testHandshakeErrorHandlingN(t, -1, i, true) 414 } 415 } 416 417 // testHandshakeErrorHandlingN runs handshakes, injecting errors. If 418 // handshakeTransport deadlocks, the go runtime will detect it and 419 // panic. 420 func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int, coupled bool) { 421 msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)}) 422 423 a, b := memPipe() 424 defer a.Close() 425 defer b.Close() 426 427 key := testSigners["ecdsa"] 428 serverConf := Config{RekeyThreshold: minRekeyThreshold} 429 serverConf.SetDefaults() 430 serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'}) 431 serverConn.hostKeys = []Signer{key} 432 go serverConn.readLoop() 433 go serverConn.kexLoop() 434 435 clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold} 436 clientConf.SetDefaults() 437 clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'}) 438 clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()} 439 clientConn.hostKeyCallback = InsecureIgnoreHostKey() 440 go clientConn.readLoop() 441 go clientConn.kexLoop() 442 443 var wg sync.WaitGroup 444 445 for _, hs := range []packetConn{serverConn, clientConn} { 446 if !coupled { 447 wg.Add(2) 448 go func(c packetConn) { 449 for i := 0; ; i++ { 450 str := fmt.Sprintf("%08x", i) + strings.Repeat("x", int(minRekeyThreshold)/4-8) 451 err := c.writePacket(Marshal(&serviceRequestMsg{str})) 452 if err != nil { 453 break 454 } 455 } 456 wg.Done() 457 c.Close() 458 }(hs) 459 go func(c packetConn) { 460 for { 461 _, err := c.readPacket() 462 if err != nil { 463 break 464 } 465 } 466 wg.Done() 467 }(hs) 468 } else { 469 wg.Add(1) 470 go func(c packetConn) { 471 for { 472 _, err := c.readPacket() 473 if err != nil { 474 break 475 } 476 if err := c.writePacket(msg); err != nil { 477 break 478 } 479 480 } 481 wg.Done() 482 }(hs) 483 } 484 } 485 wg.Wait() 486 } 487 488 func TestDisconnect(t *testing.T) { 489 if runtime.GOOS == "plan9" { 490 t.Skip("see golang.org/issue/7237") 491 } 492 checker := &testChecker{} 493 trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr", false) 494 if err != nil { 495 t.Fatalf("handshakePair: %v", err) 496 } 497 498 defer trC.Close() 499 defer trS.Close() 500 501 trC.writePacket([]byte{msgRequestSuccess, 0, 0}) 502 errMsg := &disconnectMsg{ 503 Reason: 42, 504 Message: "such is life", 505 } 506 trC.writePacket(Marshal(errMsg)) 507 trC.writePacket([]byte{msgRequestSuccess, 0, 0}) 508 509 packet, err := trS.readPacket() 510 if err != nil { 511 t.Fatalf("readPacket 1: %v", err) 512 } 513 if packet[0] != msgRequestSuccess { 514 t.Errorf("got packet %v, want packet type %d", packet, msgRequestSuccess) 515 } 516 517 _, err = trS.readPacket() 518 if err == nil { 519 t.Errorf("readPacket 2 succeeded") 520 } else if !reflect.DeepEqual(err, errMsg) { 521 t.Errorf("got error %#v, want %#v", err, errMsg) 522 } 523 524 _, err = trS.readPacket() 525 if err == nil { 526 t.Errorf("readPacket 3 succeeded") 527 } 528 } 529 530 func TestHandshakeRekeyDefault(t *testing.T) { 531 clientConf := &ClientConfig{ 532 Config: Config{ 533 Ciphers: []string{"aes128-ctr"}, 534 }, 535 HostKeyCallback: InsecureIgnoreHostKey(), 536 } 537 trC, trS, err := handshakePair(clientConf, "addr", false) 538 if err != nil { 539 t.Fatalf("handshakePair: %v", err) 540 } 541 defer trC.Close() 542 defer trS.Close() 543 544 trC.writePacket([]byte{msgRequestSuccess, 0, 0}) 545 trC.Close() 546 547 rgb := (1024 + trC.readBytesLeft) >> 30 548 wgb := (1024 + trC.writeBytesLeft) >> 30 549 550 if rgb != 64 { 551 t.Errorf("got rekey after %dG read, want 64G", rgb) 552 } 553 if wgb != 64 { 554 t.Errorf("got rekey after %dG write, want 64G", wgb) 555 } 556 }