github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/flat/index_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  //go:build !race
    13  
    14  package flat
    15  
    16  import (
    17  	"encoding/binary"
    18  	"errors"
    19  	"fmt"
    20  	"os"
    21  	"strconv"
    22  	"sync"
    23  	"testing"
    24  	"time"
    25  
    26  	"github.com/google/uuid"
    27  	"github.com/sirupsen/logrus"
    28  	"github.com/sirupsen/logrus/hooks/test"
    29  	"github.com/stretchr/testify/assert"
    30  	"github.com/stretchr/testify/require"
    31  	"github.com/weaviate/weaviate/adapters/repos/db/helpers"
    32  	"github.com/weaviate/weaviate/adapters/repos/db/lsmkv"
    33  
    34  	"github.com/weaviate/weaviate/adapters/repos/db/vector/compressionhelpers"
    35  	"github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw/distancer"
    36  	"github.com/weaviate/weaviate/adapters/repos/db/vector/testinghelpers"
    37  	"github.com/weaviate/weaviate/entities/cyclemanager"
    38  	flatent "github.com/weaviate/weaviate/entities/vectorindex/flat"
    39  )
    40  
    41  func distanceWrapper(provider distancer.Provider) func(x, y []float32) float32 {
    42  	return func(x, y []float32) float32 {
    43  		dist, _, _ := provider.SingleDist(x, y)
    44  		return dist
    45  	}
    46  }
    47  
    48  func run(dirName string, logger *logrus.Logger, compression string, vectorCache bool,
    49  	vectors [][]float32, queries [][]float32, k int, truths [][]uint64,
    50  	extraVectorsForDelete [][]float32, allowIds []uint64,
    51  	distancer distancer.Provider,
    52  ) (float32, float32, error) {
    53  	vectors_size := len(vectors)
    54  	queries_size := len(queries)
    55  	runId := uuid.New().String()
    56  
    57  	store, err := lsmkv.New(dirName, dirName, logger, nil,
    58  		cyclemanager.NewCallbackGroupNoop(), cyclemanager.NewCallbackGroupNoop())
    59  	if err != nil {
    60  		return 0, 0, err
    61  	}
    62  
    63  	pq := flatent.CompressionUserConfig{
    64  		Enabled: false,
    65  	}
    66  	bq := flatent.CompressionUserConfig{
    67  		Enabled: false,
    68  	}
    69  	switch compression {
    70  	case compressionPQ:
    71  		pq.Enabled = true
    72  		pq.RescoreLimit = 100 * k
    73  		pq.Cache = vectorCache
    74  	case compressionBQ:
    75  		bq.Enabled = true
    76  		bq.RescoreLimit = 100 * k
    77  		bq.Cache = vectorCache
    78  	}
    79  	index, err := New(Config{
    80  		ID:               runId,
    81  		DistanceProvider: distancer,
    82  	}, flatent.UserConfig{
    83  		PQ: pq,
    84  		BQ: bq,
    85  	}, store)
    86  	if err != nil {
    87  		return 0, 0, err
    88  	}
    89  
    90  	compressionhelpers.Concurrently(uint64(vectors_size), func(id uint64) {
    91  		index.Add(id, vectors[id])
    92  	})
    93  
    94  	for i := range extraVectorsForDelete {
    95  		index.Add(uint64(vectors_size+i), extraVectorsForDelete[i])
    96  	}
    97  
    98  	for i := range extraVectorsForDelete {
    99  		Id := make([]byte, 16)
   100  		binary.BigEndian.PutUint64(Id[8:], uint64(vectors_size+i))
   101  		err := index.Delete(uint64(vectors_size + i))
   102  		if err != nil {
   103  			return 0, 0, err
   104  		}
   105  	}
   106  
   107  	var relevant uint64
   108  	var retrieved int
   109  	var querying time.Duration = 0
   110  	mutex := new(sync.Mutex)
   111  
   112  	var allowList helpers.AllowList = nil
   113  	if allowIds != nil {
   114  		allowList = helpers.NewAllowList(allowIds...)
   115  	}
   116  	err = nil
   117  	compressionhelpers.Concurrently(uint64(len(queries)), func(i uint64) {
   118  		before := time.Now()
   119  		results, _, _ := index.SearchByVector(queries[i], k, allowList)
   120  
   121  		since := time.Since(before)
   122  		len := len(results)
   123  		matches := testinghelpers.MatchesInLists(truths[i], results)
   124  
   125  		if hasDuplicates(results) {
   126  			err = errors.New("results have duplicates")
   127  		}
   128  
   129  		mutex.Lock()
   130  		querying += since
   131  		retrieved += len
   132  		relevant += matches
   133  		mutex.Unlock()
   134  	})
   135  
   136  	return float32(relevant) / float32(retrieved), float32(querying.Microseconds()) / float32(queries_size), err
   137  }
   138  
   139  func hasDuplicates(results []uint64) bool {
   140  	for i := 0; i < len(results)-1; i++ {
   141  		for j := i + 1; j < len(results); j++ {
   142  			if results[i] == results[j] {
   143  				return true
   144  			}
   145  		}
   146  	}
   147  	return false
   148  }
   149  
   150  func Test_NoRaceFlatIndex(t *testing.T) {
   151  	dirName := t.TempDir()
   152  
   153  	logger, _ := test.NewNullLogger()
   154  
   155  	dimensions := 256
   156  	vectors_size := 12000
   157  	queries_size := 100
   158  	k := 10
   159  	vectors, queries := testinghelpers.RandomVecs(vectors_size, queries_size, dimensions)
   160  	testinghelpers.Normalize(vectors)
   161  	testinghelpers.Normalize(queries)
   162  	distancer := distancer.NewCosineDistanceProvider()
   163  
   164  	truths := make([][]uint64, queries_size)
   165  	for i := range queries {
   166  		truths[i], _ = testinghelpers.BruteForce(vectors, queries[i], k, distanceWrapper(distancer))
   167  	}
   168  
   169  	extraVectorsForDelete, _ := testinghelpers.RandomVecs(5_000, 0, dimensions)
   170  	for _, compression := range []string{compressionNone, compressionBQ} {
   171  		t.Run("compression: "+compression, func(t *testing.T) {
   172  			for _, cache := range []bool{false, true} {
   173  				t.Run("cache: "+strconv.FormatBool(cache), func(t *testing.T) {
   174  					if compression == compressionNone && cache == true {
   175  						return
   176  					}
   177  					targetRecall := float32(0.99)
   178  					if compression == compressionBQ {
   179  						targetRecall = 0.8
   180  					}
   181  					t.Run("recall", func(t *testing.T) {
   182  						recall, latency, err := run(dirName, logger, compression, cache, vectors, queries, k, truths, nil, nil, distancer)
   183  						require.Nil(t, err)
   184  
   185  						fmt.Println(recall, latency)
   186  						assert.Greater(t, recall, targetRecall)
   187  						assert.Less(t, latency, float32(1_000_000))
   188  					})
   189  
   190  					t.Run("recall with deletes", func(t *testing.T) {
   191  						recall, latency, err := run(dirName, logger, compression, cache, vectors, queries, k, truths, extraVectorsForDelete, nil, distancer)
   192  						require.Nil(t, err)
   193  
   194  						fmt.Println(recall, latency)
   195  						assert.Greater(t, recall, targetRecall)
   196  						assert.Less(t, latency, float32(1_000_000))
   197  					})
   198  				})
   199  			}
   200  		})
   201  	}
   202  	for _, compression := range []string{compressionNone, compressionBQ} {
   203  		t.Run("compression: "+compression, func(t *testing.T) {
   204  			for _, cache := range []bool{false, true} {
   205  				t.Run("cache: "+strconv.FormatBool(cache), func(t *testing.T) {
   206  					from := 0
   207  					to := 3_000
   208  					for i := range queries {
   209  						truths[i], _ = testinghelpers.BruteForce(vectors[from:to], queries[i], k, distanceWrapper(distancer))
   210  					}
   211  
   212  					allowIds := make([]uint64, 0, to-from)
   213  					for i := uint64(from); i < uint64(to); i++ {
   214  						allowIds = append(allowIds, i)
   215  					}
   216  					targetRecall := float32(0.99)
   217  					if compression == compressionBQ {
   218  						targetRecall = 0.8
   219  					}
   220  
   221  					t.Run("recall on filtered", func(t *testing.T) {
   222  						recall, latency, err := run(dirName, logger, compression, cache, vectors, queries, k, truths, nil, allowIds, distancer)
   223  						require.Nil(t, err)
   224  
   225  						fmt.Println(recall, latency)
   226  						assert.Greater(t, recall, targetRecall)
   227  						assert.Less(t, latency, float32(1_000_000))
   228  					})
   229  
   230  					t.Run("recall on filtered with deletes", func(t *testing.T) {
   231  						recall, latency, err := run(dirName, logger, compression, cache, vectors, queries, k, truths, extraVectorsForDelete, allowIds, distancer)
   232  						require.Nil(t, err)
   233  
   234  						fmt.Println(recall, latency)
   235  						assert.Greater(t, recall, targetRecall)
   236  						assert.Less(t, latency, float32(1_000_000))
   237  					})
   238  				})
   239  			}
   240  		})
   241  	}
   242  
   243  	err := os.RemoveAll(dirName)
   244  	if err != nil {
   245  		fmt.Println(err)
   246  	}
   247  }