github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/hnsw/index_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package hnsw
    13  
    14  import (
    15  	"context"
    16  	"testing"
    17  
    18  	"github.com/stretchr/testify/assert"
    19  	"github.com/stretchr/testify/require"
    20  	"github.com/weaviate/weaviate/adapters/repos/db/vector/cache"
    21  	"github.com/weaviate/weaviate/adapters/repos/db/vector/common"
    22  	"github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw/distancer"
    23  	"github.com/weaviate/weaviate/adapters/repos/db/vector/testinghelpers"
    24  	"github.com/weaviate/weaviate/entities/cyclemanager"
    25  	ent "github.com/weaviate/weaviate/entities/vectorindex/hnsw"
    26  )
    27  
    28  func TestHnswIndex(t *testing.T) {
    29  	index := createEmptyHnswIndexForTests(t, testVectorForID)
    30  
    31  	for i, vec := range testVectors {
    32  		err := index.Add(uint64(i), vec)
    33  		require.Nil(t, err)
    34  	}
    35  
    36  	t.Run("searching within cluster 1", func(t *testing.T) {
    37  		position := 0
    38  		res, _, err := index.knnSearchByVector(testVectors[position], 3, 36, nil)
    39  		require.Nil(t, err)
    40  		assert.ElementsMatch(t, []uint64{0, 1, 2}, res)
    41  	})
    42  
    43  	t.Run("searching within cluster 2", func(t *testing.T) {
    44  		position := 3
    45  		res, _, err := index.knnSearchByVector(testVectors[position], 3, 36, nil)
    46  		require.Nil(t, err)
    47  		assert.ElementsMatch(t, []uint64{3, 4, 5}, res)
    48  	})
    49  
    50  	t.Run("searching within cluster 3", func(t *testing.T) {
    51  		position := 6
    52  		res, _, err := index.knnSearchByVector(testVectors[position], 3, 36, nil)
    53  		require.Nil(t, err)
    54  		assert.ElementsMatch(t, []uint64{6, 7, 8}, res)
    55  	})
    56  
    57  	t.Run("searching within cluster 2 with a scope larger than the cluster", func(t *testing.T) {
    58  		position := 3
    59  		res, _, err := index.knnSearchByVector(testVectors[position], 50, 36, nil)
    60  		require.Nil(t, err)
    61  		assert.Equal(t, []uint64{
    62  			3, 5, 4, // cluster 2
    63  			7, 8, 6, // cluster 3
    64  			2, 1, 0, // cluster 1
    65  		}, res)
    66  	})
    67  
    68  	t.Run("searching with negative value of k", func(t *testing.T) {
    69  		position := 0
    70  		_, _, err := index.knnSearchByVector(testVectors[position], -1, 36, nil)
    71  		require.Error(t, err)
    72  	})
    73  }
    74  
    75  func TestHnswIndexGrow(t *testing.T) {
    76  	vector := []float32{0.1, 0.2}
    77  	vecForIDFn := func(ctx context.Context, id uint64) ([]float32, error) {
    78  		return vector, nil
    79  	}
    80  	index := createEmptyHnswIndexForTests(t, vecForIDFn)
    81  
    82  	t.Run("should grow initial empty index", func(t *testing.T) {
    83  		// when we invoke Add method suggesting a size bigger then the default
    84  		// initial size, then if we don't grow an index at initial state
    85  		// we get: panic: runtime error: index out of range [25001] with length 25000
    86  		// in order to avoid this, insertInitialElement method is now able
    87  		// to grow it's size at initial state
    88  		err := index.Add(uint64(cache.InitialSize+1), vector)
    89  		require.Nil(t, err)
    90  	})
    91  
    92  	t.Run("should grow index without panic", func(t *testing.T) {
    93  		// This test shows that we had an edge case that was not covered
    94  		// in growIndexToAccomodateNode method which was leading to panic:
    95  		// panic: runtime error: index out of range [170001] with length 170001
    96  		vector := []float32{0.11, 0.22}
    97  		id := uint64(5*cache.InitialSize + 1)
    98  		err := index.Add(id, vector)
    99  		require.Nil(t, err)
   100  		// index should grow to 5001
   101  		assert.Equal(t, int(id)+cache.MinimumIndexGrowthDelta, len(index.nodes))
   102  		assert.Equal(t, int32(id+2*cache.MinimumIndexGrowthDelta), index.cache.Len())
   103  		// try to add a vector with id: 8001
   104  		id = uint64(6*cache.InitialSize + cache.MinimumIndexGrowthDelta + 1)
   105  		err = index.Add(id, vector)
   106  		require.Nil(t, err)
   107  		// index should grow to at least 8001
   108  		assert.GreaterOrEqual(t, len(index.nodes), 8001)
   109  		assert.GreaterOrEqual(t, index.cache.Len(), int32(8001))
   110  	})
   111  
   112  	t.Run("should grow index", func(t *testing.T) {
   113  		// should not increase the nodes size
   114  		sizeBefore := len(index.nodes)
   115  		cacheBefore := index.cache.Len()
   116  		idDontGrowIndex := uint64(6*cache.InitialSize - 1)
   117  		err := index.Add(idDontGrowIndex, vector)
   118  		require.Nil(t, err)
   119  		assert.Equal(t, sizeBefore, len(index.nodes))
   120  		assert.Equal(t, cacheBefore, index.cache.Len())
   121  		// should increase nodes
   122  		id := uint64(8*cache.InitialSize + 1)
   123  		err = index.Add(id, vector)
   124  		require.Nil(t, err)
   125  		assert.GreaterOrEqual(t, len(index.nodes), int(id))
   126  		assert.GreaterOrEqual(t, index.cache.Len(), int32(id))
   127  		// should increase nodes when a much greater id is passed
   128  		id = uint64(20*cache.InitialSize + 22)
   129  		err = index.Add(id, vector)
   130  		require.Nil(t, err)
   131  		assert.Equal(t, int(id)+cache.MinimumIndexGrowthDelta, len(index.nodes))
   132  		assert.Equal(t, int32(id+2*cache.MinimumIndexGrowthDelta), index.cache.Len())
   133  	})
   134  }
   135  
   136  func createEmptyHnswIndexForTests(t testing.TB, vecForIDFn common.VectorForID[float32]) *hnsw {
   137  	// mock out commit logger before adding data so we don't leave a disk
   138  	// footprint. Commit logging and deserializing from a (condensed) commit log
   139  	// is tested in a separate integration test that takes care of providing and
   140  	// cleaning up the correct place on disk to write test files
   141  	index, err := New(Config{
   142  		RootPath:              "doesnt-matter-as-committlogger-is-mocked-out",
   143  		ID:                    "unittest",
   144  		MakeCommitLoggerThunk: MakeNoopCommitLogger,
   145  		DistanceProvider:      distancer.NewCosineDistanceProvider(),
   146  		VectorForIDThunk:      vecForIDFn,
   147  	}, ent.UserConfig{
   148  		MaxConnections: 30,
   149  		EFConstruction: 60,
   150  	}, cyclemanager.NewCallbackGroupNoop(), cyclemanager.NewCallbackGroupNoop(),
   151  		cyclemanager.NewCallbackGroupNoop(), testinghelpers.NewDummyStore(t))
   152  	require.Nil(t, err)
   153  	return index
   154  }