github.com/bytom/bytom@v1.1.2-0.20221014091027-bbcba3df6075/p2p/discover/dht/udp.go (about)

     1  package dht
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/ecdsa"
     6  	"encoding/hex"
     7  	"errors"
     8  	"fmt"
     9  	"net"
    10  	"path"
    11  	"strconv"
    12  	"time"
    13  
    14  	log "github.com/sirupsen/logrus"
    15  	"github.com/tendermint/go-wire"
    16  
    17  	"github.com/bytom/bytom/common"
    18  	cfg "github.com/bytom/bytom/config"
    19  	"github.com/bytom/bytom/crypto"
    20  	"github.com/bytom/bytom/crypto/ed25519/chainkd"
    21  	"github.com/bytom/bytom/p2p/netutil"
    22  	"github.com/bytom/bytom/version"
    23  )
    24  
    25  const (
    26  	Version   = 4
    27  	logModule = "discover"
    28  )
    29  
    30  // Errors
    31  var (
    32  	errPacketTooSmall   = errors.New("too small")
    33  	errBadPrefix        = errors.New("bad prefix")
    34  	errExpired          = errors.New("expired")
    35  	errUnsolicitedReply = errors.New("unsolicited reply")
    36  	errUnknownNode      = errors.New("unknown node")
    37  	errTimeout          = errors.New("RPC timeout")
    38  	errClockWarp        = errors.New("reply deadline too far in the future")
    39  	errClosed           = errors.New("socket closed")
    40  )
    41  
    42  // Timeouts
    43  const (
    44  	respTimeout = 1 * time.Second
    45  	queryDelay  = 1000 * time.Millisecond
    46  	expiration  = 20 * time.Second
    47  
    48  	ntpFailureThreshold = 32               // Continuous timeouts after which to check NTP
    49  	ntpWarningCooldown  = 10 * time.Minute // Minimum amount of time to pass before repeating NTP warning
    50  	driftThreshold      = 10 * time.Second // Allowed clock drift before warning user
    51  )
    52  
    53  // ReadPacket is sent to the unhandled channel when it could not be processed
    54  type ReadPacket struct {
    55  	Data []byte
    56  	Addr *net.UDPAddr
    57  }
    58  
    59  // Config holds Table-related settings.
    60  type Config struct {
    61  	// These settings are required and configure the UDP listener:
    62  	PrivateKey *ecdsa.PrivateKey
    63  
    64  	// These settings are optional:
    65  	AnnounceAddr *net.UDPAddr // local address announced in the DHT
    66  	NodeDBPath   string       // if set, the node database is stored at this filesystem location
    67  	//NetRestrict  *netutil.Netlist  // network whitelist
    68  	Bootnodes []*Node           // list of bootstrap nodes
    69  	Unhandled chan<- ReadPacket // unhandled packets are sent on this channel
    70  }
    71  
    72  // RPC request structures
    73  type (
    74  	ping struct {
    75  		Version    uint
    76  		From, To   rpcEndpoint
    77  		Expiration uint64
    78  
    79  		// v5
    80  		Topics []Topic
    81  
    82  		// Ignore additional fields (for forward compatibility).
    83  		Rest []byte
    84  	}
    85  
    86  	// pong is the reply to ping.
    87  	pong struct {
    88  		// This field should mirror the UDP envelope address
    89  		// of the ping packet, which provides a way to discover the
    90  		// the external address (after NAT).
    91  		To rpcEndpoint
    92  
    93  		ReplyTok   []byte // This contains the hash of the ping packet.
    94  		Expiration uint64 // Absolute timestamp at which the packet becomes invalid.
    95  
    96  		// v5
    97  		TopicHash    common.Hash
    98  		TicketSerial uint32
    99  		WaitPeriods  []uint32
   100  
   101  		// Ignore additional fields (for forward compatibility).
   102  		Rest []byte
   103  	}
   104  
   105  	// findnode is a query for nodes close to the given target.
   106  	findnode struct {
   107  		Target     NodeID // doesn't need to be an actual public key
   108  		Expiration uint64
   109  		// Ignore additional fields (for forward compatibility).
   110  		Rest []byte
   111  	}
   112  
   113  	// findnode is a query for nodes close to the given target.
   114  	findnodeHash struct {
   115  		Target     common.Hash
   116  		Expiration uint64
   117  		// Ignore additional fields (for forward compatibility).
   118  		Rest []byte
   119  	}
   120  
   121  	// reply to findnode
   122  	neighbors struct {
   123  		Nodes      []rpcNode
   124  		Expiration uint64
   125  		// Ignore additional fields (for forward compatibility).
   126  		Rest []byte
   127  	}
   128  
   129  	topicRegister struct {
   130  		Topics []Topic
   131  		Idx    uint
   132  		Pong   []byte
   133  	}
   134  
   135  	topicQuery struct {
   136  		Topic      Topic
   137  		Expiration uint64
   138  	}
   139  
   140  	// reply to topicQuery
   141  	topicNodes struct {
   142  		Echo  common.Hash
   143  		Nodes []rpcNode
   144  	}
   145  
   146  	rpcNode struct {
   147  		IP  net.IP // len 4 for IPv4 or 16 for IPv6
   148  		UDP uint16 // for discovery protocol
   149  		TCP uint16 // for RLPx protocol
   150  		ID  NodeID
   151  	}
   152  
   153  	rpcEndpoint struct {
   154  		IP  net.IP // len 4 for IPv4 or 16 for IPv6
   155  		UDP uint16 // for discovery protocol
   156  		TCP uint16 // for RLPx protocol
   157  	}
   158  )
   159  
   160  var (
   161  	versionPrefix     = []byte("bytom discovery")
   162  	versionPrefixSize = len(versionPrefix)
   163  	nodeIDSize        = 32
   164  	sigSize           = 520 / 8
   165  	headSize          = versionPrefixSize + nodeIDSize + sigSize // space of packet frame data
   166  )
   167  
   168  // Neighbors replies are sent across multiple packets to
   169  // stay below the 1280 byte limit. We compute the maximum number
   170  // of entries by stuffing a packet until it grows too large.
   171  var maxNeighbors = func() int {
   172  	p := neighbors{Expiration: ^uint64(0)}
   173  	maxSizeNode := rpcNode{IP: make(net.IP, 16), UDP: ^uint16(0), TCP: ^uint16(0)}
   174  	for n := 0; ; n++ {
   175  		p.Nodes = append(p.Nodes, maxSizeNode)
   176  		var size int
   177  		var err error
   178  		b := new(bytes.Buffer)
   179  		wire.WriteJSON(p, b, &size, &err)
   180  		if err != nil {
   181  			// If this ever happens, it will be caught by the unit tests.
   182  			panic("cannot encode: " + err.Error())
   183  		}
   184  		if headSize+size+1 >= 1280 {
   185  			return n
   186  		}
   187  	}
   188  }()
   189  
   190  var maxTopicNodes = func() int {
   191  	p := topicNodes{}
   192  	maxSizeNode := rpcNode{IP: make(net.IP, 16), UDP: ^uint16(0), TCP: ^uint16(0)}
   193  	for n := 0; ; n++ {
   194  		p.Nodes = append(p.Nodes, maxSizeNode)
   195  		var size int
   196  		var err error
   197  		b := new(bytes.Buffer)
   198  		wire.WriteJSON(p, b, &size, &err)
   199  		if err != nil {
   200  			// If this ever happens, it will be caught by the unit tests.
   201  			panic("cannot encode: " + err.Error())
   202  		}
   203  		if headSize+size+1 >= 1280 {
   204  			return n
   205  		}
   206  	}
   207  }()
   208  
   209  func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint {
   210  	ip := addr.IP.To4()
   211  	if ip == nil {
   212  		ip = addr.IP.To16()
   213  	}
   214  	return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort}
   215  }
   216  
   217  func (e1 rpcEndpoint) equal(e2 rpcEndpoint) bool {
   218  	return e1.UDP == e2.UDP && e1.TCP == e2.TCP && e1.IP.Equal(e2.IP)
   219  }
   220  
   221  func nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) {
   222  	if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil {
   223  		return nil, err
   224  	}
   225  	n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP)
   226  	err := n.validateComplete()
   227  	return n, err
   228  }
   229  
   230  func nodeToRPC(n *Node) rpcNode {
   231  	return rpcNode{ID: n.ID, IP: n.IP, UDP: n.UDP, TCP: n.TCP}
   232  }
   233  
   234  type ingressPacket struct {
   235  	remoteID   NodeID
   236  	remoteAddr *net.UDPAddr
   237  	ev         nodeEvent
   238  	hash       []byte
   239  	data       interface{} // one of the RPC structs
   240  	rawData    []byte
   241  }
   242  
   243  type conn interface {
   244  	ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error)
   245  	WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error)
   246  	Close() error
   247  	LocalAddr() net.Addr
   248  }
   249  
   250  type netWork interface {
   251  	reqReadPacket(pkt ingressPacket)
   252  	selfIP() net.IP
   253  }
   254  
   255  // udp implements the RPC protocol.
   256  type udp struct {
   257  	conn        conn
   258  	priv        chainkd.XPrv
   259  	ourEndpoint rpcEndpoint
   260  	//nat         nat.Interface
   261  	net netWork
   262  }
   263  
   264  func NewDiscover(config *cfg.Config, priv chainkd.XPrv, port uint16) (*Network, error) {
   265  	addr, err := net.ResolveUDPAddr("udp", net.JoinHostPort("0.0.0.0", strconv.FormatUint(uint64(port), 10)))
   266  	if err != nil {
   267  		return nil, err
   268  	}
   269  
   270  	conn, err := net.ListenUDP("udp", addr)
   271  	if err != nil {
   272  		return nil, err
   273  	}
   274  
   275  	realaddr := conn.LocalAddr().(*net.UDPAddr)
   276  	ntab, err := ListenUDP(priv, conn, realaddr, path.Join(config.DBDir(), "discover"), nil)
   277  	if err != nil {
   278  		return nil, err
   279  	}
   280  	seeds, err := QueryDNSSeeds(net.LookupHost)
   281  	if err != nil {
   282  		log.WithFields(log.Fields{"module": logModule, "err": err}).Error("fail on query dns seeds")
   283  	}
   284  
   285  	codedSeeds := netutil.CheckAndSplitAddresses(config.P2P.Seeds)
   286  	seeds = append(seeds, codedSeeds...)
   287  	if len(seeds) == 0 {
   288  		return ntab, nil
   289  	}
   290  
   291  	var nodes []*Node
   292  	for _, seed := range seeds {
   293  		version.Status.AddSeed(seed)
   294  		url := "enode://" + hex.EncodeToString(crypto.Sha256([]byte(seed))) + "@" + seed
   295  		nodes = append(nodes, MustParseNode(url))
   296  	}
   297  
   298  	if err = ntab.SetFallbackNodes(nodes); err != nil {
   299  		return nil, err
   300  	}
   301  	return ntab, nil
   302  }
   303  
   304  // ListenUDP returns a new table that listens for UDP packets on laddr.
   305  func ListenUDP(priv chainkd.XPrv, conn conn, realaddr *net.UDPAddr, nodeDBPath string, netrestrict *netutil.Netlist) (*Network, error) {
   306  	transport, err := listenUDP(priv, conn, realaddr)
   307  	if err != nil {
   308  		return nil, err
   309  	}
   310  
   311  	net, err := newNetwork(transport, priv.XPub().PublicKey(), nodeDBPath, netrestrict)
   312  	if err != nil {
   313  		return nil, err
   314  	}
   315  	log.WithFields(log.Fields{"module": logModule, "net": net.tab.self}).Info("UDP listener up v5")
   316  	transport.net = net
   317  	go transport.readLoop()
   318  	return net, nil
   319  }
   320  
   321  func listenUDP(priv chainkd.XPrv, conn conn, realaddr *net.UDPAddr) (*udp, error) {
   322  	return &udp{conn: conn, priv: priv, ourEndpoint: makeEndpoint(realaddr, uint16(realaddr.Port))}, nil
   323  }
   324  
   325  func (t *udp) localAddr() *net.UDPAddr {
   326  	return t.conn.LocalAddr().(*net.UDPAddr)
   327  }
   328  
   329  func (t *udp) Close() {
   330  	t.conn.Close()
   331  }
   332  
   333  func (t *udp) send(remote *Node, ptype nodeEvent, data interface{}) (hash []byte) {
   334  	hash, _ = t.sendPacket(remote.ID, remote.addr(), byte(ptype), data)
   335  	return hash
   336  }
   337  
   338  func (t *udp) sendPing(remote *Node, toaddr *net.UDPAddr, topics []Topic) (hash []byte) {
   339  	hash, _ = t.sendPacket(remote.ID, toaddr, byte(pingPacket), ping{
   340  		Version:    Version,
   341  		From:       t.ourEndpoint,
   342  		To:         makeEndpoint(toaddr, uint16(toaddr.Port)), // TODO: maybe use known TCP port from DB
   343  		Expiration: uint64(time.Now().Add(expiration).Unix()),
   344  		Topics:     topics,
   345  	})
   346  	return hash
   347  }
   348  
   349  func (t *udp) sendFindnode(remote *Node, target NodeID) {
   350  	t.sendPacket(remote.ID, remote.addr(), byte(findnodePacket), findnode{
   351  		Target:     target,
   352  		Expiration: uint64(time.Now().Add(expiration).Unix()),
   353  	})
   354  }
   355  
   356  func (t *udp) sendNeighbours(remote *Node, results []*Node) {
   357  	// Send neighbors in chunks with at most maxNeighbors per packet
   358  	// to stay below the 1280 byte limit.
   359  	p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
   360  	for i, result := range results {
   361  		p.Nodes = append(p.Nodes, nodeToRPC(result))
   362  		if len(p.Nodes) == maxNeighbors || i == len(results)-1 {
   363  			t.sendPacket(remote.ID, remote.addr(), byte(neighborsPacket), p)
   364  			p.Nodes = p.Nodes[:0]
   365  		}
   366  	}
   367  }
   368  
   369  func (t *udp) sendFindnodeHash(remote *Node, target common.Hash) {
   370  	t.sendPacket(remote.ID, remote.addr(), byte(findnodeHashPacket), findnodeHash{
   371  		Target:     common.Hash(target),
   372  		Expiration: uint64(time.Now().Add(expiration).Unix()),
   373  	})
   374  }
   375  
   376  func (t *udp) sendTopicRegister(remote *Node, topics []Topic, idx int, pong []byte) {
   377  	t.sendPacket(remote.ID, remote.addr(), byte(topicRegisterPacket), topicRegister{
   378  		Topics: topics,
   379  		Idx:    uint(idx),
   380  		Pong:   pong,
   381  	})
   382  }
   383  
   384  func (t *udp) sendTopicNodes(remote *Node, queryHash common.Hash, nodes []*Node) {
   385  	p := topicNodes{Echo: queryHash}
   386  	var sent bool
   387  	for _, result := range nodes {
   388  		if result.IP.Equal(t.net.selfIP()) || netutil.CheckRelayIP(remote.IP, result.IP) == nil {
   389  			p.Nodes = append(p.Nodes, nodeToRPC(result))
   390  		}
   391  		if len(p.Nodes) == maxTopicNodes {
   392  			t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p)
   393  			p.Nodes = p.Nodes[:0]
   394  			sent = true
   395  		}
   396  	}
   397  	if !sent || len(p.Nodes) > 0 {
   398  		t.sendPacket(remote.ID, remote.addr(), byte(topicNodesPacket), p)
   399  	}
   400  }
   401  
   402  func (t *udp) sendPacket(toid NodeID, toaddr *net.UDPAddr, ptype byte, req interface{}) (hash []byte, err error) {
   403  	packet, hash, err := encodePacket(t.priv, ptype, req)
   404  	if err != nil {
   405  		return hash, err
   406  	}
   407  	log.WithFields(log.Fields{"module": logModule, "event": nodeEvent(ptype), "to id": hex.EncodeToString(toid[:8]), "to addr": toaddr}).Debug("send packet")
   408  	if _, err = t.conn.WriteToUDP(packet, toaddr); err != nil {
   409  		log.WithFields(log.Fields{"module": logModule, "error": err}).Info(fmt.Sprint("UDP send failed"))
   410  	}
   411  	return hash, err
   412  }
   413  
   414  // zeroed padding space for encodePacket.
   415  var headSpace = make([]byte, headSize)
   416  
   417  func encodePacket(priv chainkd.XPrv, ptype byte, req interface{}) (p, hash []byte, err error) {
   418  	b := new(bytes.Buffer)
   419  	b.Write(headSpace)
   420  	b.WriteByte(ptype)
   421  	var size int
   422  	wire.WriteJSON(req, b, &size, &err)
   423  	if err != nil {
   424  		log.WithFields(log.Fields{"module": logModule, "error": err}).Error("error encoding packet")
   425  		return nil, nil, err
   426  	}
   427  	packet := b.Bytes()
   428  	nodeID := priv.XPub().PublicKey()
   429  	sig := priv.Sign(common.BytesToHash(packet[headSize:]).Bytes())
   430  	copy(packet, versionPrefix)
   431  	copy(packet[versionPrefixSize:], nodeID[:])
   432  	copy(packet[versionPrefixSize+nodeIDSize:], sig)
   433  
   434  	hash = common.BytesToHash(packet[versionPrefixSize:]).Bytes()
   435  	return packet, hash, nil
   436  }
   437  
   438  // readLoop runs in its own goroutine. it injects ingress UDP packets
   439  // into the network loop.
   440  func (t *udp) readLoop() {
   441  	defer t.conn.Close()
   442  	// Discovery packets are defined to be no larger than 1280 bytes.
   443  	// Packets larger than this size will be cut at the end and treated
   444  	// as invalid because their hash won't match.
   445  	buf := make([]byte, 1280)
   446  	for {
   447  		nbytes, from, err := t.conn.ReadFromUDP(buf)
   448  		if netutil.IsTemporaryError(err) {
   449  			// Ignore temporary read errors.
   450  			log.WithFields(log.Fields{"module": logModule, "error": err}).Debug("Temporary read error")
   451  			continue
   452  		} else if err != nil {
   453  			// Shut down the loop for permament errors.
   454  			log.WithFields(log.Fields{"module": logModule, "error": err}).Debug("Read error")
   455  			return
   456  		}
   457  		t.handlePacket(from, buf[:nbytes])
   458  	}
   459  }
   460  
   461  func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error {
   462  	pkt := ingressPacket{remoteAddr: from}
   463  	if err := decodePacket(buf, &pkt); err != nil {
   464  		log.WithFields(log.Fields{"module": logModule, "from": from, "error": err}).Error("Bad packet")
   465  		return err
   466  	}
   467  	t.net.reqReadPacket(pkt)
   468  	return nil
   469  }
   470  
   471  func decodePacket(buffer []byte, pkt *ingressPacket) error {
   472  	if len(buffer) < headSize+1 {
   473  		return errPacketTooSmall
   474  	}
   475  	buf := make([]byte, len(buffer))
   476  	copy(buf, buffer)
   477  	prefix, fromID, sigdata := buf[:versionPrefixSize], buf[versionPrefixSize:versionPrefixSize+nodeIDSize], buf[headSize:]
   478  	if !bytes.Equal(prefix, versionPrefix) {
   479  		return errBadPrefix
   480  	}
   481  	pkt.rawData = buf
   482  	pkt.hash = common.BytesToHash(buf[versionPrefixSize:]).Bytes()
   483  	pkt.remoteID = ByteID(fromID)
   484  	switch pkt.ev = nodeEvent(sigdata[0]); pkt.ev {
   485  	case pingPacket:
   486  		pkt.data = new(ping)
   487  	case pongPacket:
   488  		pkt.data = new(pong)
   489  	case findnodePacket:
   490  		pkt.data = new(findnode)
   491  	case neighborsPacket:
   492  		pkt.data = new(neighbors)
   493  	case findnodeHashPacket:
   494  		pkt.data = new(findnodeHash)
   495  	case topicRegisterPacket:
   496  		pkt.data = new(topicRegister)
   497  	case topicQueryPacket:
   498  		pkt.data = new(topicQuery)
   499  	case topicNodesPacket:
   500  		pkt.data = new(topicNodes)
   501  	default:
   502  		return fmt.Errorf("unknown packet type: %d", sigdata[0])
   503  	}
   504  	var err error
   505  	wire.ReadJSON(pkt.data, sigdata[1:], &err)
   506  	if err != nil {
   507  		log.WithFields(log.Fields{"module": logModule, "error": err}).Error("wire readjson err")
   508  	}
   509  
   510  	return err
   511  }