github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/flat/index_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 //go:build !race 13 14 package flat 15 16 import ( 17 "encoding/binary" 18 "errors" 19 "fmt" 20 "os" 21 "strconv" 22 "sync" 23 "testing" 24 "time" 25 26 "github.com/google/uuid" 27 "github.com/sirupsen/logrus" 28 "github.com/sirupsen/logrus/hooks/test" 29 "github.com/stretchr/testify/assert" 30 "github.com/stretchr/testify/require" 31 "github.com/weaviate/weaviate/adapters/repos/db/helpers" 32 "github.com/weaviate/weaviate/adapters/repos/db/lsmkv" 33 34 "github.com/weaviate/weaviate/adapters/repos/db/vector/compressionhelpers" 35 "github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw/distancer" 36 "github.com/weaviate/weaviate/adapters/repos/db/vector/testinghelpers" 37 "github.com/weaviate/weaviate/entities/cyclemanager" 38 flatent "github.com/weaviate/weaviate/entities/vectorindex/flat" 39 ) 40 41 func distanceWrapper(provider distancer.Provider) func(x, y []float32) float32 { 42 return func(x, y []float32) float32 { 43 dist, _, _ := provider.SingleDist(x, y) 44 return dist 45 } 46 } 47 48 func run(dirName string, logger *logrus.Logger, compression string, vectorCache bool, 49 vectors [][]float32, queries [][]float32, k int, truths [][]uint64, 50 extraVectorsForDelete [][]float32, allowIds []uint64, 51 distancer distancer.Provider, 52 ) (float32, float32, error) { 53 vectors_size := len(vectors) 54 queries_size := len(queries) 55 runId := uuid.New().String() 56 57 store, err := lsmkv.New(dirName, dirName, logger, nil, 58 cyclemanager.NewCallbackGroupNoop(), cyclemanager.NewCallbackGroupNoop()) 59 if err != nil { 60 return 0, 0, err 61 } 62 63 pq := flatent.CompressionUserConfig{ 64 Enabled: false, 65 } 66 bq := flatent.CompressionUserConfig{ 67 Enabled: false, 68 } 69 switch compression { 70 case compressionPQ: 71 pq.Enabled = true 72 pq.RescoreLimit = 100 * k 73 pq.Cache = vectorCache 74 case compressionBQ: 75 bq.Enabled = true 76 bq.RescoreLimit = 100 * k 77 bq.Cache = vectorCache 78 } 79 index, err := New(Config{ 80 ID: runId, 81 DistanceProvider: distancer, 82 }, flatent.UserConfig{ 83 PQ: pq, 84 BQ: bq, 85 }, store) 86 if err != nil { 87 return 0, 0, err 88 } 89 90 compressionhelpers.Concurrently(uint64(vectors_size), func(id uint64) { 91 index.Add(id, vectors[id]) 92 }) 93 94 for i := range extraVectorsForDelete { 95 index.Add(uint64(vectors_size+i), extraVectorsForDelete[i]) 96 } 97 98 for i := range extraVectorsForDelete { 99 Id := make([]byte, 16) 100 binary.BigEndian.PutUint64(Id[8:], uint64(vectors_size+i)) 101 err := index.Delete(uint64(vectors_size + i)) 102 if err != nil { 103 return 0, 0, err 104 } 105 } 106 107 var relevant uint64 108 var retrieved int 109 var querying time.Duration = 0 110 mutex := new(sync.Mutex) 111 112 var allowList helpers.AllowList = nil 113 if allowIds != nil { 114 allowList = helpers.NewAllowList(allowIds...) 115 } 116 err = nil 117 compressionhelpers.Concurrently(uint64(len(queries)), func(i uint64) { 118 before := time.Now() 119 results, _, _ := index.SearchByVector(queries[i], k, allowList) 120 121 since := time.Since(before) 122 len := len(results) 123 matches := testinghelpers.MatchesInLists(truths[i], results) 124 125 if hasDuplicates(results) { 126 err = errors.New("results have duplicates") 127 } 128 129 mutex.Lock() 130 querying += since 131 retrieved += len 132 relevant += matches 133 mutex.Unlock() 134 }) 135 136 return float32(relevant) / float32(retrieved), float32(querying.Microseconds()) / float32(queries_size), err 137 } 138 139 func hasDuplicates(results []uint64) bool { 140 for i := 0; i < len(results)-1; i++ { 141 for j := i + 1; j < len(results); j++ { 142 if results[i] == results[j] { 143 return true 144 } 145 } 146 } 147 return false 148 } 149 150 func Test_NoRaceFlatIndex(t *testing.T) { 151 dirName := t.TempDir() 152 153 logger, _ := test.NewNullLogger() 154 155 dimensions := 256 156 vectors_size := 12000 157 queries_size := 100 158 k := 10 159 vectors, queries := testinghelpers.RandomVecs(vectors_size, queries_size, dimensions) 160 testinghelpers.Normalize(vectors) 161 testinghelpers.Normalize(queries) 162 distancer := distancer.NewCosineDistanceProvider() 163 164 truths := make([][]uint64, queries_size) 165 for i := range queries { 166 truths[i], _ = testinghelpers.BruteForce(vectors, queries[i], k, distanceWrapper(distancer)) 167 } 168 169 extraVectorsForDelete, _ := testinghelpers.RandomVecs(5_000, 0, dimensions) 170 for _, compression := range []string{compressionNone, compressionBQ} { 171 t.Run("compression: "+compression, func(t *testing.T) { 172 for _, cache := range []bool{false, true} { 173 t.Run("cache: "+strconv.FormatBool(cache), func(t *testing.T) { 174 if compression == compressionNone && cache == true { 175 return 176 } 177 targetRecall := float32(0.99) 178 if compression == compressionBQ { 179 targetRecall = 0.8 180 } 181 t.Run("recall", func(t *testing.T) { 182 recall, latency, err := run(dirName, logger, compression, cache, vectors, queries, k, truths, nil, nil, distancer) 183 require.Nil(t, err) 184 185 fmt.Println(recall, latency) 186 assert.Greater(t, recall, targetRecall) 187 assert.Less(t, latency, float32(1_000_000)) 188 }) 189 190 t.Run("recall with deletes", func(t *testing.T) { 191 recall, latency, err := run(dirName, logger, compression, cache, vectors, queries, k, truths, extraVectorsForDelete, nil, distancer) 192 require.Nil(t, err) 193 194 fmt.Println(recall, latency) 195 assert.Greater(t, recall, targetRecall) 196 assert.Less(t, latency, float32(1_000_000)) 197 }) 198 }) 199 } 200 }) 201 } 202 for _, compression := range []string{compressionNone, compressionBQ} { 203 t.Run("compression: "+compression, func(t *testing.T) { 204 for _, cache := range []bool{false, true} { 205 t.Run("cache: "+strconv.FormatBool(cache), func(t *testing.T) { 206 from := 0 207 to := 3_000 208 for i := range queries { 209 truths[i], _ = testinghelpers.BruteForce(vectors[from:to], queries[i], k, distanceWrapper(distancer)) 210 } 211 212 allowIds := make([]uint64, 0, to-from) 213 for i := uint64(from); i < uint64(to); i++ { 214 allowIds = append(allowIds, i) 215 } 216 targetRecall := float32(0.99) 217 if compression == compressionBQ { 218 targetRecall = 0.8 219 } 220 221 t.Run("recall on filtered", func(t *testing.T) { 222 recall, latency, err := run(dirName, logger, compression, cache, vectors, queries, k, truths, nil, allowIds, distancer) 223 require.Nil(t, err) 224 225 fmt.Println(recall, latency) 226 assert.Greater(t, recall, targetRecall) 227 assert.Less(t, latency, float32(1_000_000)) 228 }) 229 230 t.Run("recall on filtered with deletes", func(t *testing.T) { 231 recall, latency, err := run(dirName, logger, compression, cache, vectors, queries, k, truths, extraVectorsForDelete, allowIds, distancer) 232 require.Nil(t, err) 233 234 fmt.Println(recall, latency) 235 assert.Greater(t, recall, targetRecall) 236 assert.Less(t, latency, float32(1_000_000)) 237 }) 238 }) 239 } 240 }) 241 } 242 243 err := os.RemoveAll(dirName) 244 if err != nil { 245 fmt.Println(err) 246 } 247 }