github.com/neatio-net/neatio@v1.7.3-0.20231114194659-f4d7a2226baa/network/p2p/server_test.go (about) 1 package p2p 2 3 import ( 4 "crypto/ecdsa" 5 "errors" 6 "math/rand" 7 "net" 8 "reflect" 9 "testing" 10 "time" 11 12 "github.com/neatio-net/neatio/chain/log" 13 "github.com/neatio-net/neatio/network/p2p/discover" 14 "github.com/neatio-net/neatio/utilities/crypto" 15 "github.com/neatio-net/neatio/utilities/crypto/sha3" 16 ) 17 18 func init() { 19 20 } 21 22 type testTransport struct { 23 id discover.NodeID 24 *rlpx 25 26 closeErr error 27 } 28 29 func newTestTransport(id discover.NodeID, fd net.Conn) transport { 30 wrapped := newRLPX(fd).(*rlpx) 31 wrapped.rw = newRLPXFrameRW(fd, secrets{ 32 MAC: zero16, 33 AES: zero16, 34 IngressMAC: sha3.NewKeccak256(), 35 EgressMAC: sha3.NewKeccak256(), 36 }) 37 return &testTransport{id: id, rlpx: wrapped} 38 } 39 40 func (c *testTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) { 41 return c.id, nil 42 } 43 44 func (c *testTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) { 45 return &protoHandshake{ID: c.id, Name: "test"}, nil 46 } 47 48 func (c *testTransport) close(err error) { 49 c.rlpx.fd.Close() 50 c.closeErr = err 51 } 52 53 func startTestServer(t *testing.T, id discover.NodeID, pf func(*Peer)) *Server { 54 config := Config{ 55 Name: "test", 56 MaxPeers: 10, 57 ListenAddr: "127.0.0.1:0", 58 PrivateKey: newkey(), 59 } 60 server := &Server{ 61 Config: config, 62 newPeerHook: pf, 63 newTransport: func(fd net.Conn) transport { return newTestTransport(id, fd) }, 64 } 65 if err := server.Start(); err != nil { 66 t.Fatalf("Could not start server: %v", err) 67 } 68 return server 69 } 70 71 func TestServerListen(t *testing.T) { 72 73 connected := make(chan *Peer) 74 remid := randomID() 75 srv := startTestServer(t, remid, func(p *Peer) { 76 if p.ID() != remid { 77 t.Error("peer func called with wrong node id") 78 } 79 if p == nil { 80 t.Error("peer func called with nil conn") 81 } 82 connected <- p 83 }) 84 defer close(connected) 85 defer srv.Stop() 86 87 conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second) 88 if err != nil { 89 t.Fatalf("could not dial: %v", err) 90 } 91 defer conn.Close() 92 93 select { 94 case peer := <-connected: 95 if peer.LocalAddr().String() != conn.RemoteAddr().String() { 96 t.Errorf("peer started with wrong conn: got %v, want %v", 97 peer.LocalAddr(), conn.RemoteAddr()) 98 } 99 peers := srv.Peers() 100 if !reflect.DeepEqual(peers, []*Peer{peer}) { 101 t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer}) 102 } 103 case <-time.After(1 * time.Second): 104 t.Error("server did not accept within one second") 105 } 106 } 107 108 func TestServerDial(t *testing.T) { 109 110 listener, err := net.Listen("tcp", "127.0.0.1:0") 111 if err != nil { 112 t.Fatalf("could not setup listener: %v", err) 113 } 114 defer listener.Close() 115 accepted := make(chan net.Conn) 116 go func() { 117 conn, err := listener.Accept() 118 if err != nil { 119 t.Error("accept error:", err) 120 return 121 } 122 accepted <- conn 123 }() 124 125 connected := make(chan *Peer) 126 remid := randomID() 127 srv := startTestServer(t, remid, func(p *Peer) { connected <- p }) 128 defer close(connected) 129 defer srv.Stop() 130 131 tcpAddr := listener.Addr().(*net.TCPAddr) 132 srv.AddPeer(&discover.Node{ID: remid, IP: tcpAddr.IP, TCP: uint16(tcpAddr.Port)}) 133 134 select { 135 case conn := <-accepted: 136 defer conn.Close() 137 138 select { 139 case peer := <-connected: 140 if peer.ID() != remid { 141 t.Errorf("peer has wrong id") 142 } 143 if peer.Name() != "test" { 144 t.Errorf("peer has wrong name") 145 } 146 if peer.RemoteAddr().String() != conn.LocalAddr().String() { 147 t.Errorf("peer started with wrong conn: got %v, want %v", 148 peer.RemoteAddr(), conn.LocalAddr()) 149 } 150 peers := srv.Peers() 151 if !reflect.DeepEqual(peers, []*Peer{peer}) { 152 t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer}) 153 } 154 case <-time.After(1 * time.Second): 155 t.Error("server did not launch peer within one second") 156 } 157 158 case <-time.After(1 * time.Second): 159 t.Error("server did not connect within one second") 160 } 161 } 162 163 func TestServerTaskScheduling(t *testing.T) { 164 var ( 165 done = make(chan *testTask) 166 quit, returned = make(chan struct{}), make(chan struct{}) 167 tc = 0 168 tg = taskgen{ 169 newFunc: func(running int, peers map[discover.NodeID]*Peer) []task { 170 tc++ 171 return []task{&testTask{index: tc - 1}} 172 }, 173 doneFunc: func(t task) { 174 select { 175 case done <- t.(*testTask): 176 case <-quit: 177 } 178 }, 179 } 180 ) 181 182 srv := &Server{ 183 Config: Config{MaxPeers: 10}, 184 quit: make(chan struct{}), 185 ntab: fakeTable{}, 186 running: true, 187 log: log.New(), 188 } 189 srv.loopWG.Add(1) 190 go func() { 191 srv.run(tg) 192 close(returned) 193 }() 194 195 var gotdone []*testTask 196 for i := 0; i < 100; i++ { 197 gotdone = append(gotdone, <-done) 198 } 199 for i, task := range gotdone { 200 if task.index != i { 201 t.Errorf("task %d has wrong index, got %d", i, task.index) 202 break 203 } 204 if !task.called { 205 t.Errorf("task %d was not called", i) 206 break 207 } 208 } 209 210 close(quit) 211 srv.Stop() 212 select { 213 case <-returned: 214 case <-time.After(500 * time.Millisecond): 215 t.Error("Server.run did not return within 500ms") 216 } 217 } 218 219 func TestServerManyTasks(t *testing.T) { 220 alltasks := make([]task, 300) 221 for i := range alltasks { 222 alltasks[i] = &testTask{index: i} 223 } 224 225 var ( 226 srv = &Server{ 227 quit: make(chan struct{}), 228 ntab: fakeTable{}, 229 running: true, 230 log: log.New(), 231 } 232 done = make(chan *testTask) 233 start, end = 0, 0 234 ) 235 defer srv.Stop() 236 srv.loopWG.Add(1) 237 go srv.run(taskgen{ 238 newFunc: func(running int, peers map[discover.NodeID]*Peer) []task { 239 start, end = end, end+maxActiveDialTasks+10 240 if end > len(alltasks) { 241 end = len(alltasks) 242 } 243 return alltasks[start:end] 244 }, 245 doneFunc: func(tt task) { 246 done <- tt.(*testTask) 247 }, 248 }) 249 250 doneset := make(map[int]bool) 251 timeout := time.After(2 * time.Second) 252 for len(doneset) < len(alltasks) { 253 select { 254 case tt := <-done: 255 if doneset[tt.index] { 256 t.Errorf("task %d got done more than once", tt.index) 257 } else { 258 doneset[tt.index] = true 259 } 260 case <-timeout: 261 t.Errorf("%d of %d tasks got done within 2s", len(doneset), len(alltasks)) 262 for i := 0; i < len(alltasks); i++ { 263 if !doneset[i] { 264 t.Logf("task %d not done", i) 265 } 266 } 267 return 268 } 269 } 270 } 271 272 type taskgen struct { 273 newFunc func(running int, peers map[discover.NodeID]*Peer) []task 274 doneFunc func(task) 275 } 276 277 func (tg taskgen) newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task { 278 return tg.newFunc(running, peers) 279 } 280 func (tg taskgen) taskDone(t task, now time.Time) { 281 tg.doneFunc(t) 282 } 283 func (tg taskgen) addStatic(*discover.Node) { 284 } 285 func (tg taskgen) removeStatic(*discover.Node) { 286 } 287 288 type testTask struct { 289 index int 290 called bool 291 } 292 293 func (t *testTask) Do(srv *Server) { 294 t.called = true 295 } 296 297 func TestServerAtCap(t *testing.T) { 298 trustedID := randomID() 299 srv := &Server{ 300 Config: Config{ 301 PrivateKey: newkey(), 302 MaxPeers: 10, 303 NoDial: true, 304 TrustedNodes: []*discover.Node{{ID: trustedID}}, 305 }, 306 } 307 if err := srv.Start(); err != nil { 308 t.Fatalf("could not start: %v", err) 309 } 310 defer srv.Stop() 311 312 newconn := func(id discover.NodeID) *conn { 313 fd, _ := net.Pipe() 314 tx := newTestTransport(id, fd) 315 return &conn{fd: fd, transport: tx, flags: inboundConn, id: id, cont: make(chan error)} 316 } 317 318 for i := 0; i < 10; i++ { 319 c := newconn(randomID()) 320 if err := srv.checkpoint(c, srv.addpeer); err != nil { 321 t.Fatalf("could not add conn %d: %v", i, err) 322 } 323 } 324 325 c := newconn(randomID()) 326 if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers { 327 t.Error("wrong error for insert:", err) 328 } 329 330 c = newconn(trustedID) 331 if err := srv.checkpoint(c, srv.posthandshake); err != nil { 332 t.Error("unexpected error for trusted conn @posthandshake:", err) 333 } 334 if !c.is(trustedConn) { 335 t.Error("Server did not set trusted flag") 336 } 337 338 } 339 340 func TestServerSetupConn(t *testing.T) { 341 id := randomID() 342 srvkey := newkey() 343 srvid := discover.PubkeyID(&srvkey.PublicKey) 344 tests := []struct { 345 dontstart bool 346 tt *setupTransport 347 flags connFlag 348 dialDest *discover.Node 349 350 wantCloseErr error 351 wantCalls string 352 }{ 353 { 354 dontstart: true, 355 tt: &setupTransport{id: id}, 356 wantCalls: "close,", 357 wantCloseErr: errServerStopped, 358 }, 359 { 360 tt: &setupTransport{id: id, encHandshakeErr: errors.New("read error")}, 361 flags: inboundConn, 362 wantCalls: "doEncHandshake,close,", 363 wantCloseErr: errors.New("read error"), 364 }, 365 { 366 tt: &setupTransport{id: id}, 367 dialDest: &discover.Node{ID: randomID()}, 368 flags: dynDialedConn, 369 wantCalls: "doEncHandshake,close,", 370 wantCloseErr: DiscUnexpectedIdentity, 371 }, 372 { 373 tt: &setupTransport{id: id, phs: &protoHandshake{ID: randomID()}}, 374 dialDest: &discover.Node{ID: id}, 375 flags: dynDialedConn, 376 wantCalls: "doEncHandshake,doProtoHandshake,close,", 377 wantCloseErr: DiscUnexpectedIdentity, 378 }, 379 { 380 tt: &setupTransport{id: id, protoHandshakeErr: errors.New("foo")}, 381 dialDest: &discover.Node{ID: id}, 382 flags: dynDialedConn, 383 wantCalls: "doEncHandshake,doProtoHandshake,close,", 384 wantCloseErr: errors.New("foo"), 385 }, 386 { 387 tt: &setupTransport{id: srvid, phs: &protoHandshake{ID: srvid}}, 388 flags: inboundConn, 389 wantCalls: "doEncHandshake,close,", 390 wantCloseErr: DiscSelf, 391 }, 392 { 393 tt: &setupTransport{id: id, phs: &protoHandshake{ID: id}}, 394 flags: inboundConn, 395 wantCalls: "doEncHandshake,doProtoHandshake,close,", 396 wantCloseErr: DiscUselessPeer, 397 }, 398 } 399 400 for i, test := range tests { 401 srv := &Server{ 402 Config: Config{ 403 PrivateKey: srvkey, 404 MaxPeers: 10, 405 NoDial: true, 406 Protocols: []Protocol{discard}, 407 }, 408 newTransport: func(fd net.Conn) transport { return test.tt }, 409 log: log.New(), 410 } 411 if !test.dontstart { 412 if err := srv.Start(); err != nil { 413 t.Fatalf("couldn't start server: %v", err) 414 } 415 } 416 p1, _ := net.Pipe() 417 srv.SetupConn(p1, test.flags, test.dialDest) 418 if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) { 419 t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr) 420 } 421 if test.tt.calls != test.wantCalls { 422 t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls) 423 } 424 } 425 } 426 427 type setupTransport struct { 428 id discover.NodeID 429 encHandshakeErr error 430 431 phs *protoHandshake 432 protoHandshakeErr error 433 434 calls string 435 closeErr error 436 } 437 438 func (c *setupTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) { 439 c.calls += "doEncHandshake," 440 return c.id, c.encHandshakeErr 441 } 442 func (c *setupTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) { 443 c.calls += "doProtoHandshake," 444 if c.protoHandshakeErr != nil { 445 return nil, c.protoHandshakeErr 446 } 447 return c.phs, nil 448 } 449 func (c *setupTransport) close(err error) { 450 c.calls += "close," 451 c.closeErr = err 452 } 453 454 func (c *setupTransport) WriteMsg(Msg) error { 455 panic("WriteMsg called on setupTransport") 456 } 457 func (c *setupTransport) ReadMsg() (Msg, error) { 458 panic("ReadMsg called on setupTransport") 459 } 460 461 func newkey() *ecdsa.PrivateKey { 462 key, err := crypto.GenerateKey() 463 if err != nil { 464 panic("couldn't generate key: " + err.Error()) 465 } 466 return key 467 } 468 469 func randomID() (id discover.NodeID) { 470 for i := range id { 471 id[i] = byte(rand.Intn(255)) 472 } 473 return id 474 }