github.com/klaytn/klaytn@v1.12.1/networks/p2p/discover/node.go (about)

     1  // Modifications Copyright 2018 The klaytn Authors
     2  // Copyright 2015 The go-ethereum Authors
     3  // This file is part of the go-ethereum library.
     4  //
     5  // The go-ethereum library is free software: you can redistribute it and/or modify
     6  // it under the terms of the GNU Lesser General Public License as published by
     7  // the Free Software Foundation, either version 3 of the License, or
     8  // (at your option) any later version.
     9  //
    10  // The go-ethereum library is distributed in the hope that it will be useful,
    11  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    12  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    13  // GNU Lesser General Public License for more details.
    14  //
    15  // You should have received a copy of the GNU Lesser General Public License
    16  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    17  //
    18  // This file is derived from p2p/discover/node.go (2018/06/04).
    19  // Modified and improved for the klaytn development.
    20  
    21  package discover
    22  
    23  import (
    24  	"crypto/ecdsa"
    25  	"crypto/elliptic"
    26  	"encoding/hex"
    27  	"errors"
    28  	"fmt"
    29  	"math/big"
    30  	"math/rand"
    31  	"net"
    32  	"net/url"
    33  	"regexp"
    34  	"strconv"
    35  	"strings"
    36  	"time"
    37  
    38  	"github.com/klaytn/klaytn/common"
    39  	"github.com/klaytn/klaytn/crypto"
    40  	"github.com/klaytn/klaytn/crypto/secp256k1"
    41  )
    42  
    43  const NodeIDBits = 512
    44  
    45  // Node represents a host on the network.
    46  // The fields of Node may not be modified.
    47  type Node struct {
    48  	IP    net.IP   // len 4 for IPv4 or 16 for IPv6
    49  	UDP   uint16   // discovery port numbers
    50  	TCP   uint16   // TCP listening port number
    51  	TCPs  []uint16 // TCP listening port number including both main port and subports
    52  	ID    NodeID   // the node's public key
    53  	NType NodeType // the node's type (cn, pn, en, bn)
    54  
    55  	// This is a cached copy of sha3(ID) which is used for node
    56  	// distance calculations. This is part of Node in order to make it
    57  	// possible to write tests that need a node at a certain distance.
    58  	// In those tests, the content of sha will not actually correspond
    59  	// with ID.
    60  	sha common.Hash
    61  
    62  	// Time when the node was added to the table.
    63  	addedAt time.Time
    64  	// PortOrder is the order of the ports that should be connected in multi-channel.
    65  	PortOrder uint16
    66  }
    67  
    68  // NewNode creates a new node. It is mostly meant to be used for
    69  // testing purposes.
    70  func NewNode(id NodeID, ip net.IP, udpPort, tcpPort uint16, tcpSubports []uint16, nType NodeType) *Node {
    71  	if ipv4 := ip.To4(); ipv4 != nil {
    72  		ip = ipv4
    73  	}
    74  	node := &Node{
    75  		IP:    ip,
    76  		UDP:   udpPort,
    77  		TCP:   tcpPort,
    78  		TCPs:  tcpSubports,
    79  		ID:    id,
    80  		NType: nType,
    81  		sha:   crypto.Keccak256Hash(id[:]),
    82  	}
    83  	node.AddSubport(tcpPort)
    84  	return node
    85  }
    86  
    87  func (n *Node) addr() *net.UDPAddr {
    88  	return &net.UDPAddr{IP: n.IP, Port: int(n.UDP)}
    89  }
    90  
    91  // Incomplete returns true for nodes with no IP address.
    92  func (n *Node) Incomplete() bool {
    93  	return n.IP == nil
    94  }
    95  
    96  // checks whether n is a valid complete node.
    97  func (n *Node) validateComplete() error {
    98  	if n.Incomplete() {
    99  		return errors.New("incomplete node")
   100  	}
   101  	if n.UDP == 0 {
   102  		return errors.New("missing UDP port")
   103  	}
   104  	if n.TCP == 0 {
   105  		return errors.New("missing TCP port")
   106  	}
   107  	if n.IP.IsMulticast() || n.IP.IsUnspecified() {
   108  		return errors.New("invalid IP (multicast/unspecified)")
   109  	}
   110  	_, err := n.ID.Pubkey() // validate the key (on curve, etc.)
   111  	return err
   112  }
   113  
   114  // The string representation of a Node is a URL.
   115  // Please see ParseNode for a description of the format.
   116  func (n *Node) String() string {
   117  	u := url.URL{Scheme: "kni"}
   118  	var query []string
   119  	if n.Incomplete() {
   120  		u.Host = fmt.Sprintf("%x", n.ID[:])
   121  	} else {
   122  		addr := net.TCPAddr{IP: n.IP, Port: int(n.TCP)}
   123  		u.User = url.User(fmt.Sprintf("%x", n.ID[:]))
   124  		u.Host = addr.String()
   125  		for _, tcp := range n.TCPs {
   126  			if tcp != n.TCP {
   127  				query = append(query, "subport="+strconv.Itoa(int(tcp)))
   128  			}
   129  		}
   130  		if n.UDP != n.TCP {
   131  			query = append(query, "discport="+strconv.Itoa(int(n.UDP)))
   132  		}
   133  	}
   134  	if n.NType != NodeTypeUnknown {
   135  		query = append(query, "ntype="+StringNodeType(n.NType))
   136  	}
   137  	u.RawQuery = strings.Join(query, "&")
   138  	return u.String()
   139  }
   140  
   141  // TODO-Klaytn-NodeDiscovery: Deprecate supporting "enode"
   142  var incompleteNodeURL = regexp.MustCompile("(?i)^(?:kni://|enode://)?([0-9a-f]+)$")
   143  var lookupIPFunc = net.LookupIP
   144  
   145  // ParseNode parses a node designator.
   146  //
   147  // There are two basic forms of node designators
   148  //   - incomplete nodes, which only have the public key (node ID)
   149  //   - complete nodes, which contain the public key and IP/Port information
   150  //
   151  // For incomplete nodes, the designator must look like one of these
   152  //
   153  //    kni://<hex node id> or enode://<hex node id>
   154  //    <hex node id>
   155  //
   156  // For complete nodes, the node ID is encoded in the username portion
   157  // of the URL, separated from the host by an @ sign. The hostname can
   158  // only be given as an IP address, DNS domain names are not allowed.
   159  // The port in the host name section is the TCP listening port. If the
   160  // TCP and UDP (discovery) ports differ, the UDP port is specified as
   161  // query parameter "discport".
   162  //
   163  // In the following examples, the node URL describes
   164  // a node with IP address 10.3.58.6, TCP listening port 30303
   165  // and UDP discovery port 30301.
   166  //
   167  //    kni://<hex node id>@10.3.58.6:30303?&subport=30304&discport=30301[&ntype=cn|pn|en|bn]
   168  //    enode://<hex node id>@10.3.58.6:30303?discport=30301[&ntype=cn|pn|en|bn]
   169  func ParseNode(rawurl string) (*Node, error) {
   170  	if m := incompleteNodeURL.FindStringSubmatch(rawurl); m != nil {
   171  		id, err := HexID(m[1])
   172  		if err != nil {
   173  			return nil, fmt.Errorf("invalid node ID (%v)", err)
   174  		}
   175  		return NewNode(id, nil, 0, 0, nil, NodeTypeUnknown), nil
   176  	}
   177  	return parseComplete(rawurl)
   178  }
   179  
   180  func parseComplete(rawurl string) (*Node, error) {
   181  	var (
   182  		id               NodeID
   183  		tcpPort, udpPort uint64
   184  		tcpSubports      []uint16
   185  	)
   186  	u, err := url.Parse(rawurl)
   187  	if err != nil {
   188  		return nil, err
   189  	}
   190  	if u.Scheme != "kni" && u.Scheme != "enode" {
   191  		return nil, errors.New("invalid URL scheme, want \"kni\"")
   192  	}
   193  	// Parse the Node ID from the user portion.
   194  	if u.User == nil {
   195  		return nil, errors.New("does not contain node ID")
   196  	}
   197  	if id, err = HexID(u.User.String()); err != nil {
   198  		return nil, fmt.Errorf("invalid node ID (%v)", err)
   199  	}
   200  
   201  	// Parse the host address and port.
   202  	ip := net.ParseIP(u.Hostname())
   203  	if ip == nil {
   204  		ips, err := lookupIPFunc(u.Hostname())
   205  		if err != nil {
   206  			return nil, err
   207  		}
   208  		ip = ips[0]
   209  	}
   210  
   211  	// Ensure the IP is 4 bytes long for IPv4 addresses.
   212  	if ipv4 := ip.To4(); ipv4 != nil {
   213  		ip = ipv4
   214  	}
   215  	// Parse the port numbers.
   216  	if tcpPort, err = strconv.ParseUint(u.Port(), 10, 16); err != nil {
   217  		return nil, errors.New("invalid port")
   218  	}
   219  	// Extract subport from query
   220  	qv := u.Query()
   221  	if qv.Get("subport") != "" {
   222  		for _, subport := range qv["subport"] {
   223  			if p, err := strconv.ParseUint(subport, 10, 16); err != nil {
   224  				logger.Warn("skipping invalid subport in query", "id", id, "ip", ip, "subport", p)
   225  			} else {
   226  				tcpSubports = append(tcpSubports, uint16(p))
   227  			}
   228  		}
   229  	}
   230  	// Extract discovery port from query
   231  	udpPort = tcpPort
   232  	if qv.Get("discport") != "" {
   233  		udpPort, err = strconv.ParseUint(qv.Get("discport"), 10, 16)
   234  		if err != nil {
   235  			return nil, errors.New("invalid discport in query")
   236  		}
   237  	}
   238  
   239  	nType := NodeTypeUnknown
   240  	if qv.Get("ntype") != "" {
   241  		nType = ParseNodeType(qv.Get("ntype"))
   242  	}
   243  	return NewNode(id, ip, uint16(udpPort), uint16(tcpPort), tcpSubports, nType), nil
   244  }
   245  
   246  // MustParseNode parses a node URL. It panics if the URL is not valid.
   247  func MustParseNode(rawurl string) *Node {
   248  	n, err := ParseNode(rawurl)
   249  	if err != nil {
   250  		panic("invalid node URL: " + err.Error())
   251  	}
   252  	return n
   253  }
   254  
   255  // MarshalText implements encoding.TextMarshaler.
   256  func (n *Node) MarshalText() ([]byte, error) {
   257  	return []byte(n.String()), nil
   258  }
   259  
   260  // UnmarshalText implements encoding.TextUnmarshaler.
   261  func (n *Node) UnmarshalText(text []byte) error {
   262  	dec, err := ParseNode(string(text))
   263  	if err == nil {
   264  		*n = *dec
   265  	}
   266  	return err
   267  }
   268  
   269  // CompareNode implements the compare the all node field and return its result
   270  func (n *Node) CompareNode(tn *Node) bool {
   271  	if n.ID != tn.ID {
   272  		return false
   273  	}
   274  	if n.TCP != tn.TCP {
   275  		return false
   276  	}
   277  	if n.UDP != tn.UDP {
   278  		return false
   279  	}
   280  	return true
   281  }
   282  
   283  // AddSubport adds a new port to TCPs
   284  // TCPs contains unique ports
   285  func (n *Node) AddSubport(port uint16) {
   286  	if n.TCPs == nil {
   287  		n.TCPs = []uint16{}
   288  	}
   289  	for _, val := range n.TCPs {
   290  		if val == port {
   291  			return
   292  		}
   293  	}
   294  	n.TCPs = append(n.TCPs, port)
   295  }
   296  
   297  // NodeID is a unique identifier for each node.
   298  // The node identifier is a marshaled elliptic curve public key.
   299  type NodeID [NodeIDBits / 8]byte
   300  
   301  // Bytes returns a byte slice representation of the NodeID
   302  func (n NodeID) Bytes() []byte {
   303  	return n[:]
   304  }
   305  
   306  // NodeID prints as a long hexadecimal number.
   307  func (n NodeID) String() string {
   308  	return fmt.Sprintf("%x", n[:])
   309  }
   310  
   311  // The Go syntax representation of a NodeID is a call to HexID.
   312  func (n NodeID) GoString() string {
   313  	return fmt.Sprintf("discover.HexID(\"%x\")", n[:])
   314  }
   315  
   316  // TerminalString returns a shortened hex string for terminal logging.
   317  func (n NodeID) TerminalString() string {
   318  	return hex.EncodeToString(n[:8])
   319  }
   320  
   321  // ShortString returns a shortened 4 digits hex string for logging
   322  func (n NodeID) ShortString() string {
   323  	return hex.EncodeToString(n[:4])
   324  }
   325  
   326  // MarshalText implements the encoding.TextMarshaler interface.
   327  func (n NodeID) MarshalText() ([]byte, error) {
   328  	return []byte(hex.EncodeToString(n[:])), nil
   329  }
   330  
   331  // UnmarshalText implements the encoding.TextUnmarshaler interface.
   332  func (n *NodeID) UnmarshalText(text []byte) error {
   333  	id, err := HexID(string(text))
   334  	if err != nil {
   335  		return err
   336  	}
   337  	*n = id
   338  	return nil
   339  }
   340  
   341  // BytesID converts a byte slice to a NodeID
   342  func BytesID(b []byte) (NodeID, error) {
   343  	var id NodeID
   344  	if len(b) != len(id) {
   345  		return id, fmt.Errorf("wrong length, want %d bytes", len(id))
   346  	}
   347  	copy(id[:], b)
   348  	return id, nil
   349  }
   350  
   351  // MustBytesID converts a byte slice to a NodeID.
   352  // It panics if the byte slice is not a valid NodeID.
   353  func MustBytesID(b []byte) NodeID {
   354  	id, err := BytesID(b)
   355  	if err != nil {
   356  		panic(err)
   357  	}
   358  	return id
   359  }
   360  
   361  // HexID converts a hex string to a NodeID.
   362  // The string may be prefixed with 0x.
   363  func HexID(in string) (NodeID, error) {
   364  	var id NodeID
   365  	b, err := hex.DecodeString(strings.TrimPrefix(in, "0x"))
   366  	if err != nil {
   367  		return id, err
   368  	} else if len(b) != len(id) {
   369  		return id, fmt.Errorf("wrong length, want %d hex chars", len(id)*2)
   370  	}
   371  	copy(id[:], b)
   372  	return id, nil
   373  }
   374  
   375  // MustHexID converts a hex string to a NodeID.
   376  // It panics if the string is not a valid NodeID.
   377  func MustHexID(in string) NodeID {
   378  	id, err := HexID(in)
   379  	if err != nil {
   380  		panic(err)
   381  	}
   382  	return id
   383  }
   384  
   385  // PubkeyID returns a marshaled representation of the given public key.
   386  func PubkeyID(pub *ecdsa.PublicKey) NodeID {
   387  	var id NodeID
   388  	pbytes := elliptic.Marshal(pub.Curve, pub.X, pub.Y)
   389  	if len(pbytes)-1 != len(id) {
   390  		panic(fmt.Errorf("need %d bit pubkey, got %d bits", (len(id)+1)*8, len(pbytes)))
   391  	}
   392  	copy(id[:], pbytes[1:])
   393  	return id
   394  }
   395  
   396  // Pubkey returns the public key represented by the node ID.
   397  // It returns an error if the ID is not a point on the curve.
   398  func (id NodeID) Pubkey() (*ecdsa.PublicKey, error) {
   399  	p := &ecdsa.PublicKey{Curve: crypto.S256(), X: new(big.Int), Y: new(big.Int)}
   400  	half := len(id) / 2
   401  	p.X.SetBytes(id[:half])
   402  	p.Y.SetBytes(id[half:])
   403  	if !p.Curve.IsOnCurve(p.X, p.Y) {
   404  		return nil, errors.New("id is invalid secp256k1 curve point")
   405  	}
   406  	return p, nil
   407  }
   408  
   409  // recoverNodeID computes the public key used to sign the
   410  // given hash from the signature.
   411  func recoverNodeID(hash, sig []byte) (id NodeID, err error) {
   412  	pubkey, err := secp256k1.RecoverPubkey(hash, sig)
   413  	if err != nil {
   414  		return id, err
   415  	}
   416  	if len(pubkey)-1 != len(id) {
   417  		return id, fmt.Errorf("recovered pubkey has %d bits, want %d bits", len(pubkey)*8, (len(id)+1)*8)
   418  	}
   419  	for i := range id {
   420  		id[i] = pubkey[i+1]
   421  	}
   422  	return id, nil
   423  }
   424  
   425  // distcmp compares the distances a->target and b->target.
   426  // Returns -1 if a is closer to target, 1 if b is closer to target
   427  // and 0 if they are equal.
   428  func distcmp(target, a, b common.Hash) int {
   429  	for i := range target {
   430  		da := a[i] ^ target[i]
   431  		db := b[i] ^ target[i]
   432  		if da > db {
   433  			return 1
   434  		} else if da < db {
   435  			return -1
   436  		}
   437  	}
   438  	return 0
   439  }
   440  
   441  // table of leading zero counts for bytes [0..255]
   442  var lzcount = [256]int{
   443  	8, 7, 6, 6, 5, 5, 5, 5,
   444  	4, 4, 4, 4, 4, 4, 4, 4,
   445  	3, 3, 3, 3, 3, 3, 3, 3,
   446  	3, 3, 3, 3, 3, 3, 3, 3,
   447  	2, 2, 2, 2, 2, 2, 2, 2,
   448  	2, 2, 2, 2, 2, 2, 2, 2,
   449  	2, 2, 2, 2, 2, 2, 2, 2,
   450  	2, 2, 2, 2, 2, 2, 2, 2,
   451  	1, 1, 1, 1, 1, 1, 1, 1,
   452  	1, 1, 1, 1, 1, 1, 1, 1,
   453  	1, 1, 1, 1, 1, 1, 1, 1,
   454  	1, 1, 1, 1, 1, 1, 1, 1,
   455  	1, 1, 1, 1, 1, 1, 1, 1,
   456  	1, 1, 1, 1, 1, 1, 1, 1,
   457  	1, 1, 1, 1, 1, 1, 1, 1,
   458  	1, 1, 1, 1, 1, 1, 1, 1,
   459  	0, 0, 0, 0, 0, 0, 0, 0,
   460  	0, 0, 0, 0, 0, 0, 0, 0,
   461  	0, 0, 0, 0, 0, 0, 0, 0,
   462  	0, 0, 0, 0, 0, 0, 0, 0,
   463  	0, 0, 0, 0, 0, 0, 0, 0,
   464  	0, 0, 0, 0, 0, 0, 0, 0,
   465  	0, 0, 0, 0, 0, 0, 0, 0,
   466  	0, 0, 0, 0, 0, 0, 0, 0,
   467  	0, 0, 0, 0, 0, 0, 0, 0,
   468  	0, 0, 0, 0, 0, 0, 0, 0,
   469  	0, 0, 0, 0, 0, 0, 0, 0,
   470  	0, 0, 0, 0, 0, 0, 0, 0,
   471  	0, 0, 0, 0, 0, 0, 0, 0,
   472  	0, 0, 0, 0, 0, 0, 0, 0,
   473  	0, 0, 0, 0, 0, 0, 0, 0,
   474  	0, 0, 0, 0, 0, 0, 0, 0,
   475  }
   476  
   477  // logdist returns the logarithmic distance between a and b, log2(a ^ b).
   478  func logdist(a, b common.Hash) int {
   479  	lz := 0
   480  	for i := range a {
   481  		x := a[i] ^ b[i]
   482  		if x == 0 {
   483  			lz += 8
   484  		} else {
   485  			lz += lzcount[x]
   486  			break
   487  		}
   488  	}
   489  	return len(a)*8 - lz
   490  }
   491  
   492  // hashAtDistance returns a random hash such that logdist(a, b) == n
   493  func hashAtDistance(a common.Hash, n int) (b common.Hash) {
   494  	if n == 0 {
   495  		return a
   496  	}
   497  	// flip bit at position n, fill the rest with random bits
   498  	b = a
   499  	pos := len(a) - n/8 - 1
   500  	bit := byte(0x01) << (byte(n%8) - 1)
   501  	if bit == 0 {
   502  		pos++
   503  		bit = 0x80
   504  	}
   505  	b[pos] = a[pos]&^bit | ^a[pos]&bit // TODO: randomize end bits
   506  	for i := pos + 1; i < len(a); i++ {
   507  		b[i] = byte(rand.Intn(255))
   508  	}
   509  	return b
   510  }