github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/compressionhelpers/product_quantization_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 //go:build !race 13 14 package compressionhelpers_test 15 16 import ( 17 "fmt" 18 "sort" 19 "testing" 20 21 "github.com/stretchr/testify/assert" 22 "github.com/weaviate/weaviate/adapters/repos/db/vector/compressionhelpers" 23 "github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw/distancer" 24 "github.com/weaviate/weaviate/adapters/repos/db/vector/testinghelpers" 25 "github.com/weaviate/weaviate/entities/vectorindex/hnsw" 26 ent "github.com/weaviate/weaviate/entities/vectorindex/hnsw" 27 ) 28 29 type IndexAndDistance struct { 30 index uint64 31 distance float32 32 } 33 34 func distance(dp distancer.Provider) func(x, y []float32) float32 { 35 return func(x, y []float32) float32 { 36 dist, _, _ := dp.SingleDist(x, y) 37 return dist 38 } 39 } 40 41 func Test_NoRacePQSettings(t *testing.T) { 42 distanceProvider := distancer.NewL2SquaredProvider() 43 44 cfg := ent.PQConfig{ 45 Enabled: true, 46 Encoder: ent.PQEncoder{ 47 Type: ent.PQEncoderTypeKMeans, 48 Distribution: ent.PQEncoderDistributionLogNormal, 49 }, 50 Centroids: 512, 51 Segments: 128, 52 } 53 54 _, err := compressionhelpers.NewProductQuantizer( 55 cfg, 56 distanceProvider, 57 128, 58 ) 59 assert.NotNil(t, err) 60 } 61 62 func Test_NoRacePQKMeans(t *testing.T) { 63 dimensions := 128 64 vectors_size := 1000 65 queries_size := 100 66 k := 100 67 vectors, queries := testinghelpers.RandomVecs(vectors_size, queries_size, int(dimensions)) 68 distanceProvider := distancer.NewDotProductProvider() 69 70 cfg := ent.PQConfig{ 71 Enabled: true, 72 Encoder: ent.PQEncoder{ 73 Type: ent.PQEncoderTypeKMeans, 74 Distribution: ent.PQEncoderDistributionLogNormal, 75 }, 76 Centroids: 255, 77 Segments: dimensions, 78 } 79 pq, _ := compressionhelpers.NewProductQuantizer( 80 cfg, 81 distanceProvider, 82 dimensions, 83 ) 84 pq.Fit(vectors) 85 encoded := make([][]byte, vectors_size) 86 for i := 0; i < vectors_size; i++ { 87 encoded[i] = pq.Encode(vectors[i]) 88 } 89 90 var relevant uint64 91 queries_size = 100 92 for _, query := range queries { 93 truth, _ := testinghelpers.BruteForce(vectors, query, k, distance(distanceProvider)) 94 distances := make([]IndexAndDistance, len(vectors)) 95 96 distancer := pq.NewDistancer(query) 97 for v := range vectors { 98 d, _, _ := distancer.Distance(encoded[v]) 99 distances[v] = IndexAndDistance{index: uint64(v), distance: d} 100 } 101 sort.Slice(distances, func(a, b int) bool { 102 return distances[a].distance < distances[b].distance 103 }) 104 105 results := make([]uint64, 0, k) 106 for i := 0; i < k; i++ { 107 results = append(results, distances[i].index) 108 } 109 relevant += testinghelpers.MatchesInLists(truth, results) 110 } 111 recall := float32(relevant) / float32(k*queries_size) 112 fmt.Println(recall) 113 assert.True(t, recall > 0.99) 114 } 115 116 func Test_NoRacePQDecodeBytes(t *testing.T) { 117 t.Run("extracts correctly on one code per byte", func(t *testing.T) { 118 amount := 100 119 values := make([]byte, 0, amount) 120 for i := byte(0); i < byte(amount); i++ { 121 values = append(values, i) 122 } 123 for i := 0; i < amount; i++ { 124 code := compressionhelpers.ExtractCode8(values, i) 125 assert.Equal(t, code, uint8(i)) 126 } 127 }) 128 } 129 130 func Test_NoRacePQInvalidConfig(t *testing.T) { 131 t.Run("validate pq options", func(t *testing.T) { 132 amount := 100 133 centroids := 256 134 cfg := ent.PQConfig{ 135 Enabled: true, 136 Encoder: ent.PQEncoder{ 137 Type: "lmeans", 138 Distribution: ent.PQEncoderDistributionLogNormal, 139 }, 140 Centroids: centroids, 141 TrainingLimit: 75, 142 Segments: amount, 143 } 144 _, err := compressionhelpers.NewProductQuantizer( 145 cfg, 146 nil, 147 amount, 148 ) 149 assert.ErrorContains(t, err, "invalid encoder type") 150 cfg = ent.PQConfig{ 151 Enabled: true, 152 Encoder: ent.PQEncoder{ 153 Type: ent.DefaultPQEncoderType, 154 Distribution: "log", 155 }, 156 Centroids: centroids, 157 TrainingLimit: 75, 158 Segments: amount, 159 } 160 _, err = compressionhelpers.NewProductQuantizer( 161 cfg, 162 nil, 163 amount, 164 ) 165 assert.ErrorContains(t, err, "invalid encoder distribution") 166 cfg = ent.PQConfig{ 167 Enabled: true, 168 Encoder: ent.PQEncoder{ 169 Type: ent.DefaultPQEncoderType, 170 Distribution: ent.DefaultPQEncoderDistribution, 171 }, 172 Centroids: centroids, 173 TrainingLimit: 75, 174 Segments: 0, 175 } 176 _, err = compressionhelpers.NewProductQuantizer( 177 cfg, 178 nil, 179 amount, 180 ) 181 assert.ErrorContains(t, err, "segments cannot be 0 nor negative") 182 cfg = ent.PQConfig{ 183 Enabled: true, 184 Encoder: ent.PQEncoder{ 185 Type: ent.DefaultPQEncoderType, 186 Distribution: ent.DefaultPQEncoderDistribution, 187 }, 188 Centroids: centroids, 189 TrainingLimit: 75, 190 Segments: 3, 191 } 192 _, err = compressionhelpers.NewProductQuantizer( 193 cfg, 194 nil, 195 4, 196 ) 197 assert.ErrorContains(t, err, "segments should be an integer divisor of dimensions") 198 }) 199 t.Run("validate training limit applied", func(t *testing.T) { 200 amount := 64 201 centroids := 256 202 vectors_size := 400 203 vectors, _ := testinghelpers.RandomVecs(vectors_size, vectors_size, amount) 204 distanceProvider := distancer.NewL2SquaredProvider() 205 206 cfg := ent.PQConfig{ 207 Enabled: true, 208 Encoder: ent.PQEncoder{ 209 Type: hnsw.PQEncoderTypeKMeans, 210 Distribution: ent.PQEncoderDistributionLogNormal, 211 }, 212 Centroids: centroids, 213 TrainingLimit: 260, 214 Segments: amount, 215 } 216 pq, err := compressionhelpers.NewProductQuantizer( 217 cfg, 218 distanceProvider, 219 amount, 220 ) 221 assert.NoError(t, err) 222 pq.Fit(vectors) 223 pqdata := pq.ExposeFields() 224 assert.Equal(t, pqdata.TrainingLimit, 260) 225 }) 226 } 227 228 func Test_NoRacePQEncodeBytes(t *testing.T) { 229 t.Run("encodes correctly on one code per byte", func(t *testing.T) { 230 amount := 100 231 values := make([]byte, amount) 232 for i := 0; i < amount; i++ { 233 compressionhelpers.PutCode8(uint8(i), values, i) 234 } 235 for i := 0; i < amount; i++ { 236 code := compressionhelpers.ExtractCode8(values, i) 237 assert.Equal(t, code, uint8(i)) 238 } 239 }) 240 }