github.com/bigzoro/my_simplechain@v0.0.0-20240315012955-8ad0a2a29bb9/parallel_trie/trie_dag.go (about) 1 package trie 2 3 import ( 4 "runtime" 5 "sync" 6 "sync/atomic" 7 8 "github.com/bigzoro/my_simplechain/common" 9 10 "github.com/cespare/xxhash" 11 ) 12 13 var fullNodeSuffix = []byte("fullnode") 14 15 // dagNode 16 type dagNode struct { 17 collapsed node 18 cached node 19 20 pid uint64 21 idx int 22 } 23 24 // trieDag 25 type trieDag struct { 26 nodes map[uint64]*dagNode 27 dag *dag 28 29 lock sync.Mutex 30 31 loged bool 32 } 33 34 func newTrieDag() *trieDag { 35 return &trieDag{ 36 nodes: make(map[uint64]*dagNode), 37 dag: newDag(), 38 loged: false, 39 } 40 } 41 42 func (td *trieDag) addVertexAndEdge(pprefix, prefix []byte, n node) { 43 td.lock.Lock() 44 defer td.lock.Unlock() 45 td.internalAddVertexAndEdge(pprefix, prefix, n, true) 46 } 47 48 func (td *trieDag) internalAddVertexAndEdge(pprefix, prefix []byte, n node, recursive bool) { 49 var pid uint64 50 if len(pprefix) > 0 { 51 pid = xxhash.Sum64(pprefix) 52 } 53 54 cachedHash := func(n node) (node, bool) { 55 if hash, _ := n.cache(); len(hash) != 0 { 56 return hash, true 57 } 58 return n, false 59 } 60 61 switch nc := n.(type) { 62 case *shortNode: 63 collapsed, cached := nc.copy(), nc.copy() 64 collapsed.Key = hexToCompact(nc.Key) 65 cached.Key = common.CopyBytes(nc.Key) 66 67 hash, has := cachedHash(nc.Val) 68 if has { 69 hash, _ = hash.(hashNode) 70 collapsed.Val = hash 71 } 72 73 id := xxhash.Sum64(append(prefix, nc.Key...)) 74 td.nodes[id] = &dagNode{ 75 collapsed: collapsed, 76 cached: cached, 77 pid: pid, 78 } 79 if len(prefix) > 0 { 80 td.nodes[id].idx = int(prefix[len(prefix)-1]) 81 } 82 td.dag.addVertex(id) 83 84 if pid > 0 { 85 td.dag.addEdge(id, pid) 86 } 87 88 case *fullNode: 89 collapsed, cached := nc.copy(), nc.copy() 90 cached.Children[16] = nc.Children[16] 91 92 dagNode := &dagNode{ 93 collapsed: collapsed, 94 cached: cached, 95 pid: pid, 96 } 97 if len(prefix) > 0 { 98 dagNode.idx = int(prefix[len(prefix)-1]) 99 } 100 101 id := xxhash.Sum64(append(prefix, fullNodeSuffix...)) 102 td.nodes[id] = dagNode 103 td.dag.addVertex(id) 104 if pid > 0 { 105 td.dag.addEdge(id, pid) 106 } 107 108 if recursive { 109 for i := 0; i < 16; i++ { 110 if cached.Children[i] != nil { 111 cn := cached.Children[i] 112 td.internalAddVertexAndEdge(append(prefix, fullNodeSuffix...), append(prefix, byte(i)), cn, false) 113 } 114 } 115 } 116 } 117 } 118 119 func (td *trieDag) delVertexAndEdge(key []byte) { 120 id := xxhash.Sum64(key) 121 td.delVertexAndEdgeByID(id) 122 } 123 124 func (td *trieDag) delVertexAndEdgeByID(id uint64) { 125 td.lock.Lock() 126 defer td.lock.Unlock() 127 //td.dag.delEdge(id) 128 td.dag.delVertex(id) 129 delete(td.nodes, id) 130 //fmt.Printf("del: %d\n", id) 131 } 132 133 func (td *trieDag) delVertexAndEdgeByNode(prefix []byte, n node) { 134 var id uint64 135 switch nc := n.(type) { 136 case *shortNode: 137 id = xxhash.Sum64(append(prefix, nc.Key...)) 138 case *fullNode: 139 id = xxhash.Sum64(append(prefix, fullNodeSuffix...)) 140 } 141 td.delVertexAndEdgeByID(id) 142 } 143 144 func (td *trieDag) clear() { 145 td.lock.Lock() 146 defer td.lock.Unlock() 147 148 td.dag.clear() 149 td.nodes = make(map[uint64]*dagNode) 150 } 151 152 func (td *trieDag) hash(db *Database, force bool, onleaf LeafCallback) (node, node, error) { 153 td.lock.Lock() 154 defer td.lock.Unlock() 155 156 td.dag.generate() 157 158 //log.Trace("Prepare do hash", "me", fmt.Sprintf("%p", td), "routineID", goid.Get(), "dag", fmt.Sprintf("%p", td.dag), "nodes", len(td.nodes), "topLevel", td.dag.topLevel.Len(), "consumed", td.dag.totalConsumed, "vtxs", td.dag.totalVertexs, "cv", td.dag.cv) 159 160 var wg sync.WaitGroup 161 var errDone uint32 162 var e atomic.Value // error 163 var resHash node = hashNode{} 164 var newRoot node 165 numCPU := runtime.NumCPU() 166 167 cachedHash := func(n, c node) (node, node, bool) { 168 if hash, dirty := c.cache(); len(hash) != 0 { 169 if db == nil { 170 return hash, c, true 171 } 172 173 if !dirty { 174 return hash, c, true 175 } 176 } 177 return n, c, false 178 } 179 180 process := func() { 181 //log.Trace("Do hash", "me", fmt.Sprintf("%p", td), "routineID", goid.Get(), "dag", fmt.Sprintf("%p", td.dag), "nodes", len(td.nodes), "topLevel", td.dag.topLevel.Len(), "consumed", td.dag.totalConsumed, "vtxs", td.dag.totalVertexs, "cv", td.dag.cv) 182 hasher := newHasher(onleaf) 183 184 id := td.dag.waitPop() 185 if id == invalidID { 186 returnHasherToPool(hasher) 187 wg.Done() 188 return 189 } 190 191 var hashed node 192 var cached node 193 var err error 194 var hasCache bool 195 for id != invalidID { 196 n := td.nodes[id] 197 198 tmpForce := false 199 if n.pid == 0 { 200 tmpForce = force 201 } 202 203 hashed, cached, hasCache = cachedHash(n.collapsed, n.cached) 204 if !hasCache { 205 switch ct := n.collapsed.(type) { 206 case *fullNode: 207 for i := 0; i < 16; i++ { 208 if ct.Children[i] != nil { 209 nc := ct.Children[i] 210 if _, isHash := nc.(hashNode); !isHash { 211 h, _, _ := cachedHash(nc, nc) 212 ct.Children[i] = h 213 } 214 } 215 } 216 } 217 218 hashed, err = hasher.store(n.collapsed, db, tmpForce) 219 if err != nil { 220 e.Store(err) 221 atomic.StoreUint32(&errDone, 1) 222 break 223 } 224 cached = n.cached 225 } 226 227 if n.pid > 0 { 228 p := td.nodes[n.pid] 229 switch ptype := p.collapsed.(type) { 230 case *shortNode: 231 ptype.Val = hashed 232 case *fullNode: 233 ptype.Children[n.idx] = hashed 234 } 235 236 if _, ok := cached.(hashNode); ok { 237 switch nc := p.cached.(type) { 238 case *shortNode: 239 nc.Val = cached 240 case *fullNode: 241 nc.Children[n.idx] = cached 242 } 243 } 244 } 245 246 cachedHash, _ := hashed.(hashNode) 247 switch cn := n.cached.(type) { 248 case *shortNode: 249 *cn.flags.hash = cachedHash 250 if db != nil { 251 *cn.flags.dirty = false 252 } 253 case *fullNode: 254 *cn.flags.hash = cachedHash 255 if db != nil { 256 *cn.flags.dirty = false 257 } 258 } 259 260 id = td.dag.consume(id) 261 if n.pid == 0 { 262 resHash = hashed 263 newRoot = n.cached 264 break 265 } 266 267 if atomic.LoadUint32(&errDone) > 0 { 268 break 269 } 270 271 if id == invalidID && !td.dag.hasFinished() { 272 id = td.dag.waitPop() 273 } 274 } 275 returnHasherToPool(hasher) 276 wg.Done() 277 } 278 279 wg.Add(numCPU) 280 for i := 0; i < numCPU; i++ { 281 go process() 282 } 283 284 wg.Wait() 285 td.dag.reset() 286 td.loged = true 287 288 if e.Load() != nil && e.Load().(error) != nil { 289 return hashNode{}, nil, e.Load().(error) 290 } 291 return resHash, newRoot, nil 292 }