github.com/unicornultrafoundation/go-u2u@v1.0.0-rc1.0.20240205080301-e74a83d3fadc/topicsdb/search_parallel.go (about)

     1  package topicsdb
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  
     7  	"github.com/unicornultrafoundation/go-u2u/common"
     8  )
     9  
    10  type logHandler func(rec *logrec) (gonext bool, err error)
    11  
    12  func (tt *index) searchParallel(ctx context.Context, pattern [][]common.Hash, blockStart, blockEnd uint64, onMatched logHandler, onDbIterator func()) error {
    13  	if ctx == nil {
    14  		ctx = context.Background()
    15  	}
    16  
    17  	var (
    18  		syncing      = newSynchronizator()
    19  		mu           sync.Mutex
    20  		foundByBlock = make(map[uint64]map[ID]*logrec)
    21  	)
    22  
    23  	aggregator := func(pos, num int) logHandler {
    24  		return func(rec *logrec) (gonext bool, err error) {
    25  			if rec == nil {
    26  				syncing.FinishThread(pos, num)
    27  				return
    28  			}
    29  
    30  			err = ctx.Err()
    31  			if err != nil {
    32  				return
    33  			}
    34  
    35  			block := rec.ID.BlockNumber()
    36  			if blockEnd > 0 && block > blockEnd {
    37  				return
    38  			}
    39  			if rec.topicsCount < uint8(len(pattern)-1) {
    40  				gonext = true
    41  				return
    42  			}
    43  
    44  			var prevBlock uint64
    45  			prevBlock, gonext = syncing.GoNext(block)
    46  			if !gonext {
    47  				return
    48  			}
    49  
    50  			mu.Lock()
    51  			defer mu.Unlock()
    52  
    53  			if prevBlock > 0 {
    54  				delete(foundByBlock, prevBlock)
    55  			}
    56  
    57  			found, ok := foundByBlock[block]
    58  			if !ok {
    59  				found = make(map[ID]*logrec)
    60  				foundByBlock[block] = found
    61  			}
    62  
    63  			if before, ok := found[rec.ID]; ok {
    64  				rec = before
    65  			} else {
    66  				found[rec.ID] = rec
    67  			}
    68  			rec.matched++
    69  			if rec.matched == syncing.PositionsCount() {
    70  				gonext, err = onMatched(rec)
    71  				if !gonext {
    72  					syncing.Halt()
    73  					return
    74  				}
    75  			}
    76  
    77  			return
    78  		}
    79  	}
    80  
    81  	// start the threads
    82  	var preparing sync.WaitGroup
    83  	preparing.Add(1)
    84  	for pos := range pattern {
    85  		if len(pattern[pos]) == 0 {
    86  			continue
    87  		}
    88  		for i, variant := range pattern[pos] {
    89  			syncing.StartThread(pos, i)
    90  			go func(pos, i int, variant common.Hash) {
    91  				onMatched := aggregator(pos, i)
    92  				preparing.Wait()
    93  				tt.scanPatternVariant(uint8(pos), variant, blockStart, onMatched, onDbIterator)
    94  			}(pos, i, variant)
    95  		}
    96  	}
    97  	preparing.Done()
    98  
    99  	syncing.WaitForThreads()
   100  
   101  	return ctx.Err()
   102  }
   103  
   104  func (tt *index) scanPatternVariant(pos uint8, variant common.Hash, start uint64, onMatched logHandler, onDbIterator func()) {
   105  	prefix := append(variant.Bytes(), posToBytes(pos)...)
   106  
   107  	onDbIterator()
   108  	it := tt.table.Topic.NewIterator(prefix, uintToBytes(start))
   109  	defer it.Release()
   110  	for it.Next() {
   111  		id := extractLogrecID(it.Key())
   112  		topicCount := bytesToPos(it.Value())
   113  		rec := newLogrec(id, topicCount)
   114  
   115  		gonext, _ := onMatched(rec)
   116  		if !gonext {
   117  			break
   118  		}
   119  	}
   120  	onMatched(nil)
   121  }