github.com/aquanetwork/aquachain@v1.7.8/p2p/server_test.go (about) 1 // Copyright 2014 The aquachain Authors 2 // This file is part of the aquachain library. 3 // 4 // The aquachain library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The aquachain library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the aquachain library. If not, see <http://www.gnu.org/licenses/>. 16 17 package p2p 18 19 import ( 20 "crypto/ecdsa" 21 "errors" 22 "math/rand" 23 "net" 24 "reflect" 25 "testing" 26 "time" 27 28 "gitlab.com/aquachain/aquachain/common/log" 29 "gitlab.com/aquachain/aquachain/crypto" 30 "gitlab.com/aquachain/aquachain/crypto/sha3" 31 "gitlab.com/aquachain/aquachain/p2p/discover" 32 ) 33 34 func init() { 35 // log.Root().SetHandler(log.LvlFilterHandler(log.LvlError, log.StreamHandler(os.Stderr, log.TerminalFormat(false)))) 36 } 37 38 type testTransport struct { 39 id discover.NodeID 40 *rlpx 41 42 closeErr error 43 } 44 45 func newTestTransport(id discover.NodeID, fd net.Conn) transport { 46 wrapped := newRLPX(fd).(*rlpx) 47 wrapped.rw = newRLPXFrameRW(fd, secrets{ 48 MAC: zero16, 49 AES: zero16, 50 IngressMAC: sha3.NewKeccak256(), 51 EgressMAC: sha3.NewKeccak256(), 52 }) 53 return &testTransport{id: id, rlpx: wrapped} 54 } 55 56 func (c *testTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) { 57 return c.id, nil 58 } 59 60 func (c *testTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) { 61 return &protoHandshake{ID: c.id, Name: "test"}, nil 62 } 63 64 func (c *testTransport) close(err error) { 65 c.rlpx.fd.Close() 66 c.closeErr = err 67 } 68 69 func startTestServer(t *testing.T, id discover.NodeID, pf func(*Peer)) *Server { 70 config := Config{ 71 Name: "test", 72 MaxPeers: 10, 73 ListenAddr: "127.0.0.1:0", 74 PrivateKey: newkey(), 75 } 76 server := &Server{ 77 Config: config, 78 newPeerHook: pf, 79 newTransport: func(fd net.Conn) transport { return newTestTransport(id, fd) }, 80 } 81 if err := server.Start(); err != nil { 82 t.Fatalf("Could not start server: %v", err) 83 } 84 return server 85 } 86 87 func TestServerListen(t *testing.T) { 88 // start the test server 89 connected := make(chan *Peer) 90 remid := randomID() 91 srv := startTestServer(t, remid, func(p *Peer) { 92 if p.ID() != remid { 93 t.Error("peer func called with wrong node id") 94 } 95 if p == nil { 96 t.Error("peer func called with nil conn") 97 } 98 connected <- p 99 }) 100 defer close(connected) 101 defer srv.Stop() 102 103 // dial the test server 104 conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second) 105 if err != nil { 106 t.Fatalf("could not dial: %v", err) 107 } 108 defer conn.Close() 109 110 select { 111 case peer := <-connected: 112 if peer.LocalAddr().String() != conn.RemoteAddr().String() { 113 t.Errorf("peer started with wrong conn: got %v, want %v", 114 peer.LocalAddr(), conn.RemoteAddr()) 115 } 116 peers := srv.Peers() 117 if !reflect.DeepEqual(peers, []*Peer{peer}) { 118 t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer}) 119 } 120 case <-time.After(1 * time.Second): 121 t.Error("server did not accept within one second") 122 } 123 } 124 125 func TestServerDial(t *testing.T) { 126 // run a one-shot TCP server to handle the connection. 127 listener, err := net.Listen("tcp", "127.0.0.1:0") 128 if err != nil { 129 t.Fatalf("could not setup listener: %v", err) 130 } 131 defer listener.Close() 132 accepted := make(chan net.Conn) 133 go func() { 134 conn, err := listener.Accept() 135 if err != nil { 136 t.Error("accept error:", err) 137 return 138 } 139 accepted <- conn 140 }() 141 142 // start the server 143 connected := make(chan *Peer) 144 remid := randomID() 145 srv := startTestServer(t, remid, func(p *Peer) { connected <- p }) 146 defer close(connected) 147 defer srv.Stop() 148 149 // tell the server to connect 150 tcpAddr := listener.Addr().(*net.TCPAddr) 151 srv.AddPeer(&discover.Node{ID: remid, IP: tcpAddr.IP, TCP: uint16(tcpAddr.Port)}) 152 153 select { 154 case conn := <-accepted: 155 defer conn.Close() 156 157 select { 158 case peer := <-connected: 159 if peer.ID() != remid { 160 t.Errorf("peer has wrong id") 161 } 162 if peer.Name() != "test" { 163 t.Errorf("peer has wrong name") 164 } 165 if peer.RemoteAddr().String() != conn.LocalAddr().String() { 166 t.Errorf("peer started with wrong conn: got %v, want %v", 167 peer.RemoteAddr(), conn.LocalAddr()) 168 } 169 peers := srv.Peers() 170 if !reflect.DeepEqual(peers, []*Peer{peer}) { 171 t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer}) 172 } 173 case <-time.After(1 * time.Second): 174 t.Error("server did not launch peer within one second") 175 } 176 177 case <-time.After(1 * time.Second): 178 t.Error("server did not connect within one second") 179 } 180 } 181 182 // This test checks that tasks generated by dialstate are 183 // actually executed and taskdone is called for them. 184 func TestServerTaskScheduling(t *testing.T) { 185 var ( 186 done = make(chan *testTask) 187 quit, returned = make(chan struct{}), make(chan struct{}) 188 tc = 0 189 tg = taskgen{ 190 newFunc: func(running int, peers map[discover.NodeID]*Peer) []task { 191 tc++ 192 return []task{&testTask{index: tc - 1}} 193 }, 194 doneFunc: func(t task) { 195 select { 196 case done <- t.(*testTask): 197 case <-quit: 198 } 199 }, 200 } 201 ) 202 203 // The Server in this test isn't actually running 204 // because we're only interested in what run does. 205 srv := &Server{ 206 Config: Config{MaxPeers: 10}, 207 quit: make(chan struct{}), 208 ntab: fakeTable{}, 209 running: true, 210 log: log.New(), 211 } 212 srv.loopWG.Add(1) 213 go func() { 214 srv.run(tg) 215 close(returned) 216 }() 217 218 var gotdone []*testTask 219 for i := 0; i < 100; i++ { 220 gotdone = append(gotdone, <-done) 221 } 222 for i, task := range gotdone { 223 if task.index != i { 224 t.Errorf("task %d has wrong index, got %d", i, task.index) 225 break 226 } 227 if !task.called { 228 t.Errorf("task %d was not called", i) 229 break 230 } 231 } 232 233 close(quit) 234 srv.Stop() 235 select { 236 case <-returned: 237 case <-time.After(500 * time.Millisecond): 238 t.Error("Server.run did not return within 500ms") 239 } 240 } 241 242 // This test checks that Server doesn't drop tasks, 243 // even if newTasks returns more than the maximum number of tasks. 244 func TestServerManyTasks(t *testing.T) { 245 alltasks := make([]task, 300) 246 for i := range alltasks { 247 alltasks[i] = &testTask{index: i} 248 } 249 250 var ( 251 srv = &Server{ 252 quit: make(chan struct{}), 253 ntab: fakeTable{}, 254 running: true, 255 log: log.New(), 256 } 257 done = make(chan *testTask) 258 start, end = 0, 0 259 ) 260 defer srv.Stop() 261 srv.loopWG.Add(1) 262 go srv.run(taskgen{ 263 newFunc: func(running int, peers map[discover.NodeID]*Peer) []task { 264 start, end = end, end+maxActiveDialTasks+10 265 if end > len(alltasks) { 266 end = len(alltasks) 267 } 268 return alltasks[start:end] 269 }, 270 doneFunc: func(tt task) { 271 done <- tt.(*testTask) 272 }, 273 }) 274 275 doneset := make(map[int]bool) 276 timeout := time.After(2 * time.Second) 277 for len(doneset) < len(alltasks) { 278 select { 279 case tt := <-done: 280 if doneset[tt.index] { 281 t.Errorf("task %d got done more than once", tt.index) 282 } else { 283 doneset[tt.index] = true 284 } 285 case <-timeout: 286 t.Errorf("%d of %d tasks got done within 2s", len(doneset), len(alltasks)) 287 for i := 0; i < len(alltasks); i++ { 288 if !doneset[i] { 289 t.Logf("task %d not done", i) 290 } 291 } 292 return 293 } 294 } 295 } 296 297 type taskgen struct { 298 newFunc func(running int, peers map[discover.NodeID]*Peer) []task 299 doneFunc func(task) 300 } 301 302 func (tg taskgen) newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task { 303 return tg.newFunc(running, peers) 304 } 305 func (tg taskgen) taskDone(t task, now time.Time) { 306 tg.doneFunc(t) 307 } 308 func (tg taskgen) addStatic(*discover.Node) { 309 } 310 func (tg taskgen) removeStatic(*discover.Node) { 311 } 312 313 type testTask struct { 314 index int 315 called bool 316 } 317 318 func (t *testTask) Do(srv *Server) { 319 t.called = true 320 } 321 322 // This test checks that connections are disconnected 323 // just after the encryption handshake when the server is 324 // at capacity. Trusted connections should still be accepted. 325 func TestServerAtCap(t *testing.T) { 326 trustedID := randomID() 327 srv := &Server{ 328 Config: Config{ 329 PrivateKey: newkey(), 330 MaxPeers: 10, 331 NoDial: true, 332 TrustedNodes: []*discover.Node{{ID: trustedID}}, 333 }, 334 } 335 if err := srv.Start(); err != nil { 336 t.Fatalf("could not start: %v", err) 337 } 338 defer srv.Stop() 339 340 newconn := func(id discover.NodeID) *conn { 341 fd, _ := net.Pipe() 342 tx := newTestTransport(id, fd) 343 return &conn{fd: fd, transport: tx, flags: inboundConn, id: id, cont: make(chan error)} 344 } 345 346 // Inject a few connections to fill up the peer set. 347 for i := 0; i < 10; i++ { 348 c := newconn(randomID()) 349 if err := srv.checkpoint(c, srv.addpeer); err != nil { 350 t.Fatalf("could not add conn %d: %v", i, err) 351 } 352 } 353 // Try inserting a non-trusted connection. 354 c := newconn(randomID()) 355 if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers { 356 t.Error("wrong error for insert:", err) 357 } 358 // Try inserting a trusted connection. 359 c = newconn(trustedID) 360 if err := srv.checkpoint(c, srv.posthandshake); err != nil { 361 t.Error("unexpected error for trusted conn @posthandshake:", err) 362 } 363 if !c.is(trustedConn) { 364 t.Error("Server did not set trusted flag") 365 } 366 367 } 368 369 func TestServerSetupConn(t *testing.T) { 370 id := randomID() 371 srvkey := newkey() 372 srvid := discover.PubkeyID(&srvkey.PublicKey) 373 tests := []struct { 374 dontstart bool 375 tt *setupTransport 376 flags connFlag 377 dialDest *discover.Node 378 379 wantCloseErr error 380 wantCalls string 381 }{ 382 { 383 dontstart: true, 384 tt: &setupTransport{id: id}, 385 wantCalls: "close,", 386 wantCloseErr: errServerStopped, 387 }, 388 { 389 tt: &setupTransport{id: id, encHandshakeErr: errors.New("read error")}, 390 flags: inboundConn, 391 wantCalls: "doEncHandshake,close,", 392 wantCloseErr: errors.New("read error"), 393 }, 394 { 395 tt: &setupTransport{id: id}, 396 dialDest: &discover.Node{ID: randomID()}, 397 flags: dynDialedConn, 398 wantCalls: "doEncHandshake,close,", 399 wantCloseErr: DiscUnexpectedIdentity, 400 }, 401 { 402 tt: &setupTransport{id: id, phs: &protoHandshake{ID: randomID()}}, 403 dialDest: &discover.Node{ID: id}, 404 flags: dynDialedConn, 405 wantCalls: "doEncHandshake,doProtoHandshake,close,", 406 wantCloseErr: DiscUnexpectedIdentity, 407 }, 408 { 409 tt: &setupTransport{id: id, protoHandshakeErr: errors.New("foo")}, 410 dialDest: &discover.Node{ID: id}, 411 flags: dynDialedConn, 412 wantCalls: "doEncHandshake,doProtoHandshake,close,", 413 wantCloseErr: errors.New("foo"), 414 }, 415 { 416 tt: &setupTransport{id: srvid, phs: &protoHandshake{ID: srvid}}, 417 flags: inboundConn, 418 wantCalls: "doEncHandshake,close,", 419 wantCloseErr: DiscSelf, 420 }, 421 { 422 tt: &setupTransport{id: id, phs: &protoHandshake{ID: id}}, 423 flags: inboundConn, 424 wantCalls: "doEncHandshake,doProtoHandshake,close,", 425 wantCloseErr: DiscUselessPeer, 426 }, 427 } 428 429 for i, test := range tests { 430 srv := &Server{ 431 Config: Config{ 432 PrivateKey: srvkey, 433 MaxPeers: 10, 434 NoDial: true, 435 Protocols: []Protocol{discard}, 436 }, 437 newTransport: func(fd net.Conn) transport { return test.tt }, 438 log: log.New(), 439 } 440 if !test.dontstart { 441 if err := srv.Start(); err != nil { 442 t.Fatalf("couldn't start server: %v", err) 443 } 444 } 445 p1, _ := net.Pipe() 446 srv.SetupConn(p1, test.flags, test.dialDest) 447 if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) { 448 t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr) 449 } 450 if test.tt.calls != test.wantCalls { 451 t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls) 452 } 453 } 454 } 455 456 type setupTransport struct { 457 id discover.NodeID 458 encHandshakeErr error 459 460 phs *protoHandshake 461 protoHandshakeErr error 462 463 calls string 464 closeErr error 465 } 466 467 func (c *setupTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) { 468 c.calls += "doEncHandshake," 469 return c.id, c.encHandshakeErr 470 } 471 func (c *setupTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) { 472 c.calls += "doProtoHandshake," 473 if c.protoHandshakeErr != nil { 474 return nil, c.protoHandshakeErr 475 } 476 return c.phs, nil 477 } 478 func (c *setupTransport) close(err error) { 479 c.calls += "close," 480 c.closeErr = err 481 } 482 483 // setupConn shouldn't write to/read from the connection. 484 func (c *setupTransport) WriteMsg(Msg) error { 485 panic("WriteMsg called on setupTransport") 486 } 487 func (c *setupTransport) ReadMsg() (Msg, error) { 488 panic("ReadMsg called on setupTransport") 489 } 490 491 func newkey() *ecdsa.PrivateKey { 492 key, err := crypto.GenerateKey() 493 if err != nil { 494 panic("couldn't generate key: " + err.Error()) 495 } 496 return key 497 } 498 499 func randomID() (id discover.NodeID) { 500 for i := range id { 501 id[i] = byte(rand.Intn(255)) 502 } 503 return id 504 }