github.1485827954.workers.dev/ethereum/go-ethereum@v1.14.3/cmd/devp2p/internal/ethtest/conn.go (about)

     1  // Copyright 2023 The go-ethereum Authors
     2  // This file is part of go-ethereum.
     3  //
     4  // go-ethereum is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // go-ethereum is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU General Public License
    15  // along with go-ethereum. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package ethtest
    18  
    19  import (
    20  	"crypto/ecdsa"
    21  	"errors"
    22  	"fmt"
    23  	"net"
    24  	"reflect"
    25  	"time"
    26  
    27  	"github.com/davecgh/go-spew/spew"
    28  	"github.com/ethereum/go-ethereum/crypto"
    29  	"github.com/ethereum/go-ethereum/eth/protocols/eth"
    30  	"github.com/ethereum/go-ethereum/eth/protocols/snap"
    31  	"github.com/ethereum/go-ethereum/p2p"
    32  	"github.com/ethereum/go-ethereum/p2p/rlpx"
    33  	"github.com/ethereum/go-ethereum/rlp"
    34  )
    35  
    36  var (
    37  	pretty = spew.ConfigState{
    38  		Indent:                  "  ",
    39  		DisableCapacities:       true,
    40  		DisablePointerAddresses: true,
    41  		SortKeys:                true,
    42  	}
    43  	timeout = 2 * time.Second
    44  )
    45  
    46  // dial attempts to dial the given node and perform a handshake, returning the
    47  // created Conn if successful.
    48  func (s *Suite) dial() (*Conn, error) {
    49  	key, _ := crypto.GenerateKey()
    50  	return s.dialAs(key)
    51  }
    52  
    53  // dialAs attempts to dial a given node and perform a handshake using the given
    54  // private key.
    55  func (s *Suite) dialAs(key *ecdsa.PrivateKey) (*Conn, error) {
    56  	fd, err := net.Dial("tcp", fmt.Sprintf("%v:%d", s.Dest.IP(), s.Dest.TCP()))
    57  	if err != nil {
    58  		return nil, err
    59  	}
    60  	conn := Conn{Conn: rlpx.NewConn(fd, s.Dest.Pubkey())}
    61  	conn.ourKey = key
    62  	_, err = conn.Handshake(conn.ourKey)
    63  	if err != nil {
    64  		conn.Close()
    65  		return nil, err
    66  	}
    67  	conn.caps = []p2p.Cap{
    68  		{Name: "eth", Version: 67},
    69  		{Name: "eth", Version: 68},
    70  	}
    71  	conn.ourHighestProtoVersion = 68
    72  	return &conn, nil
    73  }
    74  
    75  // dialSnap creates a connection with snap/1 capability.
    76  func (s *Suite) dialSnap() (*Conn, error) {
    77  	conn, err := s.dial()
    78  	if err != nil {
    79  		return nil, fmt.Errorf("dial failed: %v", err)
    80  	}
    81  	conn.caps = append(conn.caps, p2p.Cap{Name: "snap", Version: 1})
    82  	conn.ourHighestSnapProtoVersion = 1
    83  	return conn, nil
    84  }
    85  
    86  // Conn represents an individual connection with a peer
    87  type Conn struct {
    88  	*rlpx.Conn
    89  	ourKey                     *ecdsa.PrivateKey
    90  	negotiatedProtoVersion     uint
    91  	negotiatedSnapProtoVersion uint
    92  	ourHighestProtoVersion     uint
    93  	ourHighestSnapProtoVersion uint
    94  	caps                       []p2p.Cap
    95  }
    96  
    97  // Read reads a packet from the connection.
    98  func (c *Conn) Read() (uint64, []byte, error) {
    99  	c.SetReadDeadline(time.Now().Add(timeout))
   100  	code, data, _, err := c.Conn.Read()
   101  	if err != nil {
   102  		return 0, nil, err
   103  	}
   104  	return code, data, nil
   105  }
   106  
   107  // ReadMsg attempts to read a devp2p message with a specific code.
   108  func (c *Conn) ReadMsg(proto Proto, code uint64, msg any) error {
   109  	c.SetReadDeadline(time.Now().Add(timeout))
   110  	for {
   111  		got, data, err := c.Read()
   112  		if err != nil {
   113  			return err
   114  		}
   115  		if protoOffset(proto)+code == got {
   116  			return rlp.DecodeBytes(data, msg)
   117  		}
   118  	}
   119  }
   120  
   121  // Write writes a eth packet to the connection.
   122  func (c *Conn) Write(proto Proto, code uint64, msg any) error {
   123  	c.SetWriteDeadline(time.Now().Add(timeout))
   124  	payload, err := rlp.EncodeToBytes(msg)
   125  	if err != nil {
   126  		return err
   127  	}
   128  	_, err = c.Conn.Write(protoOffset(proto)+code, payload)
   129  	return err
   130  }
   131  
   132  // ReadEth reads an Eth sub-protocol wire message.
   133  func (c *Conn) ReadEth() (any, error) {
   134  	c.SetReadDeadline(time.Now().Add(timeout))
   135  	for {
   136  		code, data, _, err := c.Conn.Read()
   137  		if err != nil {
   138  			return nil, err
   139  		}
   140  		if code == pingMsg {
   141  			c.Write(baseProto, pongMsg, []byte{})
   142  			continue
   143  		}
   144  		if getProto(code) != ethProto {
   145  			// Read until eth message.
   146  			continue
   147  		}
   148  		code -= baseProtoLen
   149  
   150  		var msg any
   151  		switch int(code) {
   152  		case eth.StatusMsg:
   153  			msg = new(eth.StatusPacket)
   154  		case eth.GetBlockHeadersMsg:
   155  			msg = new(eth.GetBlockHeadersPacket)
   156  		case eth.BlockHeadersMsg:
   157  			msg = new(eth.BlockHeadersPacket)
   158  		case eth.GetBlockBodiesMsg:
   159  			msg = new(eth.GetBlockBodiesPacket)
   160  		case eth.BlockBodiesMsg:
   161  			msg = new(eth.BlockBodiesPacket)
   162  		case eth.NewBlockMsg:
   163  			msg = new(eth.NewBlockPacket)
   164  		case eth.NewBlockHashesMsg:
   165  			msg = new(eth.NewBlockHashesPacket)
   166  		case eth.TransactionsMsg:
   167  			msg = new(eth.TransactionsPacket)
   168  		case eth.NewPooledTransactionHashesMsg:
   169  			msg = new(eth.NewPooledTransactionHashesPacket)
   170  		case eth.GetPooledTransactionsMsg:
   171  			msg = new(eth.GetPooledTransactionsPacket)
   172  		case eth.PooledTransactionsMsg:
   173  			msg = new(eth.PooledTransactionsPacket)
   174  		default:
   175  			panic(fmt.Sprintf("unhandled eth msg code %d", code))
   176  		}
   177  		if err := rlp.DecodeBytes(data, msg); err != nil {
   178  			return nil, fmt.Errorf("unable to decode eth msg: %v", err)
   179  		}
   180  		return msg, nil
   181  	}
   182  }
   183  
   184  // ReadSnap reads a snap/1 response with the given id from the connection.
   185  func (c *Conn) ReadSnap() (any, error) {
   186  	c.SetReadDeadline(time.Now().Add(timeout))
   187  	for {
   188  		code, data, _, err := c.Conn.Read()
   189  		if err != nil {
   190  			return nil, err
   191  		}
   192  		if getProto(code) != snapProto {
   193  			// Read until snap message.
   194  			continue
   195  		}
   196  		code -= baseProtoLen + ethProtoLen
   197  
   198  		var msg any
   199  		switch int(code) {
   200  		case snap.GetAccountRangeMsg:
   201  			msg = new(snap.GetAccountRangePacket)
   202  		case snap.AccountRangeMsg:
   203  			msg = new(snap.AccountRangePacket)
   204  		case snap.GetStorageRangesMsg:
   205  			msg = new(snap.GetStorageRangesPacket)
   206  		case snap.StorageRangesMsg:
   207  			msg = new(snap.StorageRangesPacket)
   208  		case snap.GetByteCodesMsg:
   209  			msg = new(snap.GetByteCodesPacket)
   210  		case snap.ByteCodesMsg:
   211  			msg = new(snap.ByteCodesPacket)
   212  		case snap.GetTrieNodesMsg:
   213  			msg = new(snap.GetTrieNodesPacket)
   214  		case snap.TrieNodesMsg:
   215  			msg = new(snap.TrieNodesPacket)
   216  		default:
   217  			panic(fmt.Errorf("unhandled snap code: %d", code))
   218  		}
   219  		if err := rlp.DecodeBytes(data, msg); err != nil {
   220  			return nil, fmt.Errorf("could not rlp decode message: %v", err)
   221  		}
   222  		return msg, nil
   223  	}
   224  }
   225  
   226  // peer performs both the protocol handshake and the status message
   227  // exchange with the node in order to peer with it.
   228  func (c *Conn) peer(chain *Chain, status *eth.StatusPacket) error {
   229  	if err := c.handshake(); err != nil {
   230  		return fmt.Errorf("handshake failed: %v", err)
   231  	}
   232  	if err := c.statusExchange(chain, status); err != nil {
   233  		return fmt.Errorf("status exchange failed: %v", err)
   234  	}
   235  	return nil
   236  }
   237  
   238  // handshake performs a protocol handshake with the node.
   239  func (c *Conn) handshake() error {
   240  	// Write hello to client.
   241  	pub0 := crypto.FromECDSAPub(&c.ourKey.PublicKey)[1:]
   242  	ourHandshake := &protoHandshake{
   243  		Version: 5,
   244  		Caps:    c.caps,
   245  		ID:      pub0,
   246  	}
   247  	if err := c.Write(baseProto, handshakeMsg, ourHandshake); err != nil {
   248  		return fmt.Errorf("write to connection failed: %v", err)
   249  	}
   250  	// Read hello from client.
   251  	code, data, err := c.Read()
   252  	if err != nil {
   253  		return fmt.Errorf("erroring reading handshake: %v", err)
   254  	}
   255  	switch code {
   256  	case handshakeMsg:
   257  		msg := new(protoHandshake)
   258  		if err := rlp.DecodeBytes(data, &msg); err != nil {
   259  			return fmt.Errorf("error decoding handshake msg: %v", err)
   260  		}
   261  		// Set snappy if version is at least 5.
   262  		if msg.Version >= 5 {
   263  			c.SetSnappy(true)
   264  		}
   265  		c.negotiateEthProtocol(msg.Caps)
   266  		if c.negotiatedProtoVersion == 0 {
   267  			return fmt.Errorf("could not negotiate eth protocol (remote caps: %v, local eth version: %v)", msg.Caps, c.ourHighestProtoVersion)
   268  		}
   269  		// If we require snap, verify that it was negotiated.
   270  		if c.ourHighestSnapProtoVersion != c.negotiatedSnapProtoVersion {
   271  			return fmt.Errorf("could not negotiate snap protocol (remote caps: %v, local snap version: %v)", msg.Caps, c.ourHighestSnapProtoVersion)
   272  		}
   273  		return nil
   274  	default:
   275  		return fmt.Errorf("bad handshake: got msg code %d", code)
   276  	}
   277  }
   278  
   279  // negotiateEthProtocol sets the Conn's eth protocol version to highest
   280  // advertised capability from peer.
   281  func (c *Conn) negotiateEthProtocol(caps []p2p.Cap) {
   282  	var highestEthVersion uint
   283  	var highestSnapVersion uint
   284  	for _, capability := range caps {
   285  		switch capability.Name {
   286  		case "eth":
   287  			if capability.Version > highestEthVersion && capability.Version <= c.ourHighestProtoVersion {
   288  				highestEthVersion = capability.Version
   289  			}
   290  		case "snap":
   291  			if capability.Version > highestSnapVersion && capability.Version <= c.ourHighestSnapProtoVersion {
   292  				highestSnapVersion = capability.Version
   293  			}
   294  		}
   295  	}
   296  	c.negotiatedProtoVersion = highestEthVersion
   297  	c.negotiatedSnapProtoVersion = highestSnapVersion
   298  }
   299  
   300  // statusExchange performs a `Status` message exchange with the given node.
   301  func (c *Conn) statusExchange(chain *Chain, status *eth.StatusPacket) error {
   302  loop:
   303  	for {
   304  		code, data, err := c.Read()
   305  		if err != nil {
   306  			return fmt.Errorf("failed to read from connection: %w", err)
   307  		}
   308  		switch code {
   309  		case eth.StatusMsg + protoOffset(ethProto):
   310  			msg := new(eth.StatusPacket)
   311  			if err := rlp.DecodeBytes(data, &msg); err != nil {
   312  				return fmt.Errorf("error decoding status packet: %w", err)
   313  			}
   314  			if have, want := msg.Head, chain.blocks[chain.Len()-1].Hash(); have != want {
   315  				return fmt.Errorf("wrong head block in status, want:  %#x (block %d) have %#x",
   316  					want, chain.blocks[chain.Len()-1].NumberU64(), have)
   317  			}
   318  			if have, want := msg.TD.Cmp(chain.TD()), 0; have != want {
   319  				return fmt.Errorf("wrong TD in status: have %v want %v", have, want)
   320  			}
   321  			if have, want := msg.ForkID, chain.ForkID(); !reflect.DeepEqual(have, want) {
   322  				return fmt.Errorf("wrong fork ID in status: have %v, want %v", have, want)
   323  			}
   324  			if have, want := msg.ProtocolVersion, c.ourHighestProtoVersion; have != uint32(want) {
   325  				return fmt.Errorf("wrong protocol version: have %v, want %v", have, want)
   326  			}
   327  			break loop
   328  		case discMsg:
   329  			var msg []p2p.DiscReason
   330  			if rlp.DecodeBytes(data, &msg); len(msg) == 0 {
   331  				return errors.New("invalid disconnect message")
   332  			}
   333  			return fmt.Errorf("disconnect received: %v", pretty.Sdump(msg))
   334  		case pingMsg:
   335  			// TODO (renaynay): in the future, this should be an error
   336  			// (PINGs should not be a response upon fresh connection)
   337  			c.Write(baseProto, pongMsg, nil)
   338  		default:
   339  			return fmt.Errorf("bad status message: code %d", code)
   340  		}
   341  	}
   342  	// make sure eth protocol version is set for negotiation
   343  	if c.negotiatedProtoVersion == 0 {
   344  		return errors.New("eth protocol version must be set in Conn")
   345  	}
   346  	if status == nil {
   347  		// default status message
   348  		status = &eth.StatusPacket{
   349  			ProtocolVersion: uint32(c.negotiatedProtoVersion),
   350  			NetworkID:       chain.config.ChainID.Uint64(),
   351  			TD:              chain.TD(),
   352  			Head:            chain.blocks[chain.Len()-1].Hash(),
   353  			Genesis:         chain.blocks[0].Hash(),
   354  			ForkID:          chain.ForkID(),
   355  		}
   356  	}
   357  	if err := c.Write(ethProto, eth.StatusMsg, status); err != nil {
   358  		return fmt.Errorf("write to connection failed: %v", err)
   359  	}
   360  	return nil
   361  }