github.com/prysmaticlabs/prysm@v1.4.4/shared/aggregation/maxcover.go (about) 1 package aggregation 2 3 import ( 4 "sort" 5 6 "github.com/pkg/errors" 7 "github.com/prysmaticlabs/go-bitfield" 8 ) 9 10 // ErrInvalidMaxCoverProblem is returned when Maximum Coverage problem was initialized incorrectly. 11 var ErrInvalidMaxCoverProblem = errors.New("invalid max_cover problem") 12 13 // MaxCoverProblem defines Maximum Coverage problem. 14 // 15 // Problem is defined as MaxCover(U, S, k): S', where: 16 // U is a finite set of objects, where |U| = n. Furthermore, let S = {S_1, ..., S_m} be all 17 // subsets of U, that's their union is equal to U. Then, Maximum Coverage is the problem of 18 // finding such a collection S' of subsets from S, where |S'| <= k, and union of all subsets in S' 19 // covering U with maximum cardinality. 20 // 21 // The current implementation captures the original MaxCover problem, and the variant where 22 // additional invariant is enforced: all elements of S' must be disjoint. This comes handy when 23 // we need to aggregate bitsets, and overlaps are not allowed. 24 // 25 // For more details, see: 26 // "Analysis of the Greedy Approach in Problems of Maximum k-Coverage" by Hochbaum and Pathria. 27 // https://hochbaum.ieor.berkeley.edu/html/pub/HPathria-max-k-coverage-greedy.pdf 28 type MaxCoverProblem struct { 29 Candidates MaxCoverCandidates 30 } 31 32 // MaxCoverCandidate represents a candidate set to be used in aggregation. 33 type MaxCoverCandidate struct { 34 key int 35 bits *bitfield.Bitlist 36 score uint64 37 processed bool 38 } 39 40 // MaxCoverCandidates is defined to allow group operations (filtering, sorting) on all candidates. 41 type MaxCoverCandidates []*MaxCoverCandidate 42 43 // NewMaxCoverCandidate returns initialized candidate. 44 func NewMaxCoverCandidate(key int, bits *bitfield.Bitlist) *MaxCoverCandidate { 45 return &MaxCoverCandidate{ 46 key: key, 47 bits: bits, 48 } 49 } 50 51 // Cover calculates solution to Maximum k-Cover problem in O(knm), where 52 // n is number of candidates and m is a length of bitlist in each candidate. 53 func (mc *MaxCoverProblem) Cover(k int, allowOverlaps bool) (*Aggregation, error) { 54 if len(mc.Candidates) == 0 { 55 return nil, errors.Wrap(ErrInvalidMaxCoverProblem, "cannot calculate set coverage") 56 } 57 if len(mc.Candidates) < k { 58 k = len(mc.Candidates) 59 } 60 61 solution := &Aggregation{ 62 Coverage: bitfield.NewBitlist(mc.Candidates[0].bits.Len()), 63 Keys: make([]int, 0, k), 64 } 65 remainingBits, err := mc.Candidates.union() 66 if err != nil { 67 return nil, err 68 } 69 if remainingBits == nil { 70 return nil, errors.Wrap(ErrInvalidMaxCoverProblem, "empty bitlists") 71 } 72 73 for len(solution.Keys) < k && len(mc.Candidates) > 0 { 74 // Score candidates against remaining bits. 75 // Filter out processed and overlapping (when disallowed). 76 // Sort by score in a descending order. 77 s, err := mc.Candidates.score(remainingBits) 78 if err != nil { 79 return nil, err 80 } 81 s, err = s.filter(solution.Coverage, allowOverlaps) 82 if err != nil { 83 return nil, err 84 } 85 s.sort() 86 87 for _, candidate := range mc.Candidates { 88 if len(solution.Keys) >= k { 89 break 90 } 91 if !candidate.processed { 92 var err error 93 solution.Coverage, err = solution.Coverage.Or(*candidate.bits) 94 if err != nil { 95 return nil, err 96 } 97 solution.Keys = append(solution.Keys, candidate.key) 98 remainingBits, err = remainingBits.And(candidate.bits.Not()) 99 if err != nil { 100 return nil, err 101 } 102 candidate.processed = true 103 break 104 } 105 } 106 } 107 return solution, nil 108 } 109 110 // MaxCover finds the k-cover of Maximum Coverage problem. 111 func MaxCover(candidates []*bitfield.Bitlist64, k int, allowOverlaps bool) (selected, coverage *bitfield.Bitlist64, err error) { 112 if len(candidates) == 0 { 113 return nil, nil, errors.Wrap(ErrInvalidMaxCoverProblem, "cannot calculate set coverage") 114 } 115 if len(candidates) < k { 116 k = len(candidates) 117 } 118 119 // Track usable candidates, and candidates selected for coverage as two bitlists. 120 selectedCandidates := bitfield.NewBitlist64(uint64(len(candidates))) 121 usableCandidates := bitfield.NewBitlist64(uint64(len(candidates))).Not() 122 123 // Track bits covered so far as a bitlist. 124 coveredBits := bitfield.NewBitlist64(candidates[0].Len()) 125 remainingBits, err := union(candidates) 126 if err != nil { 127 return nil, nil, err 128 } 129 if remainingBits == nil { 130 return nil, nil, errors.Wrap(ErrInvalidMaxCoverProblem, "empty bitlists") 131 } 132 133 attempts := 0 134 tmpBitlist := bitfield.NewBitlist64(candidates[0].Len()) // Used as return param for NoAlloc*() methods. 135 indices := make([]int, usableCandidates.Count()) 136 for selectedCandidates.Count() < uint64(k) && usableCandidates.Count() > 0 { 137 // Safe-guard, each iteration should come with at least one candidate selected. 138 if attempts > k { 139 break 140 } 141 attempts++ 142 143 // Greedy select the next best candidate (from usable ones) to cover the remaining bits maximally. 144 maxScore := uint64(0) 145 bestIdx := uint64(0) 146 indices = indices[0:usableCandidates.Count()] 147 usableCandidates.NoAllocBitIndices(indices) 148 for _, idx := range indices { 149 // Score is calculated by taking into account uncovered bits only. 150 score := uint64(0) 151 if candidates[idx].Len() == remainingBits.Len() { 152 var err error 153 score, err = candidates[idx].AndCount(remainingBits) 154 if err != nil { 155 return nil, nil, err 156 } 157 } 158 159 // Filter out zero-score candidates. 160 if score == 0 { 161 usableCandidates.SetBitAt(uint64(idx), false) 162 continue 163 } 164 165 // Filter out overlapping candidates (if overlapping is not allowed). 166 wrongLen := coveredBits.Len() != candidates[idx].Len() 167 overlaps := func(idx int) (bool, error) { 168 o, err := coveredBits.Overlaps(candidates[idx]) 169 return !allowOverlaps && o, err 170 } 171 if wrongLen { // Shortcut for wrong length check 172 usableCandidates.SetBitAt(uint64(idx), false) 173 continue 174 } else if o, err := overlaps(idx); err != nil { 175 return nil, nil, err 176 } else if o { 177 usableCandidates.SetBitAt(uint64(idx), false) 178 continue 179 } 180 181 // Track the candidate with the best score. 182 if score > maxScore { 183 maxScore = score 184 bestIdx = uint64(idx) 185 } 186 } 187 // Process greedy selected candidate. 188 if maxScore > 0 { 189 if err := coveredBits.NoAllocOr(candidates[bestIdx], coveredBits); err != nil { 190 return nil, nil, err 191 } 192 selectedCandidates.SetBitAt(bestIdx, true) 193 candidates[bestIdx].NoAllocNot(tmpBitlist) 194 if err := remainingBits.NoAllocAnd(tmpBitlist, remainingBits); err != nil { 195 return nil, nil, err 196 } 197 usableCandidates.SetBitAt(bestIdx, false) 198 } 199 } 200 return selectedCandidates, coveredBits, nil 201 } 202 203 // score updates scores of candidates, taking into account the uncovered elements only. 204 func (cl *MaxCoverCandidates) score(uncovered bitfield.Bitlist) (*MaxCoverCandidates, error) { 205 for i := 0; i < len(*cl); i++ { 206 if (*cl)[i].bits.Len() == uncovered.Len() { 207 a, err := (*cl)[i].bits.And(uncovered) 208 if err != nil { 209 return nil, err 210 } 211 (*cl)[i].score = a.Count() 212 } 213 } 214 return cl, nil 215 } 216 217 // filter removes processed, overlapping and zero-score candidates. 218 func (cl *MaxCoverCandidates) filter(covered bitfield.Bitlist, allowOverlaps bool) (*MaxCoverCandidates, error) { 219 overlaps := func(e bitfield.Bitlist) (bool, error) { 220 if !allowOverlaps && covered.Len() == e.Len() { 221 return covered.Overlaps(e) 222 } 223 return false, nil 224 } 225 cur, end := 0, len(*cl) 226 for cur < end { 227 e := *(*cl)[cur] 228 if e.processed || e.score == 0 { 229 (*cl)[cur] = (*cl)[end-1] 230 end-- 231 continue 232 } else if o, err := overlaps(*e.bits); err == nil && o { 233 (*cl)[cur] = (*cl)[end-1] 234 end-- 235 continue 236 } else if err != nil { 237 return nil, err 238 } 239 240 cur++ 241 } 242 *cl = (*cl)[:end] 243 return cl, nil 244 } 245 246 // sort orders candidates by their score, starting from the candidate with the highest score. 247 func (cl *MaxCoverCandidates) sort() *MaxCoverCandidates { 248 sort.Slice(*cl, func(i, j int) bool { 249 if (*cl)[i].score == (*cl)[j].score { 250 return (*cl)[i].key < (*cl)[j].key 251 } 252 return (*cl)[i].score > (*cl)[j].score 253 }) 254 return cl 255 } 256 257 // union merges all candidate bitlists using logical OR operator. 258 func (cl *MaxCoverCandidates) union() (bitfield.Bitlist, error) { 259 if len(*cl) == 0 { 260 return nil, nil 261 } 262 if (*cl)[0].bits == nil || (*cl)[0].bits.Len() == 0 { 263 return nil, nil 264 } 265 ret := bitfield.NewBitlist((*cl)[0].bits.Len()) 266 var err error 267 for i := 0; i < len(*cl); i++ { 268 if *(*cl)[i].bits != nil && ret.Len() == (*cl)[i].bits.Len() { 269 ret, err = ret.Or(*(*cl)[i].bits) 270 if err != nil { 271 return nil, err 272 } 273 } 274 } 275 return ret, nil 276 } 277 278 func union(candidates []*bitfield.Bitlist64) (*bitfield.Bitlist64, error) { 279 if len(candidates) == 0 || candidates[0].Len() == 0 { 280 return nil, nil 281 } 282 ret := bitfield.NewBitlist64(candidates[0].Len()) 283 for _, bl := range candidates { 284 if ret.Len() == bl.Len() { 285 if err := ret.NoAllocOr(bl, ret); err != nil { 286 return nil, err 287 } 288 } 289 } 290 return ret, nil 291 }