github.com/theQRL/go-zond@v0.2.1/cmd/devp2p/internal/zondtest/conn.go (about)

     1  package zondtest
     2  
     3  // TODO(now.youtrack.cloud/issue/TGZ-6)
     4  /*
     5  import (
     6  	"crypto/ecdsa"
     7  	"errors"
     8  	"fmt"
     9  	"net"
    10  	"reflect"
    11  	"time"
    12  
    13  	"github.com/davecgh/go-spew/spew"
    14  	"github.com/theQRL/go-zond/crypto"
    15  	"github.com/theQRL/go-zond/p2p"
    16  	"github.com/theQRL/go-zond/p2p/rlpx"
    17  	"github.com/theQRL/go-zond/rlp"
    18  	"github.com/theQRL/go-zond/zond/protocols/snap"
    19  	"github.com/theQRL/go-zond/zond/protocols/zond"
    20  )
    21  
    22  var (
    23  	pretty = spew.ConfigState{
    24  		Indent:                  "  ",
    25  		DisableCapacities:       true,
    26  		DisablePointerAddresses: true,
    27  		SortKeys:                true,
    28  	}
    29  	timeout = 2 * time.Second
    30  )
    31  
    32  // dial attempts to dial the given node and perform a handshake, returning the
    33  // created Conn if successful.
    34  func (s *Suite) dial() (*Conn, error) {
    35  	key, _ := crypto.GenerateKey()
    36  	return s.dialAs(key)
    37  }
    38  
    39  // dialAs attempts to dial a given node and perform a handshake using the given
    40  // private key.
    41  func (s *Suite) dialAs(key *ecdsa.PrivateKey) (*Conn, error) {
    42  	tcpEndpoint, _ := s.Dest.TCPEndpoint()
    43  	fd, err := net.Dial("tcp", tcpEndpoint.String())
    44  	if err != nil {
    45  		return nil, err
    46  	}
    47  	conn := Conn{Conn: rlpx.NewConn(fd, s.Dest.Pubkey())}
    48  	conn.ourKey = key
    49  	_, err = conn.Handshake(conn.ourKey)
    50  	if err != nil {
    51  		conn.Close()
    52  		return nil, err
    53  	}
    54  	conn.caps = []p2p.Cap{
    55  		{Name: "eth", Version: 67},
    56  		{Name: "eth", Version: 68},
    57  	}
    58  	conn.ourHighestProtoVersion = 68
    59  	return &conn, nil
    60  }
    61  
    62  // dialSnap creates a connection with snap/1 capability.
    63  func (s *Suite) dialSnap() (*Conn, error) {
    64  	conn, err := s.dial()
    65  	if err != nil {
    66  		return nil, fmt.Errorf("dial failed: %v", err)
    67  	}
    68  	conn.caps = append(conn.caps, p2p.Cap{Name: "snap", Version: 1})
    69  	conn.ourHighestSnapProtoVersion = 1
    70  	return conn, nil
    71  }
    72  
    73  // Conn represents an individual connection with a peer
    74  type Conn struct {
    75  	*rlpx.Conn
    76  	ourKey                     *ecdsa.PrivateKey
    77  	negotiatedProtoVersion     uint
    78  	negotiatedSnapProtoVersion uint
    79  	ourHighestProtoVersion     uint
    80  	ourHighestSnapProtoVersion uint
    81  	caps                       []p2p.Cap
    82  }
    83  
    84  // Read reads a packet from the connection.
    85  func (c *Conn) Read() (uint64, []byte, error) {
    86  	c.SetReadDeadline(time.Now().Add(timeout))
    87  	code, data, _, err := c.Conn.Read()
    88  	if err != nil {
    89  		return 0, nil, err
    90  	}
    91  	return code, data, nil
    92  }
    93  
    94  // ReadMsg attempts to read a devp2p message with a specific code.
    95  func (c *Conn) ReadMsg(proto Proto, code uint64, msg any) error {
    96  	c.SetReadDeadline(time.Now().Add(timeout))
    97  	for {
    98  		got, data, err := c.Read()
    99  		if err != nil {
   100  			return err
   101  		}
   102  		if protoOffset(proto)+code == got {
   103  			return rlp.DecodeBytes(data, msg)
   104  		}
   105  	}
   106  }
   107  
   108  // Write writes a eth packet to the connection.
   109  func (c *Conn) Write(proto Proto, code uint64, msg any) error {
   110  	c.SetWriteDeadline(time.Now().Add(timeout))
   111  	payload, err := rlp.EncodeToBytes(msg)
   112  	if err != nil {
   113  		return err
   114  	}
   115  	_, err = c.Conn.Write(protoOffset(proto)+code, payload)
   116  	return err
   117  }
   118  
   119  // ReadEth reads an Eth sub-protocol wire message.
   120  func (c *Conn) ReadEth() (any, error) {
   121  	c.SetReadDeadline(time.Now().Add(timeout))
   122  	for {
   123  		code, data, _, err := c.Conn.Read()
   124  		if err != nil {
   125  			return nil, err
   126  		}
   127  		if code == pingMsg {
   128  			c.Write(baseProto, pongMsg, []byte{})
   129  			continue
   130  		}
   131  		if getProto(code) != ethProto {
   132  			// Read until eth message.
   133  			continue
   134  		}
   135  		code -= baseProtoLen
   136  
   137  		var msg any
   138  		switch int(code) {
   139  		case zond.StatusMsg:
   140  			msg = new(zond.StatusPacket)
   141  		case zond.GetBlockHeadersMsg:
   142  			msg = new(zond.GetBlockHeadersPacket)
   143  		case zond.BlockHeadersMsg:
   144  			msg = new(zond.BlockHeadersPacket)
   145  		case zond.GetBlockBodiesMsg:
   146  			msg = new(zond.GetBlockBodiesPacket)
   147  		case zond.BlockBodiesMsg:
   148  			msg = new(zond.BlockBodiesPacket)
   149  		case zond.TransactionsMsg:
   150  			msg = new(zond.TransactionsPacket)
   151  		case zond.NewPooledTransactionHashesMsg:
   152  			msg = new(zond.NewPooledTransactionHashesPacket)
   153  		case zond.GetPooledTransactionsMsg:
   154  			msg = new(zond.GetPooledTransactionsPacket)
   155  		case zond.PooledTransactionsMsg:
   156  			msg = new(zond.PooledTransactionsPacket)
   157  		default:
   158  			panic(fmt.Sprintf("unhandled eth msg code %d", code))
   159  		}
   160  		if err := rlp.DecodeBytes(data, msg); err != nil {
   161  			return nil, fmt.Errorf("unable to decode eth msg: %v", err)
   162  		}
   163  		return msg, nil
   164  	}
   165  }
   166  
   167  // ReadSnap reads a snap/1 response with the given id from the connection.
   168  func (c *Conn) ReadSnap() (any, error) {
   169  	c.SetReadDeadline(time.Now().Add(timeout))
   170  	for {
   171  		code, data, _, err := c.Conn.Read()
   172  		if err != nil {
   173  			return nil, err
   174  		}
   175  		if getProto(code) != snapProto {
   176  			// Read until snap message.
   177  			continue
   178  		}
   179  		code -= baseProtoLen + ethProtoLen
   180  
   181  		var msg any
   182  		switch int(code) {
   183  		case snap.GetAccountRangeMsg:
   184  			msg = new(snap.GetAccountRangePacket)
   185  		case snap.AccountRangeMsg:
   186  			msg = new(snap.AccountRangePacket)
   187  		case snap.GetStorageRangesMsg:
   188  			msg = new(snap.GetStorageRangesPacket)
   189  		case snap.StorageRangesMsg:
   190  			msg = new(snap.StorageRangesPacket)
   191  		case snap.GetByteCodesMsg:
   192  			msg = new(snap.GetByteCodesPacket)
   193  		case snap.ByteCodesMsg:
   194  			msg = new(snap.ByteCodesPacket)
   195  		case snap.GetTrieNodesMsg:
   196  			msg = new(snap.GetTrieNodesPacket)
   197  		case snap.TrieNodesMsg:
   198  			msg = new(snap.TrieNodesPacket)
   199  		default:
   200  			panic(fmt.Errorf("unhandled snap code: %d", code))
   201  		}
   202  		if err := rlp.DecodeBytes(data, msg); err != nil {
   203  			return nil, fmt.Errorf("could not rlp decode message: %v", err)
   204  		}
   205  		return msg, nil
   206  	}
   207  }
   208  
   209  // peer performs both the protocol handshake and the status message
   210  // exchange with the node in order to peer with it.
   211  func (c *Conn) peer(chain *Chain, status *zond.StatusPacket) error {
   212  	if err := c.handshake(); err != nil {
   213  		return fmt.Errorf("handshake failed: %v", err)
   214  	}
   215  	if err := c.statusExchange(chain, status); err != nil {
   216  		return fmt.Errorf("status exchange failed: %v", err)
   217  	}
   218  	return nil
   219  }
   220  
   221  // handshake performs a protocol handshake with the node.
   222  func (c *Conn) handshake() error {
   223  	// Write hello to client.
   224  	pub0 := crypto.FromECDSAPub(&c.ourKey.PublicKey)[1:]
   225  	ourHandshake := &protoHandshake{
   226  		Version: 5,
   227  		Caps:    c.caps,
   228  		ID:      pub0,
   229  	}
   230  	if err := c.Write(baseProto, handshakeMsg, ourHandshake); err != nil {
   231  		return fmt.Errorf("write to connection failed: %v", err)
   232  	}
   233  	// Read hello from client.
   234  	code, data, err := c.Read()
   235  	if err != nil {
   236  		return fmt.Errorf("erroring reading handshake: %v", err)
   237  	}
   238  	switch code {
   239  	case handshakeMsg:
   240  		msg := new(protoHandshake)
   241  		if err := rlp.DecodeBytes(data, &msg); err != nil {
   242  			return fmt.Errorf("error decoding handshake msg: %v", err)
   243  		}
   244  		// Set snappy if version is at least 5.
   245  		if msg.Version >= 5 {
   246  			c.SetSnappy(true)
   247  		}
   248  		c.negotiateEthProtocol(msg.Caps)
   249  		if c.negotiatedProtoVersion == 0 {
   250  			return fmt.Errorf("could not negotiate eth protocol (remote caps: %v, local eth version: %v)", msg.Caps, c.ourHighestProtoVersion)
   251  		}
   252  		// If we require snap, verify that it was negotiated.
   253  		if c.ourHighestSnapProtoVersion != c.negotiatedSnapProtoVersion {
   254  			return fmt.Errorf("could not negotiate snap protocol (remote caps: %v, local snap version: %v)", msg.Caps, c.ourHighestSnapProtoVersion)
   255  		}
   256  		return nil
   257  	default:
   258  		return fmt.Errorf("bad handshake: got msg code %d", code)
   259  	}
   260  }
   261  
   262  // negotiateEthProtocol sets the Conn's eth protocol version to highest
   263  // advertised capability from peer.
   264  func (c *Conn) negotiateEthProtocol(caps []p2p.Cap) {
   265  	var highestEthVersion uint
   266  	var highestSnapVersion uint
   267  	for _, capability := range caps {
   268  		switch capability.Name {
   269  		case "eth":
   270  			if capability.Version > highestEthVersion && capability.Version <= c.ourHighestProtoVersion {
   271  				highestEthVersion = capability.Version
   272  			}
   273  		case "snap":
   274  			if capability.Version > highestSnapVersion && capability.Version <= c.ourHighestSnapProtoVersion {
   275  				highestSnapVersion = capability.Version
   276  			}
   277  		}
   278  	}
   279  	c.negotiatedProtoVersion = highestEthVersion
   280  	c.negotiatedSnapProtoVersion = highestSnapVersion
   281  }
   282  
   283  // statusExchange performs a `Status` message exchange with the given node.
   284  func (c *Conn) statusExchange(chain *Chain, status *zond.StatusPacket) error {
   285  loop:
   286  	for {
   287  		code, data, err := c.Read()
   288  		if err != nil {
   289  			return fmt.Errorf("failed to read from connection: %w", err)
   290  		}
   291  		switch code {
   292  		case zond.StatusMsg + protoOffset(zondProto):
   293  			msg := new(zond.StatusPacket)
   294  			if err := rlp.DecodeBytes(data, &msg); err != nil {
   295  				return fmt.Errorf("error decoding status packet: %w", err)
   296  			}
   297  			if have, want := msg.Head, chain.blocks[chain.Len()-1].Hash(); have != want {
   298  				return fmt.Errorf("wrong head block in status, want:  %#x (block %d) have %#x",
   299  					want, chain.blocks[chain.Len()-1].NumberU64(), have)
   300  			}
   301  			if have, want := msg.TD.Cmp(chain.TD()), 0; have != want {
   302  				return fmt.Errorf("wrong TD in status: have %v want %v", have, want)
   303  			}
   304  			if have, want := msg.ForkID, chain.ForkID(); !reflect.DeepEqual(have, want) {
   305  				return fmt.Errorf("wrong fork ID in status: have %v, want %v", have, want)
   306  			}
   307  			if have, want := msg.ProtocolVersion, c.ourHighestProtoVersion; have != uint32(want) {
   308  				return fmt.Errorf("wrong protocol version: have %v, want %v", have, want)
   309  			}
   310  			break loop
   311  		case discMsg:
   312  			var msg []p2p.DiscReason
   313  			if rlp.DecodeBytes(data, &msg); len(msg) == 0 {
   314  				return errors.New("invalid disconnect message")
   315  			}
   316  			return fmt.Errorf("disconnect received: %v", pretty.Sdump(msg))
   317  		case pingMsg:
   318  			// TODO (renaynay): in the future, this should be an error
   319  			// (PINGs should not be a response upon fresh connection)
   320  			c.Write(baseProto, pongMsg, nil)
   321  		default:
   322  			return fmt.Errorf("bad status message: code %d", code)
   323  		}
   324  	}
   325  	// make sure zond protocol version is set for negotiation
   326  	if c.negotiatedProtoVersion == 0 {
   327  		return errors.New("zond protocol version must be set in Conn")
   328  	}
   329  	if status == nil {
   330  		// default status message
   331  		status = &zond.StatusPacket{
   332  			ProtocolVersion: uint32(c.negotiatedProtoVersion),
   333  			NetworkID:       chain.config.ChainID.Uint64(),
   334  			Head:            chain.blocks[chain.Len()-1].Hash(),
   335  			Genesis:         chain.blocks[0].Hash(),
   336  			ForkID:          chain.ForkID(),
   337  		}
   338  	}
   339  	if err := c.Write(zondProto, zond.StatusMsg, status); err != nil {
   340  		return fmt.Errorf("write to connection failed: %v", err)
   341  	}
   342  	return nil
   343  }
   344  */