github.com/cilium/statedb@v0.3.2/part/map.go (about)

     1  // SPDX-License-Identifier: Apache-2.0
     2  // Copyright Authors of Cilium
     3  
     4  package part
     5  
     6  import (
     7  	"bytes"
     8  	"encoding/json"
     9  	"fmt"
    10  	"iter"
    11  	"reflect"
    12  
    13  	"gopkg.in/yaml.v3"
    14  )
    15  
    16  // Map of key-value pairs. The zero value is ready for use, provided
    17  // that the key type has been registered with RegisterKeyType.
    18  //
    19  // Map is a typed wrapper around Tree[T] for working with
    20  // keys that are not []byte.
    21  type Map[K, V any] struct {
    22  	bytesFromKey func(K) []byte
    23  	tree         *Tree[mapKVPair[K, V]]
    24  }
    25  
    26  type mapKVPair[K, V any] struct {
    27  	Key   K `json:"k" yaml:"k"`
    28  	Value V `json:"v" yaml:"v"`
    29  }
    30  
    31  // FromMap copies values from the hash map into the given Map.
    32  // This is not implemented as a method on Map[K,V] as hash maps require the
    33  // comparable constraint and we do not need to limit Map[K, V] to that.
    34  func FromMap[K comparable, V any](m Map[K, V], hm map[K]V) Map[K, V] {
    35  	m.ensureTree()
    36  	txn := m.tree.Txn()
    37  	for k, v := range hm {
    38  		txn.Insert(m.bytesFromKey(k), mapKVPair[K, V]{k, v})
    39  	}
    40  	m.tree = txn.CommitOnly()
    41  	return m
    42  }
    43  
    44  // ensureTree checks that the tree is not nil and allocates it if
    45  // it is. The whole nil tree thing is to make sure that creating
    46  // an empty map does not allocate anything.
    47  func (m *Map[K, V]) ensureTree() {
    48  	if m.tree == nil {
    49  		m.tree = New[mapKVPair[K, V]](RootOnlyWatch)
    50  	}
    51  	m.bytesFromKey = lookupKeyType[K]()
    52  }
    53  
    54  // Get a value from the map by its key.
    55  func (m Map[K, V]) Get(key K) (value V, found bool) {
    56  	if m.tree == nil {
    57  		return
    58  	}
    59  	kv, _, found := m.tree.Get(m.bytesFromKey(key))
    60  	return kv.Value, found
    61  }
    62  
    63  // Set a value. Returns a new map with the value set.
    64  // Original map is unchanged.
    65  func (m Map[K, V]) Set(key K, value V) Map[K, V] {
    66  	m.ensureTree()
    67  	txn := m.tree.Txn()
    68  	txn.Insert(m.bytesFromKey(key), mapKVPair[K, V]{key, value})
    69  	m.tree = txn.CommitOnly()
    70  	return m
    71  }
    72  
    73  // Delete a value from the map. Returns a new map
    74  // without the element pointed to by the key (if found).
    75  func (m Map[K, V]) Delete(key K) Map[K, V] {
    76  	if m.tree != nil {
    77  		txn := m.tree.Txn()
    78  		txn.Delete(m.bytesFromKey(key))
    79  		// Map is a struct passed by value, so we can modify
    80  		// it without changing the caller's view of it.
    81  		m.tree = txn.CommitOnly()
    82  	}
    83  	return m
    84  }
    85  
    86  func toSeq2[K, V any](iter *Iterator[mapKVPair[K, V]]) iter.Seq2[K, V] {
    87  	return func(yield func(K, V) bool) {
    88  		if iter == nil {
    89  			return
    90  		}
    91  		iter = iter.Clone()
    92  		for _, kv, ok := iter.Next(); ok; _, kv, ok = iter.Next() {
    93  			if !yield(kv.Key, kv.Value) {
    94  				break
    95  			}
    96  		}
    97  	}
    98  }
    99  
   100  // LowerBound iterates over all keys in order with value equal
   101  // to or greater than [from].
   102  func (m Map[K, V]) LowerBound(from K) iter.Seq2[K, V] {
   103  	if m.tree == nil {
   104  		return toSeq2[K, V](nil)
   105  	}
   106  	return toSeq2(m.tree.LowerBound(m.bytesFromKey(from)))
   107  }
   108  
   109  // Prefix iterates in order over all keys that start with
   110  // the given prefix.
   111  func (m Map[K, V]) Prefix(prefix K) iter.Seq2[K, V] {
   112  	if m.tree == nil {
   113  		return toSeq2[K, V](nil)
   114  	}
   115  	iter, _ := m.tree.Prefix(m.bytesFromKey(prefix))
   116  	return toSeq2(iter)
   117  }
   118  
   119  // All iterates every key-value in the map in order.
   120  // The order is in bytewise order of the byte slice
   121  // returned by bytesFromKey.
   122  func (m Map[K, V]) All() iter.Seq2[K, V] {
   123  	if m.tree == nil {
   124  		return toSeq2[K, V](nil)
   125  	}
   126  	return toSeq2(m.tree.Iterator())
   127  }
   128  
   129  // EqualKeys returns true if both maps contain the same keys.
   130  func (m Map[K, V]) EqualKeys(other Map[K, V]) bool {
   131  	switch {
   132  	case m.tree == nil && other.tree == nil:
   133  		return true
   134  	case m.Len() != other.Len():
   135  		return false
   136  	default:
   137  		iter1 := m.tree.Iterator()
   138  		iter2 := other.tree.Iterator()
   139  		for {
   140  			k1, _, ok := iter1.Next()
   141  			if !ok {
   142  				break
   143  			}
   144  			k2, _, _ := iter2.Next()
   145  			// Equal lengths, no need to check 'ok' for 'iter2'.
   146  			if !bytes.Equal(k1, k2) {
   147  				return false
   148  			}
   149  		}
   150  		return true
   151  	}
   152  }
   153  
   154  // SlowEqual returns true if the two maps contain the same keys and values.
   155  // Value comparison is implemented with reflect.DeepEqual which makes this
   156  // slow and mostly useful for testing.
   157  func (m Map[K, V]) SlowEqual(other Map[K, V]) bool {
   158  	switch {
   159  	case m.tree == nil && other.tree == nil:
   160  		return true
   161  	case m.Len() != other.Len():
   162  		return false
   163  	default:
   164  		iter1 := m.tree.Iterator()
   165  		iter2 := other.tree.Iterator()
   166  		for {
   167  			k1, v1, ok := iter1.Next()
   168  			if !ok {
   169  				break
   170  			}
   171  			k2, v2, _ := iter2.Next()
   172  			// Equal lengths, no need to check 'ok' for 'iter2'.
   173  			if !bytes.Equal(k1, k2) || !reflect.DeepEqual(v1, v2) {
   174  				return false
   175  			}
   176  		}
   177  		return true
   178  	}
   179  }
   180  
   181  // Len returns the number of elements in the map.
   182  func (m Map[K, V]) Len() int {
   183  	if m.tree == nil {
   184  		return 0
   185  	}
   186  	return m.tree.size
   187  }
   188  
   189  func (m Map[K, V]) MarshalJSON() ([]byte, error) {
   190  	if m.tree == nil {
   191  		return []byte("[]"), nil
   192  	}
   193  
   194  	var b bytes.Buffer
   195  	b.WriteRune('[')
   196  	iter := m.tree.Iterator()
   197  	_, kv, ok := iter.Next()
   198  	for ok {
   199  		bs, err := json.Marshal(kv)
   200  		if err != nil {
   201  			return nil, err
   202  		}
   203  		b.Write(bs)
   204  		_, kv, ok = iter.Next()
   205  		if ok {
   206  			b.WriteRune(',')
   207  		}
   208  	}
   209  	b.WriteRune(']')
   210  	return b.Bytes(), nil
   211  }
   212  
   213  func (m *Map[K, V]) UnmarshalJSON(data []byte) error {
   214  	dec := json.NewDecoder(bytes.NewReader(data))
   215  	t, err := dec.Token()
   216  	if err != nil {
   217  		return err
   218  	}
   219  	if d, ok := t.(json.Delim); !ok || d != '[' {
   220  		return fmt.Errorf("%T.UnmarshalJSON: expected '[' got %v", m, t)
   221  	}
   222  	m.ensureTree()
   223  	txn := m.tree.Txn()
   224  	for dec.More() {
   225  		var kv mapKVPair[K, V]
   226  		err := dec.Decode(&kv)
   227  		if err != nil {
   228  			return err
   229  		}
   230  		txn.Insert(m.bytesFromKey(kv.Key), mapKVPair[K, V]{kv.Key, kv.Value})
   231  	}
   232  
   233  	t, err = dec.Token()
   234  	if err != nil {
   235  		return err
   236  	}
   237  	if d, ok := t.(json.Delim); !ok || d != ']' {
   238  		return fmt.Errorf("%T.UnmarshalJSON: expected ']' got %v", m, t)
   239  	}
   240  	m.tree = txn.CommitOnly()
   241  	return nil
   242  }
   243  
   244  func (m Map[K, V]) MarshalYAML() (any, error) {
   245  	kvs := make([]mapKVPair[K, V], 0, m.Len())
   246  	iter := m.tree.Iterator()
   247  	for _, kv, ok := iter.Next(); ok; _, kv, ok = iter.Next() {
   248  		kvs = append(kvs, kv)
   249  	}
   250  	return kvs, nil
   251  }
   252  
   253  func (m *Map[K, V]) UnmarshalYAML(value *yaml.Node) error {
   254  	if value.Kind != yaml.SequenceNode {
   255  		return fmt.Errorf("%T.UnmarshalYAML: expected sequence", m)
   256  	}
   257  	m.ensureTree()
   258  	txn := m.tree.Txn()
   259  	for _, e := range value.Content {
   260  		var kv mapKVPair[K, V]
   261  		if err := e.Decode(&kv); err != nil {
   262  			return err
   263  		}
   264  		txn.Insert(m.bytesFromKey(kv.Key), mapKVPair[K, V]{kv.Key, kv.Value})
   265  	}
   266  	m.tree = txn.CommitOnly()
   267  	return nil
   268  }