github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/hnsw/compress_deletes_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  //go:build !race
    13  
    14  package hnsw
    15  
    16  import (
    17  	"context"
    18  	"fmt"
    19  	"os"
    20  	"testing"
    21  
    22  	"github.com/stretchr/testify/assert"
    23  	"github.com/stretchr/testify/require"
    24  	"github.com/weaviate/weaviate/adapters/repos/db/vector/common"
    25  	"github.com/weaviate/weaviate/adapters/repos/db/vector/compressionhelpers"
    26  	"github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw/distancer"
    27  	"github.com/weaviate/weaviate/adapters/repos/db/vector/testinghelpers"
    28  	"github.com/weaviate/weaviate/entities/cyclemanager"
    29  	"github.com/weaviate/weaviate/entities/storobj"
    30  	ent "github.com/weaviate/weaviate/entities/vectorindex/hnsw"
    31  )
    32  
    33  func Test_NoRaceCompressDoesNotCrash(t *testing.T) {
    34  	efConstruction := 64
    35  	ef := 32
    36  	maxNeighbors := 32
    37  	dimensions := 20
    38  	vectors_size := 10000
    39  	queries_size := 100
    40  	k := 100
    41  	delete_indices := make([]uint64, 0, 1000)
    42  	for i := 0; i < 1000; i++ {
    43  		delete_indices = append(delete_indices, uint64(i+10))
    44  	}
    45  	delete_indices = append(delete_indices, uint64(1))
    46  
    47  	vectors, queries := testinghelpers.RandomVecs(vectors_size, queries_size, dimensions)
    48  	distancer := distancer.NewL2SquaredProvider()
    49  
    50  	uc := ent.UserConfig{}
    51  	uc.MaxConnections = maxNeighbors
    52  	uc.EFConstruction = efConstruction
    53  	uc.EF = ef
    54  	uc.VectorCacheMaxObjects = 10e12
    55  	uc.PQ = ent.PQConfig{Enabled: true, Encoder: ent.PQEncoder{Type: "title", Distribution: "normal"}}
    56  
    57  	index, _ := New(Config{
    58  		RootPath:              t.TempDir(),
    59  		ID:                    "recallbenchmark",
    60  		MakeCommitLoggerThunk: MakeNoopCommitLogger,
    61  		DistanceProvider:      distancer,
    62  		VectorForIDThunk: func(ctx context.Context, id uint64) ([]float32, error) {
    63  			if int(id) >= len(vectors) {
    64  				return nil, storobj.NewErrNotFoundf(id, "out of range")
    65  			}
    66  			return vectors[int(id)], nil
    67  		},
    68  		TempVectorForIDThunk: func(ctx context.Context, id uint64, container *common.VectorSlice) ([]float32, error) {
    69  			copy(container.Slice, vectors[int(id)])
    70  			return container.Slice, nil
    71  		},
    72  	}, uc, cyclemanager.NewCallbackGroupNoop(), cyclemanager.NewCallbackGroupNoop(),
    73  		cyclemanager.NewCallbackGroupNoop(), testinghelpers.NewDummyStore(t))
    74  	defer index.Shutdown(context.Background())
    75  	compressionhelpers.Concurrently(uint64(len(vectors)), func(id uint64) {
    76  		index.Add(uint64(id), vectors[id])
    77  	})
    78  	index.Delete(delete_indices...)
    79  
    80  	cfg := ent.PQConfig{
    81  		Enabled: true,
    82  		Encoder: ent.PQEncoder{
    83  			Type:         ent.PQEncoderTypeKMeans,
    84  			Distribution: ent.PQEncoderDistributionLogNormal,
    85  		},
    86  		Segments:  dimensions,
    87  		Centroids: 256,
    88  	}
    89  	uc.PQ = cfg
    90  	index.compress(uc)
    91  	for _, v := range queries {
    92  		_, _, err := index.SearchByVector(v, k, nil)
    93  		assert.Nil(t, err)
    94  	}
    95  }
    96  
    97  func TestHnswPqNilVectors(t *testing.T) {
    98  	dimensions := 20
    99  	vectors_size := 10_000
   100  	queries_size := 10
   101  
   102  	vectors, _ := testinghelpers.RandomVecs(vectors_size, queries_size, dimensions)
   103  
   104  	// set some vectors to nil
   105  	for i := range vectors {
   106  		if i == 500 {
   107  			vectors[i] = nil
   108  		}
   109  	}
   110  
   111  	userConfig := ent.UserConfig{
   112  		MaxConnections: 30,
   113  		EFConstruction: 64,
   114  		EF:             32,
   115  
   116  		// The actual size does not matter for this test, but if it defaults to
   117  		// zero it will constantly think it's full and needs to be deleted - even
   118  		// after just being deleted, so make sure to use a positive number here.
   119  		VectorCacheMaxObjects: 1000000,
   120  	}
   121  
   122  	rootPath := "doesnt-matter-as-committlogger-is-mocked-out"
   123  	defer func(path string) {
   124  		err := os.RemoveAll(path)
   125  		if err != nil {
   126  			fmt.Println(err)
   127  		}
   128  	}(rootPath)
   129  
   130  	index, err := New(Config{
   131  		RootPath:              rootPath,
   132  		ID:                    "nil-vector-test",
   133  		MakeCommitLoggerThunk: MakeNoopCommitLogger,
   134  		DistanceProvider:      distancer.NewCosineDistanceProvider(),
   135  		VectorForIDThunk: func(ctx context.Context, id uint64) ([]float32, error) {
   136  			vec := vectors[int(id)]
   137  			if vec == nil {
   138  				return nil, storobj.NewErrNotFoundf(id, "nil vec")
   139  			}
   140  			return vec, nil
   141  		},
   142  		TempVectorForIDThunk: TempVectorForIDThunk(vectors),
   143  	}, userConfig, cyclemanager.NewCallbackGroupNoop(), cyclemanager.NewCallbackGroupNoop(), cyclemanager.NewCallbackGroupNoop(), testinghelpers.NewDummyStore(t))
   144  
   145  	require.NoError(t, err)
   146  
   147  	compressionhelpers.Concurrently(uint64(len(vectors)/2), func(id uint64) {
   148  		if vectors[id] == nil {
   149  			return
   150  		}
   151  
   152  		err := index.Add(uint64(id), vectors[id])
   153  		require.Nil(t, err)
   154  	})
   155  
   156  	userConfig.PQ = ent.PQConfig{
   157  		Enabled: true,
   158  		Encoder: ent.PQEncoder{
   159  			Type:         ent.PQEncoderTypeTile,
   160  			Distribution: ent.PQEncoderDistributionLogNormal,
   161  		},
   162  		BitCompression: false,
   163  		Segments:       dimensions,
   164  		Centroids:      256,
   165  	}
   166  
   167  	ch := make(chan error)
   168  	err = index.UpdateUserConfig(userConfig, func() {
   169  		close(ch)
   170  	})
   171  	require.NoError(t, err)
   172  
   173  	<-ch
   174  	start := uint64(len(vectors) / 2)
   175  	compressionhelpers.Concurrently(uint64(len(vectors)/2), func(id uint64) {
   176  		if vectors[id+start] == nil {
   177  			return
   178  		}
   179  
   180  		err = index.Add(uint64(id)+start, vectors[id+start])
   181  		require.Nil(t, err)
   182  	})
   183  }