github.com/ethereum/go-ethereum@v1.16.1/cmd/devp2p/internal/ethtest/conn.go (about)

     1  // Copyright 2023 The go-ethereum Authors
     2  // This file is part of go-ethereum.
     3  //
     4  // go-ethereum is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // go-ethereum is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU General Public License
    15  // along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package ethtest
    18  
    19  import (
    20  	"crypto/ecdsa"
    21  	"errors"
    22  	"fmt"
    23  	"net"
    24  	"reflect"
    25  	"time"
    26  
    27  	"github.com/davecgh/go-spew/spew"
    28  	"github.com/ethereum/go-ethereum/crypto"
    29  	"github.com/ethereum/go-ethereum/eth/protocols/eth"
    30  	"github.com/ethereum/go-ethereum/eth/protocols/snap"
    31  	"github.com/ethereum/go-ethereum/p2p"
    32  	"github.com/ethereum/go-ethereum/p2p/rlpx"
    33  	"github.com/ethereum/go-ethereum/rlp"
    34  )
    35  
    36  var (
    37  	pretty = spew.ConfigState{
    38  		Indent:                  "  ",
    39  		DisableCapacities:       true,
    40  		DisablePointerAddresses: true,
    41  		SortKeys:                true,
    42  	}
    43  	timeout = 2 * time.Second
    44  )
    45  
    46  // dial attempts to dial the given node and perform a handshake, returning the
    47  // created Conn if successful.
    48  func (s *Suite) dial() (*Conn, error) {
    49  	key, _ := crypto.GenerateKey()
    50  	return s.dialAs(key)
    51  }
    52  
    53  // dialAs attempts to dial a given node and perform a handshake using the given
    54  // private key.
    55  func (s *Suite) dialAs(key *ecdsa.PrivateKey) (*Conn, error) {
    56  	tcpEndpoint, _ := s.Dest.TCPEndpoint()
    57  	fd, err := net.Dial("tcp", tcpEndpoint.String())
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  	conn := Conn{Conn: rlpx.NewConn(fd, s.Dest.Pubkey())}
    62  	conn.ourKey = key
    63  	_, err = conn.Handshake(conn.ourKey)
    64  	if err != nil {
    65  		conn.Close()
    66  		return nil, err
    67  	}
    68  	conn.caps = []p2p.Cap{
    69  		{Name: "eth", Version: 69},
    70  	}
    71  	conn.ourHighestProtoVersion = 69
    72  	return &conn, nil
    73  }
    74  
    75  // dialSnap creates a connection with snap/1 capability.
    76  func (s *Suite) dialSnap() (*Conn, error) {
    77  	conn, err := s.dial()
    78  	if err != nil {
    79  		return nil, fmt.Errorf("dial failed: %v", err)
    80  	}
    81  	conn.caps = append(conn.caps, p2p.Cap{Name: "snap", Version: 1})
    82  	conn.ourHighestSnapProtoVersion = 1
    83  	return conn, nil
    84  }
    85  
    86  // Conn represents an individual connection with a peer
    87  type Conn struct {
    88  	*rlpx.Conn
    89  	ourKey                     *ecdsa.PrivateKey
    90  	negotiatedProtoVersion     uint
    91  	negotiatedSnapProtoVersion uint
    92  	ourHighestProtoVersion     uint
    93  	ourHighestSnapProtoVersion uint
    94  	caps                       []p2p.Cap
    95  }
    96  
    97  // Read reads a packet from the connection.
    98  func (c *Conn) Read() (uint64, []byte, error) {
    99  	c.SetReadDeadline(time.Now().Add(timeout))
   100  	code, data, _, err := c.Conn.Read()
   101  	if err != nil {
   102  		return 0, nil, err
   103  	}
   104  	return code, data, nil
   105  }
   106  
   107  // ReadMsg attempts to read a devp2p message with a specific code.
   108  func (c *Conn) ReadMsg(proto Proto, code uint64, msg any) error {
   109  	c.SetReadDeadline(time.Now().Add(timeout))
   110  	for {
   111  		got, data, err := c.Read()
   112  		if err != nil {
   113  			return err
   114  		}
   115  		if protoOffset(proto)+code == got {
   116  			return rlp.DecodeBytes(data, msg)
   117  		}
   118  	}
   119  }
   120  
   121  // Write writes a eth packet to the connection.
   122  func (c *Conn) Write(proto Proto, code uint64, msg any) error {
   123  	c.SetWriteDeadline(time.Now().Add(timeout))
   124  	payload, err := rlp.EncodeToBytes(msg)
   125  	if err != nil {
   126  		return err
   127  	}
   128  	_, err = c.Conn.Write(protoOffset(proto)+code, payload)
   129  	return err
   130  }
   131  
   132  var errDisc error = fmt.Errorf("disconnect")
   133  
   134  // ReadEth reads an Eth sub-protocol wire message.
   135  func (c *Conn) ReadEth() (any, error) {
   136  	c.SetReadDeadline(time.Now().Add(timeout))
   137  	for {
   138  		code, data, _, err := c.Conn.Read()
   139  		if code == discMsg {
   140  			return nil, errDisc
   141  		}
   142  		if err != nil {
   143  			return nil, err
   144  		}
   145  		if code == pingMsg {
   146  			c.Write(baseProto, pongMsg, []byte{})
   147  			continue
   148  		}
   149  		if getProto(code) != ethProto {
   150  			// Read until eth message.
   151  			continue
   152  		}
   153  		code -= baseProtoLen
   154  
   155  		var msg any
   156  		switch int(code) {
   157  		case eth.StatusMsg:
   158  			msg = new(eth.StatusPacket69)
   159  		case eth.GetBlockHeadersMsg:
   160  			msg = new(eth.GetBlockHeadersPacket)
   161  		case eth.BlockHeadersMsg:
   162  			msg = new(eth.BlockHeadersPacket)
   163  		case eth.GetBlockBodiesMsg:
   164  			msg = new(eth.GetBlockBodiesPacket)
   165  		case eth.BlockBodiesMsg:
   166  			msg = new(eth.BlockBodiesPacket)
   167  		case eth.NewBlockMsg:
   168  			msg = new(eth.NewBlockPacket)
   169  		case eth.NewBlockHashesMsg:
   170  			msg = new(eth.NewBlockHashesPacket)
   171  		case eth.TransactionsMsg:
   172  			msg = new(eth.TransactionsPacket)
   173  		case eth.NewPooledTransactionHashesMsg:
   174  			msg = new(eth.NewPooledTransactionHashesPacket)
   175  		case eth.GetPooledTransactionsMsg:
   176  			msg = new(eth.GetPooledTransactionsPacket)
   177  		case eth.PooledTransactionsMsg:
   178  			msg = new(eth.PooledTransactionsPacket)
   179  		default:
   180  			panic(fmt.Sprintf("unhandled eth msg code %d", code))
   181  		}
   182  		if err := rlp.DecodeBytes(data, msg); err != nil {
   183  			return nil, fmt.Errorf("unable to decode eth msg: %v", err)
   184  		}
   185  		return msg, nil
   186  	}
   187  }
   188  
   189  // ReadSnap reads a snap/1 response with the given id from the connection.
   190  func (c *Conn) ReadSnap() (any, error) {
   191  	c.SetReadDeadline(time.Now().Add(timeout))
   192  	for {
   193  		code, data, _, err := c.Conn.Read()
   194  		if err != nil {
   195  			return nil, err
   196  		}
   197  		if getProto(code) != snapProto {
   198  			// Read until snap message.
   199  			continue
   200  		}
   201  		code -= baseProtoLen + ethProtoLen
   202  
   203  		var msg any
   204  		switch int(code) {
   205  		case snap.GetAccountRangeMsg:
   206  			msg = new(snap.GetAccountRangePacket)
   207  		case snap.AccountRangeMsg:
   208  			msg = new(snap.AccountRangePacket)
   209  		case snap.GetStorageRangesMsg:
   210  			msg = new(snap.GetStorageRangesPacket)
   211  		case snap.StorageRangesMsg:
   212  			msg = new(snap.StorageRangesPacket)
   213  		case snap.GetByteCodesMsg:
   214  			msg = new(snap.GetByteCodesPacket)
   215  		case snap.ByteCodesMsg:
   216  			msg = new(snap.ByteCodesPacket)
   217  		case snap.GetTrieNodesMsg:
   218  			msg = new(snap.GetTrieNodesPacket)
   219  		case snap.TrieNodesMsg:
   220  			msg = new(snap.TrieNodesPacket)
   221  		default:
   222  			panic(fmt.Errorf("unhandled snap code: %d", code))
   223  		}
   224  		if err := rlp.DecodeBytes(data, msg); err != nil {
   225  			return nil, fmt.Errorf("could not rlp decode message: %v", err)
   226  		}
   227  		return msg, nil
   228  	}
   229  }
   230  
   231  // dialAndPeer creates a peer connection and runs the handshake.
   232  func (s *Suite) dialAndPeer(status *eth.StatusPacket69) (*Conn, error) {
   233  	c, err := s.dial()
   234  	if err != nil {
   235  		return nil, err
   236  	}
   237  	if err = c.peer(s.chain, status); err != nil {
   238  		c.Close()
   239  	}
   240  	return c, err
   241  }
   242  
   243  // peer performs both the protocol handshake and the status message
   244  // exchange with the node in order to peer with it.
   245  func (c *Conn) peer(chain *Chain, status *eth.StatusPacket69) error {
   246  	if err := c.handshake(); err != nil {
   247  		return fmt.Errorf("handshake failed: %v", err)
   248  	}
   249  	if err := c.statusExchange(chain, status); err != nil {
   250  		return fmt.Errorf("status exchange failed: %v", err)
   251  	}
   252  	return nil
   253  }
   254  
   255  // handshake performs a protocol handshake with the node.
   256  func (c *Conn) handshake() error {
   257  	// Write hello to client.
   258  	pub0 := crypto.FromECDSAPub(&c.ourKey.PublicKey)[1:]
   259  	ourHandshake := &protoHandshake{
   260  		Version: 5,
   261  		Caps:    c.caps,
   262  		ID:      pub0,
   263  	}
   264  	if err := c.Write(baseProto, handshakeMsg, ourHandshake); err != nil {
   265  		return fmt.Errorf("write to connection failed: %v", err)
   266  	}
   267  	// Read hello from client.
   268  	code, data, err := c.Read()
   269  	if err != nil {
   270  		return fmt.Errorf("erroring reading handshake: %v", err)
   271  	}
   272  	switch code {
   273  	case handshakeMsg:
   274  		msg := new(protoHandshake)
   275  		if err := rlp.DecodeBytes(data, &msg); err != nil {
   276  			return fmt.Errorf("error decoding handshake msg: %v", err)
   277  		}
   278  		// Set snappy if version is at least 5.
   279  		if msg.Version >= 5 {
   280  			c.SetSnappy(true)
   281  		}
   282  		c.negotiateEthProtocol(msg.Caps)
   283  		if c.negotiatedProtoVersion == 0 {
   284  			return fmt.Errorf("could not negotiate eth protocol (remote caps: %v, local eth version: %v)", msg.Caps, c.ourHighestProtoVersion)
   285  		}
   286  		// If we require snap, verify that it was negotiated.
   287  		if c.ourHighestSnapProtoVersion != c.negotiatedSnapProtoVersion {
   288  			return fmt.Errorf("could not negotiate snap protocol (remote caps: %v, local snap version: %v)", msg.Caps, c.ourHighestSnapProtoVersion)
   289  		}
   290  		return nil
   291  	default:
   292  		return fmt.Errorf("bad handshake: got msg code %d", code)
   293  	}
   294  }
   295  
   296  // negotiateEthProtocol sets the Conn's eth protocol version to highest
   297  // advertised capability from peer.
   298  func (c *Conn) negotiateEthProtocol(caps []p2p.Cap) {
   299  	var highestEthVersion uint
   300  	var highestSnapVersion uint
   301  	for _, capability := range caps {
   302  		switch capability.Name {
   303  		case "eth":
   304  			if capability.Version > highestEthVersion && capability.Version <= c.ourHighestProtoVersion {
   305  				highestEthVersion = capability.Version
   306  			}
   307  		case "snap":
   308  			if capability.Version > highestSnapVersion && capability.Version <= c.ourHighestSnapProtoVersion {
   309  				highestSnapVersion = capability.Version
   310  			}
   311  		}
   312  	}
   313  	c.negotiatedProtoVersion = highestEthVersion
   314  	c.negotiatedSnapProtoVersion = highestSnapVersion
   315  }
   316  
   317  // statusExchange performs a `Status` message exchange with the given node.
   318  func (c *Conn) statusExchange(chain *Chain, status *eth.StatusPacket69) error {
   319  loop:
   320  	for {
   321  		code, data, err := c.Read()
   322  		if err != nil {
   323  			return fmt.Errorf("failed to read from connection: %w", err)
   324  		}
   325  		switch code {
   326  		case eth.StatusMsg + protoOffset(ethProto):
   327  			msg := new(eth.StatusPacket69)
   328  			if err := rlp.DecodeBytes(data, &msg); err != nil {
   329  				return fmt.Errorf("error decoding status packet: %w", err)
   330  			}
   331  			if have, want := msg.LatestBlock, chain.blocks[chain.Len()-1].NumberU64(); have != want {
   332  				return fmt.Errorf("wrong head block in status, want: %d, have %d",
   333  					want, have)
   334  			}
   335  			if have, want := msg.LatestBlockHash, chain.blocks[chain.Len()-1].Hash(); have != want {
   336  				return fmt.Errorf("wrong head block in status, want: %#x (block %d) have %#x",
   337  					want, chain.blocks[chain.Len()-1].NumberU64(), have)
   338  			}
   339  			if have, want := msg.ForkID, chain.ForkID(); !reflect.DeepEqual(have, want) {
   340  				return fmt.Errorf("wrong fork ID in status: have %v, want %v", have, want)
   341  			}
   342  			if have, want := msg.ProtocolVersion, c.ourHighestProtoVersion; have != uint32(want) {
   343  				return fmt.Errorf("wrong protocol version: have %v, want %v", have, want)
   344  			}
   345  			break loop
   346  		case discMsg:
   347  			var msg []p2p.DiscReason
   348  			if rlp.DecodeBytes(data, &msg); len(msg) == 0 {
   349  				return errors.New("invalid disconnect message")
   350  			}
   351  			return fmt.Errorf("disconnect received: %v", pretty.Sdump(msg))
   352  		case pingMsg:
   353  			// TODO (renaynay): in the future, this should be an error
   354  			// (PINGs should not be a response upon fresh connection)
   355  			c.Write(baseProto, pongMsg, nil)
   356  		default:
   357  			return fmt.Errorf("bad status message: code %d", code)
   358  		}
   359  	}
   360  	// make sure eth protocol version is set for negotiation
   361  	if c.negotiatedProtoVersion == 0 {
   362  		return errors.New("eth protocol version must be set in Conn")
   363  	}
   364  	if status == nil {
   365  		// default status message
   366  		status = &eth.StatusPacket69{
   367  			ProtocolVersion: uint32(c.negotiatedProtoVersion),
   368  			NetworkID:       chain.config.ChainID.Uint64(),
   369  			Genesis:         chain.blocks[0].Hash(),
   370  			ForkID:          chain.ForkID(),
   371  			EarliestBlock:   0,
   372  			LatestBlock:     chain.blocks[chain.Len()-1].NumberU64(),
   373  			LatestBlockHash: chain.blocks[chain.Len()-1].Hash(),
   374  		}
   375  	}
   376  	if err := c.Write(ethProto, eth.StatusMsg, status); err != nil {
   377  		return fmt.Errorf("write to connection failed: %v", err)
   378  	}
   379  	return nil
   380  }