github.com/maenmax/kairep@v0.0.0-20210218001208-55bf3df36788/src/golang.org/x/crypto/ssh/handshake_test.go (about) 1 // Copyright 2013 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package ssh 6 7 import ( 8 "bytes" 9 "crypto/rand" 10 "errors" 11 "fmt" 12 "net" 13 "reflect" 14 "runtime" 15 "strings" 16 "sync" 17 "testing" 18 ) 19 20 type testChecker struct { 21 calls []string 22 } 23 24 func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { 25 if dialAddr == "bad" { 26 return fmt.Errorf("dialAddr is bad") 27 } 28 29 if tcpAddr, ok := addr.(*net.TCPAddr); !ok || tcpAddr == nil { 30 return fmt.Errorf("testChecker: got %T want *net.TCPAddr", addr) 31 } 32 33 t.calls = append(t.calls, fmt.Sprintf("%s %v %s %x", dialAddr, addr, key.Type(), key.Marshal())) 34 35 return nil 36 } 37 38 // netPipe is analogous to net.Pipe, but it uses a real net.Conn, and 39 // therefore is buffered (net.Pipe deadlocks if both sides start with 40 // a write.) 41 func netPipe() (net.Conn, net.Conn, error) { 42 listener, err := net.Listen("tcp", "127.0.0.1:0") 43 if err != nil { 44 return nil, nil, err 45 } 46 defer listener.Close() 47 c1, err := net.Dial("tcp", listener.Addr().String()) 48 if err != nil { 49 return nil, nil, err 50 } 51 52 c2, err := listener.Accept() 53 if err != nil { 54 c1.Close() 55 return nil, nil, err 56 } 57 58 return c1, c2, nil 59 } 60 61 func handshakePair(clientConf *ClientConfig, addr string) (client *handshakeTransport, server *handshakeTransport, err error) { 62 a, b, err := netPipe() 63 if err != nil { 64 return nil, nil, err 65 } 66 67 trC := newTransport(a, rand.Reader, true) 68 trS := newTransport(b, rand.Reader, false) 69 clientConf.SetDefaults() 70 71 v := []byte("version") 72 client = newClientTransport(trC, v, v, clientConf, addr, a.RemoteAddr()) 73 74 serverConf := &ServerConfig{} 75 serverConf.AddHostKey(testSigners["ecdsa"]) 76 serverConf.AddHostKey(testSigners["rsa"]) 77 serverConf.SetDefaults() 78 server = newServerTransport(trS, v, v, serverConf) 79 80 return client, server, nil 81 } 82 83 func TestHandshakeBasic(t *testing.T) { 84 if runtime.GOOS == "plan9" { 85 t.Skip("see golang.org/issue/7237") 86 } 87 checker := &testChecker{} 88 trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") 89 if err != nil { 90 t.Fatalf("handshakePair: %v", err) 91 } 92 93 defer trC.Close() 94 defer trS.Close() 95 96 go func() { 97 // Client writes a bunch of stuff, and does a key 98 // change in the middle. This should not confuse the 99 // handshake in progress 100 for i := 0; i < 10; i++ { 101 p := []byte{msgRequestSuccess, byte(i)} 102 if err := trC.writePacket(p); err != nil { 103 t.Fatalf("sendPacket: %v", err) 104 } 105 if i == 5 { 106 // halfway through, we request a key change. 107 err := trC.sendKexInit(subsequentKeyExchange) 108 if err != nil { 109 t.Fatalf("sendKexInit: %v", err) 110 } 111 } 112 } 113 trC.Close() 114 }() 115 116 // Server checks that client messages come in cleanly 117 i := 0 118 for { 119 p, err := trS.readPacket() 120 if err != nil { 121 break 122 } 123 if p[0] == msgNewKeys { 124 continue 125 } 126 want := []byte{msgRequestSuccess, byte(i)} 127 if bytes.Compare(p, want) != 0 { 128 t.Errorf("message %d: got %q, want %q", i, p, want) 129 } 130 i++ 131 } 132 if i != 10 { 133 t.Errorf("received %d messages, want 10.", i) 134 } 135 136 // If all went well, we registered exactly 1 key change. 137 if len(checker.calls) != 1 { 138 t.Fatalf("got %d host key checks, want 1", len(checker.calls)) 139 } 140 141 pub := testSigners["ecdsa"].PublicKey() 142 want := fmt.Sprintf("%s %v %s %x", "addr", trC.remoteAddr, pub.Type(), pub.Marshal()) 143 if want != checker.calls[0] { 144 t.Errorf("got %q want %q for host key check", checker.calls[0], want) 145 } 146 } 147 148 func TestHandshakeError(t *testing.T) { 149 checker := &testChecker{} 150 trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "bad") 151 if err != nil { 152 t.Fatalf("handshakePair: %v", err) 153 } 154 defer trC.Close() 155 defer trS.Close() 156 157 // send a packet 158 packet := []byte{msgRequestSuccess, 42} 159 if err := trC.writePacket(packet); err != nil { 160 t.Errorf("writePacket: %v", err) 161 } 162 163 // Now request a key change. 164 err = trC.sendKexInit(subsequentKeyExchange) 165 if err != nil { 166 t.Errorf("sendKexInit: %v", err) 167 } 168 169 // the key change will fail, and afterwards we can't write. 170 if err := trC.writePacket([]byte{msgRequestSuccess, 43}); err == nil { 171 t.Errorf("writePacket after botched rekey succeeded.") 172 } 173 174 readback, err := trS.readPacket() 175 if err != nil { 176 t.Fatalf("server closed too soon: %v", err) 177 } 178 if bytes.Compare(readback, packet) != 0 { 179 t.Errorf("got %q want %q", readback, packet) 180 } 181 readback, err = trS.readPacket() 182 if err == nil { 183 t.Errorf("got a message %q after failed key change", readback) 184 } 185 } 186 187 func TestForceFirstKex(t *testing.T) { 188 checker := &testChecker{} 189 trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") 190 if err != nil { 191 t.Fatalf("handshakePair: %v", err) 192 } 193 194 defer trC.Close() 195 defer trS.Close() 196 197 trC.writePacket(Marshal(&serviceRequestMsg{serviceUserAuth})) 198 199 // We setup the initial key exchange, but the remote side 200 // tries to send serviceRequestMsg in cleartext, which is 201 // disallowed. 202 203 err = trS.sendKexInit(firstKeyExchange) 204 if err == nil { 205 t.Errorf("server first kex init should reject unexpected packet") 206 } 207 } 208 209 func TestHandshakeTwice(t *testing.T) { 210 checker := &testChecker{} 211 trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") 212 if err != nil { 213 t.Fatalf("handshakePair: %v", err) 214 } 215 216 defer trC.Close() 217 defer trS.Close() 218 219 // Both sides should ask for the first key exchange first. 220 err = trS.sendKexInit(firstKeyExchange) 221 if err != nil { 222 t.Errorf("server sendKexInit: %v", err) 223 } 224 225 err = trC.sendKexInit(firstKeyExchange) 226 if err != nil { 227 t.Errorf("client sendKexInit: %v", err) 228 } 229 230 sent := 0 231 // send a packet 232 packet := make([]byte, 5) 233 packet[0] = msgRequestSuccess 234 if err := trC.writePacket(packet); err != nil { 235 t.Errorf("writePacket: %v", err) 236 } 237 sent++ 238 239 // Send another packet. Use a fresh one, since writePacket destroys. 240 packet = make([]byte, 5) 241 packet[0] = msgRequestSuccess 242 if err := trC.writePacket(packet); err != nil { 243 t.Errorf("writePacket: %v", err) 244 } 245 sent++ 246 247 // 2nd key change. 248 err = trC.sendKexInit(subsequentKeyExchange) 249 if err != nil { 250 t.Errorf("sendKexInit: %v", err) 251 } 252 253 packet = make([]byte, 5) 254 packet[0] = msgRequestSuccess 255 if err := trC.writePacket(packet); err != nil { 256 t.Errorf("writePacket: %v", err) 257 } 258 sent++ 259 260 packet = make([]byte, 5) 261 packet[0] = msgRequestSuccess 262 for i := 0; i < sent; i++ { 263 msg, err := trS.readPacket() 264 if err != nil { 265 t.Fatalf("server closed too soon: %v", err) 266 } 267 268 if bytes.Compare(msg, packet) != 0 { 269 t.Errorf("packet %d: got %q want %q", i, msg, packet) 270 } 271 } 272 if len(checker.calls) != 2 { 273 t.Errorf("got %d key changes, want 2", len(checker.calls)) 274 } 275 } 276 277 func TestHandshakeAutoRekeyWrite(t *testing.T) { 278 checker := &testChecker{} 279 clientConf := &ClientConfig{HostKeyCallback: checker.Check} 280 clientConf.RekeyThreshold = 500 281 trC, trS, err := handshakePair(clientConf, "addr") 282 if err != nil { 283 t.Fatalf("handshakePair: %v", err) 284 } 285 defer trC.Close() 286 defer trS.Close() 287 288 for i := 0; i < 5; i++ { 289 packet := make([]byte, 251) 290 packet[0] = msgRequestSuccess 291 if err := trC.writePacket(packet); err != nil { 292 t.Errorf("writePacket: %v", err) 293 } 294 } 295 296 j := 0 297 for ; j < 5; j++ { 298 _, err := trS.readPacket() 299 if err != nil { 300 break 301 } 302 } 303 304 if j != 5 { 305 t.Errorf("got %d, want 5 messages", j) 306 } 307 308 if len(checker.calls) != 2 { 309 t.Errorf("got %d key changes, wanted 2", len(checker.calls)) 310 } 311 } 312 313 type syncChecker struct { 314 called chan int 315 } 316 317 func (t *syncChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error { 318 t.called <- 1 319 return nil 320 } 321 322 func TestHandshakeAutoRekeyRead(t *testing.T) { 323 sync := &syncChecker{make(chan int, 2)} 324 clientConf := &ClientConfig{ 325 HostKeyCallback: sync.Check, 326 } 327 clientConf.RekeyThreshold = 500 328 329 trC, trS, err := handshakePair(clientConf, "addr") 330 if err != nil { 331 t.Fatalf("handshakePair: %v", err) 332 } 333 defer trC.Close() 334 defer trS.Close() 335 336 packet := make([]byte, 501) 337 packet[0] = msgRequestSuccess 338 if err := trS.writePacket(packet); err != nil { 339 t.Fatalf("writePacket: %v", err) 340 } 341 // While we read out the packet, a key change will be 342 // initiated. 343 if _, err := trC.readPacket(); err != nil { 344 t.Fatalf("readPacket(client): %v", err) 345 } 346 347 <-sync.called 348 } 349 350 // errorKeyingTransport generates errors after a given number of 351 // read/write operations. 352 type errorKeyingTransport struct { 353 packetConn 354 readLeft, writeLeft int 355 } 356 357 func (n *errorKeyingTransport) prepareKeyChange(*algorithms, *kexResult) error { 358 return nil 359 } 360 func (n *errorKeyingTransport) getSessionID() []byte { 361 return nil 362 } 363 364 func (n *errorKeyingTransport) writePacket(packet []byte) error { 365 if n.writeLeft == 0 { 366 n.Close() 367 return errors.New("barf") 368 } 369 370 n.writeLeft-- 371 return n.packetConn.writePacket(packet) 372 } 373 374 func (n *errorKeyingTransport) readPacket() ([]byte, error) { 375 if n.readLeft == 0 { 376 n.Close() 377 return nil, errors.New("barf") 378 } 379 380 n.readLeft-- 381 return n.packetConn.readPacket() 382 } 383 384 func TestHandshakeErrorHandlingRead(t *testing.T) { 385 for i := 0; i < 20; i++ { 386 testHandshakeErrorHandlingN(t, i, -1) 387 } 388 } 389 390 func TestHandshakeErrorHandlingWrite(t *testing.T) { 391 for i := 0; i < 20; i++ { 392 testHandshakeErrorHandlingN(t, -1, i) 393 } 394 } 395 396 // testHandshakeErrorHandlingN runs handshakes, injecting errors. If 397 // handshakeTransport deadlocks, the go runtime will detect it and 398 // panic. 399 func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int) { 400 msg := Marshal(&serviceRequestMsg{strings.Repeat("x", int(minRekeyThreshold)/4)}) 401 402 a, b := memPipe() 403 defer a.Close() 404 defer b.Close() 405 406 key := testSigners["ecdsa"] 407 serverConf := Config{RekeyThreshold: minRekeyThreshold} 408 serverConf.SetDefaults() 409 serverConn := newHandshakeTransport(&errorKeyingTransport{a, readLimit, writeLimit}, &serverConf, []byte{'a'}, []byte{'b'}) 410 serverConn.hostKeys = []Signer{key} 411 go serverConn.readLoop() 412 413 clientConf := Config{RekeyThreshold: 10 * minRekeyThreshold} 414 clientConf.SetDefaults() 415 clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'}) 416 clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()} 417 go clientConn.readLoop() 418 419 var wg sync.WaitGroup 420 wg.Add(4) 421 422 for _, hs := range []packetConn{serverConn, clientConn} { 423 go func(c packetConn) { 424 for { 425 err := c.writePacket(msg) 426 if err != nil { 427 break 428 } 429 } 430 wg.Done() 431 }(hs) 432 go func(c packetConn) { 433 for { 434 _, err := c.readPacket() 435 if err != nil { 436 break 437 } 438 } 439 wg.Done() 440 }(hs) 441 } 442 443 wg.Wait() 444 } 445 446 func TestDisconnect(t *testing.T) { 447 if runtime.GOOS == "plan9" { 448 t.Skip("see golang.org/issue/7237") 449 } 450 checker := &testChecker{} 451 trC, trS, err := handshakePair(&ClientConfig{HostKeyCallback: checker.Check}, "addr") 452 if err != nil { 453 t.Fatalf("handshakePair: %v", err) 454 } 455 456 defer trC.Close() 457 defer trS.Close() 458 459 trC.writePacket([]byte{msgRequestSuccess, 0, 0}) 460 errMsg := &disconnectMsg{ 461 Reason: 42, 462 Message: "such is life", 463 } 464 trC.writePacket(Marshal(errMsg)) 465 trC.writePacket([]byte{msgRequestSuccess, 0, 0}) 466 467 packet, err := trS.readPacket() 468 if err != nil { 469 t.Fatalf("readPacket 1: %v", err) 470 } 471 if packet[0] != msgRequestSuccess { 472 t.Errorf("got packet %v, want packet type %d", packet, msgRequestSuccess) 473 } 474 475 _, err = trS.readPacket() 476 if err == nil { 477 t.Errorf("readPacket 2 succeeded") 478 } else if !reflect.DeepEqual(err, errMsg) { 479 t.Errorf("got error %#v, want %#v", err, errMsg) 480 } 481 482 _, err = trS.readPacket() 483 if err == nil { 484 t.Errorf("readPacket 3 succeeded") 485 } 486 }