github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/compressionhelpers/binary_quantization_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package compressionhelpers_test
    13  
    14  import (
    15  	"fmt"
    16  	"sync"
    17  	"testing"
    18  	"time"
    19  
    20  	"github.com/stretchr/testify/assert"
    21  	"github.com/weaviate/weaviate/adapters/repos/db/priorityqueue"
    22  	"github.com/weaviate/weaviate/adapters/repos/db/vector/compressionhelpers"
    23  	"github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw/distancer"
    24  	testinghelpers "github.com/weaviate/weaviate/adapters/repos/db/vector/testinghelpers"
    25  )
    26  
    27  func TestBinaryQuantizerRecall(t *testing.T) {
    28  	k := 10
    29  	distanceProvider := distancer.NewCosineDistanceProvider()
    30  	vectors, queryVecs := testinghelpers.RandomVecs(10_000, 100, 1536)
    31  	compressionhelpers.Concurrently(uint64(len(vectors)), func(i uint64) {
    32  		vectors[i] = distancer.Normalize(vectors[i])
    33  	})
    34  	compressionhelpers.Concurrently(uint64(len(queryVecs)), func(i uint64) {
    35  		queryVecs[i] = distancer.Normalize(queryVecs[i])
    36  	})
    37  	bq := compressionhelpers.NewBinaryQuantizer(nil)
    38  
    39  	codes := make([][]uint64, len(vectors))
    40  	compressionhelpers.Concurrently(uint64(len(vectors)), func(i uint64) {
    41  		codes[i] = bq.Encode(vectors[i])
    42  	})
    43  	neighbors := make([][]uint64, len(queryVecs))
    44  	compressionhelpers.Concurrently(uint64(len(queryVecs)), func(i uint64) {
    45  		neighbors[i], _ = testinghelpers.BruteForce(vectors, queryVecs[i], k, func(f1, f2 []float32) float32 {
    46  			d, _, _ := distanceProvider.SingleDist(f1, f2)
    47  			return d
    48  		})
    49  	})
    50  	correctedK := 200
    51  	hits := uint64(0)
    52  	mutex := sync.Mutex{}
    53  	duration := time.Duration(0)
    54  	compressionhelpers.Concurrently(uint64(len(queryVecs)), func(i uint64) {
    55  		before := time.Now()
    56  		query := bq.Encode(queryVecs[i])
    57  		heap := priorityqueue.NewMax[any](correctedK)
    58  		for j := range codes {
    59  			d, _ := bq.DistanceBetweenCompressedVectors(codes[j], query)
    60  			if heap.Len() < correctedK || heap.Top().Dist > d {
    61  				if heap.Len() == correctedK {
    62  					heap.Pop()
    63  				}
    64  				heap.Insert(uint64(j), d)
    65  			}
    66  		}
    67  		ids := make([]uint64, correctedK)
    68  		for j := range ids {
    69  			ids[j] = heap.Pop().ID
    70  		}
    71  		mutex.Lock()
    72  		duration += time.Since(before)
    73  		hits += testinghelpers.MatchesInLists(neighbors[i][:k], ids)
    74  		mutex.Unlock()
    75  	})
    76  	recall := float32(hits) / float32(k*len(queryVecs))
    77  	latency := float32(duration.Microseconds()) / float32(len(queryVecs))
    78  	fmt.Println(recall, latency)
    79  	assert.True(t, recall > 0.7)
    80  }
    81  
    82  func TestBinaryQuantizerChecksSize(t *testing.T) {
    83  	bq := compressionhelpers.NewBinaryQuantizer(nil)
    84  	_, err := bq.DistanceBetweenCompressedVectors(make([]uint64, 3), make([]uint64, 4))
    85  	assert.NotNil(t, err)
    86  }