github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/hnsw/generate_recall_datasets.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  //go:build ignore
    13  // +build ignore
    14  
    15  package main
    16  
    17  import (
    18  	"encoding/json"
    19  	"fmt"
    20  	"io/ioutil"
    21  	"math"
    22  	"math/rand"
    23  	"sort"
    24  
    25  	"github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw/distancer"
    26  )
    27  
    28  func main() {
    29  	dimensions := 256
    30  	size := 10000
    31  	queries := 1000
    32  
    33  	vectors := make([][]float32, size)
    34  	queryVectors := make([][]float32, queries)
    35  	truths := make([][]uint64, queries)
    36  
    37  	fmt.Printf("generating %d vectors", size)
    38  	for i := 0; i < size; i++ {
    39  		vector := make([]float32, dimensions)
    40  		for j := 0; j < dimensions; j++ {
    41  			vector[j] = rand.Float32()
    42  		}
    43  		vectors[i] = Normalize(vector)
    44  
    45  	}
    46  	fmt.Printf("done\n")
    47  
    48  	fmt.Printf("generating %d search queries", queries)
    49  	for i := 0; i < queries; i++ {
    50  		queryVector := make([]float32, dimensions)
    51  		for j := 0; j < dimensions; j++ {
    52  			queryVector[j] = rand.Float32()
    53  		}
    54  		queryVectors[i] = Normalize(queryVector)
    55  	}
    56  	fmt.Printf("done\n")
    57  
    58  	fmt.Printf("defining truth through brute force")
    59  
    60  	k := 10
    61  	for i, query := range queryVectors {
    62  		truths[i] = bruteForce(vectors, query, k)
    63  	}
    64  
    65  	vectorsJSON, _ := json.Marshal(vectors)
    66  	queriesJSON, _ := json.Marshal(queryVectors)
    67  	truthsJSON, _ := json.Marshal(truths)
    68  
    69  	ioutil.WriteFile("recall_vectors.json", vectorsJSON, 0o644)
    70  	ioutil.WriteFile("recall_queries.json", queriesJSON, 0o644)
    71  	ioutil.WriteFile("recall_truths.json", truthsJSON, 0o644)
    72  }
    73  
    74  func Normalize(v []float32) []float32 {
    75  	var norm float32
    76  	for i := range v {
    77  		norm += v[i] * v[i]
    78  	}
    79  
    80  	norm = float32(math.Sqrt(float64(norm)))
    81  	for i := range v {
    82  		v[i] = v[i] / norm
    83  	}
    84  
    85  	return v
    86  }
    87  
    88  func bruteForce(vectors [][]float32, query []float32, k int) []uint64 {
    89  	type distanceAndIndex struct {
    90  		distance float32
    91  		index    uint64
    92  	}
    93  
    94  	distances := make([]distanceAndIndex, len(vectors))
    95  
    96  	for i, vec := range vectors {
    97  		dist, _, _ := distancer.NewCosineDistanceProvider().SingleDist(query, vec)
    98  		distances[i] = distanceAndIndex{
    99  			index:    uint64(i),
   100  			distance: dist,
   101  		}
   102  	}
   103  
   104  	sort.Slice(distances, func(a, b int) bool {
   105  		return distances[a].distance < distances[b].distance
   106  	})
   107  
   108  	if len(distances) < k {
   109  		k = len(distances)
   110  	}
   111  
   112  	out := make([]uint64, k)
   113  	for i := 0; i < k; i++ {
   114  		out[i] = distances[i].index
   115  	}
   116  
   117  	return out
   118  }