github.com/neatlab/neatio@v1.7.3-0.20220425043230-d903e92fcc75/network/p2p/discover/udp.go (about)

     1  package discover
     2  
     3  import (
     4  	"bytes"
     5  	"container/list"
     6  	"crypto/ecdsa"
     7  	"errors"
     8  	"fmt"
     9  	"net"
    10  	"time"
    11  
    12  	"github.com/neatlab/neatio/chain/log"
    13  	"github.com/neatlab/neatio/network/p2p/nat"
    14  	"github.com/neatlab/neatio/network/p2p/netutil"
    15  	"github.com/neatlab/neatio/utilities/crypto"
    16  	"github.com/neatlab/neatio/utilities/rlp"
    17  )
    18  
    19  const Version = 4
    20  
    21  var (
    22  	errPacketTooSmall   = errors.New("too small")
    23  	errBadHash          = errors.New("bad hash")
    24  	errExpired          = errors.New("expired")
    25  	errUnsolicitedReply = errors.New("unsolicited reply")
    26  	errUnknownNode      = errors.New("unknown node")
    27  	errTimeout          = errors.New("RPC timeout")
    28  	errClockWarp        = errors.New("reply deadline too far in the future")
    29  	errClosed           = errors.New("socket closed")
    30  )
    31  
    32  const (
    33  	respTimeout = 500 * time.Millisecond
    34  	sendTimeout = 500 * time.Millisecond
    35  	expiration  = 20 * time.Second
    36  
    37  	ntpFailureThreshold = 32
    38  	ntpWarningCooldown  = 10 * time.Minute
    39  	driftThreshold      = 10 * time.Second
    40  )
    41  
    42  const (
    43  	pingPacket = iota + 1
    44  	pongPacket
    45  	findnodePacket
    46  	neighborsPacket
    47  )
    48  
    49  type (
    50  	ping struct {
    51  		Version    uint
    52  		From, To   rpcEndpoint
    53  		Expiration uint64
    54  
    55  		Rest []rlp.RawValue `rlp:"tail"`
    56  	}
    57  
    58  	pong struct {
    59  		To rpcEndpoint
    60  
    61  		ReplyTok   []byte
    62  		Expiration uint64
    63  
    64  		Rest []rlp.RawValue `rlp:"tail"`
    65  	}
    66  
    67  	findnode struct {
    68  		Target     NodeID
    69  		Expiration uint64
    70  
    71  		Rest []rlp.RawValue `rlp:"tail"`
    72  	}
    73  
    74  	neighbors struct {
    75  		Nodes      []rpcNode
    76  		Expiration uint64
    77  
    78  		Rest []rlp.RawValue `rlp:"tail"`
    79  	}
    80  
    81  	rpcNode struct {
    82  		IP  net.IP
    83  		UDP uint16
    84  		TCP uint16
    85  		ID  NodeID
    86  	}
    87  
    88  	rpcEndpoint struct {
    89  		IP  net.IP
    90  		UDP uint16
    91  		TCP uint16
    92  	}
    93  )
    94  
    95  func makeEndpoint(addr *net.UDPAddr, tcpPort uint16) rpcEndpoint {
    96  	ip := addr.IP.To4()
    97  	if ip == nil {
    98  		ip = addr.IP.To16()
    99  	}
   100  	return rpcEndpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort}
   101  }
   102  
   103  func (t *udp) nodeFromRPC(sender *net.UDPAddr, rn rpcNode) (*Node, error) {
   104  	if rn.UDP <= 1024 {
   105  		return nil, errors.New("low port")
   106  	}
   107  	if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil {
   108  		return nil, err
   109  	}
   110  	if t.netrestrict != nil && !t.netrestrict.Contains(rn.IP) {
   111  		return nil, errors.New("not contained in netrestrict whitelist")
   112  	}
   113  	n := NewNode(rn.ID, rn.IP, rn.UDP, rn.TCP)
   114  	err := n.validateComplete()
   115  	return n, err
   116  }
   117  
   118  func nodeToRPC(n *Node) rpcNode {
   119  	return rpcNode{ID: n.ID, IP: n.IP, UDP: n.UDP, TCP: n.TCP}
   120  }
   121  
   122  type packet interface {
   123  	handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error
   124  	name() string
   125  }
   126  
   127  type conn interface {
   128  	ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error)
   129  	WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error)
   130  	Close() error
   131  	LocalAddr() net.Addr
   132  }
   133  
   134  type udp struct {
   135  	conn        conn
   136  	netrestrict *netutil.Netlist
   137  	priv        *ecdsa.PrivateKey
   138  	ourEndpoint rpcEndpoint
   139  
   140  	addpending chan *pending
   141  	gotreply   chan reply
   142  
   143  	closing chan struct{}
   144  	nat     nat.Interface
   145  
   146  	*Table
   147  }
   148  
   149  type pending struct {
   150  	from  NodeID
   151  	ptype byte
   152  
   153  	deadline time.Time
   154  
   155  	callback func(resp interface{}) (done bool)
   156  
   157  	errc chan<- error
   158  }
   159  
   160  type reply struct {
   161  	from  NodeID
   162  	ptype byte
   163  	data  interface{}
   164  
   165  	matched chan<- bool
   166  }
   167  
   168  type ReadPacket struct {
   169  	Data []byte
   170  	Addr *net.UDPAddr
   171  }
   172  
   173  type Config struct {
   174  	PrivateKey *ecdsa.PrivateKey
   175  
   176  	AnnounceAddr *net.UDPAddr
   177  	NodeDBPath   string
   178  	NetRestrict  *netutil.Netlist
   179  	Bootnodes    []*Node
   180  	Unhandled    chan<- ReadPacket
   181  }
   182  
   183  func ListenUDP(c conn, cfg Config) (*Table, error) {
   184  	tab, _, err := newUDP(c, cfg)
   185  	if err != nil {
   186  		return nil, err
   187  	}
   188  
   189  	return tab, nil
   190  }
   191  
   192  func newUDP(c conn, cfg Config) (*Table, *udp, error) {
   193  	udp := &udp{
   194  		conn:        c,
   195  		priv:        cfg.PrivateKey,
   196  		netrestrict: cfg.NetRestrict,
   197  		closing:     make(chan struct{}),
   198  		gotreply:    make(chan reply),
   199  		addpending:  make(chan *pending),
   200  	}
   201  	realaddr := c.LocalAddr().(*net.UDPAddr)
   202  	if cfg.AnnounceAddr != nil {
   203  		realaddr = cfg.AnnounceAddr
   204  	}
   205  
   206  	udp.ourEndpoint = makeEndpoint(realaddr, uint16(realaddr.Port))
   207  	tab, err := newTable(udp, PubkeyID(&cfg.PrivateKey.PublicKey), realaddr, cfg.NodeDBPath, cfg.Bootnodes)
   208  	if err != nil {
   209  		return nil, nil, err
   210  	}
   211  	udp.Table = tab
   212  
   213  	go udp.loop()
   214  	go udp.readLoop(cfg.Unhandled)
   215  	return udp.Table, udp, nil
   216  }
   217  
   218  func (t *udp) close() {
   219  	close(t.closing)
   220  	t.conn.Close()
   221  
   222  }
   223  
   224  func (t *udp) ping(toid NodeID, toaddr *net.UDPAddr) error {
   225  	req := &ping{
   226  		Version:    Version,
   227  		From:       t.ourEndpoint,
   228  		To:         makeEndpoint(toaddr, 0),
   229  		Expiration: uint64(time.Now().Add(expiration).Unix()),
   230  	}
   231  	packet, hash, err := encodePacket(t.priv, pingPacket, req)
   232  	if err != nil {
   233  		return err
   234  	}
   235  	errc := t.pending(toid, pongPacket, func(p interface{}) bool {
   236  		return bytes.Equal(p.(*pong).ReplyTok, hash)
   237  	})
   238  	t.write(toaddr, req.name(), packet)
   239  	return <-errc
   240  }
   241  
   242  func (t *udp) waitping(from NodeID) error {
   243  	return <-t.pending(from, pingPacket, func(interface{}) bool { return true })
   244  }
   245  
   246  func (t *udp) findnode(toid NodeID, toaddr *net.UDPAddr, target NodeID) ([]*Node, error) {
   247  	nodes := make([]*Node, 0, bucketSize)
   248  	nreceived := 0
   249  	errc := t.pending(toid, neighborsPacket, func(r interface{}) bool {
   250  		reply := r.(*neighbors)
   251  		for _, rn := range reply.Nodes {
   252  			nreceived++
   253  			n, err := t.nodeFromRPC(toaddr, rn)
   254  			if err != nil {
   255  				log.Trace("Invalid neighbor node received", "ip", rn.IP, "addr", toaddr, "err", err)
   256  				continue
   257  			}
   258  			nodes = append(nodes, n)
   259  		}
   260  		return nreceived >= bucketSize
   261  	})
   262  	t.send(toaddr, findnodePacket, &findnode{
   263  		Target:     target,
   264  		Expiration: uint64(time.Now().Add(expiration).Unix()),
   265  	})
   266  	err := <-errc
   267  	return nodes, err
   268  }
   269  
   270  func (t *udp) pending(id NodeID, ptype byte, callback func(interface{}) bool) <-chan error {
   271  	ch := make(chan error, 1)
   272  	p := &pending{from: id, ptype: ptype, callback: callback, errc: ch}
   273  	select {
   274  	case t.addpending <- p:
   275  
   276  	case <-t.closing:
   277  		ch <- errClosed
   278  	}
   279  	return ch
   280  }
   281  
   282  func (t *udp) handleReply(from NodeID, ptype byte, req packet) bool {
   283  	matched := make(chan bool, 1)
   284  	select {
   285  	case t.gotreply <- reply{from, ptype, req, matched}:
   286  
   287  		return <-matched
   288  	case <-t.closing:
   289  		return false
   290  	}
   291  }
   292  
   293  func (t *udp) loop() {
   294  	var (
   295  		plist        = list.New()
   296  		timeout      = time.NewTimer(0)
   297  		nextTimeout  *pending
   298  		contTimeouts = 0
   299  		ntpWarnTime  = time.Unix(0, 0)
   300  	)
   301  	<-timeout.C
   302  	defer timeout.Stop()
   303  
   304  	resetTimeout := func() {
   305  		if plist.Front() == nil || nextTimeout == plist.Front().Value {
   306  			return
   307  		}
   308  
   309  		now := time.Now()
   310  		for el := plist.Front(); el != nil; el = el.Next() {
   311  			nextTimeout = el.Value.(*pending)
   312  			if dist := nextTimeout.deadline.Sub(now); dist < 2*respTimeout {
   313  				timeout.Reset(dist)
   314  				return
   315  			}
   316  
   317  			nextTimeout.errc <- errClockWarp
   318  			plist.Remove(el)
   319  		}
   320  		nextTimeout = nil
   321  		timeout.Stop()
   322  	}
   323  
   324  	for {
   325  		resetTimeout()
   326  
   327  		select {
   328  		case <-t.closing:
   329  			for el := plist.Front(); el != nil; el = el.Next() {
   330  				el.Value.(*pending).errc <- errClosed
   331  			}
   332  			return
   333  
   334  		case p := <-t.addpending:
   335  			p.deadline = time.Now().Add(respTimeout)
   336  			plist.PushBack(p)
   337  
   338  		case r := <-t.gotreply:
   339  			var matched bool
   340  			for el := plist.Front(); el != nil; el = el.Next() {
   341  				p := el.Value.(*pending)
   342  				if p.from == r.from && p.ptype == r.ptype {
   343  					matched = true
   344  
   345  					if p.callback(r.data) {
   346  						p.errc <- nil
   347  						plist.Remove(el)
   348  					}
   349  
   350  					contTimeouts = 0
   351  				}
   352  			}
   353  			r.matched <- matched
   354  
   355  		case now := <-timeout.C:
   356  			nextTimeout = nil
   357  
   358  			for el := plist.Front(); el != nil; el = el.Next() {
   359  				p := el.Value.(*pending)
   360  				if now.After(p.deadline) || now.Equal(p.deadline) {
   361  					p.errc <- errTimeout
   362  					plist.Remove(el)
   363  					contTimeouts++
   364  				}
   365  			}
   366  
   367  			if contTimeouts > ntpFailureThreshold {
   368  				if time.Since(ntpWarnTime) >= ntpWarningCooldown {
   369  					ntpWarnTime = time.Now()
   370  					go checkClockDrift()
   371  				}
   372  				contTimeouts = 0
   373  			}
   374  		}
   375  	}
   376  }
   377  
   378  const (
   379  	macSize  = 256 / 8
   380  	sigSize  = 520 / 8
   381  	headSize = macSize + sigSize
   382  )
   383  
   384  var (
   385  	headSpace = make([]byte, headSize)
   386  
   387  	maxNeighbors int
   388  )
   389  
   390  func init() {
   391  	p := neighbors{Expiration: ^uint64(0)}
   392  	maxSizeNode := rpcNode{IP: make(net.IP, 16), UDP: ^uint16(0), TCP: ^uint16(0)}
   393  	for n := 0; ; n++ {
   394  		p.Nodes = append(p.Nodes, maxSizeNode)
   395  		size, _, err := rlp.EncodeToReader(p)
   396  		if err != nil {
   397  
   398  			panic("cannot encode: " + err.Error())
   399  		}
   400  		if headSize+size+1 >= 1280 {
   401  			maxNeighbors = n
   402  			break
   403  		}
   404  	}
   405  }
   406  
   407  func (t *udp) send(toaddr *net.UDPAddr, ptype byte, req packet) ([]byte, error) {
   408  	packet, hash, err := encodePacket(t.priv, ptype, req)
   409  	if err != nil {
   410  		return hash, err
   411  	}
   412  	return hash, t.write(toaddr, req.name(), packet)
   413  }
   414  
   415  func (t *udp) write(toaddr *net.UDPAddr, what string, packet []byte) error {
   416  	_, err := t.conn.WriteToUDP(packet, toaddr)
   417  	log.Trace(">> "+what, "addr", toaddr, "err", err)
   418  	return err
   419  }
   420  
   421  func encodePacket(priv *ecdsa.PrivateKey, ptype byte, req interface{}) (packet, hash []byte, err error) {
   422  	b := new(bytes.Buffer)
   423  	b.Write(headSpace)
   424  	b.WriteByte(ptype)
   425  	if err := rlp.Encode(b, req); err != nil {
   426  		log.Error("Can't encode discv4 packet", "err", err)
   427  		return nil, nil, err
   428  	}
   429  	packet = b.Bytes()
   430  	sig, err := crypto.Sign(crypto.Keccak256(packet[headSize:]), priv)
   431  	if err != nil {
   432  		log.Error("Can't sign discv4 packet", "err", err)
   433  		return nil, nil, err
   434  	}
   435  	copy(packet[macSize:], sig)
   436  
   437  	hash = crypto.Keccak256(packet[macSize:])
   438  	copy(packet, hash)
   439  	return packet, hash, nil
   440  }
   441  
   442  func (t *udp) readLoop(unhandled chan<- ReadPacket) {
   443  	defer t.conn.Close()
   444  	if unhandled != nil {
   445  		defer close(unhandled)
   446  	}
   447  
   448  	buf := make([]byte, 1280)
   449  	for {
   450  		nbytes, from, err := t.conn.ReadFromUDP(buf)
   451  		if netutil.IsTemporaryError(err) {
   452  
   453  			log.Debug("Temporary UDP read error", "err", err)
   454  			continue
   455  		} else if err != nil {
   456  
   457  			log.Debug("UDP read error", "err", err)
   458  			return
   459  		}
   460  		if t.handlePacket(from, buf[:nbytes]) != nil && unhandled != nil {
   461  			select {
   462  			case unhandled <- ReadPacket{buf[:nbytes], from}:
   463  			default:
   464  			}
   465  		}
   466  	}
   467  }
   468  
   469  func (t *udp) handlePacket(from *net.UDPAddr, buf []byte) error {
   470  	packet, fromID, hash, err := decodePacket(buf)
   471  	if err != nil {
   472  		log.Debug("Bad discv4 packet", "addr", from, "err", err)
   473  		return err
   474  	}
   475  	err = packet.handle(t, from, fromID, hash)
   476  	log.Trace("<< "+packet.name(), "addr", from, "err", err)
   477  	return err
   478  }
   479  
   480  func decodePacket(buf []byte) (packet, NodeID, []byte, error) {
   481  	if len(buf) < headSize+1 {
   482  		return nil, NodeID{}, nil, errPacketTooSmall
   483  	}
   484  	hash, sig, sigdata := buf[:macSize], buf[macSize:headSize], buf[headSize:]
   485  	shouldhash := crypto.Keccak256(buf[macSize:])
   486  	if !bytes.Equal(hash, shouldhash) {
   487  		return nil, NodeID{}, nil, errBadHash
   488  	}
   489  	fromID, err := recoverNodeID(crypto.Keccak256(buf[headSize:]), sig)
   490  	if err != nil {
   491  		return nil, NodeID{}, hash, err
   492  	}
   493  	var req packet
   494  	switch ptype := sigdata[0]; ptype {
   495  	case pingPacket:
   496  		req = new(ping)
   497  	case pongPacket:
   498  		req = new(pong)
   499  	case findnodePacket:
   500  		req = new(findnode)
   501  	case neighborsPacket:
   502  		req = new(neighbors)
   503  	default:
   504  		return nil, fromID, hash, fmt.Errorf("unknown type: %d", ptype)
   505  	}
   506  	s := rlp.NewStream(bytes.NewReader(sigdata[1:]), 0)
   507  	err = s.Decode(req)
   508  	return req, fromID, hash, err
   509  }
   510  
   511  func (req *ping) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
   512  	if expired(req.Expiration) {
   513  		return errExpired
   514  	}
   515  	t.send(from, pongPacket, &pong{
   516  		To:         makeEndpoint(from, req.From.TCP),
   517  		ReplyTok:   mac,
   518  		Expiration: uint64(time.Now().Add(expiration).Unix()),
   519  	})
   520  	if !t.handleReply(fromID, pingPacket, req) {
   521  
   522  		go t.bond(true, fromID, from, req.From.TCP)
   523  	}
   524  	return nil
   525  }
   526  
   527  func (req *ping) name() string { return "PING/v4" }
   528  
   529  func (req *pong) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
   530  	if expired(req.Expiration) {
   531  		return errExpired
   532  	}
   533  	if !t.handleReply(fromID, pongPacket, req) {
   534  		return errUnsolicitedReply
   535  	}
   536  	return nil
   537  }
   538  
   539  func (req *pong) name() string { return "PONG/v4" }
   540  
   541  func (req *findnode) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
   542  	if expired(req.Expiration) {
   543  		return errExpired
   544  	}
   545  	if !t.db.hasBond(fromID) {
   546  
   547  		return errUnknownNode
   548  	}
   549  	target := crypto.Keccak256Hash(req.Target[:])
   550  	t.mutex.Lock()
   551  	closest := t.closest(target, bucketSize).entries
   552  	t.mutex.Unlock()
   553  
   554  	p := neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())}
   555  	var sent bool
   556  
   557  	for _, n := range closest {
   558  		if netutil.CheckRelayIP(from.IP, n.IP) == nil {
   559  			p.Nodes = append(p.Nodes, nodeToRPC(n))
   560  		}
   561  		if len(p.Nodes) == maxNeighbors {
   562  			t.send(from, neighborsPacket, &p)
   563  			p.Nodes = p.Nodes[:0]
   564  			sent = true
   565  		}
   566  	}
   567  	if len(p.Nodes) > 0 || !sent {
   568  		t.send(from, neighborsPacket, &p)
   569  	}
   570  	return nil
   571  }
   572  
   573  func (req *findnode) name() string { return "FINDNODE/v4" }
   574  
   575  func (req *neighbors) handle(t *udp, from *net.UDPAddr, fromID NodeID, mac []byte) error {
   576  	if expired(req.Expiration) {
   577  		return errExpired
   578  	}
   579  	if !t.handleReply(fromID, neighborsPacket, req) {
   580  		return errUnsolicitedReply
   581  	}
   582  	return nil
   583  }
   584  
   585  func (req *neighbors) name() string { return "NEIGHBORS/v4" }
   586  
   587  func expired(ts uint64) bool {
   588  	return time.Unix(int64(ts), 0).Before(time.Now())
   589  }