github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/cache/sharded_lock_cache.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package cache
    13  
    14  import (
    15  	"context"
    16  	"sync"
    17  	"sync/atomic"
    18  	"time"
    19  	"unsafe"
    20  
    21  	enterrors "github.com/weaviate/weaviate/entities/errors"
    22  
    23  	"github.com/sirupsen/logrus"
    24  	"github.com/weaviate/weaviate/adapters/repos/db/vector/common"
    25  	"github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw/distancer"
    26  )
    27  
    28  type shardedLockCache[T float32 | byte | uint64] struct {
    29  	shardedLocks     *common.ShardedRWLocks
    30  	cache            [][]T
    31  	vectorForID      common.VectorForID[T]
    32  	normalizeOnRead  bool
    33  	maxSize          int64
    34  	count            int64
    35  	cancel           chan bool
    36  	logger           logrus.FieldLogger
    37  	deletionInterval time.Duration
    38  
    39  	// The maintenanceLock makes sure that only one maintenance operation, such
    40  	// as growing the cache or clearing the cache happens at the same time.
    41  	maintenanceLock sync.RWMutex
    42  }
    43  
    44  const (
    45  	InitialSize             = 1000
    46  	MinimumIndexGrowthDelta = 2000
    47  	indexGrowthRate         = 1.25
    48  )
    49  
    50  func NewShardedFloat32LockCache(vecForID common.VectorForID[float32], maxSize int,
    51  	logger logrus.FieldLogger, normalizeOnRead bool, deletionInterval time.Duration,
    52  ) Cache[float32] {
    53  	vc := &shardedLockCache[float32]{
    54  		vectorForID: func(ctx context.Context, id uint64) ([]float32, error) {
    55  			vec, err := vecForID(ctx, id)
    56  			if err != nil {
    57  				return nil, err
    58  			}
    59  			if normalizeOnRead {
    60  				vec = distancer.Normalize(vec)
    61  			}
    62  			return vec, nil
    63  		},
    64  		cache:            make([][]float32, InitialSize),
    65  		normalizeOnRead:  normalizeOnRead,
    66  		count:            0,
    67  		maxSize:          int64(maxSize),
    68  		cancel:           make(chan bool),
    69  		logger:           logger,
    70  		shardedLocks:     common.NewDefaultShardedRWLocks(),
    71  		maintenanceLock:  sync.RWMutex{},
    72  		deletionInterval: deletionInterval,
    73  	}
    74  
    75  	vc.watchForDeletion()
    76  	return vc
    77  }
    78  
    79  func NewShardedByteLockCache(vecForID common.VectorForID[byte], maxSize int,
    80  	logger logrus.FieldLogger, deletionInterval time.Duration,
    81  ) Cache[byte] {
    82  	vc := &shardedLockCache[byte]{
    83  		vectorForID:      vecForID,
    84  		cache:            make([][]byte, InitialSize),
    85  		normalizeOnRead:  false,
    86  		count:            0,
    87  		maxSize:          int64(maxSize),
    88  		cancel:           make(chan bool),
    89  		logger:           logger,
    90  		shardedLocks:     common.NewDefaultShardedRWLocks(),
    91  		maintenanceLock:  sync.RWMutex{},
    92  		deletionInterval: deletionInterval,
    93  	}
    94  
    95  	vc.watchForDeletion()
    96  	return vc
    97  }
    98  
    99  func NewShardedUInt64LockCache(vecForID common.VectorForID[uint64], maxSize int,
   100  	logger logrus.FieldLogger, deletionInterval time.Duration,
   101  ) Cache[uint64] {
   102  	vc := &shardedLockCache[uint64]{
   103  		vectorForID:      vecForID,
   104  		cache:            make([][]uint64, InitialSize),
   105  		normalizeOnRead:  false,
   106  		count:            0,
   107  		maxSize:          int64(maxSize),
   108  		cancel:           make(chan bool),
   109  		logger:           logger,
   110  		shardedLocks:     common.NewDefaultShardedRWLocks(),
   111  		maintenanceLock:  sync.RWMutex{},
   112  		deletionInterval: deletionInterval,
   113  	}
   114  
   115  	vc.watchForDeletion()
   116  	return vc
   117  }
   118  
   119  func (s *shardedLockCache[T]) All() [][]T {
   120  	return s.cache
   121  }
   122  
   123  func (s *shardedLockCache[T]) Get(ctx context.Context, id uint64) ([]T, error) {
   124  	s.shardedLocks.RLock(id)
   125  	vec := s.cache[id]
   126  	s.shardedLocks.RUnlock(id)
   127  
   128  	if vec != nil {
   129  		return vec, nil
   130  	}
   131  
   132  	return s.handleCacheMiss(ctx, id)
   133  }
   134  
   135  func (s *shardedLockCache[T]) Delete(ctx context.Context, id uint64) {
   136  	s.shardedLocks.Lock(id)
   137  	defer s.shardedLocks.Unlock(id)
   138  
   139  	if int(id) >= len(s.cache) || s.cache[id] == nil {
   140  		return
   141  	}
   142  
   143  	s.cache[id] = nil
   144  	atomic.AddInt64(&s.count, -1)
   145  }
   146  
   147  func (s *shardedLockCache[T]) handleCacheMiss(ctx context.Context, id uint64) ([]T, error) {
   148  	vec, err := s.vectorForID(ctx, id)
   149  	if err != nil {
   150  		return nil, err
   151  	}
   152  
   153  	atomic.AddInt64(&s.count, 1)
   154  	s.shardedLocks.Lock(id)
   155  	s.cache[id] = vec
   156  	s.shardedLocks.Unlock(id)
   157  
   158  	return vec, nil
   159  }
   160  
   161  func (s *shardedLockCache[T]) MultiGet(ctx context.Context, ids []uint64) ([][]T, []error) {
   162  	out := make([][]T, len(ids))
   163  	errs := make([]error, len(ids))
   164  
   165  	for i, id := range ids {
   166  		s.shardedLocks.RLock(id)
   167  		vec := s.cache[id]
   168  		s.shardedLocks.RUnlock(id)
   169  
   170  		if vec == nil {
   171  			vecFromDisk, err := s.handleCacheMiss(ctx, id)
   172  			errs[i] = err
   173  			vec = vecFromDisk
   174  		}
   175  
   176  		out[i] = vec
   177  	}
   178  
   179  	return out, errs
   180  }
   181  
   182  var prefetchFunc func(in uintptr) = func(in uintptr) {
   183  	// do nothing on default arch
   184  	// this function will be overridden for amd64
   185  }
   186  
   187  func (s *shardedLockCache[T]) Prefetch(id uint64) {
   188  	s.shardedLocks.RLock(id)
   189  	defer s.shardedLocks.RUnlock(id)
   190  
   191  	prefetchFunc(uintptr(unsafe.Pointer(&s.cache[id])))
   192  }
   193  
   194  func (s *shardedLockCache[T]) Preload(id uint64, vec []T) {
   195  	s.shardedLocks.Lock(id)
   196  	defer s.shardedLocks.Unlock(id)
   197  
   198  	atomic.AddInt64(&s.count, 1)
   199  	s.cache[id] = vec
   200  }
   201  
   202  func (s *shardedLockCache[T]) Grow(node uint64) {
   203  	s.maintenanceLock.RLock()
   204  	if node < uint64(len(s.cache)) {
   205  		s.maintenanceLock.RUnlock()
   206  		return
   207  	}
   208  	s.maintenanceLock.RUnlock()
   209  
   210  	s.maintenanceLock.Lock()
   211  	defer s.maintenanceLock.Unlock()
   212  
   213  	// make sure cache still needs growing
   214  	// (it could have grown while waiting for maintenance lock)
   215  	if node < uint64(len(s.cache)) {
   216  		return
   217  	}
   218  
   219  	s.shardedLocks.LockAll()
   220  	defer s.shardedLocks.UnlockAll()
   221  
   222  	newSize := node + MinimumIndexGrowthDelta
   223  	newCache := make([][]T, newSize)
   224  	copy(newCache, s.cache)
   225  	s.cache = newCache
   226  }
   227  
   228  func (s *shardedLockCache[T]) Len() int32 {
   229  	s.maintenanceLock.RLock()
   230  	defer s.maintenanceLock.RUnlock()
   231  
   232  	return int32(len(s.cache))
   233  }
   234  
   235  func (s *shardedLockCache[T]) CountVectors() int64 {
   236  	return atomic.LoadInt64(&s.count)
   237  }
   238  
   239  func (s *shardedLockCache[T]) Drop() {
   240  	s.deleteAllVectors()
   241  	if s.deletionInterval != 0 {
   242  		s.cancel <- true
   243  	}
   244  }
   245  
   246  func (s *shardedLockCache[T]) deleteAllVectors() {
   247  	s.shardedLocks.LockAll()
   248  	defer s.shardedLocks.UnlockAll()
   249  
   250  	s.logger.WithField("action", "hnsw_delete_vector_cache").
   251  		Debug("deleting full vector cache")
   252  	for i := range s.cache {
   253  		s.cache[i] = nil
   254  	}
   255  
   256  	atomic.StoreInt64(&s.count, 0)
   257  }
   258  
   259  func (s *shardedLockCache[T]) watchForDeletion() {
   260  	if s.deletionInterval != 0 {
   261  		f := func() {
   262  			t := time.NewTicker(s.deletionInterval)
   263  			defer t.Stop()
   264  			for {
   265  				select {
   266  				case <-s.cancel:
   267  					return
   268  				case <-t.C:
   269  					s.replaceIfFull()
   270  				}
   271  			}
   272  		}
   273  		enterrors.GoWrapper(f, s.logger)
   274  	}
   275  }
   276  
   277  func (s *shardedLockCache[T]) replaceIfFull() {
   278  	if atomic.LoadInt64(&s.count) >= atomic.LoadInt64(&s.maxSize) {
   279  		s.deleteAllVectors()
   280  	}
   281  }
   282  
   283  func (s *shardedLockCache[T]) UpdateMaxSize(size int64) {
   284  	atomic.StoreInt64(&s.maxSize, size)
   285  }
   286  
   287  func (s *shardedLockCache[T]) CopyMaxSize() int64 {
   288  	sizeCopy := atomic.LoadInt64(&s.maxSize)
   289  	return sizeCopy
   290  }
   291  
   292  // noopCache can be helpful in debugging situations, where we want to
   293  // explicitly pass through each vectorForID call to the underlying vectorForID
   294  // function without caching in between.
   295  type noopCache struct {
   296  	vectorForID common.VectorForID[float32]
   297  }
   298  
   299  func NewNoopCache(vecForID common.VectorForID[float32], maxSize int,
   300  	logger logrus.FieldLogger,
   301  ) *noopCache {
   302  	return &noopCache{vectorForID: vecForID}
   303  }