github.com/neatio-net/neatio@v1.7.3-0.20231114194659-f4d7a2226baa/network/p2p/discover/node.go (about)

     1  package discover
     2  
     3  import (
     4  	"crypto/ecdsa"
     5  	"crypto/elliptic"
     6  	"encoding/hex"
     7  	"errors"
     8  	"fmt"
     9  	"math/big"
    10  	"math/rand"
    11  	"net"
    12  	"net/url"
    13  	"regexp"
    14  	"strconv"
    15  	"strings"
    16  	"time"
    17  
    18  	"github.com/neatio-net/neatio/utilities/common"
    19  	"github.com/neatio-net/neatio/utilities/crypto"
    20  	"github.com/neatio-net/neatio/utilities/crypto/secp256k1"
    21  )
    22  
    23  const NodeIDBits = 512
    24  
    25  type Node struct {
    26  	IP       net.IP
    27  	UDP, TCP uint16
    28  	ID       NodeID
    29  
    30  	sha common.Hash
    31  
    32  	addedAt time.Time
    33  }
    34  
    35  func NewNode(id NodeID, ip net.IP, udpPort, tcpPort uint16) *Node {
    36  	if ipv4 := ip.To4(); ipv4 != nil {
    37  		ip = ipv4
    38  	}
    39  	return &Node{
    40  		IP:  ip,
    41  		UDP: udpPort,
    42  		TCP: tcpPort,
    43  		ID:  id,
    44  		sha: crypto.Keccak256Hash(id[:]),
    45  	}
    46  }
    47  
    48  func (n *Node) addr() *net.UDPAddr {
    49  	return &net.UDPAddr{IP: n.IP, Port: int(n.UDP)}
    50  }
    51  
    52  func (n *Node) Incomplete() bool {
    53  	return n.IP == nil
    54  }
    55  
    56  func (n *Node) validateComplete() error {
    57  	if n.Incomplete() {
    58  		return errors.New("incomplete node")
    59  	}
    60  	if n.UDP == 0 {
    61  		return errors.New("missing UDP port")
    62  	}
    63  	if n.TCP == 0 {
    64  		return errors.New("missing TCP port")
    65  	}
    66  	if n.IP.IsMulticast() || n.IP.IsUnspecified() {
    67  		return errors.New("invalid IP (multicast/unspecified)")
    68  	}
    69  	_, err := n.ID.Pubkey()
    70  	return err
    71  }
    72  
    73  func (n *Node) String() string {
    74  	u := url.URL{Scheme: "enode"}
    75  	if n.Incomplete() {
    76  		u.Host = fmt.Sprintf("%x", n.ID[:])
    77  	} else {
    78  		addr := net.TCPAddr{IP: n.IP, Port: int(n.TCP)}
    79  		u.User = url.User(fmt.Sprintf("%x", n.ID[:]))
    80  		u.Host = addr.String()
    81  		if n.UDP != n.TCP {
    82  			u.RawQuery = "discport=" + strconv.Itoa(int(n.UDP))
    83  		}
    84  	}
    85  	return u.String()
    86  }
    87  
    88  var incompleteNodeURL = regexp.MustCompile("(?i)^(?:enode://)?([0-9a-f]+)$")
    89  
    90  func ParseNode(rawurl string) (*Node, error) {
    91  	if m := incompleteNodeURL.FindStringSubmatch(rawurl); m != nil {
    92  		id, err := HexID(m[1])
    93  		if err != nil {
    94  			return nil, fmt.Errorf("invalid node ID (%v)", err)
    95  		}
    96  		return NewNode(id, nil, 0, 0), nil
    97  	}
    98  	return parseComplete(rawurl)
    99  }
   100  
   101  func parseComplete(rawurl string) (*Node, error) {
   102  	var (
   103  		id               NodeID
   104  		ip               net.IP
   105  		tcpPort, udpPort uint64
   106  	)
   107  	u, err := url.Parse(rawurl)
   108  	if err != nil {
   109  		return nil, err
   110  	}
   111  	if u.Scheme != "enode" {
   112  		return nil, errors.New("invalid URL scheme, want \"enode\"")
   113  	}
   114  	if u.User == nil {
   115  		return nil, errors.New("does not contain node ID")
   116  	}
   117  	if id, err = HexID(u.User.String()); err != nil {
   118  		return nil, fmt.Errorf("invalid node ID (%v)", err)
   119  	}
   120  	host, port, err := net.SplitHostPort(u.Host)
   121  	if err != nil {
   122  		return nil, fmt.Errorf("invalid host: %v", err)
   123  	}
   124  	if ip = net.ParseIP(host); ip == nil {
   125  		return nil, errors.New("invalid IP address")
   126  	}
   127  	if ipv4 := ip.To4(); ipv4 != nil {
   128  		ip = ipv4
   129  	}
   130  	if tcpPort, err = strconv.ParseUint(port, 10, 16); err != nil {
   131  		return nil, errors.New("invalid port")
   132  	}
   133  	udpPort = tcpPort
   134  	qv := u.Query()
   135  	if qv.Get("discport") != "" {
   136  		udpPort, err = strconv.ParseUint(qv.Get("discport"), 10, 16)
   137  		if err != nil {
   138  			return nil, errors.New("invalid discport in query")
   139  		}
   140  	}
   141  	return NewNode(id, ip, uint16(udpPort), uint16(tcpPort)), nil
   142  }
   143  
   144  func MustParseNode(rawurl string) *Node {
   145  	n, err := ParseNode(rawurl)
   146  	if err != nil {
   147  		panic("invalid node URL: " + err.Error())
   148  	}
   149  	return n
   150  }
   151  
   152  func (n *Node) MarshalText() ([]byte, error) {
   153  	return []byte(n.String()), nil
   154  }
   155  
   156  func (n *Node) UnmarshalText(text []byte) error {
   157  	dec, err := ParseNode(string(text))
   158  	if err == nil {
   159  		*n = *dec
   160  	}
   161  	return err
   162  }
   163  
   164  type NodeID [NodeIDBits / 8]byte
   165  
   166  func (n NodeID) Bytes() []byte {
   167  	return n[:]
   168  }
   169  
   170  func (n NodeID) String() string {
   171  	return fmt.Sprintf("%x", n[:])
   172  }
   173  
   174  func (n NodeID) GoString() string {
   175  	return fmt.Sprintf("discover.HexID(\"%x\")", n[:])
   176  }
   177  
   178  func (n NodeID) TerminalString() string {
   179  	return hex.EncodeToString(n[:8])
   180  }
   181  
   182  func (n NodeID) MarshalText() ([]byte, error) {
   183  	return []byte(hex.EncodeToString(n[:])), nil
   184  }
   185  
   186  func (n *NodeID) UnmarshalText(text []byte) error {
   187  	id, err := HexID(string(text))
   188  	if err != nil {
   189  		return err
   190  	}
   191  	*n = id
   192  	return nil
   193  }
   194  
   195  func BytesID(b []byte) (NodeID, error) {
   196  	var id NodeID
   197  	if len(b) != len(id) {
   198  		return id, fmt.Errorf("wrong length, want %d bytes", len(id))
   199  	}
   200  	copy(id[:], b)
   201  	return id, nil
   202  }
   203  
   204  func MustBytesID(b []byte) NodeID {
   205  	id, err := BytesID(b)
   206  	if err != nil {
   207  		panic(err)
   208  	}
   209  	return id
   210  }
   211  
   212  func HexID(in string) (NodeID, error) {
   213  	var id NodeID
   214  	b, err := hex.DecodeString(strings.TrimPrefix(in, "0x"))
   215  	if err != nil {
   216  		return id, err
   217  	} else if len(b) != len(id) {
   218  		return id, fmt.Errorf("wrong length, want %d hex chars", len(id)*2)
   219  	}
   220  	copy(id[:], b)
   221  	return id, nil
   222  }
   223  
   224  func MustHexID(in string) NodeID {
   225  	id, err := HexID(in)
   226  	if err != nil {
   227  		panic(err)
   228  	}
   229  	return id
   230  }
   231  
   232  func PubkeyID(pub *ecdsa.PublicKey) NodeID {
   233  	var id NodeID
   234  	pbytes := elliptic.Marshal(pub.Curve, pub.X, pub.Y)
   235  	if len(pbytes)-1 != len(id) {
   236  		panic(fmt.Errorf("need %d bit pubkey, got %d bits", (len(id)+1)*8, len(pbytes)))
   237  	}
   238  	copy(id[:], pbytes[1:])
   239  	return id
   240  }
   241  
   242  func (id NodeID) Pubkey() (*ecdsa.PublicKey, error) {
   243  	p := &ecdsa.PublicKey{Curve: crypto.S256(), X: new(big.Int), Y: new(big.Int)}
   244  	half := len(id) / 2
   245  	p.X.SetBytes(id[:half])
   246  	p.Y.SetBytes(id[half:])
   247  	if !p.Curve.IsOnCurve(p.X, p.Y) {
   248  		return nil, errors.New("id is invalid secp256k1 curve point")
   249  	}
   250  	return p, nil
   251  }
   252  
   253  func recoverNodeID(hash, sig []byte) (id NodeID, err error) {
   254  	pubkey, err := secp256k1.RecoverPubkey(hash, sig)
   255  	if err != nil {
   256  		return id, err
   257  	}
   258  	if len(pubkey)-1 != len(id) {
   259  		return id, fmt.Errorf("recovered pubkey has %d bits, want %d bits", len(pubkey)*8, (len(id)+1)*8)
   260  	}
   261  	for i := range id {
   262  		id[i] = pubkey[i+1]
   263  	}
   264  	return id, nil
   265  }
   266  
   267  func distcmp(target, a, b common.Hash) int {
   268  	for i := range target {
   269  		da := a[i] ^ target[i]
   270  		db := b[i] ^ target[i]
   271  		if da > db {
   272  			return 1
   273  		} else if da < db {
   274  			return -1
   275  		}
   276  	}
   277  	return 0
   278  }
   279  
   280  var lzcount = [256]int{
   281  	8, 7, 6, 6, 5, 5, 5, 5,
   282  	4, 4, 4, 4, 4, 4, 4, 4,
   283  	3, 3, 3, 3, 3, 3, 3, 3,
   284  	3, 3, 3, 3, 3, 3, 3, 3,
   285  	2, 2, 2, 2, 2, 2, 2, 2,
   286  	2, 2, 2, 2, 2, 2, 2, 2,
   287  	2, 2, 2, 2, 2, 2, 2, 2,
   288  	2, 2, 2, 2, 2, 2, 2, 2,
   289  	1, 1, 1, 1, 1, 1, 1, 1,
   290  	1, 1, 1, 1, 1, 1, 1, 1,
   291  	1, 1, 1, 1, 1, 1, 1, 1,
   292  	1, 1, 1, 1, 1, 1, 1, 1,
   293  	1, 1, 1, 1, 1, 1, 1, 1,
   294  	1, 1, 1, 1, 1, 1, 1, 1,
   295  	1, 1, 1, 1, 1, 1, 1, 1,
   296  	1, 1, 1, 1, 1, 1, 1, 1,
   297  	0, 0, 0, 0, 0, 0, 0, 0,
   298  	0, 0, 0, 0, 0, 0, 0, 0,
   299  	0, 0, 0, 0, 0, 0, 0, 0,
   300  	0, 0, 0, 0, 0, 0, 0, 0,
   301  	0, 0, 0, 0, 0, 0, 0, 0,
   302  	0, 0, 0, 0, 0, 0, 0, 0,
   303  	0, 0, 0, 0, 0, 0, 0, 0,
   304  	0, 0, 0, 0, 0, 0, 0, 0,
   305  	0, 0, 0, 0, 0, 0, 0, 0,
   306  	0, 0, 0, 0, 0, 0, 0, 0,
   307  	0, 0, 0, 0, 0, 0, 0, 0,
   308  	0, 0, 0, 0, 0, 0, 0, 0,
   309  	0, 0, 0, 0, 0, 0, 0, 0,
   310  	0, 0, 0, 0, 0, 0, 0, 0,
   311  	0, 0, 0, 0, 0, 0, 0, 0,
   312  	0, 0, 0, 0, 0, 0, 0, 0,
   313  }
   314  
   315  func logdist(a, b common.Hash) int {
   316  	lz := 0
   317  	for i := range a {
   318  		x := a[i] ^ b[i]
   319  		if x == 0 {
   320  			lz += 8
   321  		} else {
   322  			lz += lzcount[x]
   323  			break
   324  		}
   325  	}
   326  	return len(a)*8 - lz
   327  }
   328  
   329  func hashAtDistance(a common.Hash, n int) (b common.Hash) {
   330  	if n == 0 {
   331  		return a
   332  	}
   333  	b = a
   334  	pos := len(a) - n/8 - 1
   335  	bit := byte(0x01) << (byte(n%8) - 1)
   336  	if bit == 0 {
   337  		pos++
   338  		bit = 0x80
   339  	}
   340  	b[pos] = a[pos]&^bit | ^a[pos]&bit
   341  	for i := pos + 1; i < len(a); i++ {
   342  		b[i] = byte(rand.Intn(255))
   343  	}
   344  	return b
   345  }