github.com/MetalBlockchain/metalgo@v1.11.9/utils/set/set.go (about) 1 // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. 2 // See the file LICENSE for licensing terms. 3 4 package set 5 6 import ( 7 "bytes" 8 "encoding/json" 9 "slices" 10 11 "golang.org/x/exp/maps" 12 13 "github.com/MetalBlockchain/metalgo/utils" 14 "github.com/MetalBlockchain/metalgo/utils/wrappers" 15 16 avajson "github.com/MetalBlockchain/metalgo/utils/json" 17 ) 18 19 // The minimum capacity of a set 20 const minSetSize = 16 21 22 var _ json.Marshaler = (*Set[int])(nil) 23 24 // Set is a set of elements. 25 type Set[T comparable] map[T]struct{} 26 27 // Of returns a Set initialized with [elts] 28 func Of[T comparable](elts ...T) Set[T] { 29 s := NewSet[T](len(elts)) 30 s.Add(elts...) 31 return s 32 } 33 34 // Return a new set with initial capacity [size]. 35 // More or less than [size] elements can be added to this set. 36 // Using NewSet() rather than Set[T]{} is just an optimization that can 37 // be used if you know how many elements will be put in this set. 38 func NewSet[T comparable](size int) Set[T] { 39 if size < 0 { 40 return Set[T]{} 41 } 42 return make(map[T]struct{}, size) 43 } 44 45 func (s *Set[T]) resize(size int) { 46 if *s == nil { 47 if minSetSize > size { 48 size = minSetSize 49 } 50 *s = make(map[T]struct{}, size) 51 } 52 } 53 54 // Add all the elements to this set. 55 // If the element is already in the set, nothing happens. 56 func (s *Set[T]) Add(elts ...T) { 57 s.resize(2 * len(elts)) 58 for _, elt := range elts { 59 (*s)[elt] = struct{}{} 60 } 61 } 62 63 // Union adds all the elements from the provided set to this set. 64 func (s *Set[T]) Union(set Set[T]) { 65 s.resize(2 * set.Len()) 66 for elt := range set { 67 (*s)[elt] = struct{}{} 68 } 69 } 70 71 // Difference removes all the elements in [set] from [s]. 72 func (s *Set[T]) Difference(set Set[T]) { 73 for elt := range set { 74 delete(*s, elt) 75 } 76 } 77 78 // Contains returns true iff the set contains this element. 79 func (s *Set[T]) Contains(elt T) bool { 80 _, contains := (*s)[elt] 81 return contains 82 } 83 84 // Overlaps returns true if the intersection of the set is non-empty 85 func (s *Set[T]) Overlaps(big Set[T]) bool { 86 small := *s 87 if small.Len() > big.Len() { 88 small, big = big, small 89 } 90 91 for elt := range small { 92 if _, ok := big[elt]; ok { 93 return true 94 } 95 } 96 return false 97 } 98 99 // Len returns the number of elements in this set. 100 func (s Set[_]) Len() int { 101 return len(s) 102 } 103 104 // Remove all the given elements from this set. 105 // If an element isn't in the set, it's ignored. 106 func (s *Set[T]) Remove(elts ...T) { 107 for _, elt := range elts { 108 delete(*s, elt) 109 } 110 } 111 112 // Clear empties this set 113 func (s *Set[_]) Clear() { 114 clear(*s) 115 } 116 117 // List converts this set into a list 118 func (s Set[T]) List() []T { 119 return maps.Keys(s) 120 } 121 122 // Equals returns true if the sets contain the same elements 123 func (s Set[T]) Equals(other Set[T]) bool { 124 return maps.Equal(s, other) 125 } 126 127 // Removes and returns an element. 128 // If the set is empty, does nothing and returns false. 129 func (s *Set[T]) Pop() (T, bool) { 130 for elt := range *s { 131 delete(*s, elt) 132 return elt, true 133 } 134 return utils.Zero[T](), false 135 } 136 137 func (s *Set[T]) UnmarshalJSON(b []byte) error { 138 str := string(b) 139 if str == avajson.Null { 140 return nil 141 } 142 var elts []T 143 if err := json.Unmarshal(b, &elts); err != nil { 144 return err 145 } 146 s.Clear() 147 s.Add(elts...) 148 return nil 149 } 150 151 func (s Set[_]) MarshalJSON() ([]byte, error) { 152 var ( 153 eltBytes = make([][]byte, len(s)) 154 i int 155 err error 156 ) 157 for elt := range s { 158 eltBytes[i], err = json.Marshal(elt) 159 if err != nil { 160 return nil, err 161 } 162 i++ 163 } 164 // Sort for determinism 165 slices.SortFunc(eltBytes, bytes.Compare) 166 167 // Build the JSON 168 var ( 169 jsonBuf = bytes.Buffer{} 170 errs = wrappers.Errs{} 171 ) 172 _, err = jsonBuf.WriteString("[") 173 errs.Add(err) 174 for i, elt := range eltBytes { 175 _, err := jsonBuf.Write(elt) 176 errs.Add(err) 177 if i != len(eltBytes)-1 { 178 _, err := jsonBuf.WriteString(",") 179 errs.Add(err) 180 } 181 } 182 _, err = jsonBuf.WriteString("]") 183 errs.Add(err) 184 185 return jsonBuf.Bytes(), errs.Err 186 } 187 188 // Returns a random element. If the set is empty, returns false 189 func (s *Set[T]) Peek() (T, bool) { 190 for elt := range *s { 191 return elt, true 192 } 193 return utils.Zero[T](), false 194 }