github.com/klaytn/klaytn@v1.12.1/networks/p2p/server_test.go (about) 1 // Modifications Copyright 2018 The klaytn Authors 2 // Copyright 2014 The go-ethereum Authors 3 // This file is part of the go-ethereum library. 4 // 5 // The go-ethereum library is free software: you can redistribute it and/or modify 6 // it under the terms of the GNU Lesser General Public License as published by 7 // the Free Software Foundation, either version 3 of the License, or 8 // (at your option) any later version. 9 // 10 // The go-ethereum library is distributed in the hope that it will be useful, 11 // but WITHOUT ANY WARRANTY; without even the implied warranty of 12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 // GNU Lesser General Public License for more details. 14 // 15 // You should have received a copy of the GNU Lesser General Public License 16 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 17 // 18 // This file is derived from p2p/server_test.go (2018/06/04). 19 // Modified and improved for the klaytn development. 20 21 package p2p 22 23 import ( 24 "crypto/ecdsa" 25 "errors" 26 "math/rand" 27 "net" 28 "reflect" 29 "testing" 30 "time" 31 32 "github.com/klaytn/klaytn/common" 33 "github.com/klaytn/klaytn/crypto" 34 "github.com/klaytn/klaytn/crypto/sha3" 35 "github.com/klaytn/klaytn/networks/p2p/discover" 36 "github.com/klaytn/klaytn/networks/p2p/rlpx" 37 ) 38 39 func init() { 40 // log.Root().SetHandler(logger.LvlFilterHandler(logger.LvlError, logger.StreamHandler(os.Stderr, logger.TerminalFormat(false)))) 41 } 42 43 type testTransport struct { 44 id discover.NodeID 45 *rlpxTransport 46 mutichannel bool 47 48 closeErr error 49 } 50 51 func newTestTransport(id discover.NodeID, fd net.Conn, dialDest *ecdsa.PublicKey, mutichannel bool) transport { 52 wrapped := newRLPX(fd, dialDest).(*rlpxTransport) 53 wrapped.conn.InitWithSecrets(rlpx.Secrets{ 54 MAC: make([]byte, 16), 55 AES: make([]byte, 16), 56 IngressMAC: sha3.NewKeccak256(), 57 EgressMAC: sha3.NewKeccak256(), 58 }) 59 return &testTransport{id: id, rlpxTransport: wrapped, mutichannel: mutichannel} 60 } 61 62 func (c *testTransport) doEncHandshake(prv *ecdsa.PrivateKey) (*ecdsa.PublicKey, error) { 63 remoteKey, _ := c.id.Pubkey() 64 return remoteKey, nil 65 } 66 67 func (c *testTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) { 68 return &protoHandshake{ID: c.id, Name: "test", Multichannel: c.mutichannel}, nil 69 } 70 71 func (c *testTransport) doConnTypeHandshake(myConnType common.ConnType) (common.ConnType, error) { 72 return 1, nil 73 } 74 75 func (c *testTransport) close(err error) { 76 c.conn.Close() 77 c.closeErr = err 78 } 79 80 func startTestServer(t *testing.T, id discover.NodeID, pf func(*Peer), config *Config) Server { 81 config.Name = "test" 82 config.MaxPhysicalConnections = 10 83 config.ListenAddr = "127.0.0.1:0" 84 config.PrivateKey = newkey() 85 server := &SingleChannelServer{ 86 &BaseServer{ 87 Config: *config, 88 newPeerHook: pf, 89 newTransport: func(fd net.Conn, dialDest *ecdsa.PublicKey) transport { 90 return newTestTransport(id, fd, dialDest, false) 91 }, 92 }, 93 } 94 if err := server.Start(); err != nil { 95 t.Fatalf("Could not start server: %v", err) 96 } 97 return server 98 } 99 100 func startTestMultiChannelServer(t *testing.T, id discover.NodeID, pf func(*Peer), config *Config) Server { 101 config.Name = "test" 102 config.MaxPhysicalConnections = 10 103 config.PrivateKey = newkey() 104 105 listeners := make([]net.Listener, 0, len(config.SubListenAddr)+1) 106 listenAddrs := make([]string, 0, len(config.SubListenAddr)+1) 107 listenAddrs = append(listenAddrs, config.ListenAddr) 108 listenAddrs = append(listenAddrs, config.SubListenAddr...) 109 110 server := &MultiChannelServer{ 111 BaseServer: &BaseServer{ 112 Config: *config, 113 newPeerHook: pf, 114 newTransport: func(fd net.Conn, dialDest *ecdsa.PublicKey) transport { 115 return newTestTransport(id, fd, dialDest, true) 116 }, 117 }, 118 listeners: listeners, 119 ListenAddrs: listenAddrs, 120 CandidateConns: make(map[discover.NodeID][]*conn), 121 } 122 if err := server.Start(); err != nil { 123 t.Fatalf("Could not start server: %v", err) 124 } 125 return server 126 } 127 128 func makeconn(fd net.Conn, id discover.NodeID) *conn { 129 dialDest, _ := id.Pubkey() 130 tx := newTestTransport(id, fd, dialDest, false) 131 return &conn{fd: fd, transport: tx, flags: staticDialedConn, conntype: common.ConnTypeUndefined, id: id, cont: make(chan error)} 132 } 133 134 func makeMultiChannelConn(fd net.Conn, id discover.NodeID) *conn { 135 dialDest, _ := id.Pubkey() 136 tx := newTestTransport(id, fd, dialDest, true) 137 return &conn{fd: fd, transport: tx, flags: staticDialedConn, conntype: common.ConnTypeUndefined, id: id, cont: make(chan error), multiChannel: true} 138 } 139 140 func TestServerListen(t *testing.T) { 141 // start the test server 142 connected := make(chan *Peer) 143 remid := discover.PubkeyID(&newkey().PublicKey) 144 srv := startTestServer(t, remid, func(p *Peer) { 145 if p.ID() != remid { 146 t.Error("peer func called with wrong node id") 147 } 148 if p == nil { 149 t.Error("peer func called with nil conn") 150 } 151 connected <- p 152 }, &Config{}) 153 defer close(connected) 154 defer srv.Stop() 155 156 // dial the test server 157 conn, err := net.DialTimeout("tcp", srv.GetListenAddress()[ConnDefault], 5*time.Second) 158 if err != nil { 159 t.Fatalf("could not dial: %v", err) 160 } 161 c := makeconn(conn, randomID()) 162 c.doConnTypeHandshake(c.conntype) 163 164 defer conn.Close() 165 166 select { 167 case peer := <-connected: 168 if peer.LocalAddr().String() != conn.RemoteAddr().String() { 169 t.Errorf("peer started with wrong conn: got %v, want %v", 170 peer.LocalAddr(), conn.RemoteAddr()) 171 } 172 173 peers := srv.Peers() 174 if !reflect.DeepEqual(peers, []*Peer{peer}) { 175 t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer}) 176 } 177 case <-time.After(5 * time.Second): 178 t.Error("server did not accept within one second") 179 } 180 } 181 182 func TestMultiChannelServerListen(t *testing.T) { 183 // start the test server 184 connected := make(chan *Peer) 185 remid := discover.PubkeyID(&newkey().PublicKey) 186 config := &Config{ListenAddr: "127.0.0.1:33331", SubListenAddr: []string{"127.0.0.1:33333"}} 187 srv := startTestMultiChannelServer(t, remid, func(p *Peer) { 188 if p.ID() != remid { 189 t.Error("peer func called with wrong node id") 190 } 191 if p == nil { 192 t.Error("peer func called with nil conn") 193 } 194 connected <- p 195 }, config) 196 defer close(connected) 197 defer srv.Stop() 198 199 // dial the test server 200 var defaultConn net.Conn 201 202 for i, address := range srv.GetListenAddress() { 203 conn, err := net.DialTimeout("tcp", address, 5*time.Second) 204 defer conn.Close() 205 206 if i == ConnDefault { 207 defaultConn = conn 208 } 209 210 if err != nil { 211 t.Fatalf("could not dial: %v", err) 212 } 213 } 214 215 select { 216 case peer := <-connected: 217 if peer.LocalAddr().String() != defaultConn.RemoteAddr().String() { 218 t.Errorf("peer started with wrong conn: got %v, want %v", 219 peer.LocalAddr(), defaultConn.RemoteAddr()) 220 } 221 222 peers := srv.Peers() 223 if !reflect.DeepEqual(peers, []*Peer{peer}) { 224 t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer}) 225 } 226 case <-time.After(5 * time.Second): 227 t.Error("server did not accept within five second") 228 } 229 } 230 231 func TestServerNoListen(t *testing.T) { 232 // start the test server 233 connected := make(chan *Peer) 234 remid := discover.PubkeyID(&newkey().PublicKey) 235 srv := startTestServer(t, remid, func(p *Peer) { 236 if p.ID() != remid { 237 t.Error("peer func called with wrong node id") 238 } 239 if p == nil { 240 t.Error("peer func called with nil conn") 241 } 242 connected <- p 243 }, &Config{NoListen: true}) 244 defer close(connected) 245 defer srv.Stop() 246 247 // dial the test server that will be failed 248 _, err := net.DialTimeout("tcp", srv.GetListenAddress()[ConnDefault], 10*time.Millisecond) 249 if err == nil { 250 t.Fatalf("server started with listening") 251 } 252 } 253 254 func TestServerDial(t *testing.T) { 255 // run a one-shot TCP server to handle the connection. 256 listener, err := net.Listen("tcp", "127.0.0.1:0") 257 if err != nil { 258 t.Fatalf("could not setup listener: %v", err) 259 } 260 defer listener.Close() 261 accepted := make(chan net.Conn) 262 go func() { 263 conn, err := listener.Accept() 264 if err != nil { 265 t.Error("accept error:", err) 266 return 267 } 268 269 c := makeconn(conn, discover.PubkeyID(&newkey().PublicKey)) 270 c.doConnTypeHandshake(c.conntype) 271 accepted <- conn 272 }() 273 274 // start the server 275 connected := make(chan *Peer) 276 remid := discover.PubkeyID(&newkey().PublicKey) 277 srv := startTestServer(t, remid, func(p *Peer) { connected <- p }, &Config{}) 278 defer close(connected) 279 defer srv.Stop() 280 281 // tell the server to connect 282 tcpAddr := listener.Addr().(*net.TCPAddr) 283 srv.AddPeer(&discover.Node{ID: remid, IP: tcpAddr.IP, TCP: uint16(tcpAddr.Port)}) 284 285 select { 286 case conn := <-accepted: 287 defer conn.Close() 288 289 select { 290 case peer := <-connected: 291 if peer.ID() != remid { 292 t.Errorf("peer has wrong id") 293 } 294 if peer.Name() != "test" { 295 t.Errorf("peer has wrong name") 296 } 297 if peer.RemoteAddr().String() != conn.LocalAddr().String() { 298 t.Errorf("peer started with wrong conn: got %v, want %v", 299 peer.RemoteAddr(), conn.LocalAddr()) 300 } 301 peers := srv.Peers() 302 if !reflect.DeepEqual(peers, []*Peer{peer}) { 303 t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer}) 304 } 305 case <-time.After(1 * time.Second): 306 t.Error("server did not launch peer within one second") 307 } 308 309 case <-time.After(1 * time.Second): 310 t.Error("server did not connect within one second") 311 } 312 } 313 314 // This test checks that tasks generated by dialstate are 315 // actually executed and taskdone is called for them. 316 func TestServerTaskScheduling(t *testing.T) { 317 var ( 318 done = make(chan *testTask) 319 quit, returned = make(chan struct{}), make(chan struct{}) 320 tc = 0 321 tg = taskgen{ 322 newFunc: func(running int, peers map[discover.NodeID]*Peer) []task { 323 tc++ 324 return []task{&testTask{index: tc - 1}} 325 }, 326 doneFunc: func(t task) { 327 select { 328 case done <- t.(*testTask): 329 case <-quit: 330 } 331 }, 332 } 333 ) 334 335 // The Server in this test isn't actually running 336 // because we're only interested in what run does. 337 srv := &SingleChannelServer{ 338 &BaseServer{ 339 Config: Config{MaxPhysicalConnections: 10}, 340 quit: make(chan struct{}), 341 ntab: fakeTable{}, 342 running: true, 343 logger: logger.NewWith(), 344 }, 345 } 346 srv.loopWG.Add(1) 347 go func() { 348 srv.run(tg) 349 close(returned) 350 }() 351 352 var gotdone []*testTask 353 for i := 0; i < 100; i++ { 354 gotdone = append(gotdone, <-done) 355 } 356 for i, task := range gotdone { 357 if task.index != i { 358 t.Errorf("task %d has wrong index, got %d", i, task.index) 359 break 360 } 361 if !task.called { 362 t.Errorf("task %d was not called", i) 363 break 364 } 365 } 366 367 close(quit) 368 srv.Stop() 369 select { 370 case <-returned: 371 case <-time.After(500 * time.Millisecond): 372 t.Error("Server.run did not return within 500ms") 373 } 374 } 375 376 // This test checks that Server doesn't drop tasks, 377 // even if newTasks returns more than the maximum number of tasks. 378 func TestServerManyTasks(t *testing.T) { 379 alltasks := make([]task, 300) 380 for i := range alltasks { 381 alltasks[i] = &testTask{index: i} 382 } 383 384 var ( 385 srv = &SingleChannelServer{ 386 &BaseServer{ 387 quit: make(chan struct{}), 388 ntab: fakeTable{}, 389 running: true, 390 logger: logger.NewWith(), 391 }, 392 } 393 done = make(chan *testTask) 394 start, end = 0, 0 395 ) 396 defer srv.Stop() 397 srv.loopWG.Add(1) 398 go srv.run(taskgen{ 399 newFunc: func(running int, peers map[discover.NodeID]*Peer) []task { 400 start, end = end, end+maxActiveDialTasks+10 401 if end > len(alltasks) { 402 end = len(alltasks) 403 } 404 return alltasks[start:end] 405 }, 406 doneFunc: func(tt task) { 407 done <- tt.(*testTask) 408 }, 409 }) 410 411 doneset := make(map[int]bool) 412 timeout := time.After(2 * time.Second) 413 for len(doneset) < len(alltasks) { 414 select { 415 case tt := <-done: 416 if doneset[tt.index] { 417 t.Errorf("task %d got done more than once", tt.index) 418 } else { 419 doneset[tt.index] = true 420 } 421 case <-timeout: 422 t.Errorf("%d of %d tasks got done within 2s", len(doneset), len(alltasks)) 423 for i := 0; i < len(alltasks); i++ { 424 if !doneset[i] { 425 t.Logf("task %d not done", i) 426 } 427 } 428 return 429 } 430 } 431 } 432 433 type taskgen struct { 434 newFunc func(running int, peers map[discover.NodeID]*Peer) []task 435 doneFunc func(task) 436 } 437 438 func (tg taskgen) newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task { 439 return tg.newFunc(running, peers) 440 } 441 442 func (tg taskgen) taskDone(t task, now time.Time) { 443 tg.doneFunc(t) 444 } 445 446 func (tg taskgen) addStatic(*discover.Node) { 447 } 448 449 func (tg taskgen) removeStatic(*discover.Node) { 450 } 451 452 type testTask struct { 453 index int 454 called bool 455 } 456 457 func (t *testTask) Do(srv Server) { 458 t.called = true 459 } 460 461 // This test checks that connections are disconnected 462 // just after the encryption handshake when the server is 463 // at capacity. Trusted connections should still be accepted. 464 func TestServerAtCap(t *testing.T) { 465 trustedID := randomID() 466 srv := &SingleChannelServer{ 467 BaseServer: &BaseServer{ 468 Config: Config{ 469 PrivateKey: newkey(), 470 MaxPhysicalConnections: 10, 471 NoDial: true, 472 TrustedNodes: []*discover.Node{{ID: trustedID}}, 473 }, 474 }, 475 } 476 if err := srv.Start(); err != nil { 477 t.Fatalf("could not start: %v", err) 478 } 479 defer srv.Stop() 480 481 newconn := func(id discover.NodeID) *conn { 482 fd, _ := net.Pipe() 483 tx := newTestTransport(id, fd, nil, false) 484 return &conn{fd: fd, transport: tx, flags: inboundConn, conntype: common.ConnTypeUndefined, id: id, cont: make(chan error)} 485 } 486 487 // Inject a few connections to fill up the peer set. 488 for i := 0; i < 10; i++ { 489 c := newconn(randomID()) 490 if err := srv.checkpoint(c, srv.addpeer); err != nil { 491 t.Fatalf("could not add conn %d: %v", i, err) 492 } 493 } 494 // Try inserting a non-trusted connection. 495 c := newconn(randomID()) 496 if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers { 497 t.Error("wrong error for insert:", err) 498 } 499 // Try inserting a trusted connection. 500 c = newconn(trustedID) 501 if err := srv.checkpoint(c, srv.posthandshake); err != nil { 502 t.Error("unexpected error for trusted conn @posthandshake:", err) 503 } 504 if !c.is(trustedConn) { 505 t.Error("Server did not set trusted flag") 506 } 507 } 508 509 func TestServerSetupConn(t *testing.T) { 510 var ( 511 id = discover.PubkeyID(&newkey().PublicKey) 512 srvkey = newkey() 513 srvid = discover.PubkeyID(&srvkey.PublicKey) 514 ) 515 516 tests := []struct { 517 dontstart bool 518 tt *setupTransport 519 flags connFlag 520 dialDest *discover.Node 521 522 wantCloseErr error 523 wantCalls string 524 }{ 525 { 526 dontstart: true, 527 tt: &setupTransport{id: id}, 528 wantCalls: "close,", 529 wantCloseErr: errServerStopped, 530 }, 531 { 532 tt: &setupTransport{id: id, encHandshakeErr: errors.New("read error")}, 533 flags: inboundConn, 534 wantCalls: "doEncHandshake,close,", 535 wantCloseErr: errors.New("read error"), 536 }, 537 { 538 tt: &setupTransport{id: id, phs: &protoHandshake{ID: randomID()}}, 539 dialDest: &discover.Node{ID: id, NType: discover.NodeType(common.ENDPOINTNODE)}, 540 flags: dynDialedConn, 541 wantCalls: "doEncHandshake,doProtoHandshake,close,", 542 wantCloseErr: DiscUnexpectedIdentity, 543 }, 544 { 545 tt: &setupTransport{id: id, protoHandshakeErr: errors.New("foo")}, 546 dialDest: &discover.Node{ID: id, NType: discover.NodeType(common.ENDPOINTNODE)}, 547 flags: dynDialedConn, 548 wantCalls: "doEncHandshake,doProtoHandshake,close,", 549 wantCloseErr: errors.New("foo"), 550 }, 551 { 552 tt: &setupTransport{id: srvid, phs: &protoHandshake{ID: srvid}}, 553 flags: inboundConn, 554 wantCalls: "doEncHandshake,close,", 555 wantCloseErr: DiscSelf, 556 }, 557 { 558 tt: &setupTransport{id: id, phs: &protoHandshake{ID: id}}, 559 flags: inboundConn, 560 wantCalls: "doEncHandshake,doProtoHandshake,close,", 561 wantCloseErr: DiscUselessPeer, 562 }, 563 } 564 565 for i, test := range tests { 566 srv := &SingleChannelServer{ 567 &BaseServer{ 568 Config: Config{ 569 PrivateKey: srvkey, 570 MaxPhysicalConnections: 10, 571 NoDial: true, 572 Protocols: []Protocol{discard}, 573 ConnectionType: 1, // ENDPOINTNODE 574 }, 575 newTransport: func(fd net.Conn, dialDest *ecdsa.PublicKey) transport { return test.tt }, 576 logger: logger.NewWith(), 577 }, 578 } 579 if !test.dontstart { 580 if err := srv.Start(); err != nil { 581 t.Fatalf("couldn't start server: %v", err) 582 } 583 } 584 p1, _ := net.Pipe() 585 srv.SetupConn(p1, test.flags, test.dialDest) 586 if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) { 587 t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr) 588 } 589 if test.tt.calls != test.wantCalls { 590 t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls) 591 } 592 } 593 } 594 595 type setupTransport struct { 596 id discover.NodeID 597 encHandshakeErr error 598 599 phs *protoHandshake 600 protoHandshakeErr error 601 602 calls string 603 closeErr error 604 } 605 606 func (c *setupTransport) doConnTypeHandshake(myConnType common.ConnType) (common.ConnType, error) { 607 return 1, nil 608 } 609 610 func (c *setupTransport) doEncHandshake(prv *ecdsa.PrivateKey) (*ecdsa.PublicKey, error) { 611 c.calls += "doEncHandshake," 612 pubkey, _ := c.id.Pubkey() 613 return pubkey, c.encHandshakeErr 614 } 615 616 func (c *setupTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) { 617 c.calls += "doProtoHandshake," 618 if c.protoHandshakeErr != nil { 619 return nil, c.protoHandshakeErr 620 } 621 return c.phs, nil 622 } 623 624 func (c *setupTransport) close(err error) { 625 c.calls += "close," 626 c.closeErr = err 627 } 628 629 // setupConn shouldn't write to/read from the connection. 630 func (c *setupTransport) WriteMsg(Msg) error { 631 panic("WriteMsg called on setupTransport") 632 } 633 634 func (c *setupTransport) ReadMsg() (Msg, error) { 635 panic("ReadMsg called on setupTransport") 636 } 637 638 func newkey() *ecdsa.PrivateKey { 639 key, err := crypto.GenerateKey() 640 if err != nil { 641 panic("couldn't generate key: " + err.Error()) 642 } 643 return key 644 } 645 646 func randomID() (id discover.NodeID) { 647 for i := range id { 648 id[i] = byte(rand.Intn(255)) 649 } 650 return id 651 }