github.com/jonasnick/go-ethereum@v0.7.12-0.20150216215225-22176f05d387/trie/trie.go (about)

     1  package trie
     2  
     3  import (
     4  	"bytes"
     5  	"container/list"
     6  	"fmt"
     7  	"sync"
     8  
     9  	"github.com/jonasnick/go-ethereum/crypto"
    10  	"github.com/jonasnick/go-ethereum/ethutil"
    11  )
    12  
    13  func ParanoiaCheck(t1 *Trie, backend Backend) (bool, *Trie) {
    14  	t2 := New(nil, backend)
    15  
    16  	it := t1.Iterator()
    17  	for it.Next() {
    18  		t2.Update(it.Key, it.Value)
    19  	}
    20  
    21  	return bytes.Equal(t2.Hash(), t1.Hash()), t2
    22  }
    23  
    24  type Trie struct {
    25  	mu       sync.Mutex
    26  	root     Node
    27  	roothash []byte
    28  	cache    *Cache
    29  
    30  	revisions *list.List
    31  }
    32  
    33  func New(root []byte, backend Backend) *Trie {
    34  	trie := &Trie{}
    35  	trie.revisions = list.New()
    36  	trie.roothash = root
    37  	if backend != nil {
    38  		trie.cache = NewCache(backend)
    39  	}
    40  
    41  	if root != nil {
    42  		value := ethutil.NewValueFromBytes(trie.cache.Get(root))
    43  		trie.root = trie.mknode(value)
    44  	}
    45  
    46  	return trie
    47  }
    48  
    49  func (self *Trie) Iterator() *Iterator {
    50  	return NewIterator(self)
    51  }
    52  
    53  func (self *Trie) Copy() *Trie {
    54  	cpy := make([]byte, 32)
    55  	copy(cpy, self.roothash)
    56  	trie := New(nil, nil)
    57  	trie.cache = self.cache.Copy()
    58  	if self.root != nil {
    59  		trie.root = self.root.Copy(trie)
    60  	}
    61  
    62  	return trie
    63  }
    64  
    65  // Legacy support
    66  func (self *Trie) Root() []byte { return self.Hash() }
    67  func (self *Trie) Hash() []byte {
    68  	var hash []byte
    69  	if self.root != nil {
    70  		t := self.root.Hash()
    71  		if byts, ok := t.([]byte); ok && len(byts) > 0 {
    72  			hash = byts
    73  		} else {
    74  			hash = crypto.Sha3(ethutil.Encode(self.root.RlpData()))
    75  		}
    76  	} else {
    77  		hash = crypto.Sha3(ethutil.Encode(""))
    78  	}
    79  
    80  	if !bytes.Equal(hash, self.roothash) {
    81  		self.revisions.PushBack(self.roothash)
    82  		self.roothash = hash
    83  	}
    84  
    85  	return hash
    86  }
    87  func (self *Trie) Commit() {
    88  	self.mu.Lock()
    89  	defer self.mu.Unlock()
    90  
    91  	// Hash first
    92  	self.Hash()
    93  
    94  	self.cache.Flush()
    95  }
    96  
    97  // Reset should only be called if the trie has been hashed
    98  func (self *Trie) Reset() {
    99  	self.mu.Lock()
   100  	defer self.mu.Unlock()
   101  
   102  	self.cache.Reset()
   103  
   104  	if self.revisions.Len() > 0 {
   105  		revision := self.revisions.Remove(self.revisions.Back()).([]byte)
   106  		self.roothash = revision
   107  	}
   108  	value := ethutil.NewValueFromBytes(self.cache.Get(self.roothash))
   109  	self.root = self.mknode(value)
   110  }
   111  
   112  func (self *Trie) UpdateString(key, value string) Node { return self.Update([]byte(key), []byte(value)) }
   113  func (self *Trie) Update(key, value []byte) Node {
   114  	self.mu.Lock()
   115  	defer self.mu.Unlock()
   116  
   117  	k := CompactHexDecode(string(key))
   118  
   119  	if len(value) != 0 {
   120  		self.root = self.insert(self.root, k, &ValueNode{self, value})
   121  	} else {
   122  		self.root = self.delete(self.root, k)
   123  	}
   124  
   125  	return self.root
   126  }
   127  
   128  func (self *Trie) GetString(key string) []byte { return self.Get([]byte(key)) }
   129  func (self *Trie) Get(key []byte) []byte {
   130  	self.mu.Lock()
   131  	defer self.mu.Unlock()
   132  
   133  	k := CompactHexDecode(string(key))
   134  
   135  	n := self.get(self.root, k)
   136  	if n != nil {
   137  		return n.(*ValueNode).Val()
   138  	}
   139  
   140  	return nil
   141  }
   142  
   143  func (self *Trie) DeleteString(key string) Node { return self.Delete([]byte(key)) }
   144  func (self *Trie) Delete(key []byte) Node {
   145  	self.mu.Lock()
   146  	defer self.mu.Unlock()
   147  
   148  	k := CompactHexDecode(string(key))
   149  	self.root = self.delete(self.root, k)
   150  
   151  	return self.root
   152  }
   153  
   154  func (self *Trie) insert(node Node, key []byte, value Node) Node {
   155  	if len(key) == 0 {
   156  		return value
   157  	}
   158  
   159  	if node == nil {
   160  		return NewShortNode(self, key, value)
   161  	}
   162  
   163  	switch node := node.(type) {
   164  	case *ShortNode:
   165  		k := node.Key()
   166  		cnode := node.Value()
   167  		if bytes.Equal(k, key) {
   168  			return NewShortNode(self, key, value)
   169  		}
   170  
   171  		var n Node
   172  		matchlength := MatchingNibbleLength(key, k)
   173  		if matchlength == len(k) {
   174  			n = self.insert(cnode, key[matchlength:], value)
   175  		} else {
   176  			pnode := self.insert(nil, k[matchlength+1:], cnode)
   177  			nnode := self.insert(nil, key[matchlength+1:], value)
   178  			fulln := NewFullNode(self)
   179  			fulln.set(k[matchlength], pnode)
   180  			fulln.set(key[matchlength], nnode)
   181  			n = fulln
   182  		}
   183  		if matchlength == 0 {
   184  			return n
   185  		}
   186  
   187  		return NewShortNode(self, key[:matchlength], n)
   188  
   189  	case *FullNode:
   190  		cpy := node.Copy(self).(*FullNode)
   191  		cpy.set(key[0], self.insert(node.branch(key[0]), key[1:], value))
   192  
   193  		return cpy
   194  
   195  	default:
   196  		panic(fmt.Sprintf("%T: invalid node: %v", node, node))
   197  	}
   198  }
   199  
   200  func (self *Trie) get(node Node, key []byte) Node {
   201  	if len(key) == 0 {
   202  		return node
   203  	}
   204  
   205  	if node == nil {
   206  		return nil
   207  	}
   208  
   209  	switch node := node.(type) {
   210  	case *ShortNode:
   211  		k := node.Key()
   212  		cnode := node.Value()
   213  
   214  		if len(key) >= len(k) && bytes.Equal(k, key[:len(k)]) {
   215  			return self.get(cnode, key[len(k):])
   216  		}
   217  
   218  		return nil
   219  	case *FullNode:
   220  		return self.get(node.branch(key[0]), key[1:])
   221  	default:
   222  		panic(fmt.Sprintf("%T: invalid node: %v", node, node))
   223  	}
   224  }
   225  
   226  func (self *Trie) delete(node Node, key []byte) Node {
   227  	if len(key) == 0 && node == nil {
   228  		return nil
   229  	}
   230  
   231  	switch node := node.(type) {
   232  	case *ShortNode:
   233  		k := node.Key()
   234  		cnode := node.Value()
   235  		if bytes.Equal(key, k) {
   236  			return nil
   237  		} else if bytes.Equal(key[:len(k)], k) {
   238  			child := self.delete(cnode, key[len(k):])
   239  
   240  			var n Node
   241  			switch child := child.(type) {
   242  			case *ShortNode:
   243  				nkey := append(k, child.Key()...)
   244  				n = NewShortNode(self, nkey, child.Value())
   245  			case *FullNode:
   246  				sn := NewShortNode(self, node.Key(), child)
   247  				sn.key = node.key
   248  				n = sn
   249  			}
   250  
   251  			return n
   252  		} else {
   253  			return node
   254  		}
   255  
   256  	case *FullNode:
   257  		n := node.Copy(self).(*FullNode)
   258  		n.set(key[0], self.delete(n.branch(key[0]), key[1:]))
   259  
   260  		pos := -1
   261  		for i := 0; i < 17; i++ {
   262  			if n.branch(byte(i)) != nil {
   263  				if pos == -1 {
   264  					pos = i
   265  				} else {
   266  					pos = -2
   267  				}
   268  			}
   269  		}
   270  
   271  		var nnode Node
   272  		if pos == 16 {
   273  			nnode = NewShortNode(self, []byte{16}, n.branch(byte(pos)))
   274  		} else if pos >= 0 {
   275  			cnode := n.branch(byte(pos))
   276  			switch cnode := cnode.(type) {
   277  			case *ShortNode:
   278  				// Stitch keys
   279  				k := append([]byte{byte(pos)}, cnode.Key()...)
   280  				nnode = NewShortNode(self, k, cnode.Value())
   281  			case *FullNode:
   282  				nnode = NewShortNode(self, []byte{byte(pos)}, n.branch(byte(pos)))
   283  			}
   284  		} else {
   285  			nnode = n
   286  		}
   287  
   288  		return nnode
   289  	case nil:
   290  		return nil
   291  	default:
   292  		panic(fmt.Sprintf("%T: invalid node: %v (%v)", node, node, key))
   293  	}
   294  }
   295  
   296  // casting functions and cache storing
   297  func (self *Trie) mknode(value *ethutil.Value) Node {
   298  	l := value.Len()
   299  	switch l {
   300  	case 0:
   301  		return nil
   302  	case 2:
   303  		// A value node may consists of 2 bytes.
   304  		if value.Get(0).Len() != 0 {
   305  			return NewShortNode(self, CompactDecode(string(value.Get(0).Bytes())), self.mknode(value.Get(1)))
   306  		}
   307  	case 17:
   308  		fnode := NewFullNode(self)
   309  		for i := 0; i < l; i++ {
   310  			fnode.set(byte(i), self.mknode(value.Get(i)))
   311  		}
   312  		return fnode
   313  	case 32:
   314  		return &HashNode{value.Bytes(), self}
   315  	}
   316  
   317  	return &ValueNode{self, value.Bytes()}
   318  }
   319  
   320  func (self *Trie) trans(node Node) Node {
   321  	switch node := node.(type) {
   322  	case *HashNode:
   323  		value := ethutil.NewValueFromBytes(self.cache.Get(node.key))
   324  		return self.mknode(value)
   325  	default:
   326  		return node
   327  	}
   328  }
   329  
   330  func (self *Trie) store(node Node) interface{} {
   331  	data := ethutil.Encode(node)
   332  	if len(data) >= 32 {
   333  		key := crypto.Sha3(data)
   334  		self.cache.Put(key, data)
   335  
   336  		return key
   337  	}
   338  
   339  	return node.RlpData()
   340  }
   341  
   342  func (self *Trie) PrintRoot() {
   343  	fmt.Println(self.root)
   344  	fmt.Printf("root=%x\n", self.Root())
   345  }