github.com/sberex/go-sberex@v1.8.2-0.20181113200658-ed96ac38f7d7/p2p/server_test.go (about) 1 // This file is part of the go-sberex library. The go-sberex library is 2 // free software: you can redistribute it and/or modify it under the terms 3 // of the GNU Lesser General Public License as published by the Free 4 // Software Foundation, either version 3 of the License, or (at your option) 5 // any later version. 6 // 7 // The go-sberex library is distributed in the hope that it will be useful, 8 // but WITHOUT ANY WARRANTY; without even the implied warranty of 9 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser 10 // General Public License <http://www.gnu.org/licenses/> for more details. 11 12 package p2p 13 14 import ( 15 "crypto/ecdsa" 16 "errors" 17 "math/rand" 18 "net" 19 "reflect" 20 "testing" 21 "time" 22 23 "github.com/Sberex/go-sberex/crypto" 24 "github.com/Sberex/go-sberex/crypto/sha3" 25 "github.com/Sberex/go-sberex/log" 26 "github.com/Sberex/go-sberex/p2p/discover" 27 ) 28 29 func init() { 30 // log.Root().SetHandler(log.LvlFilterHandler(log.LvlError, log.StreamHandler(os.Stderr, log.TerminalFormat(false)))) 31 } 32 33 type testTransport struct { 34 id discover.NodeID 35 *rlpx 36 37 closeErr error 38 } 39 40 func newTestTransport(id discover.NodeID, fd net.Conn) transport { 41 wrapped := newRLPX(fd).(*rlpx) 42 wrapped.rw = newRLPXFrameRW(fd, secrets{ 43 MAC: zero16, 44 AES: zero16, 45 IngressMAC: sha3.NewKeccak256(), 46 EgressMAC: sha3.NewKeccak256(), 47 }) 48 return &testTransport{id: id, rlpx: wrapped} 49 } 50 51 func (c *testTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) { 52 return c.id, nil 53 } 54 55 func (c *testTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) { 56 return &protoHandshake{ID: c.id, Name: "test"}, nil 57 } 58 59 func (c *testTransport) close(err error) { 60 c.rlpx.fd.Close() 61 c.closeErr = err 62 } 63 64 func startTestServer(t *testing.T, id discover.NodeID, pf func(*Peer)) *Server { 65 config := Config{ 66 Name: "test", 67 MaxPeers: 10, 68 ListenAddr: "127.0.0.1:0", 69 PrivateKey: newkey(), 70 } 71 server := &Server{ 72 Config: config, 73 newPeerHook: pf, 74 newTransport: func(fd net.Conn) transport { return newTestTransport(id, fd) }, 75 } 76 if err := server.Start(); err != nil { 77 t.Fatalf("Could not start server: %v", err) 78 } 79 return server 80 } 81 82 func TestServerListen(t *testing.T) { 83 // start the test server 84 connected := make(chan *Peer) 85 remid := randomID() 86 srv := startTestServer(t, remid, func(p *Peer) { 87 if p.ID() != remid { 88 t.Error("peer func called with wrong node id") 89 } 90 if p == nil { 91 t.Error("peer func called with nil conn") 92 } 93 connected <- p 94 }) 95 defer close(connected) 96 defer srv.Stop() 97 98 // dial the test server 99 conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second) 100 if err != nil { 101 t.Fatalf("could not dial: %v", err) 102 } 103 defer conn.Close() 104 105 select { 106 case peer := <-connected: 107 if peer.LocalAddr().String() != conn.RemoteAddr().String() { 108 t.Errorf("peer started with wrong conn: got %v, want %v", 109 peer.LocalAddr(), conn.RemoteAddr()) 110 } 111 peers := srv.Peers() 112 if !reflect.DeepEqual(peers, []*Peer{peer}) { 113 t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer}) 114 } 115 case <-time.After(1 * time.Second): 116 t.Error("server did not accept within one second") 117 } 118 } 119 120 func TestServerDial(t *testing.T) { 121 // run a one-shot TCP server to handle the connection. 122 listener, err := net.Listen("tcp", "127.0.0.1:0") 123 if err != nil { 124 t.Fatalf("could not setup listener: %v", err) 125 } 126 defer listener.Close() 127 accepted := make(chan net.Conn) 128 go func() { 129 conn, err := listener.Accept() 130 if err != nil { 131 t.Error("accept error:", err) 132 return 133 } 134 accepted <- conn 135 }() 136 137 // start the server 138 connected := make(chan *Peer) 139 remid := randomID() 140 srv := startTestServer(t, remid, func(p *Peer) { connected <- p }) 141 defer close(connected) 142 defer srv.Stop() 143 144 // tell the server to connect 145 tcpAddr := listener.Addr().(*net.TCPAddr) 146 srv.AddPeer(&discover.Node{ID: remid, IP: tcpAddr.IP, TCP: uint16(tcpAddr.Port)}) 147 148 select { 149 case conn := <-accepted: 150 defer conn.Close() 151 152 select { 153 case peer := <-connected: 154 if peer.ID() != remid { 155 t.Errorf("peer has wrong id") 156 } 157 if peer.Name() != "test" { 158 t.Errorf("peer has wrong name") 159 } 160 if peer.RemoteAddr().String() != conn.LocalAddr().String() { 161 t.Errorf("peer started with wrong conn: got %v, want %v", 162 peer.RemoteAddr(), conn.LocalAddr()) 163 } 164 peers := srv.Peers() 165 if !reflect.DeepEqual(peers, []*Peer{peer}) { 166 t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer}) 167 } 168 case <-time.After(1 * time.Second): 169 t.Error("server did not launch peer within one second") 170 } 171 172 case <-time.After(1 * time.Second): 173 t.Error("server did not connect within one second") 174 } 175 } 176 177 // This test checks that tasks generated by dialstate are 178 // actually executed and taskdone is called for them. 179 func TestServerTaskScheduling(t *testing.T) { 180 var ( 181 done = make(chan *testTask) 182 quit, returned = make(chan struct{}), make(chan struct{}) 183 tc = 0 184 tg = taskgen{ 185 newFunc: func(running int, peers map[discover.NodeID]*Peer) []task { 186 tc++ 187 return []task{&testTask{index: tc - 1}} 188 }, 189 doneFunc: func(t task) { 190 select { 191 case done <- t.(*testTask): 192 case <-quit: 193 } 194 }, 195 } 196 ) 197 198 // The Server in this test isn't actually running 199 // because we're only interested in what run does. 200 srv := &Server{ 201 Config: Config{MaxPeers: 10}, 202 quit: make(chan struct{}), 203 ntab: fakeTable{}, 204 running: true, 205 log: log.New(), 206 } 207 srv.loopWG.Add(1) 208 go func() { 209 srv.run(tg) 210 close(returned) 211 }() 212 213 var gotdone []*testTask 214 for i := 0; i < 100; i++ { 215 gotdone = append(gotdone, <-done) 216 } 217 for i, task := range gotdone { 218 if task.index != i { 219 t.Errorf("task %d has wrong index, got %d", i, task.index) 220 break 221 } 222 if !task.called { 223 t.Errorf("task %d was not called", i) 224 break 225 } 226 } 227 228 close(quit) 229 srv.Stop() 230 select { 231 case <-returned: 232 case <-time.After(500 * time.Millisecond): 233 t.Error("Server.run did not return within 500ms") 234 } 235 } 236 237 // This test checks that Server doesn't drop tasks, 238 // even if newTasks returns more than the maximum number of tasks. 239 func TestServerManyTasks(t *testing.T) { 240 alltasks := make([]task, 300) 241 for i := range alltasks { 242 alltasks[i] = &testTask{index: i} 243 } 244 245 var ( 246 srv = &Server{ 247 quit: make(chan struct{}), 248 ntab: fakeTable{}, 249 running: true, 250 log: log.New(), 251 } 252 done = make(chan *testTask) 253 start, end = 0, 0 254 ) 255 defer srv.Stop() 256 srv.loopWG.Add(1) 257 go srv.run(taskgen{ 258 newFunc: func(running int, peers map[discover.NodeID]*Peer) []task { 259 start, end = end, end+maxActiveDialTasks+10 260 if end > len(alltasks) { 261 end = len(alltasks) 262 } 263 return alltasks[start:end] 264 }, 265 doneFunc: func(tt task) { 266 done <- tt.(*testTask) 267 }, 268 }) 269 270 doneset := make(map[int]bool) 271 timeout := time.After(2 * time.Second) 272 for len(doneset) < len(alltasks) { 273 select { 274 case tt := <-done: 275 if doneset[tt.index] { 276 t.Errorf("task %d got done more than once", tt.index) 277 } else { 278 doneset[tt.index] = true 279 } 280 case <-timeout: 281 t.Errorf("%d of %d tasks got done within 2s", len(doneset), len(alltasks)) 282 for i := 0; i < len(alltasks); i++ { 283 if !doneset[i] { 284 t.Logf("task %d not done", i) 285 } 286 } 287 return 288 } 289 } 290 } 291 292 type taskgen struct { 293 newFunc func(running int, peers map[discover.NodeID]*Peer) []task 294 doneFunc func(task) 295 } 296 297 func (tg taskgen) newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task { 298 return tg.newFunc(running, peers) 299 } 300 func (tg taskgen) taskDone(t task, now time.Time) { 301 tg.doneFunc(t) 302 } 303 func (tg taskgen) addStatic(*discover.Node) { 304 } 305 func (tg taskgen) removeStatic(*discover.Node) { 306 } 307 308 type testTask struct { 309 index int 310 called bool 311 } 312 313 func (t *testTask) Do(srv *Server) { 314 t.called = true 315 } 316 317 // This test checks that connections are disconnected 318 // just after the encryption handshake when the server is 319 // at capacity. Trusted connections should still be accepted. 320 func TestServerAtCap(t *testing.T) { 321 trustedID := randomID() 322 srv := &Server{ 323 Config: Config{ 324 PrivateKey: newkey(), 325 MaxPeers: 10, 326 NoDial: true, 327 TrustedNodes: []*discover.Node{{ID: trustedID}}, 328 }, 329 } 330 if err := srv.Start(); err != nil { 331 t.Fatalf("could not start: %v", err) 332 } 333 defer srv.Stop() 334 335 newconn := func(id discover.NodeID) *conn { 336 fd, _ := net.Pipe() 337 tx := newTestTransport(id, fd) 338 return &conn{fd: fd, transport: tx, flags: inboundConn, id: id, cont: make(chan error)} 339 } 340 341 // Inject a few connections to fill up the peer set. 342 for i := 0; i < 10; i++ { 343 c := newconn(randomID()) 344 if err := srv.checkpoint(c, srv.addpeer); err != nil { 345 t.Fatalf("could not add conn %d: %v", i, err) 346 } 347 } 348 // Try inserting a non-trusted connection. 349 c := newconn(randomID()) 350 if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers { 351 t.Error("wrong error for insert:", err) 352 } 353 // Try inserting a trusted connection. 354 c = newconn(trustedID) 355 if err := srv.checkpoint(c, srv.posthandshake); err != nil { 356 t.Error("unexpected error for trusted conn @posthandshake:", err) 357 } 358 if !c.is(trustedConn) { 359 t.Error("Server did not set trusted flag") 360 } 361 362 } 363 364 func TestServerSetupConn(t *testing.T) { 365 id := randomID() 366 srvkey := newkey() 367 srvid := discover.PubkeyID(&srvkey.PublicKey) 368 tests := []struct { 369 dontstart bool 370 tt *setupTransport 371 flags connFlag 372 dialDest *discover.Node 373 374 wantCloseErr error 375 wantCalls string 376 }{ 377 { 378 dontstart: true, 379 tt: &setupTransport{id: id}, 380 wantCalls: "close,", 381 wantCloseErr: errServerStopped, 382 }, 383 { 384 tt: &setupTransport{id: id, encHandshakeErr: errors.New("read error")}, 385 flags: inboundConn, 386 wantCalls: "doEncHandshake,close,", 387 wantCloseErr: errors.New("read error"), 388 }, 389 { 390 tt: &setupTransport{id: id}, 391 dialDest: &discover.Node{ID: randomID()}, 392 flags: dynDialedConn, 393 wantCalls: "doEncHandshake,close,", 394 wantCloseErr: DiscUnexpectedIdentity, 395 }, 396 { 397 tt: &setupTransport{id: id, phs: &protoHandshake{ID: randomID()}}, 398 dialDest: &discover.Node{ID: id}, 399 flags: dynDialedConn, 400 wantCalls: "doEncHandshake,doProtoHandshake,close,", 401 wantCloseErr: DiscUnexpectedIdentity, 402 }, 403 { 404 tt: &setupTransport{id: id, protoHandshakeErr: errors.New("foo")}, 405 dialDest: &discover.Node{ID: id}, 406 flags: dynDialedConn, 407 wantCalls: "doEncHandshake,doProtoHandshake,close,", 408 wantCloseErr: errors.New("foo"), 409 }, 410 { 411 tt: &setupTransport{id: srvid, phs: &protoHandshake{ID: srvid}}, 412 flags: inboundConn, 413 wantCalls: "doEncHandshake,close,", 414 wantCloseErr: DiscSelf, 415 }, 416 { 417 tt: &setupTransport{id: id, phs: &protoHandshake{ID: id}}, 418 flags: inboundConn, 419 wantCalls: "doEncHandshake,doProtoHandshake,close,", 420 wantCloseErr: DiscUselessPeer, 421 }, 422 } 423 424 for i, test := range tests { 425 srv := &Server{ 426 Config: Config{ 427 PrivateKey: srvkey, 428 MaxPeers: 10, 429 NoDial: true, 430 Protocols: []Protocol{discard}, 431 }, 432 newTransport: func(fd net.Conn) transport { return test.tt }, 433 log: log.New(), 434 } 435 if !test.dontstart { 436 if err := srv.Start(); err != nil { 437 t.Fatalf("couldn't start server: %v", err) 438 } 439 } 440 p1, _ := net.Pipe() 441 srv.SetupConn(p1, test.flags, test.dialDest) 442 if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) { 443 t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr) 444 } 445 if test.tt.calls != test.wantCalls { 446 t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls) 447 } 448 } 449 } 450 451 type setupTransport struct { 452 id discover.NodeID 453 encHandshakeErr error 454 455 phs *protoHandshake 456 protoHandshakeErr error 457 458 calls string 459 closeErr error 460 } 461 462 func (c *setupTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) { 463 c.calls += "doEncHandshake," 464 return c.id, c.encHandshakeErr 465 } 466 func (c *setupTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) { 467 c.calls += "doProtoHandshake," 468 if c.protoHandshakeErr != nil { 469 return nil, c.protoHandshakeErr 470 } 471 return c.phs, nil 472 } 473 func (c *setupTransport) close(err error) { 474 c.calls += "close," 475 c.closeErr = err 476 } 477 478 // setupConn shouldn't write to/read from the connection. 479 func (c *setupTransport) WriteMsg(Msg) error { 480 panic("WriteMsg called on setupTransport") 481 } 482 func (c *setupTransport) ReadMsg() (Msg, error) { 483 panic("ReadMsg called on setupTransport") 484 } 485 486 func newkey() *ecdsa.PrivateKey { 487 key, err := crypto.GenerateKey() 488 if err != nil { 489 panic("couldn't generate key: " + err.Error()) 490 } 491 return key 492 } 493 494 func randomID() (id discover.NodeID) { 495 for i := range id { 496 id[i] = byte(rand.Intn(255)) 497 } 498 return id 499 }