github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/hnsw/compress_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package hnsw 13 14 import ( 15 "context" 16 "testing" 17 18 "github.com/stretchr/testify/assert" 19 "github.com/weaviate/weaviate/adapters/repos/db/vector/common" 20 "github.com/weaviate/weaviate/adapters/repos/db/vector/compressionhelpers" 21 "github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw/distancer" 22 "github.com/weaviate/weaviate/adapters/repos/db/vector/testinghelpers" 23 "github.com/weaviate/weaviate/entities/cyclemanager" 24 "github.com/weaviate/weaviate/entities/storobj" 25 ent "github.com/weaviate/weaviate/entities/vectorindex/hnsw" 26 ) 27 28 func TestCompression_CalculateOptimalSegments(t *testing.T) { 29 h := &hnsw{} 30 31 type testCase struct { 32 dimensions int 33 expectedSegments int 34 } 35 36 for _, tc := range []testCase{ 37 { 38 dimensions: 2048, 39 expectedSegments: 256, 40 }, 41 { 42 dimensions: 1536, 43 expectedSegments: 256, 44 }, 45 { 46 dimensions: 768, 47 expectedSegments: 128, 48 }, 49 { 50 dimensions: 512, 51 expectedSegments: 128, 52 }, 53 { 54 dimensions: 256, 55 expectedSegments: 64, 56 }, 57 { 58 dimensions: 125, 59 expectedSegments: 125, 60 }, 61 { 62 dimensions: 64, 63 expectedSegments: 32, 64 }, 65 { 66 dimensions: 27, 67 expectedSegments: 27, 68 }, 69 { 70 dimensions: 19, 71 expectedSegments: 19, 72 }, 73 { 74 dimensions: 2, 75 expectedSegments: 1, 76 }, 77 } { 78 segments := h.calculateOptimalSegments(tc.dimensions) 79 assert.Equal(t, tc.expectedSegments, segments) 80 } 81 } 82 83 func Test_NoRaceCompressReturnsErrorWhenNotEnoughData(t *testing.T) { 84 efConstruction := 64 85 ef := 32 86 maxNeighbors := 32 87 dimensions := 200 88 vectors_size := 10 89 vectors, _ := testinghelpers.RandomVecs(vectors_size, 0, dimensions) 90 distancer := distancer.NewL2SquaredProvider() 91 92 uc := ent.UserConfig{} 93 uc.MaxConnections = maxNeighbors 94 uc.EFConstruction = efConstruction 95 uc.EF = ef 96 uc.VectorCacheMaxObjects = 10e12 97 uc.PQ = ent.PQConfig{ 98 Enabled: false, 99 Encoder: ent.PQEncoder{ 100 Type: ent.PQEncoderTypeKMeans, 101 Distribution: ent.PQEncoderDistributionLogNormal, 102 }, 103 TrainingLimit: 5, 104 Segments: dimensions, 105 Centroids: 256, 106 } 107 108 index, _ := New(Config{ 109 RootPath: t.TempDir(), 110 ID: "recallbenchmark", 111 MakeCommitLoggerThunk: MakeNoopCommitLogger, 112 DistanceProvider: distancer, 113 VectorForIDThunk: func(ctx context.Context, id uint64) ([]float32, error) { 114 if int(id) >= len(vectors) { 115 return nil, storobj.NewErrNotFoundf(id, "out of range") 116 } 117 return vectors[int(id)], nil 118 }, 119 TempVectorForIDThunk: func(ctx context.Context, id uint64, container *common.VectorSlice) ([]float32, error) { 120 copy(container.Slice, vectors[int(id)]) 121 return container.Slice, nil 122 }, 123 }, uc, cyclemanager.NewCallbackGroupNoop(), cyclemanager.NewCallbackGroupNoop(), 124 cyclemanager.NewCallbackGroupNoop(), testinghelpers.NewDummyStore(t)) 125 defer index.Shutdown(context.Background()) 126 compressionhelpers.Concurrently(uint64(len(vectors)), func(id uint64) { 127 index.Add(uint64(id), vectors[id]) 128 }) 129 130 cfg := ent.PQConfig{ 131 Enabled: true, 132 Encoder: ent.PQEncoder{ 133 Type: ent.PQEncoderTypeKMeans, 134 Distribution: ent.PQEncoderDistributionLogNormal, 135 }, 136 Segments: dimensions, 137 Centroids: 256, 138 } 139 uc.PQ = cfg 140 err := index.compress(uc) 141 assert.NotNil(t, err) 142 }