github.com/neatlab/neatio@v1.7.3-0.20220425043230-d903e92fcc75/chain/trie/node.go (about)

     1  package trie
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"strings"
     7  
     8  	"github.com/neatlab/neatio/utilities/common"
     9  	"github.com/neatlab/neatio/utilities/rlp"
    10  )
    11  
    12  var indices = []string{"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b", "c", "d", "e", "f", "[17]"}
    13  
    14  type node interface {
    15  	fstring(string) string
    16  	cache() (hashNode, bool)
    17  }
    18  
    19  type (
    20  	fullNode struct {
    21  		Children [17]node
    22  		flags    nodeFlag
    23  	}
    24  	shortNode struct {
    25  		Key   []byte
    26  		Val   node
    27  		flags nodeFlag
    28  	}
    29  	hashNode  []byte
    30  	valueNode []byte
    31  )
    32  
    33  var nilValueNode = valueNode(nil)
    34  
    35  func (n *fullNode) EncodeRLP(w io.Writer) error {
    36  	var nodes [17]node
    37  
    38  	for i, side := range &n.Children {
    39  		if side != nil {
    40  			nodes[i] = side
    41  		} else {
    42  			nodes[i] = nilValueNode
    43  		}
    44  	}
    45  	return rlp.Encode(w, nodes)
    46  }
    47  
    48  func (n *fullNode) copy() *fullNode   { copy := *n; return &copy }
    49  func (n *shortNode) copy() *shortNode { copy := *n; return &copy }
    50  
    51  type nodeFlag struct {
    52  	hash  hashNode
    53  	dirty bool
    54  }
    55  
    56  func (n *fullNode) cache() (hashNode, bool)  { return n.flags.hash, n.flags.dirty }
    57  func (n *shortNode) cache() (hashNode, bool) { return n.flags.hash, n.flags.dirty }
    58  func (n hashNode) cache() (hashNode, bool)   { return nil, true }
    59  func (n valueNode) cache() (hashNode, bool)  { return nil, true }
    60  
    61  func (n *fullNode) String() string  { return n.fstring("") }
    62  func (n *shortNode) String() string { return n.fstring("") }
    63  func (n hashNode) String() string   { return n.fstring("") }
    64  func (n valueNode) String() string  { return n.fstring("") }
    65  
    66  func (n *fullNode) fstring(ind string) string {
    67  	resp := fmt.Sprintf("[\n%s  ", ind)
    68  	for i, node := range &n.Children {
    69  		if node == nil {
    70  			resp += fmt.Sprintf("%s: <nil> ", indices[i])
    71  		} else {
    72  			resp += fmt.Sprintf("%s: %v", indices[i], node.fstring(ind+"  "))
    73  		}
    74  	}
    75  	return resp + fmt.Sprintf("\n%s] ", ind)
    76  }
    77  func (n *shortNode) fstring(ind string) string {
    78  	return fmt.Sprintf("{%x: %v} ", n.Key, n.Val.fstring(ind+"  "))
    79  }
    80  func (n hashNode) fstring(ind string) string {
    81  	return fmt.Sprintf("<%x> ", []byte(n))
    82  }
    83  func (n valueNode) fstring(ind string) string {
    84  	return fmt.Sprintf("%x ", []byte(n))
    85  }
    86  
    87  func mustDecodeNode(hash, buf []byte) node {
    88  	n, err := decodeNode(hash, buf)
    89  	if err != nil {
    90  		panic(fmt.Sprintf("node %x: %v", hash, err))
    91  	}
    92  	return n
    93  }
    94  
    95  func decodeNode(hash, buf []byte) (node, error) {
    96  	if len(buf) == 0 {
    97  		return nil, io.ErrUnexpectedEOF
    98  	}
    99  	elems, _, err := rlp.SplitList(buf)
   100  	if err != nil {
   101  		return nil, fmt.Errorf("decode error: %v", err)
   102  	}
   103  	switch c, _ := rlp.CountValues(elems); c {
   104  	case 2:
   105  		n, err := decodeShort(hash, elems)
   106  		return n, wrapError(err, "short")
   107  	case 17:
   108  		n, err := decodeFull(hash, elems)
   109  		return n, wrapError(err, "full")
   110  	default:
   111  		return nil, fmt.Errorf("invalid number of list elements: %v", c)
   112  	}
   113  }
   114  
   115  func decodeShort(hash, elems []byte) (node, error) {
   116  	kbuf, rest, err := rlp.SplitString(elems)
   117  	if err != nil {
   118  		return nil, err
   119  	}
   120  	flag := nodeFlag{hash: hash}
   121  	key := compactToHex(kbuf)
   122  	if hasTerm(key) {
   123  
   124  		val, _, err := rlp.SplitString(rest)
   125  		if err != nil {
   126  			return nil, fmt.Errorf("invalid value node: %v", err)
   127  		}
   128  		return &shortNode{key, append(valueNode{}, val...), flag}, nil
   129  	}
   130  	r, _, err := decodeRef(rest)
   131  	if err != nil {
   132  		return nil, wrapError(err, "val")
   133  	}
   134  	return &shortNode{key, r, flag}, nil
   135  }
   136  
   137  func decodeFull(hash, elems []byte) (*fullNode, error) {
   138  	n := &fullNode{flags: nodeFlag{hash: hash}}
   139  	for i := 0; i < 16; i++ {
   140  		cld, rest, err := decodeRef(elems)
   141  		if err != nil {
   142  			return n, wrapError(err, fmt.Sprintf("[%d]", i))
   143  		}
   144  		n.Children[i], elems = cld, rest
   145  	}
   146  	val, _, err := rlp.SplitString(elems)
   147  	if err != nil {
   148  		return n, err
   149  	}
   150  	if len(val) > 0 {
   151  		n.Children[16] = append(valueNode{}, val...)
   152  	}
   153  	return n, nil
   154  }
   155  
   156  const hashLen = len(common.Hash{})
   157  
   158  func decodeRef(buf []byte) (node, []byte, error) {
   159  	kind, val, rest, err := rlp.Split(buf)
   160  	if err != nil {
   161  		return nil, buf, err
   162  	}
   163  	switch {
   164  	case kind == rlp.List:
   165  
   166  		if size := len(buf) - len(rest); size > hashLen {
   167  			err := fmt.Errorf("oversized embedded node (size is %d bytes, want size < %d)", size, hashLen)
   168  			return nil, buf, err
   169  		}
   170  		n, err := decodeNode(nil, buf)
   171  		return n, rest, err
   172  	case kind == rlp.String && len(val) == 0:
   173  
   174  		return nil, rest, nil
   175  	case kind == rlp.String && len(val) == 32:
   176  		return append(hashNode{}, val...), rest, nil
   177  	default:
   178  		return nil, nil, fmt.Errorf("invalid RLP string size %d (want 0 or 32)", len(val))
   179  	}
   180  }
   181  
   182  type decodeError struct {
   183  	what  error
   184  	stack []string
   185  }
   186  
   187  func wrapError(err error, ctx string) error {
   188  	if err == nil {
   189  		return nil
   190  	}
   191  	if decErr, ok := err.(*decodeError); ok {
   192  		decErr.stack = append(decErr.stack, ctx)
   193  		return decErr
   194  	}
   195  	return &decodeError{err, []string{ctx}}
   196  }
   197  
   198  func (err *decodeError) Error() string {
   199  	return fmt.Sprintf("%v (decode path: %s)", err.what, strings.Join(err.stack, "<-"))
   200  }