github.com/juliankolbe/go-ethereum@v1.9.992/cmd/devp2p/internal/ethtest/eth66_suiteHelpers.go (about)

     1  // Copyright 2021 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package ethtest
    18  
    19  import (
    20  	"fmt"
    21  	"time"
    22  
    23  	"github.com/juliankolbe/go-ethereum/core/types"
    24  	"github.com/juliankolbe/go-ethereum/eth/protocols/eth"
    25  	"github.com/juliankolbe/go-ethereum/internal/utesting"
    26  	"github.com/juliankolbe/go-ethereum/p2p"
    27  	"github.com/juliankolbe/go-ethereum/rlp"
    28  	"github.com/stretchr/testify/assert"
    29  )
    30  
    31  func (c *Conn) statusExchange66(t *utesting.T, chain *Chain) Message {
    32  	status := &Status{
    33  		ProtocolVersion: uint32(66),
    34  		NetworkID:       chain.chainConfig.ChainID.Uint64(),
    35  		TD:              chain.TD(chain.Len()),
    36  		Head:            chain.blocks[chain.Len()-1].Hash(),
    37  		Genesis:         chain.blocks[0].Hash(),
    38  		ForkID:          chain.ForkID(),
    39  	}
    40  	return c.statusExchange(t, chain, status)
    41  }
    42  
    43  func (s *Suite) dial66(t *utesting.T) *Conn {
    44  	conn, err := s.dial()
    45  	if err != nil {
    46  		t.Fatalf("could not dial: %v", err)
    47  	}
    48  	conn.caps = append(conn.caps, p2p.Cap{Name: "eth", Version: 66})
    49  	return conn
    50  }
    51  
    52  func (c *Conn) write66(req eth.Packet, code int) error {
    53  	payload, err := rlp.EncodeToBytes(req)
    54  	if err != nil {
    55  		return err
    56  	}
    57  	_, err = c.Conn.Write(uint64(code), payload)
    58  	return err
    59  }
    60  
    61  func (c *Conn) read66() (uint64, Message) {
    62  	code, rawData, _, err := c.Conn.Read()
    63  	if err != nil {
    64  		return 0, errorf("could not read from connection: %v", err)
    65  	}
    66  
    67  	var msg Message
    68  
    69  	switch int(code) {
    70  	case (Hello{}).Code():
    71  		msg = new(Hello)
    72  
    73  	case (Ping{}).Code():
    74  		msg = new(Ping)
    75  	case (Pong{}).Code():
    76  		msg = new(Pong)
    77  	case (Disconnect{}).Code():
    78  		msg = new(Disconnect)
    79  	case (Status{}).Code():
    80  		msg = new(Status)
    81  	case (GetBlockHeaders{}).Code():
    82  		ethMsg := new(eth.GetBlockHeadersPacket66)
    83  		if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
    84  			return 0, errorf("could not rlp decode message: %v", err)
    85  		}
    86  		return ethMsg.RequestId, GetBlockHeaders(*ethMsg.GetBlockHeadersPacket)
    87  	case (BlockHeaders{}).Code():
    88  		ethMsg := new(eth.BlockHeadersPacket66)
    89  		if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
    90  			return 0, errorf("could not rlp decode message: %v", err)
    91  		}
    92  		return ethMsg.RequestId, BlockHeaders(ethMsg.BlockHeadersPacket)
    93  	case (GetBlockBodies{}).Code():
    94  		ethMsg := new(eth.GetBlockBodiesPacket66)
    95  		if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
    96  			return 0, errorf("could not rlp decode message: %v", err)
    97  		}
    98  		return ethMsg.RequestId, GetBlockBodies(ethMsg.GetBlockBodiesPacket)
    99  	case (BlockBodies{}).Code():
   100  		ethMsg := new(eth.BlockBodiesPacket66)
   101  		if err := rlp.DecodeBytes(rawData, ethMsg); err != nil {
   102  			return 0, errorf("could not rlp decode message: %v", err)
   103  		}
   104  		return ethMsg.RequestId, BlockBodies(ethMsg.BlockBodiesPacket)
   105  	case (NewBlock{}).Code():
   106  		msg = new(NewBlock)
   107  	case (NewBlockHashes{}).Code():
   108  		msg = new(NewBlockHashes)
   109  	case (Transactions{}).Code():
   110  		msg = new(Transactions)
   111  	case (NewPooledTransactionHashes{}).Code():
   112  		msg = new(NewPooledTransactionHashes)
   113  	default:
   114  		msg = errorf("invalid message code: %d", code)
   115  	}
   116  
   117  	if msg != nil {
   118  		if err := rlp.DecodeBytes(rawData, msg); err != nil {
   119  			return 0, errorf("could not rlp decode message: %v", err)
   120  		}
   121  		return 0, msg
   122  	}
   123  	return 0, errorf("invalid message: %s", string(rawData))
   124  }
   125  
   126  // ReadAndServe serves GetBlockHeaders requests while waiting
   127  // on another message from the node.
   128  func (c *Conn) readAndServe66(chain *Chain, timeout time.Duration) (uint64, Message) {
   129  	start := time.Now()
   130  	for time.Since(start) < timeout {
   131  		timeout := time.Now().Add(10 * time.Second)
   132  		c.SetReadDeadline(timeout)
   133  
   134  		reqID, msg := c.read66()
   135  
   136  		switch msg := msg.(type) {
   137  		case *Ping:
   138  			c.Write(&Pong{})
   139  		case *GetBlockHeaders:
   140  			headers, err := chain.GetHeaders(*msg)
   141  			if err != nil {
   142  				return 0, errorf("could not get headers for inbound header request: %v", err)
   143  			}
   144  
   145  			if err := c.Write(headers); err != nil {
   146  				return 0, errorf("could not write to connection: %v", err)
   147  			}
   148  		default:
   149  			return reqID, msg
   150  		}
   151  	}
   152  	return 0, errorf("no message received within %v", timeout)
   153  }
   154  
   155  func (s *Suite) setupConnection66(t *utesting.T) *Conn {
   156  	// create conn
   157  	sendConn := s.dial66(t)
   158  	sendConn.handshake(t)
   159  	sendConn.statusExchange66(t, s.chain)
   160  	return sendConn
   161  }
   162  
   163  func (s *Suite) testAnnounce66(t *utesting.T, sendConn, receiveConn *Conn, blockAnnouncement *NewBlock) {
   164  	// Announce the block.
   165  	if err := sendConn.Write(blockAnnouncement); err != nil {
   166  		t.Fatalf("could not write to connection: %v", err)
   167  	}
   168  	s.waitAnnounce66(t, receiveConn, blockAnnouncement)
   169  }
   170  
   171  func (s *Suite) waitAnnounce66(t *utesting.T, conn *Conn, blockAnnouncement *NewBlock) {
   172  	timeout := 20 * time.Second
   173  	_, msg := conn.readAndServe66(s.chain, timeout)
   174  	switch msg := msg.(type) {
   175  	case *NewBlock:
   176  		t.Logf("received NewBlock message: %s", pretty.Sdump(msg.Block))
   177  		assert.Equal(t,
   178  			blockAnnouncement.Block.Header(), msg.Block.Header(),
   179  			"wrong block header in announcement",
   180  		)
   181  		assert.Equal(t,
   182  			blockAnnouncement.TD, msg.TD,
   183  			"wrong TD in announcement",
   184  		)
   185  	case *NewBlockHashes:
   186  		blockHashes := *msg
   187  		t.Logf("received NewBlockHashes message: %s", pretty.Sdump(blockHashes))
   188  		assert.Equal(t, blockAnnouncement.Block.Hash(), blockHashes[0].Hash,
   189  			"wrong block hash in announcement",
   190  		)
   191  	default:
   192  		t.Fatalf("unexpected: %s", pretty.Sdump(msg))
   193  	}
   194  }
   195  
   196  // waitForBlock66 waits for confirmation from the client that it has
   197  // imported the given block.
   198  func (c *Conn) waitForBlock66(block *types.Block) error {
   199  	defer c.SetReadDeadline(time.Time{})
   200  
   201  	timeout := time.Now().Add(20 * time.Second)
   202  	c.SetReadDeadline(timeout)
   203  	for {
   204  		req := eth.GetBlockHeadersPacket66{
   205  			RequestId: 54,
   206  			GetBlockHeadersPacket: &eth.GetBlockHeadersPacket{
   207  				Origin: eth.HashOrNumber{
   208  					Hash: block.Hash(),
   209  				},
   210  				Amount: 1,
   211  			},
   212  		}
   213  		if err := c.write66(req, GetBlockHeaders{}.Code()); err != nil {
   214  			return err
   215  		}
   216  
   217  		reqID, msg := c.read66()
   218  		// check message
   219  		switch msg := msg.(type) {
   220  		case BlockHeaders:
   221  			// check request ID
   222  			if reqID != req.RequestId {
   223  				return fmt.Errorf("request ID mismatch: wanted %d, got %d", req.RequestId, reqID)
   224  			}
   225  			if len(msg) > 0 {
   226  				return nil
   227  			}
   228  			time.Sleep(100 * time.Millisecond)
   229  		default:
   230  			return fmt.Errorf("invalid message: %s", pretty.Sdump(msg))
   231  		}
   232  	}
   233  }
   234  
   235  func sendSuccessfulTx66(t *utesting.T, s *Suite, tx *types.Transaction) {
   236  	sendConn := s.setupConnection66(t)
   237  	sendSuccessfulTxWithConn(t, s, tx, sendConn)
   238  }
   239  
   240  func sendFailingTx66(t *utesting.T, s *Suite, tx *types.Transaction) {
   241  	sendConn, recvConn := s.setupConnection66(t), s.setupConnection66(t)
   242  	sendFailingTxWithConns(t, s, tx, sendConn, recvConn)
   243  }
   244  
   245  func (s *Suite) getBlockHeaders66(t *utesting.T, conn *Conn, req eth.Packet, expectedID uint64) BlockHeaders {
   246  	if err := conn.write66(req, GetBlockHeaders{}.Code()); err != nil {
   247  		t.Fatalf("could not write to connection: %v", err)
   248  	}
   249  	// check block headers response
   250  	reqID, msg := conn.readAndServe66(s.chain, timeout)
   251  
   252  	switch msg := msg.(type) {
   253  	case BlockHeaders:
   254  		if reqID != expectedID {
   255  			t.Fatalf("request ID mismatch: wanted %d, got %d", expectedID, reqID)
   256  		}
   257  		return msg
   258  	default:
   259  		t.Fatalf("unexpected: %s", pretty.Sdump(msg))
   260  		return nil
   261  	}
   262  }
   263  
   264  func headersMatch(t *utesting.T, chain *Chain, headers BlockHeaders) {
   265  	for _, header := range headers {
   266  		num := header.Number.Uint64()
   267  		t.Logf("received header (%d): %s", num, pretty.Sdump(header.Hash()))
   268  		assert.Equal(t, chain.blocks[int(num)].Header(), header)
   269  	}
   270  }