
     1  // Modifications Copyright 2018 The klaytn Authors
     2  // Copyright 2014 The go-ethereum Authors
     3  // This file is part of the go-ethereum library.
     4  //
     5  // The go-ethereum library is free software: you can redistribute it and/or modify
     6  // it under the terms of the GNU Lesser General Public License as published by
     7  // the Free Software Foundation, either version 3 of the License, or
     8  // (at your option) any later version.
     9  //
    10  // The go-ethereum library is distributed in the hope that it will be useful,
    11  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    13  // GNU Lesser General Public License for more details.
    14  //
    15  // You should have received a copy of the GNU Lesser General Public License
    16  // along with the go-ethereum library. If not, see <>.
    17  //
    18  // This file is derived from p2p/server_test.go (2018/06/04).
    19  // Modified and improved for the klaytn development.
    21  package p2p
    23  import (
    24  	"crypto/ecdsa"
    25  	"errors"
    26  	"math/rand"
    27  	"net"
    28  	"reflect"
    29  	"testing"
    30  	"time"
    32  	""
    33  	""
    34  	""
    35  	""
    36  	""
    37  )
    39  func init() {
    40  	// log.Root().SetHandler(logger.LvlFilterHandler(logger.LvlError, logger.StreamHandler(os.Stderr, logger.TerminalFormat(false))))
    41  }
    43  type testTransport struct {
    44  	id discover.NodeID
    45  	*rlpxTransport
    46  	mutichannel bool
    48  	closeErr error
    49  }
    51  func newTestTransport(id discover.NodeID, fd net.Conn, dialDest *ecdsa.PublicKey, mutichannel bool) transport {
    52  	wrapped := newRLPX(fd, dialDest).(*rlpxTransport)
    53  	wrapped.conn.InitWithSecrets(rlpx.Secrets{
    54  		MAC:        make([]byte, 16),
    55  		AES:        make([]byte, 16),
    56  		IngressMAC: sha3.NewKeccak256(),
    57  		EgressMAC:  sha3.NewKeccak256(),
    58  	})
    59  	return &testTransport{id: id, rlpxTransport: wrapped, mutichannel: mutichannel}
    60  }
    62  func (c *testTransport) doEncHandshake(prv *ecdsa.PrivateKey) (*ecdsa.PublicKey, error) {
    63  	remoteKey, _ :=
    64  	return remoteKey, nil
    65  }
    67  func (c *testTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
    68  	return &protoHandshake{ID:, Name: "test", Multichannel: c.mutichannel}, nil
    69  }
    71  func (c *testTransport) doConnTypeHandshake(myConnType common.ConnType) (common.ConnType, error) {
    72  	return 1, nil
    73  }
    75  func (c *testTransport) close(err error) {
    76  	c.conn.Close()
    77  	c.closeErr = err
    78  }
    80  func startTestServer(t *testing.T, id discover.NodeID, pf func(*Peer), config *Config) Server {
    81  	config.Name = "test"
    82  	config.MaxPhysicalConnections = 10
    83  	config.ListenAddr = ""
    84  	config.PrivateKey = newkey()
    85  	server := &SingleChannelServer{
    86  		&BaseServer{
    87  			Config:      *config,
    88  			newPeerHook: pf,
    89  			newTransport: func(fd net.Conn, dialDest *ecdsa.PublicKey) transport {
    90  				return newTestTransport(id, fd, dialDest, false)
    91  			},
    92  		},
    93  	}
    94  	if err := server.Start(); err != nil {
    95  		t.Fatalf("Could not start server: %v", err)
    96  	}
    97  	return server
    98  }
   100  func startTestMultiChannelServer(t *testing.T, id discover.NodeID, pf func(*Peer), config *Config) Server {
   101  	config.Name = "test"
   102  	config.MaxPhysicalConnections = 10
   103  	config.PrivateKey = newkey()
   105  	listeners := make([]net.Listener, 0, len(config.SubListenAddr)+1)
   106  	listenAddrs := make([]string, 0, len(config.SubListenAddr)+1)
   107  	listenAddrs = append(listenAddrs, config.ListenAddr)
   108  	listenAddrs = append(listenAddrs, config.SubListenAddr...)
   110  	server := &MultiChannelServer{
   111  		BaseServer: &BaseServer{
   112  			Config:      *config,
   113  			newPeerHook: pf,
   114  			newTransport: func(fd net.Conn, dialDest *ecdsa.PublicKey) transport {
   115  				return newTestTransport(id, fd, dialDest, true)
   116  			},
   117  		},
   118  		listeners:      listeners,
   119  		ListenAddrs:    listenAddrs,
   120  		CandidateConns: make(map[discover.NodeID][]*conn),
   121  	}
   122  	if err := server.Start(); err != nil {
   123  		t.Fatalf("Could not start server: %v", err)
   124  	}
   125  	return server
   126  }
   128  func makeconn(fd net.Conn, id discover.NodeID) *conn {
   129  	dialDest, _ := id.Pubkey()
   130  	tx := newTestTransport(id, fd, dialDest, false)
   131  	return &conn{fd: fd, transport: tx, flags: staticDialedConn, conntype: common.ConnTypeUndefined, id: id, cont: make(chan error)}
   132  }
   134  func makeMultiChannelConn(fd net.Conn, id discover.NodeID) *conn {
   135  	dialDest, _ := id.Pubkey()
   136  	tx := newTestTransport(id, fd, dialDest, true)
   137  	return &conn{fd: fd, transport: tx, flags: staticDialedConn, conntype: common.ConnTypeUndefined, id: id, cont: make(chan error), multiChannel: true}
   138  }
   140  func TestServerListen(t *testing.T) {
   141  	// start the test server
   142  	connected := make(chan *Peer)
   143  	remid := discover.PubkeyID(&newkey().PublicKey)
   144  	srv := startTestServer(t, remid, func(p *Peer) {
   145  		if p.ID() != remid {
   146  			t.Error("peer func called with wrong node id")
   147  		}
   148  		if p == nil {
   149  			t.Error("peer func called with nil conn")
   150  		}
   151  		connected <- p
   152  	}, &Config{})
   153  	defer close(connected)
   154  	defer srv.Stop()
   156  	// dial the test server
   157  	conn, err := net.DialTimeout("tcp", srv.GetListenAddress()[ConnDefault], 5*time.Second)
   158  	if err != nil {
   159  		t.Fatalf("could not dial: %v", err)
   160  	}
   161  	c := makeconn(conn, randomID())
   162  	c.doConnTypeHandshake(c.conntype)
   164  	defer conn.Close()
   166  	select {
   167  	case peer := <-connected:
   168  		if peer.LocalAddr().String() != conn.RemoteAddr().String() {
   169  			t.Errorf("peer started with wrong conn: got %v, want %v",
   170  				peer.LocalAddr(), conn.RemoteAddr())
   171  		}
   173  		peers := srv.Peers()
   174  		if !reflect.DeepEqual(peers, []*Peer{peer}) {
   175  			t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
   176  		}
   177  	case <-time.After(5 * time.Second):
   178  		t.Error("server did not accept within one second")
   179  	}
   180  }
   182  func TestMultiChannelServerListen(t *testing.T) {
   183  	// start the test server
   184  	connected := make(chan *Peer)
   185  	remid := discover.PubkeyID(&newkey().PublicKey)
   186  	config := &Config{ListenAddr: "", SubListenAddr: []string{""}}
   187  	srv := startTestMultiChannelServer(t, remid, func(p *Peer) {
   188  		if p.ID() != remid {
   189  			t.Error("peer func called with wrong node id")
   190  		}
   191  		if p == nil {
   192  			t.Error("peer func called with nil conn")
   193  		}
   194  		connected <- p
   195  	}, config)
   196  	defer close(connected)
   197  	defer srv.Stop()
   199  	// dial the test server
   200  	var defaultConn net.Conn
   202  	for i, address := range srv.GetListenAddress() {
   203  		conn, err := net.DialTimeout("tcp", address, 5*time.Second)
   204  		defer conn.Close()
   206  		if i == ConnDefault {
   207  			defaultConn = conn
   208  		}
   210  		if err != nil {
   211  			t.Fatalf("could not dial: %v", err)
   212  		}
   213  	}
   215  	select {
   216  	case peer := <-connected:
   217  		if peer.LocalAddr().String() != defaultConn.RemoteAddr().String() {
   218  			t.Errorf("peer started with wrong conn: got %v, want %v",
   219  				peer.LocalAddr(), defaultConn.RemoteAddr())
   220  		}
   222  		peers := srv.Peers()
   223  		if !reflect.DeepEqual(peers, []*Peer{peer}) {
   224  			t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
   225  		}
   226  	case <-time.After(5 * time.Second):
   227  		t.Error("server did not accept within five second")
   228  	}
   229  }
   231  func TestServerNoListen(t *testing.T) {
   232  	// start the test server
   233  	connected := make(chan *Peer)
   234  	remid := discover.PubkeyID(&newkey().PublicKey)
   235  	srv := startTestServer(t, remid, func(p *Peer) {
   236  		if p.ID() != remid {
   237  			t.Error("peer func called with wrong node id")
   238  		}
   239  		if p == nil {
   240  			t.Error("peer func called with nil conn")
   241  		}
   242  		connected <- p
   243  	}, &Config{NoListen: true})
   244  	defer close(connected)
   245  	defer srv.Stop()
   247  	// dial the test server that will be failed
   248  	_, err := net.DialTimeout("tcp", srv.GetListenAddress()[ConnDefault], 10*time.Millisecond)
   249  	if err == nil {
   250  		t.Fatalf("server started with listening")
   251  	}
   252  }
   254  func TestServerDial(t *testing.T) {
   255  	// run a one-shot TCP server to handle the connection.
   256  	listener, err := net.Listen("tcp", "")
   257  	if err != nil {
   258  		t.Fatalf("could not setup listener: %v", err)
   259  	}
   260  	defer listener.Close()
   261  	accepted := make(chan net.Conn)
   262  	go func() {
   263  		conn, err := listener.Accept()
   264  		if err != nil {
   265  			t.Error("accept error:", err)
   266  			return
   267  		}
   269  		c := makeconn(conn, discover.PubkeyID(&newkey().PublicKey))
   270  		c.doConnTypeHandshake(c.conntype)
   271  		accepted <- conn
   272  	}()
   274  	// start the server
   275  	connected := make(chan *Peer)
   276  	remid := discover.PubkeyID(&newkey().PublicKey)
   277  	srv := startTestServer(t, remid, func(p *Peer) { connected <- p }, &Config{})
   278  	defer close(connected)
   279  	defer srv.Stop()
   281  	// tell the server to connect
   282  	tcpAddr := listener.Addr().(*net.TCPAddr)
   283  	srv.AddPeer(&discover.Node{ID: remid, IP: tcpAddr.IP, TCP: uint16(tcpAddr.Port)})
   285  	select {
   286  	case conn := <-accepted:
   287  		defer conn.Close()
   289  		select {
   290  		case peer := <-connected:
   291  			if peer.ID() != remid {
   292  				t.Errorf("peer has wrong id")
   293  			}
   294  			if peer.Name() != "test" {
   295  				t.Errorf("peer has wrong name")
   296  			}
   297  			if peer.RemoteAddr().String() != conn.LocalAddr().String() {
   298  				t.Errorf("peer started with wrong conn: got %v, want %v",
   299  					peer.RemoteAddr(), conn.LocalAddr())
   300  			}
   301  			peers := srv.Peers()
   302  			if !reflect.DeepEqual(peers, []*Peer{peer}) {
   303  				t.Errorf("Peers mismatch: got %v, want %v", peers, []*Peer{peer})
   304  			}
   305  		case <-time.After(1 * time.Second):
   306  			t.Error("server did not launch peer within one second")
   307  		}
   309  	case <-time.After(1 * time.Second):
   310  		t.Error("server did not connect within one second")
   311  	}
   312  }
   314  // This test checks that tasks generated by dialstate are
   315  // actually executed and taskdone is called for them.
   316  func TestServerTaskScheduling(t *testing.T) {
   317  	var (
   318  		done           = make(chan *testTask)
   319  		quit, returned = make(chan struct{}), make(chan struct{})
   320  		tc             = 0
   321  		tg             = taskgen{
   322  			newFunc: func(running int, peers map[discover.NodeID]*Peer) []task {
   323  				tc++
   324  				return []task{&testTask{index: tc - 1}}
   325  			},
   326  			doneFunc: func(t task) {
   327  				select {
   328  				case done <- t.(*testTask):
   329  				case <-quit:
   330  				}
   331  			},
   332  		}
   333  	)
   335  	// The Server in this test isn't actually running
   336  	// because we're only interested in what run does.
   337  	srv := &SingleChannelServer{
   338  		&BaseServer{
   339  			Config:  Config{MaxPhysicalConnections: 10},
   340  			quit:    make(chan struct{}),
   341  			ntab:    fakeTable{},
   342  			running: true,
   343  			logger:  logger.NewWith(),
   344  		},
   345  	}
   346  	srv.loopWG.Add(1)
   347  	go func() {
   349  		close(returned)
   350  	}()
   352  	var gotdone []*testTask
   353  	for i := 0; i < 100; i++ {
   354  		gotdone = append(gotdone, <-done)
   355  	}
   356  	for i, task := range gotdone {
   357  		if task.index != i {
   358  			t.Errorf("task %d has wrong index, got %d", i, task.index)
   359  			break
   360  		}
   361  		if !task.called {
   362  			t.Errorf("task %d was not called", i)
   363  			break
   364  		}
   365  	}
   367  	close(quit)
   368  	srv.Stop()
   369  	select {
   370  	case <-returned:
   371  	case <-time.After(500 * time.Millisecond):
   372  		t.Error(" did not return within 500ms")
   373  	}
   374  }
   376  // This test checks that Server doesn't drop tasks,
   377  // even if newTasks returns more than the maximum number of tasks.
   378  func TestServerManyTasks(t *testing.T) {
   379  	alltasks := make([]task, 300)
   380  	for i := range alltasks {
   381  		alltasks[i] = &testTask{index: i}
   382  	}
   384  	var (
   385  		srv = &SingleChannelServer{
   386  			&BaseServer{
   387  				quit:    make(chan struct{}),
   388  				ntab:    fakeTable{},
   389  				running: true,
   390  				logger:  logger.NewWith(),
   391  			},
   392  		}
   393  		done       = make(chan *testTask)
   394  		start, end = 0, 0
   395  	)
   396  	defer srv.Stop()
   397  	srv.loopWG.Add(1)
   398  	go{
   399  		newFunc: func(running int, peers map[discover.NodeID]*Peer) []task {
   400  			start, end = end, end+maxActiveDialTasks+10
   401  			if end > len(alltasks) {
   402  				end = len(alltasks)
   403  			}
   404  			return alltasks[start:end]
   405  		},
   406  		doneFunc: func(tt task) {
   407  			done <- tt.(*testTask)
   408  		},
   409  	})
   411  	doneset := make(map[int]bool)
   412  	timeout := time.After(2 * time.Second)
   413  	for len(doneset) < len(alltasks) {
   414  		select {
   415  		case tt := <-done:
   416  			if doneset[tt.index] {
   417  				t.Errorf("task %d got done more than once", tt.index)
   418  			} else {
   419  				doneset[tt.index] = true
   420  			}
   421  		case <-timeout:
   422  			t.Errorf("%d of %d tasks got done within 2s", len(doneset), len(alltasks))
   423  			for i := 0; i < len(alltasks); i++ {
   424  				if !doneset[i] {
   425  					t.Logf("task %d not done", i)
   426  				}
   427  			}
   428  			return
   429  		}
   430  	}
   431  }
   433  type taskgen struct {
   434  	newFunc  func(running int, peers map[discover.NodeID]*Peer) []task
   435  	doneFunc func(task)
   436  }
   438  func (tg taskgen) newTasks(running int, peers map[discover.NodeID]*Peer, now time.Time) []task {
   439  	return tg.newFunc(running, peers)
   440  }
   442  func (tg taskgen) taskDone(t task, now time.Time) {
   443  	tg.doneFunc(t)
   444  }
   446  func (tg taskgen) addStatic(*discover.Node) {
   447  }
   449  func (tg taskgen) removeStatic(*discover.Node) {
   450  }
   452  type testTask struct {
   453  	index  int
   454  	called bool
   455  }
   457  func (t *testTask) Do(srv Server) {
   458  	t.called = true
   459  }
   461  // This test checks that connections are disconnected
   462  // just after the encryption handshake when the server is
   463  // at capacity. Trusted connections should still be accepted.
   464  func TestServerAtCap(t *testing.T) {
   465  	trustedID := randomID()
   466  	srv := &SingleChannelServer{
   467  		BaseServer: &BaseServer{
   468  			Config: Config{
   469  				PrivateKey:             newkey(),
   470  				MaxPhysicalConnections: 10,
   471  				NoDial:                 true,
   472  				TrustedNodes:           []*discover.Node{{ID: trustedID}},
   473  			},
   474  		},
   475  	}
   476  	if err := srv.Start(); err != nil {
   477  		t.Fatalf("could not start: %v", err)
   478  	}
   479  	defer srv.Stop()
   481  	newconn := func(id discover.NodeID) *conn {
   482  		fd, _ := net.Pipe()
   483  		tx := newTestTransport(id, fd, nil, false)
   484  		return &conn{fd: fd, transport: tx, flags: inboundConn, conntype: common.ConnTypeUndefined, id: id, cont: make(chan error)}
   485  	}
   487  	// Inject a few connections to fill up the peer set.
   488  	for i := 0; i < 10; i++ {
   489  		c := newconn(randomID())
   490  		if err := srv.checkpoint(c, srv.addpeer); err != nil {
   491  			t.Fatalf("could not add conn %d: %v", i, err)
   492  		}
   493  	}
   494  	// Try inserting a non-trusted connection.
   495  	c := newconn(randomID())
   496  	if err := srv.checkpoint(c, srv.posthandshake); err != DiscTooManyPeers {
   497  		t.Error("wrong error for insert:", err)
   498  	}
   499  	// Try inserting a trusted connection.
   500  	c = newconn(trustedID)
   501  	if err := srv.checkpoint(c, srv.posthandshake); err != nil {
   502  		t.Error("unexpected error for trusted conn @posthandshake:", err)
   503  	}
   504  	if ! {
   505  		t.Error("Server did not set trusted flag")
   506  	}
   507  }
   509  func TestServerSetupConn(t *testing.T) {
   510  	var (
   511  		id     = discover.PubkeyID(&newkey().PublicKey)
   512  		srvkey = newkey()
   513  		srvid  = discover.PubkeyID(&srvkey.PublicKey)
   514  	)
   516  	tests := []struct {
   517  		dontstart bool
   518  		tt        *setupTransport
   519  		flags     connFlag
   520  		dialDest  *discover.Node
   522  		wantCloseErr error
   523  		wantCalls    string
   524  	}{
   525  		{
   526  			dontstart:    true,
   527  			tt:           &setupTransport{id: id},
   528  			wantCalls:    "close,",
   529  			wantCloseErr: errServerStopped,
   530  		},
   531  		{
   532  			tt:           &setupTransport{id: id, encHandshakeErr: errors.New("read error")},
   533  			flags:        inboundConn,
   534  			wantCalls:    "doEncHandshake,close,",
   535  			wantCloseErr: errors.New("read error"),
   536  		},
   537  		{
   538  			tt:           &setupTransport{id: id, phs: &protoHandshake{ID: randomID()}},
   539  			dialDest:     &discover.Node{ID: id, NType: discover.NodeType(common.ENDPOINTNODE)},
   540  			flags:        dynDialedConn,
   541  			wantCalls:    "doEncHandshake,doProtoHandshake,close,",
   542  			wantCloseErr: DiscUnexpectedIdentity,
   543  		},
   544  		{
   545  			tt:           &setupTransport{id: id, protoHandshakeErr: errors.New("foo")},
   546  			dialDest:     &discover.Node{ID: id, NType: discover.NodeType(common.ENDPOINTNODE)},
   547  			flags:        dynDialedConn,
   548  			wantCalls:    "doEncHandshake,doProtoHandshake,close,",
   549  			wantCloseErr: errors.New("foo"),
   550  		},
   551  		{
   552  			tt:           &setupTransport{id: srvid, phs: &protoHandshake{ID: srvid}},
   553  			flags:        inboundConn,
   554  			wantCalls:    "doEncHandshake,close,",
   555  			wantCloseErr: DiscSelf,
   556  		},
   557  		{
   558  			tt:           &setupTransport{id: id, phs: &protoHandshake{ID: id}},
   559  			flags:        inboundConn,
   560  			wantCalls:    "doEncHandshake,doProtoHandshake,close,",
   561  			wantCloseErr: DiscUselessPeer,
   562  		},
   563  	}
   565  	for i, test := range tests {
   566  		srv := &SingleChannelServer{
   567  			&BaseServer{
   568  				Config: Config{
   569  					PrivateKey:             srvkey,
   570  					MaxPhysicalConnections: 10,
   571  					NoDial:                 true,
   572  					Protocols:              []Protocol{discard},
   573  					ConnectionType:         1, // ENDPOINTNODE
   574  				},
   575  				newTransport: func(fd net.Conn, dialDest *ecdsa.PublicKey) transport { return },
   576  				logger:       logger.NewWith(),
   577  			},
   578  		}
   579  		if !test.dontstart {
   580  			if err := srv.Start(); err != nil {
   581  				t.Fatalf("couldn't start server: %v", err)
   582  			}
   583  		}
   584  		p1, _ := net.Pipe()
   585  		srv.SetupConn(p1, test.flags, test.dialDest)
   586  		if !reflect.DeepEqual(, test.wantCloseErr) {
   587  			t.Errorf("test %d: close error mismatch: got %q, want %q", i,, test.wantCloseErr)
   588  		}
   589  		if != test.wantCalls {
   590  			t.Errorf("test %d: calls mismatch: got %q, want %q", i,, test.wantCalls)
   591  		}
   592  	}
   593  }
   595  type setupTransport struct {
   596  	id              discover.NodeID
   597  	encHandshakeErr error
   599  	phs               *protoHandshake
   600  	protoHandshakeErr error
   602  	calls    string
   603  	closeErr error
   604  }
   606  func (c *setupTransport) doConnTypeHandshake(myConnType common.ConnType) (common.ConnType, error) {
   607  	return 1, nil
   608  }
   610  func (c *setupTransport) doEncHandshake(prv *ecdsa.PrivateKey) (*ecdsa.PublicKey, error) {
   611  	c.calls += "doEncHandshake,"
   612  	pubkey, _ :=
   613  	return pubkey, c.encHandshakeErr
   614  }
   616  func (c *setupTransport) doProtoHandshake(our *protoHandshake) (*protoHandshake, error) {
   617  	c.calls += "doProtoHandshake,"
   618  	if c.protoHandshakeErr != nil {
   619  		return nil, c.protoHandshakeErr
   620  	}
   621  	return c.phs, nil
   622  }
   624  func (c *setupTransport) close(err error) {
   625  	c.calls += "close,"
   626  	c.closeErr = err
   627  }
   629  // setupConn shouldn't write to/read from the connection.
   630  func (c *setupTransport) WriteMsg(Msg) error {
   631  	panic("WriteMsg called on setupTransport")
   632  }
   634  func (c *setupTransport) ReadMsg() (Msg, error) {
   635  	panic("ReadMsg called on setupTransport")
   636  }
   638  func newkey() *ecdsa.PrivateKey {
   639  	key, err := crypto.GenerateKey()
   640  	if err != nil {
   641  		panic("couldn't generate key: " + err.Error())
   642  	}
   643  	return key
   644  }
   646  func randomID() (id discover.NodeID) {
   647  	for i := range id {
   648  		id[i] = byte(rand.Intn(255))
   649  	}
   650  	return id
   651  }