github.com/cilium/statedb@v0.3.2/part/set.go (about) 1 // SPDX-License-Identifier: Apache-2.0 2 // Copyright Authors of Cilium 3 4 package part 5 6 import ( 7 "bytes" 8 "encoding/json" 9 "fmt" 10 "iter" 11 "slices" 12 13 "gopkg.in/yaml.v3" 14 ) 15 16 // Set is a persistent (immutable) set of values. A Set can be 17 // defined for any type for which a byte slice key can be derived. 18 // 19 // A zero value Set[T] can be used provided that the conversion 20 // function for T have been registered with RegisterKeyType. 21 // For Set-only use only [bytesFromKey] needs to be defined. 22 type Set[T any] struct { 23 toBytes func(T) []byte 24 tree *Tree[T] 25 } 26 27 // NewSet creates a new set of T. 28 // The value type T must be registered with RegisterKeyType. 29 func NewSet[T any](values ...T) Set[T] { 30 s := Set[T]{tree: New[T](RootOnlyWatch)} 31 s.toBytes = lookupKeyType[T]() 32 if len(values) > 0 { 33 txn := s.tree.Txn() 34 for _, v := range values { 35 txn.Insert(s.toBytes(v), v) 36 } 37 s.tree = txn.CommitOnly() 38 } 39 return s 40 } 41 42 // Set a value. Returns a new set. Original is unchanged. 43 func (s Set[T]) Set(v T) Set[T] { 44 if s.tree == nil { 45 return NewSet(v) 46 } 47 txn := s.tree.Txn() 48 txn.Insert(s.toBytes(v), v) 49 s.tree = txn.CommitOnly() // As Set is passed by value we can just modify it. 50 return s 51 } 52 53 // Delete returns a new set without the value. The original 54 // set is unchanged. 55 func (s Set[T]) Delete(v T) Set[T] { 56 if s.tree == nil { 57 return s 58 } 59 txn := s.tree.Txn() 60 txn.Delete(s.toBytes(v)) 61 s.tree = txn.CommitOnly() 62 return s 63 } 64 65 // Has returns true if the set has the value. 66 func (s Set[T]) Has(v T) bool { 67 if s.tree == nil { 68 return false 69 } 70 _, _, found := s.tree.Get(s.toBytes(v)) 71 return found 72 } 73 74 // All returns an iterator for all values. 75 func (s Set[T]) All() iter.Seq[T] { 76 if s.tree == nil { 77 return toSeq[T](nil) 78 } 79 return toSeq(s.tree.Iterator()) 80 } 81 82 // Union returns a set that is the union of the values 83 // in the input sets. 84 func (s Set[T]) Union(s2 Set[T]) Set[T] { 85 if s2.tree == nil { 86 return s 87 } 88 if s.tree == nil { 89 return s2 90 } 91 txn := s.tree.Txn() 92 iter := s2.tree.Iterator() 93 for k, v, ok := iter.Next(); ok; k, v, ok = iter.Next() { 94 txn.Insert(k, v) 95 } 96 s.tree = txn.CommitOnly() 97 return s 98 } 99 100 // Difference returns a set with values that only 101 // appear in the first set. 102 func (s Set[T]) Difference(s2 Set[T]) Set[T] { 103 if s.tree == nil || s2.tree == nil { 104 return s 105 } 106 107 txn := s.tree.Txn() 108 iter := s2.tree.Iterator() 109 for k, _, ok := iter.Next(); ok; k, _, ok = iter.Next() { 110 txn.Delete(k) 111 } 112 s.tree = txn.CommitOnly() 113 return s 114 } 115 116 // Len returns the number of values in the set. 117 func (s Set[T]) Len() int { 118 if s.tree == nil { 119 return 0 120 } 121 return s.tree.size 122 } 123 124 // Equal returns true if the two sets contain the equal keys. 125 func (s Set[T]) Equal(other Set[T]) bool { 126 switch { 127 case s.tree == nil && other.tree == nil: 128 return true 129 case s.Len() != other.Len(): 130 return false 131 default: 132 iter1 := s.tree.Iterator() 133 iter2 := other.tree.Iterator() 134 for { 135 k1, _, ok := iter1.Next() 136 if !ok { 137 break 138 } 139 k2, _, _ := iter2.Next() 140 // Equal lengths, no need to check 'ok' for 'iter2'. 141 if !bytes.Equal(k1, k2) { 142 return false 143 } 144 } 145 return true 146 } 147 } 148 149 // ToBytesFunc returns the function to extract the key from 150 // the element type. Useful for utilities that are interested 151 // in the key. 152 func (s Set[T]) ToBytesFunc() func(T) []byte { 153 return s.toBytes 154 } 155 156 func (s Set[T]) MarshalJSON() ([]byte, error) { 157 if s.tree == nil { 158 return []byte("[]"), nil 159 } 160 var b bytes.Buffer 161 b.WriteRune('[') 162 iter := s.tree.Iterator() 163 _, v, ok := iter.Next() 164 for ok { 165 bs, err := json.Marshal(v) 166 if err != nil { 167 return nil, err 168 } 169 b.Write(bs) 170 _, v, ok = iter.Next() 171 if ok { 172 b.WriteRune(',') 173 } 174 } 175 b.WriteRune(']') 176 return b.Bytes(), nil 177 } 178 179 func (s *Set[T]) UnmarshalJSON(data []byte) error { 180 dec := json.NewDecoder(bytes.NewReader(data)) 181 t, err := dec.Token() 182 if err != nil { 183 return err 184 } 185 if d, ok := t.(json.Delim); !ok || d != '[' { 186 return fmt.Errorf("%T.UnmarshalJSON: expected '[' got %v", s, t) 187 } 188 189 if s.tree == nil { 190 *s = NewSet[T]() 191 } 192 txn := s.tree.Txn() 193 194 for dec.More() { 195 var x T 196 err := dec.Decode(&x) 197 if err != nil { 198 return err 199 } 200 txn.Insert(s.toBytes(x), x) 201 } 202 s.tree = txn.CommitOnly() 203 204 t, err = dec.Token() 205 if err != nil { 206 return err 207 } 208 if d, ok := t.(json.Delim); !ok || d != ']' { 209 return fmt.Errorf("%T.UnmarshalJSON: expected ']' got %v", s, t) 210 } 211 return nil 212 } 213 214 func (s Set[T]) MarshalYAML() (any, error) { 215 // TODO: Once yaml.v3 supports iter.Seq, drop the Collect(). 216 return slices.Collect(s.All()), nil 217 } 218 219 func (s *Set[T]) UnmarshalYAML(value *yaml.Node) error { 220 if value.Kind != yaml.SequenceNode { 221 return fmt.Errorf("%T.UnmarshalYAML: expected sequence", s) 222 } 223 224 if s.tree == nil { 225 *s = NewSet[T]() 226 } 227 txn := s.tree.Txn() 228 229 for _, e := range value.Content { 230 var v T 231 if err := e.Decode(&v); err != nil { 232 return err 233 } 234 txn.Insert(s.toBytes(v), v) 235 } 236 s.tree = txn.CommitOnly() 237 return nil 238 } 239 240 func toSeq[T any](iter *Iterator[T]) iter.Seq[T] { 241 return func(yield func(T) bool) { 242 if iter == nil { 243 return 244 } 245 iter = iter.Clone() 246 for _, x, ok := iter.Next(); ok; _, x, ok = iter.Next() { 247 if !yield(x) { 248 break 249 } 250 } 251 } 252 }