github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/compressionhelpers/product_quantization.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package compressionhelpers 13 14 import ( 15 "errors" 16 "fmt" 17 "math" 18 "sync" 19 20 "github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw/distancer" 21 ent "github.com/weaviate/weaviate/entities/vectorindex/hnsw" 22 ) 23 24 type Encoder byte 25 26 const ( 27 UseTileEncoder Encoder = 0 28 UseKMeansEncoder Encoder = 1 29 ) 30 31 type DistanceLookUpTable struct { 32 calculated []bool 33 distances []float32 34 center [][]float32 35 segments int 36 centroids int 37 flatCenter []float32 38 } 39 40 func NewDistanceLookUpTable(segments int, centroids int, center []float32) *DistanceLookUpTable { 41 distances := make([]float32, segments*centroids) 42 calculated := make([]bool, segments*centroids) 43 parsedCenter := make([][]float32, segments) 44 ds := len(center) / segments 45 for c := 0; c < segments; c++ { 46 parsedCenter[c] = center[c*ds : (c+1)*ds] 47 } 48 49 dlt := &DistanceLookUpTable{ 50 distances: distances, 51 calculated: calculated, 52 center: parsedCenter, 53 segments: segments, 54 centroids: centroids, 55 flatCenter: center, 56 } 57 return dlt 58 } 59 60 func (lut *DistanceLookUpTable) Reset(segments int, centroids int, center []float32) { 61 elems := segments * centroids 62 lut.segments = segments 63 lut.centroids = centroids 64 if len(lut.distances) != elems || 65 len(lut.calculated) != elems || 66 len(lut.center) != segments { 67 lut.distances = make([]float32, segments*centroids) 68 lut.calculated = make([]bool, segments*centroids) 69 lut.center = make([][]float32, segments) 70 } else { 71 for i := range lut.calculated { 72 lut.calculated[i] = false 73 } 74 } 75 76 ds := len(center) / segments 77 for c := 0; c < segments; c++ { 78 lut.center[c] = center[c*ds : (c+1)*ds] 79 } 80 lut.flatCenter = center 81 } 82 83 func (lut *DistanceLookUpTable) LookUp( 84 encoded []byte, 85 pq *ProductQuantizer, 86 ) float32 { 87 var sum float32 88 89 for i := range pq.kms { 90 c := ExtractCode8(encoded, i) 91 if lut.distCalculated(i, c) { 92 sum += lut.codeDist(i, c) 93 } else { 94 centroid := pq.kms[i].Centroid(c) 95 dist := pq.distance.Step(lut.center[i], centroid) 96 lut.setCodeDist(i, c, dist) 97 lut.setDistCalculated(i, c) 98 sum += dist 99 } 100 } 101 return pq.distance.Wrap(sum) 102 } 103 104 // meant for better readability, rely on the fact that the compiler will inline this 105 func (lut *DistanceLookUpTable) posForSegmentAndCode(segment int, code byte) int { 106 return segment*lut.centroids + int(code) 107 } 108 109 // meant for better readability, rely on the fact that the compiler will inline this 110 func (lut *DistanceLookUpTable) distCalculated(segment int, code byte) bool { 111 return lut.calculated[lut.posForSegmentAndCode(segment, code)] 112 } 113 114 // meant for better readability, rely on the fact that the compiler will inline this 115 func (lut *DistanceLookUpTable) setDistCalculated(segment int, code byte) { 116 lut.calculated[lut.posForSegmentAndCode(segment, code)] = true 117 } 118 119 // meant for better readability, rely on the fact that the compiler will inline this 120 func (lut *DistanceLookUpTable) codeDist(segment int, code byte) float32 { 121 return lut.distances[lut.posForSegmentAndCode(segment, code)] 122 } 123 124 // meant for better readability, rely on the fact that the compiler will inline this 125 func (lut *DistanceLookUpTable) setCodeDist(segment int, code byte, dist float32) { 126 lut.distances[lut.posForSegmentAndCode(segment, code)] = dist 127 } 128 129 type DLUTPool struct { 130 pool sync.Pool 131 } 132 133 func NewDLUTPool() *DLUTPool { 134 return &DLUTPool{ 135 pool: sync.Pool{ 136 New: func() any { 137 return &DistanceLookUpTable{} 138 }, 139 }, 140 } 141 } 142 143 func (p *DLUTPool) Get(segments, centroids int, centers []float32) *DistanceLookUpTable { 144 dlt := p.pool.Get().(*DistanceLookUpTable) 145 dlt.Reset(segments, centroids, centers) 146 return dlt 147 } 148 149 func (p *DLUTPool) Return(dlt *DistanceLookUpTable) { 150 p.pool.Put(dlt) 151 } 152 153 type ProductQuantizer struct { 154 ks int // centroids 155 m int // segments 156 ds int // dimensions per segment 157 distance distancer.Provider 158 dimensions int 159 kms []PQEncoder 160 encoderType Encoder 161 encoderDistribution EncoderDistribution 162 dlutPool *DLUTPool 163 trainingLimit int 164 globalDistances []float32 165 } 166 167 type PQData struct { 168 Ks uint16 169 M uint16 170 Dimensions uint16 171 EncoderType Encoder 172 EncoderDistribution byte 173 Encoders []PQEncoder 174 UseBitsEncoding bool 175 TrainingLimit int 176 } 177 178 type PQEncoder interface { 179 Encode(x []float32) byte 180 Centroid(b byte) []float32 181 Add(x []float32) 182 Fit(data [][]float32) error 183 ExposeDataForRestore() []byte 184 } 185 186 func NewProductQuantizer(cfg ent.PQConfig, distance distancer.Provider, dimensions int) (*ProductQuantizer, error) { 187 if cfg.Segments <= 0 { 188 return nil, errors.New("segments cannot be 0 nor negative") 189 } 190 if cfg.Centroids > 256 { 191 return nil, fmt.Errorf("centroids should not be higher than 256. Attempting to use %d", cfg.Centroids) 192 } 193 if dimensions%cfg.Segments != 0 { 194 return nil, errors.New("segments should be an integer divisor of dimensions") 195 } 196 encoderType, err := parseEncoder(cfg.Encoder.Type) 197 if err != nil { 198 return nil, errors.New("invalid encoder type") 199 } 200 201 encoderDistribution, err := parseEncoderDistribution(cfg.Encoder.Distribution) 202 if err != nil { 203 return nil, errors.New("invalid encoder distribution") 204 } 205 pq := &ProductQuantizer{ 206 ks: cfg.Centroids, 207 m: cfg.Segments, 208 ds: int(dimensions / cfg.Segments), 209 distance: distance, 210 trainingLimit: cfg.TrainingLimit, 211 dimensions: dimensions, 212 encoderType: encoderType, 213 encoderDistribution: encoderDistribution, 214 dlutPool: NewDLUTPool(), 215 } 216 217 return pq, nil 218 } 219 220 func NewProductQuantizerWithEncoders(cfg ent.PQConfig, distance distancer.Provider, dimensions int, encoders []PQEncoder) (*ProductQuantizer, error) { 221 cfg.Segments = len(encoders) 222 pq, err := NewProductQuantizer(cfg, distance, dimensions) 223 if err != nil { 224 return nil, err 225 } 226 227 pq.kms = encoders 228 pq.buildGlobalDistances() 229 return pq, nil 230 } 231 232 func (pq *ProductQuantizer) buildGlobalDistances() { 233 // This hosts the partial distances between the centroids. This way we do not need 234 // to recalculate all the time when calculating full distances between compressed vecs 235 pq.globalDistances = make([]float32, pq.m*pq.ks*pq.ks) 236 for segment := 0; segment < pq.m; segment++ { 237 for i := 0; i < pq.ks; i++ { 238 cX := pq.kms[segment].Centroid(byte(i)) 239 for j := 0; j <= i; j++ { 240 cY := pq.kms[segment].Centroid(byte(j)) 241 pq.globalDistances[segment*pq.ks*pq.ks+i*pq.ks+j] = pq.distance.Step(cX, cY) 242 // Just copy from already calculated cell since step should be symmetric. 243 pq.globalDistances[segment*pq.ks*pq.ks+j*pq.ks+i] = pq.globalDistances[segment*pq.ks*pq.ks+i*pq.ks+j] 244 } 245 } 246 } 247 } 248 249 // Only made public for testing purposes... Not sure we need it outside 250 func ExtractCode8(encoded []byte, index int) byte { 251 return encoded[index] 252 } 253 254 func parseEncoder(encoder string) (Encoder, error) { 255 switch encoder { 256 case ent.PQEncoderTypeTile: 257 return UseTileEncoder, nil 258 case ent.PQEncoderTypeKMeans: 259 return UseKMeansEncoder, nil 260 default: 261 return 0, fmt.Errorf("invalid encoder type: %s", encoder) 262 } 263 } 264 265 func parseEncoderDistribution(distribution string) (EncoderDistribution, error) { 266 switch distribution { 267 case ent.PQEncoderDistributionLogNormal: 268 return LogNormalEncoderDistribution, nil 269 case ent.PQEncoderDistributionNormal: 270 return NormalEncoderDistribution, nil 271 default: 272 return 0, fmt.Errorf("invalid encoder distribution: %s", distribution) 273 } 274 } 275 276 // Only made public for testing purposes... Not sure we need it outside 277 func PutCode8(code byte, buffer []byte, index int) { 278 buffer[index] = code 279 } 280 281 func (pq *ProductQuantizer) ExposeFields() PQData { 282 return PQData{ 283 Dimensions: uint16(pq.dimensions), 284 EncoderType: pq.encoderType, 285 Ks: uint16(pq.ks), 286 M: uint16(pq.m), 287 EncoderDistribution: byte(pq.encoderDistribution), 288 Encoders: pq.kms, 289 TrainingLimit: pq.trainingLimit, 290 } 291 } 292 293 func (pq *ProductQuantizer) DistanceBetweenCompressedVectors(x, y []byte) (float32, error) { 294 if len(x) != pq.m || len(y) != pq.m { 295 return 0, fmt.Errorf("inconsistent compressed vectors lengths") 296 } 297 298 dist := float32(0) 299 300 for i := 0; i < pq.m; i++ { 301 cX := ExtractCode8(x, i) 302 cY := ExtractCode8(y, i) 303 dist += pq.globalDistances[i*pq.ks*pq.ks+int(cX)*pq.ks+int(cY)] 304 } 305 306 return pq.distance.Wrap(dist), nil 307 } 308 309 func (pq *ProductQuantizer) DistanceBetweenCompressedAndUncompressedVectors(x []float32, encoded []byte) (float32, error) { 310 dist := float32(0) 311 for i := 0; i < pq.m; i++ { 312 cY := pq.kms[i].Centroid(ExtractCode8(encoded, i)) 313 dist += pq.distance.Step(x[i*pq.ds:(i+1)*pq.ds], cY) 314 } 315 return pq.distance.Wrap(dist), nil 316 } 317 318 type PQDistancer struct { 319 x []float32 320 pq *ProductQuantizer 321 lut *DistanceLookUpTable 322 compressed []byte 323 } 324 325 func (pq *ProductQuantizer) NewDistancer(a []float32) *PQDistancer { 326 lut := pq.CenterAt(a) 327 return &PQDistancer{ 328 x: a, 329 pq: pq, 330 lut: lut, 331 compressed: nil, 332 } 333 } 334 335 func (pq *ProductQuantizer) NewCompressedQuantizerDistancer(a []byte) quantizerDistancer[byte] { 336 return &PQDistancer{ 337 x: nil, 338 pq: pq, 339 lut: nil, 340 compressed: a, 341 } 342 } 343 344 func (pq *ProductQuantizer) ReturnDistancer(d *PQDistancer) { 345 pq.dlutPool.Return(d.lut) 346 } 347 348 func (d *PQDistancer) Distance(x []byte) (float32, bool, error) { 349 if d.lut == nil { 350 dist, err := d.pq.DistanceBetweenCompressedVectors(d.compressed, x) 351 return dist, err == nil, err 352 } 353 if len(x) != d.pq.m { 354 return 0, false, fmt.Errorf("inconsistent compressed vector length") 355 } 356 return d.pq.Distance(x, d.lut), true, nil 357 } 358 359 func (d *PQDistancer) DistanceToFloat(x []float32) (float32, bool, error) { 360 if d.lut != nil { 361 return d.pq.distance.SingleDist(x, d.lut.flatCenter) 362 } 363 xComp := d.pq.Encode(x) 364 dist, err := d.pq.DistanceBetweenCompressedVectors(d.compressed, xComp) 365 return dist, err == nil, err 366 } 367 368 func (pq *ProductQuantizer) Fit(data [][]float32) error { 369 if pq.trainingLimit > 0 && len(data) > pq.trainingLimit { 370 data = data[:pq.trainingLimit] 371 } 372 switch pq.encoderType { 373 case UseTileEncoder: 374 pq.kms = make([]PQEncoder, pq.m) 375 Concurrently(uint64(pq.m), func(i uint64) { 376 pq.kms[i] = NewTileEncoder(int(math.Log2(float64(pq.ks))), int(i), pq.encoderDistribution) 377 for j := 0; j < len(data); j++ { 378 pq.kms[i].Add(data[j]) 379 } 380 pq.kms[i].Fit(data) 381 }) 382 case UseKMeansEncoder: 383 mutex := sync.Mutex{} 384 var errorResult error = nil 385 pq.kms = make([]PQEncoder, pq.m) 386 Concurrently(uint64(pq.m), func(i uint64) { 387 mutex.Lock() 388 if errorResult != nil { 389 mutex.Unlock() 390 return 391 } 392 mutex.Unlock() 393 pq.kms[i] = NewKMeans( 394 pq.ks, 395 pq.ds, 396 int(i), 397 ) 398 err := pq.kms[i].Fit(data) 399 mutex.Lock() 400 if errorResult == nil && err != nil { 401 errorResult = err 402 } 403 mutex.Unlock() 404 }) 405 if errorResult != nil { 406 return errorResult 407 } 408 } 409 pq.buildGlobalDistances() 410 return nil 411 } 412 413 func (pq *ProductQuantizer) Encode(vec []float32) []byte { 414 codes := make([]byte, pq.m) 415 for i := 0; i < pq.m; i++ { 416 PutCode8(pq.kms[i].Encode(vec), codes, i) 417 } 418 return codes 419 } 420 421 func (pq *ProductQuantizer) Decode(code []byte) []float32 { 422 vec := make([]float32, 0, pq.m) 423 for i := 0; i < pq.m; i++ { 424 vec = append(vec, pq.kms[i].Centroid(ExtractCode8(code, i))...) 425 } 426 return vec 427 } 428 429 func (pq *ProductQuantizer) CenterAt(vec []float32) *DistanceLookUpTable { 430 return pq.dlutPool.Get(int(pq.m), int(pq.ks), vec) 431 } 432 433 func (pq *ProductQuantizer) Distance(encoded []byte, lut *DistanceLookUpTable) float32 { 434 return lut.LookUp(encoded, pq) 435 }