github.com/klaytn/klaytn@v1.10.2/networks/p2p/server_test.go (about) 1 // Modifications Copyright 2018 The klaytn Authors 2 // Copyright 2014 The go-ethereum Authors 3 // This file is part of the go-ethereum library. 4 // 5 // The go-ethereum library is free software: you can redistribute it and/or modify 6 // it under the terms of the GNU Lesser General Public License as published by 7 // the Free Software Foundation, either version 3 of the License, or 8 // (at your option) any later version. 9 // 10 // The go-ethereum library is distributed in the hope that it will be useful, 11 // but WITHOUT ANY WARRANTY; without even the implied warranty of 12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 // GNU Lesser General Public License for more details. 14 // 15 // You should have received a copy of the GNU Lesser General Public License 16 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 17 // 18 // This file is derived from p2p/server_test.go (2018/06/04). 19 // Modified and improved for the klaytn development. 20 21 package p2p 22 23 import ( 24 "crypto/ecdsa" 25 "errors" 26 "math/rand" 27 "net" 28 "reflect" 29 "testing" 30 "time" 31 32 "github.com/klaytn/klaytn/common" 33 "github.com/klaytn/klaytn/crypto" 34 "github.com/klaytn/klaytn/crypto/sha3" 35 "github.com/klaytn/klaytn/networks/p2p/discover" 36 ) 37 38 func init() { 39 // log.Root().SetHandler(logger.LvlFilterHandler(logger.LvlError, logger.StreamHandler(os.Stderr, logger.TerminalFormat(false)))) 40 } 41 42 type testTransport struct { 43 id discover.NodeID 44 *rlpx 45 mutichannel bool 46 47 closeErr error 48 } 49 50 func newTestTransport(id discover.NodeID, fd net.Conn, mutichannel bool) transport { 51 wrapped := newRLPX(fd).(*rlpx) 52 wrapped.rw = newRLPXFrameRW(fd, secrets{ 53 MAC: zero16, 54 AES: zero16, 55 IngressMAC: sha3.NewKeccak256(), 56 EgressMAC: sha3.NewKeccak256(), 57 }) 58 return &testTransport{id: id, rlpx: wrapped, mutichannel: mutichannel} 59 } 60 61 func (c *testTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) { 62 return c.id, nil 63 } 64 65 func (c *testTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) { 66 return &protoHandshake{ID: c.id, Name: "test", Multichannel: c.mutichannel}, nil 67 } 68 69 func (c *testTransport) doConnTypeHandshake(myConnType common.ConnType) (common.ConnType, error) { 70 return 1, nil 71 } 72 73 func (c *testTransport) close(err error) { 74 c.rlpx.fd.Close() 75 c.closeErr = err 76 } 77 78 func startTestServer(t *testing.T, id discover.NodeID, pf func(*Peer), config *Config) Server { 79 config.Name = "test" 80 config.MaxPhysicalConnections = 10 81 config.ListenAddr = "127.0.0.1:0" 82 config.PrivateKey = newkey() 83 server := &SingleChannelServer{ 84 &BaseServer{ 85 Config: *config, 86 newPeerHook: pf, 87 newTransport: func(fd net.Conn) transport { 88 return newTestTransport(id, fd, false) 89 }, 90 }, 91 } 92 if err := server.Start(); err != nil { 93 t.Fatalf("Could not start server: %v", err) 94 } 95 return server 96 } 97 98 func startTestMultiChannelServer(t *testing.T, id discover.NodeID, pf func(*Peer), config *Config) Server { 99 config.Name = "test" 100 config.MaxPhysicalConnections = 10 101 config.PrivateKey = newkey() 102 103 listeners := make([]net.Listener, 0, len(config.SubListenAddr)+1) 104 listenAddrs := make([]string, 0, len(config.SubListenAddr)+1) 105 listenAddrs = append(listenAddrs, config.ListenAddr) 106 listenAddrs = append(listenAddrs, config.SubListenAddr...) 107 108 server := &MultiChannelServer{ 109 BaseServer: &BaseServer{ 110 Config: *config, 111 newPeerHook: pf, 112 newTransport: func(fd net.Conn) transport { 113 return newTestTransport(id, fd, true) 114 }, 115 }, 116 listeners: listeners, 117 ListenAddrs: listenAddrs, 118 CandidateConns: make(map[discover.NodeID][]*conn), 119 } 120 if err := server.Start(); err != nil { 121 t.Fatalf("Could not start server: %v", err) 122 } 123 return server 124 } 125 126 func makeconn(fd net.Conn, id discover.NodeID) *conn { 127 tx := newTestTransport(id, fd, false) 128 return &conn{fd: fd, transport: tx, flags: staticDialedConn, conntype: common.ConnTypeUndefined, id: id, cont: make(chan error)} 129 } 130 131 func makeMultiChannelConn(fd net.Conn, id discover.NodeID) *conn { 132 tx := newTestTransport(id, fd, true) 133 return &conn{fd: fd, transport: tx, flags: staticDialedConn, conntype: common.ConnTypeUndefined, id: id, cont: make(chan error), multiChannel: true} 134 } 135 136 func TestServerListen(t *testing.T) { 137 // start the test server 138 connected := make(chan *Peer) 139 remid := randomID() 140 srv := startTestServer(t, remid, func(p *Peer) { 141 if p.ID() != remid { 142 t.Error("peer func called with wrong node id") 143 } 144 if p == nil { 145 t.Error("peer func called with nil conn") 146 } 147 connected <- p 148 }, &Config{}) 149 defer close(connected) 150 defer srv.Stop() 151 152 // dial the test server 153 conn, err := net.DialTimeout("tcp", srv.GetListenAddress()[ConnDefault], 5*time.Second) 154 if err != nil { 155 t.Fatalf("could not dial: %v", err) 156 } 157 c := makeconn(conn, randomID()) 158 c.doConnTypeHandshake(c.conntype) 159 160 defer conn.Close() 161 162 select { 163 case peer := <-connected: 164 if peer.LocalAddr().String() != conn.RemoteAddr().String() { 165 t.Errorf("peer started with wrong conn: got %v, want %v", 166 peer.LocalAddr(), conn.RemoteAddr()) 167 } 168 169 peers := srv.Peers() 170 if !reflect.DeepEqual(peers, []*Peer{peer}) { 171 t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer}) 172 } 173 case <-time.After(5 * time.Second): 174 t.Error("server did not accept within one second") 175 } 176 } 177 178 func TestMultiChannelServerListen(t *testing.T) { 179 // start the test server 180 connected := make(chan *Peer) 181 remid := randomID() 182 config := &Config{ListenAddr: "127.0.0.1:33331", SubListenAddr: []string{"127.0.0.1:33333"}} 183 srv := startTestMultiChannelServer(t, remid, func(p *Peer) { 184 if p.ID() != remid { 185 t.Error("peer func called with wrong node id") 186 } 187 if p == nil { 188 t.Error("peer func called with nil conn") 189 } 190 connected <- p 191 }, config) 192 defer close(connected) 193 defer srv.Stop() 194 195 // dial the test server 196 var defaultConn net.Conn 197 198 for i, address := range srv.GetListenAddress() { 199 conn, err := net.DialTimeout("tcp", address, 5*time.Second) 200 defer conn.Close() 201 202 if i == ConnDefault { 203 defaultConn = conn 204 } 205 206 if err != nil { 207 t.Fatalf("could not dial: %v", err) 208 } 209 210 c := makeMultiChannelConn(conn, randomID()) 211 c.doConnTypeHandshake(c.conntype) 212 } 213 214 select { 215 case peer := <-connected: 216 if peer.LocalAddr().String() != defaultConn.RemoteAddr().String() { 217 t.Errorf("peer started with wrong conn: got %v, want %v", 218 peer.LocalAddr(), defaultConn.RemoteAddr()) 219 } 220 221 peers := srv.Peers() 222 if !reflect.DeepEqual(peers, []*Peer{peer}) { 223 t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer}) 224 } 225 case <-time.After(5 * time.Second): 226 t.Error("server did not accept within five second") 227 } 228 } 229 230 func TestServerNoListen(t *testing.T) { 231 // start the test server 232 connected := make(chan *Peer) 233 remid := randomID() 234 srv := startTestServer(t, remid, func(p *Peer) { 235 if p.ID() != remid { 236 t.Error("peer func called with wrong node id") 237 } 238 if p == nil { 239 t.Error("peer func called with nil conn") 240 } 241 connected <- p 242 }, &Config{NoListen: true}) 243 defer close(connected) 244 defer srv.Stop() 245 246 // dial the test server that will be failed 247 _, err := net.DialTimeout("tcp", srv.GetListenAddress()[ConnDefault], 10*time.Millisecond) 248 if err == nil { 249 t.Fatalf("server started with listening") 250 } 251 } 252 253 func TestServerDial(t *testing.T) { 254 // run a one-shot TCP server to handle the connection. 255 listener, err := net.Listen("tcp", "127.0.0.1:0") 256 if err != nil { 257 t.Fatalf("could not setup listener: %v", err) 258 } 259 defer listener.Close() 260 accepted := make(chan net.Conn) 261 go func() { 262 conn, err := listener.Accept() 263 if err != nil { 264 t.Error("accept error:", err) 265 return 266 } 267 268 c := makeconn(conn, randomID()) 269 c.doConnTypeHandshake(c.conntype) 270 accepted <- conn 271 }() 272 273 // start the server 274 connected := make(chan *Peer) 275 remid := randomID() 276 srv := startTestServer(t, remid, func(p *Peer) { connected <- p }, &Config{}) 277 defer close(connected) 278 defer srv.Stop() 279 280 // tell the server to connect 281 tcpAddr := listener.Addr().(*net.TCPAddr) 282 srv.AddPeer(&discover.Node{ID: remid, IP: tcpAddr.IP, TCP: uint16(tcpAddr.Port)}) 283 284 select { 285 case conn := <-accepted: 286 defer conn.Close() 287 288 select { 289 case peer := <-connected: 290 if peer.ID() != remid { 291 t.Errorf("peer has wrong id") 292 } 293 if peer.Name() != "test" { 294 t.Errorf("peer has wrong name") 295 } 296 if peer.RemoteAddr().String() != conn.LocalAddr().String() { 297 t.Errorf("peer started with wrong conn: got %v, want %v", 298 peer.RemoteAddr(), conn.LocalAddr()) 299 } 300 peers := srv.Peers() 301 if !reflect.DeepEqual(peers, []*Peer{peer}) { 302 t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer}) 303 } 304 case <-time.After(1 * time.Second): 305 t.Error("server did not launch peer within one second") 306 } 307 308 case <-time.After(1 * time.Second): 309 t.Error("server did not connect within one second") 310 } 311 } 312 313 // This test checks that tasks generated by dialstate are 314 // actually executed and taskdone is called for them. 315 func TestServerTaskScheduling(t *testing.T) { 316 var ( 317 done = make(chan *testTask) 318 quit, returned = make(chan struct{}), make(chan struct{}) 319 tc = 0 320 tg = taskgen{ 321 newFunc: func(running int, peers map[discover.NodeID]*Peer) []task { 322 tc++ 323 return []task{&testTask{index: tc - 1}} 324 }, 325 doneFunc: func(t task) { 326 select { 327 case done <- t.(*testTask): 328 case <-quit: 329 } 330 }, 331 } 332 ) 333 334 // The Server in this test isn't actually running 335 // because we're only interested in what run does. 336 srv := &SingleChannelServer{ 337 &BaseServer{ 338 Config: Config{MaxPhysicalConnections: 10}, 339 quit: make(chan struct{}), 340 ntab: fakeTable{}, 341 running: true, 342 logger: logger.NewWith(), 343 }, 344 } 345 srv.loopWG.Add(1) 346 go func() { 347 srv.run(tg) 348 close(returned) 349 }() 350 351 var gotdone []*testTask 352 for i := 0; i < 100; i++ { 353 gotdone = append(gotdone, <-done) 354 } 355 for i, task := range gotdone { 356 if task.index != i { 357 t.Errorf("task %d has wrong index, got %d", i, task.index) 358 break 359 } 360 if !task.called { 361 t.Errorf("task %d was not called", i) 362 break 363 } 364 } 365 366 close(quit) 367 srv.Stop() 368 select { 369 case <-returned: 370 case <-time.After(500 * time.Millisecond): 371 t.Error("Server.run did not return within 500ms") 372 } 373 } 374 375 // This test checks that Server doesn't drop tasks, 376 // even if newTasks returns more than the maximum number of tasks. 377 func TestServerManyTasks(t *testing.T) { 378 alltasks := make([]task, 300) 379 for i := range alltasks { 380 alltasks[i] = &testTask{index: i} 381 } 382 383 var ( 384 srv = &SingleChannelServer{ 385 &BaseServer{ 386 quit: make(chan struct{}), 387 ntab: fakeTable{}, 388 running: true, 389 logger: logger.NewWith(), 390 }, 391 } 392 done = make(chan *testTask) 393 start, end = 0, 0 394 ) 395 defer srv.Stop() 396 srv.loopWG.Add(1) 397 go srv.run(taskgen{ 398 newFunc: func(running int, peers map[discover.NodeID]*Peer) []task { 399 start, end = end, end+maxActiveDialTasks+10 400 if end > len(alltasks) { 401 end = len(alltasks) 402 } 403 return alltasks[start:end] 404 }, 405 doneFunc: func(tt task) { 406 done <- tt.(*testTask) 407 }, 408 }) 409 410 doneset := make(map[int]bool) 411 timeout := time.After(2 * time.Second) 412 for len(doneset) < len(alltasks) { 413 select { 414 case tt := <-done: 415 if doneset[tt.index] { 416 t.Errorf("task %d got done more than once", tt.index) 417 } else { 418 doneset[tt.index] = true 419 } 420 case <-timeout: 421 t.Errorf("%d of %d tasks got done within 2s", len(doneset), len(alltasks)) 422 for i := 0; i < len(alltasks); i++ { 423 if !doneset[i] { 424 t.Logf("task %d not done", i) 425 } 426 } 427 return 428 } 429 } 430 } 431 432 type taskgen struct { 433 newFunc func(running int, peers map[discover.NodeID]*Peer) []task 434 doneFunc func(task) 435 } 436 437 func (tg taskgen) newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task { 438 return tg.newFunc(running, peers) 439 } 440 441 func (tg taskgen) taskDone(t task, now time.Time) { 442 tg.doneFunc(t) 443 } 444 445 func (tg taskgen) addStatic(*discover.Node) { 446 } 447 448 func (tg taskgen) removeStatic(*discover.Node) { 449 } 450 451 type testTask struct { 452 index int 453 called bool 454 } 455 456 func (t *testTask) Do(srv Server) { 457 t.called = true 458 } 459 460 // This test checks that connections are disconnected 461 // just after the encryption handshake when the server is 462 // at capacity. Trusted connections should still be accepted. 463 func TestServerAtCap(t *testing.T) { 464 trustedID := randomID() 465 srv := &SingleChannelServer{ 466 BaseServer: &BaseServer{ 467 Config: Config{ 468 PrivateKey: newkey(), 469 MaxPhysicalConnections: 10, 470 NoDial: true, 471 TrustedNodes: []*discover.Node{{ID: trustedID}}, 472 }, 473 }, 474 } 475 if err := srv.Start(); err != nil { 476 t.Fatalf("could not start: %v", err) 477 } 478 defer srv.Stop() 479 480 newconn := func(id discover.NodeID) *conn { 481 fd, _ := net.Pipe() 482 tx := newTestTransport(id, fd, false) 483 return &conn{fd: fd, transport: tx, flags: inboundConn, conntype: common.ConnTypeUndefined, id: id, cont: make(chan error)} 484 } 485 486 // Inject a few connections to fill up the peer set. 487 for i := 0; i < 10; i++ { 488 c := newconn(randomID()) 489 if err := srv.checkpoint(c, srv.addpeer); err != nil { 490 t.Fatalf("could not add conn %d: %v", i, err) 491 } 492 } 493 // Try inserting a non-trusted connection. 494 c := newconn(randomID()) 495 if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers { 496 t.Error("wrong error for insert:", err) 497 } 498 // Try inserting a trusted connection. 499 c = newconn(trustedID) 500 if err := srv.checkpoint(c, srv.posthandshake); err != nil { 501 t.Error("unexpected error for trusted conn @posthandshake:", err) 502 } 503 if !c.is(trustedConn) { 504 t.Error("Server did not set trusted flag") 505 } 506 } 507 508 func TestServerSetupConn(t *testing.T) { 509 id := randomID() 510 srvkey := newkey() 511 srvid := discover.PubkeyID(&srvkey.PublicKey) 512 tests := []struct { 513 dontstart bool 514 tt *setupTransport 515 flags connFlag 516 dialDest *discover.Node 517 518 wantCloseErr error 519 wantCalls string 520 }{ 521 { 522 dontstart: true, 523 tt: &setupTransport{id: id}, 524 wantCalls: "close,", 525 wantCloseErr: errServerStopped, 526 }, 527 { 528 tt: &setupTransport{id: id, encHandshakeErr: errors.New("read error")}, 529 flags: inboundConn, 530 wantCalls: "doEncHandshake,close,", 531 wantCloseErr: errors.New("read error"), 532 }, 533 { 534 tt: &setupTransport{id: id}, 535 dialDest: &discover.Node{ID: randomID(), NType: discover.NodeType(common.ENDPOINTNODE)}, 536 flags: dynDialedConn, 537 wantCalls: "doEncHandshake,close,", 538 wantCloseErr: DiscUnexpectedIdentity, 539 }, 540 { 541 tt: &setupTransport{id: id, phs: &protoHandshake{ID: randomID()}}, 542 dialDest: &discover.Node{ID: id, NType: discover.NodeType(common.ENDPOINTNODE)}, 543 flags: dynDialedConn, 544 wantCalls: "doEncHandshake,doProtoHandshake,close,", 545 wantCloseErr: DiscUnexpectedIdentity, 546 }, 547 { 548 tt: &setupTransport{id: id, protoHandshakeErr: errors.New("foo")}, 549 dialDest: &discover.Node{ID: id, NType: discover.NodeType(common.ENDPOINTNODE)}, 550 flags: dynDialedConn, 551 wantCalls: "doEncHandshake,doProtoHandshake,close,", 552 wantCloseErr: errors.New("foo"), 553 }, 554 { 555 tt: &setupTransport{id: srvid, phs: &protoHandshake{ID: srvid}}, 556 flags: inboundConn, 557 wantCalls: "doEncHandshake,close,", 558 wantCloseErr: DiscSelf, 559 }, 560 { 561 tt: &setupTransport{id: id, phs: &protoHandshake{ID: id}}, 562 flags: inboundConn, 563 wantCalls: "doEncHandshake,doProtoHandshake,close,", 564 wantCloseErr: DiscUselessPeer, 565 }, 566 } 567 568 for i, test := range tests { 569 srv := &SingleChannelServer{ 570 &BaseServer{ 571 Config: Config{ 572 PrivateKey: srvkey, 573 MaxPhysicalConnections: 10, 574 NoDial: true, 575 Protocols: []Protocol{discard}, 576 ConnectionType: 1, // ENDPOINTNODE 577 }, 578 newTransport: func(fd net.Conn) transport { return test.tt }, 579 logger: logger.NewWith(), 580 }, 581 } 582 if !test.dontstart { 583 if err := srv.Start(); err != nil { 584 t.Fatalf("couldn't start server: %v", err) 585 } 586 } 587 p1, _ := net.Pipe() 588 srv.SetupConn(p1, test.flags, test.dialDest) 589 if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) { 590 t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr) 591 } 592 if test.tt.calls != test.wantCalls { 593 t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls) 594 } 595 } 596 } 597 598 type setupTransport struct { 599 id discover.NodeID 600 encHandshakeErr error 601 602 phs *protoHandshake 603 protoHandshakeErr error 604 605 calls string 606 closeErr error 607 } 608 609 func (c *setupTransport) doConnTypeHandshake(myConnType common.ConnType) (common.ConnType, error) { 610 return 1, nil 611 } 612 613 func (c *setupTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) { 614 c.calls += "doEncHandshake," 615 return c.id, c.encHandshakeErr 616 } 617 618 func (c *setupTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) { 619 c.calls += "doProtoHandshake," 620 if c.protoHandshakeErr != nil { 621 return nil, c.protoHandshakeErr 622 } 623 return c.phs, nil 624 } 625 626 func (c *setupTransport) close(err error) { 627 c.calls += "close," 628 c.closeErr = err 629 } 630 631 // setupConn shouldn't write to/read from the connection. 632 func (c *setupTransport) WriteMsg(Msg) error { 633 panic("WriteMsg called on setupTransport") 634 } 635 636 func (c *setupTransport) ReadMsg() (Msg, error) { 637 panic("ReadMsg called on setupTransport") 638 } 639 640 func newkey() *ecdsa.PrivateKey { 641 key, err := crypto.GenerateKey() 642 if err != nil { 643 panic("couldn't generate key: " + err.Error()) 644 } 645 return key 646 } 647 648 func randomID() (id discover.NodeID) { 649 for i := range id { 650 id[i] = byte(rand.Intn(255)) 651 } 652 return id 653 }