github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/compressionhelpers/product_quantization.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package compressionhelpers
    13  
    14  import (
    15  	"errors"
    16  	"fmt"
    17  	"math"
    18  	"sync"
    19  
    20  	"github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw/distancer"
    21  	ent "github.com/weaviate/weaviate/entities/vectorindex/hnsw"
    22  )
    23  
    24  type Encoder byte
    25  
    26  const (
    27  	UseTileEncoder   Encoder = 0
    28  	UseKMeansEncoder Encoder = 1
    29  )
    30  
    31  type DistanceLookUpTable struct {
    32  	calculated []bool
    33  	distances  []float32
    34  	center     [][]float32
    35  	segments   int
    36  	centroids  int
    37  	flatCenter []float32
    38  }
    39  
    40  func NewDistanceLookUpTable(segments int, centroids int, center []float32) *DistanceLookUpTable {
    41  	distances := make([]float32, segments*centroids)
    42  	calculated := make([]bool, segments*centroids)
    43  	parsedCenter := make([][]float32, segments)
    44  	ds := len(center) / segments
    45  	for c := 0; c < segments; c++ {
    46  		parsedCenter[c] = center[c*ds : (c+1)*ds]
    47  	}
    48  
    49  	dlt := &DistanceLookUpTable{
    50  		distances:  distances,
    51  		calculated: calculated,
    52  		center:     parsedCenter,
    53  		segments:   segments,
    54  		centroids:  centroids,
    55  		flatCenter: center,
    56  	}
    57  	return dlt
    58  }
    59  
    60  func (lut *DistanceLookUpTable) Reset(segments int, centroids int, center []float32) {
    61  	elems := segments * centroids
    62  	lut.segments = segments
    63  	lut.centroids = centroids
    64  	if len(lut.distances) != elems ||
    65  		len(lut.calculated) != elems ||
    66  		len(lut.center) != segments {
    67  		lut.distances = make([]float32, segments*centroids)
    68  		lut.calculated = make([]bool, segments*centroids)
    69  		lut.center = make([][]float32, segments)
    70  	} else {
    71  		for i := range lut.calculated {
    72  			lut.calculated[i] = false
    73  		}
    74  	}
    75  
    76  	ds := len(center) / segments
    77  	for c := 0; c < segments; c++ {
    78  		lut.center[c] = center[c*ds : (c+1)*ds]
    79  	}
    80  	lut.flatCenter = center
    81  }
    82  
    83  func (lut *DistanceLookUpTable) LookUp(
    84  	encoded []byte,
    85  	pq *ProductQuantizer,
    86  ) float32 {
    87  	var sum float32
    88  
    89  	for i := range pq.kms {
    90  		c := ExtractCode8(encoded, i)
    91  		if lut.distCalculated(i, c) {
    92  			sum += lut.codeDist(i, c)
    93  		} else {
    94  			centroid := pq.kms[i].Centroid(c)
    95  			dist := pq.distance.Step(lut.center[i], centroid)
    96  			lut.setCodeDist(i, c, dist)
    97  			lut.setDistCalculated(i, c)
    98  			sum += dist
    99  		}
   100  	}
   101  	return pq.distance.Wrap(sum)
   102  }
   103  
   104  // meant for better readability, rely on the fact that the compiler will inline this
   105  func (lut *DistanceLookUpTable) posForSegmentAndCode(segment int, code byte) int {
   106  	return segment*lut.centroids + int(code)
   107  }
   108  
   109  // meant for better readability, rely on the fact that the compiler will inline this
   110  func (lut *DistanceLookUpTable) distCalculated(segment int, code byte) bool {
   111  	return lut.calculated[lut.posForSegmentAndCode(segment, code)]
   112  }
   113  
   114  // meant for better readability, rely on the fact that the compiler will inline this
   115  func (lut *DistanceLookUpTable) setDistCalculated(segment int, code byte) {
   116  	lut.calculated[lut.posForSegmentAndCode(segment, code)] = true
   117  }
   118  
   119  // meant for better readability, rely on the fact that the compiler will inline this
   120  func (lut *DistanceLookUpTable) codeDist(segment int, code byte) float32 {
   121  	return lut.distances[lut.posForSegmentAndCode(segment, code)]
   122  }
   123  
   124  // meant for better readability, rely on the fact that the compiler will inline this
   125  func (lut *DistanceLookUpTable) setCodeDist(segment int, code byte, dist float32) {
   126  	lut.distances[lut.posForSegmentAndCode(segment, code)] = dist
   127  }
   128  
   129  type DLUTPool struct {
   130  	pool sync.Pool
   131  }
   132  
   133  func NewDLUTPool() *DLUTPool {
   134  	return &DLUTPool{
   135  		pool: sync.Pool{
   136  			New: func() any {
   137  				return &DistanceLookUpTable{}
   138  			},
   139  		},
   140  	}
   141  }
   142  
   143  func (p *DLUTPool) Get(segments, centroids int, centers []float32) *DistanceLookUpTable {
   144  	dlt := p.pool.Get().(*DistanceLookUpTable)
   145  	dlt.Reset(segments, centroids, centers)
   146  	return dlt
   147  }
   148  
   149  func (p *DLUTPool) Return(dlt *DistanceLookUpTable) {
   150  	p.pool.Put(dlt)
   151  }
   152  
   153  type ProductQuantizer struct {
   154  	ks                  int // centroids
   155  	m                   int // segments
   156  	ds                  int // dimensions per segment
   157  	distance            distancer.Provider
   158  	dimensions          int
   159  	kms                 []PQEncoder
   160  	encoderType         Encoder
   161  	encoderDistribution EncoderDistribution
   162  	dlutPool            *DLUTPool
   163  	trainingLimit       int
   164  	globalDistances     []float32
   165  }
   166  
   167  type PQData struct {
   168  	Ks                  uint16
   169  	M                   uint16
   170  	Dimensions          uint16
   171  	EncoderType         Encoder
   172  	EncoderDistribution byte
   173  	Encoders            []PQEncoder
   174  	UseBitsEncoding     bool
   175  	TrainingLimit       int
   176  }
   177  
   178  type PQEncoder interface {
   179  	Encode(x []float32) byte
   180  	Centroid(b byte) []float32
   181  	Add(x []float32)
   182  	Fit(data [][]float32) error
   183  	ExposeDataForRestore() []byte
   184  }
   185  
   186  func NewProductQuantizer(cfg ent.PQConfig, distance distancer.Provider, dimensions int) (*ProductQuantizer, error) {
   187  	if cfg.Segments <= 0 {
   188  		return nil, errors.New("segments cannot be 0 nor negative")
   189  	}
   190  	if cfg.Centroids > 256 {
   191  		return nil, fmt.Errorf("centroids should not be higher than 256. Attempting to use %d", cfg.Centroids)
   192  	}
   193  	if dimensions%cfg.Segments != 0 {
   194  		return nil, errors.New("segments should be an integer divisor of dimensions")
   195  	}
   196  	encoderType, err := parseEncoder(cfg.Encoder.Type)
   197  	if err != nil {
   198  		return nil, errors.New("invalid encoder type")
   199  	}
   200  
   201  	encoderDistribution, err := parseEncoderDistribution(cfg.Encoder.Distribution)
   202  	if err != nil {
   203  		return nil, errors.New("invalid encoder distribution")
   204  	}
   205  	pq := &ProductQuantizer{
   206  		ks:                  cfg.Centroids,
   207  		m:                   cfg.Segments,
   208  		ds:                  int(dimensions / cfg.Segments),
   209  		distance:            distance,
   210  		trainingLimit:       cfg.TrainingLimit,
   211  		dimensions:          dimensions,
   212  		encoderType:         encoderType,
   213  		encoderDistribution: encoderDistribution,
   214  		dlutPool:            NewDLUTPool(),
   215  	}
   216  
   217  	return pq, nil
   218  }
   219  
   220  func NewProductQuantizerWithEncoders(cfg ent.PQConfig, distance distancer.Provider, dimensions int, encoders []PQEncoder) (*ProductQuantizer, error) {
   221  	cfg.Segments = len(encoders)
   222  	pq, err := NewProductQuantizer(cfg, distance, dimensions)
   223  	if err != nil {
   224  		return nil, err
   225  	}
   226  
   227  	pq.kms = encoders
   228  	pq.buildGlobalDistances()
   229  	return pq, nil
   230  }
   231  
   232  func (pq *ProductQuantizer) buildGlobalDistances() {
   233  	// This hosts the partial distances between the centroids. This way we do not need
   234  	// to recalculate all the time when calculating full distances between compressed vecs
   235  	pq.globalDistances = make([]float32, pq.m*pq.ks*pq.ks)
   236  	for segment := 0; segment < pq.m; segment++ {
   237  		for i := 0; i < pq.ks; i++ {
   238  			cX := pq.kms[segment].Centroid(byte(i))
   239  			for j := 0; j <= i; j++ {
   240  				cY := pq.kms[segment].Centroid(byte(j))
   241  				pq.globalDistances[segment*pq.ks*pq.ks+i*pq.ks+j] = pq.distance.Step(cX, cY)
   242  				// Just copy from already calculated cell since step should be symmetric.
   243  				pq.globalDistances[segment*pq.ks*pq.ks+j*pq.ks+i] = pq.globalDistances[segment*pq.ks*pq.ks+i*pq.ks+j]
   244  			}
   245  		}
   246  	}
   247  }
   248  
   249  // Only made public for testing purposes... Not sure we need it outside
   250  func ExtractCode8(encoded []byte, index int) byte {
   251  	return encoded[index]
   252  }
   253  
   254  func parseEncoder(encoder string) (Encoder, error) {
   255  	switch encoder {
   256  	case ent.PQEncoderTypeTile:
   257  		return UseTileEncoder, nil
   258  	case ent.PQEncoderTypeKMeans:
   259  		return UseKMeansEncoder, nil
   260  	default:
   261  		return 0, fmt.Errorf("invalid encoder type: %s", encoder)
   262  	}
   263  }
   264  
   265  func parseEncoderDistribution(distribution string) (EncoderDistribution, error) {
   266  	switch distribution {
   267  	case ent.PQEncoderDistributionLogNormal:
   268  		return LogNormalEncoderDistribution, nil
   269  	case ent.PQEncoderDistributionNormal:
   270  		return NormalEncoderDistribution, nil
   271  	default:
   272  		return 0, fmt.Errorf("invalid encoder distribution: %s", distribution)
   273  	}
   274  }
   275  
   276  // Only made public for testing purposes... Not sure we need it outside
   277  func PutCode8(code byte, buffer []byte, index int) {
   278  	buffer[index] = code
   279  }
   280  
   281  func (pq *ProductQuantizer) ExposeFields() PQData {
   282  	return PQData{
   283  		Dimensions:          uint16(pq.dimensions),
   284  		EncoderType:         pq.encoderType,
   285  		Ks:                  uint16(pq.ks),
   286  		M:                   uint16(pq.m),
   287  		EncoderDistribution: byte(pq.encoderDistribution),
   288  		Encoders:            pq.kms,
   289  		TrainingLimit:       pq.trainingLimit,
   290  	}
   291  }
   292  
   293  func (pq *ProductQuantizer) DistanceBetweenCompressedVectors(x, y []byte) (float32, error) {
   294  	if len(x) != pq.m || len(y) != pq.m {
   295  		return 0, fmt.Errorf("inconsistent compressed vectors lengths")
   296  	}
   297  
   298  	dist := float32(0)
   299  
   300  	for i := 0; i < pq.m; i++ {
   301  		cX := ExtractCode8(x, i)
   302  		cY := ExtractCode8(y, i)
   303  		dist += pq.globalDistances[i*pq.ks*pq.ks+int(cX)*pq.ks+int(cY)]
   304  	}
   305  
   306  	return pq.distance.Wrap(dist), nil
   307  }
   308  
   309  func (pq *ProductQuantizer) DistanceBetweenCompressedAndUncompressedVectors(x []float32, encoded []byte) (float32, error) {
   310  	dist := float32(0)
   311  	for i := 0; i < pq.m; i++ {
   312  		cY := pq.kms[i].Centroid(ExtractCode8(encoded, i))
   313  		dist += pq.distance.Step(x[i*pq.ds:(i+1)*pq.ds], cY)
   314  	}
   315  	return pq.distance.Wrap(dist), nil
   316  }
   317  
   318  type PQDistancer struct {
   319  	x          []float32
   320  	pq         *ProductQuantizer
   321  	lut        *DistanceLookUpTable
   322  	compressed []byte
   323  }
   324  
   325  func (pq *ProductQuantizer) NewDistancer(a []float32) *PQDistancer {
   326  	lut := pq.CenterAt(a)
   327  	return &PQDistancer{
   328  		x:          a,
   329  		pq:         pq,
   330  		lut:        lut,
   331  		compressed: nil,
   332  	}
   333  }
   334  
   335  func (pq *ProductQuantizer) NewCompressedQuantizerDistancer(a []byte) quantizerDistancer[byte] {
   336  	return &PQDistancer{
   337  		x:          nil,
   338  		pq:         pq,
   339  		lut:        nil,
   340  		compressed: a,
   341  	}
   342  }
   343  
   344  func (pq *ProductQuantizer) ReturnDistancer(d *PQDistancer) {
   345  	pq.dlutPool.Return(d.lut)
   346  }
   347  
   348  func (d *PQDistancer) Distance(x []byte) (float32, bool, error) {
   349  	if d.lut == nil {
   350  		dist, err := d.pq.DistanceBetweenCompressedVectors(d.compressed, x)
   351  		return dist, err == nil, err
   352  	}
   353  	if len(x) != d.pq.m {
   354  		return 0, false, fmt.Errorf("inconsistent compressed vector length")
   355  	}
   356  	return d.pq.Distance(x, d.lut), true, nil
   357  }
   358  
   359  func (d *PQDistancer) DistanceToFloat(x []float32) (float32, bool, error) {
   360  	if d.lut != nil {
   361  		return d.pq.distance.SingleDist(x, d.lut.flatCenter)
   362  	}
   363  	xComp := d.pq.Encode(x)
   364  	dist, err := d.pq.DistanceBetweenCompressedVectors(d.compressed, xComp)
   365  	return dist, err == nil, err
   366  }
   367  
   368  func (pq *ProductQuantizer) Fit(data [][]float32) error {
   369  	if pq.trainingLimit > 0 && len(data) > pq.trainingLimit {
   370  		data = data[:pq.trainingLimit]
   371  	}
   372  	switch pq.encoderType {
   373  	case UseTileEncoder:
   374  		pq.kms = make([]PQEncoder, pq.m)
   375  		Concurrently(uint64(pq.m), func(i uint64) {
   376  			pq.kms[i] = NewTileEncoder(int(math.Log2(float64(pq.ks))), int(i), pq.encoderDistribution)
   377  			for j := 0; j < len(data); j++ {
   378  				pq.kms[i].Add(data[j])
   379  			}
   380  			pq.kms[i].Fit(data)
   381  		})
   382  	case UseKMeansEncoder:
   383  		mutex := sync.Mutex{}
   384  		var errorResult error = nil
   385  		pq.kms = make([]PQEncoder, pq.m)
   386  		Concurrently(uint64(pq.m), func(i uint64) {
   387  			mutex.Lock()
   388  			if errorResult != nil {
   389  				mutex.Unlock()
   390  				return
   391  			}
   392  			mutex.Unlock()
   393  			pq.kms[i] = NewKMeans(
   394  				pq.ks,
   395  				pq.ds,
   396  				int(i),
   397  			)
   398  			err := pq.kms[i].Fit(data)
   399  			mutex.Lock()
   400  			if errorResult == nil && err != nil {
   401  				errorResult = err
   402  			}
   403  			mutex.Unlock()
   404  		})
   405  		if errorResult != nil {
   406  			return errorResult
   407  		}
   408  	}
   409  	pq.buildGlobalDistances()
   410  	return nil
   411  }
   412  
   413  func (pq *ProductQuantizer) Encode(vec []float32) []byte {
   414  	codes := make([]byte, pq.m)
   415  	for i := 0; i < pq.m; i++ {
   416  		PutCode8(pq.kms[i].Encode(vec), codes, i)
   417  	}
   418  	return codes
   419  }
   420  
   421  func (pq *ProductQuantizer) Decode(code []byte) []float32 {
   422  	vec := make([]float32, 0, pq.m)
   423  	for i := 0; i < pq.m; i++ {
   424  		vec = append(vec, pq.kms[i].Centroid(ExtractCode8(code, i))...)
   425  	}
   426  	return vec
   427  }
   428  
   429  func (pq *ProductQuantizer) CenterAt(vec []float32) *DistanceLookUpTable {
   430  	return pq.dlutPool.Get(int(pq.m), int(pq.ks), vec)
   431  }
   432  
   433  func (pq *ProductQuantizer) Distance(encoded []byte, lut *DistanceLookUpTable) float32 {
   434  	return lut.LookUp(encoded, pq)
   435  }