github.com/neatio-net/neatio@v1.7.3-0.20231114194659-f4d7a2226baa/network/p2p/discv5/node.go (about)

     1  package discv5
     2  
     3  import (
     4  	"crypto/ecdsa"
     5  	"crypto/elliptic"
     6  	"encoding/hex"
     7  	"errors"
     8  	"fmt"
     9  	"math/big"
    10  	"math/rand"
    11  	"net"
    12  	"net/url"
    13  	"regexp"
    14  	"strconv"
    15  	"strings"
    16  
    17  	"github.com/neatio-net/neatio/utilities/common"
    18  	"github.com/neatio-net/neatio/utilities/crypto"
    19  )
    20  
    21  type Node struct {
    22  	IP       net.IP
    23  	UDP, TCP uint16
    24  	ID       NodeID
    25  
    26  	nodeNetGuts
    27  }
    28  
    29  func NewNode(id NodeID, ip net.IP, udpPort, tcpPort uint16) *Node {
    30  	if ipv4 := ip.To4(); ipv4 != nil {
    31  		ip = ipv4
    32  	}
    33  	return &Node{
    34  		IP:          ip,
    35  		UDP:         udpPort,
    36  		TCP:         tcpPort,
    37  		ID:          id,
    38  		nodeNetGuts: nodeNetGuts{sha: crypto.Keccak256Hash(id[:])},
    39  	}
    40  }
    41  
    42  func (n *Node) addr() *net.UDPAddr {
    43  	return &net.UDPAddr{IP: n.IP, Port: int(n.UDP)}
    44  }
    45  
    46  func (n *Node) setAddr(a *net.UDPAddr) {
    47  	n.IP = a.IP
    48  	if ipv4 := a.IP.To4(); ipv4 != nil {
    49  		n.IP = ipv4
    50  	}
    51  	n.UDP = uint16(a.Port)
    52  }
    53  
    54  func (n *Node) addrEqual(a *net.UDPAddr) bool {
    55  	ip := a.IP
    56  	if ipv4 := a.IP.To4(); ipv4 != nil {
    57  		ip = ipv4
    58  	}
    59  	return n.UDP == uint16(a.Port) && n.IP.Equal(ip)
    60  }
    61  
    62  func (n *Node) Incomplete() bool {
    63  	return n.IP == nil
    64  }
    65  
    66  func (n *Node) validateComplete() error {
    67  	if n.Incomplete() {
    68  		return errors.New("incomplete node")
    69  	}
    70  	if n.UDP == 0 {
    71  		return errors.New("missing UDP port")
    72  	}
    73  	if n.TCP == 0 {
    74  		return errors.New("missing TCP port")
    75  	}
    76  	if n.IP.IsMulticast() || n.IP.IsUnspecified() {
    77  		return errors.New("invalid IP (multicast/unspecified)")
    78  	}
    79  	_, err := n.ID.Pubkey()
    80  	return err
    81  }
    82  
    83  func (n *Node) String() string {
    84  	u := url.URL{Scheme: "enode"}
    85  	if n.Incomplete() {
    86  		u.Host = fmt.Sprintf("%x", n.ID[:])
    87  	} else {
    88  		addr := net.TCPAddr{IP: n.IP, Port: int(n.TCP)}
    89  		u.User = url.User(fmt.Sprintf("%x", n.ID[:]))
    90  		u.Host = addr.String()
    91  		if n.UDP != n.TCP {
    92  			u.RawQuery = "discport=" + strconv.Itoa(int(n.UDP))
    93  		}
    94  	}
    95  	return u.String()
    96  }
    97  
    98  var incompleteNodeURL = regexp.MustCompile("(?i)^(?:enode://)?([0-9a-f]+)$")
    99  
   100  func ParseNode(rawurl string) (*Node, error) {
   101  	if m := incompleteNodeURL.FindStringSubmatch(rawurl); m != nil {
   102  		id, err := HexID(m[1])
   103  		if err != nil {
   104  			return nil, fmt.Errorf("invalid node ID (%v)", err)
   105  		}
   106  		return NewNode(id, nil, 0, 0), nil
   107  	}
   108  	return parseComplete(rawurl)
   109  }
   110  
   111  func parseComplete(rawurl string) (*Node, error) {
   112  	var (
   113  		id               NodeID
   114  		ip               net.IP
   115  		tcpPort, udpPort uint64
   116  	)
   117  	u, err := url.Parse(rawurl)
   118  	if err != nil {
   119  		return nil, err
   120  	}
   121  	if u.Scheme != "enode" {
   122  		return nil, errors.New("invalid URL scheme, want \"enode\"")
   123  	}
   124  	if u.User == nil {
   125  		return nil, errors.New("does not contain node ID")
   126  	}
   127  	if id, err = HexID(u.User.String()); err != nil {
   128  		return nil, fmt.Errorf("invalid node ID (%v)", err)
   129  	}
   130  	host, port, err := net.SplitHostPort(u.Host)
   131  	if err != nil {
   132  		return nil, fmt.Errorf("invalid host: %v", err)
   133  	}
   134  	if ip = net.ParseIP(host); ip == nil {
   135  		return nil, errors.New("invalid IP address")
   136  	}
   137  	if ipv4 := ip.To4(); ipv4 != nil {
   138  		ip = ipv4
   139  	}
   140  	if tcpPort, err = strconv.ParseUint(port, 10, 16); err != nil {
   141  		return nil, errors.New("invalid port")
   142  	}
   143  	udpPort = tcpPort
   144  	qv := u.Query()
   145  	if qv.Get("discport") != "" {
   146  		udpPort, err = strconv.ParseUint(qv.Get("discport"), 10, 16)
   147  		if err != nil {
   148  			return nil, errors.New("invalid discport in query")
   149  		}
   150  	}
   151  	return NewNode(id, ip, uint16(udpPort), uint16(tcpPort)), nil
   152  }
   153  
   154  func MustParseNode(rawurl string) *Node {
   155  	n, err := ParseNode(rawurl)
   156  	if err != nil {
   157  		panic("invalid node URL: " + err.Error())
   158  	}
   159  	return n
   160  }
   161  
   162  func (n *Node) MarshalText() ([]byte, error) {
   163  	return []byte(n.String()), nil
   164  }
   165  
   166  func (n *Node) UnmarshalText(text []byte) error {
   167  	dec, err := ParseNode(string(text))
   168  	if err == nil {
   169  		*n = *dec
   170  	}
   171  	return err
   172  }
   173  
   174  const nodeIDBits = 512
   175  
   176  type NodeID [nodeIDBits / 8]byte
   177  
   178  func (n NodeID) String() string {
   179  	return fmt.Sprintf("%x", n[:])
   180  }
   181  
   182  func (n NodeID) GoString() string {
   183  	return fmt.Sprintf("discover.HexID(\"%x\")", n[:])
   184  }
   185  
   186  func (n NodeID) TerminalString() string {
   187  	return hex.EncodeToString(n[:8])
   188  }
   189  
   190  func HexID(in string) (NodeID, error) {
   191  	var id NodeID
   192  	b, err := hex.DecodeString(strings.TrimPrefix(in, "0x"))
   193  	if err != nil {
   194  		return id, err
   195  	} else if len(b) != len(id) {
   196  		return id, fmt.Errorf("wrong length, want %d hex chars", len(id)*2)
   197  	}
   198  	copy(id[:], b)
   199  	return id, nil
   200  }
   201  
   202  func MustHexID(in string) NodeID {
   203  	id, err := HexID(in)
   204  	if err != nil {
   205  		panic(err)
   206  	}
   207  	return id
   208  }
   209  
   210  func PubkeyID(pub *ecdsa.PublicKey) NodeID {
   211  	var id NodeID
   212  	pbytes := elliptic.Marshal(pub.Curve, pub.X, pub.Y)
   213  	if len(pbytes)-1 != len(id) {
   214  		panic(fmt.Errorf("need %d bit pubkey, got %d bits", (len(id)+1)*8, len(pbytes)))
   215  	}
   216  	copy(id[:], pbytes[1:])
   217  	return id
   218  }
   219  
   220  func (id NodeID) Pubkey() (*ecdsa.PublicKey, error) {
   221  	p := &ecdsa.PublicKey{Curve: crypto.S256(), X: new(big.Int), Y: new(big.Int)}
   222  	half := len(id) / 2
   223  	p.X.SetBytes(id[:half])
   224  	p.Y.SetBytes(id[half:])
   225  	if !p.Curve.IsOnCurve(p.X, p.Y) {
   226  		return nil, errors.New("id is invalid secp256k1 curve point")
   227  	}
   228  	return p, nil
   229  }
   230  
   231  func (id NodeID) mustPubkey() ecdsa.PublicKey {
   232  	pk, err := id.Pubkey()
   233  	if err != nil {
   234  		panic(err)
   235  	}
   236  	return *pk
   237  }
   238  
   239  func recoverNodeID(hash, sig []byte) (id NodeID, err error) {
   240  	pubkey, err := crypto.Ecrecover(hash, sig)
   241  	if err != nil {
   242  		return id, err
   243  	}
   244  	if len(pubkey)-1 != len(id) {
   245  		return id, fmt.Errorf("recovered pubkey has %d bits, want %d bits", len(pubkey)*8, (len(id)+1)*8)
   246  	}
   247  	for i := range id {
   248  		id[i] = pubkey[i+1]
   249  	}
   250  	return id, nil
   251  }
   252  
   253  func distcmp(target, a, b common.Hash) int {
   254  	for i := range target {
   255  		da := a[i] ^ target[i]
   256  		db := b[i] ^ target[i]
   257  		if da > db {
   258  			return 1
   259  		} else if da < db {
   260  			return -1
   261  		}
   262  	}
   263  	return 0
   264  }
   265  
   266  var lzcount = [256]int{
   267  	8, 7, 6, 6, 5, 5, 5, 5,
   268  	4, 4, 4, 4, 4, 4, 4, 4,
   269  	3, 3, 3, 3, 3, 3, 3, 3,
   270  	3, 3, 3, 3, 3, 3, 3, 3,
   271  	2, 2, 2, 2, 2, 2, 2, 2,
   272  	2, 2, 2, 2, 2, 2, 2, 2,
   273  	2, 2, 2, 2, 2, 2, 2, 2,
   274  	2, 2, 2, 2, 2, 2, 2, 2,
   275  	1, 1, 1, 1, 1, 1, 1, 1,
   276  	1, 1, 1, 1, 1, 1, 1, 1,
   277  	1, 1, 1, 1, 1, 1, 1, 1,
   278  	1, 1, 1, 1, 1, 1, 1, 1,
   279  	1, 1, 1, 1, 1, 1, 1, 1,
   280  	1, 1, 1, 1, 1, 1, 1, 1,
   281  	1, 1, 1, 1, 1, 1, 1, 1,
   282  	1, 1, 1, 1, 1, 1, 1, 1,
   283  	0, 0, 0, 0, 0, 0, 0, 0,
   284  	0, 0, 0, 0, 0, 0, 0, 0,
   285  	0, 0, 0, 0, 0, 0, 0, 0,
   286  	0, 0, 0, 0, 0, 0, 0, 0,
   287  	0, 0, 0, 0, 0, 0, 0, 0,
   288  	0, 0, 0, 0, 0, 0, 0, 0,
   289  	0, 0, 0, 0, 0, 0, 0, 0,
   290  	0, 0, 0, 0, 0, 0, 0, 0,
   291  	0, 0, 0, 0, 0, 0, 0, 0,
   292  	0, 0, 0, 0, 0, 0, 0, 0,
   293  	0, 0, 0, 0, 0, 0, 0, 0,
   294  	0, 0, 0, 0, 0, 0, 0, 0,
   295  	0, 0, 0, 0, 0, 0, 0, 0,
   296  	0, 0, 0, 0, 0, 0, 0, 0,
   297  	0, 0, 0, 0, 0, 0, 0, 0,
   298  	0, 0, 0, 0, 0, 0, 0, 0,
   299  }
   300  
   301  func logdist(a, b common.Hash) int {
   302  	lz := 0
   303  	for i := range a {
   304  		x := a[i] ^ b[i]
   305  		if x == 0 {
   306  			lz += 8
   307  		} else {
   308  			lz += lzcount[x]
   309  			break
   310  		}
   311  	}
   312  	return len(a)*8 - lz
   313  }
   314  
   315  func hashAtDistance(a common.Hash, n int) (b common.Hash) {
   316  	if n == 0 {
   317  		return a
   318  	}
   319  	b = a
   320  	pos := len(a) - n/8 - 1
   321  	bit := byte(0x01) << (byte(n%8) - 1)
   322  	if bit == 0 {
   323  		pos++
   324  		bit = 0x80
   325  	}
   326  	b[pos] = a[pos]&^bit | ^a[pos]&bit
   327  	for i := pos + 1; i < len(a); i++ {
   328  		b[i] = byte(rand.Intn(255))
   329  	}
   330  	return b
   331  }