gopkg.in/bitherhq/go-bither.v1@v1.7.1/p2p/server_test.go (about)

     1  // Copyright 2014 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package p2p
    18  
    19  import (
    20  	"crypto/ecdsa"
    21  	"errors"
    22  	"math/rand"
    23  	"net"
    24  	"reflect"
    25  	"testing"
    26  	"time"
    27  
    28  	"github.com/ethereum/go-ethereum/crypto"
    29  	"github.com/ethereum/go-ethereum/crypto/sha3"
    30  	"github.com/ethereum/go-ethereum/p2p/discover"
    31  )
    32  
    33  func init() {
    34  	// log.Root().SetHandler(log.LvlFilterHandler(log.LvlError, log.StreamHandler(os.Stderr, log.TerminalFormat(false))))
    35  }
    36  
    37  type testTransport struct {
    38  	id discover.NodeID
    39  	*rlpx
    40  
    41  	closeErr error
    42  }
    43  
    44  func newTestTransport(id discover.NodeID, fd net.Conn) transport {
    45  	wrapped := newRLPX(fd).(*rlpx)
    46  	wrapped.rw = newRLPXFrameRW(fd, secrets{
    47  		MAC:        zero16,
    48  		AES:        zero16,
    49  		IngressMAC: sha3.NewKeccak256(),
    50  		EgressMAC:  sha3.NewKeccak256(),
    51  	})
    52  	return &testTransport{id: id, rlpx: wrapped}
    53  }
    54  
    55  func (c *testTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) {
    56  	return c.id, nil
    57  }
    58  
    59  func (c *testTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
    60  	return &protoHandshake{ID: c.id, Name: "test"}, nil
    61  }
    62  
    63  func (c *testTransport) close(err error) {
    64  	c.rlpx.fd.Close()
    65  	c.closeErr = err
    66  }
    67  
    68  func startTestServer(t *testing.T, id discover.NodeID, pf func(*Peer)) *Server {
    69  	config := Config{
    70  		Name:       "test",
    71  		MaxPeers:   10,
    72  		ListenAddr: "127.0.0.1:0",
    73  		PrivateKey: newkey(),
    74  	}
    75  	server := &Server{
    76  		Config:       config,
    77  		newPeerHook:  pf,
    78  		newTransport: func(fd net.Conn) transport { return newTestTransport(id, fd) },
    79  	}
    80  	if err := server.Start(); err != nil {
    81  		t.Fatalf("Could not start server: %v", err)
    82  	}
    83  	return server
    84  }
    85  
    86  func TestServerListen(t *testing.T) {
    87  	// start the test server
    88  	connected := make(chan *Peer)
    89  	remid := randomID()
    90  	srv := startTestServer(t, remid, func(p *Peer) {
    91  		if p.ID() != remid {
    92  			t.Error("peer func called with wrong node id")
    93  		}
    94  		if p == nil {
    95  			t.Error("peer func called with nil conn")
    96  		}
    97  		connected <- p
    98  	})
    99  	defer close(connected)
   100  	defer srv.Stop()
   101  
   102  	// dial the test server
   103  	conn, err := net.DialTimeout("tcp", srv.ListenAddr, 5*time.Second)
   104  	if err != nil {
   105  		t.Fatalf("could not dial: %v", err)
   106  	}
   107  	defer conn.Close()
   108  
   109  	select {
   110  	case peer := <-connected:
   111  		if peer.LocalAddr().String() != conn.RemoteAddr().String() {
   112  			t.Errorf("peer started with wrong conn: got %v, want %v",
   113  				peer.LocalAddr(), conn.RemoteAddr())
   114  		}
   115  		peers := srv.Peers()
   116  		if !reflect.DeepEqual(peers, []*Peer{peer}) {
   117  			t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
   118  		}
   119  	case <-time.After(1 * time.Second):
   120  		t.Error("server did not accept within one second")
   121  	}
   122  }
   123  
   124  func TestServerDial(t *testing.T) {
   125  	// run a one-shot TCP server to handle the connection.
   126  	listener, err := net.Listen("tcp", "127.0.0.1:0")
   127  	if err != nil {
   128  		t.Fatalf("could not setup listener: %v", err)
   129  	}
   130  	defer listener.Close()
   131  	accepted := make(chan net.Conn)
   132  	go func() {
   133  		conn, err := listener.Accept()
   134  		if err != nil {
   135  			t.Error("accept error:", err)
   136  			return
   137  		}
   138  		accepted <- conn
   139  	}()
   140  
   141  	// start the server
   142  	connected := make(chan *Peer)
   143  	remid := randomID()
   144  	srv := startTestServer(t, remid, func(p *Peer) { connected <- p })
   145  	defer close(connected)
   146  	defer srv.Stop()
   147  
   148  	// tell the server to connect
   149  	tcpAddr := listener.Addr().(*net.TCPAddr)
   150  	srv.AddPeer(&discover.Node{ID: remid, IP: tcpAddr.IP, TCP: uint16(tcpAddr.Port)})
   151  
   152  	select {
   153  	case conn := <-accepted:
   154  		defer conn.Close()
   155  
   156  		select {
   157  		case peer := <-connected:
   158  			if peer.ID() != remid {
   159  				t.Errorf("peer has wrong id")
   160  			}
   161  			if peer.Name() != "test" {
   162  				t.Errorf("peer has wrong name")
   163  			}
   164  			if peer.RemoteAddr().String() != conn.LocalAddr().String() {
   165  				t.Errorf("peer started with wrong conn: got %v, want %v",
   166  					peer.RemoteAddr(), conn.LocalAddr())
   167  			}
   168  			peers := srv.Peers()
   169  			if !reflect.DeepEqual(peers, []*Peer{peer}) {
   170  				t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
   171  			}
   172  		case <-time.After(1 * time.Second):
   173  			t.Error("server did not launch peer within one second")
   174  		}
   175  
   176  	case <-time.After(1 * time.Second):
   177  		t.Error("server did not connect within one second")
   178  	}
   179  }
   180  
   181  // This test checks that tasks generated by dialstate are
   182  // actually executed and taskdone is called for them.
   183  func TestServerTaskScheduling(t *testing.T) {
   184  	var (
   185  		done           = make(chan *testTask)
   186  		quit, returned = make(chan struct{}), make(chan struct{})
   187  		tc             = 0
   188  		tg             = taskgen{
   189  			newFunc: func(running int, peers map[discover.NodeID]*Peer) []task {
   190  				tc++
   191  				return []task{&testTask{index: tc - 1}}
   192  			},
   193  			doneFunc: func(t task) {
   194  				select {
   195  				case done <- t.(*testTask):
   196  				case <-quit:
   197  				}
   198  			},
   199  		}
   200  	)
   201  
   202  	// The Server in this test isn't actually running
   203  	// because we're only interested in what run does.
   204  	srv := &Server{
   205  		Config:  Config{MaxPeers: 10},
   206  		quit:    make(chan struct{}),
   207  		ntab:    fakeTable{},
   208  		running: true,
   209  	}
   210  	srv.loopWG.Add(1)
   211  	go func() {
   212  		srv.run(tg)
   213  		close(returned)
   214  	}()
   215  
   216  	var gotdone []*testTask
   217  	for i := 0; i < 100; i++ {
   218  		gotdone = append(gotdone, <-done)
   219  	}
   220  	for i, task := range gotdone {
   221  		if task.index != i {
   222  			t.Errorf("task %d has wrong index, got %d", i, task.index)
   223  			break
   224  		}
   225  		if !task.called {
   226  			t.Errorf("task %d was not called", i)
   227  			break
   228  		}
   229  	}
   230  
   231  	close(quit)
   232  	srv.Stop()
   233  	select {
   234  	case <-returned:
   235  	case <-time.After(500 * time.Millisecond):
   236  		t.Error("Server.run did not return within 500ms")
   237  	}
   238  }
   239  
   240  // This test checks that Server doesn't drop tasks,
   241  // even if newTasks returns more than the maximum number of tasks.
   242  func TestServerManyTasks(t *testing.T) {
   243  	alltasks := make([]task, 300)
   244  	for i := range alltasks {
   245  		alltasks[i] = &testTask{index: i}
   246  	}
   247  
   248  	var (
   249  		srv        = &Server{quit: make(chan struct{}), ntab: fakeTable{}, running: true}
   250  		done       = make(chan *testTask)
   251  		start, end = 0, 0
   252  	)
   253  	defer srv.Stop()
   254  	srv.loopWG.Add(1)
   255  	go srv.run(taskgen{
   256  		newFunc: func(running int, peers map[discover.NodeID]*Peer) []task {
   257  			start, end = end, end+maxActiveDialTasks+10
   258  			if end > len(alltasks) {
   259  				end = len(alltasks)
   260  			}
   261  			return alltasks[start:end]
   262  		},
   263  		doneFunc: func(tt task) {
   264  			done <- tt.(*testTask)
   265  		},
   266  	})
   267  
   268  	doneset := make(map[int]bool)
   269  	timeout := time.After(2 * time.Second)
   270  	for len(doneset) < len(alltasks) {
   271  		select {
   272  		case tt := <-done:
   273  			if doneset[tt.index] {
   274  				t.Errorf("task %d got done more than once", tt.index)
   275  			} else {
   276  				doneset[tt.index] = true
   277  			}
   278  		case <-timeout:
   279  			t.Errorf("%d of %d tasks got done within 2s", len(doneset), len(alltasks))
   280  			for i := 0; i < len(alltasks); i++ {
   281  				if !doneset[i] {
   282  					t.Logf("task %d not done", i)
   283  				}
   284  			}
   285  			return
   286  		}
   287  	}
   288  }
   289  
   290  type taskgen struct {
   291  	newFunc  func(running int, peers map[discover.NodeID]*Peer) []task
   292  	doneFunc func(task)
   293  }
   294  
   295  func (tg taskgen) newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task {
   296  	return tg.newFunc(running, peers)
   297  }
   298  func (tg taskgen) taskDone(t task, now time.Time) {
   299  	tg.doneFunc(t)
   300  }
   301  func (tg taskgen) addStatic(*discover.Node) {
   302  }
   303  func (tg taskgen) removeStatic(*discover.Node) {
   304  }
   305  
   306  type testTask struct {
   307  	index  int
   308  	called bool
   309  }
   310  
   311  func (t *testTask) Do(srv *Server) {
   312  	t.called = true
   313  }
   314  
   315  // This test checks that connections are disconnected
   316  // just after the encryption handshake when the server is
   317  // at capacity. Trusted connections should still be accepted.
   318  func TestServerAtCap(t *testing.T) {
   319  	trustedID := randomID()
   320  	srv := &Server{
   321  		Config: Config{
   322  			PrivateKey:   newkey(),
   323  			MaxPeers:     10,
   324  			NoDial:       true,
   325  			TrustedNodes: []*discover.Node{{ID: trustedID}},
   326  		},
   327  	}
   328  	if err := srv.Start(); err != nil {
   329  		t.Fatalf("could not start: %v", err)
   330  	}
   331  	defer srv.Stop()
   332  
   333  	newconn := func(id discover.NodeID) *conn {
   334  		fd, _ := net.Pipe()
   335  		tx := newTestTransport(id, fd)
   336  		return &conn{fd: fd, transport: tx, flags: inboundConn, id: id, cont: make(chan error)}
   337  	}
   338  
   339  	// Inject a few connections to fill up the peer set.
   340  	for i := 0; i < 10; i++ {
   341  		c := newconn(randomID())
   342  		if err := srv.checkpoint(c, srv.addpeer); err != nil {
   343  			t.Fatalf("could not add conn %d: %v", i, err)
   344  		}
   345  	}
   346  	// Try inserting a non-trusted connection.
   347  	c := newconn(randomID())
   348  	if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers {
   349  		t.Error("wrong error for insert:", err)
   350  	}
   351  	// Try inserting a trusted connection.
   352  	c = newconn(trustedID)
   353  	if err := srv.checkpoint(c, srv.posthandshake); err != nil {
   354  		t.Error("unexpected error for trusted conn @posthandshake:", err)
   355  	}
   356  	if !c.is(trustedConn) {
   357  		t.Error("Server did not set trusted flag")
   358  	}
   359  
   360  }
   361  
   362  func TestServerSetupConn(t *testing.T) {
   363  	id := randomID()
   364  	srvkey := newkey()
   365  	srvid := discover.PubkeyID(&srvkey.PublicKey)
   366  	tests := []struct {
   367  		dontstart bool
   368  		tt        *setupTransport
   369  		flags     connFlag
   370  		dialDest  *discover.Node
   371  
   372  		wantCloseErr error
   373  		wantCalls    string
   374  	}{
   375  		{
   376  			dontstart:    true,
   377  			tt:           &setupTransport{id: id},
   378  			wantCalls:    "close,",
   379  			wantCloseErr: errServerStopped,
   380  		},
   381  		{
   382  			tt:           &setupTransport{id: id, encHandshakeErr: errors.New("read error")},
   383  			flags:        inboundConn,
   384  			wantCalls:    "doEncHandshake,close,",
   385  			wantCloseErr: errors.New("read error"),
   386  		},
   387  		{
   388  			tt:           &setupTransport{id: id},
   389  			dialDest:     &discover.Node{ID: randomID()},
   390  			flags:        dynDialedConn,
   391  			wantCalls:    "doEncHandshake,close,",
   392  			wantCloseErr: DiscUnexpectedIdentity,
   393  		},
   394  		{
   395  			tt:           &setupTransport{id: id, phs: &protoHandshake{ID: randomID()}},
   396  			dialDest:     &discover.Node{ID: id},
   397  			flags:        dynDialedConn,
   398  			wantCalls:    "doEncHandshake,doProtoHandshake,close,",
   399  			wantCloseErr: DiscUnexpectedIdentity,
   400  		},
   401  		{
   402  			tt:           &setupTransport{id: id, protoHandshakeErr: errors.New("foo")},
   403  			dialDest:     &discover.Node{ID: id},
   404  			flags:        dynDialedConn,
   405  			wantCalls:    "doEncHandshake,doProtoHandshake,close,",
   406  			wantCloseErr: errors.New("foo"),
   407  		},
   408  		{
   409  			tt:           &setupTransport{id: srvid, phs: &protoHandshake{ID: srvid}},
   410  			flags:        inboundConn,
   411  			wantCalls:    "doEncHandshake,close,",
   412  			wantCloseErr: DiscSelf,
   413  		},
   414  		{
   415  			tt:           &setupTransport{id: id, phs: &protoHandshake{ID: id}},
   416  			flags:        inboundConn,
   417  			wantCalls:    "doEncHandshake,doProtoHandshake,close,",
   418  			wantCloseErr: DiscUselessPeer,
   419  		},
   420  	}
   421  
   422  	for i, test := range tests {
   423  		srv := &Server{
   424  			Config: Config{
   425  				PrivateKey: srvkey,
   426  				MaxPeers:   10,
   427  				NoDial:     true,
   428  				Protocols:  []Protocol{discard},
   429  			},
   430  			newTransport: func(fd net.Conn) transport { return test.tt },
   431  		}
   432  		if !test.dontstart {
   433  			if err := srv.Start(); err != nil {
   434  				t.Fatalf("couldn't start server: %v", err)
   435  			}
   436  		}
   437  		p1, _ := net.Pipe()
   438  		srv.SetupConn(p1, test.flags, test.dialDest)
   439  		if !reflect.DeepEqual(test.tt.closeErr, test.wantCloseErr) {
   440  			t.Errorf("test %d: close error mismatch: got %q, want %q", i, test.tt.closeErr, test.wantCloseErr)
   441  		}
   442  		if test.tt.calls != test.wantCalls {
   443  			t.Errorf("test %d: calls mismatch: got %q, want %q", i, test.tt.calls, test.wantCalls)
   444  		}
   445  	}
   446  }
   447  
   448  type setupTransport struct {
   449  	id              discover.NodeID
   450  	encHandshakeErr error
   451  
   452  	phs               *protoHandshake
   453  	protoHandshakeErr error
   454  
   455  	calls    string
   456  	closeErr error
   457  }
   458  
   459  func (c *setupTransport) doEncHandshake(prv *ecdsa.PrivateKey, dialDest *discover.Node) (discover.NodeID, error) {
   460  	c.calls += "doEncHandshake,"
   461  	return c.id, c.encHandshakeErr
   462  }
   463  func (c *setupTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
   464  	c.calls += "doProtoHandshake,"
   465  	if c.protoHandshakeErr != nil {
   466  		return nil, c.protoHandshakeErr
   467  	}
   468  	return c.phs, nil
   469  }
   470  func (c *setupTransport) close(err error) {
   471  	c.calls += "close,"
   472  	c.closeErr = err
   473  }
   474  
   475  // setupConn shouldn't write to/read from the connection.
   476  func (c *setupTransport) WriteMsg(Msg) error {
   477  	panic("WriteMsg called on setupTransport")
   478  }
   479  func (c *setupTransport) ReadMsg() (Msg, error) {
   480  	panic("ReadMsg called on setupTransport")
   481  }
   482  
   483  func newkey() *ecdsa.PrivateKey {
   484  	key, err := crypto.GenerateKey()
   485  	if err != nil {
   486  		panic("couldn't generate key: " + err.Error())
   487  	}
   488  	return key
   489  }
   490  
   491  func randomID() (id discover.NodeID) {
   492  	for i := range id {
   493  		id[i] = byte(rand.Intn(255))
   494  	}
   495  	return id
   496  }