github.com/arieschain/arieschain@v0.0.0-20191023063405-37c074544356/p2p/discover/node.go (about)

     1  package discover
     2  
     3  import (
     4  	"crypto/ecdsa"
     5  	"crypto/elliptic"
     6  	"encoding/hex"
     7  	"errors"
     8  	"fmt"
     9  	"math/big"
    10  	"math/rand"
    11  	"net"
    12  	"net/url"
    13  	"regexp"
    14  	"strconv"
    15  	"strings"
    16  	"time"
    17  
    18  	"github.com/quickchainproject/quickchain/common"
    19  	"github.com/quickchainproject/quickchain/crypto"
    20  	"github.com/quickchainproject/quickchain/crypto/secp256k1"
    21  )
    22  
    23  const NodeIDBits = 512
    24  
    25  // Node represents a host on the network.
    26  // The fields of Node may not be modified.
    27  type Node struct {
    28  	IP       net.IP // len 4 for IPv4 or 16 for IPv6
    29  	UDP, TCP uint16 // port numbers
    30  	ID       NodeID // the node's public key
    31  
    32  	// This is a cached copy of sha3(ID) which is used for node
    33  	// distance calculations. This is part of Node in order to make it
    34  	// possible to write tests that need a node at a certain distance.
    35  	// In those tests, the content of sha will not actually correspond
    36  	// with ID.
    37  	sha common.Hash
    38  
    39  	// Time when the node was added to the table.
    40  	addedAt time.Time
    41  }
    42  
    43  // NewNode creates a new node. It is mostly meant to be used for
    44  // testing purposes.
    45  func NewNode(id NodeID, ip net.IP, udpPort, tcpPort uint16) *Node {
    46  	if ipv4 := ip.To4(); ipv4 != nil {
    47  		ip = ipv4
    48  	}
    49  	return &Node{
    50  		IP:  ip,
    51  		UDP: udpPort,
    52  		TCP: tcpPort,
    53  		ID:  id,
    54  		sha: crypto.Keccak256Hash(id[:]),
    55  	}
    56  }
    57  
    58  func (n *Node) addr() *net.UDPAddr {
    59  	return &net.UDPAddr{IP: n.IP, Port: int(n.UDP)}
    60  }
    61  
    62  // Incomplete returns true for nodes with no IP address.
    63  func (n *Node) Incomplete() bool {
    64  	return n.IP == nil
    65  }
    66  
    67  // checks whether n is a valid complete node.
    68  func (n *Node) validateComplete() error {
    69  	if n.Incomplete() {
    70  		return errors.New("incomplete node")
    71  	}
    72  	if n.UDP == 0 {
    73  		return errors.New("missing UDP port")
    74  	}
    75  	if n.TCP == 0 {
    76  		return errors.New("missing TCP port")
    77  	}
    78  	if n.IP.IsMulticast() || n.IP.IsUnspecified() {
    79  		return errors.New("invalid IP (multicast/unspecified)")
    80  	}
    81  	_, err := n.ID.Pubkey() // validate the key (on curve, etc.)
    82  	return err
    83  }
    84  
    85  // The string representation of a Node is a URL.
    86  // Please see ParseNode for a description of the format.
    87  func (n *Node) String() string {
    88  	u := url.URL{Scheme: "enode"}
    89  	if n.Incomplete() {
    90  		u.Host = fmt.Sprintf("%x", n.ID[:])
    91  	} else {
    92  		addr := net.TCPAddr{IP: n.IP, Port: int(n.TCP)}
    93  		u.User = url.User(fmt.Sprintf("%x", n.ID[:]))
    94  		u.Host = addr.String()
    95  		if n.UDP != n.TCP {
    96  			u.RawQuery = "discport=" + strconv.Itoa(int(n.UDP))
    97  		}
    98  	}
    99  	return u.String()
   100  }
   101  
   102  var incompleteNodeURL = regexp.MustCompile("(?i)^(?:enode://)?([0-9a-f]+)$")
   103  
   104  // ParseNode parses a node designator.
   105  //
   106  // There are two basic forms of node designators
   107  //   - incomplete nodes, which only have the public key (node ID)
   108  //   - complete nodes, which contain the public key and IP/Port information
   109  //
   110  // For incomplete nodes, the designator must look like one of these
   111  //
   112  //    enode://<hex node id>
   113  //    <hex node id>
   114  //
   115  // For complete nodes, the node ID is encoded in the username portion
   116  // of the URL, separated from the host by an @ sign. The hostname can
   117  // only be given as an IP address, DNS domain names are not allowed.
   118  // The port in the host name section is the TCP listening port. If the
   119  // TCP and UDP (discovery) ports differ, the UDP port is specified as
   120  // query parameter "discport".
   121  //
   122  // In the following example, the node URL describes
   123  // a node with IP address 10.3.58.6, TCP listening port 36663
   124  // and UDP discovery port 36661.
   125  //
   126  //    enode://<hex node id>@10.3.58.6:36663?discport=36661
   127  func ParseNode(rawurl string) (*Node, error) {
   128  	if m := incompleteNodeURL.FindStringSubmatch(rawurl); m != nil {
   129  		id, err := HexID(m[1])
   130  		if err != nil {
   131  			return nil, fmt.Errorf("invalid node ID (%v)", err)
   132  		}
   133  		return NewNode(id, nil, 0, 0), nil
   134  	}
   135  	return parseComplete(rawurl)
   136  }
   137  
   138  func parseComplete(rawurl string) (*Node, error) {
   139  	var (
   140  		id               NodeID
   141  		ip               net.IP
   142  		tcpPort, udpPort uint64
   143  	)
   144  	u, err := url.Parse(rawurl)
   145  	if err != nil {
   146  		return nil, err
   147  	}
   148  	if u.Scheme != "enode" {
   149  		return nil, errors.New("invalid URL scheme, want \"enode\"")
   150  	}
   151  	// Parse the Node ID from the user portion.
   152  	if u.User == nil {
   153  		return nil, errors.New("does not contain node ID")
   154  	}
   155  	if id, err = HexID(u.User.String()); err != nil {
   156  		return nil, fmt.Errorf("invalid node ID (%v)", err)
   157  	}
   158  	// Parse the IP address.
   159  	host, port, err := net.SplitHostPort(u.Host)
   160  	if err != nil {
   161  		return nil, fmt.Errorf("invalid host: %v", err)
   162  	}
   163  	if ip = net.ParseIP(host); ip == nil {
   164  		return nil, errors.New("invalid IP address")
   165  	}
   166  	// Ensure the IP is 4 bytes long for IPv4 addresses.
   167  	if ipv4 := ip.To4(); ipv4 != nil {
   168  		ip = ipv4
   169  	}
   170  	// Parse the port numbers.
   171  	if tcpPort, err = strconv.ParseUint(port, 10, 16); err != nil {
   172  		return nil, errors.New("invalid port")
   173  	}
   174  	udpPort = tcpPort
   175  	qv := u.Query()
   176  	if qv.Get("discport") != "" {
   177  		udpPort, err = strconv.ParseUint(qv.Get("discport"), 10, 16)
   178  		if err != nil {
   179  			return nil, errors.New("invalid discport in query")
   180  		}
   181  	}
   182  	return NewNode(id, ip, uint16(udpPort), uint16(tcpPort)), nil
   183  }
   184  
   185  // MustParseNode parses a node URL. It panics if the URL is not valid.
   186  func MustParseNode(rawurl string) *Node {
   187  	n, err := ParseNode(rawurl)
   188  	if err != nil {
   189  		panic("invalid node URL: " + err.Error())
   190  	}
   191  	return n
   192  }
   193  
   194  // MarshalText implements encoding.TextMarshaler.
   195  func (n *Node) MarshalText() ([]byte, error) {
   196  	return []byte(n.String()), nil
   197  }
   198  
   199  // UnmarshalText implements encoding.TextUnmarshaler.
   200  func (n *Node) UnmarshalText(text []byte) error {
   201  	dec, err := ParseNode(string(text))
   202  	if err == nil {
   203  		*n = *dec
   204  	}
   205  	return err
   206  }
   207  
   208  // NodeID is a unique identifier for each node.
   209  // The node identifier is a marshaled elliptic curve public key.
   210  type NodeID [NodeIDBits / 8]byte
   211  
   212  // Bytes returns a byte slice representation of the NodeID
   213  func (n NodeID) Bytes() []byte {
   214  	return n[:]
   215  }
   216  
   217  // NodeID prints as a long hexadecimal number.
   218  func (n NodeID) String() string {
   219  	return fmt.Sprintf("%x", n[:])
   220  }
   221  
   222  // The Go syntax representation of a NodeID is a call to HexID.
   223  func (n NodeID) GoString() string {
   224  	return fmt.Sprintf("discover.HexID(\"%x\")", n[:])
   225  }
   226  
   227  // TerminalString returns a shortened hex string for terminal logging.
   228  func (n NodeID) TerminalString() string {
   229  	return hex.EncodeToString(n[:8])
   230  }
   231  
   232  // MarshalText implements the encoding.TextMarshaler interface.
   233  func (n NodeID) MarshalText() ([]byte, error) {
   234  	return []byte(hex.EncodeToString(n[:])), nil
   235  }
   236  
   237  // UnmarshalText implements the encoding.TextUnmarshaler interface.
   238  func (n *NodeID) UnmarshalText(text []byte) error {
   239  	id, err := HexID(string(text))
   240  	if err != nil {
   241  		return err
   242  	}
   243  	*n = id
   244  	return nil
   245  }
   246  
   247  // BytesID converts a byte slice to a NodeID
   248  func BytesID(b []byte) (NodeID, error) {
   249  	var id NodeID
   250  	if len(b) != len(id) {
   251  		return id, fmt.Errorf("wrong length, want %d bytes", len(id))
   252  	}
   253  	copy(id[:], b)
   254  	return id, nil
   255  }
   256  
   257  // MustBytesID converts a byte slice to a NodeID.
   258  // It panics if the byte slice is not a valid NodeID.
   259  func MustBytesID(b []byte) NodeID {
   260  	id, err := BytesID(b)
   261  	if err != nil {
   262  		panic(err)
   263  	}
   264  	return id
   265  }
   266  
   267  // HexID converts a hex string to a NodeID.
   268  // The string may be prefixed with 0x.
   269  func HexID(in string) (NodeID, error) {
   270  	var id NodeID
   271  	b, err := hex.DecodeString(strings.TrimPrefix(in, "0x"))
   272  	if err != nil {
   273  		return id, err
   274  	} else if len(b) != len(id) {
   275  		return id, fmt.Errorf("wrong length, want %d hex chars", len(id)*2)
   276  	}
   277  	copy(id[:], b)
   278  	return id, nil
   279  }
   280  
   281  // MustHexID converts a hex string to a NodeID.
   282  // It panics if the string is not a valid NodeID.
   283  func MustHexID(in string) NodeID {
   284  	id, err := HexID(in)
   285  	if err != nil {
   286  		panic(err)
   287  	}
   288  	return id
   289  }
   290  
   291  // PubkeyID returns a marshaled representation of the given public key.
   292  func PubkeyID(pub *ecdsa.PublicKey) NodeID {
   293  	var id NodeID
   294  	pbytes := elliptic.Marshal(pub.Curve, pub.X, pub.Y)
   295  	if len(pbytes)-1 != len(id) {
   296  		panic(fmt.Errorf("need %d bit pubkey, got %d bits", (len(id)+1)*8, len(pbytes)*8))
   297  	}
   298  	copy(id[:], pbytes[1:])
   299  	return id
   300  }
   301  
   302  // Pubkey returns the public key represented by the node ID.
   303  // It returns an error if the ID is not a point on the curve.
   304  func (id NodeID) Pubkey() (*ecdsa.PublicKey, error) {
   305  	p := &ecdsa.PublicKey{Curve: crypto.S256(), X: new(big.Int), Y: new(big.Int)}
   306  	half := len(id) / 2
   307  	p.X.SetBytes(id[:half])
   308  	p.Y.SetBytes(id[half:])
   309  	if !p.Curve.IsOnCurve(p.X, p.Y) {
   310  		return nil, errors.New("id is invalid secp256k1 curve point")
   311  	}
   312  	return p, nil
   313  }
   314  
   315  // recoverNodeID computes the public key used to sign the
   316  // given hash from the signature.
   317  func recoverNodeID(hash, sig []byte) (id NodeID, err error) {
   318  	pubkey, err := secp256k1.RecoverPubkey(hash, sig)
   319  	if err != nil {
   320  		return id, err
   321  	}
   322  	if len(pubkey)-1 != len(id) {
   323  		return id, fmt.Errorf("recovered pubkey has %d bits, want %d bits", len(pubkey)*8, (len(id)+1)*8)
   324  	}
   325  	for i := range id {
   326  		id[i] = pubkey[i+1]
   327  	}
   328  	return id, nil
   329  }
   330  
   331  // distcmp compares the distances a->target and b->target.
   332  // Returns -1 if a is closer to target, 1 if b is closer to target
   333  // and 0 if they are equal.
   334  func distcmp(target, a, b common.Hash) int {
   335  	for i := range target {
   336  		da := a[i] ^ target[i]
   337  		db := b[i] ^ target[i]
   338  		if da > db {
   339  			return 1
   340  		} else if da < db {
   341  			return -1
   342  		}
   343  	}
   344  	return 0
   345  }
   346  
   347  // table of leading zero counts for bytes [0..255]
   348  var lzcount = [256]int{
   349  	8, 7, 6, 6, 5, 5, 5, 5,
   350  	4, 4, 4, 4, 4, 4, 4, 4,
   351  	3, 3, 3, 3, 3, 3, 3, 3,
   352  	3, 3, 3, 3, 3, 3, 3, 3,
   353  	2, 2, 2, 2, 2, 2, 2, 2,
   354  	2, 2, 2, 2, 2, 2, 2, 2,
   355  	2, 2, 2, 2, 2, 2, 2, 2,
   356  	2, 2, 2, 2, 2, 2, 2, 2,
   357  	1, 1, 1, 1, 1, 1, 1, 1,
   358  	1, 1, 1, 1, 1, 1, 1, 1,
   359  	1, 1, 1, 1, 1, 1, 1, 1,
   360  	1, 1, 1, 1, 1, 1, 1, 1,
   361  	1, 1, 1, 1, 1, 1, 1, 1,
   362  	1, 1, 1, 1, 1, 1, 1, 1,
   363  	1, 1, 1, 1, 1, 1, 1, 1,
   364  	1, 1, 1, 1, 1, 1, 1, 1,
   365  	0, 0, 0, 0, 0, 0, 0, 0,
   366  	0, 0, 0, 0, 0, 0, 0, 0,
   367  	0, 0, 0, 0, 0, 0, 0, 0,
   368  	0, 0, 0, 0, 0, 0, 0, 0,
   369  	0, 0, 0, 0, 0, 0, 0, 0,
   370  	0, 0, 0, 0, 0, 0, 0, 0,
   371  	0, 0, 0, 0, 0, 0, 0, 0,
   372  	0, 0, 0, 0, 0, 0, 0, 0,
   373  	0, 0, 0, 0, 0, 0, 0, 0,
   374  	0, 0, 0, 0, 0, 0, 0, 0,
   375  	0, 0, 0, 0, 0, 0, 0, 0,
   376  	0, 0, 0, 0, 0, 0, 0, 0,
   377  	0, 0, 0, 0, 0, 0, 0, 0,
   378  	0, 0, 0, 0, 0, 0, 0, 0,
   379  	0, 0, 0, 0, 0, 0, 0, 0,
   380  	0, 0, 0, 0, 0, 0, 0, 0,
   381  }
   382  
   383  // logdist returns the logarithmic distance between a and b, log2(a ^ b).
   384  func logdist(a, b common.Hash) int {
   385  	lz := 0
   386  	for i := range a {
   387  		x := a[i] ^ b[i]
   388  		if x == 0 {
   389  			lz += 8
   390  		} else {
   391  			lz += lzcount[x]
   392  			break
   393  		}
   394  	}
   395  	return len(a)*8 - lz
   396  }
   397  
   398  // hashAtDistance returns a random hash such that logdist(a, b) == n
   399  func hashAtDistance(a common.Hash, n int) (b common.Hash) {
   400  	if n == 0 {
   401  		return a
   402  	}
   403  	// flip bit at position n, fill the rest with random bits
   404  	b = a
   405  	pos := len(a) - n/8 - 1
   406  	bit := byte(0x01) << (byte(n%8) - 1)
   407  	if bit == 0 {
   408  		pos++
   409  		bit = 0x80
   410  	}
   411  	b[pos] = a[pos]&^bit | ^a[pos]&bit // TODO: randomize end bits
   412  	for i := pos + 1; i < len(a); i++ {
   413  		b[i] = byte(rand.Intn(255))
   414  	}
   415  	return b
   416  }