github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/hnsw/recall_geo_spatial_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  //go:build integrationTestSlow && !race
    13  // +build integrationTestSlow,!race
    14  
    15  package hnsw
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"math/rand"
    21  	"runtime"
    22  	"sort"
    23  	"sync"
    24  	"testing"
    25  	"time"
    26  
    27  	"github.com/stretchr/testify/assert"
    28  	"github.com/stretchr/testify/require"
    29  	"github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw/distancer"
    30  	"github.com/weaviate/weaviate/adapters/repos/db/vector/testinghelpers"
    31  	"github.com/weaviate/weaviate/entities/cyclemanager"
    32  	ent "github.com/weaviate/weaviate/entities/vectorindex/hnsw"
    33  )
    34  
    35  func TestRecallGeo(t *testing.T) {
    36  	size := 10000
    37  	queries := 100
    38  	efConstruction := 128
    39  	maxNeighbors := 64
    40  
    41  	vectors := make([][]float32, size)
    42  	queryVectors := make([][]float32, queries)
    43  	var vectorIndex *hnsw
    44  
    45  	t.Run("generate random vectors", func(t *testing.T) {
    46  		fmt.Printf("generating %d vectors", size)
    47  		for i := 0; i < size; i++ {
    48  			lat, lon := randLatLon()
    49  			vectors[i] = []float32{lat, lon}
    50  		}
    51  		fmt.Printf("done\n")
    52  
    53  		fmt.Printf("generating %d search queries", queries)
    54  		for i := 0; i < queries; i++ {
    55  			lat, lon := randLatLon()
    56  			queryVectors[i] = []float32{lat, lon}
    57  		}
    58  		fmt.Printf("done\n")
    59  	})
    60  
    61  	t.Run("importing into hnsw", func(t *testing.T) {
    62  		fmt.Printf("importing into hnsw\n")
    63  		index, err := New(Config{
    64  			RootPath:              "doesnt-matter-as-committlogger-is-mocked-out",
    65  			ID:                    "recallbenchmark",
    66  			MakeCommitLoggerThunk: MakeNoopCommitLogger,
    67  			DistanceProvider:      distancer.NewGeoProvider(),
    68  			VectorForIDThunk: func(ctx context.Context, id uint64) ([]float32, error) {
    69  				return vectors[int(id)], nil
    70  			},
    71  		}, ent.UserConfig{
    72  			MaxConnections: maxNeighbors,
    73  			EFConstruction: efConstruction,
    74  		}, cyclemanager.NewCallbackGroupNoop(), cyclemanager.NewCallbackGroupNoop(),
    75  			cyclemanager.NewCallbackGroupNoop(), testinghelpers.NewDummyStore(t))
    76  
    77  		require.Nil(t, err)
    78  		vectorIndex = index
    79  
    80  		workerCount := runtime.GOMAXPROCS(0)
    81  		jobsForWorker := make([][][]float32, workerCount)
    82  
    83  		for i, vec := range vectors {
    84  			workerID := i % workerCount
    85  			jobsForWorker[workerID] = append(jobsForWorker[workerID], vec)
    86  		}
    87  
    88  		beforeImport := time.Now()
    89  		wg := &sync.WaitGroup{}
    90  		for workerID, jobs := range jobsForWorker {
    91  			wg.Add(1)
    92  			go func(workerID int, myJobs [][]float32) {
    93  				defer wg.Done()
    94  				for i, vec := range myJobs {
    95  					originalIndex := (i * workerCount) + workerID
    96  					err := vectorIndex.Add(uint64(originalIndex), vec)
    97  					require.Nil(t, err)
    98  				}
    99  			}(workerID, jobs)
   100  		}
   101  
   102  		wg.Wait()
   103  		fmt.Printf("import took %s\n", time.Since(beforeImport))
   104  	})
   105  
   106  	t.Run("with k=10", func(t *testing.T) {
   107  		k := 10
   108  
   109  		var relevant int
   110  		var retrieved int
   111  
   112  		var times time.Duration
   113  
   114  		for i := 0; i < queries; i++ {
   115  			controlList := bruteForce(vectors, queryVectors[i], k)
   116  			before := time.Now()
   117  			results, _, err := vectorIndex.knnSearchByVector(queryVectors[i], k, 800, nil)
   118  			times += time.Since(before)
   119  
   120  			require.Nil(t, err)
   121  
   122  			retrieved += k
   123  			relevant += matchesInLists(controlList, results)
   124  		}
   125  
   126  		recall := float32(relevant) / float32(retrieved)
   127  		fmt.Printf("recall is %f\n", recall)
   128  		fmt.Printf("avg search time for k=%d is %s\n", k, times/time.Duration(queries))
   129  		assert.True(t, recall >= 0.99)
   130  	})
   131  
   132  	t.Run("with max dist set", func(t *testing.T) {
   133  		distances := []float32{
   134  			0.1,
   135  			1,
   136  			10,
   137  			100,
   138  			1000,
   139  			2000,
   140  			5000,
   141  			7500,
   142  			10000,
   143  			12500,
   144  			15000,
   145  			20000,
   146  			35000,
   147  			100000, // larger than the circumference of the earth, should contain all
   148  		}
   149  
   150  		for _, maxDist := range distances {
   151  			t.Run(fmt.Sprintf("with maxDist=%f", maxDist), func(t *testing.T) {
   152  				var relevant int
   153  				var retrieved int
   154  
   155  				var times time.Duration
   156  
   157  				for i := 0; i < queries; i++ {
   158  					controlList := bruteForceMaxDist(vectors, queryVectors[i], maxDist)
   159  					before := time.Now()
   160  					results, err := vectorIndex.KnnSearchByVectorMaxDist(queryVectors[i], maxDist, 800, nil)
   161  					times += time.Since(before)
   162  					require.Nil(t, err)
   163  
   164  					retrieved += len(results)
   165  					relevant += matchesInLists(controlList, results)
   166  				}
   167  
   168  				if relevant == 0 {
   169  					// skip, as we risk dividing by zero, if both relevant and retrieved
   170  					// are zero, however, we want to fail with a divide-by-zero if only
   171  					// retrieved is 0 and relevant was more than 0
   172  					return
   173  				}
   174  				recall := float32(relevant) / float32(retrieved)
   175  				fmt.Printf("recall is %f\n", recall)
   176  				fmt.Printf("avg search time for maxDist=%f is %s\n", maxDist, times/time.Duration(queries))
   177  				assert.True(t, recall >= 0.99)
   178  			})
   179  		}
   180  	})
   181  }
   182  
   183  func matchesInLists(control []uint64, results []uint64) int {
   184  	desired := map[uint64]struct{}{}
   185  	for _, relevant := range control {
   186  		desired[relevant] = struct{}{}
   187  	}
   188  
   189  	var matches int
   190  	for _, candidate := range results {
   191  		_, ok := desired[candidate]
   192  		if ok {
   193  			matches++
   194  		}
   195  	}
   196  
   197  	return matches
   198  }
   199  
   200  func bruteForce(vectors [][]float32, query []float32, k int) []uint64 {
   201  	type distanceAndIndex struct {
   202  		distance float32
   203  		index    uint64
   204  	}
   205  
   206  	distances := make([]distanceAndIndex, len(vectors))
   207  
   208  	distancer := distancer.NewGeoProvider().New(query)
   209  	for i, vec := range vectors {
   210  		dist, _, _ := distancer.Distance(vec)
   211  		distances[i] = distanceAndIndex{
   212  			index:    uint64(i),
   213  			distance: dist,
   214  		}
   215  	}
   216  
   217  	sort.Slice(distances, func(a, b int) bool {
   218  		return distances[a].distance < distances[b].distance
   219  	})
   220  
   221  	if len(distances) < k {
   222  		k = len(distances)
   223  	}
   224  
   225  	out := make([]uint64, k)
   226  	for i := 0; i < k; i++ {
   227  		out[i] = distances[i].index
   228  	}
   229  
   230  	return out
   231  }
   232  
   233  func bruteForceMaxDist(vectors [][]float32, query []float32, maxDist float32) []uint64 {
   234  	type distanceAndIndex struct {
   235  		distance float32
   236  		index    uint64
   237  	}
   238  
   239  	distances := make([]distanceAndIndex, len(vectors))
   240  
   241  	distancer := distancer.NewGeoProvider().New(query)
   242  	for i, vec := range vectors {
   243  		dist, _, _ := distancer.Distance(vec)
   244  		distances[i] = distanceAndIndex{
   245  			index:    uint64(i),
   246  			distance: dist,
   247  		}
   248  	}
   249  
   250  	sort.Slice(distances, func(a, b int) bool {
   251  		return distances[a].distance < distances[b].distance
   252  	})
   253  
   254  	out := make([]uint64, len(distances))
   255  	i := 0
   256  	for _, elem := range distances {
   257  		if elem.distance > maxDist {
   258  			break
   259  		}
   260  		out[i] = distances[i].index
   261  		i++
   262  	}
   263  
   264  	return out[:i]
   265  }
   266  
   267  func randLatLon() (float32, float32) {
   268  	maxLat := float32(90.0)
   269  	minLat := float32(-90.0)
   270  	maxLon := float32(180)
   271  	minLon := float32(-180)
   272  
   273  	lat := minLat + (maxLat-minLat)*rand.Float32()
   274  	lon := minLon + (maxLon-minLon)*rand.Float32()
   275  	return lat, lon
   276  }