github.com/Night-mk/quorum@v21.1.0+incompatible/p2p/server_test.go (about) 1 // Copyright 2014 The go-ethereum Authors 2 // This file is part of the go-ethereum library. 3 // 4 // The go-ethereum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The go-ethereum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package p2p 18 19 import ( 20 "crypto/ecdsa" 21 "errors" 22 "io" 23 "io/ioutil" 24 "math/rand" 25 "net" 26 "os" 27 "path" 28 "reflect" 29 "testing" 30 "time" 31 32 "github.com/stretchr/testify/assert" 33 "golang.org/x/crypto/sha3" 34 35 "github.com/ethereum/go-ethereum/crypto" 36 "github.com/ethereum/go-ethereum/internal/testlog" 37 "github.com/ethereum/go-ethereum/log" 38 "github.com/ethereum/go-ethereum/p2p/enode" 39 "github.com/ethereum/go-ethereum/p2p/enr" 40 "github.com/ethereum/go-ethereum/params" 41 ) 42 43 // func init() { 44 // log.Root().SetHandler(log.LvlFilterHandler(log.LvlTrace, log.StreamHandler(os.Stderr, log.TerminalFormat(false)))) 45 // } 46 47 type testTransport struct { 48 rpub *ecdsa.PublicKey 49 *rlpx 50 51 closeErr error 52 } 53 54 func newTestTransport(rpub *ecdsa.PublicKey, fd net.Conn) transport { 55 wrapped := newRLPX(fd).(*rlpx) 56 wrapped.rw = newRLPXFrameRW(fd, secrets{ 57 MAC: zero16, 58 AES: zero16, 59 IngressMAC: sha3.NewLegacyKeccak256(), 60 EgressMAC: sha3.NewLegacyKeccak256(), 61 }) 62 return &testTransport{rpub: rpub, rlpx: wrapped} 63 } 64 65 func (c *testTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *ecdsa.PublicKey) (*ecdsa.PublicKey, error) { 66 return c.rpub, nil 67 } 68 69 func (c *testTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) { 70 pubkey := crypto.FromECDSAPub(c.rpub)[1:] 71 return &protoHandshake{ID: pubkey, Name: "test"}, nil 72 } 73 74 func (c *testTransport) close(err error) { 75 c.rlpx.fd.Close() 76 c.closeErr = err 77 } 78 79 func startTestServer(t *testing.T, remoteKey *ecdsa.PublicKey, pf func(*Peer)) *Server { 80 config := Config{ 81 Name: "test", 82 MaxPeers: 10, 83 ListenAddr: "127.0.0.1:0", 84 PrivateKey: newkey(), 85 Logger: testlog.Logger(t, log.LvlTrace), 86 } 87 server := &Server{ 88 Config: config, 89 newPeerHook: pf, 90 newTransport: func(fd net.Conn) transport { return newTestTransport(remoteKey, fd) }, 91 } 92 if err := server.Start(); err != nil { 93 t.Fatalf("Could not start server: %v", err) 94 } 95 return server 96 } 97 98 func TestServerListen(t *testing.T) { 99 // start the test server 100 connected := make(chan *Peer) 101 remid := &newkey().PublicKey 102 srv := startTestServer(t, remid, func(p *Peer) { 103 if p.ID() != enode.PubkeyToIDV4(remid) { 104 t.Error("peer func called with wrong node id") 105 } 106 connected <- p 107 }) 108 defer close(connected) 109 defer srv.Stop() 110 111 // dial the test server 112 conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second) 113 if err != nil { 114 t.Fatalf("could not dial: %v", err) 115 } 116 defer conn.Close() 117 118 select { 119 case peer := <-connected: 120 if peer.LocalAddr().String() != conn.RemoteAddr().String() { 121 t.Errorf("peer started with wrong conn: got %v, want %v", 122 peer.LocalAddr(), conn.RemoteAddr()) 123 } 124 peers := srv.Peers() 125 if !reflect.DeepEqual(peers, []*Peer{peer}) { 126 t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer}) 127 } 128 case <-time.After(1 * time.Second): 129 t.Error("server did not accept within one second") 130 } 131 } 132 133 func TestServerDial(t *testing.T) { 134 // run a one-shot TCP server to handle the connection. 135 listener, err := net.Listen("tcp", "127.0.0.1:0") 136 if err != nil { 137 t.Fatalf("could not setup listener: %v", err) 138 } 139 defer listener.Close() 140 accepted := make(chan net.Conn) 141 go func() { 142 conn, err := listener.Accept() 143 if err != nil { 144 t.Error("accept error:", err) 145 return 146 } 147 accepted <- conn 148 }() 149 150 // start the server 151 connected := make(chan *Peer) 152 remid := &newkey().PublicKey 153 srv := startTestServer(t, remid, func(p *Peer) { connected <- p }) 154 defer close(connected) 155 defer srv.Stop() 156 157 // tell the server to connect 158 tcpAddr := listener.Addr().(*net.TCPAddr) 159 node := enode.NewV4(remid, tcpAddr.IP, tcpAddr.Port, 0) 160 srv.AddPeer(node) 161 162 select { 163 case conn := <-accepted: 164 defer conn.Close() 165 166 select { 167 case peer := <-connected: 168 if peer.ID() != enode.PubkeyToIDV4(remid) { 169 t.Errorf("peer has wrong id") 170 } 171 if peer.Name() != "test" { 172 t.Errorf("peer has wrong name") 173 } 174 if peer.RemoteAddr().String() != conn.LocalAddr().String() { 175 t.Errorf("peer started with wrong conn: got %v, want %v", 176 peer.RemoteAddr(), conn.LocalAddr()) 177 } 178 peers := srv.Peers() 179 if !reflect.DeepEqual(peers, []*Peer{peer}) { 180 t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer}) 181 } 182 183 // Test AddTrustedPeer/RemoveTrustedPeer and changing Trusted flags 184 // Particularly for race conditions on changing the flag state. 185 if peer := srv.Peers()[0]; peer.Info().Network.Trusted { 186 t.Errorf("peer is trusted prematurely: %v", peer) 187 } 188 done := make(chan bool) 189 go func() { 190 srv.AddTrustedPeer(node) 191 if peer := srv.Peers()[0]; !peer.Info().Network.Trusted { 192 t.Errorf("peer is not trusted after AddTrustedPeer: %v", peer) 193 } 194 srv.RemoveTrustedPeer(node) 195 if peer := srv.Peers()[0]; peer.Info().Network.Trusted { 196 t.Errorf("peer is trusted after RemoveTrustedPeer: %v", peer) 197 } 198 done <- true 199 }() 200 // Trigger potential race conditions 201 peer = srv.Peers()[0] 202 _ = peer.Inbound() 203 _ = peer.Info() 204 <-done 205 case <-time.After(1 * time.Second): 206 t.Error("server did not launch peer within one second") 207 } 208 209 case <-time.After(1 * time.Second): 210 t.Error("server did not connect within one second") 211 } 212 } 213 214 // This test checks that tasks generated by dialstate are 215 // actually executed and taskdone is called for them. 216 func TestServerTaskScheduling(t *testing.T) { 217 var ( 218 done = make(chan *testTask) 219 quit, returned = make(chan struct{}), make(chan struct{}) 220 tc = 0 221 tg = taskgen{ 222 newFunc: func(running int, peers map[enode.ID]*Peer) []task { 223 tc++ 224 return []task{&testTask{index: tc - 1}} 225 }, 226 doneFunc: func(t task) { 227 select { 228 case done <- t.(*testTask): 229 case <-quit: 230 } 231 }, 232 } 233 ) 234 235 // The Server in this test isn't actually running 236 // because we're only interested in what run does. 237 db, _ := enode.OpenDB("") 238 srv := &Server{ 239 Config: Config{MaxPeers: 10}, 240 localnode: enode.NewLocalNode(db, newkey()), 241 nodedb: db, 242 discmix: enode.NewFairMix(0), 243 quit: make(chan struct{}), 244 running: true, 245 log: log.New(), 246 } 247 srv.loopWG.Add(1) 248 go func() { 249 srv.run(tg) 250 close(returned) 251 }() 252 253 var gotdone []*testTask 254 for i := 0; i < 100; i++ { 255 gotdone = append(gotdone, <-done) 256 } 257 for i, task := range gotdone { 258 if task.index != i { 259 t.Errorf("task %d has wrong index, got %d", i, task.index) 260 break 261 } 262 if !task.called { 263 t.Errorf("task %d was not called", i) 264 break 265 } 266 } 267 268 close(quit) 269 srv.Stop() 270 select { 271 case <-returned: 272 case <-time.After(500 * time.Millisecond): 273 t.Error("Server.run did not return within 500ms") 274 } 275 } 276 277 // This test checks that Server doesn't drop tasks, 278 // even if newTasks returns more than the maximum number of tasks. 279 func TestServerManyTasks(t *testing.T) { 280 alltasks := make([]task, 300) 281 for i := range alltasks { 282 alltasks[i] = &testTask{index: i} 283 } 284 285 var ( 286 db, _ = enode.OpenDB("") 287 srv = &Server{ 288 quit: make(chan struct{}), 289 localnode: enode.NewLocalNode(db, newkey()), 290 nodedb: db, 291 running: true, 292 log: log.New(), 293 discmix: enode.NewFairMix(0), 294 } 295 done = make(chan *testTask) 296 start, end = 0, 0 297 ) 298 defer srv.Stop() 299 srv.loopWG.Add(1) 300 go srv.run(taskgen{ 301 newFunc: func(running int, peers map[enode.ID]*Peer) []task { 302 start, end = end, end+maxActiveDialTasks+10 303 if end > len(alltasks) { 304 end = len(alltasks) 305 } 306 return alltasks[start:end] 307 }, 308 doneFunc: func(tt task) { 309 done <- tt.(*testTask) 310 }, 311 }) 312 313 doneset := make(map[int]bool) 314 timeout := time.After(2 * time.Second) 315 for len(doneset) < len(alltasks) { 316 select { 317 case tt := <-done: 318 if doneset[tt.index] { 319 t.Errorf("task %d got done more than once", tt.index) 320 } else { 321 doneset[tt.index] = true 322 } 323 case <-timeout: 324 t.Errorf("%d of %d tasks got done within 2s", len(doneset), len(alltasks)) 325 for i := 0; i < len(alltasks); i++ { 326 if !doneset[i] { 327 t.Logf("task %d not done", i) 328 } 329 } 330 return 331 } 332 } 333 } 334 335 type taskgen struct { 336 newFunc func(running int, peers map[enode.ID]*Peer) []task 337 doneFunc func(task) 338 } 339 340 func (tg taskgen) newTasks(running int, peers map[enode.ID]*Peer, now time.Time) []task { 341 return tg.newFunc(running, peers) 342 } 343 func (tg taskgen) taskDone(t task, now time.Time) { 344 tg.doneFunc(t) 345 } 346 func (tg taskgen) addStatic(*enode.Node) { 347 } 348 func (tg taskgen) removeStatic(*enode.Node) { 349 } 350 351 type testTask struct { 352 index int 353 called bool 354 } 355 356 func (t *testTask) Do(srv *Server) { 357 t.called = true 358 } 359 360 // This test checks that connections are disconnected 361 // just after the encryption handshake when the server is 362 // at capacity. Trusted connections should still be accepted. 363 func TestServerAtCap(t *testing.T) { 364 trustedNode := newkey() 365 trustedID := enode.PubkeyToIDV4(&trustedNode.PublicKey) 366 srv := &Server{ 367 Config: Config{ 368 PrivateKey: newkey(), 369 MaxPeers: 10, 370 NoDial: true, 371 NoDiscovery: true, 372 TrustedNodes: []*enode.Node{newNode(trustedID, nil)}, 373 }, 374 } 375 if err := srv.Start(); err != nil { 376 t.Fatalf("could not start: %v", err) 377 } 378 defer srv.Stop() 379 380 newconn := func(id enode.ID) *conn { 381 fd, _ := net.Pipe() 382 tx := newTestTransport(&trustedNode.PublicKey, fd) 383 node := enode.SignNull(new(enr.Record), id) 384 return &conn{fd: fd, transport: tx, flags: inboundConn, node: node, cont: make(chan error)} 385 } 386 387 // Inject a few connections to fill up the peer set. 388 for i := 0; i < 10; i++ { 389 c := newconn(randomID()) 390 if err := srv.checkpoint(c, srv.checkpointAddPeer); err != nil { 391 t.Fatalf("could not add conn %d: %v", i, err) 392 } 393 } 394 // Try inserting a non-trusted connection. 395 anotherID := randomID() 396 c := newconn(anotherID) 397 if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != DiscTooManyPeers { 398 t.Error("wrong error for insert:", err) 399 } 400 // Try inserting a trusted connection. 401 c = newconn(trustedID) 402 if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != nil { 403 t.Error("unexpected error for trusted conn @posthandshake:", err) 404 } 405 if !c.is(trustedConn) { 406 t.Error("Server did not set trusted flag") 407 } 408 409 // Remove from trusted set and try again 410 srv.RemoveTrustedPeer(newNode(trustedID, nil)) 411 c = newconn(trustedID) 412 if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != DiscTooManyPeers { 413 t.Error("wrong error for insert:", err) 414 } 415 416 // Add anotherID to trusted set and try again 417 srv.AddTrustedPeer(newNode(anotherID, nil)) 418 c = newconn(anotherID) 419 if err := srv.checkpoint(c, srv.checkpointPostHandshake); err != nil { 420 t.Error("unexpected error for trusted conn @posthandshake:", err) 421 } 422 if !c.is(trustedConn) { 423 t.Error("Server did not set trusted flag") 424 } 425 } 426 427 func TestServerPeerLimits(t *testing.T) { 428 srvkey := newkey() 429 clientkey := newkey() 430 clientnode := enode.NewV4(&clientkey.PublicKey, nil, 0, 0) 431 432 var tp = &setupTransport{ 433 pubkey: &clientkey.PublicKey, 434 phs: protoHandshake{ 435 ID: crypto.FromECDSAPub(&clientkey.PublicKey)[1:], 436 // Force "DiscUselessPeer" due to unmatching caps 437 // Caps: []Cap{discard.cap()}, 438 }, 439 } 440 441 srv := &Server{ 442 Config: Config{ 443 PrivateKey: srvkey, 444 MaxPeers: 0, 445 NoDial: true, 446 NoDiscovery: true, 447 Protocols: []Protocol{discard}, 448 }, 449 newTransport: func(fd net.Conn) transport { return tp }, 450 log: log.New(), 451 } 452 if err := srv.Start(); err != nil { 453 t.Fatalf("couldn't start server: %v", err) 454 } 455 defer srv.Stop() 456 457 // Check that server is full (MaxPeers=0) 458 flags := dynDialedConn 459 dialDest := clientnode 460 conn, _ := net.Pipe() 461 srv.SetupConn(conn, flags, dialDest) 462 if tp.closeErr != DiscTooManyPeers { 463 t.Errorf("unexpected close error: %q", tp.closeErr) 464 } 465 conn.Close() 466 467 srv.AddTrustedPeer(clientnode) 468 469 // Check that server allows a trusted peer despite being full. 470 conn, _ = net.Pipe() 471 srv.SetupConn(conn, flags, dialDest) 472 if tp.closeErr == DiscTooManyPeers { 473 t.Errorf("failed to bypass MaxPeers with trusted node: %q", tp.closeErr) 474 } 475 476 if tp.closeErr != DiscUselessPeer { 477 t.Errorf("unexpected close error: %q", tp.closeErr) 478 } 479 conn.Close() 480 481 srv.RemoveTrustedPeer(clientnode) 482 483 // Check that server is full again. 484 conn, _ = net.Pipe() 485 srv.SetupConn(conn, flags, dialDest) 486 if tp.closeErr != DiscTooManyPeers { 487 t.Errorf("unexpected close error: %q", tp.closeErr) 488 } 489 conn.Close() 490 } 491 492 func TestServerSetupConn(t *testing.T) { 493 var ( 494 clientkey, srvkey = newkey(), newkey() 495 clientpub = &clientkey.PublicKey 496 srvpub = &srvkey.PublicKey 497 ) 498 tests := []struct { 499 dontstart bool 500 tt *setupTransport 501 flags connFlag 502 dialDest *enode.Node 503 504 wantCloseErr error 505 wantCalls string 506 }{ 507 { 508 dontstart: true, 509 tt: &setupTransport{pubkey: clientpub}, 510 wantCalls: "close,", 511 wantCloseErr: errServerStopped, 512 }, 513 { 514 tt: &setupTransport{pubkey: clientpub, encHandshakeErr: errors.New("read error")}, 515 flags: inboundConn, 516 wantCalls: "doEncHandshake,close,", 517 wantCloseErr: errors.New("read error"), 518 }, 519 { 520 tt: &setupTransport{pubkey: clientpub}, 521 dialDest: enode.NewV4(&newkey().PublicKey, nil, 0, 0), 522 flags: dynDialedConn, 523 wantCalls: "doEncHandshake,close,", 524 wantCloseErr: DiscUnexpectedIdentity, 525 }, 526 { 527 tt: &setupTransport{pubkey: clientpub, phs: protoHandshake{ID: randomID().Bytes()}}, 528 dialDest: enode.NewV4(clientpub, nil, 0, 0), 529 flags: dynDialedConn, 530 wantCalls: "doEncHandshake,doProtoHandshake,close,", 531 wantCloseErr: DiscUnexpectedIdentity, 532 }, 533 { 534 tt: &setupTransport{pubkey: clientpub, protoHandshakeErr: errors.New("foo")}, 535 dialDest: enode.NewV4(clientpub, nil, 0, 0), 536 flags: dynDialedConn, 537 wantCalls: "doEncHandshake,doProtoHandshake,close,", 538 wantCloseErr: errors.New("foo"), 539 }, 540 { 541 tt: &setupTransport{pubkey: srvpub, phs: protoHandshake{ID: crypto.FromECDSAPub(srvpub)[1:]}}, 542 flags: inboundConn, 543 wantCalls: "doEncHandshake,close,", 544 wantCloseErr: DiscSelf, 545 }, 546 { 547 tt: &setupTransport{pubkey: clientpub, phs: protoHandshake{ID: crypto.FromECDSAPub(clientpub)[1:]}}, 548 flags: inboundConn, 549 wantCalls: "doEncHandshake,doProtoHandshake,close,", 550 wantCloseErr: DiscUselessPeer, 551 }, 552 } 553 554 for i, test := range tests { 555 t.Run(test.wantCalls, func(t *testing.T) { 556 cfg := Config{ 557 PrivateKey: srvkey, 558 MaxPeers: 10, 559 NoDial: true, 560 NoDiscovery: true, 561 Protocols: []Protocol{discard}, 562 Logger: testlog.Logger(t, log.LvlTrace), 563 } 564 srv := &Server{ 565 Config: cfg, 566 newTransport: func(fd net.Conn) transport { return test.tt }, 567 log: cfg.Logger, 568 } 569 if !test.dontstart { 570 if err := srv.Start(); err != nil { 571 t.Fatalf("couldn't start server: %v", err) 572 } 573 defer srv.Stop() 574 } 575 p1, _ := net.Pipe() 576 srv.SetupConn(p1, test.flags, test.dialDest) 577 if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) { 578 t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr) 579 } 580 if test.tt.calls != test.wantCalls { 581 t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls) 582 } 583 }) 584 } 585 } 586 587 func TestServerSetupConn_whenNotInRaftCluster(t *testing.T) { 588 var ( 589 clientkey, srvkey = newkey(), newkey() 590 clientpub = &clientkey.PublicKey 591 ) 592 593 clientNode := enode.NewV4(clientpub, nil, 0, 0) 594 srv := &Server{ 595 Config: Config{ 596 PrivateKey: srvkey, 597 NoDiscovery: true, 598 }, 599 newTransport: func(fd net.Conn) transport { return newTestTransport(clientpub, fd) }, 600 log: log.New(), 601 checkPeerInRaft: func(node *enode.Node) bool { 602 return false 603 }, 604 } 605 if err := srv.Start(); err != nil { 606 t.Fatalf("couldn't start server: %v", err) 607 } 608 defer srv.Stop() 609 p1, _ := net.Pipe() 610 err := srv.SetupConn(p1, inboundConn, clientNode) 611 612 assert.IsType(t, &peerError{}, err) 613 perr := err.(*peerError) 614 t.Log(perr.Error()) 615 assert.Equal(t, errNotInRaftCluster, perr.code) 616 } 617 618 func TestServerSetupConn_whenNotPermissioned(t *testing.T) { 619 tmpDir, err := ioutil.TempDir("", "") 620 if err != nil { 621 t.Fatal(err) 622 } 623 defer func() { _ = os.RemoveAll(tmpDir) }() 624 if err := ioutil.WriteFile(path.Join(tmpDir, params.PERMISSIONED_CONFIG), []byte("[]"), 0644); err != nil { 625 t.Fatal(err) 626 } 627 var ( 628 clientkey, srvkey = newkey(), newkey() 629 clientpub = &clientkey.PublicKey 630 ) 631 clientNode := enode.NewV4(clientpub, nil, 0, 0) 632 srv := &Server{ 633 Config: Config{ 634 PrivateKey: srvkey, 635 NoDiscovery: true, 636 DataDir: tmpDir, 637 EnableNodePermission: true, 638 }, 639 newTransport: func(fd net.Conn) transport { return newTestTransport(clientpub, fd) }, 640 log: log.New(), 641 } 642 if err := srv.Start(); err != nil { 643 t.Fatalf("couldn't start server: %v", err) 644 } 645 defer srv.Stop() 646 p1, _ := net.Pipe() 647 err = srv.SetupConn(p1, inboundConn, clientNode) 648 649 assert.IsType(t, &peerError{}, err) 650 perr := err.(*peerError) 651 t.Log(perr.Error()) 652 assert.Equal(t, errPermissionDenied, perr.code) 653 } 654 655 type setupTransport struct { 656 pubkey *ecdsa.PublicKey 657 encHandshakeErr error 658 phs protoHandshake 659 protoHandshakeErr error 660 661 calls string 662 closeErr error 663 } 664 665 func (c *setupTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *ecdsa.PublicKey) (*ecdsa.PublicKey, error) { 666 c.calls += "doEncHandshake," 667 return c.pubkey, c.encHandshakeErr 668 } 669 670 func (c *setupTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) { 671 c.calls += "doProtoHandshake," 672 if c.protoHandshakeErr != nil { 673 return nil, c.protoHandshakeErr 674 } 675 return &c.phs, nil 676 } 677 func (c *setupTransport) close(err error) { 678 c.calls += "close," 679 c.closeErr = err 680 } 681 682 // setupConn shouldn't write to/read from the connection. 683 func (c *setupTransport) WriteMsg(Msg) error { 684 panic("WriteMsg called on setupTransport") 685 } 686 func (c *setupTransport) ReadMsg() (Msg, error) { 687 panic("ReadMsg called on setupTransport") 688 } 689 690 func newkey() *ecdsa.PrivateKey { 691 key, err := crypto.GenerateKey() 692 if err != nil { 693 panic("couldn't generate key: " + err.Error()) 694 } 695 return key 696 } 697 698 func randomID() (id enode.ID) { 699 for i := range id { 700 id[i] = byte(rand.Intn(255)) 701 } 702 return id 703 } 704 705 // This test checks that inbound connections are throttled by IP. 706 func TestServerInboundThrottle(t *testing.T) { 707 const timeout = 5 * time.Second 708 newTransportCalled := make(chan struct{}) 709 srv := &Server{ 710 Config: Config{ 711 PrivateKey: newkey(), 712 ListenAddr: "127.0.0.1:0", 713 MaxPeers: 10, 714 NoDial: true, 715 NoDiscovery: true, 716 Protocols: []Protocol{discard}, 717 Logger: testlog.Logger(t, log.LvlTrace), 718 }, 719 newTransport: func(fd net.Conn) transport { 720 newTransportCalled <- struct{}{} 721 return newRLPX(fd) 722 }, 723 listenFunc: func(network, laddr string) (net.Listener, error) { 724 fakeAddr := &net.TCPAddr{IP: net.IP{95, 33, 21, 2}, Port: 4444} 725 return listenFakeAddr(network, laddr, fakeAddr) 726 }, 727 } 728 if err := srv.Start(); err != nil { 729 t.Fatal("can't start: ", err) 730 } 731 defer srv.Stop() 732 733 // Dial the test server. 734 conn, err := net.DialTimeout("tcp", srv.ListenAddr, timeout) 735 if err != nil { 736 t.Fatalf("could not dial: %v", err) 737 } 738 select { 739 case <-newTransportCalled: 740 // OK 741 case <-time.After(timeout): 742 t.Error("newTransport not called") 743 } 744 conn.Close() 745 746 // Dial again. This time the server should close the connection immediately. 747 connClosed := make(chan struct{}) 748 conn, err = net.DialTimeout("tcp", srv.ListenAddr, timeout) 749 if err != nil { 750 t.Fatalf("could not dial: %v", err) 751 } 752 defer conn.Close() 753 go func() { 754 conn.SetDeadline(time.Now().Add(timeout)) 755 buf := make([]byte, 10) 756 if n, err := conn.Read(buf); err != io.EOF || n != 0 { 757 t.Errorf("expected io.EOF and n == 0, got error %q and n == %d", err, n) 758 } 759 connClosed <- struct{}{} 760 }() 761 select { 762 case <-connClosed: 763 // OK 764 case <-newTransportCalled: 765 t.Error("newTransport called for second attempt") 766 case <-time.After(timeout): 767 t.Error("connection not closed within timeout") 768 } 769 } 770 771 func listenFakeAddr(network, laddr string, remoteAddr net.Addr) (net.Listener, error) { 772 l, err := net.Listen(network, laddr) 773 if err == nil { 774 l = &fakeAddrListener{l, remoteAddr} 775 } 776 return l, err 777 } 778 779 // fakeAddrListener is a listener that creates connections with a mocked remote address. 780 type fakeAddrListener struct { 781 net.Listener 782 remoteAddr net.Addr 783 } 784 785 type fakeAddrConn struct { 786 net.Conn 787 remoteAddr net.Addr 788 } 789 790 func (l *fakeAddrListener) Accept() (net.Conn, error) { 791 c, err := l.Listener.Accept() 792 if err != nil { 793 return nil, err 794 } 795 return &fakeAddrConn{c, l.remoteAddr}, nil 796 } 797 798 func (c *fakeAddrConn) RemoteAddr() net.Addr { 799 return c.remoteAddr 800 }