github.com/neatlab/neatio@v1.7.3-0.20220425043230-d903e92fcc75/chain/trie/hasher.go (about)

     1  package trie
     2  
     3  import (
     4  	"hash"
     5  	"sync"
     6  
     7  	"github.com/neatlab/neatio/utilities/common"
     8  	"github.com/neatlab/neatio/utilities/rlp"
     9  	"golang.org/x/crypto/sha3"
    10  )
    11  
    12  type hasher struct {
    13  	tmp    sliceBuffer
    14  	sha    keccakState
    15  	onleaf LeafCallback
    16  }
    17  
    18  type keccakState interface {
    19  	hash.Hash
    20  	Read([]byte) (int, error)
    21  }
    22  
    23  type sliceBuffer []byte
    24  
    25  func (b *sliceBuffer) Write(data []byte) (n int, err error) {
    26  	*b = append(*b, data...)
    27  	return len(data), nil
    28  }
    29  
    30  func (b *sliceBuffer) Reset() {
    31  	*b = (*b)[:0]
    32  }
    33  
    34  var hasherPool = sync.Pool{
    35  	New: func() interface{} {
    36  		return &hasher{
    37  			tmp: make(sliceBuffer, 0, 550),
    38  			sha: sha3.NewLegacyKeccak256().(keccakState),
    39  		}
    40  	},
    41  }
    42  
    43  func newHasher(onleaf LeafCallback) *hasher {
    44  	h := hasherPool.Get().(*hasher)
    45  	h.onleaf = onleaf
    46  	return h
    47  }
    48  
    49  func returnHasherToPool(h *hasher) {
    50  	hasherPool.Put(h)
    51  }
    52  
    53  func (h *hasher) hash(n node, db *Database, force bool) (node, node, error) {
    54  
    55  	if hash, dirty := n.cache(); hash != nil {
    56  		if db == nil {
    57  			return hash, n, nil
    58  		}
    59  		if !dirty {
    60  			switch n.(type) {
    61  			case *fullNode, *shortNode:
    62  				return hash, hash, nil
    63  			default:
    64  				return hash, n, nil
    65  			}
    66  		}
    67  	}
    68  
    69  	collapsed, cached, err := h.hashChildren(n, db)
    70  	if err != nil {
    71  		return hashNode{}, n, err
    72  	}
    73  	hashed, err := h.store(collapsed, db, force)
    74  	if err != nil {
    75  		return hashNode{}, n, err
    76  	}
    77  
    78  	cachedHash, _ := hashed.(hashNode)
    79  	switch cn := cached.(type) {
    80  	case *shortNode:
    81  		cn.flags.hash = cachedHash
    82  		if db != nil {
    83  			cn.flags.dirty = false
    84  		}
    85  	case *fullNode:
    86  		cn.flags.hash = cachedHash
    87  		if db != nil {
    88  			cn.flags.dirty = false
    89  		}
    90  	}
    91  	return hashed, cached, nil
    92  }
    93  
    94  func (h *hasher) hashChildren(original node, db *Database) (node, node, error) {
    95  	var err error
    96  
    97  	switch n := original.(type) {
    98  	case *shortNode:
    99  
   100  		collapsed, cached := n.copy(), n.copy()
   101  		collapsed.Key = hexToCompact(n.Key)
   102  		cached.Key = common.CopyBytes(n.Key)
   103  
   104  		if _, ok := n.Val.(valueNode); !ok {
   105  			collapsed.Val, cached.Val, err = h.hash(n.Val, db, false)
   106  			if err != nil {
   107  				return original, original, err
   108  			}
   109  		}
   110  		return collapsed, cached, nil
   111  
   112  	case *fullNode:
   113  
   114  		collapsed, cached := n.copy(), n.copy()
   115  
   116  		for i := 0; i < 16; i++ {
   117  			if n.Children[i] != nil {
   118  				collapsed.Children[i], cached.Children[i], err = h.hash(n.Children[i], db, false)
   119  				if err != nil {
   120  					return original, original, err
   121  				}
   122  			}
   123  		}
   124  		cached.Children[16] = n.Children[16]
   125  		return collapsed, cached, nil
   126  
   127  	default:
   128  
   129  		return n, original, nil
   130  	}
   131  }
   132  
   133  func (h *hasher) store(n node, db *Database, force bool) (node, error) {
   134  
   135  	if _, isHash := n.(hashNode); n == nil || isHash {
   136  		return n, nil
   137  	}
   138  
   139  	h.tmp.Reset()
   140  	if err := rlp.Encode(&h.tmp, n); err != nil {
   141  		panic("encode error: " + err.Error())
   142  	}
   143  	if len(h.tmp) < 32 && !force {
   144  		return n, nil
   145  	}
   146  
   147  	hash, _ := n.cache()
   148  	if hash == nil {
   149  		hash = h.makeHashNode(h.tmp)
   150  	}
   151  
   152  	if db != nil {
   153  
   154  		hash := common.BytesToHash(hash)
   155  
   156  		db.lock.Lock()
   157  		db.insert(hash, h.tmp, n)
   158  		db.lock.Unlock()
   159  
   160  		if h.onleaf != nil {
   161  			switch n := n.(type) {
   162  			case *shortNode:
   163  				if side, ok := n.Val.(valueNode); ok {
   164  					h.onleaf(side, hash)
   165  				}
   166  			case *fullNode:
   167  				for i := 0; i < 16; i++ {
   168  					if side, ok := n.Children[i].(valueNode); ok {
   169  						h.onleaf(side, hash)
   170  					}
   171  				}
   172  			}
   173  		}
   174  	}
   175  	return hash, nil
   176  }
   177  
   178  func (h *hasher) makeHashNode(data []byte) hashNode {
   179  	n := make(hashNode, h.sha.Size())
   180  	h.sha.Reset()
   181  	h.sha.Write(data)
   182  	h.sha.Read(n)
   183  	return n
   184  }