github.com/fff-chain/go-fff@v0.0.0-20220726032732-1c84420b8a99/cmd/devp2p/internal/ethtest/eth66_suiteHelpers.go (about)

     1  // Copyright 2021 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package ethtest
    18  
    19  import (
    20  	"fmt"
    21  	"reflect"
    22  	"time"
    23  
    24  	"github.com/fff-chain/go-fff/core/types"
    25  	"github.com/fff-chain/go-fff/eth/protocols/eth"
    26  	"github.com/fff-chain/go-fff/internal/utesting"
    27  	"github.com/fff-chain/go-fff/p2p"
    28  	"github.com/fff-chain/go-fff/rlp"
    29  	"github.com/stretchr/testify/assert"
    30  )
    31  
    32  func (c *Conn) statusExchange66(t *utesting.T, chain *Chain) Message {
    33  	status := &Status{
    34  		ProtocolVersion: uint32(66),
    35  		NetworkID:       chain.chainConfig.ChainID.Uint64(),
    36  		TD:              chain.TD(chain.Len()),
    37  		Head:            chain.blocks[chain.Len()-1].Hash(),
    38  		Genesis:         chain.blocks[0].Hash(),
    39  		ForkID:          chain.ForkID(),
    40  	}
    41  	return c.statusExchange(t, chain, status)
    42  }
    43  
    44  func (s *Suite) dial66(t *utesting.T) *Conn {
    45  	conn, err := s.dial()
    46  	if err != nil {
    47  		t.Fatalf("could not dial: %v", err)
    48  	}
    49  	conn.caps = append(conn.caps, p2p.Cap{Name: "eth", Version: 66})
    50  	conn.ourHighestProtoVersion = 66
    51  	return conn
    52  }
    53  
    54  func (c *Conn) write66(req eth.Packet, code int) error {
    55  	payload, err := rlp.EncodeToBytes(req)
    56  	if err != nil {
    57  		return err
    58  	}
    59  	_, err = c.Conn.Write(uint64(code), payload)
    60  	return err
    61  }
    62  
    63  func (c *Conn) read66() (uint64, Message) {
    64  	code, rawData, _, err := c.Conn.Read()
    65  	if err != nil {
    66  		return 0, errorf("could not read from connection: %v", err)
    67  	}
    68  
    69  	var msg Message
    70  
    71  	switch int(code) {
    72  	case (Hello{}).Code():
    73  		msg = new(Hello)
    74  
    75  	case (Ping{}).Code():
    76  		msg = new(Ping)
    77  	case (Pong{}).Code():
    78  		msg = new(Pong)
    79  	case (Disconnect{}).Code():
    80  		msg = new(Disconnect)
    81  	case (Status{}).Code():
    82  		msg = new(Status)
    83  	case (GetBlockHeaders{}).Code():
    84  		ethMsg := new(eth.GetBlockHeadersPacket66)
    85  		if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
    86  			return 0, errorf("could not rlp decode message: %v", err)
    87  		}
    88  		return ethMsg.RequestId, GetBlockHeaders(*ethMsg.GetBlockHeadersPacket)
    89  	case (BlockHeaders{}).Code():
    90  		ethMsg := new(eth.BlockHeadersPacket66)
    91  		if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
    92  			return 0, errorf("could not rlp decode message: %v", err)
    93  		}
    94  		return ethMsg.RequestId, BlockHeaders(ethMsg.BlockHeadersPacket)
    95  	case (GetBlockBodies{}).Code():
    96  		ethMsg := new(eth.GetBlockBodiesPacket66)
    97  		if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
    98  			return 0, errorf("could not rlp decode message: %v", err)
    99  		}
   100  		return ethMsg.RequestId, GetBlockBodies(ethMsg.GetBlockBodiesPacket)
   101  	case (BlockBodies{}).Code():
   102  		ethMsg := new(eth.BlockBodiesPacket66)
   103  		if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
   104  			return 0, errorf("could not rlp decode message: %v", err)
   105  		}
   106  		return ethMsg.RequestId, BlockBodies(ethMsg.BlockBodiesPacket)
   107  	case (NewBlock{}).Code():
   108  		msg = new(NewBlock)
   109  	case (NewBlockHashes{}).Code():
   110  		msg = new(NewBlockHashes)
   111  	case (Transactions{}).Code():
   112  		msg = new(Transactions)
   113  	case (NewPooledTransactionHashes{}).Code():
   114  		msg = new(NewPooledTransactionHashes)
   115  	case (GetPooledTransactions{}.Code()):
   116  		ethMsg := new(eth.GetPooledTransactionsPacket66)
   117  		if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
   118  			return 0, errorf("could not rlp decode message: %v", err)
   119  		}
   120  		return ethMsg.RequestId, GetPooledTransactions(ethMsg.GetPooledTransactionsPacket)
   121  	case (PooledTransactions{}.Code()):
   122  		ethMsg := new(eth.PooledTransactionsPacket66)
   123  		if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
   124  			return 0, errorf("could not rlp decode message: %v", err)
   125  		}
   126  		return ethMsg.RequestId, PooledTransactions(ethMsg.PooledTransactionsPacket)
   127  	default:
   128  		msg = errorf("invalid message code: %d", code)
   129  	}
   130  
   131  	if msg != nil {
   132  		if err := rlp.DecodeBytes(rawData, msg); err != nil {
   133  			return 0, errorf("could not rlp decode message: %v", err)
   134  		}
   135  		return 0, msg
   136  	}
   137  	return 0, errorf("invalid message: %s", string(rawData))
   138  }
   139  
   140  func (c *Conn) waitForResponse(chain *Chain, timeout time.Duration, requestID uint64) Message {
   141  	for {
   142  		id, msg := c.readAndServe66(chain, timeout)
   143  		if id == requestID {
   144  			return msg
   145  		}
   146  	}
   147  }
   148  
   149  // ReadAndServe serves GetBlockHeaders requests while waiting
   150  // on another message from the node.
   151  func (c *Conn) readAndServe66(chain *Chain, timeout time.Duration) (uint64, Message) {
   152  	start := time.Now()
   153  	for time.Since(start) < timeout {
   154  		c.SetReadDeadline(time.Now().Add(10 * time.Second))
   155  
   156  		reqID, msg := c.read66()
   157  
   158  		switch msg := msg.(type) {
   159  		case *Ping:
   160  			c.Write(&Pong{})
   161  		case *GetBlockHeaders:
   162  			headers, err := chain.GetHeaders(*msg)
   163  			if err != nil {
   164  				return 0, errorf("could not get headers for inbound header request: %v", err)
   165  			}
   166  			resp := &eth.BlockHeadersPacket66{
   167  				RequestId:          reqID,
   168  				BlockHeadersPacket: eth.BlockHeadersPacket(headers),
   169  			}
   170  			if err := c.write66(resp, BlockHeaders{}.Code()); err != nil {
   171  				return 0, errorf("could not write to connection: %v", err)
   172  			}
   173  		default:
   174  			return reqID, msg
   175  		}
   176  	}
   177  	return 0, errorf("no message received within %v", timeout)
   178  }
   179  
   180  func (s *Suite) setupConnection66(t *utesting.T) *Conn {
   181  	// create conn
   182  	sendConn := s.dial66(t)
   183  	sendConn.handshake(t)
   184  	sendConn.statusExchange66(t, s.chain)
   185  	return sendConn
   186  }
   187  
   188  func (s *Suite) testAnnounce66(t *utesting.T, sendConn, receiveConn *Conn, blockAnnouncement *NewBlock) {
   189  	// Announce the block.
   190  	if err := sendConn.Write(blockAnnouncement); err != nil {
   191  		t.Fatalf("could not write to connection: %v", err)
   192  	}
   193  	s.waitAnnounce66(t, receiveConn, blockAnnouncement)
   194  }
   195  
   196  func (s *Suite) waitAnnounce66(t *utesting.T, conn *Conn, blockAnnouncement *NewBlock) {
   197  	for {
   198  		_, msg := conn.readAndServe66(s.chain, timeout)
   199  		switch msg := msg.(type) {
   200  		case *NewBlock:
   201  			t.Logf("received NewBlock message: %s", pretty.Sdump(msg.Block))
   202  			assert.Equal(t,
   203  				blockAnnouncement.Block.Header(), msg.Block.Header(),
   204  				"wrong block header in announcement",
   205  			)
   206  			assert.Equal(t,
   207  				blockAnnouncement.TD, msg.TD,
   208  				"wrong TD in announcement",
   209  			)
   210  			return
   211  		case *NewBlockHashes:
   212  			blockHashes := *msg
   213  			t.Logf("received NewBlockHashes message: %s", pretty.Sdump(blockHashes))
   214  			assert.Equal(t, blockAnnouncement.Block.Hash(), blockHashes[0].Hash,
   215  				"wrong block hash in announcement",
   216  			)
   217  			return
   218  		case *NewPooledTransactionHashes:
   219  			// ignore old txs being propagated
   220  			continue
   221  		default:
   222  			t.Fatalf("unexpected: %s", pretty.Sdump(msg))
   223  		}
   224  	}
   225  }
   226  
   227  // waitForBlock66 waits for confirmation from the client that it has
   228  // imported the given block.
   229  func (c *Conn) waitForBlock66(block *types.Block) error {
   230  	defer c.SetReadDeadline(time.Time{})
   231  
   232  	c.SetReadDeadline(time.Now().Add(20 * time.Second))
   233  	// note: if the node has not yet imported the block, it will respond
   234  	// to the GetBlockHeaders request with an empty BlockHeaders response,
   235  	// so the GetBlockHeaders request must be sent again until the BlockHeaders
   236  	// response contains the desired header.
   237  	for {
   238  		req := eth.GetBlockHeadersPacket66{
   239  			RequestId: 54,
   240  			GetBlockHeadersPacket: &eth.GetBlockHeadersPacket{
   241  				Origin: eth.HashOrNumber{
   242  					Hash: block.Hash(),
   243  				},
   244  				Amount: 1,
   245  			},
   246  		}
   247  		if err := c.write66(req, GetBlockHeaders{}.Code()); err != nil {
   248  			return err
   249  		}
   250  
   251  		reqID, msg := c.read66()
   252  		// check message
   253  		switch msg := msg.(type) {
   254  		case BlockHeaders:
   255  			// check request ID
   256  			if reqID != req.RequestId {
   257  				return fmt.Errorf("request ID mismatch: wanted %d, got %d", req.RequestId, reqID)
   258  			}
   259  			for _, header := range msg {
   260  				if header.Number.Uint64() == block.NumberU64() {
   261  					return nil
   262  				}
   263  			}
   264  			time.Sleep(100 * time.Millisecond)
   265  		case *NewPooledTransactionHashes:
   266  			// ignore old announcements
   267  			continue
   268  		default:
   269  			return fmt.Errorf("invalid message: %s", pretty.Sdump(msg))
   270  		}
   271  	}
   272  }
   273  
   274  func sendSuccessfulTx66(t *utesting.T, s *Suite, tx *types.Transaction) {
   275  	sendConn := s.setupConnection66(t)
   276  	defer sendConn.Close()
   277  	sendSuccessfulTxWithConn(t, s, tx, sendConn)
   278  }
   279  
   280  // waitForBlockHeadersResponse66 waits for a BlockHeaders message with the given expected request ID
   281  func (s *Suite) waitForBlockHeadersResponse66(conn *Conn, expectedID uint64) (BlockHeaders, error) {
   282  	reqID, msg := conn.readAndServe66(s.chain, timeout)
   283  	switch msg := msg.(type) {
   284  	case BlockHeaders:
   285  		if reqID != expectedID {
   286  			return nil, fmt.Errorf("request ID mismatch: wanted %d, got %d", expectedID, reqID)
   287  		}
   288  		return msg, nil
   289  	default:
   290  		return nil, fmt.Errorf("unexpected: %s", pretty.Sdump(msg))
   291  	}
   292  }
   293  
   294  func (s *Suite) getBlockHeaders66(conn *Conn, req eth.Packet, expectedID uint64) (BlockHeaders, error) {
   295  	if err := conn.write66(req, GetBlockHeaders{}.Code()); err != nil {
   296  		return nil, fmt.Errorf("could not write to connection: %v", err)
   297  	}
   298  	return s.waitForBlockHeadersResponse66(conn, expectedID)
   299  }
   300  
   301  func headersMatch(t *utesting.T, chain *Chain, headers BlockHeaders) bool {
   302  	mismatched := 0
   303  	for _, header := range headers {
   304  		num := header.Number.Uint64()
   305  		t.Logf("received header (%d): %s", num, pretty.Sdump(header.Hash()))
   306  		if !reflect.DeepEqual(chain.blocks[int(num)].Header(), header) {
   307  			mismatched += 1
   308  			t.Logf("received wrong header: %v", pretty.Sdump(header))
   309  		}
   310  	}
   311  	return mismatched == 0
   312  }
   313  
   314  func (s *Suite) sendNextBlock66(t *utesting.T) {
   315  	sendConn, receiveConn := s.setupConnection66(t), s.setupConnection66(t)
   316  	defer sendConn.Close()
   317  	defer receiveConn.Close()
   318  
   319  	// create new block announcement
   320  	nextBlock := len(s.chain.blocks)
   321  	blockAnnouncement := &NewBlock{
   322  		Block: s.fullChain.blocks[nextBlock],
   323  		TD:    s.fullChain.TD(nextBlock + 1),
   324  	}
   325  	// send announcement and wait for node to request the header
   326  	s.testAnnounce66(t, sendConn, receiveConn, blockAnnouncement)
   327  	// wait for client to update its chain
   328  	if err := receiveConn.waitForBlock66(s.fullChain.blocks[nextBlock]); err != nil {
   329  		t.Fatal(err)
   330  	}
   331  	// update test suite chain
   332  	s.chain.blocks = append(s.chain.blocks, s.fullChain.blocks[nextBlock])
   333  }