github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/hnsw/compress_recall_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  //go:build !race
    13  
    14  package hnsw_test
    15  
    16  import (
    17  	"context"
    18  	"fmt"
    19  	"os"
    20  	"sync"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/stretchr/testify/assert"
    25  	"github.com/weaviate/weaviate/adapters/repos/db/vector/common"
    26  	"github.com/weaviate/weaviate/adapters/repos/db/vector/compressionhelpers"
    27  	"github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw"
    28  	"github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw/distancer"
    29  	"github.com/weaviate/weaviate/adapters/repos/db/vector/testinghelpers"
    30  	"github.com/weaviate/weaviate/entities/cyclemanager"
    31  	"github.com/weaviate/weaviate/entities/storobj"
    32  	ent "github.com/weaviate/weaviate/entities/vectorindex/hnsw"
    33  )
    34  
    35  func distanceWrapper(provider distancer.Provider) func(x, y []float32) float32 {
    36  	return func(x, y []float32) float32 {
    37  		dist, _, _ := provider.SingleDist(x, y)
    38  		return dist
    39  	}
    40  }
    41  
    42  func Test_NoRaceCompressionRecall(t *testing.T) {
    43  	path := t.TempDir()
    44  
    45  	efConstruction := 64
    46  	ef := 64
    47  	maxNeighbors := 32
    48  	segments := 4
    49  	dimensions := 64
    50  	vectors_size := 10000
    51  	queries_size := 100
    52  	fmt.Println("Sift1M PQ")
    53  	before := time.Now()
    54  	vectors, queries := testinghelpers.RandomVecs(vectors_size, queries_size, dimensions)
    55  	testinghelpers.Normalize(vectors)
    56  	testinghelpers.Normalize(queries)
    57  	k := 100
    58  
    59  	distancers := []distancer.Provider{
    60  		distancer.NewL2SquaredProvider(),
    61  		distancer.NewCosineDistanceProvider(),
    62  		distancer.NewDotProductProvider(),
    63  	}
    64  
    65  	for _, distancer := range distancers {
    66  		truths := make([][]uint64, queries_size)
    67  		compressionhelpers.Concurrently(uint64(len(queries)), func(i uint64) {
    68  			truths[i], _ = testinghelpers.BruteForce(vectors, queries[i], k, distanceWrapper(distancer))
    69  		})
    70  		fmt.Printf("generating data took %s\n", time.Since(before))
    71  
    72  		uc := ent.UserConfig{
    73  			MaxConnections:        maxNeighbors,
    74  			EFConstruction:        efConstruction,
    75  			EF:                    ef,
    76  			VectorCacheMaxObjects: 10e12,
    77  		}
    78  		index, _ := hnsw.New(hnsw.Config{
    79  			RootPath:              path,
    80  			ID:                    "recallbenchmark",
    81  			MakeCommitLoggerThunk: hnsw.MakeNoopCommitLogger,
    82  			ClassName:             "clasRecallBenchmark",
    83  			ShardName:             "shardRecallBenchmark",
    84  			DistanceProvider:      distancer,
    85  			VectorForIDThunk: func(ctx context.Context, id uint64) ([]float32, error) {
    86  				if int(id) >= len(vectors) {
    87  					return nil, storobj.NewErrNotFoundf(id, "out of range")
    88  				}
    89  				return vectors[int(id)], nil
    90  			},
    91  			TempVectorForIDThunk: func(ctx context.Context, id uint64, container *common.VectorSlice) ([]float32, error) {
    92  				copy(container.Slice, vectors[int(id)])
    93  				return container.Slice, nil
    94  			},
    95  		}, uc, cyclemanager.NewCallbackGroupNoop(), cyclemanager.NewCallbackGroupNoop(),
    96  			cyclemanager.NewCallbackGroupNoop(), testinghelpers.NewDummyStore(t))
    97  		init := time.Now()
    98  		compressionhelpers.Concurrently(uint64(vectors_size), func(id uint64) {
    99  			index.Add(id, vectors[id])
   100  		})
   101  		before = time.Now()
   102  		fmt.Println("Start compressing...")
   103  		uc.PQ = ent.PQConfig{
   104  			Enabled:   true,
   105  			Segments:  dimensions / segments,
   106  			Centroids: 256,
   107  			Encoder:   ent.NewDefaultUserConfig().PQ.Encoder,
   108  		}
   109  		uc.EF = 256
   110  		wg := sync.WaitGroup{}
   111  		wg.Add(1)
   112  		index.UpdateUserConfig(uc, func() {
   113  			fmt.Printf("Time to compress: %s\n", time.Since(before))
   114  			fmt.Printf("Building the index took %s\n", time.Since(init))
   115  
   116  			var relevant uint64
   117  			var retrieved int
   118  
   119  			var querying time.Duration = 0
   120  			compressionhelpers.Concurrently(uint64(len(queries)), func(i uint64) {
   121  				before = time.Now()
   122  				results, _, _ := index.SearchByVector(queries[i], k, nil)
   123  				querying += time.Since(before)
   124  				retrieved += k
   125  				relevant += testinghelpers.MatchesInLists(truths[i], results)
   126  			})
   127  
   128  			recall := float32(relevant) / float32(retrieved)
   129  			latency := float32(querying.Microseconds()) / float32(queries_size)
   130  			fmt.Println(recall, latency)
   131  			assert.True(t, recall > 0.9)
   132  
   133  			err := os.RemoveAll(path)
   134  			if err != nil {
   135  				fmt.Println(err)
   136  			}
   137  			wg.Done()
   138  		})
   139  		wg.Wait()
   140  	}
   141  }