github.com/unicornultrafoundation/go-u2u@v1.0.0-rc1.0.20240205080301-e74a83d3fadc/utils/weighted_shuffle.go (about)

     1  package utils
     2  
     3  import (
     4  	"crypto/sha256"
     5  
     6  	"github.com/unicornultrafoundation/go-helios/common/littleendian"
     7  	"github.com/unicornultrafoundation/go-helios/hash"
     8  	"github.com/unicornultrafoundation/go-helios/native/pos"
     9  )
    10  
    11  type weightedShuffleNode struct {
    12  	thisWeight  pos.Weight
    13  	leftWeight  pos.Weight
    14  	rightWeight pos.Weight
    15  }
    16  
    17  type weightedShuffleTree struct {
    18  	seed      hash.Hash
    19  	seedIndex int
    20  
    21  	weights []pos.Weight
    22  	nodes   []weightedShuffleNode
    23  }
    24  
    25  func (t *weightedShuffleTree) leftIndex(i int) int {
    26  	return i*2 + 1
    27  }
    28  
    29  func (t *weightedShuffleTree) rightIndex(i int) int {
    30  	return i*2 + 2
    31  }
    32  
    33  func (t *weightedShuffleTree) build(i int) pos.Weight {
    34  	if i >= len(t.weights) {
    35  		return 0
    36  	}
    37  	thisW := t.weights[i]
    38  	leftW := t.build(t.leftIndex(i))
    39  	rightW := t.build(t.rightIndex(i))
    40  
    41  	if thisW <= 0 {
    42  		panic("all the weight must be positive")
    43  	}
    44  
    45  	t.nodes[i] = weightedShuffleNode{
    46  		thisWeight:  thisW,
    47  		leftWeight:  leftW,
    48  		rightWeight: rightW,
    49  	}
    50  	return thisW + leftW + rightW
    51  }
    52  
    53  func (t *weightedShuffleTree) rand32() uint32 {
    54  	if t.seedIndex == 32 {
    55  		hasher := sha256.New() // use sha2 instead of sha3 for speed
    56  		hasher.Write(t.seed.Bytes())
    57  		t.seed = hash.BytesToHash(hasher.Sum(nil))
    58  		t.seedIndex = 0
    59  	}
    60  	// use not used parts of old seed, instead of calculating new one
    61  	res := littleendian.BytesToUint32(t.seed[t.seedIndex : t.seedIndex+4])
    62  	t.seedIndex += 4
    63  	return res
    64  }
    65  
    66  func (t *weightedShuffleTree) retrieve(i int) int {
    67  	node := t.nodes[i]
    68  	total := node.rightWeight + node.leftWeight + node.thisWeight
    69  
    70  	r := pos.Weight(t.rand32()) % total
    71  
    72  	if r < node.thisWeight {
    73  		t.nodes[i].thisWeight = 0
    74  		return i
    75  	} else if r < node.thisWeight+node.leftWeight {
    76  		chosen := t.retrieve(t.leftIndex(i))
    77  		t.nodes[i].leftWeight -= t.weights[chosen]
    78  		return chosen
    79  	} else {
    80  		chosen := t.retrieve(t.rightIndex(i))
    81  		t.nodes[i].rightWeight -= t.weights[chosen]
    82  		return chosen
    83  	}
    84  }
    85  
    86  // WeightedPermutation builds weighted random permutation
    87  // Returns first {size} entries of {weights} permutation.
    88  // Call with {size} == len(weights) to get the whole permutation.
    89  func WeightedPermutation(size int, weights []pos.Weight, seed hash.Hash) []int {
    90  	if len(weights) < size {
    91  		panic("the permutation size must be less or equal to weights size")
    92  	}
    93  
    94  	if len(weights) == 0 {
    95  		return make([]int, 0)
    96  	}
    97  
    98  	tree := weightedShuffleTree{
    99  		weights: weights,
   100  		nodes:   make([]weightedShuffleNode, len(weights)),
   101  		seed:    seed,
   102  	}
   103  	tree.build(0)
   104  
   105  	permutation := make([]int, size)
   106  	for i := 0; i < size; i++ {
   107  		permutation[i] = tree.retrieve(0)
   108  	}
   109  	return permutation
   110  }