github.com/weaviate/weaviate@v1.24.6/modules/ref2vec-centroid/module_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package modcentroid 13 14 import ( 15 "context" 16 "fmt" 17 "testing" 18 "time" 19 20 "github.com/go-openapi/strfmt" 21 "github.com/google/uuid" 22 "github.com/sirupsen/logrus/hooks/test" 23 "github.com/stretchr/testify/assert" 24 "github.com/stretchr/testify/mock" 25 "github.com/weaviate/weaviate/entities/additional" 26 "github.com/weaviate/weaviate/entities/models" 27 "github.com/weaviate/weaviate/entities/modulecapabilities" 28 "github.com/weaviate/weaviate/entities/moduletools" 29 "github.com/weaviate/weaviate/entities/schema" 30 "github.com/weaviate/weaviate/entities/schema/crossref" 31 "github.com/weaviate/weaviate/entities/search" 32 "github.com/weaviate/weaviate/usecases/config" 33 ) 34 35 func TestRef2VecCentroid(t *testing.T) { 36 ctx, cancel := context.WithTimeout(context.Background(), time.Minute) 37 defer cancel() 38 sp := newFakeStorageProvider(t) 39 logger, _ := test.NewNullLogger() 40 params := moduletools.NewInitParams(sp, nil, config.Config{}, logger) 41 42 mod := New() 43 classConfig := fakeClassConfig(mod.ClassConfigDefaults()) 44 refProp := "someRef" 45 classConfig["referenceProperties"] = []interface{}{refProp} 46 47 t.Run("Init", func(t *testing.T) { 48 err := mod.Init(ctx, params) 49 assert.Nil(t, err) 50 }) 51 52 t.Run("RootHandler", func(t *testing.T) { 53 h := mod.RootHandler() 54 assert.Nil(t, h) 55 }) 56 57 t.Run("Type", func(t *testing.T) { 58 typ := mod.Type() 59 assert.Equal(t, modulecapabilities.Ref2Vec, typ) 60 }) 61 62 t.Run("Name", func(t *testing.T) { 63 name := mod.Name() 64 assert.Equal(t, Name, name) 65 }) 66 67 t.Run("MetaInfo", func(t *testing.T) { 68 meta, err := mod.MetaInfo() 69 assert.Nil(t, err) 70 assert.Empty(t, meta) 71 }) 72 73 t.Run("PropertyConfigDefaults", func(t *testing.T) { 74 dt := schema.DataType("dataType") 75 props := mod.PropertyConfigDefaults(&dt) 76 assert.Nil(t, props) 77 }) 78 79 t.Run("ValidateClass", func(t *testing.T) { 80 t.Run("expected success", func(t *testing.T) { 81 class := &models.Class{} 82 83 err := mod.ValidateClass(ctx, class, classConfig) 84 assert.Nil(t, err) 85 }) 86 87 t.Run("expected error", func(t *testing.T) { 88 class := &models.Class{Class: "InvalidConfigClass"} 89 cfg := fakeClassConfig{} 90 91 expectedErr := fmt.Sprintf( 92 "validate %q: invalid config: must have at least one "+ 93 "value in the \"referenceProperties\" field", 94 class.Class) 95 err := mod.ValidateClass(ctx, class, cfg) 96 assert.EqualError(t, err, expectedErr) 97 }) 98 }) 99 100 t.Run("VectorizeObject", func(t *testing.T) { 101 t.Run("expected success", func(t *testing.T) { 102 t.Run("one refVec", func(t *testing.T) { 103 repo := &fakeObjectsRepo{} 104 ref := crossref.New("localhost", "SomeClass", strfmt.UUID(uuid.NewString())) 105 obj := &models.Object{Properties: map[string]interface{}{ 106 refProp: models.MultipleRef{ref.SingleRef()}, 107 }} 108 109 repo.On("Object", ctx, ref.Class, ref.TargetID). 110 Return(&search.Result{Vector: []float32{1, 2, 3}}, nil) 111 112 vec, err := mod.VectorizeObject(ctx, obj, classConfig, repo.Object) 113 assert.Nil(t, err) 114 expectedVec := models.C11yVector{1, 2, 3} 115 assert.EqualValues(t, expectedVec, vec) 116 }) 117 118 t.Run("no refVecs", func(t *testing.T) { 119 repo := &fakeObjectsRepo{} 120 ref := crossref.New("localhost", "SomeClass", strfmt.UUID(uuid.NewString())) 121 obj := &models.Object{Properties: map[string]interface{}{ 122 refProp: models.MultipleRef{ref.SingleRef()}, 123 }} 124 125 repo.On("Object", ctx, ref.Class, ref.TargetID). 126 Return(&search.Result{}, nil) 127 128 _, err := mod.VectorizeObject(ctx, obj, classConfig, repo.Object) 129 assert.Nil(t, err) 130 assert.Nil(t, nil, obj.Vector) 131 }) 132 }) 133 134 t.Run("expected error", func(t *testing.T) { 135 t.Run("mismatched refVec lengths", func(t *testing.T) { 136 repo := &fakeObjectsRepo{} 137 ref1 := crossref.New("localhost", "SomeClass", strfmt.UUID(uuid.NewString())) 138 ref2 := crossref.New("localhost", "OtherClass", strfmt.UUID(uuid.NewString())) 139 obj := &models.Object{Properties: map[string]interface{}{ 140 refProp: models.MultipleRef{ 141 ref1.SingleRef(), 142 ref2.SingleRef(), 143 }, 144 }} 145 expectedErr := fmt.Errorf("calculate vector: calculate mean: " + 146 "found vectors of different length: 2 and 3") 147 148 repo.On("Object", ctx, ref1.Class, ref1.TargetID). 149 Return(&search.Result{Vector: []float32{1, 2}}, nil) 150 repo.On("Object", ctx, ref2.Class, ref2.TargetID). 151 Return(&search.Result{Vector: []float32{1, 2, 3}}, nil) 152 153 _, err := mod.VectorizeObject(ctx, obj, classConfig, repo.Object) 154 assert.EqualError(t, err, expectedErr.Error()) 155 }) 156 }) 157 }) 158 } 159 160 type fakeObjectsRepo struct { 161 mock.Mock 162 } 163 164 func (r *fakeObjectsRepo) Object(ctx context.Context, class string, 165 id strfmt.UUID, props search.SelectProperties, 166 addl additional.Properties, tenant string, 167 ) (*search.Result, error) { 168 args := r.Called(ctx, class, id) 169 if args.Get(0) == nil { 170 return nil, args.Error(1) 171 } 172 return args.Get(0).(*search.Result), args.Error(1) 173 }