github.com/MetalBlockchain/metalgo@v1.11.9/utils/set/sampleable_set.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package set
     5  
     6  import (
     7  	"bytes"
     8  	"encoding/json"
     9  	"slices"
    10  
    11  	"github.com/MetalBlockchain/metalgo/utils"
    12  	"github.com/MetalBlockchain/metalgo/utils/sampler"
    13  	"github.com/MetalBlockchain/metalgo/utils/wrappers"
    14  
    15  	avajson "github.com/MetalBlockchain/metalgo/utils/json"
    16  )
    17  
    18  var _ json.Marshaler = (*Set[int])(nil)
    19  
    20  // SampleableSet is a set of elements that supports sampling.
    21  type SampleableSet[T comparable] struct {
    22  	// indices maps the element in the set to the index that it appears in
    23  	// elements.
    24  	indices  map[T]int
    25  	elements []T
    26  }
    27  
    28  // OfSampleable returns a Set initialized with [elts]
    29  func OfSampleable[T comparable](elts ...T) SampleableSet[T] {
    30  	s := NewSampleableSet[T](len(elts))
    31  	s.Add(elts...)
    32  	return s
    33  }
    34  
    35  // Return a new sampleable set with initial capacity [size].
    36  // More or less than [size] elements can be added to this set.
    37  // Using NewSampleableSet() rather than SampleableSet[T]{} is just an
    38  // optimization that can be used if you know how many elements will be put in
    39  // this set.
    40  func NewSampleableSet[T comparable](size int) SampleableSet[T] {
    41  	if size < 0 {
    42  		return SampleableSet[T]{}
    43  	}
    44  	return SampleableSet[T]{
    45  		indices:  make(map[T]int, size),
    46  		elements: make([]T, 0, size),
    47  	}
    48  }
    49  
    50  // Add all the elements to this set.
    51  // If the element is already in the set, nothing happens.
    52  func (s *SampleableSet[T]) Add(elements ...T) {
    53  	s.resize(2 * len(elements))
    54  	for _, e := range elements {
    55  		s.add(e)
    56  	}
    57  }
    58  
    59  // Union adds all the elements from the provided set to this set.
    60  func (s *SampleableSet[T]) Union(set SampleableSet[T]) {
    61  	s.resize(2 * set.Len())
    62  	for _, e := range set.elements {
    63  		s.add(e)
    64  	}
    65  }
    66  
    67  // Difference removes all the elements in [set] from [s].
    68  func (s *SampleableSet[T]) Difference(set SampleableSet[T]) {
    69  	for _, e := range set.elements {
    70  		s.remove(e)
    71  	}
    72  }
    73  
    74  // Contains returns true iff the set contains this element.
    75  func (s SampleableSet[T]) Contains(e T) bool {
    76  	_, contains := s.indices[e]
    77  	return contains
    78  }
    79  
    80  // Overlaps returns true if the intersection of the set is non-empty
    81  func (s SampleableSet[T]) Overlaps(big SampleableSet[T]) bool {
    82  	small := s
    83  	if small.Len() > big.Len() {
    84  		small, big = big, small
    85  	}
    86  
    87  	for _, e := range small.elements {
    88  		if _, ok := big.indices[e]; ok {
    89  			return true
    90  		}
    91  	}
    92  	return false
    93  }
    94  
    95  // Len returns the number of elements in this set.
    96  func (s SampleableSet[_]) Len() int {
    97  	return len(s.elements)
    98  }
    99  
   100  // Remove all the given elements from this set.
   101  // If an element isn't in the set, it's ignored.
   102  func (s *SampleableSet[T]) Remove(elements ...T) {
   103  	for _, e := range elements {
   104  		s.remove(e)
   105  	}
   106  }
   107  
   108  // Clear empties this set
   109  func (s *SampleableSet[T]) Clear() {
   110  	clear(s.indices)
   111  	for i := range s.elements {
   112  		s.elements[i] = utils.Zero[T]()
   113  	}
   114  	s.elements = s.elements[:0]
   115  }
   116  
   117  // List converts this set into a list
   118  func (s SampleableSet[T]) List() []T {
   119  	return slices.Clone(s.elements)
   120  }
   121  
   122  // Equals returns true if the sets contain the same elements
   123  func (s SampleableSet[T]) Equals(other SampleableSet[T]) bool {
   124  	if len(s.indices) != len(other.indices) {
   125  		return false
   126  	}
   127  	for k := range s.indices {
   128  		if _, ok := other.indices[k]; !ok {
   129  			return false
   130  		}
   131  	}
   132  	return true
   133  }
   134  
   135  func (s SampleableSet[T]) Sample(numToSample int) []T {
   136  	if numToSample <= 0 {
   137  		return nil
   138  	}
   139  
   140  	uniform := sampler.NewUniform()
   141  	uniform.Initialize(uint64(len(s.elements)))
   142  	indices, _ := uniform.Sample(min(len(s.elements), numToSample))
   143  	elements := make([]T, len(indices))
   144  	for i, index := range indices {
   145  		elements[i] = s.elements[index]
   146  	}
   147  	return elements
   148  }
   149  
   150  func (s *SampleableSet[T]) UnmarshalJSON(b []byte) error {
   151  	str := string(b)
   152  	if str == avajson.Null {
   153  		return nil
   154  	}
   155  	var elements []T
   156  	if err := json.Unmarshal(b, &elements); err != nil {
   157  		return err
   158  	}
   159  	s.Clear()
   160  	s.Add(elements...)
   161  	return nil
   162  }
   163  
   164  func (s *SampleableSet[_]) MarshalJSON() ([]byte, error) {
   165  	var (
   166  		elementBytes = make([][]byte, len(s.elements))
   167  		err          error
   168  	)
   169  	for i, e := range s.elements {
   170  		elementBytes[i], err = json.Marshal(e)
   171  		if err != nil {
   172  			return nil, err
   173  		}
   174  	}
   175  	// Sort for determinism
   176  	slices.SortFunc(elementBytes, bytes.Compare)
   177  
   178  	// Build the JSON
   179  	var (
   180  		jsonBuf = bytes.Buffer{}
   181  		errs    = wrappers.Errs{}
   182  	)
   183  	_, err = jsonBuf.WriteString("[")
   184  	errs.Add(err)
   185  	for i, elt := range elementBytes {
   186  		_, err := jsonBuf.Write(elt)
   187  		errs.Add(err)
   188  		if i != len(elementBytes)-1 {
   189  			_, err := jsonBuf.WriteString(",")
   190  			errs.Add(err)
   191  		}
   192  	}
   193  	_, err = jsonBuf.WriteString("]")
   194  	errs.Add(err)
   195  
   196  	return jsonBuf.Bytes(), errs.Err
   197  }
   198  
   199  func (s *SampleableSet[T]) resize(size int) {
   200  	if s.elements == nil {
   201  		if minSetSize > size {
   202  			size = minSetSize
   203  		}
   204  		s.indices = make(map[T]int, size)
   205  	}
   206  }
   207  
   208  func (s *SampleableSet[T]) add(e T) {
   209  	_, ok := s.indices[e]
   210  	if ok {
   211  		return
   212  	}
   213  
   214  	s.indices[e] = len(s.elements)
   215  	s.elements = append(s.elements, e)
   216  }
   217  
   218  func (s *SampleableSet[T]) remove(e T) {
   219  	indexToRemove, ok := s.indices[e]
   220  	if !ok {
   221  		return
   222  	}
   223  
   224  	lastIndex := len(s.elements) - 1
   225  	if indexToRemove != lastIndex {
   226  		lastElement := s.elements[lastIndex]
   227  
   228  		s.indices[lastElement] = indexToRemove
   229  		s.elements[indexToRemove] = lastElement
   230  	}
   231  
   232  	delete(s.indices, e)
   233  	s.elements[lastIndex] = utils.Zero[T]()
   234  	s.elements = s.elements[:lastIndex]
   235  }