github.com/MetalBlockchain/metalgo@v1.11.9/utils/set/sampleable_set.go (about) 1 // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. 2 // See the file LICENSE for licensing terms. 3 4 package set 5 6 import ( 7 "bytes" 8 "encoding/json" 9 "slices" 10 11 "github.com/MetalBlockchain/metalgo/utils" 12 "github.com/MetalBlockchain/metalgo/utils/sampler" 13 "github.com/MetalBlockchain/metalgo/utils/wrappers" 14 15 avajson "github.com/MetalBlockchain/metalgo/utils/json" 16 ) 17 18 var _ json.Marshaler = (*Set[int])(nil) 19 20 // SampleableSet is a set of elements that supports sampling. 21 type SampleableSet[T comparable] struct { 22 // indices maps the element in the set to the index that it appears in 23 // elements. 24 indices map[T]int 25 elements []T 26 } 27 28 // OfSampleable returns a Set initialized with [elts] 29 func OfSampleable[T comparable](elts ...T) SampleableSet[T] { 30 s := NewSampleableSet[T](len(elts)) 31 s.Add(elts...) 32 return s 33 } 34 35 // Return a new sampleable set with initial capacity [size]. 36 // More or less than [size] elements can be added to this set. 37 // Using NewSampleableSet() rather than SampleableSet[T]{} is just an 38 // optimization that can be used if you know how many elements will be put in 39 // this set. 40 func NewSampleableSet[T comparable](size int) SampleableSet[T] { 41 if size < 0 { 42 return SampleableSet[T]{} 43 } 44 return SampleableSet[T]{ 45 indices: make(map[T]int, size), 46 elements: make([]T, 0, size), 47 } 48 } 49 50 // Add all the elements to this set. 51 // If the element is already in the set, nothing happens. 52 func (s *SampleableSet[T]) Add(elements ...T) { 53 s.resize(2 * len(elements)) 54 for _, e := range elements { 55 s.add(e) 56 } 57 } 58 59 // Union adds all the elements from the provided set to this set. 60 func (s *SampleableSet[T]) Union(set SampleableSet[T]) { 61 s.resize(2 * set.Len()) 62 for _, e := range set.elements { 63 s.add(e) 64 } 65 } 66 67 // Difference removes all the elements in [set] from [s]. 68 func (s *SampleableSet[T]) Difference(set SampleableSet[T]) { 69 for _, e := range set.elements { 70 s.remove(e) 71 } 72 } 73 74 // Contains returns true iff the set contains this element. 75 func (s SampleableSet[T]) Contains(e T) bool { 76 _, contains := s.indices[e] 77 return contains 78 } 79 80 // Overlaps returns true if the intersection of the set is non-empty 81 func (s SampleableSet[T]) Overlaps(big SampleableSet[T]) bool { 82 small := s 83 if small.Len() > big.Len() { 84 small, big = big, small 85 } 86 87 for _, e := range small.elements { 88 if _, ok := big.indices[e]; ok { 89 return true 90 } 91 } 92 return false 93 } 94 95 // Len returns the number of elements in this set. 96 func (s SampleableSet[_]) Len() int { 97 return len(s.elements) 98 } 99 100 // Remove all the given elements from this set. 101 // If an element isn't in the set, it's ignored. 102 func (s *SampleableSet[T]) Remove(elements ...T) { 103 for _, e := range elements { 104 s.remove(e) 105 } 106 } 107 108 // Clear empties this set 109 func (s *SampleableSet[T]) Clear() { 110 clear(s.indices) 111 for i := range s.elements { 112 s.elements[i] = utils.Zero[T]() 113 } 114 s.elements = s.elements[:0] 115 } 116 117 // List converts this set into a list 118 func (s SampleableSet[T]) List() []T { 119 return slices.Clone(s.elements) 120 } 121 122 // Equals returns true if the sets contain the same elements 123 func (s SampleableSet[T]) Equals(other SampleableSet[T]) bool { 124 if len(s.indices) != len(other.indices) { 125 return false 126 } 127 for k := range s.indices { 128 if _, ok := other.indices[k]; !ok { 129 return false 130 } 131 } 132 return true 133 } 134 135 func (s SampleableSet[T]) Sample(numToSample int) []T { 136 if numToSample <= 0 { 137 return nil 138 } 139 140 uniform := sampler.NewUniform() 141 uniform.Initialize(uint64(len(s.elements))) 142 indices, _ := uniform.Sample(min(len(s.elements), numToSample)) 143 elements := make([]T, len(indices)) 144 for i, index := range indices { 145 elements[i] = s.elements[index] 146 } 147 return elements 148 } 149 150 func (s *SampleableSet[T]) UnmarshalJSON(b []byte) error { 151 str := string(b) 152 if str == avajson.Null { 153 return nil 154 } 155 var elements []T 156 if err := json.Unmarshal(b, &elements); err != nil { 157 return err 158 } 159 s.Clear() 160 s.Add(elements...) 161 return nil 162 } 163 164 func (s *SampleableSet[_]) MarshalJSON() ([]byte, error) { 165 var ( 166 elementBytes = make([][]byte, len(s.elements)) 167 err error 168 ) 169 for i, e := range s.elements { 170 elementBytes[i], err = json.Marshal(e) 171 if err != nil { 172 return nil, err 173 } 174 } 175 // Sort for determinism 176 slices.SortFunc(elementBytes, bytes.Compare) 177 178 // Build the JSON 179 var ( 180 jsonBuf = bytes.Buffer{} 181 errs = wrappers.Errs{} 182 ) 183 _, err = jsonBuf.WriteString("[") 184 errs.Add(err) 185 for i, elt := range elementBytes { 186 _, err := jsonBuf.Write(elt) 187 errs.Add(err) 188 if i != len(elementBytes)-1 { 189 _, err := jsonBuf.WriteString(",") 190 errs.Add(err) 191 } 192 } 193 _, err = jsonBuf.WriteString("]") 194 errs.Add(err) 195 196 return jsonBuf.Bytes(), errs.Err 197 } 198 199 func (s *SampleableSet[T]) resize(size int) { 200 if s.elements == nil { 201 if minSetSize > size { 202 size = minSetSize 203 } 204 s.indices = make(map[T]int, size) 205 } 206 } 207 208 func (s *SampleableSet[T]) add(e T) { 209 _, ok := s.indices[e] 210 if ok { 211 return 212 } 213 214 s.indices[e] = len(s.elements) 215 s.elements = append(s.elements, e) 216 } 217 218 func (s *SampleableSet[T]) remove(e T) { 219 indexToRemove, ok := s.indices[e] 220 if !ok { 221 return 222 } 223 224 lastIndex := len(s.elements) - 1 225 if indexToRemove != lastIndex { 226 lastElement := s.elements[lastIndex] 227 228 s.indices[lastElement] = indexToRemove 229 s.elements[indexToRemove] = lastElement 230 } 231 232 delete(s.indices, e) 233 s.elements[lastIndex] = utils.Zero[T]() 234 s.elements = s.elements[:lastIndex] 235 }