github.com/ethereum/go-ethereum@v1.16.1/cmd/devp2p/internal/v5test/discv5tests.go (about)

     1  // Copyright 2020 The go-ethereum Authors
     2  // This file is part of go-ethereum.
     3  //
     4  // go-ethereum is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // go-ethereum is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU General Public License
    15  // along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package v5test
    18  
    19  import (
    20  	"bytes"
    21  	"net"
    22  	"slices"
    23  	"sync"
    24  	"time"
    25  
    26  	"github.com/ethereum/go-ethereum/internal/utesting"
    27  	"github.com/ethereum/go-ethereum/p2p/discover/v5wire"
    28  	"github.com/ethereum/go-ethereum/p2p/enode"
    29  	"github.com/ethereum/go-ethereum/p2p/netutil"
    30  )
    31  
    32  // Suite is the discv5 test suite.
    33  type Suite struct {
    34  	Dest             *enode.Node
    35  	Listen1, Listen2 string // listening addresses
    36  }
    37  
    38  func (s *Suite) listen1(log logger) (*conn, net.PacketConn) {
    39  	c := newConn(s.Dest, log)
    40  	l := c.listen(s.Listen1)
    41  	return c, l
    42  }
    43  
    44  func (s *Suite) listen2(log logger) (*conn, net.PacketConn, net.PacketConn) {
    45  	c := newConn(s.Dest, log)
    46  	l1, l2 := c.listen(s.Listen1), c.listen(s.Listen2)
    47  	return c, l1, l2
    48  }
    49  
    50  func (s *Suite) AllTests() []utesting.Test {
    51  	return []utesting.Test{
    52  		{Name: "Ping", Fn: s.TestPing},
    53  		{Name: "PingLargeRequestID", Fn: s.TestPingLargeRequestID},
    54  		{Name: "PingMultiIP", Fn: s.TestPingMultiIP},
    55  		{Name: "PingHandshakeInterrupted", Fn: s.TestPingHandshakeInterrupted},
    56  		{Name: "TalkRequest", Fn: s.TestTalkRequest},
    57  		{Name: "FindnodeZeroDistance", Fn: s.TestFindnodeZeroDistance},
    58  		{Name: "FindnodeResults", Fn: s.TestFindnodeResults},
    59  	}
    60  }
    61  
    62  func (s *Suite) TestPing(t *utesting.T) {
    63  	t.Log(`This test is just a sanity check. It sends PING and expects a PONG response.`)
    64  
    65  	conn, l1 := s.listen1(t)
    66  	defer conn.close()
    67  
    68  	ping := &v5wire.Ping{ReqID: conn.nextReqID()}
    69  	switch resp := conn.reqresp(l1, ping).(type) {
    70  	case *v5wire.Pong:
    71  		checkPong(t, resp, ping, l1)
    72  	default:
    73  		t.Fatal("expected PONG, got", resp.Name())
    74  	}
    75  }
    76  
    77  func checkPong(t *utesting.T, pong *v5wire.Pong, ping *v5wire.Ping, c net.PacketConn) {
    78  	if !bytes.Equal(pong.ReqID, ping.ReqID) {
    79  		t.Fatalf("wrong request ID %x in PONG, want %x", pong.ReqID, ping.ReqID)
    80  	}
    81  	if !pong.ToIP.Equal(laddr(c).IP) {
    82  		t.Fatalf("wrong destination IP %v in PONG, want %v", pong.ToIP, laddr(c).IP)
    83  	}
    84  	if int(pong.ToPort) != laddr(c).Port {
    85  		t.Fatalf("wrong destination port %v in PONG, want %v", pong.ToPort, laddr(c).Port)
    86  	}
    87  }
    88  
    89  func (s *Suite) TestPingLargeRequestID(t *utesting.T) {
    90  	t.Log(`This test sends PING with a 9-byte request ID, which isn't allowed by the spec.
    91  The remote node should not respond.`)
    92  
    93  	conn, l1 := s.listen1(t)
    94  	defer conn.close()
    95  
    96  	ping := &v5wire.Ping{ReqID: make([]byte, 9)}
    97  	switch resp := conn.reqresp(l1, ping).(type) {
    98  	case *v5wire.Pong:
    99  		t.Errorf("PONG response with unknown request ID %x", resp.ReqID)
   100  	case *readError:
   101  		if resp.err == v5wire.ErrInvalidReqID {
   102  			t.Error("response with oversized request ID")
   103  		} else if !netutil.IsTimeout(resp.err) {
   104  			t.Error(resp)
   105  		}
   106  	}
   107  }
   108  
   109  func (s *Suite) TestPingMultiIP(t *utesting.T) {
   110  	t.Log(`This test establishes a session from one IP as usual. The session is then reused
   111  on another IP, which shouldn't work. The remote node should respond with WHOAREYOU for
   112  the attempt from a different IP.`)
   113  
   114  	conn, l1, l2 := s.listen2(t)
   115  	defer conn.close()
   116  
   117  	// Create the session on l1.
   118  	ping := &v5wire.Ping{ReqID: conn.nextReqID()}
   119  	resp := conn.reqresp(l1, ping)
   120  	if resp.Kind() != v5wire.PongMsg {
   121  		t.Fatal("expected PONG, got", resp)
   122  	}
   123  	checkPong(t, resp.(*v5wire.Pong), ping, l1)
   124  
   125  	// Send on l2. This reuses the session because there is only one codec.
   126  	t.Log("sending ping from alternate IP", l2.LocalAddr())
   127  	ping2 := &v5wire.Ping{ReqID: conn.nextReqID()}
   128  	conn.write(l2, ping2, nil)
   129  	switch resp := conn.read(l2).(type) {
   130  	case *v5wire.Pong:
   131  		t.Fatalf("remote responded to PING from %v for session on IP %v", laddr(l2).IP, laddr(l1).IP)
   132  	case *v5wire.Whoareyou:
   133  		t.Logf("got WHOAREYOU for new session as expected")
   134  		resp.Node = s.Dest
   135  		conn.write(l2, ping2, resp)
   136  	default:
   137  		t.Fatal("expected WHOAREYOU, got", resp)
   138  	}
   139  
   140  	// Catch the PONG on l2.
   141  	switch resp := conn.read(l2).(type) {
   142  	case *v5wire.Pong:
   143  		checkPong(t, resp, ping2, l2)
   144  	default:
   145  		t.Fatal("expected PONG, got", resp)
   146  	}
   147  
   148  	// Try on l1 again.
   149  	ping3 := &v5wire.Ping{ReqID: conn.nextReqID()}
   150  	conn.write(l1, ping3, nil)
   151  	switch resp := conn.read(l1).(type) {
   152  	case *v5wire.Pong:
   153  		t.Fatalf("remote responded to PING from %v for session on IP %v", laddr(l1).IP, laddr(l2).IP)
   154  	case *v5wire.Whoareyou:
   155  		t.Logf("got WHOAREYOU for new session as expected")
   156  	default:
   157  		t.Fatal("expected WHOAREYOU, got", resp)
   158  	}
   159  }
   160  
   161  // TestPingHandshakeInterrupted starts a handshake, but doesn't finish it and sends a second ordinary message
   162  // packet instead of a handshake message packet. The remote node should respond with
   163  // another WHOAREYOU challenge for the second packet.
   164  func (s *Suite) TestPingHandshakeInterrupted(t *utesting.T) {
   165  	t.Log(`TestPingHandshakeInterrupted starts a handshake, but doesn't finish it and sends a second ordinary message
   166  packet instead of a handshake message packet. The remote node should respond with
   167  another WHOAREYOU challenge for the second packet.`)
   168  
   169  	conn, l1 := s.listen1(t)
   170  	defer conn.close()
   171  
   172  	// First PING triggers challenge.
   173  	ping := &v5wire.Ping{ReqID: conn.nextReqID()}
   174  	conn.write(l1, ping, nil)
   175  	switch resp := conn.read(l1).(type) {
   176  	case *v5wire.Whoareyou:
   177  		t.Logf("got WHOAREYOU for PING")
   178  	default:
   179  		t.Fatal("expected WHOAREYOU, got", resp)
   180  	}
   181  
   182  	// Send second PING.
   183  	ping2 := &v5wire.Ping{ReqID: conn.nextReqID()}
   184  	switch resp := conn.reqresp(l1, ping2).(type) {
   185  	case *v5wire.Pong:
   186  		checkPong(t, resp, ping2, l1)
   187  	default:
   188  		t.Fatal("expected WHOAREYOU, got", resp)
   189  	}
   190  }
   191  
   192  func (s *Suite) TestTalkRequest(t *utesting.T) {
   193  	t.Log(`This test sends some examples of TALKREQ with a protocol-id of "test-protocol"
   194  and expects an empty TALKRESP response.`)
   195  
   196  	conn, l1 := s.listen1(t)
   197  	defer conn.close()
   198  
   199  	// Non-empty request ID.
   200  	id := conn.nextReqID()
   201  	resp := conn.reqresp(l1, &v5wire.TalkRequest{ReqID: id, Protocol: "test-protocol"})
   202  	switch resp := resp.(type) {
   203  	case *v5wire.TalkResponse:
   204  		if !bytes.Equal(resp.ReqID, id) {
   205  			t.Fatalf("wrong request ID %x in TALKRESP, want %x", resp.ReqID, id)
   206  		}
   207  		if len(resp.Message) > 0 {
   208  			t.Fatalf("non-empty message %x in TALKRESP", resp.Message)
   209  		}
   210  	default:
   211  		t.Fatal("expected TALKRESP, got", resp.Name())
   212  	}
   213  
   214  	// Empty request ID.
   215  	t.Log("sending TALKREQ with empty request-id")
   216  	resp = conn.reqresp(l1, &v5wire.TalkRequest{Protocol: "test-protocol"})
   217  	switch resp := resp.(type) {
   218  	case *v5wire.TalkResponse:
   219  		if len(resp.ReqID) > 0 {
   220  			t.Fatalf("wrong request ID %x in TALKRESP, want empty byte array", resp.ReqID)
   221  		}
   222  		if len(resp.Message) > 0 {
   223  			t.Fatalf("non-empty message %x in TALKRESP", resp.Message)
   224  		}
   225  	default:
   226  		t.Fatal("expected TALKRESP, got", resp.Name())
   227  	}
   228  }
   229  
   230  func (s *Suite) TestFindnodeZeroDistance(t *utesting.T) {
   231  	t.Log(`This test checks that the remote node returns itself for FINDNODE with distance zero.`)
   232  
   233  	conn, l1 := s.listen1(t)
   234  	defer conn.close()
   235  
   236  	nodes, err := conn.findnode(l1, []uint{0})
   237  	if err != nil {
   238  		t.Fatal(err)
   239  	}
   240  	if len(nodes) != 1 {
   241  		t.Fatalf("remote returned more than one node for FINDNODE [0]")
   242  	}
   243  	if nodes[0].ID() != conn.remote.ID() {
   244  		t.Errorf("ID of response node is %v, want %v", nodes[0].ID(), conn.remote.ID())
   245  	}
   246  }
   247  
   248  func (s *Suite) TestFindnodeResults(t *utesting.T) {
   249  	t.Log(`This test pings the node under test from multiple other endpoints and node identities
   250  (the 'bystanders'). After waiting for them to be accepted into the remote table, the test checks
   251  that they are returned by FINDNODE.`)
   252  
   253  	// Create bystanders.
   254  	nodes := make([]*bystander, 5)
   255  	added := make(chan enode.ID, len(nodes))
   256  	for i := range nodes {
   257  		nodes[i] = newBystander(t, s, added)
   258  		defer nodes[i].close()
   259  	}
   260  
   261  	// Get them added to the remote table.
   262  	timeout := 60 * time.Second
   263  	timeoutCh := time.After(timeout)
   264  	for count := 0; count < len(nodes); {
   265  		select {
   266  		case id := <-added:
   267  			t.Logf("bystander node %v added to remote table", id)
   268  			count++
   269  		case <-timeoutCh:
   270  			t.Errorf("remote added %d bystander nodes in %v, need %d to continue", count, timeout, len(nodes))
   271  			t.Logf("this can happen if the node has a non-empty table from previous runs")
   272  			return
   273  		}
   274  	}
   275  	t.Logf("all %d bystander nodes were added", len(nodes))
   276  
   277  	// Collect our nodes by distance.
   278  	var dists []uint
   279  	expect := make(map[enode.ID]*enode.Node)
   280  	for _, bn := range nodes {
   281  		n := bn.conn.localNode.Node()
   282  		expect[n.ID()] = n
   283  		d := uint(enode.LogDist(n.ID(), s.Dest.ID()))
   284  		if !slices.Contains(dists, d) {
   285  			dists = append(dists, d)
   286  		}
   287  	}
   288  
   289  	// Send FINDNODE for all distances.
   290  	t.Log("requesting nodes")
   291  	conn, l1 := s.listen1(t)
   292  	defer conn.close()
   293  	foundNodes, err := conn.findnode(l1, dists)
   294  	if err != nil {
   295  		t.Fatal(err)
   296  	}
   297  	t.Logf("remote returned %d nodes for distance list %v", len(foundNodes), dists)
   298  	for _, n := range foundNodes {
   299  		delete(expect, n.ID())
   300  	}
   301  	if len(expect) > 0 {
   302  		t.Errorf("missing %d nodes in FINDNODE result", len(expect))
   303  		t.Logf("this can happen if the test is run multiple times in quick succession")
   304  		t.Logf("and the remote node hasn't removed dead nodes from previous runs yet")
   305  	} else {
   306  		t.Logf("all %d expected nodes were returned", len(nodes))
   307  	}
   308  }
   309  
   310  // A bystander is a node whose only purpose is filling a spot in the remote table.
   311  type bystander struct {
   312  	dest *enode.Node
   313  	conn *conn
   314  	l    net.PacketConn
   315  
   316  	addedCh chan enode.ID
   317  	done    sync.WaitGroup
   318  }
   319  
   320  func newBystander(t *utesting.T, s *Suite, added chan enode.ID) *bystander {
   321  	conn, l := s.listen1(t)
   322  	conn.setEndpoint(l) // bystander nodes need IP/port to get pinged
   323  	bn := &bystander{
   324  		conn:    conn,
   325  		l:       l,
   326  		dest:    s.Dest,
   327  		addedCh: added,
   328  	}
   329  	bn.done.Add(1)
   330  	go bn.loop()
   331  	return bn
   332  }
   333  
   334  // id returns the node ID of the bystander.
   335  func (bn *bystander) id() enode.ID {
   336  	return bn.conn.localNode.ID()
   337  }
   338  
   339  // close shuts down loop.
   340  func (bn *bystander) close() {
   341  	bn.conn.close()
   342  	bn.done.Wait()
   343  }
   344  
   345  // loop answers packets from the remote node until quit.
   346  func (bn *bystander) loop() {
   347  	defer bn.done.Done()
   348  
   349  	var (
   350  		lastPing time.Time
   351  		wasAdded bool
   352  	)
   353  	for {
   354  		// Ping the remote node.
   355  		if !wasAdded && time.Since(lastPing) > 10*time.Second {
   356  			bn.conn.reqresp(bn.l, &v5wire.Ping{
   357  				ReqID:  bn.conn.nextReqID(),
   358  				ENRSeq: bn.dest.Seq(),
   359  			})
   360  			lastPing = time.Now()
   361  		}
   362  		// Answer packets.
   363  		switch p := bn.conn.read(bn.l).(type) {
   364  		case *v5wire.Ping:
   365  			bn.conn.write(bn.l, &v5wire.Pong{
   366  				ReqID:  p.ReqID,
   367  				ENRSeq: bn.conn.localNode.Seq(),
   368  				ToIP:   bn.dest.IP(),
   369  				ToPort: uint16(bn.dest.UDP()),
   370  			}, nil)
   371  			wasAdded = true
   372  			bn.notifyAdded()
   373  		case *v5wire.Findnode:
   374  			bn.conn.write(bn.l, &v5wire.Nodes{ReqID: p.ReqID, RespCount: 1}, nil)
   375  			wasAdded = true
   376  			bn.notifyAdded()
   377  		case *v5wire.TalkRequest:
   378  			bn.conn.write(bn.l, &v5wire.TalkResponse{ReqID: p.ReqID}, nil)
   379  		case *readError:
   380  			if !netutil.IsTemporaryError(p.err) {
   381  				bn.conn.logf("shutting down: %v", p.err)
   382  				return
   383  			}
   384  		}
   385  	}
   386  }
   387  
   388  func (bn *bystander) notifyAdded() {
   389  	if bn.addedCh != nil {
   390  		bn.addedCh <- bn.id()
   391  		bn.addedCh = nil
   392  	}
   393  }