github.com/weaviate/weaviate@v1.24.6/adapters/repos/db/vector/hnsw/compress_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package hnsw
    13  
    14  import (
    15  	"context"
    16  	"testing"
    17  
    18  	"github.com/stretchr/testify/assert"
    19  	"github.com/weaviate/weaviate/adapters/repos/db/vector/common"
    20  	"github.com/weaviate/weaviate/adapters/repos/db/vector/compressionhelpers"
    21  	"github.com/weaviate/weaviate/adapters/repos/db/vector/hnsw/distancer"
    22  	"github.com/weaviate/weaviate/adapters/repos/db/vector/testinghelpers"
    23  	"github.com/weaviate/weaviate/entities/cyclemanager"
    24  	"github.com/weaviate/weaviate/entities/storobj"
    25  	ent "github.com/weaviate/weaviate/entities/vectorindex/hnsw"
    26  )
    27  
    28  func TestCompression_CalculateOptimalSegments(t *testing.T) {
    29  	h := &hnsw{}
    30  
    31  	type testCase struct {
    32  		dimensions       int
    33  		expectedSegments int
    34  	}
    35  
    36  	for _, tc := range []testCase{
    37  		{
    38  			dimensions:       2048,
    39  			expectedSegments: 256,
    40  		},
    41  		{
    42  			dimensions:       1536,
    43  			expectedSegments: 256,
    44  		},
    45  		{
    46  			dimensions:       768,
    47  			expectedSegments: 128,
    48  		},
    49  		{
    50  			dimensions:       512,
    51  			expectedSegments: 128,
    52  		},
    53  		{
    54  			dimensions:       256,
    55  			expectedSegments: 64,
    56  		},
    57  		{
    58  			dimensions:       125,
    59  			expectedSegments: 125,
    60  		},
    61  		{
    62  			dimensions:       64,
    63  			expectedSegments: 32,
    64  		},
    65  		{
    66  			dimensions:       27,
    67  			expectedSegments: 27,
    68  		},
    69  		{
    70  			dimensions:       19,
    71  			expectedSegments: 19,
    72  		},
    73  		{
    74  			dimensions:       2,
    75  			expectedSegments: 1,
    76  		},
    77  	} {
    78  		segments := h.calculateOptimalSegments(tc.dimensions)
    79  		assert.Equal(t, tc.expectedSegments, segments)
    80  	}
    81  }
    82  
    83  func Test_NoRaceCompressReturnsErrorWhenNotEnoughData(t *testing.T) {
    84  	efConstruction := 64
    85  	ef := 32
    86  	maxNeighbors := 32
    87  	dimensions := 200
    88  	vectors_size := 10
    89  	vectors, _ := testinghelpers.RandomVecs(vectors_size, 0, dimensions)
    90  	distancer := distancer.NewL2SquaredProvider()
    91  
    92  	uc := ent.UserConfig{}
    93  	uc.MaxConnections = maxNeighbors
    94  	uc.EFConstruction = efConstruction
    95  	uc.EF = ef
    96  	uc.VectorCacheMaxObjects = 10e12
    97  	uc.PQ = ent.PQConfig{
    98  		Enabled: false,
    99  		Encoder: ent.PQEncoder{
   100  			Type:         ent.PQEncoderTypeKMeans,
   101  			Distribution: ent.PQEncoderDistributionLogNormal,
   102  		},
   103  		TrainingLimit: 5,
   104  		Segments:      dimensions,
   105  		Centroids:     256,
   106  	}
   107  
   108  	index, _ := New(Config{
   109  		RootPath:              t.TempDir(),
   110  		ID:                    "recallbenchmark",
   111  		MakeCommitLoggerThunk: MakeNoopCommitLogger,
   112  		DistanceProvider:      distancer,
   113  		VectorForIDThunk: func(ctx context.Context, id uint64) ([]float32, error) {
   114  			if int(id) >= len(vectors) {
   115  				return nil, storobj.NewErrNotFoundf(id, "out of range")
   116  			}
   117  			return vectors[int(id)], nil
   118  		},
   119  		TempVectorForIDThunk: func(ctx context.Context, id uint64, container *common.VectorSlice) ([]float32, error) {
   120  			copy(container.Slice, vectors[int(id)])
   121  			return container.Slice, nil
   122  		},
   123  	}, uc, cyclemanager.NewCallbackGroupNoop(), cyclemanager.NewCallbackGroupNoop(),
   124  		cyclemanager.NewCallbackGroupNoop(), testinghelpers.NewDummyStore(t))
   125  	defer index.Shutdown(context.Background())
   126  	compressionhelpers.Concurrently(uint64(len(vectors)), func(id uint64) {
   127  		index.Add(uint64(id), vectors[id])
   128  	})
   129  
   130  	cfg := ent.PQConfig{
   131  		Enabled: true,
   132  		Encoder: ent.PQEncoder{
   133  			Type:         ent.PQEncoderTypeKMeans,
   134  			Distribution: ent.PQEncoderDistributionLogNormal,
   135  		},
   136  		Segments:  dimensions,
   137  		Centroids: 256,
   138  	}
   139  	uc.PQ = cfg
   140  	err := index.compress(uc)
   141  	assert.NotNil(t, err)
   142  }