github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/hnsw/compress_recall_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 //go:build !race 13 14 package hnsw_test 15 16 import ( 17 "context" 18 "fmt" 19 "os" 20 "sync" 21 "testing" 22 "time" 23 24 "github.com/stretchr/testify/assert" 25 "github.com/weaviate/weaviate/adapters/repos/db/vector/common" 26 "github.com/weaviate/weaviate/adapters/repos/db/vector/compressionhelpers" 27 "github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw" 28 "github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw/distancer" 29 "github.com/weaviate/weaviate/adapters/repos/db/vector/testinghelpers" 30 "github.com/weaviate/weaviate/entities/cyclemanager" 31 "github.com/weaviate/weaviate/entities/storobj" 32 ent "github.com/weaviate/weaviate/entities/vectorindex/hnsw" 33 ) 34 35 func distanceWrapper(provider distancer.Provider) func(x, y []float32) float32 { 36 return func(x, y []float32) float32 { 37 dist, _, _ := provider.SingleDist(x, y) 38 return dist 39 } 40 } 41 42 func Test_NoRaceCompressionRecall(t *testing.T) { 43 path := t.TempDir() 44 45 efConstruction := 64 46 ef := 64 47 maxNeighbors := 32 48 segments := 4 49 dimensions := 64 50 vectors_size := 10000 51 queries_size := 100 52 fmt.Println("Sift1M PQ") 53 before := time.Now() 54 vectors, queries := testinghelpers.RandomVecs(vectors_size, queries_size, dimensions) 55 testinghelpers.Normalize(vectors) 56 testinghelpers.Normalize(queries) 57 k := 100 58 59 distancers := []distancer.Provider{ 60 distancer.NewL2SquaredProvider(), 61 distancer.NewCosineDistanceProvider(), 62 distancer.NewDotProductProvider(), 63 } 64 65 for _, distancer := range distancers { 66 truths := make([][]uint64, queries_size) 67 compressionhelpers.Concurrently(uint64(len(queries)), func(i uint64) { 68 truths[i], _ = testinghelpers.BruteForce(vectors, queries[i], k, distanceWrapper(distancer)) 69 }) 70 fmt.Printf("generating data took %s\n", time.Since(before)) 71 72 uc := ent.UserConfig{ 73 MaxConnections: maxNeighbors, 74 EFConstruction: efConstruction, 75 EF: ef, 76 VectorCacheMaxObjects: 10e12, 77 } 78 index, _ := hnsw.New(hnsw.Config{ 79 RootPath: path, 80 ID: "recallbenchmark", 81 MakeCommitLoggerThunk: hnsw.MakeNoopCommitLogger, 82 ClassName: "clasRecallBenchmark", 83 ShardName: "shardRecallBenchmark", 84 DistanceProvider: distancer, 85 VectorForIDThunk: func(ctx context.Context, id uint64) ([]float32, error) { 86 if int(id) >= len(vectors) { 87 return nil, storobj.NewErrNotFoundf(id, "out of range") 88 } 89 return vectors[int(id)], nil 90 }, 91 TempVectorForIDThunk: func(ctx context.Context, id uint64, container *common.VectorSlice) ([]float32, error) { 92 copy(container.Slice, vectors[int(id)]) 93 return container.Slice, nil 94 }, 95 }, uc, cyclemanager.NewCallbackGroupNoop(), cyclemanager.NewCallbackGroupNoop(), 96 cyclemanager.NewCallbackGroupNoop(), testinghelpers.NewDummyStore(t)) 97 init := time.Now() 98 compressionhelpers.Concurrently(uint64(vectors_size), func(id uint64) { 99 index.Add(id, vectors[id]) 100 }) 101 before = time.Now() 102 fmt.Println("Start compressing...") 103 uc.PQ = ent.PQConfig{ 104 Enabled: true, 105 Segments: dimensions / segments, 106 Centroids: 256, 107 Encoder: ent.NewDefaultUserConfig().PQ.Encoder, 108 } 109 uc.EF = 256 110 wg := sync.WaitGroup{} 111 wg.Add(1) 112 index.UpdateUserConfig(uc, func() { 113 fmt.Printf("Time to compress: %s\n", time.Since(before)) 114 fmt.Printf("Building the index took %s\n", time.Since(init)) 115 116 var relevant uint64 117 var retrieved int 118 119 var querying time.Duration = 0 120 compressionhelpers.Concurrently(uint64(len(queries)), func(i uint64) { 121 before = time.Now() 122 results, _, _ := index.SearchByVector(queries[i], k, nil) 123 querying += time.Since(before) 124 retrieved += k 125 relevant += testinghelpers.MatchesInLists(truths[i], results) 126 }) 127 128 recall := float32(relevant) / float32(retrieved) 129 latency := float32(querying.Microseconds()) / float32(queries_size) 130 fmt.Println(recall, latency) 131 assert.True(t, recall > 0.9) 132 133 err := os.RemoveAll(path) 134 if err != nil { 135 fmt.Println(err) 136 } 137 wg.Done() 138 }) 139 wg.Wait() 140 } 141 }