github.com/neatio-net/neatio@v1.7.3-0.20231114194659-f4d7a2226baa/network/p2p/server_test.go (about)

     1  package p2p
     2  
     3  import (
     4  	"crypto/ecdsa"
     5  	"errors"
     6  	"math/rand"
     7  	"net"
     8  	"reflect"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/neatio-net/neatio/chain/log"
    13  	"github.com/neatio-net/neatio/network/p2p/discover"
    14  	"github.com/neatio-net/neatio/utilities/crypto"
    15  	"github.com/neatio-net/neatio/utilities/crypto/sha3"
    16  )
    17  
    18  func init() {
    19  
    20  }
    21  
    22  type testTransport struct {
    23  	id discover.NodeID
    24  	*rlpx
    25  
    26  	closeErr error
    27  }
    28  
    29  func newTestTransport(id discover.NodeID, fd net.Conn) transport {
    30  	wrapped := newRLPX(fd).(*rlpx)
    31  	wrapped.rw = newRLPXFrameRW(fd, secrets{
    32  		MAC:        zero16,
    33  		AES:        zero16,
    34  		IngressMAC: sha3.NewKeccak256(),
    35  		EgressMAC:  sha3.NewKeccak256(),
    36  	})
    37  	return &testTransport{id: id, rlpx: wrapped}
    38  }
    39  
    40  func (c *testTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) {
    41  	return c.id, nil
    42  }
    43  
    44  func (c *testTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
    45  	return &protoHandshake{ID: c.id, Name: "test"}, nil
    46  }
    47  
    48  func (c *testTransport) close(err error) {
    49  	c.rlpx.fd.Close()
    50  	c.closeErr = err
    51  }
    52  
    53  func startTestServer(t *testing.T, id discover.NodeID, pf func(*Peer)) *Server {
    54  	config := Config{
    55  		Name:       "test",
    56  		MaxPeers:   10,
    57  		ListenAddr: "127.0.0.1:0",
    58  		PrivateKey: newkey(),
    59  	}
    60  	server := &Server{
    61  		Config:       config,
    62  		newPeerHook:  pf,
    63  		newTransport: func(fd net.Conn) transport { return newTestTransport(id, fd) },
    64  	}
    65  	if err := server.Start(); err != nil {
    66  		t.Fatalf("Could not start server: %v", err)
    67  	}
    68  	return server
    69  }
    70  
    71  func TestServerListen(t *testing.T) {
    72  
    73  	connected := make(chan *Peer)
    74  	remid := randomID()
    75  	srv := startTestServer(t, remid, func(p *Peer) {
    76  		if p.ID() != remid {
    77  			t.Error("peer func called with wrong node id")
    78  		}
    79  		if p == nil {
    80  			t.Error("peer func called with nil conn")
    81  		}
    82  		connected <- p
    83  	})
    84  	defer close(connected)
    85  	defer srv.Stop()
    86  
    87  	conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second)
    88  	if err != nil {
    89  		t.Fatalf("could not dial: %v", err)
    90  	}
    91  	defer conn.Close()
    92  
    93  	select {
    94  	case peer := <-connected:
    95  		if peer.LocalAddr().String() != conn.RemoteAddr().String() {
    96  			t.Errorf("peer started with wrong conn: got %v, want %v",
    97  				peer.LocalAddr(), conn.RemoteAddr())
    98  		}
    99  		peers := srv.Peers()
   100  		if !reflect.DeepEqual(peers, []*Peer{peer}) {
   101  			t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
   102  		}
   103  	case <-time.After(1 * time.Second):
   104  		t.Error("server did not accept within one second")
   105  	}
   106  }
   107  
   108  func TestServerDial(t *testing.T) {
   109  
   110  	listener, err := net.Listen("tcp", "127.0.0.1:0")
   111  	if err != nil {
   112  		t.Fatalf("could not setup listener: %v", err)
   113  	}
   114  	defer listener.Close()
   115  	accepted := make(chan net.Conn)
   116  	go func() {
   117  		conn, err := listener.Accept()
   118  		if err != nil {
   119  			t.Error("accept error:", err)
   120  			return
   121  		}
   122  		accepted <- conn
   123  	}()
   124  
   125  	connected := make(chan *Peer)
   126  	remid := randomID()
   127  	srv := startTestServer(t, remid, func(p *Peer) { connected <- p })
   128  	defer close(connected)
   129  	defer srv.Stop()
   130  
   131  	tcpAddr := listener.Addr().(*net.TCPAddr)
   132  	srv.AddPeer(&discover.Node{ID: remid, IP: tcpAddr.IP, TCP: uint16(tcpAddr.Port)})
   133  
   134  	select {
   135  	case conn := <-accepted:
   136  		defer conn.Close()
   137  
   138  		select {
   139  		case peer := <-connected:
   140  			if peer.ID() != remid {
   141  				t.Errorf("peer has wrong id")
   142  			}
   143  			if peer.Name() != "test" {
   144  				t.Errorf("peer has wrong name")
   145  			}
   146  			if peer.RemoteAddr().String() != conn.LocalAddr().String() {
   147  				t.Errorf("peer started with wrong conn: got %v, want %v",
   148  					peer.RemoteAddr(), conn.LocalAddr())
   149  			}
   150  			peers := srv.Peers()
   151  			if !reflect.DeepEqual(peers, []*Peer{peer}) {
   152  				t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
   153  			}
   154  		case <-time.After(1 * time.Second):
   155  			t.Error("server did not launch peer within one second")
   156  		}
   157  
   158  	case <-time.After(1 * time.Second):
   159  		t.Error("server did not connect within one second")
   160  	}
   161  }
   162  
   163  func TestServerTaskScheduling(t *testing.T) {
   164  	var (
   165  		done           = make(chan *testTask)
   166  		quit, returned = make(chan struct{}), make(chan struct{})
   167  		tc             = 0
   168  		tg             = taskgen{
   169  			newFunc: func(running int, peers map[discover.NodeID]*Peer) []task {
   170  				tc++
   171  				return []task{&testTask{index: tc - 1}}
   172  			},
   173  			doneFunc: func(t task) {
   174  				select {
   175  				case done <- t.(*testTask):
   176  				case <-quit:
   177  				}
   178  			},
   179  		}
   180  	)
   181  
   182  	srv := &Server{
   183  		Config:  Config{MaxPeers: 10},
   184  		quit:    make(chan struct{}),
   185  		ntab:    fakeTable{},
   186  		running: true,
   187  		log:     log.New(),
   188  	}
   189  	srv.loopWG.Add(1)
   190  	go func() {
   191  		srv.run(tg)
   192  		close(returned)
   193  	}()
   194  
   195  	var gotdone []*testTask
   196  	for i := 0; i < 100; i++ {
   197  		gotdone = append(gotdone, <-done)
   198  	}
   199  	for i, task := range gotdone {
   200  		if task.index != i {
   201  			t.Errorf("task %d has wrong index, got %d", i, task.index)
   202  			break
   203  		}
   204  		if !task.called {
   205  			t.Errorf("task %d was not called", i)
   206  			break
   207  		}
   208  	}
   209  
   210  	close(quit)
   211  	srv.Stop()
   212  	select {
   213  	case <-returned:
   214  	case <-time.After(500 * time.Millisecond):
   215  		t.Error("Server.run did not return within 500ms")
   216  	}
   217  }
   218  
   219  func TestServerManyTasks(t *testing.T) {
   220  	alltasks := make([]task, 300)
   221  	for i := range alltasks {
   222  		alltasks[i] = &testTask{index: i}
   223  	}
   224  
   225  	var (
   226  		srv = &Server{
   227  			quit:    make(chan struct{}),
   228  			ntab:    fakeTable{},
   229  			running: true,
   230  			log:     log.New(),
   231  		}
   232  		done       = make(chan *testTask)
   233  		start, end = 0, 0
   234  	)
   235  	defer srv.Stop()
   236  	srv.loopWG.Add(1)
   237  	go srv.run(taskgen{
   238  		newFunc: func(running int, peers map[discover.NodeID]*Peer) []task {
   239  			start, end = end, end+maxActiveDialTasks+10
   240  			if end > len(alltasks) {
   241  				end = len(alltasks)
   242  			}
   243  			return alltasks[start:end]
   244  		},
   245  		doneFunc: func(tt task) {
   246  			done <- tt.(*testTask)
   247  		},
   248  	})
   249  
   250  	doneset := make(map[int]bool)
   251  	timeout := time.After(2 * time.Second)
   252  	for len(doneset) < len(alltasks) {
   253  		select {
   254  		case tt := <-done:
   255  			if doneset[tt.index] {
   256  				t.Errorf("task %d got done more than once", tt.index)
   257  			} else {
   258  				doneset[tt.index] = true
   259  			}
   260  		case <-timeout:
   261  			t.Errorf("%d of %d tasks got done within 2s", len(doneset), len(alltasks))
   262  			for i := 0; i < len(alltasks); i++ {
   263  				if !doneset[i] {
   264  					t.Logf("task %d not done", i)
   265  				}
   266  			}
   267  			return
   268  		}
   269  	}
   270  }
   271  
   272  type taskgen struct {
   273  	newFunc  func(running int, peers map[discover.NodeID]*Peer) []task
   274  	doneFunc func(task)
   275  }
   276  
   277  func (tg taskgen) newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task {
   278  	return tg.newFunc(running, peers)
   279  }
   280  func (tg taskgen) taskDone(t task, now time.Time) {
   281  	tg.doneFunc(t)
   282  }
   283  func (tg taskgen) addStatic(*discover.Node) {
   284  }
   285  func (tg taskgen) removeStatic(*discover.Node) {
   286  }
   287  
   288  type testTask struct {
   289  	index  int
   290  	called bool
   291  }
   292  
   293  func (t *testTask) Do(srv *Server) {
   294  	t.called = true
   295  }
   296  
   297  func TestServerAtCap(t *testing.T) {
   298  	trustedID := randomID()
   299  	srv := &Server{
   300  		Config: Config{
   301  			PrivateKey:   newkey(),
   302  			MaxPeers:     10,
   303  			NoDial:       true,
   304  			TrustedNodes: []*discover.Node{{ID: trustedID}},
   305  		},
   306  	}
   307  	if err := srv.Start(); err != nil {
   308  		t.Fatalf("could not start: %v", err)
   309  	}
   310  	defer srv.Stop()
   311  
   312  	newconn := func(id discover.NodeID) *conn {
   313  		fd, _ := net.Pipe()
   314  		tx := newTestTransport(id, fd)
   315  		return &conn{fd: fd, transport: tx, flags: inboundConn, id: id, cont: make(chan error)}
   316  	}
   317  
   318  	for i := 0; i < 10; i++ {
   319  		c := newconn(randomID())
   320  		if err := srv.checkpoint(c, srv.addpeer); err != nil {
   321  			t.Fatalf("could not add conn %d: %v", i, err)
   322  		}
   323  	}
   324  
   325  	c := newconn(randomID())
   326  	if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers {
   327  		t.Error("wrong error for insert:", err)
   328  	}
   329  
   330  	c = newconn(trustedID)
   331  	if err := srv.checkpoint(c, srv.posthandshake); err != nil {
   332  		t.Error("unexpected error for trusted conn @posthandshake:", err)
   333  	}
   334  	if !c.is(trustedConn) {
   335  		t.Error("Server did not set trusted flag")
   336  	}
   337  
   338  }
   339  
   340  func TestServerSetupConn(t *testing.T) {
   341  	id := randomID()
   342  	srvkey := newkey()
   343  	srvid := discover.PubkeyID(&srvkey.PublicKey)
   344  	tests := []struct {
   345  		dontstart bool
   346  		tt        *setupTransport
   347  		flags     connFlag
   348  		dialDest  *discover.Node
   349  
   350  		wantCloseErr error
   351  		wantCalls    string
   352  	}{
   353  		{
   354  			dontstart:    true,
   355  			tt:           &setupTransport{id: id},
   356  			wantCalls:    "close,",
   357  			wantCloseErr: errServerStopped,
   358  		},
   359  		{
   360  			tt:           &setupTransport{id: id, encHandshakeErr: errors.New("read error")},
   361  			flags:        inboundConn,
   362  			wantCalls:    "doEncHandshake,close,",
   363  			wantCloseErr: errors.New("read error"),
   364  		},
   365  		{
   366  			tt:           &setupTransport{id: id},
   367  			dialDest:     &discover.Node{ID: randomID()},
   368  			flags:        dynDialedConn,
   369  			wantCalls:    "doEncHandshake,close,",
   370  			wantCloseErr: DiscUnexpectedIdentity,
   371  		},
   372  		{
   373  			tt:           &setupTransport{id: id, phs: &protoHandshake{ID: randomID()}},
   374  			dialDest:     &discover.Node{ID: id},
   375  			flags:        dynDialedConn,
   376  			wantCalls:    "doEncHandshake,doProtoHandshake,close,",
   377  			wantCloseErr: DiscUnexpectedIdentity,
   378  		},
   379  		{
   380  			tt:           &setupTransport{id: id, protoHandshakeErr: errors.New("foo")},
   381  			dialDest:     &discover.Node{ID: id},
   382  			flags:        dynDialedConn,
   383  			wantCalls:    "doEncHandshake,doProtoHandshake,close,",
   384  			wantCloseErr: errors.New("foo"),
   385  		},
   386  		{
   387  			tt:           &setupTransport{id: srvid, phs: &protoHandshake{ID: srvid}},
   388  			flags:        inboundConn,
   389  			wantCalls:    "doEncHandshake,close,",
   390  			wantCloseErr: DiscSelf,
   391  		},
   392  		{
   393  			tt:           &setupTransport{id: id, phs: &protoHandshake{ID: id}},
   394  			flags:        inboundConn,
   395  			wantCalls:    "doEncHandshake,doProtoHandshake,close,",
   396  			wantCloseErr: DiscUselessPeer,
   397  		},
   398  	}
   399  
   400  	for i, test := range tests {
   401  		srv := &Server{
   402  			Config: Config{
   403  				PrivateKey: srvkey,
   404  				MaxPeers:   10,
   405  				NoDial:     true,
   406  				Protocols:  []Protocol{discard},
   407  			},
   408  			newTransport: func(fd net.Conn) transport { return test.tt },
   409  			log:          log.New(),
   410  		}
   411  		if !test.dontstart {
   412  			if err := srv.Start(); err != nil {
   413  				t.Fatalf("couldn't start server: %v", err)
   414  			}
   415  		}
   416  		p1, _ := net.Pipe()
   417  		srv.SetupConn(p1, test.flags, test.dialDest)
   418  		if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) {
   419  			t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr)
   420  		}
   421  		if test.tt.calls != test.wantCalls {
   422  			t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls)
   423  		}
   424  	}
   425  }
   426  
   427  type setupTransport struct {
   428  	id              discover.NodeID
   429  	encHandshakeErr error
   430  
   431  	phs               *protoHandshake
   432  	protoHandshakeErr error
   433  
   434  	calls    string
   435  	closeErr error
   436  }
   437  
   438  func (c *setupTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) {
   439  	c.calls += "doEncHandshake,"
   440  	return c.id, c.encHandshakeErr
   441  }
   442  func (c *setupTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
   443  	c.calls += "doProtoHandshake,"
   444  	if c.protoHandshakeErr != nil {
   445  		return nil, c.protoHandshakeErr
   446  	}
   447  	return c.phs, nil
   448  }
   449  func (c *setupTransport) close(err error) {
   450  	c.calls += "close,"
   451  	c.closeErr = err
   452  }
   453  
   454  func (c *setupTransport) WriteMsg(Msg) error {
   455  	panic("WriteMsg called on setupTransport")
   456  }
   457  func (c *setupTransport) ReadMsg() (Msg, error) {
   458  	panic("ReadMsg called on setupTransport")
   459  }
   460  
   461  func newkey() *ecdsa.PrivateKey {
   462  	key, err := crypto.GenerateKey()
   463  	if err != nil {
   464  		panic("couldn't generate key: " + err.Error())
   465  	}
   466  	return key
   467  }
   468  
   469  func randomID() (id discover.NodeID) {
   470  	for i := range id {
   471  		id[i] = byte(rand.Intn(255))
   472  	}
   473  	return id
   474  }