github.com/weaviate/weaviate@v1.24.6/modules/ref2vec-centroid/vectorizer/vectorizer_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package vectorizer
    13  
    14  import (
    15  	"context"
    16  	"errors"
    17  	"reflect"
    18  	"testing"
    19  
    20  	"github.com/go-openapi/strfmt"
    21  	"github.com/google/uuid"
    22  	"github.com/stretchr/testify/assert"
    23  	"github.com/weaviate/weaviate/entities/models"
    24  	"github.com/weaviate/weaviate/entities/schema/crossref"
    25  	"github.com/weaviate/weaviate/entities/search"
    26  	"github.com/weaviate/weaviate/modules/ref2vec-centroid/config"
    27  )
    28  
    29  func TestVectorizer_New(t *testing.T) {
    30  	repo := &fakeObjectsRepo{}
    31  	t.Run("default is set correctly", func(t *testing.T) {
    32  		vzr := New(fakeClassConfig(config.Default()), repo.Object)
    33  
    34  		expected := reflect.ValueOf(calculateMean).Pointer()
    35  		received := reflect.ValueOf(vzr.calcFn).Pointer()
    36  
    37  		assert.EqualValues(t, expected, received)
    38  	})
    39  
    40  	t.Run("default calcFn is used when none provided", func(t *testing.T) {
    41  		cfg := fakeClassConfig{"method": ""}
    42  		vzr := New(cfg, repo.Object)
    43  
    44  		expected := reflect.ValueOf(calculateMean).Pointer()
    45  		received := reflect.ValueOf(vzr.calcFn).Pointer()
    46  
    47  		assert.EqualValues(t, expected, received)
    48  	})
    49  }
    50  
    51  func TestVectorizer_Object(t *testing.T) {
    52  	t.Run("calculate with mean", func(t *testing.T) {
    53  		type objectSearchResult struct {
    54  			res *search.Result
    55  			err error
    56  		}
    57  
    58  		tests := []struct {
    59  			name                string
    60  			objectSearchResults []objectSearchResult
    61  			expectedResult      []float32
    62  			expectedCalcError   error
    63  		}{
    64  			{
    65  				name: "expected success 1",
    66  				objectSearchResults: []objectSearchResult{
    67  					{res: &search.Result{Vector: []float32{2, 4, 6}}},
    68  					{res: &search.Result{Vector: []float32{4, 6, 8}}},
    69  				},
    70  				expectedResult: []float32{3, 5, 7},
    71  			},
    72  			{
    73  				name: "expected success 2",
    74  				objectSearchResults: []objectSearchResult{
    75  					{res: &search.Result{Vector: []float32{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}}},
    76  					{res: &search.Result{Vector: []float32{2, 2, 2, 2, 2, 2, 2, 2, 2, 2}}},
    77  					{res: &search.Result{Vector: []float32{3, 3, 3, 3, 3, 3, 3, 3, 3, 3}}},
    78  					{res: &search.Result{Vector: []float32{4, 4, 4, 4, 4, 4, 4, 4, 4, 4}}},
    79  					{res: &search.Result{Vector: []float32{5, 5, 5, 5, 5, 5, 5, 5, 5, 5}}},
    80  					{res: &search.Result{Vector: []float32{6, 6, 6, 6, 6, 6, 6, 6, 6, 6}}},
    81  					{res: &search.Result{Vector: []float32{7, 7, 7, 7, 7, 7, 7, 7, 7, 7}}},
    82  					{res: &search.Result{Vector: []float32{8, 8, 8, 8, 8, 8, 8, 8, 8, 8}}},
    83  					{res: &search.Result{Vector: []float32{9, 9, 9, 9, 9, 9, 9, 9, 9, 9}}},
    84  				},
    85  				expectedResult: []float32{5, 5, 5, 5, 5, 5, 5, 5, 5, 5},
    86  			},
    87  			{
    88  				name:                "expected success 3",
    89  				objectSearchResults: []objectSearchResult{{}},
    90  			},
    91  			{
    92  				name: "expected success 4",
    93  				objectSearchResults: []objectSearchResult{
    94  					{res: &search.Result{Vector: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9}}},
    95  				},
    96  				expectedResult: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9},
    97  			},
    98  			{
    99  				name: "expected success 5",
   100  				objectSearchResults: []objectSearchResult{
   101  					{res: &search.Result{}},
   102  				},
   103  				expectedResult: nil,
   104  			},
   105  			{
   106  				name: "expected error - mismatched vector dimensions",
   107  				objectSearchResults: []objectSearchResult{
   108  					{res: &search.Result{Vector: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9}}},
   109  					{res: &search.Result{Vector: []float32{1, 2, 3, 4, 5, 6, 7, 8}}},
   110  				},
   111  				expectedCalcError: errors.New(
   112  					"calculate vector: calculate mean: found vectors of different length: 9 and 8"),
   113  			},
   114  		}
   115  
   116  		for _, test := range tests {
   117  			t.Run(test.name, func(t *testing.T) {
   118  				ctx := context.Background()
   119  				repo := &fakeObjectsRepo{}
   120  				refProps := []interface{}{"toRef"}
   121  				cfg := fakeClassConfig{"method": "mean", "referenceProperties": refProps}
   122  
   123  				crossRefs := make([]*crossref.Ref, len(test.objectSearchResults))
   124  				modelRefs := make(models.MultipleRef, len(test.objectSearchResults))
   125  				for i, res := range test.objectSearchResults {
   126  					crossRef := crossref.New("localhost", "SomeClass",
   127  						strfmt.UUID(uuid.NewString()))
   128  					crossRefs[i] = crossRef
   129  					modelRefs[i] = crossRef.SingleRef()
   130  
   131  					repo.On("Object", ctx, crossRef.Class, crossRef.TargetID, "").
   132  						Return(res.res, res.err)
   133  				}
   134  
   135  				obj := &models.Object{
   136  					Properties: map[string]interface{}{"toRef": modelRefs},
   137  				}
   138  
   139  				vec, err := New(cfg, repo.Object).Object(ctx, obj)
   140  				if test.expectedCalcError != nil {
   141  					assert.EqualError(t, err, test.expectedCalcError.Error())
   142  				} else {
   143  					assert.EqualValues(t, test.expectedResult, vec)
   144  				}
   145  			})
   146  		}
   147  	})
   148  
   149  	// due to the fix introduced in https://github.com/weaviate/weaviate/pull/2320,
   150  	// MultipleRef's can appear as empty []interface{} when no actual refs are provided for
   151  	// an object's reference property.
   152  	//
   153  	// this test asserts that reference properties do not break when they are unmarshalled
   154  	// as empty interface{} slices.
   155  	t.Run("when rep prop is stored as empty interface{} slice", func(t *testing.T) {
   156  		ctx := context.Background()
   157  		repo := &fakeObjectsRepo{}
   158  		refProps := []interface{}{"toRef"}
   159  		cfg := fakeClassConfig{"method": "mean", "referenceProperties": refProps}
   160  
   161  		obj := &models.Object{
   162  			Properties: map[string]interface{}{"toRef": []interface{}{}},
   163  		}
   164  
   165  		_, err := New(cfg, repo.Object).Object(ctx, obj)
   166  		assert.Nil(t, err)
   167  		assert.Nil(t, obj.Vector)
   168  	})
   169  }
   170  
   171  func TestVectorizer_Tenant(t *testing.T) {
   172  	objectSearchResults := search.Result{Vector: []float32{}}
   173  	ctx := context.Background()
   174  	repo := &fakeObjectsRepo{}
   175  	refProps := []interface{}{"toRef"}
   176  	cfg := fakeClassConfig{"method": "mean", "referenceProperties": refProps}
   177  	tenant := "randomTenant"
   178  
   179  	crossRef := crossref.New("localhost", "SomeClass",
   180  		strfmt.UUID(uuid.NewString()))
   181  	modelRefs := models.MultipleRef{crossRef.SingleRef()}
   182  
   183  	repo.On("Object", ctx, crossRef.Class, crossRef.TargetID, tenant).
   184  		Return(&objectSearchResults, nil)
   185  
   186  	obj := &models.Object{
   187  		Properties: map[string]interface{}{"toRef": modelRefs},
   188  		Tenant:     tenant,
   189  	}
   190  
   191  	_, err := New(cfg, repo.Object).Object(ctx, obj)
   192  	assert.Nil(t, err)
   193  	assert.Nil(t, obj.Vector)
   194  }