github.com/cranelv/ethereum_mpc@v0.0.0-20191031014521-23aeb1415092/p2p/discv5/udp.go (about)

     1  // Copyright 2016 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package discv5
    18  
    19  import (
    20  	"bytes"
    21  	"crypto/ecdsa"
    22  	"errors"
    23  	"fmt"
    24  	"net"
    25  	"time"
    26  
    27  	"github.com/ethereum/go-ethereum/common"
    28  	"github.com/ethereum/go-ethereum/crypto"
    29  	"github.com/ethereum/go-ethereum/log"
    30  	"github.com/ethereum/go-ethereum/p2p/nat"
    31  	"github.com/ethereum/go-ethereum/p2p/netutil"
    32  	"github.com/ethereum/go-ethereum/rlp"
    33  )
    34  
    35  const Version = 4
    36  
    37  // Errors
    38  var (
    39  	errPacketTooSmall = errors.New("too small")
    40  	errBadPrefix      = errors.New("bad prefix")
    41  	errTimeout        = errors.New("RPC timeout")
    42  )
    43  
    44  // Timeouts
    45  const (
    46  	respTimeout = 500 * time.Millisecond
    47  	expiration  = 20 * time.Second
    48  
    49  	driftThreshold = 10 * time.Second // Allowed clock drift before warning user
    50  )
    51  
    52  // RPC request structures
    53  type (
    54  	ping struct {
    55  		Version    uint
    56  		From, To   rpcEndpoint
    57  		Expiration uint64
    58  
    59  		// v5
    60  		Topics []Topic
    61  
    62  		// Ignore additional fields (for forward compatibility).
    63  		Rest []rlp.RawValue `rlp:"tail"`
    64  	}
    65  
    66  	// pong is the reply to ping.
    67  	pong struct {
    68  		// This field should mirror the UDP envelope address
    69  		// of the ping packet, which provides a way to discover the
    70  		// the external address (after NAT).
    71  		To rpcEndpoint
    72  
    73  		ReplyTok   []byte // This contains the hash of the ping packet.
    74  		Expiration uint64 // Absolute timestamp at which the packet becomes invalid.
    75  
    76  		// v5
    77  		TopicHash    common.Hash
    78  		TicketSerial uint32
    79  		WaitPeriods  []uint32
    80  
    81  		// Ignore additional fields (for forward compatibility).
    82  		Rest []rlp.RawValue `rlp:"tail"`
    83  	}
    84  
    85  	// findnode is a query for nodes close to the given target.
    86  	findnode struct {
    87  		Target     NodeID // doesn't need to be an actual public key
    88  		Expiration uint64
    89  		// Ignore additional fields (for forward compatibility).
    90  		Rest []rlp.RawValue `rlp:"tail"`
    91  	}
    92  
    93  	// findnode is a query for nodes close to the given target.
    94  	findnodeHash struct {
    95  		Target     common.Hash
    96  		Expiration uint64
    97  		// Ignore additional fields (for forward compatibility).
    98  		Rest []rlp.RawValue `rlp:"tail"`
    99  	}
   100  
   101  	// reply to findnode
   102  	neighbors struct {
   103  		Nodes      []rpcNode
   104  		Expiration uint64
   105  		// Ignore additional fields (for forward compatibility).
   106  		Rest []rlp.RawValue `rlp:"tail"`
   107  	}
   108  
   109  	topicRegister struct {
   110  		Topics []Topic
   111  		Idx    uint
   112  		Pong   []byte
   113  	}
   114  
   115  	topicQuery struct {
   116  		Topic      Topic
   117  		Expiration uint64
   118  	}
   119  
   120  	// reply to topicQuery
   121  	topicNodes struct {
   122  		Echo  common.Hash
   123  		Nodes []rpcNode
   124  	}
   125  
   126  	rpcNode struct {
   127  		IP  net.IP // len 4 for IPv4 or 16 for IPv6
   128  		UDP uint16 // for discovery protocol
   129  		TCP uint16 // for RLPx protocol
   130  		ID  NodeID
   131  	}
   132  
   133  	rpcEndpoint struct {
   134  		IP  net.IP // len 4 for IPv4 or 16 for IPv6
   135  		UDP uint16 // for discovery protocol
   136  		TCP uint16 // for RLPx protocol
   137  	}
   138  )
   139  
   140  var (
   141  	versionPrefix     = []byte("temporary discovery v5")
   142  	versionPrefixSize = len(versionPrefix)
   143  	sigSize           = 520 / 8
   144  	headSize          = versionPrefixSize + sigSize // space of packet frame data
   145  )
   146  
   147  // Neighbors replies are sent across multiple packets to
   148  // stay below the 1280 byte limit. We compute the maximum number
   149  // of entries by stuffing a packet until it grows too large.
   150  var maxNeighbors = func() int {
   151  	p := neighbors{Expiration: ^uint64(0)}
   152  	maxSizeNode := rpcNode{IP: make(net.IP, 16), UDP: ^uint16(0), TCP: ^uint16(0)}
   153  	for n := 0; ; n++ {
   154  		p.Nodes = append(p.Nodes, maxSizeNode)
   155  		size, _, err := rlp.EncodeToReader(p)
   156  		if err != nil {
   157  			// If this ever happens, it will be caught by the unit tests.
   158  			panic("cannot encode: " + err.Error())
   159  		}
   160  		if headSize+size+1 >= 1280 {
   161  			return n
   162  		}
   163  	}
   164  }()
   165  
   166  var maxTopicNodes = func() int {
   167  	p := topicNodes{}
   168  	maxSizeNode := rpcNode{IP: make(net.IP, 16), UDP: ^uint16(0), TCP: ^uint16(0)}
   169  	for n := 0; ; n++ {
   170  		p.Nodes = append(p.Nodes, maxSizeNode)
   171  		size, _, err := rlp.EncodeToReader(p)
   172  		if err != nil {
   173  			// If this ever happens, it will be caught by the unit tests.
   174  			panic("cannot encode: " + err.Error())
   175  		}
   176  		if headSize+size+1 >= 1280 {
   177  			return n
   178  		}
   179  	}
   180  }()
   181  
   182  func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint {
   183  	ip := addr.IP.To4()
   184  	if ip == nil {
   185  		ip = addr.IP.To16()
   186  	}
   187  	return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort}
   188  }
   189  
   190  func (e1 rpcEndpoint) equal(e2 rpcEndpoint) bool {
   191  	return e1.UDP == e2.UDP && e1.TCP == e2.TCP && e1.IP.Equal(e2.IP)
   192  }
   193  
   194  func nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) {
   195  	if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil {
   196  		return nil, err
   197  	}
   198  	n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP)
   199  	err := n.validateComplete()
   200  	return n, err
   201  }
   202  
   203  func nodeToRPC(n *Node) rpcNode {
   204  	return rpcNode{ID: n.ID, IP: n.IP, UDP: n.UDP, TCP: n.TCP}
   205  }
   206  
   207  type ingressPacket struct {
   208  	remoteID   NodeID
   209  	remoteAddr *net.UDPAddr
   210  	ev         nodeEvent
   211  	hash       []byte
   212  	data       interface{} // one of the RPC structs
   213  	rawData    []byte
   214  }
   215  
   216  type conn interface {
   217  	ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error)
   218  	WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error)
   219  	Close() error
   220  	LocalAddr() net.Addr
   221  }
   222  
   223  // udp implements the RPC protocol.
   224  type udp struct {
   225  	conn        conn
   226  	priv        *ecdsa.PrivateKey
   227  	ourEndpoint rpcEndpoint
   228  	nat         nat.Interface
   229  	net         *Network
   230  }
   231  
   232  // ListenUDP returns a new table that listens for UDP packets on laddr.
   233  func ListenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) {
   234  	transport, err := listenUDP(priv, conn, realaddr)
   235  	if err != nil {
   236  		return nil, err
   237  	}
   238  	net, err := newNetwork(transport, priv.PublicKey, nodeDBPath, netrestrict)
   239  	if err != nil {
   240  		return nil, err
   241  	}
   242  	log.Info("UDP listener up", "net", net.tab.self)
   243  	transport.net = net
   244  	go transport.readLoop()
   245  	return net, nil
   246  }
   247  
   248  func listenUDP(priv *ecdsa.PrivateKey, conn conn, realaddr *net.UDPAddr) (*udp, error) {
   249  	return &udp{conn: conn, priv: priv, ourEndpoint: makeEndpoint(realaddr, uint16(realaddr.Port))}, nil
   250  }
   251  
   252  func (t *udp) localAddr() *net.UDPAddr {
   253  	return t.conn.LocalAddr().(*net.UDPAddr)
   254  }
   255  
   256  func (t *udp) Close() {
   257  	t.conn.Close()
   258  }
   259  
   260  func (t *udp) send(remote *Node, ptype nodeEvent, data interface{}) (hash []byte) {
   261  	hash, _ = t.sendPacket(remote.ID, remote.addr(), byte(ptype), data)
   262  	return hash
   263  }
   264  
   265  func (t *udp) sendPing(remote *Node, toaddr *net.UDPAddr, topics []Topic) (hash []byte) {
   266  	hash, _ = t.sendPacket(remote.ID, toaddr, byte(pingPacket), ping{
   267  		Version:    Version,
   268  		From:       t.ourEndpoint,
   269  		To:         makeEndpoint(toaddr, uint16(toaddr.Port)), // TODO: maybe use known TCP port from DB
   270  		Expiration: uint64(time.Now().Add(expiration).Unix()),
   271  		Topics:     topics,
   272  	})
   273  	return hash
   274  }
   275  
   276  func (t *udp) sendFindnode(remote *Node, target NodeID) {
   277  	t.sendPacket(remote.ID, remote.addr(), byte(findnodePacket), findnode{
   278  		Target:     target,
   279  		Expiration: uint64(time.Now().Add(expiration).Unix()),
   280  	})
   281  }
   282  
   283  func (t *udp) sendNeighbours(remote *Node, results []*Node) {
   284  	// Send neighbors in chunks with at most maxNeighbors per packet
   285  	// to stay below the 1280 byte limit.
   286  	p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
   287  	for i, result := range results {
   288  		p.Nodes = append(p.Nodes, nodeToRPC(result))
   289  		if len(p.Nodes) == maxNeighbors || i == len(results)-1 {
   290  			t.sendPacket(remote.ID, remote.addr(), byte(neighborsPacket), p)
   291  			p.Nodes = p.Nodes[:0]
   292  		}
   293  	}
   294  }
   295  
   296  func (t *udp) sendFindnodeHash(remote *Node, target common.Hash) {
   297  	t.sendPacket(remote.ID, remote.addr(), byte(findnodeHashPacket), findnodeHash{
   298  		Target:     target,
   299  		Expiration: uint64(time.Now().Add(expiration).Unix()),
   300  	})
   301  }
   302  
   303  func (t *udp) sendTopicRegister(remote *Node, topics []Topic, idx int, pong []byte) {
   304  	t.sendPacket(remote.ID, remote.addr(), byte(topicRegisterPacket), topicRegister{
   305  		Topics: topics,
   306  		Idx:    uint(idx),
   307  		Pong:   pong,
   308  	})
   309  }
   310  
   311  func (t *udp) sendTopicNodes(remote *Node, queryHash common.Hash, nodes []*Node) {
   312  	p := topicNodes{Echo: queryHash}
   313  	var sent bool
   314  	for _, result := range nodes {
   315  		if result.IP.Equal(t.net.tab.self.IP) || netutil.CheckRelayIP(remote.IP, result.IP) == nil {
   316  			p.Nodes = append(p.Nodes, nodeToRPC(result))
   317  		}
   318  		if len(p.Nodes) == maxTopicNodes {
   319  			t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p)
   320  			p.Nodes = p.Nodes[:0]
   321  			sent = true
   322  		}
   323  	}
   324  	if !sent || len(p.Nodes) > 0 {
   325  		t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p)
   326  	}
   327  }
   328  
   329  func (t *udp) sendPacket(toid NodeID, toaddr *net.UDPAddr, ptype byte, req interface{}) (hash []byte, err error) {
   330  	//fmt.Println("sendPacket", nodeEvent(ptype), toaddr.String(), toid.String())
   331  	packet, hash, err := encodePacket(t.priv, ptype, req)
   332  	if err != nil {
   333  		//fmt.Println(err)
   334  		return hash, err
   335  	}
   336  	log.Trace(fmt.Sprintf(">>> %v to %x@%v", nodeEvent(ptype), toid[:8], toaddr))
   337  	if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil {
   338  		log.Trace(fmt.Sprint("UDP send failed:", err))
   339  	}
   340  	//fmt.Println(err)
   341  	return hash, err
   342  }
   343  
   344  // zeroed padding space for encodePacket.
   345  var headSpace = make([]byte, headSize)
   346  
   347  func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) (p, hash []byte, err error) {
   348  	b := new(bytes.Buffer)
   349  	b.Write(headSpace)
   350  	b.WriteByte(ptype)
   351  	if err := rlp.Encode(b, req); err != nil {
   352  		log.Error(fmt.Sprint("error encoding packet:", err))
   353  		return nil, nil, err
   354  	}
   355  	packet := b.Bytes()
   356  	sig, err := crypto.Sign(crypto.Keccak256(packet[headSize:]), priv)
   357  	if err != nil {
   358  		log.Error(fmt.Sprint("could not sign packet:", err))
   359  		return nil, nil, err
   360  	}
   361  	copy(packet, versionPrefix)
   362  	copy(packet[versionPrefixSize:], sig)
   363  	hash = crypto.Keccak256(packet[versionPrefixSize:])
   364  	return packet, hash, nil
   365  }
   366  
   367  // readLoop runs in its own goroutine. it injects ingress UDP packets
   368  // into the network loop.
   369  func (t *udp) readLoop() {
   370  	defer t.conn.Close()
   371  	// Discovery packets are defined to be no larger than 1280 bytes.
   372  	// Packets larger than this size will be cut at the end and treated
   373  	// as invalid because their hash won't match.
   374  	buf := make([]byte, 1280)
   375  	for {
   376  		nbytes, from, err := t.conn.ReadFromUDP(buf)
   377  		if netutil.IsTemporaryError(err) {
   378  			// Ignore temporary read errors.
   379  			log.Debug(fmt.Sprintf("Temporary read error: %v", err))
   380  			continue
   381  		} else if err != nil {
   382  			// Shut down the loop for permament errors.
   383  			log.Debug(fmt.Sprintf("Read error: %v", err))
   384  			return
   385  		}
   386  		t.handlePacket(from, buf[:nbytes])
   387  	}
   388  }
   389  
   390  func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error {
   391  	pkt := ingressPacket{remoteAddr: from}
   392  	if err := decodePacket(buf, &pkt); err != nil {
   393  		log.Debug(fmt.Sprintf("Bad packet from %v: %v", from, err))
   394  		//fmt.Println("bad packet", err)
   395  		return err
   396  	}
   397  	t.net.reqReadPacket(pkt)
   398  	return nil
   399  }
   400  
   401  func decodePacket(buffer []byte, pkt *ingressPacket) error {
   402  	if len(buffer) < headSize+1 {
   403  		return errPacketTooSmall
   404  	}
   405  	buf := make([]byte, len(buffer))
   406  	copy(buf, buffer)
   407  	prefix, sig, sigdata := buf[:versionPrefixSize], buf[versionPrefixSize:headSize], buf[headSize:]
   408  	if !bytes.Equal(prefix, versionPrefix) {
   409  		return errBadPrefix
   410  	}
   411  	fromID, err := recoverNodeID(crypto.Keccak256(buf[headSize:]), sig)
   412  	if err != nil {
   413  		return err
   414  	}
   415  	pkt.rawData = buf
   416  	pkt.hash = crypto.Keccak256(buf[versionPrefixSize:])
   417  	pkt.remoteID = fromID
   418  	switch pkt.ev = nodeEvent(sigdata[0]); pkt.ev {
   419  	case pingPacket:
   420  		pkt.data = new(ping)
   421  	case pongPacket:
   422  		pkt.data = new(pong)
   423  	case findnodePacket:
   424  		pkt.data = new(findnode)
   425  	case neighborsPacket:
   426  		pkt.data = new(neighbors)
   427  	case findnodeHashPacket:
   428  		pkt.data = new(findnodeHash)
   429  	case topicRegisterPacket:
   430  		pkt.data = new(topicRegister)
   431  	case topicQueryPacket:
   432  		pkt.data = new(topicQuery)
   433  	case topicNodesPacket:
   434  		pkt.data = new(topicNodes)
   435  	default:
   436  		return fmt.Errorf("unknown packet type: %d", sigdata[0])
   437  	}
   438  	s := rlp.NewStream(bytes.NewReader(sigdata[1:]), 0)
   439  	err = s.Decode(pkt.data)
   440  	return err
   441  }