github.com/weaviate/weaviate@v1.24.6/modules/ref2vec-centroid/module_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package modcentroid
    13  
    14  import (
    15  	"context"
    16  	"fmt"
    17  	"testing"
    18  	"time"
    19  
    20  	"github.com/go-openapi/strfmt"
    21  	"github.com/google/uuid"
    22  	"github.com/sirupsen/logrus/hooks/test"
    23  	"github.com/stretchr/testify/assert"
    24  	"github.com/stretchr/testify/mock"
    25  	"github.com/weaviate/weaviate/entities/additional"
    26  	"github.com/weaviate/weaviate/entities/models"
    27  	"github.com/weaviate/weaviate/entities/modulecapabilities"
    28  	"github.com/weaviate/weaviate/entities/moduletools"
    29  	"github.com/weaviate/weaviate/entities/schema"
    30  	"github.com/weaviate/weaviate/entities/schema/crossref"
    31  	"github.com/weaviate/weaviate/entities/search"
    32  	"github.com/weaviate/weaviate/usecases/config"
    33  )
    34  
    35  func TestRef2VecCentroid(t *testing.T) {
    36  	ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
    37  	defer cancel()
    38  	sp := newFakeStorageProvider(t)
    39  	logger, _ := test.NewNullLogger()
    40  	params := moduletools.NewInitParams(sp, nil, config.Config{}, logger)
    41  
    42  	mod := New()
    43  	classConfig := fakeClassConfig(mod.ClassConfigDefaults())
    44  	refProp := "someRef"
    45  	classConfig["referenceProperties"] = []interface{}{refProp}
    46  
    47  	t.Run("Init", func(t *testing.T) {
    48  		err := mod.Init(ctx, params)
    49  		assert.Nil(t, err)
    50  	})
    51  
    52  	t.Run("RootHandler", func(t *testing.T) {
    53  		h := mod.RootHandler()
    54  		assert.Nil(t, h)
    55  	})
    56  
    57  	t.Run("Type", func(t *testing.T) {
    58  		typ := mod.Type()
    59  		assert.Equal(t, modulecapabilities.Ref2Vec, typ)
    60  	})
    61  
    62  	t.Run("Name", func(t *testing.T) {
    63  		name := mod.Name()
    64  		assert.Equal(t, Name, name)
    65  	})
    66  
    67  	t.Run("MetaInfo", func(t *testing.T) {
    68  		meta, err := mod.MetaInfo()
    69  		assert.Nil(t, err)
    70  		assert.Empty(t, meta)
    71  	})
    72  
    73  	t.Run("PropertyConfigDefaults", func(t *testing.T) {
    74  		dt := schema.DataType("dataType")
    75  		props := mod.PropertyConfigDefaults(&dt)
    76  		assert.Nil(t, props)
    77  	})
    78  
    79  	t.Run("ValidateClass", func(t *testing.T) {
    80  		t.Run("expected success", func(t *testing.T) {
    81  			class := &models.Class{}
    82  
    83  			err := mod.ValidateClass(ctx, class, classConfig)
    84  			assert.Nil(t, err)
    85  		})
    86  
    87  		t.Run("expected error", func(t *testing.T) {
    88  			class := &models.Class{Class: "InvalidConfigClass"}
    89  			cfg := fakeClassConfig{}
    90  
    91  			expectedErr := fmt.Sprintf(
    92  				"validate %q: invalid config: must have at least one "+
    93  					"value in the \"referenceProperties\" field",
    94  				class.Class)
    95  			err := mod.ValidateClass(ctx, class, cfg)
    96  			assert.EqualError(t, err, expectedErr)
    97  		})
    98  	})
    99  
   100  	t.Run("VectorizeObject", func(t *testing.T) {
   101  		t.Run("expected success", func(t *testing.T) {
   102  			t.Run("one refVec", func(t *testing.T) {
   103  				repo := &fakeObjectsRepo{}
   104  				ref := crossref.New("localhost", "SomeClass", strfmt.UUID(uuid.NewString()))
   105  				obj := &models.Object{Properties: map[string]interface{}{
   106  					refProp: models.MultipleRef{ref.SingleRef()},
   107  				}}
   108  
   109  				repo.On("Object", ctx, ref.Class, ref.TargetID).
   110  					Return(&search.Result{Vector: []float32{1, 2, 3}}, nil)
   111  
   112  				vec, err := mod.VectorizeObject(ctx, obj, classConfig, repo.Object)
   113  				assert.Nil(t, err)
   114  				expectedVec := models.C11yVector{1, 2, 3}
   115  				assert.EqualValues(t, expectedVec, vec)
   116  			})
   117  
   118  			t.Run("no refVecs", func(t *testing.T) {
   119  				repo := &fakeObjectsRepo{}
   120  				ref := crossref.New("localhost", "SomeClass", strfmt.UUID(uuid.NewString()))
   121  				obj := &models.Object{Properties: map[string]interface{}{
   122  					refProp: models.MultipleRef{ref.SingleRef()},
   123  				}}
   124  
   125  				repo.On("Object", ctx, ref.Class, ref.TargetID).
   126  					Return(&search.Result{}, nil)
   127  
   128  				_, err := mod.VectorizeObject(ctx, obj, classConfig, repo.Object)
   129  				assert.Nil(t, err)
   130  				assert.Nil(t, nil, obj.Vector)
   131  			})
   132  		})
   133  
   134  		t.Run("expected error", func(t *testing.T) {
   135  			t.Run("mismatched refVec lengths", func(t *testing.T) {
   136  				repo := &fakeObjectsRepo{}
   137  				ref1 := crossref.New("localhost", "SomeClass", strfmt.UUID(uuid.NewString()))
   138  				ref2 := crossref.New("localhost", "OtherClass", strfmt.UUID(uuid.NewString()))
   139  				obj := &models.Object{Properties: map[string]interface{}{
   140  					refProp: models.MultipleRef{
   141  						ref1.SingleRef(),
   142  						ref2.SingleRef(),
   143  					},
   144  				}}
   145  				expectedErr := fmt.Errorf("calculate vector: calculate mean: " +
   146  					"found vectors of different length: 2 and 3")
   147  
   148  				repo.On("Object", ctx, ref1.Class, ref1.TargetID).
   149  					Return(&search.Result{Vector: []float32{1, 2}}, nil)
   150  				repo.On("Object", ctx, ref2.Class, ref2.TargetID).
   151  					Return(&search.Result{Vector: []float32{1, 2, 3}}, nil)
   152  
   153  				_, err := mod.VectorizeObject(ctx, obj, classConfig, repo.Object)
   154  				assert.EqualError(t, err, expectedErr.Error())
   155  			})
   156  		})
   157  	})
   158  }
   159  
   160  type fakeObjectsRepo struct {
   161  	mock.Mock
   162  }
   163  
   164  func (r *fakeObjectsRepo) Object(ctx context.Context, class string,
   165  	id strfmt.UUID, props search.SelectProperties,
   166  	addl additional.Properties, tenant string,
   167  ) (*search.Result, error) {
   168  	args := r.Called(ctx, class, id)
   169  	if args.Get(0) == nil {
   170  		return nil, args.Error(1)
   171  	}
   172  	return args.Get(0).(*search.Result), args.Error(1)
   173  }