github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/hnsw/recall_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  //go:build benchmarkRecall
    13  // +build benchmarkRecall
    14  
    15  package hnsw
    16  
    17  import (
    18  	"context"
    19  	"encoding/json"
    20  	"fmt"
    21  	"io/ioutil"
    22  	"runtime"
    23  	"sync"
    24  	"testing"
    25  	"time"
    26  
    27  	"github.com/stretchr/testify/assert"
    28  	"github.com/stretchr/testify/require"
    29  	"github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw/distancer"
    30  	"github.com/weaviate/weaviate/adapters/repos/db/vector/testinghelpers"
    31  )
    32  
    33  func TestRecall(t *testing.T) {
    34  	efConstruction := 256
    35  	ef := 256
    36  	maxNeighbors := 64
    37  
    38  	var vectors [][]float32
    39  	var queries [][]float32
    40  	var truths [][]uint64
    41  	var vectorIndex *hnsw
    42  
    43  	t.Run("generate random vectors", func(t *testing.T) {
    44  		vectorsJSON, err := ioutil.ReadFile("recall_vectors.json")
    45  		require.Nil(t, err)
    46  		err = json.Unmarshal(vectorsJSON, &vectors)
    47  		require.Nil(t, err)
    48  
    49  		queriesJSON, err := ioutil.ReadFile("recall_queries.json")
    50  		require.Nil(t, err)
    51  		err = json.Unmarshal(queriesJSON, &queries)
    52  		require.Nil(t, err)
    53  
    54  		truthsJSON, err := ioutil.ReadFile("recall_truths.json")
    55  		require.Nil(t, err)
    56  		err = json.Unmarshal(truthsJSON, &truths)
    57  		require.Nil(t, err)
    58  	})
    59  
    60  	t.Run("importing into hnsw", func(t *testing.T) {
    61  		fmt.Printf("importing into hnsw\n")
    62  
    63  		index, err := New(Config{
    64  			RootPath:              "doesnt-matter-as-committlogger-is-mocked-out",
    65  			ID:                    "recallbenchmark",
    66  			MakeCommitLoggerThunk: MakeNoopCommitLogger,
    67  			DistanceProvider:      distancer.NewCosineDistanceProvider(),
    68  			VectorForIDThunk: func(ctx context.Context, id uint64) ([]float32, error) {
    69  				return vectors[int(id)], nil
    70  			},
    71  		}, UserConfig{
    72  			MaxConnections: maxNeighbors,
    73  			EFConstruction: efConstruction,
    74  			EF:             ef,
    75  		}, testinghelpers.NewDummyStore(t))
    76  		require.Nil(t, err)
    77  		vectorIndex = index
    78  
    79  		workerCount := runtime.GOMAXPROCS(0)
    80  		jobsForWorker := make([][][]float32, workerCount)
    81  
    82  		before := time.Now()
    83  		for i, vec := range vectors {
    84  			workerID := i % workerCount
    85  			jobsForWorker[workerID] = append(jobsForWorker[workerID], vec)
    86  		}
    87  
    88  		wg := &sync.WaitGroup{}
    89  		for workerID, jobs := range jobsForWorker {
    90  			wg.Add(1)
    91  			go func(workerID int, myJobs [][]float32) {
    92  				defer wg.Done()
    93  				for i, vec := range myJobs {
    94  					originalIndex := (i * workerCount) + workerID
    95  					err := vectorIndex.Add(uint64(originalIndex), vec)
    96  					require.Nil(t, err)
    97  				}
    98  			}(workerID, jobs)
    99  		}
   100  
   101  		wg.Wait()
   102  		fmt.Printf("importing took %s\n", time.Since(before))
   103  	})
   104  
   105  	t.Run("inspect a query", func(t *testing.T) {
   106  		k := 20
   107  
   108  		hasDuplicates := 0
   109  
   110  		for _, vec := range queries {
   111  			results, _, err := vectorIndex.SearchByVector(vec, k, nil)
   112  			require.Nil(t, err)
   113  			if containsDuplicates(results) {
   114  				hasDuplicates++
   115  				panic("stop")
   116  			}
   117  		}
   118  
   119  		fmt.Printf("%d out of %d searches contained duplicates", hasDuplicates, len(queries))
   120  	})
   121  
   122  	t.Run("with k=10", func(t *testing.T) {
   123  		k := 10
   124  
   125  		var relevant int
   126  		var retrieved int
   127  
   128  		for i := 0; i < len(queries); i++ {
   129  			results, _, err := vectorIndex.SearchByVector(queries[i], k, nil)
   130  			require.Nil(t, err)
   131  
   132  			retrieved += k
   133  			relevant += matchesInLists(truths[i], results)
   134  		}
   135  
   136  		recall := float32(relevant) / float32(retrieved)
   137  		fmt.Printf("recall is %f\n", recall)
   138  		assert.True(t, recall >= 0.99)
   139  	})
   140  }
   141  
   142  func matchesInLists(control []uint64, results []uint64) int {
   143  	desired := map[uint64]struct{}{}
   144  	for _, relevant := range control {
   145  		desired[relevant] = struct{}{}
   146  	}
   147  
   148  	var matches int
   149  	for _, candidate := range results {
   150  		_, ok := desired[candidate]
   151  		if ok {
   152  			matches++
   153  		}
   154  	}
   155  
   156  	return matches
   157  }
   158  
   159  func containsDuplicates(in []uint64) bool {
   160  	seen := map[uint64]struct{}{}
   161  
   162  	for _, value := range in {
   163  		if _, ok := seen[value]; ok {
   164  			return true
   165  		}
   166  		seen[value] = struct{}{}
   167  	}
   168  
   169  	return false
   170  }