github.com/jonasnick/go-ethereum@v0.7.12-0.20150216215225-22176f05d387/trie/trie.go (about) 1 package trie 2 3 import ( 4 "bytes" 5 "container/list" 6 "fmt" 7 "sync" 8 9 "github.com/jonasnick/go-ethereum/crypto" 10 "github.com/jonasnick/go-ethereum/ethutil" 11 ) 12 13 func ParanoiaCheck(t1 *Trie, backend Backend) (bool, *Trie) { 14 t2 := New(nil, backend) 15 16 it := t1.Iterator() 17 for it.Next() { 18 t2.Update(it.Key, it.Value) 19 } 20 21 return bytes.Equal(t2.Hash(), t1.Hash()), t2 22 } 23 24 type Trie struct { 25 mu sync.Mutex 26 root Node 27 roothash []byte 28 cache *Cache 29 30 revisions *list.List 31 } 32 33 func New(root []byte, backend Backend) *Trie { 34 trie := &Trie{} 35 trie.revisions = list.New() 36 trie.roothash = root 37 if backend != nil { 38 trie.cache = NewCache(backend) 39 } 40 41 if root != nil { 42 value := ethutil.NewValueFromBytes(trie.cache.Get(root)) 43 trie.root = trie.mknode(value) 44 } 45 46 return trie 47 } 48 49 func (self *Trie) Iterator() *Iterator { 50 return NewIterator(self) 51 } 52 53 func (self *Trie) Copy() *Trie { 54 cpy := make([]byte, 32) 55 copy(cpy, self.roothash) 56 trie := New(nil, nil) 57 trie.cache = self.cache.Copy() 58 if self.root != nil { 59 trie.root = self.root.Copy(trie) 60 } 61 62 return trie 63 } 64 65 // Legacy support 66 func (self *Trie) Root() []byte { return self.Hash() } 67 func (self *Trie) Hash() []byte { 68 var hash []byte 69 if self.root != nil { 70 t := self.root.Hash() 71 if byts, ok := t.([]byte); ok && len(byts) > 0 { 72 hash = byts 73 } else { 74 hash = crypto.Sha3(ethutil.Encode(self.root.RlpData())) 75 } 76 } else { 77 hash = crypto.Sha3(ethutil.Encode("")) 78 } 79 80 if !bytes.Equal(hash, self.roothash) { 81 self.revisions.PushBack(self.roothash) 82 self.roothash = hash 83 } 84 85 return hash 86 } 87 func (self *Trie) Commit() { 88 self.mu.Lock() 89 defer self.mu.Unlock() 90 91 // Hash first 92 self.Hash() 93 94 self.cache.Flush() 95 } 96 97 // Reset should only be called if the trie has been hashed 98 func (self *Trie) Reset() { 99 self.mu.Lock() 100 defer self.mu.Unlock() 101 102 self.cache.Reset() 103 104 if self.revisions.Len() > 0 { 105 revision := self.revisions.Remove(self.revisions.Back()).([]byte) 106 self.roothash = revision 107 } 108 value := ethutil.NewValueFromBytes(self.cache.Get(self.roothash)) 109 self.root = self.mknode(value) 110 } 111 112 func (self *Trie) UpdateString(key, value string) Node { return self.Update([]byte(key), []byte(value)) } 113 func (self *Trie) Update(key, value []byte) Node { 114 self.mu.Lock() 115 defer self.mu.Unlock() 116 117 k := CompactHexDecode(string(key)) 118 119 if len(value) != 0 { 120 self.root = self.insert(self.root, k, &ValueNode{self, value}) 121 } else { 122 self.root = self.delete(self.root, k) 123 } 124 125 return self.root 126 } 127 128 func (self *Trie) GetString(key string) []byte { return self.Get([]byte(key)) } 129 func (self *Trie) Get(key []byte) []byte { 130 self.mu.Lock() 131 defer self.mu.Unlock() 132 133 k := CompactHexDecode(string(key)) 134 135 n := self.get(self.root, k) 136 if n != nil { 137 return n.(*ValueNode).Val() 138 } 139 140 return nil 141 } 142 143 func (self *Trie) DeleteString(key string) Node { return self.Delete([]byte(key)) } 144 func (self *Trie) Delete(key []byte) Node { 145 self.mu.Lock() 146 defer self.mu.Unlock() 147 148 k := CompactHexDecode(string(key)) 149 self.root = self.delete(self.root, k) 150 151 return self.root 152 } 153 154 func (self *Trie) insert(node Node, key []byte, value Node) Node { 155 if len(key) == 0 { 156 return value 157 } 158 159 if node == nil { 160 return NewShortNode(self, key, value) 161 } 162 163 switch node := node.(type) { 164 case *ShortNode: 165 k := node.Key() 166 cnode := node.Value() 167 if bytes.Equal(k, key) { 168 return NewShortNode(self, key, value) 169 } 170 171 var n Node 172 matchlength := MatchingNibbleLength(key, k) 173 if matchlength == len(k) { 174 n = self.insert(cnode, key[matchlength:], value) 175 } else { 176 pnode := self.insert(nil, k[matchlength+1:], cnode) 177 nnode := self.insert(nil, key[matchlength+1:], value) 178 fulln := NewFullNode(self) 179 fulln.set(k[matchlength], pnode) 180 fulln.set(key[matchlength], nnode) 181 n = fulln 182 } 183 if matchlength == 0 { 184 return n 185 } 186 187 return NewShortNode(self, key[:matchlength], n) 188 189 case *FullNode: 190 cpy := node.Copy(self).(*FullNode) 191 cpy.set(key[0], self.insert(node.branch(key[0]), key[1:], value)) 192 193 return cpy 194 195 default: 196 panic(fmt.Sprintf("%T: invalid node: %v", node, node)) 197 } 198 } 199 200 func (self *Trie) get(node Node, key []byte) Node { 201 if len(key) == 0 { 202 return node 203 } 204 205 if node == nil { 206 return nil 207 } 208 209 switch node := node.(type) { 210 case *ShortNode: 211 k := node.Key() 212 cnode := node.Value() 213 214 if len(key) >= len(k) && bytes.Equal(k, key[:len(k)]) { 215 return self.get(cnode, key[len(k):]) 216 } 217 218 return nil 219 case *FullNode: 220 return self.get(node.branch(key[0]), key[1:]) 221 default: 222 panic(fmt.Sprintf("%T: invalid node: %v", node, node)) 223 } 224 } 225 226 func (self *Trie) delete(node Node, key []byte) Node { 227 if len(key) == 0 && node == nil { 228 return nil 229 } 230 231 switch node := node.(type) { 232 case *ShortNode: 233 k := node.Key() 234 cnode := node.Value() 235 if bytes.Equal(key, k) { 236 return nil 237 } else if bytes.Equal(key[:len(k)], k) { 238 child := self.delete(cnode, key[len(k):]) 239 240 var n Node 241 switch child := child.(type) { 242 case *ShortNode: 243 nkey := append(k, child.Key()...) 244 n = NewShortNode(self, nkey, child.Value()) 245 case *FullNode: 246 sn := NewShortNode(self, node.Key(), child) 247 sn.key = node.key 248 n = sn 249 } 250 251 return n 252 } else { 253 return node 254 } 255 256 case *FullNode: 257 n := node.Copy(self).(*FullNode) 258 n.set(key[0], self.delete(n.branch(key[0]), key[1:])) 259 260 pos := -1 261 for i := 0; i < 17; i++ { 262 if n.branch(byte(i)) != nil { 263 if pos == -1 { 264 pos = i 265 } else { 266 pos = -2 267 } 268 } 269 } 270 271 var nnode Node 272 if pos == 16 { 273 nnode = NewShortNode(self, []byte{16}, n.branch(byte(pos))) 274 } else if pos >= 0 { 275 cnode := n.branch(byte(pos)) 276 switch cnode := cnode.(type) { 277 case *ShortNode: 278 // Stitch keys 279 k := append([]byte{byte(pos)}, cnode.Key()...) 280 nnode = NewShortNode(self, k, cnode.Value()) 281 case *FullNode: 282 nnode = NewShortNode(self, []byte{byte(pos)}, n.branch(byte(pos))) 283 } 284 } else { 285 nnode = n 286 } 287 288 return nnode 289 case nil: 290 return nil 291 default: 292 panic(fmt.Sprintf("%T: invalid node: %v (%v)", node, node, key)) 293 } 294 } 295 296 // casting functions and cache storing 297 func (self *Trie) mknode(value *ethutil.Value) Node { 298 l := value.Len() 299 switch l { 300 case 0: 301 return nil 302 case 2: 303 // A value node may consists of 2 bytes. 304 if value.Get(0).Len() != 0 { 305 return NewShortNode(self, CompactDecode(string(value.Get(0).Bytes())), self.mknode(value.Get(1))) 306 } 307 case 17: 308 fnode := NewFullNode(self) 309 for i := 0; i < l; i++ { 310 fnode.set(byte(i), self.mknode(value.Get(i))) 311 } 312 return fnode 313 case 32: 314 return &HashNode{value.Bytes(), self} 315 } 316 317 return &ValueNode{self, value.Bytes()} 318 } 319 320 func (self *Trie) trans(node Node) Node { 321 switch node := node.(type) { 322 case *HashNode: 323 value := ethutil.NewValueFromBytes(self.cache.Get(node.key)) 324 return self.mknode(value) 325 default: 326 return node 327 } 328 } 329 330 func (self *Trie) store(node Node) interface{} { 331 data := ethutil.Encode(node) 332 if len(data) >= 32 { 333 key := crypto.Sha3(data) 334 self.cache.Put(key, data) 335 336 return key 337 } 338 339 return node.RlpData() 340 } 341 342 func (self *Trie) PrintRoot() { 343 fmt.Println(self.root) 344 fmt.Printf("root=%x\n", self.Root()) 345 }