github.com/cilium/statedb@v0.3.2/part/set.go (about)

     1  // SPDX-License-Identifier: Apache-2.0
     2  // Copyright Authors of Cilium
     3  
     4  package part
     5  
     6  import (
     7  	"bytes"
     8  	"encoding/json"
     9  	"fmt"
    10  	"iter"
    11  	"slices"
    12  
    13  	"gopkg.in/yaml.v3"
    14  )
    15  
    16  // Set is a persistent (immutable) set of values. A Set can be
    17  // defined for any type for which a byte slice key can be derived.
    18  //
    19  // A zero value Set[T] can be used provided that the conversion
    20  // function for T have been registered with RegisterKeyType.
    21  // For Set-only use only [bytesFromKey] needs to be defined.
    22  type Set[T any] struct {
    23  	toBytes func(T) []byte
    24  	tree    *Tree[T]
    25  }
    26  
    27  // NewSet creates a new set of T.
    28  // The value type T must be registered with RegisterKeyType.
    29  func NewSet[T any](values ...T) Set[T] {
    30  	s := Set[T]{tree: New[T](RootOnlyWatch)}
    31  	s.toBytes = lookupKeyType[T]()
    32  	if len(values) > 0 {
    33  		txn := s.tree.Txn()
    34  		for _, v := range values {
    35  			txn.Insert(s.toBytes(v), v)
    36  		}
    37  		s.tree = txn.CommitOnly()
    38  	}
    39  	return s
    40  }
    41  
    42  // Set a value. Returns a new set. Original is unchanged.
    43  func (s Set[T]) Set(v T) Set[T] {
    44  	if s.tree == nil {
    45  		return NewSet(v)
    46  	}
    47  	txn := s.tree.Txn()
    48  	txn.Insert(s.toBytes(v), v)
    49  	s.tree = txn.CommitOnly() // As Set is passed by value we can just modify it.
    50  	return s
    51  }
    52  
    53  // Delete returns a new set without the value. The original
    54  // set is unchanged.
    55  func (s Set[T]) Delete(v T) Set[T] {
    56  	if s.tree == nil {
    57  		return s
    58  	}
    59  	txn := s.tree.Txn()
    60  	txn.Delete(s.toBytes(v))
    61  	s.tree = txn.CommitOnly()
    62  	return s
    63  }
    64  
    65  // Has returns true if the set has the value.
    66  func (s Set[T]) Has(v T) bool {
    67  	if s.tree == nil {
    68  		return false
    69  	}
    70  	_, _, found := s.tree.Get(s.toBytes(v))
    71  	return found
    72  }
    73  
    74  // All returns an iterator for all values.
    75  func (s Set[T]) All() iter.Seq[T] {
    76  	if s.tree == nil {
    77  		return toSeq[T](nil)
    78  	}
    79  	return toSeq(s.tree.Iterator())
    80  }
    81  
    82  // Union returns a set that is the union of the values
    83  // in the input sets.
    84  func (s Set[T]) Union(s2 Set[T]) Set[T] {
    85  	if s2.tree == nil {
    86  		return s
    87  	}
    88  	if s.tree == nil {
    89  		return s2
    90  	}
    91  	txn := s.tree.Txn()
    92  	iter := s2.tree.Iterator()
    93  	for k, v, ok := iter.Next(); ok; k, v, ok = iter.Next() {
    94  		txn.Insert(k, v)
    95  	}
    96  	s.tree = txn.CommitOnly()
    97  	return s
    98  }
    99  
   100  // Difference returns a set with values that only
   101  // appear in the first set.
   102  func (s Set[T]) Difference(s2 Set[T]) Set[T] {
   103  	if s.tree == nil || s2.tree == nil {
   104  		return s
   105  	}
   106  
   107  	txn := s.tree.Txn()
   108  	iter := s2.tree.Iterator()
   109  	for k, _, ok := iter.Next(); ok; k, _, ok = iter.Next() {
   110  		txn.Delete(k)
   111  	}
   112  	s.tree = txn.CommitOnly()
   113  	return s
   114  }
   115  
   116  // Len returns the number of values in the set.
   117  func (s Set[T]) Len() int {
   118  	if s.tree == nil {
   119  		return 0
   120  	}
   121  	return s.tree.size
   122  }
   123  
   124  // Equal returns true if the two sets contain the equal keys.
   125  func (s Set[T]) Equal(other Set[T]) bool {
   126  	switch {
   127  	case s.tree == nil && other.tree == nil:
   128  		return true
   129  	case s.Len() != other.Len():
   130  		return false
   131  	default:
   132  		iter1 := s.tree.Iterator()
   133  		iter2 := other.tree.Iterator()
   134  		for {
   135  			k1, _, ok := iter1.Next()
   136  			if !ok {
   137  				break
   138  			}
   139  			k2, _, _ := iter2.Next()
   140  			// Equal lengths, no need to check 'ok' for 'iter2'.
   141  			if !bytes.Equal(k1, k2) {
   142  				return false
   143  			}
   144  		}
   145  		return true
   146  	}
   147  }
   148  
   149  // ToBytesFunc returns the function to extract the key from
   150  // the element type. Useful for utilities that are interested
   151  // in the key.
   152  func (s Set[T]) ToBytesFunc() func(T) []byte {
   153  	return s.toBytes
   154  }
   155  
   156  func (s Set[T]) MarshalJSON() ([]byte, error) {
   157  	if s.tree == nil {
   158  		return []byte("[]"), nil
   159  	}
   160  	var b bytes.Buffer
   161  	b.WriteRune('[')
   162  	iter := s.tree.Iterator()
   163  	_, v, ok := iter.Next()
   164  	for ok {
   165  		bs, err := json.Marshal(v)
   166  		if err != nil {
   167  			return nil, err
   168  		}
   169  		b.Write(bs)
   170  		_, v, ok = iter.Next()
   171  		if ok {
   172  			b.WriteRune(',')
   173  		}
   174  	}
   175  	b.WriteRune(']')
   176  	return b.Bytes(), nil
   177  }
   178  
   179  func (s *Set[T]) UnmarshalJSON(data []byte) error {
   180  	dec := json.NewDecoder(bytes.NewReader(data))
   181  	t, err := dec.Token()
   182  	if err != nil {
   183  		return err
   184  	}
   185  	if d, ok := t.(json.Delim); !ok || d != '[' {
   186  		return fmt.Errorf("%T.UnmarshalJSON: expected '[' got %v", s, t)
   187  	}
   188  
   189  	if s.tree == nil {
   190  		*s = NewSet[T]()
   191  	}
   192  	txn := s.tree.Txn()
   193  
   194  	for dec.More() {
   195  		var x T
   196  		err := dec.Decode(&x)
   197  		if err != nil {
   198  			return err
   199  		}
   200  		txn.Insert(s.toBytes(x), x)
   201  	}
   202  	s.tree = txn.CommitOnly()
   203  
   204  	t, err = dec.Token()
   205  	if err != nil {
   206  		return err
   207  	}
   208  	if d, ok := t.(json.Delim); !ok || d != ']' {
   209  		return fmt.Errorf("%T.UnmarshalJSON: expected ']' got %v", s, t)
   210  	}
   211  	return nil
   212  }
   213  
   214  func (s Set[T]) MarshalYAML() (any, error) {
   215  	// TODO: Once yaml.v3 supports iter.Seq, drop the Collect().
   216  	return slices.Collect(s.All()), nil
   217  }
   218  
   219  func (s *Set[T]) UnmarshalYAML(value *yaml.Node) error {
   220  	if value.Kind != yaml.SequenceNode {
   221  		return fmt.Errorf("%T.UnmarshalYAML: expected sequence", s)
   222  	}
   223  
   224  	if s.tree == nil {
   225  		*s = NewSet[T]()
   226  	}
   227  	txn := s.tree.Txn()
   228  
   229  	for _, e := range value.Content {
   230  		var v T
   231  		if err := e.Decode(&v); err != nil {
   232  			return err
   233  		}
   234  		txn.Insert(s.toBytes(v), v)
   235  	}
   236  	s.tree = txn.CommitOnly()
   237  	return nil
   238  }
   239  
   240  func toSeq[T any](iter *Iterator[T]) iter.Seq[T] {
   241  	return func(yield func(T) bool) {
   242  		if iter == nil {
   243  			return
   244  		}
   245  		iter = iter.Clone()
   246  		for _, x, ok := iter.Next(); ok; _, x, ok = iter.Next() {
   247  			if !yield(x) {
   248  				break
   249  			}
   250  		}
   251  	}
   252  }