github.com/aquanetwork/aquachain@v1.7.8/p2p/discover/node.go (about)

     1  // Copyright 2015 The aquachain Authors
     2  // This file is part of the aquachain library.
     3  //
     4  // The aquachain library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The aquachain library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the aquachain library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package discover
    18  
    19  import (
    20  	"crypto/ecdsa"
    21  	"crypto/elliptic"
    22  	"encoding/hex"
    23  	"errors"
    24  	"fmt"
    25  	"math/big"
    26  	"math/rand"
    27  	"net"
    28  	"net/url"
    29  	"regexp"
    30  	"strconv"
    31  	"strings"
    32  	"time"
    33  
    34  	"gitlab.com/aquachain/aquachain/common"
    35  	"gitlab.com/aquachain/aquachain/crypto"
    36  )
    37  
    38  const NodeIDBits = 512
    39  
    40  // Node represents a host on the network.
    41  // The fields of Node may not be modified.
    42  type Node struct {
    43  	IP       net.IP // len 4 for IPv4 or 16 for IPv6
    44  	UDP, TCP uint16 // port numbers
    45  	ID       NodeID // the node's public key
    46  
    47  	// This is a cached copy of sha3(ID) which is used for node
    48  	// distance calculations. This is part of Node in order to make it
    49  	// possible to write tests that need a node at a certain distance.
    50  	// In those tests, the content of sha will not actually correspond
    51  	// with ID.
    52  	sha common.Hash
    53  
    54  	// Time when the node was added to the table.
    55  	addedAt time.Time
    56  }
    57  
    58  // NewNode creates a new node. It is mostly meant to be used for
    59  // testing purposes.
    60  func NewNode(id NodeID, ip net.IP, udpPort, tcpPort uint16) *Node {
    61  	if ipv4 := ip.To4(); ipv4 != nil {
    62  		ip = ipv4
    63  	}
    64  	return &Node{
    65  		IP:  ip,
    66  		UDP: udpPort,
    67  		TCP: tcpPort,
    68  		ID:  id,
    69  		sha: crypto.Keccak256Hash(id[:]),
    70  	}
    71  }
    72  
    73  func (n *Node) addr() *net.UDPAddr {
    74  	return &net.UDPAddr{IP: n.IP, Port: int(n.UDP)}
    75  }
    76  
    77  // Incomplete returns true for nodes with no IP address.
    78  func (n *Node) Incomplete() bool {
    79  	return n.IP == nil
    80  }
    81  
    82  // checks whether n is a valid complete node.
    83  func (n *Node) validateComplete() error {
    84  	if n.Incomplete() {
    85  		return errors.New("incomplete node")
    86  	}
    87  	if n.UDP == 0 {
    88  		return errors.New("missing UDP port")
    89  	}
    90  	if n.TCP == 0 {
    91  		return errors.New("missing TCP port")
    92  	}
    93  	if n.IP.IsMulticast() || n.IP.IsUnspecified() {
    94  		return errors.New("invalid IP (multicast/unspecified)")
    95  	}
    96  	_, err := n.ID.Pubkey() // validate the key (on curve, etc.)
    97  	return err
    98  }
    99  
   100  // The string representation of a Node is a URL.
   101  // Please see ParseNode for a description of the format.
   102  func (n *Node) String() string {
   103  	u := url.URL{Scheme: "enode"}
   104  	if n.Incomplete() {
   105  		u.Host = fmt.Sprintf("%x", n.ID[:])
   106  	} else {
   107  		addr := net.TCPAddr{IP: n.IP, Port: int(n.TCP)}
   108  		u.User = url.User(fmt.Sprintf("%x", n.ID[:]))
   109  		u.Host = addr.String()
   110  		if n.UDP != n.TCP {
   111  			u.RawQuery = "discport=" + strconv.Itoa(int(n.UDP))
   112  		}
   113  	}
   114  	return u.String()
   115  }
   116  
   117  var incompleteNodeURL = regexp.MustCompile("(?i)^(?:enode://)?([0-9a-f]+)$")
   118  
   119  // ParseNode parses a node designator.
   120  //
   121  // There are two basic forms of node designators
   122  //   - incomplete nodes, which only have the public key (node ID)
   123  //   - complete nodes, which contain the public key and IP/Port information
   124  //
   125  // For incomplete nodes, the designator must look like one of these
   126  //
   127  //    enode://<hex node id>
   128  //    <hex node id>
   129  //
   130  // For complete nodes, the node ID is encoded in the username portion
   131  // of the URL, separated from the host by an @ sign. The hostname can
   132  // only be given as an IP address, DNS domain names are not allowed.
   133  // The port in the host name section is the TCP listening port. If the
   134  // TCP and UDP (discovery) ports differ, the UDP port is specified as
   135  // query parameter "discport".
   136  //
   137  // In the following example, the node URL describes
   138  // a node with IP address 10.3.58.6, TCP listening port 30303
   139  // and UDP discovery port 30301.
   140  //
   141  //    enode://<hex node id>@10.3.58.6:30303?discport=30301
   142  func ParseNode(rawurl string) (*Node, error) {
   143  	if m := incompleteNodeURL.FindStringSubmatch(rawurl); m != nil {
   144  		id, err := HexID(m[1])
   145  		if err != nil {
   146  			return nil, fmt.Errorf("invalid node ID (%v)", err)
   147  		}
   148  		return NewNode(id, nil, 0, 0), nil
   149  	}
   150  	return parseComplete(rawurl)
   151  }
   152  
   153  func parseComplete(rawurl string) (*Node, error) {
   154  	var (
   155  		id               NodeID
   156  		ip               net.IP
   157  		tcpPort, udpPort uint64
   158  	)
   159  	u, err := url.Parse(rawurl)
   160  	if err != nil {
   161  		return nil, err
   162  	}
   163  	if u.Scheme != "enode" {
   164  		return nil, errors.New("invalid URL scheme, want \"enode\"")
   165  	}
   166  	// Parse the Node ID from the user portion.
   167  	if u.User == nil {
   168  		return nil, errors.New("does not contain node ID")
   169  	}
   170  	if id, err = HexID(u.User.String()); err != nil {
   171  		return nil, fmt.Errorf("invalid node ID (%v)", err)
   172  	}
   173  	// Parse the IP address.
   174  	host, port, err := net.SplitHostPort(u.Host)
   175  	if err != nil {
   176  		return nil, fmt.Errorf("invalid host: %v", err)
   177  	}
   178  	if ip = net.ParseIP(host); ip == nil {
   179  		return nil, errors.New("invalid IP address")
   180  	}
   181  	// Ensure the IP is 4 bytes long for IPv4 addresses.
   182  	if ipv4 := ip.To4(); ipv4 != nil {
   183  		ip = ipv4
   184  	}
   185  	// Parse the port numbers.
   186  	if tcpPort, err = strconv.ParseUint(port, 10, 16); err != nil {
   187  		return nil, errors.New("invalid port")
   188  	}
   189  	udpPort = tcpPort
   190  	qv := u.Query()
   191  	if qv.Get("discport") != "" {
   192  		udpPort, err = strconv.ParseUint(qv.Get("discport"), 10, 16)
   193  		if err != nil {
   194  			return nil, errors.New("invalid discport in query")
   195  		}
   196  	}
   197  	return NewNode(id, ip, uint16(udpPort), uint16(tcpPort)), nil
   198  }
   199  
   200  // MustParseNode parses a node URL. It panics if the URL is not valid.
   201  func MustParseNode(rawurl string) *Node {
   202  	n, err := ParseNode(rawurl)
   203  	if err != nil {
   204  		panic("invalid node URL: " + err.Error())
   205  	}
   206  	return n
   207  }
   208  
   209  // MarshalText implements encoding.TextMarshaler.
   210  func (n *Node) MarshalText() ([]byte, error) {
   211  	return []byte(n.String()), nil
   212  }
   213  
   214  // UnmarshalText implements encoding.TextUnmarshaler.
   215  func (n *Node) UnmarshalText(text []byte) error {
   216  	dec, err := ParseNode(string(text))
   217  	if err == nil {
   218  		*n = *dec
   219  	}
   220  	return err
   221  }
   222  
   223  // NodeID is a unique identifier for each node.
   224  // The node identifier is a marshaled elliptic curve public key.
   225  type NodeID [NodeIDBits / 8]byte
   226  
   227  // Bytes returns a byte slice representation of the NodeID
   228  func (n NodeID) Bytes() []byte {
   229  	return n[:]
   230  }
   231  
   232  // NodeID prints as a long hexadecimal number.
   233  func (n NodeID) String() string {
   234  	return fmt.Sprintf("%x", n[:])
   235  }
   236  
   237  // The Go syntax representation of a NodeID is a call to HexID.
   238  func (n NodeID) GoString() string {
   239  	return fmt.Sprintf("discover.HexID(\"%x\")", n[:])
   240  }
   241  
   242  // TerminalString returns a shortened hex string for terminal logging.
   243  func (n NodeID) TerminalString() string {
   244  	return hex.EncodeToString(n[:8])
   245  }
   246  
   247  // MarshalText implements the encoding.TextMarshaler interface.
   248  func (n NodeID) MarshalText() ([]byte, error) {
   249  	return []byte(hex.EncodeToString(n[:])), nil
   250  }
   251  
   252  // UnmarshalText implements the encoding.TextUnmarshaler interface.
   253  func (n *NodeID) UnmarshalText(text []byte) error {
   254  	id, err := HexID(string(text))
   255  	if err != nil {
   256  		return err
   257  	}
   258  	*n = id
   259  	return nil
   260  }
   261  
   262  // BytesID converts a byte slice to a NodeID
   263  func BytesID(b []byte) (NodeID, error) {
   264  	var id NodeID
   265  	if len(b) != len(id) {
   266  		return id, fmt.Errorf("wrong length, want %d bytes", len(id))
   267  	}
   268  	copy(id[:], b)
   269  	return id, nil
   270  }
   271  
   272  // MustBytesID converts a byte slice to a NodeID.
   273  // It panics if the byte slice is not a valid NodeID.
   274  func MustBytesID(b []byte) NodeID {
   275  	id, err := BytesID(b)
   276  	if err != nil {
   277  		panic(err)
   278  	}
   279  	return id
   280  }
   281  
   282  // HexID converts a hex string to a NodeID.
   283  // The string may be prefixed with 0x.
   284  func HexID(in string) (NodeID, error) {
   285  	var id NodeID
   286  	b, err := hex.DecodeString(strings.TrimPrefix(in, "0x"))
   287  	if err != nil {
   288  		return id, err
   289  	} else if len(b) != len(id) {
   290  		return id, fmt.Errorf("wrong length, want %d hex chars", len(id)*2)
   291  	}
   292  	copy(id[:], b)
   293  	return id, nil
   294  }
   295  
   296  // MustHexID converts a hex string to a NodeID.
   297  // It panics if the string is not a valid NodeID.
   298  func MustHexID(in string) NodeID {
   299  	id, err := HexID(in)
   300  	if err != nil {
   301  		panic(err)
   302  	}
   303  	return id
   304  }
   305  
   306  // PubkeyID returns a marshaled representation of the given public key.
   307  func PubkeyID(pub *ecdsa.PublicKey) NodeID {
   308  	var id NodeID
   309  	pbytes := elliptic.Marshal(pub.Curve, pub.X, pub.Y)
   310  	if len(pbytes)-1 != len(id) {
   311  		panic(fmt.Errorf("need %d bit pubkey, got %d bits", (len(id)+1)*8, len(pbytes)))
   312  	}
   313  	copy(id[:], pbytes[1:])
   314  	return id
   315  }
   316  
   317  // Pubkey returns the public key represented by the node ID.
   318  // It returns an error if the ID is not a point on the curve.
   319  func (id NodeID) Pubkey() (*ecdsa.PublicKey, error) {
   320  	p := &ecdsa.PublicKey{Curve: crypto.S256(), X: new(big.Int), Y: new(big.Int)}
   321  	half := len(id) / 2
   322  	p.X.SetBytes(id[:half])
   323  	p.Y.SetBytes(id[half:])
   324  	if !p.Curve.IsOnCurve(p.X, p.Y) {
   325  		return nil, errors.New("id is invalid secp256k1 curve point")
   326  	}
   327  	return p, nil
   328  }
   329  
   330  // recoverNodeID computes the public key used to sign the
   331  // given hash from the signature.
   332  func recoverNodeID(hash, sig []byte) (id NodeID, err error) {
   333  	pubkey, err := crypto.Ecrecover(hash, sig)
   334  	if err != nil {
   335  		return id, err
   336  	}
   337  	if len(pubkey)-1 != len(id) {
   338  		return id, fmt.Errorf("recovered pubkey has %d bits, want %d bits", len(pubkey)*8, (len(id)+1)*8)
   339  	}
   340  	for i := range id {
   341  		id[i] = pubkey[i+1]
   342  	}
   343  	return id, nil
   344  }
   345  
   346  // distcmp compares the distances a->target and b->target.
   347  // Returns -1 if a is closer to target, 1 if b is closer to target
   348  // and 0 if they are equal.
   349  func distcmp(target, a, b common.Hash) int {
   350  	for i := range target {
   351  		da := a[i] ^ target[i]
   352  		db := b[i] ^ target[i]
   353  		if da > db {
   354  			return 1
   355  		} else if da < db {
   356  			return -1
   357  		}
   358  	}
   359  	return 0
   360  }
   361  
   362  // table of leading zero counts for bytes [0..255]
   363  var lzcount = [256]int{
   364  	8, 7, 6, 6, 5, 5, 5, 5,
   365  	4, 4, 4, 4, 4, 4, 4, 4,
   366  	3, 3, 3, 3, 3, 3, 3, 3,
   367  	3, 3, 3, 3, 3, 3, 3, 3,
   368  	2, 2, 2, 2, 2, 2, 2, 2,
   369  	2, 2, 2, 2, 2, 2, 2, 2,
   370  	2, 2, 2, 2, 2, 2, 2, 2,
   371  	2, 2, 2, 2, 2, 2, 2, 2,
   372  	1, 1, 1, 1, 1, 1, 1, 1,
   373  	1, 1, 1, 1, 1, 1, 1, 1,
   374  	1, 1, 1, 1, 1, 1, 1, 1,
   375  	1, 1, 1, 1, 1, 1, 1, 1,
   376  	1, 1, 1, 1, 1, 1, 1, 1,
   377  	1, 1, 1, 1, 1, 1, 1, 1,
   378  	1, 1, 1, 1, 1, 1, 1, 1,
   379  	1, 1, 1, 1, 1, 1, 1, 1,
   380  	0, 0, 0, 0, 0, 0, 0, 0,
   381  	0, 0, 0, 0, 0, 0, 0, 0,
   382  	0, 0, 0, 0, 0, 0, 0, 0,
   383  	0, 0, 0, 0, 0, 0, 0, 0,
   384  	0, 0, 0, 0, 0, 0, 0, 0,
   385  	0, 0, 0, 0, 0, 0, 0, 0,
   386  	0, 0, 0, 0, 0, 0, 0, 0,
   387  	0, 0, 0, 0, 0, 0, 0, 0,
   388  	0, 0, 0, 0, 0, 0, 0, 0,
   389  	0, 0, 0, 0, 0, 0, 0, 0,
   390  	0, 0, 0, 0, 0, 0, 0, 0,
   391  	0, 0, 0, 0, 0, 0, 0, 0,
   392  	0, 0, 0, 0, 0, 0, 0, 0,
   393  	0, 0, 0, 0, 0, 0, 0, 0,
   394  	0, 0, 0, 0, 0, 0, 0, 0,
   395  	0, 0, 0, 0, 0, 0, 0, 0,
   396  }
   397  
   398  // logdist returns the logarithmic distance between a and b, log2(a ^ b).
   399  func logdist(a, b common.Hash) int {
   400  	lz := 0
   401  	for i := range a {
   402  		x := a[i] ^ b[i]
   403  		if x == 0 {
   404  			lz += 8
   405  		} else {
   406  			lz += lzcount[x]
   407  			break
   408  		}
   409  	}
   410  	return len(a)*8 - lz
   411  }
   412  
   413  // hashAtDistance returns a random hash such that logdist(a, b) == n
   414  func hashAtDistance(a common.Hash, n int) (b common.Hash) {
   415  	if n == 0 {
   416  		return a
   417  	}
   418  	// flip bit at position n, fill the rest with random bits
   419  	b = a
   420  	pos := len(a) - n/8 - 1
   421  	bit := byte(0x01) << (byte(n%8) - 1)
   422  	if bit == 0 {
   423  		pos++
   424  		bit = 0x80
   425  	}
   426  	b[pos] = a[pos]&^bit | ^a[pos]&bit // TODO: randomize end bits
   427  	for i := pos + 1; i < len(a); i++ {
   428  		b[i] = byte(rand.Intn(255))
   429  	}
   430  	return b
   431  }