github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/cache/sharded_lock_cache_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package cache
    13  
    14  import (
    15  	"context"
    16  	"math/rand"
    17  	"sync"
    18  	"testing"
    19  	"time"
    20  
    21  	"github.com/sirupsen/logrus/hooks/test"
    22  	"github.com/stretchr/testify/assert"
    23  	"github.com/weaviate/weaviate/adapters/repos/db/vector/common"
    24  )
    25  
    26  func TestVectorCacheGrowth(t *testing.T) {
    27  	logger, _ := test.NewNullLogger()
    28  	var vecForId common.VectorForID[float32] = nil
    29  	id := 100_000
    30  	expectedCount := int64(0)
    31  
    32  	vectorCache := NewShardedFloat32LockCache(vecForId, 1_000_000, logger, false, time.Duration(10_000))
    33  	initialSize := vectorCache.Len()
    34  	assert.Less(t, int(initialSize), id)
    35  	assert.Equal(t, expectedCount, vectorCache.CountVectors())
    36  
    37  	vectorCache.Grow(uint64(id))
    38  	size1stGrow := vectorCache.Len()
    39  	assert.Greater(t, int(size1stGrow), id)
    40  	assert.Equal(t, expectedCount, vectorCache.CountVectors())
    41  
    42  	vectorCache.Grow(uint64(id))
    43  	size2ndGrow := vectorCache.Len()
    44  	assert.Equal(t, size1stGrow, size2ndGrow)
    45  	assert.Equal(t, expectedCount, vectorCache.CountVectors())
    46  }
    47  
    48  func TestCache_ParallelGrowth(t *testing.T) {
    49  	// no asserts
    50  	// ensures there is no "index out of range" panic on get
    51  
    52  	logger, _ := test.NewNullLogger()
    53  	var vecForId common.VectorForID[float32] = func(context.Context, uint64) ([]float32, error) { return nil, nil }
    54  	vectorCache := NewShardedFloat32LockCache(vecForId, 1_000_000, logger, false, time.Second)
    55  
    56  	r := rand.New(rand.NewSource(time.Now().UnixNano()))
    57  	count := 10_000
    58  	maxNode := 100_000
    59  
    60  	wg := new(sync.WaitGroup)
    61  	wg.Add(count)
    62  	for i := 0; i < count; i++ {
    63  		node := uint64(r.Intn(maxNode))
    64  		go func(node uint64) {
    65  			defer wg.Done()
    66  
    67  			vectorCache.Grow(node)
    68  			vectorCache.Get(context.Background(), node)
    69  		}(node)
    70  	}
    71  
    72  	wg.Wait()
    73  }
    74  
    75  func TestCacheCleanup(t *testing.T) {
    76  	logger, _ := test.NewNullLogger()
    77  	var vecForId common.VectorForID[float32] = nil
    78  
    79  	maxSize := 10
    80  	batchSize := maxSize - 1
    81  	deletionInterval := 200 * time.Millisecond // overwrite default deletionInterval of 3s
    82  	sleepMs := deletionInterval + 100*time.Millisecond
    83  
    84  	t.Run("count is not reset on unnecessary deletion", func(t *testing.T) {
    85  		vectorCache := NewShardedFloat32LockCache(vecForId, maxSize, logger, false, deletionInterval)
    86  		shardedLockCache, ok := vectorCache.(*shardedLockCache[float32])
    87  		assert.True(t, ok)
    88  
    89  		for i := 0; i < batchSize; i++ {
    90  			shardedLockCache.Preload(uint64(i), []float32{float32(i), float32(i)})
    91  		}
    92  		time.Sleep(sleepMs) // wait for deletion to fire
    93  
    94  		assert.Equal(t, batchSize, int(shardedLockCache.CountVectors()))
    95  		assert.Equal(t, batchSize, countCached(shardedLockCache))
    96  
    97  		shardedLockCache.Drop()
    98  
    99  		assert.Equal(t, 0, int(shardedLockCache.count))
   100  		assert.Equal(t, 0, countCached(shardedLockCache))
   101  	})
   102  
   103  	t.Run("deletion clears cache and counter when maxSize exceeded", func(t *testing.T) {
   104  		vectorCache := NewShardedFloat32LockCache(vecForId, maxSize, logger, false, deletionInterval)
   105  		shardedLockCache, ok := vectorCache.(*shardedLockCache[float32])
   106  		assert.True(t, ok)
   107  
   108  		for b := 0; b < 2; b++ {
   109  			for i := 0; i < batchSize; i++ {
   110  				id := b*batchSize + i
   111  				shardedLockCache.Preload(uint64(id), []float32{float32(id), float32(id)})
   112  			}
   113  			time.Sleep(sleepMs) // wait for deletion to fire, 2nd should clean the cache
   114  		}
   115  
   116  		assert.Equal(t, 0, int(shardedLockCache.CountVectors()))
   117  		assert.Equal(t, 0, countCached(shardedLockCache))
   118  
   119  		shardedLockCache.Drop()
   120  	})
   121  }
   122  
   123  func countCached(c *shardedLockCache[float32]) int {
   124  	c.shardedLocks.LockAll()
   125  	defer c.shardedLocks.UnlockAll()
   126  
   127  	count := 0
   128  	for _, vec := range c.cache {
   129  		if vec != nil {
   130  			count++
   131  		}
   132  	}
   133  	return count
   134  }