github.com/flashbots/go-ethereum@v1.9.7/p2p/discv5/node.go (about)

     1  // Copyright 2015 The go-ethereum Authors
     2  // This file is part of the go-ethereum library.
     3  //
     4  // The go-ethereum library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The go-ethereum library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package discv5
    18  
    19  import (
    20  	"crypto/ecdsa"
    21  	"crypto/elliptic"
    22  	"encoding/hex"
    23  	"errors"
    24  	"fmt"
    25  	"math/big"
    26  	"math/rand"
    27  	"net"
    28  	"net/url"
    29  	"regexp"
    30  	"strconv"
    31  	"strings"
    32  
    33  	"github.com/ethereum/go-ethereum/common"
    34  	"github.com/ethereum/go-ethereum/crypto"
    35  )
    36  
    37  // Node represents a host on the network.
    38  // The public fields of Node may not be modified.
    39  type Node struct {
    40  	IP       net.IP // len 4 for IPv4 or 16 for IPv6
    41  	UDP, TCP uint16 // port numbers
    42  	ID       NodeID // the node's public key
    43  
    44  	// Network-related fields are contained in nodeNetGuts.
    45  	// These fields are not supposed to be used off the
    46  	// Network.loop goroutine.
    47  	nodeNetGuts
    48  }
    49  
    50  // NewNode creates a new node. It is mostly meant to be used for
    51  // testing purposes.
    52  func NewNode(id NodeID, ip net.IP, udpPort, tcpPort uint16) *Node {
    53  	if ipv4 := ip.To4(); ipv4 != nil {
    54  		ip = ipv4
    55  	}
    56  	return &Node{
    57  		IP:          ip,
    58  		UDP:         udpPort,
    59  		TCP:         tcpPort,
    60  		ID:          id,
    61  		nodeNetGuts: nodeNetGuts{sha: crypto.Keccak256Hash(id[:])},
    62  	}
    63  }
    64  
    65  func (n *Node) addr() *net.UDPAddr {
    66  	return &net.UDPAddr{IP: n.IP, Port: int(n.UDP)}
    67  }
    68  
    69  func (n *Node) setAddr(a *net.UDPAddr) {
    70  	n.IP = a.IP
    71  	if ipv4 := a.IP.To4(); ipv4 != nil {
    72  		n.IP = ipv4
    73  	}
    74  	n.UDP = uint16(a.Port)
    75  }
    76  
    77  // compares the given address against the stored values.
    78  func (n *Node) addrEqual(a *net.UDPAddr) bool {
    79  	ip := a.IP
    80  	if ipv4 := a.IP.To4(); ipv4 != nil {
    81  		ip = ipv4
    82  	}
    83  	return n.UDP == uint16(a.Port) && n.IP.Equal(ip)
    84  }
    85  
    86  // Incomplete returns true for nodes with no IP address.
    87  func (n *Node) Incomplete() bool {
    88  	return n.IP == nil
    89  }
    90  
    91  // checks whether n is a valid complete node.
    92  func (n *Node) validateComplete() error {
    93  	if n.Incomplete() {
    94  		return errors.New("incomplete node")
    95  	}
    96  	if n.UDP == 0 {
    97  		return errors.New("missing UDP port")
    98  	}
    99  	if n.TCP == 0 {
   100  		return errors.New("missing TCP port")
   101  	}
   102  	if n.IP.IsMulticast() || n.IP.IsUnspecified() {
   103  		return errors.New("invalid IP (multicast/unspecified)")
   104  	}
   105  	_, err := n.ID.Pubkey() // validate the key (on curve, etc.)
   106  	return err
   107  }
   108  
   109  // The string representation of a Node is a URL.
   110  // Please see ParseNode for a description of the format.
   111  func (n *Node) String() string {
   112  	u := url.URL{Scheme: "enode"}
   113  	if n.Incomplete() {
   114  		u.Host = fmt.Sprintf("%x", n.ID[:])
   115  	} else {
   116  		addr := net.TCPAddr{IP: n.IP, Port: int(n.TCP)}
   117  		u.User = url.User(fmt.Sprintf("%x", n.ID[:]))
   118  		u.Host = addr.String()
   119  		if n.UDP != n.TCP {
   120  			u.RawQuery = "discport=" + strconv.Itoa(int(n.UDP))
   121  		}
   122  	}
   123  	return u.String()
   124  }
   125  
   126  var incompleteNodeURL = regexp.MustCompile("(?i)^(?:enode://)?([0-9a-f]+)$")
   127  
   128  // ParseNode parses a node designator.
   129  //
   130  // There are two basic forms of node designators
   131  //   - incomplete nodes, which only have the public key (node ID)
   132  //   - complete nodes, which contain the public key and IP/Port information
   133  //
   134  // For incomplete nodes, the designator must look like one of these
   135  //
   136  //    enode://<hex node id>
   137  //    <hex node id>
   138  //
   139  // For complete nodes, the node ID is encoded in the username portion
   140  // of the URL, separated from the host by an @ sign. The hostname can
   141  // only be given as an IP address, DNS domain names are not allowed.
   142  // The port in the host name section is the TCP listening port. If the
   143  // TCP and UDP (discovery) ports differ, the UDP port is specified as
   144  // query parameter "discport".
   145  //
   146  // In the following example, the node URL describes
   147  // a node with IP address 10.3.58.6, TCP listening port 30303
   148  // and UDP discovery port 30301.
   149  //
   150  //    enode://<hex node id>@10.3.58.6:30303?discport=30301
   151  func ParseNode(rawurl string) (*Node, error) {
   152  	if m := incompleteNodeURL.FindStringSubmatch(rawurl); m != nil {
   153  		id, err := HexID(m[1])
   154  		if err != nil {
   155  			return nil, fmt.Errorf("invalid node ID (%v)", err)
   156  		}
   157  		return NewNode(id, nil, 0, 0), nil
   158  	}
   159  	return parseComplete(rawurl)
   160  }
   161  
   162  func parseComplete(rawurl string) (*Node, error) {
   163  	var (
   164  		id               NodeID
   165  		ip               net.IP
   166  		tcpPort, udpPort uint64
   167  	)
   168  	u, err := url.Parse(rawurl)
   169  	if err != nil {
   170  		return nil, err
   171  	}
   172  	if u.Scheme != "enode" {
   173  		return nil, errors.New("invalid URL scheme, want \"enode\"")
   174  	}
   175  	// Parse the Node ID from the user portion.
   176  	if u.User == nil {
   177  		return nil, errors.New("does not contain node ID")
   178  	}
   179  	if id, err = HexID(u.User.String()); err != nil {
   180  		return nil, fmt.Errorf("invalid node ID (%v)", err)
   181  	}
   182  	// Parse the IP address.
   183  	host, port, err := net.SplitHostPort(u.Host)
   184  	if err != nil {
   185  		return nil, fmt.Errorf("invalid host: %v", err)
   186  	}
   187  	if ip = net.ParseIP(host); ip == nil {
   188  		return nil, errors.New("invalid IP address")
   189  	}
   190  	// Ensure the IP is 4 bytes long for IPv4 addresses.
   191  	if ipv4 := ip.To4(); ipv4 != nil {
   192  		ip = ipv4
   193  	}
   194  	// Parse the port numbers.
   195  	if tcpPort, err = strconv.ParseUint(port, 10, 16); err != nil {
   196  		return nil, errors.New("invalid port")
   197  	}
   198  	udpPort = tcpPort
   199  	qv := u.Query()
   200  	if qv.Get("discport") != "" {
   201  		udpPort, err = strconv.ParseUint(qv.Get("discport"), 10, 16)
   202  		if err != nil {
   203  			return nil, errors.New("invalid discport in query")
   204  		}
   205  	}
   206  	return NewNode(id, ip, uint16(udpPort), uint16(tcpPort)), nil
   207  }
   208  
   209  // MustParseNode parses a node URL. It panics if the URL is not valid.
   210  func MustParseNode(rawurl string) *Node {
   211  	n, err := ParseNode(rawurl)
   212  	if err != nil {
   213  		panic("invalid node URL: " + err.Error())
   214  	}
   215  	return n
   216  }
   217  
   218  // MarshalText implements encoding.TextMarshaler.
   219  func (n *Node) MarshalText() ([]byte, error) {
   220  	return []byte(n.String()), nil
   221  }
   222  
   223  // UnmarshalText implements encoding.TextUnmarshaler.
   224  func (n *Node) UnmarshalText(text []byte) error {
   225  	dec, err := ParseNode(string(text))
   226  	if err == nil {
   227  		*n = *dec
   228  	}
   229  	return err
   230  }
   231  
   232  // type nodeQueue []*Node
   233  //
   234  // // pushNew adds n to the end if it is not present.
   235  // func (nl *nodeList) appendNew(n *Node) {
   236  // 	for _, entry := range n {
   237  // 		if entry == n {
   238  // 			return
   239  // 		}
   240  // 	}
   241  // 	*nq = append(*nq, n)
   242  // }
   243  //
   244  // // popRandom removes a random node. Nodes closer to
   245  // // to the head of the beginning of the have a slightly higher probability.
   246  // func (nl *nodeList) popRandom() *Node {
   247  // 	ix := rand.Intn(len(*nq))
   248  // 	//TODO: probability as mentioned above.
   249  // 	nl.removeIndex(ix)
   250  // }
   251  //
   252  // func (nl *nodeList) removeIndex(i int) *Node {
   253  // 	slice = *nl
   254  // 	if len(*slice) <= i {
   255  // 		return nil
   256  // 	}
   257  // 	*nl = append(slice[:i], slice[i+1:]...)
   258  // }
   259  
   260  const nodeIDBits = 512
   261  
   262  // NodeID is a unique identifier for each node.
   263  // The node identifier is a marshaled elliptic curve public key.
   264  type NodeID [nodeIDBits / 8]byte
   265  
   266  // NodeID prints as a long hexadecimal number.
   267  func (n NodeID) String() string {
   268  	return fmt.Sprintf("%x", n[:])
   269  }
   270  
   271  // The Go syntax representation of a NodeID is a call to HexID.
   272  func (n NodeID) GoString() string {
   273  	return fmt.Sprintf("discover.HexID(\"%x\")", n[:])
   274  }
   275  
   276  // TerminalString returns a shortened hex string for terminal logging.
   277  func (n NodeID) TerminalString() string {
   278  	return hex.EncodeToString(n[:8])
   279  }
   280  
   281  // HexID converts a hex string to a NodeID.
   282  // The string may be prefixed with 0x.
   283  func HexID(in string) (NodeID, error) {
   284  	var id NodeID
   285  	b, err := hex.DecodeString(strings.TrimPrefix(in, "0x"))
   286  	if err != nil {
   287  		return id, err
   288  	} else if len(b) != len(id) {
   289  		return id, fmt.Errorf("wrong length, want %d hex chars", len(id)*2)
   290  	}
   291  	copy(id[:], b)
   292  	return id, nil
   293  }
   294  
   295  // MustHexID converts a hex string to a NodeID.
   296  // It panics if the string is not a valid NodeID.
   297  func MustHexID(in string) NodeID {
   298  	id, err := HexID(in)
   299  	if err != nil {
   300  		panic(err)
   301  	}
   302  	return id
   303  }
   304  
   305  // PubkeyID returns a marshaled representation of the given public key.
   306  func PubkeyID(pub *ecdsa.PublicKey) NodeID {
   307  	var id NodeID
   308  	pbytes := elliptic.Marshal(pub.Curve, pub.X, pub.Y)
   309  	if len(pbytes)-1 != len(id) {
   310  		panic(fmt.Errorf("need %d bit pubkey, got %d bits", (len(id)+1)*8, len(pbytes)))
   311  	}
   312  	copy(id[:], pbytes[1:])
   313  	return id
   314  }
   315  
   316  // Pubkey returns the public key represented by the node ID.
   317  // It returns an error if the ID is not a point on the curve.
   318  func (n NodeID) Pubkey() (*ecdsa.PublicKey, error) {
   319  	p := &ecdsa.PublicKey{Curve: crypto.S256(), X: new(big.Int), Y: new(big.Int)}
   320  	half := len(n) / 2
   321  	p.X.SetBytes(n[:half])
   322  	p.Y.SetBytes(n[half:])
   323  	if !p.Curve.IsOnCurve(p.X, p.Y) {
   324  		return nil, errors.New("id is invalid secp256k1 curve point")
   325  	}
   326  	return p, nil
   327  }
   328  
   329  func (id NodeID) mustPubkey() ecdsa.PublicKey {
   330  	pk, err := id.Pubkey()
   331  	if err != nil {
   332  		panic(err)
   333  	}
   334  	return *pk
   335  }
   336  
   337  // recoverNodeID computes the public key used to sign the
   338  // given hash from the signature.
   339  func recoverNodeID(hash, sig []byte) (id NodeID, err error) {
   340  	pubkey, err := crypto.Ecrecover(hash, sig)
   341  	if err != nil {
   342  		return id, err
   343  	}
   344  	if len(pubkey)-1 != len(id) {
   345  		return id, fmt.Errorf("recovered pubkey has %d bits, want %d bits", len(pubkey)*8, (len(id)+1)*8)
   346  	}
   347  	for i := range id {
   348  		id[i] = pubkey[i+1]
   349  	}
   350  	return id, nil
   351  }
   352  
   353  // distcmp compares the distances a->target and b->target.
   354  // Returns -1 if a is closer to target, 1 if b is closer to target
   355  // and 0 if they are equal.
   356  func distcmp(target, a, b common.Hash) int {
   357  	for i := range target {
   358  		da := a[i] ^ target[i]
   359  		db := b[i] ^ target[i]
   360  		if da > db {
   361  			return 1
   362  		} else if da < db {
   363  			return -1
   364  		}
   365  	}
   366  	return 0
   367  }
   368  
   369  // table of leading zero counts for bytes [0..255]
   370  var lzcount = [256]int{
   371  	8, 7, 6, 6, 5, 5, 5, 5,
   372  	4, 4, 4, 4, 4, 4, 4, 4,
   373  	3, 3, 3, 3, 3, 3, 3, 3,
   374  	3, 3, 3, 3, 3, 3, 3, 3,
   375  	2, 2, 2, 2, 2, 2, 2, 2,
   376  	2, 2, 2, 2, 2, 2, 2, 2,
   377  	2, 2, 2, 2, 2, 2, 2, 2,
   378  	2, 2, 2, 2, 2, 2, 2, 2,
   379  	1, 1, 1, 1, 1, 1, 1, 1,
   380  	1, 1, 1, 1, 1, 1, 1, 1,
   381  	1, 1, 1, 1, 1, 1, 1, 1,
   382  	1, 1, 1, 1, 1, 1, 1, 1,
   383  	1, 1, 1, 1, 1, 1, 1, 1,
   384  	1, 1, 1, 1, 1, 1, 1, 1,
   385  	1, 1, 1, 1, 1, 1, 1, 1,
   386  	1, 1, 1, 1, 1, 1, 1, 1,
   387  	0, 0, 0, 0, 0, 0, 0, 0,
   388  	0, 0, 0, 0, 0, 0, 0, 0,
   389  	0, 0, 0, 0, 0, 0, 0, 0,
   390  	0, 0, 0, 0, 0, 0, 0, 0,
   391  	0, 0, 0, 0, 0, 0, 0, 0,
   392  	0, 0, 0, 0, 0, 0, 0, 0,
   393  	0, 0, 0, 0, 0, 0, 0, 0,
   394  	0, 0, 0, 0, 0, 0, 0, 0,
   395  	0, 0, 0, 0, 0, 0, 0, 0,
   396  	0, 0, 0, 0, 0, 0, 0, 0,
   397  	0, 0, 0, 0, 0, 0, 0, 0,
   398  	0, 0, 0, 0, 0, 0, 0, 0,
   399  	0, 0, 0, 0, 0, 0, 0, 0,
   400  	0, 0, 0, 0, 0, 0, 0, 0,
   401  	0, 0, 0, 0, 0, 0, 0, 0,
   402  	0, 0, 0, 0, 0, 0, 0, 0,
   403  }
   404  
   405  // logdist returns the logarithmic distance between a and b, log2(a ^ b).
   406  func logdist(a, b common.Hash) int {
   407  	lz := 0
   408  	for i := range a {
   409  		x := a[i] ^ b[i]
   410  		if x == 0 {
   411  			lz += 8
   412  		} else {
   413  			lz += lzcount[x]
   414  			break
   415  		}
   416  	}
   417  	return len(a)*8 - lz
   418  }
   419  
   420  // hashAtDistance returns a random hash such that logdist(a, b) == n
   421  func hashAtDistance(a common.Hash, n int) (b common.Hash) {
   422  	if n == 0 {
   423  		return a
   424  	}
   425  	// flip bit at position n, fill the rest with random bits
   426  	b = a
   427  	pos := len(a) - n/8 - 1
   428  	bit := byte(0x01) << (byte(n%8) - 1)
   429  	if bit == 0 {
   430  		pos++
   431  		bit = 0x80
   432  	}
   433  	b[pos] = a[pos]&^bit | ^a[pos]&bit // TODO: randomize end bits
   434  	for i := pos + 1; i < len(a); i++ {
   435  		b[i] = byte(rand.Intn(255))
   436  	}
   437  	return b
   438  }