gitlab.com/aquachain/aquachain@v1.17.16-rc3.0.20221018032414-e3ddf1e1c055/p2p/server_test.go (about) 1 // Copyright 2018 The aquachain Authors 2 // This file is part of the aquachain library. 3 // 4 // The aquachain library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The aquachain library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the aquachain library. If not, see <http://www.gnu.org/licenses/>. 16 17 package p2p 18 19 import ( 20 "crypto/ecdsa" 21 "errors" 22 "math/rand" 23 "net" 24 "reflect" 25 "testing" 26 "time" 27 28 "gitlab.com/aquachain/aquachain/common/log" 29 "gitlab.com/aquachain/aquachain/crypto" 30 "gitlab.com/aquachain/aquachain/crypto/sha3" 31 "gitlab.com/aquachain/aquachain/p2p/discover" 32 ) 33 34 func init() { 35 // log.Root().SetHandler(log.LvlFilterHandler(log.LvlError, log.StreamHandler(os.Stderr, log.TerminalFormat(false)))) 36 } 37 38 type testTransport struct { 39 id discover.NodeID 40 *rlpx 41 42 closeErr error 43 } 44 45 func newTestTransport(id discover.NodeID, fd net.Conn) transport { 46 wrapped := newRLPX(fd).(*rlpx) 47 wrapped.rw = newRLPXFrameRW(fd, secrets{ 48 MAC: zero16, 49 AES: zero16, 50 IngressMAC: sha3.NewKeccak256(), 51 EgressMAC: sha3.NewKeccak256(), 52 }) 53 return &testTransport{id: id, rlpx: wrapped} 54 } 55 56 func (c *testTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) { 57 return c.id, nil 58 } 59 60 func (c *testTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) { 61 return &protoHandshake{ID: c.id, Name: "test"}, nil 62 } 63 64 func (c *testTransport) close(err error) { 65 c.rlpx.fd.Close() 66 c.closeErr = err 67 } 68 69 func startTestServer(t *testing.T, id discover.NodeID, pf func(*Peer)) *Server { 70 config := Config{ 71 Name: "test", 72 MaxPeers: 10, 73 ListenAddr: "127.0.0.1:0", 74 PrivateKey: newkey(), 75 ChainId: 222, 76 } 77 server := &Server{ 78 Config: config, 79 newPeerHook: pf, 80 newTransport: func(fd net.Conn) transport { return newTestTransport(id, fd) }, 81 } 82 if err := server.Start(); err != nil { 83 t.Fatalf("Could not start server: %v", err) 84 } 85 return server 86 } 87 88 func TestServerListen(t *testing.T) { 89 // start the test server 90 connected := make(chan *Peer) 91 remid := randomID() 92 srv := startTestServer(t, remid, func(p *Peer) { 93 if p.ID() != remid { 94 t.Error("peer func called with wrong node id") 95 } 96 if p == nil { 97 t.Error("peer func called with nil conn") 98 } 99 connected <- p 100 }) 101 defer close(connected) 102 defer srv.Stop() 103 104 // dial the test server 105 conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second) 106 if err != nil { 107 t.Fatalf("could not dial: %v", err) 108 } 109 defer conn.Close() 110 111 select { 112 case peer := <-connected: 113 if peer.LocalAddr().String() != conn.RemoteAddr().String() { 114 t.Errorf("peer started with wrong conn: got %v, want %v", 115 peer.LocalAddr(), conn.RemoteAddr()) 116 } 117 peers := srv.Peers() 118 if !reflect.DeepEqual(peers, []*Peer{peer}) { 119 t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer}) 120 } 121 case <-time.After(1 * time.Second): 122 t.Error("server did not accept within one second") 123 } 124 } 125 126 func TestServerDial(t *testing.T) { 127 // run a one-shot TCP server to handle the connection. 128 listener, err := net.Listen("tcp", "127.0.0.1:0") 129 if err != nil { 130 t.Fatalf("could not setup listener: %v", err) 131 } 132 defer listener.Close() 133 accepted := make(chan net.Conn) 134 go func() { 135 conn, err := listener.Accept() 136 if err != nil { 137 t.Error("accept error:", err) 138 return 139 } 140 accepted <- conn 141 }() 142 143 // start the server 144 connected := make(chan *Peer) 145 remid := randomID() 146 srv := startTestServer(t, remid, func(p *Peer) { connected <- p }) 147 defer close(connected) 148 defer srv.Stop() 149 150 // tell the server to connect 151 tcpAddr := listener.Addr().(*net.TCPAddr) 152 srv.AddPeer(&discover.Node{ID: remid, IP: tcpAddr.IP, TCP: uint16(tcpAddr.Port)}) 153 154 select { 155 case conn := <-accepted: 156 defer conn.Close() 157 158 select { 159 case peer := <-connected: 160 if peer.ID() != remid { 161 t.Errorf("peer has wrong id") 162 } 163 if peer.Name() != "test" { 164 t.Errorf("peer has wrong name") 165 } 166 if peer.RemoteAddr().String() != conn.LocalAddr().String() { 167 t.Errorf("peer started with wrong conn: got %v, want %v", 168 peer.RemoteAddr(), conn.LocalAddr()) 169 } 170 peers := srv.Peers() 171 if !reflect.DeepEqual(peers, []*Peer{peer}) { 172 t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer}) 173 } 174 case <-time.After(1 * time.Second): 175 t.Error("server did not launch peer within one second") 176 } 177 178 case <-time.After(1 * time.Second): 179 t.Error("server did not connect within one second") 180 } 181 } 182 183 // This test checks that tasks generated by dialstate are 184 // actually executed and taskdone is called for them. 185 func TestServerTaskScheduling(t *testing.T) { 186 var ( 187 done = make(chan *testTask) 188 quit, returned = make(chan struct{}), make(chan struct{}) 189 tc = 0 190 tg = taskgen{ 191 newFunc: func(running int, peers map[discover.NodeID]*Peer) []task { 192 tc++ 193 return []task{&testTask{index: tc - 1}} 194 }, 195 doneFunc: func(t task) { 196 select { 197 case done <- t.(*testTask): 198 case <-quit: 199 } 200 }, 201 } 202 ) 203 204 // The Server in this test isn't actually running 205 // because we're only interested in what run does. 206 srv := &Server{ 207 Config: Config{MaxPeers: 10}, 208 quit: make(chan struct{}), 209 ntab: fakeTable{}, 210 running: true, 211 log: log.New(), 212 } 213 srv.loopWG.Add(1) 214 go func() { 215 srv.run(tg) 216 close(returned) 217 }() 218 219 var gotdone []*testTask 220 for i := 0; i < 100; i++ { 221 gotdone = append(gotdone, <-done) 222 } 223 for i, task := range gotdone { 224 if task.index != i { 225 t.Errorf("task %d has wrong index, got %d", i, task.index) 226 break 227 } 228 if !task.called { 229 t.Errorf("task %d was not called", i) 230 break 231 } 232 } 233 234 close(quit) 235 srv.Stop() 236 select { 237 case <-returned: 238 case <-time.After(500 * time.Millisecond): 239 t.Error("Server.run did not return within 500ms") 240 } 241 } 242 243 // This test checks that Server doesn't drop tasks, 244 // even if newTasks returns more than the maximum number of tasks. 245 func TestServerManyTasks(t *testing.T) { 246 alltasks := make([]task, 300) 247 for i := range alltasks { 248 alltasks[i] = &testTask{index: i} 249 } 250 251 var ( 252 srv = &Server{ 253 quit: make(chan struct{}), 254 ntab: fakeTable{}, 255 running: true, 256 log: log.New(), 257 } 258 done = make(chan *testTask) 259 start, end = 0, 0 260 ) 261 defer srv.Stop() 262 srv.loopWG.Add(1) 263 go srv.run(taskgen{ 264 newFunc: func(running int, peers map[discover.NodeID]*Peer) []task { 265 start, end = end, end+maxActiveDialTasks+10 266 if end > len(alltasks) { 267 end = len(alltasks) 268 } 269 return alltasks[start:end] 270 }, 271 doneFunc: func(tt task) { 272 done <- tt.(*testTask) 273 }, 274 }) 275 276 doneset := make(map[int]bool) 277 timeout := time.After(2 * time.Second) 278 for len(doneset) < len(alltasks) { 279 select { 280 case tt := <-done: 281 if doneset[tt.index] { 282 t.Errorf("task %d got done more than once", tt.index) 283 } else { 284 doneset[tt.index] = true 285 } 286 case <-timeout: 287 t.Errorf("%d of %d tasks got done within 2s", len(doneset), len(alltasks)) 288 for i := 0; i < len(alltasks); i++ { 289 if !doneset[i] { 290 t.Logf("task %d not done", i) 291 } 292 } 293 return 294 } 295 } 296 } 297 298 type taskgen struct { 299 newFunc func(running int, peers map[discover.NodeID]*Peer) []task 300 doneFunc func(task) 301 } 302 303 func (tg taskgen) newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task { 304 return tg.newFunc(running, peers) 305 } 306 func (tg taskgen) taskDone(t task, now time.Time) { 307 tg.doneFunc(t) 308 } 309 func (tg taskgen) addStatic(*discover.Node) { 310 } 311 func (tg taskgen) removeStatic(*discover.Node) { 312 } 313 314 type testTask struct { 315 index int 316 called bool 317 } 318 319 func (t *testTask) Do(srv *Server) { 320 t.called = true 321 } 322 323 // This test checks that connections are disconnected 324 // just after the encryption handshake when the server is 325 // at capacity. Trusted connections should still be accepted. 326 func TestServerAtCap(t *testing.T) { 327 trustedID := randomID() 328 srv := &Server{ 329 Config: Config{ 330 PrivateKey: newkey(), 331 MaxPeers: 10, 332 NoDial: true, 333 TrustedNodes: []*discover.Node{{ID: trustedID}}, 334 ChainId: 333, 335 }, 336 } 337 if err := srv.Start(); err != nil { 338 t.Fatalf("could not start: %v", err) 339 } 340 defer srv.Stop() 341 342 newconn := func(id discover.NodeID) *conn { 343 fd, _ := net.Pipe() 344 tx := newTestTransport(id, fd) 345 return &conn{fd: fd, transport: tx, flags: inboundConn, id: id, cont: make(chan error)} 346 } 347 348 // Inject a few connections to fill up the peer set. 349 for i := 0; i < 10; i++ { 350 c := newconn(randomID()) 351 if err := srv.checkpoint(c, srv.addpeer); err != nil { 352 t.Fatalf("could not add conn %d: %v", i, err) 353 } 354 } 355 // Try inserting a non-trusted connection. 356 c := newconn(randomID()) 357 if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers { 358 t.Error("wrong error for insert:", err) 359 } 360 // Try inserting a trusted connection. 361 c = newconn(trustedID) 362 if err := srv.checkpoint(c, srv.posthandshake); err != nil { 363 t.Error("unexpected error for trusted conn @posthandshake:", err) 364 } 365 if !c.is(trustedConn) { 366 t.Error("Server did not set trusted flag") 367 } 368 369 } 370 371 func TestServerSetupConn(t *testing.T) { 372 id := randomID() 373 srvkey := newkey() 374 srvid := discover.PubkeyID(&srvkey.PublicKey) 375 tests := []struct { 376 dontstart bool 377 tt *setupTransport 378 flags connFlag 379 dialDest *discover.Node 380 381 wantCloseErr error 382 wantCalls string 383 }{ 384 { 385 dontstart: true, 386 tt: &setupTransport{id: id}, 387 wantCalls: "close,", 388 wantCloseErr: errServerStopped, 389 }, 390 { 391 tt: &setupTransport{id: id, encHandshakeErr: errors.New("read error")}, 392 flags: inboundConn, 393 wantCalls: "doEncHandshake,close,", 394 wantCloseErr: errors.New("read error"), 395 }, 396 { 397 tt: &setupTransport{id: id}, 398 dialDest: &discover.Node{ID: randomID()}, 399 flags: dynDialedConn, 400 wantCalls: "doEncHandshake,close,", 401 wantCloseErr: DiscUnexpectedIdentity, 402 }, 403 { 404 tt: &setupTransport{id: id, phs: &protoHandshake{ID: randomID()}}, 405 dialDest: &discover.Node{ID: id}, 406 flags: dynDialedConn, 407 wantCalls: "doEncHandshake,doProtoHandshake,close,", 408 wantCloseErr: DiscUnexpectedIdentity, 409 }, 410 { 411 tt: &setupTransport{id: id, protoHandshakeErr: errors.New("foo")}, 412 dialDest: &discover.Node{ID: id}, 413 flags: dynDialedConn, 414 wantCalls: "doEncHandshake,doProtoHandshake,close,", 415 wantCloseErr: errors.New("foo"), 416 }, 417 { 418 tt: &setupTransport{id: srvid, phs: &protoHandshake{ID: srvid}}, 419 flags: inboundConn, 420 wantCalls: "doEncHandshake,close,", 421 wantCloseErr: DiscSelf, 422 }, 423 { 424 tt: &setupTransport{id: id, phs: &protoHandshake{ID: id}}, 425 flags: inboundConn, 426 wantCalls: "doEncHandshake,doProtoHandshake,close,", 427 wantCloseErr: DiscUselessPeer, 428 }, 429 } 430 431 for i, test := range tests { 432 srv := &Server{ 433 Config: Config{ 434 PrivateKey: srvkey, 435 MaxPeers: 10, 436 NoDial: true, 437 Protocols: []Protocol{discard}, 438 ChainId: 444, 439 }, 440 newTransport: func(fd net.Conn) transport { return test.tt }, 441 log: log.New(), 442 } 443 if !test.dontstart { 444 if err := srv.Start(); err != nil { 445 t.Fatalf("couldn't start server: %v", err) 446 } 447 } 448 p1, _ := net.Pipe() 449 srv.SetupConn(p1, test.flags, test.dialDest) 450 if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) { 451 t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr) 452 } 453 if test.tt.calls != test.wantCalls { 454 t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls) 455 } 456 } 457 } 458 459 type setupTransport struct { 460 id discover.NodeID 461 encHandshakeErr error 462 463 phs *protoHandshake 464 protoHandshakeErr error 465 466 calls string 467 closeErr error 468 } 469 470 func (c *setupTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) { 471 c.calls += "doEncHandshake," 472 return c.id, c.encHandshakeErr 473 } 474 func (c *setupTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) { 475 c.calls += "doProtoHandshake," 476 if c.protoHandshakeErr != nil { 477 return nil, c.protoHandshakeErr 478 } 479 return c.phs, nil 480 } 481 func (c *setupTransport) close(err error) { 482 c.calls += "close," 483 c.closeErr = err 484 } 485 486 // setupConn shouldn't write to/read from the connection. 487 func (c *setupTransport) WriteMsg(Msg) error { 488 panic("WriteMsg called on setupTransport") 489 } 490 func (c *setupTransport) ReadMsg() (Msg, error) { 491 panic("ReadMsg called on setupTransport") 492 } 493 494 func newkey() *ecdsa.PrivateKey { 495 key, err := crypto.GenerateKey() 496 if err != nil { 497 panic("couldn't generate key: " + err.Error()) 498 } 499 return key 500 } 501 502 func randomID() (id discover.NodeID) { 503 for i := range id { 504 id[i] = byte(rand.Intn(255)) 505 } 506 return id 507 }