github.com/treeverse/lakefs@v1.24.1-0.20240520134607-95648127bfb0/pkg/kv/mem/store.go (about)

     1  package mem
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/base64"
     7  	"fmt"
     8  	"sort"
     9  	"sync"
    10  
    11  	"github.com/treeverse/lakefs/pkg/kv"
    12  	"github.com/treeverse/lakefs/pkg/kv/kvparams"
    13  )
    14  
    15  type Driver struct{}
    16  
    17  // PartitionMap holds key-value pairs of a given partition
    18  type PartitionMap map[string]kv.Entry
    19  
    20  type Store struct {
    21  	m map[string]PartitionMap
    22  
    23  	mu sync.RWMutex
    24  }
    25  
    26  type EntriesIterator struct {
    27  	entry     *kv.Entry
    28  	err       error
    29  	start     []byte
    30  	partition string
    31  	store     *Store
    32  }
    33  
    34  func (e *EntriesIterator) SeekGE(key []byte) {
    35  	e.start = key
    36  }
    37  
    38  const DriverName = "mem"
    39  
    40  //nolint:gochecknoinits
    41  func init() {
    42  	kv.Register(DriverName, &Driver{})
    43  }
    44  
    45  func (d *Driver) Open(_ context.Context, _ kvparams.Config) (kv.Store, error) {
    46  	return &Store{
    47  		m: make(map[string]PartitionMap),
    48  	}, nil
    49  }
    50  
    51  func encodeKey(key []byte) string {
    52  	return base64.StdEncoding.EncodeToString(key)
    53  }
    54  
    55  func (s *Store) Get(_ context.Context, partitionKey, key []byte) (*kv.ValueWithPredicate, error) {
    56  	if len(partitionKey) == 0 {
    57  		return nil, kv.ErrMissingPartitionKey
    58  	}
    59  	if len(key) == 0 {
    60  		return nil, kv.ErrMissingKey
    61  	}
    62  	s.mu.RLock()
    63  	defer s.mu.RUnlock()
    64  
    65  	sKey := encodeKey(key)
    66  	value, ok := s.m[string(partitionKey)][sKey]
    67  	if !ok {
    68  		return nil, fmt.Errorf("partition=%s, key=%v, encoding=%s: %w", partitionKey, key, sKey, kv.ErrNotFound)
    69  	}
    70  	return &kv.ValueWithPredicate{
    71  		Value:     value.Value,
    72  		Predicate: kv.Predicate(value.Value),
    73  	}, nil
    74  }
    75  
    76  func (s *Store) Set(_ context.Context, partitionKey, key, value []byte) error {
    77  	if len(partitionKey) == 0 {
    78  		return kv.ErrMissingPartitionKey
    79  	}
    80  	if len(key) == 0 {
    81  		return kv.ErrMissingKey
    82  	}
    83  	if value == nil {
    84  		return kv.ErrMissingValue
    85  	}
    86  	s.mu.Lock()
    87  	defer s.mu.Unlock()
    88  
    89  	s.internalSet(partitionKey, key, value)
    90  
    91  	return nil
    92  }
    93  
    94  func (s *Store) internalSet(partitionKey, key, value []byte) {
    95  	sKey := encodeKey(key)
    96  	if _, ok := s.m[string(partitionKey)]; !ok {
    97  		s.m[string(partitionKey)] = make(map[string]kv.Entry)
    98  	}
    99  	s.m[string(partitionKey)][sKey] = kv.Entry{
   100  		PartitionKey: partitionKey,
   101  		Key:          key,
   102  		Value:        value,
   103  	}
   104  }
   105  
   106  func (s *Store) SetIf(_ context.Context, partitionKey, key, value []byte, valuePredicate kv.Predicate) error {
   107  	if len(partitionKey) == 0 {
   108  		return kv.ErrMissingPartitionKey
   109  	}
   110  	if len(key) == 0 {
   111  		return kv.ErrMissingKey
   112  	}
   113  	if value == nil {
   114  		return kv.ErrMissingValue
   115  	}
   116  	s.mu.Lock()
   117  	defer s.mu.Unlock()
   118  
   119  	sKey := encodeKey(key)
   120  	curr, currOK := s.m[string(partitionKey)][sKey]
   121  
   122  	switch valuePredicate {
   123  	case nil:
   124  		if currOK {
   125  			return fmt.Errorf("key=%v: %w", key, kv.ErrPredicateFailed)
   126  		}
   127  
   128  	case kv.PrecondConditionalExists:
   129  		if !currOK {
   130  			return fmt.Errorf("key=%v: %w", key, kv.ErrPredicateFailed)
   131  		}
   132  
   133  	default: // check for predicate
   134  		if !bytes.Equal(valuePredicate.([]byte), curr.Value) {
   135  			return fmt.Errorf("%w: partition=%s, key=%v, encoding=%s", kv.ErrPredicateFailed, partitionKey, key, sKey)
   136  		}
   137  	}
   138  
   139  	s.internalSet(partitionKey, key, value)
   140  	return nil
   141  }
   142  
   143  func (s *Store) Delete(_ context.Context, partitionKey, key []byte) error {
   144  	if len(partitionKey) == 0 {
   145  		return kv.ErrMissingPartitionKey
   146  	}
   147  	if len(key) == 0 {
   148  		return kv.ErrMissingKey
   149  	}
   150  	s.mu.Lock()
   151  	defer s.mu.Unlock()
   152  
   153  	sKey := encodeKey(key)
   154  	if _, ok := s.m[string(partitionKey)][sKey]; !ok {
   155  		return nil
   156  	}
   157  	delete(s.m[string(partitionKey)], sKey)
   158  	return nil
   159  }
   160  
   161  func (s *Store) Scan(_ context.Context, partitionKey []byte, options kv.ScanOptions) (kv.EntriesIterator, error) {
   162  	if len(partitionKey) == 0 {
   163  		return nil, kv.ErrMissingPartitionKey
   164  	}
   165  
   166  	start := options.KeyStart
   167  	if start == nil {
   168  		start = []byte{}
   169  	}
   170  	return &EntriesIterator{
   171  		store:     s,
   172  		start:     start,
   173  		partition: string(partitionKey),
   174  	}, nil
   175  }
   176  
   177  func (s *Store) Close() {}
   178  
   179  func (e *EntriesIterator) Next() bool {
   180  	if e.err != nil || e.start == nil { // start is nil only if last iteration we reached end of keys
   181  		return false
   182  	}
   183  
   184  	e.store.mu.RLock()
   185  	defer e.store.mu.RUnlock()
   186  
   187  	var l []*kv.Entry
   188  	if _, ok := e.store.m[e.partition]; ok {
   189  		for _, entry := range e.store.m[e.partition] {
   190  			if bytes.Compare(entry.Key, e.start) >= 0 {
   191  				entry := entry
   192  				l = append(l, &entry)
   193  			}
   194  		}
   195  	}
   196  	if len(l) == 0 { // No results
   197  		e.start = nil
   198  		return false
   199  	}
   200  	if len(l) == 1 { // only one key >= start, set start to nil, so to indicate the next call to return false immediately.
   201  		e.start = nil
   202  		e.entry = l[0]
   203  		return true
   204  	}
   205  	sort.Slice(l, func(i, j int) bool { return bytes.Compare(l[i].Key, l[j].Key) < 0 })
   206  	e.start = l[1].Key
   207  	e.entry = l[0]
   208  	return true
   209  }
   210  
   211  func (e *EntriesIterator) Entry() *kv.Entry {
   212  	return e.entry
   213  }
   214  
   215  func (e *EntriesIterator) Err() error {
   216  	return e.err
   217  }
   218  
   219  func (e *EntriesIterator) Close() {
   220  	e.err = kv.ErrClosedEntries
   221  }