github.com/unicornultrafoundation/go-u2u@v1.0.0-rc1.0.20240205080301-e74a83d3fadc/utils/weighted_shuffle.go (about) 1 package utils 2 3 import ( 4 "crypto/sha256" 5 6 "github.com/unicornultrafoundation/go-helios/common/littleendian" 7 "github.com/unicornultrafoundation/go-helios/hash" 8 "github.com/unicornultrafoundation/go-helios/native/pos" 9 ) 10 11 type weightedShuffleNode struct { 12 thisWeight pos.Weight 13 leftWeight pos.Weight 14 rightWeight pos.Weight 15 } 16 17 type weightedShuffleTree struct { 18 seed hash.Hash 19 seedIndex int 20 21 weights []pos.Weight 22 nodes []weightedShuffleNode 23 } 24 25 func (t *weightedShuffleTree) leftIndex(i int) int { 26 return i*2 + 1 27 } 28 29 func (t *weightedShuffleTree) rightIndex(i int) int { 30 return i*2 + 2 31 } 32 33 func (t *weightedShuffleTree) build(i int) pos.Weight { 34 if i >= len(t.weights) { 35 return 0 36 } 37 thisW := t.weights[i] 38 leftW := t.build(t.leftIndex(i)) 39 rightW := t.build(t.rightIndex(i)) 40 41 if thisW <= 0 { 42 panic("all the weight must be positive") 43 } 44 45 t.nodes[i] = weightedShuffleNode{ 46 thisWeight: thisW, 47 leftWeight: leftW, 48 rightWeight: rightW, 49 } 50 return thisW + leftW + rightW 51 } 52 53 func (t *weightedShuffleTree) rand32() uint32 { 54 if t.seedIndex == 32 { 55 hasher := sha256.New() // use sha2 instead of sha3 for speed 56 hasher.Write(t.seed.Bytes()) 57 t.seed = hash.BytesToHash(hasher.Sum(nil)) 58 t.seedIndex = 0 59 } 60 // use not used parts of old seed, instead of calculating new one 61 res := littleendian.BytesToUint32(t.seed[t.seedIndex : t.seedIndex+4]) 62 t.seedIndex += 4 63 return res 64 } 65 66 func (t *weightedShuffleTree) retrieve(i int) int { 67 node := t.nodes[i] 68 total := node.rightWeight + node.leftWeight + node.thisWeight 69 70 r := pos.Weight(t.rand32()) % total 71 72 if r < node.thisWeight { 73 t.nodes[i].thisWeight = 0 74 return i 75 } else if r < node.thisWeight+node.leftWeight { 76 chosen := t.retrieve(t.leftIndex(i)) 77 t.nodes[i].leftWeight -= t.weights[chosen] 78 return chosen 79 } else { 80 chosen := t.retrieve(t.rightIndex(i)) 81 t.nodes[i].rightWeight -= t.weights[chosen] 82 return chosen 83 } 84 } 85 86 // WeightedPermutation builds weighted random permutation 87 // Returns first {size} entries of {weights} permutation. 88 // Call with {size} == len(weights) to get the whole permutation. 89 func WeightedPermutation(size int, weights []pos.Weight, seed hash.Hash) []int { 90 if len(weights) < size { 91 panic("the permutation size must be less or equal to weights size") 92 } 93 94 if len(weights) == 0 { 95 return make([]int, 0) 96 } 97 98 tree := weightedShuffleTree{ 99 weights: weights, 100 nodes: make([]weightedShuffleNode, len(weights)), 101 seed: seed, 102 } 103 tree.build(0) 104 105 permutation := make([]int, size) 106 for i := 0; i < size; i++ { 107 permutation[i] = tree.retrieve(0) 108 } 109 return permutation 110 }