github.com/prysmaticlabs/prysm@v1.4.4/shared/aggregation/maxcover.go (about)

     1  package aggregation
     2  
     3  import (
     4  	"sort"
     5  
     6  	"github.com/pkg/errors"
     7  	"github.com/prysmaticlabs/go-bitfield"
     8  )
     9  
    10  // ErrInvalidMaxCoverProblem is returned when Maximum Coverage problem was initialized incorrectly.
    11  var ErrInvalidMaxCoverProblem = errors.New("invalid max_cover problem")
    12  
    13  // MaxCoverProblem defines Maximum Coverage problem.
    14  //
    15  // Problem is defined as MaxCover(U, S, k): S', where:
    16  // U is a finite set of objects, where |U| = n. Furthermore, let S = {S_1, ..., S_m} be all
    17  // subsets of U, that's their union is equal to U. Then, Maximum Coverage is the problem of
    18  // finding such a collection S' of subsets from S, where |S'| <= k, and union of all subsets in S'
    19  // covering U with maximum cardinality.
    20  //
    21  // The current implementation captures the original MaxCover problem, and the variant where
    22  // additional invariant is enforced: all elements of S' must be disjoint. This comes handy when
    23  // we need to aggregate bitsets, and overlaps are not allowed.
    24  //
    25  // For more details, see:
    26  // "Analysis of the Greedy Approach in Problems of Maximum k-Coverage" by Hochbaum and Pathria.
    27  // https://hochbaum.ieor.berkeley.edu/html/pub/HPathria-max-k-coverage-greedy.pdf
    28  type MaxCoverProblem struct {
    29  	Candidates MaxCoverCandidates
    30  }
    31  
    32  // MaxCoverCandidate represents a candidate set to be used in aggregation.
    33  type MaxCoverCandidate struct {
    34  	key       int
    35  	bits      *bitfield.Bitlist
    36  	score     uint64
    37  	processed bool
    38  }
    39  
    40  // MaxCoverCandidates is defined to allow group operations (filtering, sorting) on all candidates.
    41  type MaxCoverCandidates []*MaxCoverCandidate
    42  
    43  // NewMaxCoverCandidate returns initialized candidate.
    44  func NewMaxCoverCandidate(key int, bits *bitfield.Bitlist) *MaxCoverCandidate {
    45  	return &MaxCoverCandidate{
    46  		key:  key,
    47  		bits: bits,
    48  	}
    49  }
    50  
    51  // Cover calculates solution to Maximum k-Cover problem in O(knm), where
    52  // n is number of candidates and m is a length of bitlist in each candidate.
    53  func (mc *MaxCoverProblem) Cover(k int, allowOverlaps bool) (*Aggregation, error) {
    54  	if len(mc.Candidates) == 0 {
    55  		return nil, errors.Wrap(ErrInvalidMaxCoverProblem, "cannot calculate set coverage")
    56  	}
    57  	if len(mc.Candidates) < k {
    58  		k = len(mc.Candidates)
    59  	}
    60  
    61  	solution := &Aggregation{
    62  		Coverage: bitfield.NewBitlist(mc.Candidates[0].bits.Len()),
    63  		Keys:     make([]int, 0, k),
    64  	}
    65  	remainingBits, err := mc.Candidates.union()
    66  	if err != nil {
    67  		return nil, err
    68  	}
    69  	if remainingBits == nil {
    70  		return nil, errors.Wrap(ErrInvalidMaxCoverProblem, "empty bitlists")
    71  	}
    72  
    73  	for len(solution.Keys) < k && len(mc.Candidates) > 0 {
    74  		// Score candidates against remaining bits.
    75  		// Filter out processed and overlapping (when disallowed).
    76  		// Sort by score in a descending order.
    77  		s, err := mc.Candidates.score(remainingBits)
    78  		if err != nil {
    79  			return nil, err
    80  		}
    81  		s, err = s.filter(solution.Coverage, allowOverlaps)
    82  		if err != nil {
    83  			return nil, err
    84  		}
    85  		s.sort()
    86  
    87  		for _, candidate := range mc.Candidates {
    88  			if len(solution.Keys) >= k {
    89  				break
    90  			}
    91  			if !candidate.processed {
    92  				var err error
    93  				solution.Coverage, err = solution.Coverage.Or(*candidate.bits)
    94  				if err != nil {
    95  					return nil, err
    96  				}
    97  				solution.Keys = append(solution.Keys, candidate.key)
    98  				remainingBits, err = remainingBits.And(candidate.bits.Not())
    99  				if err != nil {
   100  					return nil, err
   101  				}
   102  				candidate.processed = true
   103  				break
   104  			}
   105  		}
   106  	}
   107  	return solution, nil
   108  }
   109  
   110  // MaxCover finds the k-cover of Maximum Coverage problem.
   111  func MaxCover(candidates []*bitfield.Bitlist64, k int, allowOverlaps bool) (selected, coverage *bitfield.Bitlist64, err error) {
   112  	if len(candidates) == 0 {
   113  		return nil, nil, errors.Wrap(ErrInvalidMaxCoverProblem, "cannot calculate set coverage")
   114  	}
   115  	if len(candidates) < k {
   116  		k = len(candidates)
   117  	}
   118  
   119  	// Track usable candidates, and candidates selected for coverage as two bitlists.
   120  	selectedCandidates := bitfield.NewBitlist64(uint64(len(candidates)))
   121  	usableCandidates := bitfield.NewBitlist64(uint64(len(candidates))).Not()
   122  
   123  	// Track bits covered so far as a bitlist.
   124  	coveredBits := bitfield.NewBitlist64(candidates[0].Len())
   125  	remainingBits, err := union(candidates)
   126  	if err != nil {
   127  		return nil, nil, err
   128  	}
   129  	if remainingBits == nil {
   130  		return nil, nil, errors.Wrap(ErrInvalidMaxCoverProblem, "empty bitlists")
   131  	}
   132  
   133  	attempts := 0
   134  	tmpBitlist := bitfield.NewBitlist64(candidates[0].Len()) // Used as return param for NoAlloc*() methods.
   135  	indices := make([]int, usableCandidates.Count())
   136  	for selectedCandidates.Count() < uint64(k) && usableCandidates.Count() > 0 {
   137  		// Safe-guard, each iteration should come with at least one candidate selected.
   138  		if attempts > k {
   139  			break
   140  		}
   141  		attempts++
   142  
   143  		// Greedy select the next best candidate (from usable ones) to cover the remaining bits maximally.
   144  		maxScore := uint64(0)
   145  		bestIdx := uint64(0)
   146  		indices = indices[0:usableCandidates.Count()]
   147  		usableCandidates.NoAllocBitIndices(indices)
   148  		for _, idx := range indices {
   149  			// Score is calculated by taking into account uncovered bits only.
   150  			score := uint64(0)
   151  			if candidates[idx].Len() == remainingBits.Len() {
   152  				var err error
   153  				score, err = candidates[idx].AndCount(remainingBits)
   154  				if err != nil {
   155  					return nil, nil, err
   156  				}
   157  			}
   158  
   159  			// Filter out zero-score candidates.
   160  			if score == 0 {
   161  				usableCandidates.SetBitAt(uint64(idx), false)
   162  				continue
   163  			}
   164  
   165  			// Filter out overlapping candidates (if overlapping is not allowed).
   166  			wrongLen := coveredBits.Len() != candidates[idx].Len()
   167  			overlaps := func(idx int) (bool, error) {
   168  				o, err := coveredBits.Overlaps(candidates[idx])
   169  				return !allowOverlaps && o, err
   170  			}
   171  			if wrongLen { // Shortcut for wrong length check
   172  				usableCandidates.SetBitAt(uint64(idx), false)
   173  				continue
   174  			} else if o, err := overlaps(idx); err != nil {
   175  				return nil, nil, err
   176  			} else if o {
   177  				usableCandidates.SetBitAt(uint64(idx), false)
   178  				continue
   179  			}
   180  
   181  			// Track the candidate with the best score.
   182  			if score > maxScore {
   183  				maxScore = score
   184  				bestIdx = uint64(idx)
   185  			}
   186  		}
   187  		// Process greedy selected candidate.
   188  		if maxScore > 0 {
   189  			if err := coveredBits.NoAllocOr(candidates[bestIdx], coveredBits); err != nil {
   190  				return nil, nil, err
   191  			}
   192  			selectedCandidates.SetBitAt(bestIdx, true)
   193  			candidates[bestIdx].NoAllocNot(tmpBitlist)
   194  			if err := remainingBits.NoAllocAnd(tmpBitlist, remainingBits); err != nil {
   195  				return nil, nil, err
   196  			}
   197  			usableCandidates.SetBitAt(bestIdx, false)
   198  		}
   199  	}
   200  	return selectedCandidates, coveredBits, nil
   201  }
   202  
   203  // score updates scores of candidates, taking into account the uncovered elements only.
   204  func (cl *MaxCoverCandidates) score(uncovered bitfield.Bitlist) (*MaxCoverCandidates, error) {
   205  	for i := 0; i < len(*cl); i++ {
   206  		if (*cl)[i].bits.Len() == uncovered.Len() {
   207  			a, err := (*cl)[i].bits.And(uncovered)
   208  			if err != nil {
   209  				return nil, err
   210  			}
   211  			(*cl)[i].score = a.Count()
   212  		}
   213  	}
   214  	return cl, nil
   215  }
   216  
   217  // filter removes processed, overlapping and zero-score candidates.
   218  func (cl *MaxCoverCandidates) filter(covered bitfield.Bitlist, allowOverlaps bool) (*MaxCoverCandidates, error) {
   219  	overlaps := func(e bitfield.Bitlist) (bool, error) {
   220  		if !allowOverlaps && covered.Len() == e.Len() {
   221  			return covered.Overlaps(e)
   222  		}
   223  		return false, nil
   224  	}
   225  	cur, end := 0, len(*cl)
   226  	for cur < end {
   227  		e := *(*cl)[cur]
   228  		if e.processed || e.score == 0 {
   229  			(*cl)[cur] = (*cl)[end-1]
   230  			end--
   231  			continue
   232  		} else if o, err := overlaps(*e.bits); err == nil && o {
   233  			(*cl)[cur] = (*cl)[end-1]
   234  			end--
   235  			continue
   236  		} else if err != nil {
   237  			return nil, err
   238  		}
   239  
   240  		cur++
   241  	}
   242  	*cl = (*cl)[:end]
   243  	return cl, nil
   244  }
   245  
   246  // sort orders candidates by their score, starting from the candidate with the highest score.
   247  func (cl *MaxCoverCandidates) sort() *MaxCoverCandidates {
   248  	sort.Slice(*cl, func(i, j int) bool {
   249  		if (*cl)[i].score == (*cl)[j].score {
   250  			return (*cl)[i].key < (*cl)[j].key
   251  		}
   252  		return (*cl)[i].score > (*cl)[j].score
   253  	})
   254  	return cl
   255  }
   256  
   257  // union merges all candidate bitlists using logical OR operator.
   258  func (cl *MaxCoverCandidates) union() (bitfield.Bitlist, error) {
   259  	if len(*cl) == 0 {
   260  		return nil, nil
   261  	}
   262  	if (*cl)[0].bits == nil || (*cl)[0].bits.Len() == 0 {
   263  		return nil, nil
   264  	}
   265  	ret := bitfield.NewBitlist((*cl)[0].bits.Len())
   266  	var err error
   267  	for i := 0; i < len(*cl); i++ {
   268  		if *(*cl)[i].bits != nil && ret.Len() == (*cl)[i].bits.Len() {
   269  			ret, err = ret.Or(*(*cl)[i].bits)
   270  			if err != nil {
   271  				return nil, err
   272  			}
   273  		}
   274  	}
   275  	return ret, nil
   276  }
   277  
   278  func union(candidates []*bitfield.Bitlist64) (*bitfield.Bitlist64, error) {
   279  	if len(candidates) == 0 || candidates[0].Len() == 0 {
   280  		return nil, nil
   281  	}
   282  	ret := bitfield.NewBitlist64(candidates[0].Len())
   283  	for _, bl := range candidates {
   284  		if ret.Len() == bl.Len() {
   285  			if err := ret.NoAllocOr(bl, ret); err != nil {
   286  				return nil, err
   287  			}
   288  		}
   289  	}
   290  	return ret, nil
   291  }