github.com/blocknative/go-ethereum@v1.9.7/p2p/discv5/udp.go (about)

     1  // Copyright 2016 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package discv5
    18  
    19  import (
    20  	"bytes"
    21  	"crypto/ecdsa"
    22  	"errors"
    23  	"fmt"
    24  	"net"
    25  	"time"
    26  
    27  	"github.com/ethereum/go-ethereum/common"
    28  	"github.com/ethereum/go-ethereum/crypto"
    29  	"github.com/ethereum/go-ethereum/log"
    30  	"github.com/ethereum/go-ethereum/p2p/nat"
    31  	"github.com/ethereum/go-ethereum/p2p/netutil"
    32  	"github.com/ethereum/go-ethereum/rlp"
    33  )
    34  
    35  const Version = 4
    36  
    37  // Errors
    38  var (
    39  	errPacketTooSmall = errors.New("too small")
    40  	errBadPrefix      = errors.New("bad prefix")
    41  	errTimeout        = errors.New("RPC timeout")
    42  )
    43  
    44  // Timeouts
    45  const (
    46  	respTimeout = 500 * time.Millisecond
    47  	expiration  = 20 * time.Second
    48  
    49  	driftThreshold = 10 * time.Second // Allowed clock drift before warning user
    50  )
    51  
    52  // RPC request structures
    53  type (
    54  	ping struct {
    55  		Version    uint
    56  		From, To   rpcEndpoint
    57  		Expiration uint64
    58  
    59  		// v5
    60  		Topics []Topic
    61  
    62  		// Ignore additional fields (for forward compatibility).
    63  		Rest []rlp.RawValue `rlp:"tail"`
    64  	}
    65  
    66  	// pong is the reply to ping.
    67  	pong struct {
    68  		// This field should mirror the UDP envelope address
    69  		// of the ping packet, which provides a way to discover the
    70  		// the external address (after NAT).
    71  		To rpcEndpoint
    72  
    73  		ReplyTok   []byte // This contains the hash of the ping packet.
    74  		Expiration uint64 // Absolute timestamp at which the packet becomes invalid.
    75  
    76  		// v5
    77  		TopicHash    common.Hash
    78  		TicketSerial uint32
    79  		WaitPeriods  []uint32
    80  
    81  		// Ignore additional fields (for forward compatibility).
    82  		Rest []rlp.RawValue `rlp:"tail"`
    83  	}
    84  
    85  	// findnode is a query for nodes close to the given target.
    86  	findnode struct {
    87  		Target     NodeID // doesn't need to be an actual public key
    88  		Expiration uint64
    89  		// Ignore additional fields (for forward compatibility).
    90  		Rest []rlp.RawValue `rlp:"tail"`
    91  	}
    92  
    93  	// findnode is a query for nodes close to the given target.
    94  	findnodeHash struct {
    95  		Target     common.Hash
    96  		Expiration uint64
    97  		// Ignore additional fields (for forward compatibility).
    98  		Rest []rlp.RawValue `rlp:"tail"`
    99  	}
   100  
   101  	// reply to findnode
   102  	neighbors struct {
   103  		Nodes      []rpcNode
   104  		Expiration uint64
   105  		// Ignore additional fields (for forward compatibility).
   106  		Rest []rlp.RawValue `rlp:"tail"`
   107  	}
   108  
   109  	topicRegister struct {
   110  		Topics []Topic
   111  		Idx    uint
   112  		Pong   []byte
   113  	}
   114  
   115  	topicQuery struct {
   116  		Topic      Topic
   117  		Expiration uint64
   118  	}
   119  
   120  	// reply to topicQuery
   121  	topicNodes struct {
   122  		Echo  common.Hash
   123  		Nodes []rpcNode
   124  	}
   125  
   126  	rpcNode struct {
   127  		IP  net.IP // len 4 for IPv4 or 16 for IPv6
   128  		UDP uint16 // for discovery protocol
   129  		TCP uint16 // for RLPx protocol
   130  		ID  NodeID
   131  	}
   132  
   133  	rpcEndpoint struct {
   134  		IP  net.IP // len 4 for IPv4 or 16 for IPv6
   135  		UDP uint16 // for discovery protocol
   136  		TCP uint16 // for RLPx protocol
   137  	}
   138  )
   139  
   140  var (
   141  	versionPrefix     = []byte("temporary discovery v5")
   142  	versionPrefixSize = len(versionPrefix)
   143  	sigSize           = 520 / 8
   144  	headSize          = versionPrefixSize + sigSize // space of packet frame data
   145  )
   146  
   147  // Neighbors replies are sent across multiple packets to
   148  // stay below the 1280 byte limit. We compute the maximum number
   149  // of entries by stuffing a packet until it grows too large.
   150  var maxNeighbors = func() int {
   151  	p := neighbors{Expiration: ^uint64(0)}
   152  	maxSizeNode := rpcNode{IP: make(net.IP, 16), UDP: ^uint16(0), TCP: ^uint16(0)}
   153  	for n := 0; ; n++ {
   154  		p.Nodes = append(p.Nodes, maxSizeNode)
   155  		size, _, err := rlp.EncodeToReader(p)
   156  		if err != nil {
   157  			// If this ever happens, it will be caught by the unit tests.
   158  			panic("cannot encode: " + err.Error())
   159  		}
   160  		if headSize+size+1 >= 1280 {
   161  			return n
   162  		}
   163  	}
   164  }()
   165  
   166  var maxTopicNodes = func() int {
   167  	p := topicNodes{}
   168  	maxSizeNode := rpcNode{IP: make(net.IP, 16), UDP: ^uint16(0), TCP: ^uint16(0)}
   169  	for n := 0; ; n++ {
   170  		p.Nodes = append(p.Nodes, maxSizeNode)
   171  		size, _, err := rlp.EncodeToReader(p)
   172  		if err != nil {
   173  			// If this ever happens, it will be caught by the unit tests.
   174  			panic("cannot encode: " + err.Error())
   175  		}
   176  		if headSize+size+1 >= 1280 {
   177  			return n
   178  		}
   179  	}
   180  }()
   181  
   182  func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint {
   183  	ip := addr.IP.To4()
   184  	if ip == nil {
   185  		ip = addr.IP.To16()
   186  	}
   187  	return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort}
   188  }
   189  
   190  func (e1 rpcEndpoint) equal(e2 rpcEndpoint) bool {
   191  	return e1.UDP == e2.UDP && e1.TCP == e2.TCP && e1.IP.Equal(e2.IP)
   192  }
   193  
   194  func nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) {
   195  	if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil {
   196  		return nil, err
   197  	}
   198  	n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP)
   199  	err := n.validateComplete()
   200  	return n, err
   201  }
   202  
   203  func nodeToRPC(n *Node) rpcNode {
   204  	return rpcNode{ID: n.ID, IP: n.IP, UDP: n.UDP, TCP: n.TCP}
   205  }
   206  
   207  type ingressPacket struct {
   208  	remoteID   NodeID
   209  	remoteAddr *net.UDPAddr
   210  	ev         nodeEvent
   211  	hash       []byte
   212  	data       interface{} // one of the RPC structs
   213  	rawData    []byte
   214  }
   215  
   216  type conn interface {
   217  	ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error)
   218  	WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error)
   219  	Close() error
   220  	LocalAddr() net.Addr
   221  }
   222  
   223  // udp implements the RPC protocol.
   224  type udp struct {
   225  	conn        conn
   226  	priv        *ecdsa.PrivateKey
   227  	ourEndpoint rpcEndpoint
   228  	nat         nat.Interface
   229  	net         *Network
   230  }
   231  
   232  // ListenUDP returns a new table that listens for UDP packets on laddr.
   233  func ListenUDP(priv *ecdsa.PrivateKey, conn conn, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) {
   234  	realaddr := conn.LocalAddr().(*net.UDPAddr)
   235  	transport, err := listenUDP(priv, conn, realaddr)
   236  	if err != nil {
   237  		return nil, err
   238  	}
   239  	net, err := newNetwork(transport, priv.PublicKey, nodeDBPath, netrestrict)
   240  	if err != nil {
   241  		return nil, err
   242  	}
   243  	log.Info("UDP listener up", "net", net.tab.self)
   244  	transport.net = net
   245  	go transport.readLoop()
   246  	return net, nil
   247  }
   248  
   249  func listenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr) (*udp, error) {
   250  	return &udp{conn: conn, priv: priv, ourEndpoint: makeEndpoint(realaddr, uint16(realaddr.Port))}, nil
   251  }
   252  
   253  func (t *udp) localAddr() *net.UDPAddr {
   254  	return t.conn.LocalAddr().(*net.UDPAddr)
   255  }
   256  
   257  func (t *udp) Close() {
   258  	t.conn.Close()
   259  }
   260  
   261  func (t *udp) send(remote *Node, ptype nodeEvent, data interface{}) (hash []byte) {
   262  	hash, _ = t.sendPacket(remote.ID, remote.addr(), byte(ptype), data)
   263  	return hash
   264  }
   265  
   266  func (t *udp) sendPing(remote *Node, toaddr *net.UDPAddr, topics []Topic) (hash []byte) {
   267  	hash, _ = t.sendPacket(remote.ID, toaddr, byte(pingPacket), ping{
   268  		Version:    Version,
   269  		From:       t.ourEndpoint,
   270  		To:         makeEndpoint(toaddr, uint16(toaddr.Port)), // TODO: maybe use known TCP port from DB
   271  		Expiration: uint64(time.Now().Add(expiration).Unix()),
   272  		Topics:     topics,
   273  	})
   274  	return hash
   275  }
   276  
   277  func (t *udp) sendFindnode(remote *Node, target NodeID) {
   278  	t.sendPacket(remote.ID, remote.addr(), byte(findnodePacket), findnode{
   279  		Target:     target,
   280  		Expiration: uint64(time.Now().Add(expiration).Unix()),
   281  	})
   282  }
   283  
   284  func (t *udp) sendNeighbours(remote *Node, results []*Node) {
   285  	// Send neighbors in chunks with at most maxNeighbors per packet
   286  	// to stay below the 1280 byte limit.
   287  	p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
   288  	for i, result := range results {
   289  		p.Nodes = append(p.Nodes, nodeToRPC(result))
   290  		if len(p.Nodes) == maxNeighbors || i == len(results)-1 {
   291  			t.sendPacket(remote.ID, remote.addr(), byte(neighborsPacket), p)
   292  			p.Nodes = p.Nodes[:0]
   293  		}
   294  	}
   295  }
   296  
   297  func (t *udp) sendFindnodeHash(remote *Node, target common.Hash) {
   298  	t.sendPacket(remote.ID, remote.addr(), byte(findnodeHashPacket), findnodeHash{
   299  		Target:     target,
   300  		Expiration: uint64(time.Now().Add(expiration).Unix()),
   301  	})
   302  }
   303  
   304  func (t *udp) sendTopicRegister(remote *Node, topics []Topic, idx int, pong []byte) {
   305  	t.sendPacket(remote.ID, remote.addr(), byte(topicRegisterPacket), topicRegister{
   306  		Topics: topics,
   307  		Idx:    uint(idx),
   308  		Pong:   pong,
   309  	})
   310  }
   311  
   312  func (t *udp) sendTopicNodes(remote *Node, queryHash common.Hash, nodes []*Node) {
   313  	p := topicNodes{Echo: queryHash}
   314  	var sent bool
   315  	for _, result := range nodes {
   316  		if result.IP.Equal(t.net.tab.self.IP) || netutil.CheckRelayIP(remote.IP, result.IP) == nil {
   317  			p.Nodes = append(p.Nodes, nodeToRPC(result))
   318  		}
   319  		if len(p.Nodes) == maxTopicNodes {
   320  			t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p)
   321  			p.Nodes = p.Nodes[:0]
   322  			sent = true
   323  		}
   324  	}
   325  	if !sent || len(p.Nodes) > 0 {
   326  		t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p)
   327  	}
   328  }
   329  
   330  func (t *udp) sendPacket(toid NodeID, toaddr *net.UDPAddr, ptype byte, req interface{}) (hash []byte, err error) {
   331  	//fmt.Println("sendPacket", nodeEvent(ptype), toaddr.String(), toid.String())
   332  	packet, hash, err := encodePacket(t.priv, ptype, req)
   333  	if err != nil {
   334  		//fmt.Println(err)
   335  		return hash, err
   336  	}
   337  	log.Trace(fmt.Sprintf(">>> %v to %x@%v", nodeEvent(ptype), toid[:8], toaddr))
   338  	if nbytes, err := t.conn.WriteToUDP(packet, toaddr); err != nil {
   339  		log.Trace(fmt.Sprint("UDP send failed:", err))
   340  	} else {
   341  		egressTrafficMeter.Mark(int64(nbytes))
   342  	}
   343  	//fmt.Println(err)
   344  	return hash, err
   345  }
   346  
   347  // zeroed padding space for encodePacket.
   348  var headSpace = make([]byte, headSize)
   349  
   350  func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) (p, hash []byte, err error) {
   351  	b := new(bytes.Buffer)
   352  	b.Write(headSpace)
   353  	b.WriteByte(ptype)
   354  	if err := rlp.Encode(b, req); err != nil {
   355  		log.Error(fmt.Sprint("error encoding packet:", err))
   356  		return nil, nil, err
   357  	}
   358  	packet := b.Bytes()
   359  	sig, err := crypto.Sign(crypto.Keccak256(packet[headSize:]), priv)
   360  	if err != nil {
   361  		log.Error(fmt.Sprint("could not sign packet:", err))
   362  		return nil, nil, err
   363  	}
   364  	copy(packet, versionPrefix)
   365  	copy(packet[versionPrefixSize:], sig)
   366  	hash = crypto.Keccak256(packet[versionPrefixSize:])
   367  	return packet, hash, nil
   368  }
   369  
   370  // readLoop runs in its own goroutine. it injects ingress UDP packets
   371  // into the network loop.
   372  func (t *udp) readLoop() {
   373  	defer t.conn.Close()
   374  	// Discovery packets are defined to be no larger than 1280 bytes.
   375  	// Packets larger than this size will be cut at the end and treated
   376  	// as invalid because their hash won't match.
   377  	buf := make([]byte, 1280)
   378  	for {
   379  		nbytes, from, err := t.conn.ReadFromUDP(buf)
   380  		ingressTrafficMeter.Mark(int64(nbytes))
   381  		if netutil.IsTemporaryError(err) {
   382  			// Ignore temporary read errors.
   383  			log.Debug(fmt.Sprintf("Temporary read error: %v", err))
   384  			continue
   385  		} else if err != nil {
   386  			// Shut down the loop for permament errors.
   387  			log.Debug(fmt.Sprintf("Read error: %v", err))
   388  			return
   389  		}
   390  		t.handlePacket(from, buf[:nbytes])
   391  	}
   392  }
   393  
   394  func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error {
   395  	pkt := ingressPacket{remoteAddr: from}
   396  	if err := decodePacket(buf, &pkt); err != nil {
   397  		log.Debug(fmt.Sprintf("Bad packet from %v: %v", from, err))
   398  		//fmt.Println("bad packet", err)
   399  		return err
   400  	}
   401  	t.net.reqReadPacket(pkt)
   402  	return nil
   403  }
   404  
   405  func decodePacket(buffer []byte, pkt *ingressPacket) error {
   406  	if len(buffer) < headSize+1 {
   407  		return errPacketTooSmall
   408  	}
   409  	buf := make([]byte, len(buffer))
   410  	copy(buf, buffer)
   411  	prefix, sig, sigdata := buf[:versionPrefixSize], buf[versionPrefixSize:headSize], buf[headSize:]
   412  	if !bytes.Equal(prefix, versionPrefix) {
   413  		return errBadPrefix
   414  	}
   415  	fromID, err := recoverNodeID(crypto.Keccak256(buf[headSize:]), sig)
   416  	if err != nil {
   417  		return err
   418  	}
   419  	pkt.rawData = buf
   420  	pkt.hash = crypto.Keccak256(buf[versionPrefixSize:])
   421  	pkt.remoteID = fromID
   422  	switch pkt.ev = nodeEvent(sigdata[0]); pkt.ev {
   423  	case pingPacket:
   424  		pkt.data = new(ping)
   425  	case pongPacket:
   426  		pkt.data = new(pong)
   427  	case findnodePacket:
   428  		pkt.data = new(findnode)
   429  	case neighborsPacket:
   430  		pkt.data = new(neighbors)
   431  	case findnodeHashPacket:
   432  		pkt.data = new(findnodeHash)
   433  	case topicRegisterPacket:
   434  		pkt.data = new(topicRegister)
   435  	case topicQueryPacket:
   436  		pkt.data = new(topicQuery)
   437  	case topicNodesPacket:
   438  		pkt.data = new(topicNodes)
   439  	default:
   440  		return fmt.Errorf("unknown packet type: %d", sigdata[0])
   441  	}
   442  	s := rlp.NewStream(bytes.NewReader(sigdata[1:]), 0)
   443  	err = s.Decode(pkt.data)
   444  	return err
   445  }