github.com/bigzoro/my_simplechain@v0.0.0-20240315012955-8ad0a2a29bb9/parallel_trie/trie_dag.go (about)

     1  package trie
     2  
     3  import (
     4  	"runtime"
     5  	"sync"
     6  	"sync/atomic"
     7  
     8  	"github.com/bigzoro/my_simplechain/common"
     9  
    10  	"github.com/cespare/xxhash"
    11  )
    12  
    13  var fullNodeSuffix = []byte("fullnode")
    14  
    15  // dagNode
    16  type dagNode struct {
    17  	collapsed node
    18  	cached    node
    19  
    20  	pid uint64
    21  	idx int
    22  }
    23  
    24  // trieDag
    25  type trieDag struct {
    26  	nodes map[uint64]*dagNode
    27  	dag   *dag
    28  
    29  	lock sync.Mutex
    30  
    31  	loged bool
    32  }
    33  
    34  func newTrieDag() *trieDag {
    35  	return &trieDag{
    36  		nodes: make(map[uint64]*dagNode),
    37  		dag:   newDag(),
    38  		loged: false,
    39  	}
    40  }
    41  
    42  func (td *trieDag) addVertexAndEdge(pprefix, prefix []byte, n node) {
    43  	td.lock.Lock()
    44  	defer td.lock.Unlock()
    45  	td.internalAddVertexAndEdge(pprefix, prefix, n, true)
    46  }
    47  
    48  func (td *trieDag) internalAddVertexAndEdge(pprefix, prefix []byte, n node, recursive bool) {
    49  	var pid uint64
    50  	if len(pprefix) > 0 {
    51  		pid = xxhash.Sum64(pprefix)
    52  	}
    53  
    54  	cachedHash := func(n node) (node, bool) {
    55  		if hash, _ := n.cache(); len(hash) != 0 {
    56  			return hash, true
    57  		}
    58  		return n, false
    59  	}
    60  
    61  	switch nc := n.(type) {
    62  	case *shortNode:
    63  		collapsed, cached := nc.copy(), nc.copy()
    64  		collapsed.Key = hexToCompact(nc.Key)
    65  		cached.Key = common.CopyBytes(nc.Key)
    66  
    67  		hash, has := cachedHash(nc.Val)
    68  		if has {
    69  			hash, _ = hash.(hashNode)
    70  			collapsed.Val = hash
    71  		}
    72  
    73  		id := xxhash.Sum64(append(prefix, nc.Key...))
    74  		td.nodes[id] = &dagNode{
    75  			collapsed: collapsed,
    76  			cached:    cached,
    77  			pid:       pid,
    78  		}
    79  		if len(prefix) > 0 {
    80  			td.nodes[id].idx = int(prefix[len(prefix)-1])
    81  		}
    82  		td.dag.addVertex(id)
    83  
    84  		if pid > 0 {
    85  			td.dag.addEdge(id, pid)
    86  		}
    87  
    88  	case *fullNode:
    89  		collapsed, cached := nc.copy(), nc.copy()
    90  		cached.Children[16] = nc.Children[16]
    91  
    92  		dagNode := &dagNode{
    93  			collapsed: collapsed,
    94  			cached:    cached,
    95  			pid:       pid,
    96  		}
    97  		if len(prefix) > 0 {
    98  			dagNode.idx = int(prefix[len(prefix)-1])
    99  		}
   100  
   101  		id := xxhash.Sum64(append(prefix, fullNodeSuffix...))
   102  		td.nodes[id] = dagNode
   103  		td.dag.addVertex(id)
   104  		if pid > 0 {
   105  			td.dag.addEdge(id, pid)
   106  		}
   107  
   108  		if recursive {
   109  			for i := 0; i < 16; i++ {
   110  				if cached.Children[i] != nil {
   111  					cn := cached.Children[i]
   112  					td.internalAddVertexAndEdge(append(prefix, fullNodeSuffix...), append(prefix, byte(i)), cn, false)
   113  				}
   114  			}
   115  		}
   116  	}
   117  }
   118  
   119  func (td *trieDag) delVertexAndEdge(key []byte) {
   120  	id := xxhash.Sum64(key)
   121  	td.delVertexAndEdgeByID(id)
   122  }
   123  
   124  func (td *trieDag) delVertexAndEdgeByID(id uint64) {
   125  	td.lock.Lock()
   126  	defer td.lock.Unlock()
   127  	//td.dag.delEdge(id)
   128  	td.dag.delVertex(id)
   129  	delete(td.nodes, id)
   130  	//fmt.Printf("del: %d\n", id)
   131  }
   132  
   133  func (td *trieDag) delVertexAndEdgeByNode(prefix []byte, n node) {
   134  	var id uint64
   135  	switch nc := n.(type) {
   136  	case *shortNode:
   137  		id = xxhash.Sum64(append(prefix, nc.Key...))
   138  	case *fullNode:
   139  		id = xxhash.Sum64(append(prefix, fullNodeSuffix...))
   140  	}
   141  	td.delVertexAndEdgeByID(id)
   142  }
   143  
   144  func (td *trieDag) clear() {
   145  	td.lock.Lock()
   146  	defer td.lock.Unlock()
   147  
   148  	td.dag.clear()
   149  	td.nodes = make(map[uint64]*dagNode)
   150  }
   151  
   152  func (td *trieDag) hash(db *Database, force bool, onleaf LeafCallback) (node, node, error) {
   153  	td.lock.Lock()
   154  	defer td.lock.Unlock()
   155  
   156  	td.dag.generate()
   157  
   158  	//log.Trace("Prepare do hash", "me", fmt.Sprintf("%p", td), "routineID", goid.Get(), "dag", fmt.Sprintf("%p", td.dag), "nodes", len(td.nodes), "topLevel", td.dag.topLevel.Len(), "consumed", td.dag.totalConsumed, "vtxs", td.dag.totalVertexs, "cv", td.dag.cv)
   159  
   160  	var wg sync.WaitGroup
   161  	var errDone uint32
   162  	var e atomic.Value // error
   163  	var resHash node = hashNode{}
   164  	var newRoot node
   165  	numCPU := runtime.NumCPU()
   166  
   167  	cachedHash := func(n, c node) (node, node, bool) {
   168  		if hash, dirty := c.cache(); len(hash) != 0 {
   169  			if db == nil {
   170  				return hash, c, true
   171  			}
   172  
   173  			if !dirty {
   174  				return hash, c, true
   175  			}
   176  		}
   177  		return n, c, false
   178  	}
   179  
   180  	process := func() {
   181  		//log.Trace("Do hash", "me", fmt.Sprintf("%p", td), "routineID", goid.Get(), "dag", fmt.Sprintf("%p", td.dag), "nodes", len(td.nodes), "topLevel", td.dag.topLevel.Len(), "consumed", td.dag.totalConsumed, "vtxs", td.dag.totalVertexs, "cv", td.dag.cv)
   182  		hasher := newHasher(onleaf)
   183  
   184  		id := td.dag.waitPop()
   185  		if id == invalidID {
   186  			returnHasherToPool(hasher)
   187  			wg.Done()
   188  			return
   189  		}
   190  
   191  		var hashed node
   192  		var cached node
   193  		var err error
   194  		var hasCache bool
   195  		for id != invalidID {
   196  			n := td.nodes[id]
   197  
   198  			tmpForce := false
   199  			if n.pid == 0 {
   200  				tmpForce = force
   201  			}
   202  
   203  			hashed, cached, hasCache = cachedHash(n.collapsed, n.cached)
   204  			if !hasCache {
   205  				switch ct := n.collapsed.(type) {
   206  				case *fullNode:
   207  					for i := 0; i < 16; i++ {
   208  						if ct.Children[i] != nil {
   209  							nc := ct.Children[i]
   210  							if _, isHash := nc.(hashNode); !isHash {
   211  								h, _, _ := cachedHash(nc, nc)
   212  								ct.Children[i] = h
   213  							}
   214  						}
   215  					}
   216  				}
   217  
   218  				hashed, err = hasher.store(n.collapsed, db, tmpForce)
   219  				if err != nil {
   220  					e.Store(err)
   221  					atomic.StoreUint32(&errDone, 1)
   222  					break
   223  				}
   224  				cached = n.cached
   225  			}
   226  
   227  			if n.pid > 0 {
   228  				p := td.nodes[n.pid]
   229  				switch ptype := p.collapsed.(type) {
   230  				case *shortNode:
   231  					ptype.Val = hashed
   232  				case *fullNode:
   233  					ptype.Children[n.idx] = hashed
   234  				}
   235  
   236  				if _, ok := cached.(hashNode); ok {
   237  					switch nc := p.cached.(type) {
   238  					case *shortNode:
   239  						nc.Val = cached
   240  					case *fullNode:
   241  						nc.Children[n.idx] = cached
   242  					}
   243  				}
   244  			}
   245  
   246  			cachedHash, _ := hashed.(hashNode)
   247  			switch cn := n.cached.(type) {
   248  			case *shortNode:
   249  				*cn.flags.hash = cachedHash
   250  				if db != nil {
   251  					*cn.flags.dirty = false
   252  				}
   253  			case *fullNode:
   254  				*cn.flags.hash = cachedHash
   255  				if db != nil {
   256  					*cn.flags.dirty = false
   257  				}
   258  			}
   259  
   260  			id = td.dag.consume(id)
   261  			if n.pid == 0 {
   262  				resHash = hashed
   263  				newRoot = n.cached
   264  				break
   265  			}
   266  
   267  			if atomic.LoadUint32(&errDone) > 0 {
   268  				break
   269  			}
   270  
   271  			if id == invalidID && !td.dag.hasFinished() {
   272  				id = td.dag.waitPop()
   273  			}
   274  		}
   275  		returnHasherToPool(hasher)
   276  		wg.Done()
   277  	}
   278  
   279  	wg.Add(numCPU)
   280  	for i := 0; i < numCPU; i++ {
   281  		go process()
   282  	}
   283  
   284  	wg.Wait()
   285  	td.dag.reset()
   286  	td.loged = true
   287  
   288  	if e.Load() != nil && e.Load().(error) != nil {
   289  		return hashNode{}, nil, e.Load().(error)
   290  	}
   291  	return resHash, newRoot, nil
   292  }