github.com/aigarnetwork/aigar@v0.0.0-20191115204914-d59a6eb70f8e/p2p/dnsdisc/tree.go (about)

     1  //  Copyright 2018 The go-ethereum Authors
     2  //  Copyright 2019 The go-aigar Authors
     3  //  This file is part of the go-aigar library.
     4  //
     5  //  The go-aigar library is free software: you can redistribute it and/or modify
     6  //  it under the terms of the GNU Lesser General Public License as published by
     7  //  the Free Software Foundation, either version 3 of the License, or
     8  //  (at your option) any later version.
     9  //
    10  //  The go-aigar library is distributed in the hope that it will be useful,
    11  //  but WITHOUT ANY WARRANTY; without even the implied warranty of
    12  //  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    13  //  GNU Lesser General Public License for more details.
    14  //
    15  //  You should have received a copy of the GNU Lesser General Public License
    16  //  along with the go-aigar library. If not, see <http://www.gnu.org/licenses/>.
    17  
    18  package dnsdisc
    19  
    20  import (
    21  	"bytes"
    22  	"crypto/ecdsa"
    23  	"encoding/base32"
    24  	"encoding/base64"
    25  	"fmt"
    26  	"io"
    27  	"sort"
    28  	"strings"
    29  
    30  	"github.com/AigarNetwork/aigar/crypto"
    31  	"github.com/AigarNetwork/aigar/p2p/enode"
    32  	"github.com/AigarNetwork/aigar/p2p/enr"
    33  	"github.com/AigarNetwork/aigar/rlp"
    34  	"golang.org/x/crypto/sha3"
    35  )
    36  
    37  // Tree is a merkle tree of node records.
    38  type Tree struct {
    39  	root    *rootEntry
    40  	entries map[string]entry
    41  }
    42  
    43  // Sign signs the tree with the given private key and sets the sequence number.
    44  func (t *Tree) Sign(key *ecdsa.PrivateKey, domain string) (url string, err error) {
    45  	root := *t.root
    46  	sig, err := crypto.Sign(root.sigHash(), key)
    47  	if err != nil {
    48  		return "", err
    49  	}
    50  	root.sig = sig
    51  	t.root = &root
    52  	link := &linkEntry{domain, &key.PublicKey}
    53  	return link.String(), nil
    54  }
    55  
    56  // SetSignature verifies the given signature and assigns it as the tree's current
    57  // signature if valid.
    58  func (t *Tree) SetSignature(pubkey *ecdsa.PublicKey, signature string) error {
    59  	sig, err := b64format.DecodeString(signature)
    60  	if err != nil || len(sig) != crypto.SignatureLength {
    61  		return errInvalidSig
    62  	}
    63  	root := *t.root
    64  	root.sig = sig
    65  	if !root.verifySignature(pubkey) {
    66  		return errInvalidSig
    67  	}
    68  	t.root = &root
    69  	return nil
    70  }
    71  
    72  // Seq returns the sequence number of the tree.
    73  func (t *Tree) Seq() uint {
    74  	return t.root.seq
    75  }
    76  
    77  // Signature returns the signature of the tree.
    78  func (t *Tree) Signature() string {
    79  	return b64format.EncodeToString(t.root.sig)
    80  }
    81  
    82  // ToTXT returns all DNS TXT records required for the tree.
    83  func (t *Tree) ToTXT(domain string) map[string]string {
    84  	records := map[string]string{domain: t.root.String()}
    85  	for _, e := range t.entries {
    86  		sd := subdomain(e)
    87  		if domain != "" {
    88  			sd = sd + "." + domain
    89  		}
    90  		records[sd] = e.String()
    91  	}
    92  	return records
    93  }
    94  
    95  // Links returns all links contained in the tree.
    96  func (t *Tree) Links() []string {
    97  	var links []string
    98  	for _, e := range t.entries {
    99  		if le, ok := e.(*linkEntry); ok {
   100  			links = append(links, le.String())
   101  		}
   102  	}
   103  	return links
   104  }
   105  
   106  // Nodes returns all nodes contained in the tree.
   107  func (t *Tree) Nodes() []*enode.Node {
   108  	var nodes []*enode.Node
   109  	for _, e := range t.entries {
   110  		if ee, ok := e.(*enrEntry); ok {
   111  			nodes = append(nodes, ee.node)
   112  		}
   113  	}
   114  	return nodes
   115  }
   116  
   117  const (
   118  	hashAbbrev    = 16
   119  	maxChildren   = 300 / hashAbbrev * (13 / 8)
   120  	minHashLength = 12
   121  )
   122  
   123  // MakeTree creates a tree containing the given nodes and links.
   124  func MakeTree(seq uint, nodes []*enode.Node, links []string) (*Tree, error) {
   125  	// Sort records by ID and ensure all nodes have a valid record.
   126  	records := make([]*enode.Node, len(nodes))
   127  
   128  	copy(records, nodes)
   129  	sortByID(records)
   130  	for _, n := range records {
   131  		if len(n.Record().Signature()) == 0 {
   132  			return nil, fmt.Errorf("can't add node %v: unsigned node record", n.ID())
   133  		}
   134  	}
   135  
   136  	// Create the leaf list.
   137  	enrEntries := make([]entry, len(records))
   138  	for i, r := range records {
   139  		enrEntries[i] = &enrEntry{r}
   140  	}
   141  	linkEntries := make([]entry, len(links))
   142  	for i, l := range links {
   143  		le, err := parseLink(l)
   144  		if err != nil {
   145  			return nil, err
   146  		}
   147  		linkEntries[i] = le
   148  	}
   149  
   150  	// Create intermediate nodes.
   151  	t := &Tree{entries: make(map[string]entry)}
   152  	eroot := t.build(enrEntries)
   153  	t.entries[subdomain(eroot)] = eroot
   154  	lroot := t.build(linkEntries)
   155  	t.entries[subdomain(lroot)] = lroot
   156  	t.root = &rootEntry{seq: seq, eroot: subdomain(eroot), lroot: subdomain(lroot)}
   157  	return t, nil
   158  }
   159  
   160  func (t *Tree) build(entries []entry) entry {
   161  	if len(entries) == 1 {
   162  		return entries[0]
   163  	}
   164  	if len(entries) <= maxChildren {
   165  		hashes := make([]string, len(entries))
   166  		for i, e := range entries {
   167  			hashes[i] = subdomain(e)
   168  			t.entries[hashes[i]] = e
   169  		}
   170  		return &branchEntry{hashes}
   171  	}
   172  	var subtrees []entry
   173  	for len(entries) > 0 {
   174  		n := maxChildren
   175  		if len(entries) < n {
   176  			n = len(entries)
   177  		}
   178  		sub := t.build(entries[:n])
   179  		entries = entries[n:]
   180  		subtrees = append(subtrees, sub)
   181  		t.entries[subdomain(sub)] = sub
   182  	}
   183  	return t.build(subtrees)
   184  }
   185  
   186  func sortByID(nodes []*enode.Node) []*enode.Node {
   187  	sort.Slice(nodes, func(i, j int) bool {
   188  		return bytes.Compare(nodes[i].ID().Bytes(), nodes[j].ID().Bytes()) < 0
   189  	})
   190  	return nodes
   191  }
   192  
   193  // Entry Types
   194  
   195  type entry interface {
   196  	fmt.Stringer
   197  }
   198  
   199  type (
   200  	rootEntry struct {
   201  		eroot string
   202  		lroot string
   203  		seq   uint
   204  		sig   []byte
   205  	}
   206  	branchEntry struct {
   207  		children []string
   208  	}
   209  	enrEntry struct {
   210  		node *enode.Node
   211  	}
   212  	linkEntry struct {
   213  		domain string
   214  		pubkey *ecdsa.PublicKey
   215  	}
   216  )
   217  
   218  // Entry Encoding
   219  
   220  var (
   221  	b32format = base32.StdEncoding.WithPadding(base32.NoPadding)
   222  	b64format = base64.RawURLEncoding
   223  )
   224  
   225  const (
   226  	rootPrefix   = "enrtree-root:v1"
   227  	linkPrefix   = "enrtree://"
   228  	branchPrefix = "enrtree-branch:"
   229  	enrPrefix    = "enr:"
   230  )
   231  
   232  func subdomain(e entry) string {
   233  	h := sha3.NewLegacyKeccak256()
   234  	io.WriteString(h, e.String())
   235  	return b32format.EncodeToString(h.Sum(nil)[:16])
   236  }
   237  
   238  func (e *rootEntry) String() string {
   239  	return fmt.Sprintf(rootPrefix+" e=%s l=%s seq=%d sig=%s", e.eroot, e.lroot, e.seq, b64format.EncodeToString(e.sig))
   240  }
   241  
   242  func (e *rootEntry) sigHash() []byte {
   243  	h := sha3.NewLegacyKeccak256()
   244  	fmt.Fprintf(h, rootPrefix+" e=%s l=%s seq=%d", e.eroot, e.lroot, e.seq)
   245  	return h.Sum(nil)
   246  }
   247  
   248  func (e *rootEntry) verifySignature(pubkey *ecdsa.PublicKey) bool {
   249  	sig := e.sig[:crypto.RecoveryIDOffset] // remove recovery id
   250  	return crypto.VerifySignature(crypto.FromECDSAPub(pubkey), e.sigHash(), sig)
   251  }
   252  
   253  func (e *branchEntry) String() string {
   254  	return branchPrefix + strings.Join(e.children, ",")
   255  }
   256  
   257  func (e *enrEntry) String() string {
   258  	return e.node.String()
   259  }
   260  
   261  func (e *linkEntry) String() string {
   262  	pubkey := b32format.EncodeToString(crypto.CompressPubkey(e.pubkey))
   263  	return fmt.Sprintf("%s%s@%s", linkPrefix, pubkey, e.domain)
   264  }
   265  
   266  // Entry Parsing
   267  
   268  func parseEntry(e string, validSchemes enr.IdentityScheme) (entry, error) {
   269  	switch {
   270  	case strings.HasPrefix(e, linkPrefix):
   271  		return parseLinkEntry(e)
   272  	case strings.HasPrefix(e, branchPrefix):
   273  		return parseBranch(e)
   274  	case strings.HasPrefix(e, enrPrefix):
   275  		return parseENR(e, validSchemes)
   276  	default:
   277  		return nil, errUnknownEntry
   278  	}
   279  }
   280  
   281  func parseRoot(e string) (rootEntry, error) {
   282  	var eroot, lroot, sig string
   283  	var seq uint
   284  	if _, err := fmt.Sscanf(e, rootPrefix+" e=%s l=%s seq=%d sig=%s", &eroot, &lroot, &seq, &sig); err != nil {
   285  		return rootEntry{}, entryError{"root", errSyntax}
   286  	}
   287  	if !isValidHash(eroot) || !isValidHash(lroot) {
   288  		return rootEntry{}, entryError{"root", errInvalidChild}
   289  	}
   290  	sigb, err := b64format.DecodeString(sig)
   291  	if err != nil || len(sigb) != crypto.SignatureLength {
   292  		return rootEntry{}, entryError{"root", errInvalidSig}
   293  	}
   294  	return rootEntry{eroot, lroot, seq, sigb}, nil
   295  }
   296  
   297  func parseLinkEntry(e string) (entry, error) {
   298  	le, err := parseLink(e)
   299  	if err != nil {
   300  		return nil, err
   301  	}
   302  	return le, nil
   303  }
   304  
   305  func parseLink(e string) (*linkEntry, error) {
   306  	if !strings.HasPrefix(e, linkPrefix) {
   307  		return nil, fmt.Errorf("wrong/missing scheme 'enrtree' in URL")
   308  	}
   309  	e = e[len(linkPrefix):]
   310  	pos := strings.IndexByte(e, '@')
   311  	if pos == -1 {
   312  		return nil, entryError{"link", errNoPubkey}
   313  	}
   314  	keystring, domain := e[:pos], e[pos+1:]
   315  	keybytes, err := b32format.DecodeString(keystring)
   316  	if err != nil {
   317  		return nil, entryError{"link", errBadPubkey}
   318  	}
   319  	key, err := crypto.DecompressPubkey(keybytes)
   320  	if err != nil {
   321  		return nil, entryError{"link", errBadPubkey}
   322  	}
   323  	return &linkEntry{domain, key}, nil
   324  }
   325  
   326  func parseBranch(e string) (entry, error) {
   327  	e = e[len(branchPrefix):]
   328  	if e == "" {
   329  		return &branchEntry{}, nil // empty entry is OK
   330  	}
   331  	hashes := make([]string, 0, strings.Count(e, ","))
   332  	for _, c := range strings.Split(e, ",") {
   333  		if !isValidHash(c) {
   334  			return nil, entryError{"branch", errInvalidChild}
   335  		}
   336  		hashes = append(hashes, c)
   337  	}
   338  	return &branchEntry{hashes}, nil
   339  }
   340  
   341  func parseENR(e string, validSchemes enr.IdentityScheme) (entry, error) {
   342  	e = e[len(enrPrefix):]
   343  	enc, err := b64format.DecodeString(e)
   344  	if err != nil {
   345  		return nil, entryError{"enr", errInvalidENR}
   346  	}
   347  	var rec enr.Record
   348  	if err := rlp.DecodeBytes(enc, &rec); err != nil {
   349  		return nil, entryError{"enr", err}
   350  	}
   351  	n, err := enode.New(validSchemes, &rec)
   352  	if err != nil {
   353  		return nil, entryError{"enr", err}
   354  	}
   355  	return &enrEntry{n}, nil
   356  }
   357  
   358  func isValidHash(s string) bool {
   359  	dlen := b32format.DecodedLen(len(s))
   360  	if dlen < minHashLength || dlen > 32 || strings.ContainsAny(s, "\n\r") {
   361  		return false
   362  	}
   363  	buf := make([]byte, 32)
   364  	_, err := b32format.Decode(buf, []byte(s))
   365  	return err == nil
   366  }
   367  
   368  // truncateHash truncates the given base32 hash string to the minimum acceptable length.
   369  func truncateHash(hash string) string {
   370  	maxLen := b32format.EncodedLen(minHashLength)
   371  	if len(hash) < maxLen {
   372  		panic(fmt.Errorf("dnsdisc: hash %q is too short", hash))
   373  	}
   374  	return hash[:maxLen]
   375  }
   376  
   377  // URL encoding
   378  
   379  // ParseURL parses an enrtree:// URL and returns its components.
   380  func ParseURL(url string) (domain string, pubkey *ecdsa.PublicKey, err error) {
   381  	le, err := parseLink(url)
   382  	if err != nil {
   383  		return "", nil, err
   384  	}
   385  	return le.domain, le.pubkey, nil
   386  }