gitlab.com/aquachain/aquachain@v1.17.16-rc3.0.20221018032414-e3ddf1e1c055/p2p/discover/node.go (about)

     1  // Copyright 2018 The aquachain Authors
     2  // This file is part of the aquachain library.
     3  //
     4  // The aquachain library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The aquachain library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the aquachain library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package discover
    18  
    19  import (
    20  	"crypto/ecdsa"
    21  	"crypto/elliptic"
    22  	"encoding/hex"
    23  	"errors"
    24  	"fmt"
    25  	"math/big"
    26  	"net"
    27  	"net/url"
    28  	"regexp"
    29  	"strconv"
    30  	"strings"
    31  	"time"
    32  
    33  	"gitlab.com/aquachain/aquachain/common"
    34  	"gitlab.com/aquachain/aquachain/crypto"
    35  )
    36  
    37  const NodeIDBits = 512
    38  
    39  // Node represents a host on the network.
    40  // The fields of Node may not be modified.
    41  type Node struct {
    42  	IP       net.IP // len 4 for IPv4 or 16 for IPv6
    43  	UDP, TCP uint16 // port numbers
    44  	ID       NodeID // the node's public key
    45  
    46  	// This is a cached copy of sha3(ID) which is used for node
    47  	// distance calculations. This is part of Node in order to make it
    48  	// possible to write tests that need a node at a certain distance.
    49  	// In those tests, the content of sha will not actually correspond
    50  	// with ID.
    51  	sha common.Hash
    52  
    53  	// Time when the node was added to the table.
    54  	addedAt time.Time
    55  }
    56  
    57  // NewNode creates a new node. It is mostly meant to be used for
    58  // testing purposes.
    59  func NewNode(id NodeID, ip net.IP, udpPort, tcpPort uint16) *Node {
    60  	if ipv4 := ip.To4(); ipv4 != nil {
    61  		ip = ipv4
    62  	}
    63  	return &Node{
    64  		IP:  ip,
    65  		UDP: udpPort,
    66  		TCP: tcpPort,
    67  		ID:  id,
    68  		sha: crypto.Keccak256Hash(id[:]),
    69  	}
    70  }
    71  
    72  func (n *Node) addr() *net.UDPAddr {
    73  	return &net.UDPAddr{IP: n.IP, Port: int(n.UDP)}
    74  }
    75  
    76  // Incomplete returns true for nodes with no IP address.
    77  func (n *Node) Incomplete() bool {
    78  	return n.IP == nil
    79  }
    80  
    81  // checks whether n is a valid complete node.
    82  func (n *Node) validateComplete() error {
    83  	if n.Incomplete() {
    84  		return errors.New("incomplete node")
    85  	}
    86  	if n.UDP == 0 {
    87  		return errors.New("missing UDP port")
    88  	}
    89  	if n.TCP == 0 {
    90  		return errors.New("missing TCP port")
    91  	}
    92  	if n.IP.IsMulticast() || n.IP.IsUnspecified() {
    93  		return errors.New("invalid IP (multicast/unspecified)")
    94  	}
    95  	_, err := n.ID.Pubkey() // validate the key (on curve, etc.)
    96  	return err
    97  }
    98  
    99  // The string representation of a Node is a URL.
   100  // Please see ParseNode for a description of the format.
   101  func (n *Node) String() string {
   102  	u := url.URL{Scheme: "enode"}
   103  	if n.Incomplete() {
   104  		u.Host = fmt.Sprintf("%x", n.ID[:])
   105  	} else {
   106  		addr := net.TCPAddr{IP: n.IP, Port: int(n.TCP)}
   107  		u.User = url.User(fmt.Sprintf("%x", n.ID[:]))
   108  		u.Host = addr.String()
   109  		if n.UDP != n.TCP {
   110  			u.RawQuery = "discport=" + strconv.Itoa(int(n.UDP))
   111  		}
   112  	}
   113  	return u.String()
   114  }
   115  
   116  var incompleteNodeURL = regexp.MustCompile("(?i)^(?:enode://)?([0-9a-f]+)$")
   117  
   118  // ParseNode parses a node designator.
   119  //
   120  // There are two basic forms of node designators
   121  //   - incomplete nodes, which only have the public key (node ID)
   122  //   - complete nodes, which contain the public key and IP/Port information
   123  //
   124  // For incomplete nodes, the designator must look like one of these
   125  //
   126  //    enode://<hex node id>
   127  //    <hex node id>
   128  //
   129  // For complete nodes, the node ID is encoded in the username portion
   130  // of the URL, separated from the host by an @ sign. The hostname can
   131  // only be given as an IP address, DNS domain names are not allowed.
   132  // The port in the host name section is the TCP listening port. If the
   133  // TCP and UDP (discovery) ports differ, the UDP port is specified as
   134  // query parameter "discport".
   135  //
   136  // In the following example, the node URL describes
   137  // a node with IP address 10.3.58.6, TCP listening port 30303
   138  // and UDP discovery port 30301.
   139  //
   140  //    enode://<hex node id>@10.3.58.6:30303?discport=30301
   141  func ParseNode(rawurl string) (*Node, error) {
   142  	if m := incompleteNodeURL.FindStringSubmatch(rawurl); m != nil {
   143  		id, err := HexID(m[1])
   144  		if err != nil {
   145  			return nil, fmt.Errorf("invalid node ID (%v)", err)
   146  		}
   147  		return NewNode(id, nil, 0, 0), nil
   148  	}
   149  	return parseComplete(rawurl)
   150  }
   151  
   152  func parseComplete(rawurl string) (*Node, error) {
   153  	var (
   154  		id               NodeID
   155  		ip               net.IP
   156  		tcpPort, udpPort uint64
   157  	)
   158  	u, err := url.Parse(rawurl)
   159  	if err != nil {
   160  		return nil, err
   161  	}
   162  	if u.Scheme != "enode" {
   163  		return nil, errors.New("invalid URL scheme, want \"enode\"")
   164  	}
   165  	// Parse the Node ID from the user portion.
   166  	if u.User == nil {
   167  		return nil, errors.New("does not contain node ID")
   168  	}
   169  	if id, err = HexID(u.User.String()); err != nil {
   170  		return nil, fmt.Errorf("invalid node ID (%v)", err)
   171  	}
   172  	// Parse the IP address.
   173  	host, port, err := net.SplitHostPort(u.Host)
   174  	if err != nil {
   175  		return nil, fmt.Errorf("invalid host: %v", err)
   176  	}
   177  	if ip = net.ParseIP(host); ip == nil {
   178  		return nil, errors.New("invalid IP address")
   179  	}
   180  	// Ensure the IP is 4 bytes long for IPv4 addresses.
   181  	if ipv4 := ip.To4(); ipv4 != nil {
   182  		ip = ipv4
   183  	}
   184  	// Parse the port numbers.
   185  	if tcpPort, err = strconv.ParseUint(port, 10, 16); err != nil {
   186  		return nil, errors.New("invalid port")
   187  	}
   188  	udpPort = tcpPort
   189  	qv := u.Query()
   190  	if qv.Get("discport") != "" {
   191  		udpPort, err = strconv.ParseUint(qv.Get("discport"), 10, 16)
   192  		if err != nil {
   193  			return nil, errors.New("invalid discport in query")
   194  		}
   195  	}
   196  	return NewNode(id, ip, uint16(udpPort), uint16(tcpPort)), nil
   197  }
   198  
   199  // MustParseNode parses a node URL. It panics if the URL is not valid.
   200  func MustParseNode(rawurl string) *Node {
   201  	n, err := ParseNode(rawurl)
   202  	if err != nil {
   203  		panic("invalid node URL: " + err.Error())
   204  	}
   205  	return n
   206  }
   207  
   208  // MarshalText implements encoding.TextMarshaler.
   209  func (n *Node) MarshalText() ([]byte, error) {
   210  	return []byte(n.String()), nil
   211  }
   212  
   213  // UnmarshalText implements encoding.TextUnmarshaler.
   214  func (n *Node) UnmarshalText(text []byte) error {
   215  	dec, err := ParseNode(string(text))
   216  	if err == nil {
   217  		*n = *dec
   218  	}
   219  	return err
   220  }
   221  
   222  // NodeID is a unique identifier for each node.
   223  // The node identifier is a marshaled elliptic curve public key.
   224  type NodeID [NodeIDBits / 8]byte
   225  
   226  // Bytes returns a byte slice representation of the NodeID
   227  func (n NodeID) Bytes() []byte {
   228  	return n[:]
   229  }
   230  
   231  // NodeID prints as a long hexadecimal number.
   232  func (n NodeID) String() string {
   233  	return fmt.Sprintf("%x", n[:])
   234  }
   235  
   236  // The Go syntax representation of a NodeID is a call to HexID.
   237  func (n NodeID) GoString() string {
   238  	return fmt.Sprintf("discover.HexID(\"%x\")", n[:])
   239  }
   240  
   241  // TerminalString returns a shortened hex string for terminal logging.
   242  func (n NodeID) TerminalString() string {
   243  	return hex.EncodeToString(n[:8])
   244  }
   245  
   246  // MarshalText implements the encoding.TextMarshaler interface.
   247  func (n NodeID) MarshalText() ([]byte, error) {
   248  	return []byte(hex.EncodeToString(n[:])), nil
   249  }
   250  
   251  // UnmarshalText implements the encoding.TextUnmarshaler interface.
   252  func (n *NodeID) UnmarshalText(text []byte) error {
   253  	id, err := HexID(string(text))
   254  	if err != nil {
   255  		return err
   256  	}
   257  	*n = id
   258  	return nil
   259  }
   260  
   261  // BytesID converts a byte slice to a NodeID
   262  func BytesID(b []byte) (NodeID, error) {
   263  	var id NodeID
   264  	if len(b) != len(id) {
   265  		return id, fmt.Errorf("wrong length, want %d bytes", len(id))
   266  	}
   267  	copy(id[:], b)
   268  	return id, nil
   269  }
   270  
   271  // MustBytesID converts a byte slice to a NodeID.
   272  // It panics if the byte slice is not a valid NodeID.
   273  func MustBytesID(b []byte) NodeID {
   274  	id, err := BytesID(b)
   275  	if err != nil {
   276  		panic(err)
   277  	}
   278  	return id
   279  }
   280  
   281  // HexID converts a hex string to a NodeID.
   282  // The string may be prefixed with 0x.
   283  func HexID(in string) (NodeID, error) {
   284  	var id NodeID
   285  	b, err := hex.DecodeString(strings.TrimPrefix(in, "0x"))
   286  	if err != nil {
   287  		return id, err
   288  	} else if len(b) != len(id) {
   289  		return id, fmt.Errorf("wrong length, want %d hex chars", len(id)*2)
   290  	}
   291  	copy(id[:], b)
   292  	return id, nil
   293  }
   294  
   295  // MustHexID converts a hex string to a NodeID.
   296  // It panics if the string is not a valid NodeID.
   297  func MustHexID(in string) NodeID {
   298  	id, err := HexID(in)
   299  	if err != nil {
   300  		panic(err)
   301  	}
   302  	return id
   303  }
   304  
   305  // PubkeyID returns a marshaled representation of the given public key.
   306  func PubkeyID(pub *ecdsa.PublicKey) NodeID {
   307  	var id NodeID
   308  	pbytes := elliptic.Marshal(pub.Curve, pub.X, pub.Y)
   309  	if len(pbytes)-1 != len(id) {
   310  		panic(fmt.Errorf("need %d bit pubkey, got %d bits", (len(id)+1)*8, len(pbytes)))
   311  	}
   312  	copy(id[:], pbytes[1:])
   313  	return id
   314  }
   315  
   316  // Pubkey returns the public key represented by the node ID.
   317  // It returns an error if the ID is not a point on the curve.
   318  func (id NodeID) Pubkey() (*ecdsa.PublicKey, error) {
   319  	p := &ecdsa.PublicKey{Curve: crypto.S256(), X: new(big.Int), Y: new(big.Int)}
   320  	half := len(id) / 2
   321  	p.X.SetBytes(id[:half])
   322  	p.Y.SetBytes(id[half:])
   323  	if !p.Curve.IsOnCurve(p.X, p.Y) {
   324  		return nil, errors.New("id is invalid secp256k1 curve point")
   325  	}
   326  	return p, nil
   327  }
   328  
   329  // recoverNodeID computes the public key used to sign the
   330  // given hash from the signature.
   331  func recoverNodeID(hash, sig []byte) (id NodeID, err error) {
   332  	pubkey, err := crypto.Ecrecover(hash, sig)
   333  	if err != nil {
   334  		return id, err
   335  	}
   336  	if len(pubkey)-1 != len(id) {
   337  		return id, fmt.Errorf("recovered pubkey has %d bits, want %d bits", len(pubkey)*8, (len(id)+1)*8)
   338  	}
   339  	for i := range id {
   340  		id[i] = pubkey[i+1]
   341  	}
   342  	return id, nil
   343  }
   344  
   345  // distcmp compares the distances a->target and b->target.
   346  // Returns -1 if a is closer to target, 1 if b is closer to target
   347  // and 0 if they are equal.
   348  func distcmp(target, a, b common.Hash) int {
   349  	for i := range target {
   350  		da := a[i] ^ target[i]
   351  		db := b[i] ^ target[i]
   352  		if da > db {
   353  			return 1
   354  		} else if da < db {
   355  			return -1
   356  		}
   357  	}
   358  	return 0
   359  }
   360  
   361  // table of leading zero counts for bytes [0..255]
   362  var lzcount = [256]int{
   363  	8, 7, 6, 6, 5, 5, 5, 5,
   364  	4, 4, 4, 4, 4, 4, 4, 4,
   365  	3, 3, 3, 3, 3, 3, 3, 3,
   366  	3, 3, 3, 3, 3, 3, 3, 3,
   367  	2, 2, 2, 2, 2, 2, 2, 2,
   368  	2, 2, 2, 2, 2, 2, 2, 2,
   369  	2, 2, 2, 2, 2, 2, 2, 2,
   370  	2, 2, 2, 2, 2, 2, 2, 2,
   371  	1, 1, 1, 1, 1, 1, 1, 1,
   372  	1, 1, 1, 1, 1, 1, 1, 1,
   373  	1, 1, 1, 1, 1, 1, 1, 1,
   374  	1, 1, 1, 1, 1, 1, 1, 1,
   375  	1, 1, 1, 1, 1, 1, 1, 1,
   376  	1, 1, 1, 1, 1, 1, 1, 1,
   377  	1, 1, 1, 1, 1, 1, 1, 1,
   378  	1, 1, 1, 1, 1, 1, 1, 1,
   379  	0, 0, 0, 0, 0, 0, 0, 0,
   380  	0, 0, 0, 0, 0, 0, 0, 0,
   381  	0, 0, 0, 0, 0, 0, 0, 0,
   382  	0, 0, 0, 0, 0, 0, 0, 0,
   383  	0, 0, 0, 0, 0, 0, 0, 0,
   384  	0, 0, 0, 0, 0, 0, 0, 0,
   385  	0, 0, 0, 0, 0, 0, 0, 0,
   386  	0, 0, 0, 0, 0, 0, 0, 0,
   387  	0, 0, 0, 0, 0, 0, 0, 0,
   388  	0, 0, 0, 0, 0, 0, 0, 0,
   389  	0, 0, 0, 0, 0, 0, 0, 0,
   390  	0, 0, 0, 0, 0, 0, 0, 0,
   391  	0, 0, 0, 0, 0, 0, 0, 0,
   392  	0, 0, 0, 0, 0, 0, 0, 0,
   393  	0, 0, 0, 0, 0, 0, 0, 0,
   394  	0, 0, 0, 0, 0, 0, 0, 0,
   395  }
   396  
   397  // logdist returns the logarithmic distance between a and b, log2(a ^ b).
   398  func logdist(a, b common.Hash) int {
   399  	lz := 0
   400  	for i := range a {
   401  		x := a[i] ^ b[i]
   402  		if x == 0 {
   403  			lz += 8
   404  		} else {
   405  			lz += lzcount[x]
   406  			break
   407  		}
   408  	}
   409  	return len(a)*8 - lz
   410  }