github.com/weaviate/weaviate@v1.24.6/modules/ref2vec-centroid/vectorizer/vectorizer_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package vectorizer 13 14 import ( 15 "context" 16 "errors" 17 "reflect" 18 "testing" 19 20 "github.com/go-openapi/strfmt" 21 "github.com/google/uuid" 22 "github.com/stretchr/testify/assert" 23 "github.com/weaviate/weaviate/entities/models" 24 "github.com/weaviate/weaviate/entities/schema/crossref" 25 "github.com/weaviate/weaviate/entities/search" 26 "github.com/weaviate/weaviate/modules/ref2vec-centroid/config" 27 ) 28 29 func TestVectorizer_New(t *testing.T) { 30 repo := &fakeObjectsRepo{} 31 t.Run("default is set correctly", func(t *testing.T) { 32 vzr := New(fakeClassConfig(config.Default()), repo.Object) 33 34 expected := reflect.ValueOf(calculateMean).Pointer() 35 received := reflect.ValueOf(vzr.calcFn).Pointer() 36 37 assert.EqualValues(t, expected, received) 38 }) 39 40 t.Run("default calcFn is used when none provided", func(t *testing.T) { 41 cfg := fakeClassConfig{"method": ""} 42 vzr := New(cfg, repo.Object) 43 44 expected := reflect.ValueOf(calculateMean).Pointer() 45 received := reflect.ValueOf(vzr.calcFn).Pointer() 46 47 assert.EqualValues(t, expected, received) 48 }) 49 } 50 51 func TestVectorizer_Object(t *testing.T) { 52 t.Run("calculate with mean", func(t *testing.T) { 53 type objectSearchResult struct { 54 res *search.Result 55 err error 56 } 57 58 tests := []struct { 59 name string 60 objectSearchResults []objectSearchResult 61 expectedResult []float32 62 expectedCalcError error 63 }{ 64 { 65 name: "expected success 1", 66 objectSearchResults: []objectSearchResult{ 67 {res: &search.Result{Vector: []float32{2, 4, 6}}}, 68 {res: &search.Result{Vector: []float32{4, 6, 8}}}, 69 }, 70 expectedResult: []float32{3, 5, 7}, 71 }, 72 { 73 name: "expected success 2", 74 objectSearchResults: []objectSearchResult{ 75 {res: &search.Result{Vector: []float32{1, 1, 1, 1, 1, 1, 1, 1, 1, 1}}}, 76 {res: &search.Result{Vector: []float32{2, 2, 2, 2, 2, 2, 2, 2, 2, 2}}}, 77 {res: &search.Result{Vector: []float32{3, 3, 3, 3, 3, 3, 3, 3, 3, 3}}}, 78 {res: &search.Result{Vector: []float32{4, 4, 4, 4, 4, 4, 4, 4, 4, 4}}}, 79 {res: &search.Result{Vector: []float32{5, 5, 5, 5, 5, 5, 5, 5, 5, 5}}}, 80 {res: &search.Result{Vector: []float32{6, 6, 6, 6, 6, 6, 6, 6, 6, 6}}}, 81 {res: &search.Result{Vector: []float32{7, 7, 7, 7, 7, 7, 7, 7, 7, 7}}}, 82 {res: &search.Result{Vector: []float32{8, 8, 8, 8, 8, 8, 8, 8, 8, 8}}}, 83 {res: &search.Result{Vector: []float32{9, 9, 9, 9, 9, 9, 9, 9, 9, 9}}}, 84 }, 85 expectedResult: []float32{5, 5, 5, 5, 5, 5, 5, 5, 5, 5}, 86 }, 87 { 88 name: "expected success 3", 89 objectSearchResults: []objectSearchResult{{}}, 90 }, 91 { 92 name: "expected success 4", 93 objectSearchResults: []objectSearchResult{ 94 {res: &search.Result{Vector: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9}}}, 95 }, 96 expectedResult: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9}, 97 }, 98 { 99 name: "expected success 5", 100 objectSearchResults: []objectSearchResult{ 101 {res: &search.Result{}}, 102 }, 103 expectedResult: nil, 104 }, 105 { 106 name: "expected error - mismatched vector dimensions", 107 objectSearchResults: []objectSearchResult{ 108 {res: &search.Result{Vector: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9}}}, 109 {res: &search.Result{Vector: []float32{1, 2, 3, 4, 5, 6, 7, 8}}}, 110 }, 111 expectedCalcError: errors.New( 112 "calculate vector: calculate mean: found vectors of different length: 9 and 8"), 113 }, 114 } 115 116 for _, test := range tests { 117 t.Run(test.name, func(t *testing.T) { 118 ctx := context.Background() 119 repo := &fakeObjectsRepo{} 120 refProps := []interface{}{"toRef"} 121 cfg := fakeClassConfig{"method": "mean", "referenceProperties": refProps} 122 123 crossRefs := make([]*crossref.Ref, len(test.objectSearchResults)) 124 modelRefs := make(models.MultipleRef, len(test.objectSearchResults)) 125 for i, res := range test.objectSearchResults { 126 crossRef := crossref.New("localhost", "SomeClass", 127 strfmt.UUID(uuid.NewString())) 128 crossRefs[i] = crossRef 129 modelRefs[i] = crossRef.SingleRef() 130 131 repo.On("Object", ctx, crossRef.Class, crossRef.TargetID, ""). 132 Return(res.res, res.err) 133 } 134 135 obj := &models.Object{ 136 Properties: map[string]interface{}{"toRef": modelRefs}, 137 } 138 139 vec, err := New(cfg, repo.Object).Object(ctx, obj) 140 if test.expectedCalcError != nil { 141 assert.EqualError(t, err, test.expectedCalcError.Error()) 142 } else { 143 assert.EqualValues(t, test.expectedResult, vec) 144 } 145 }) 146 } 147 }) 148 149 // due to the fix introduced in https://github.com/weaviate/weaviate/pull/2320, 150 // MultipleRef's can appear as empty []interface{} when no actual refs are provided for 151 // an object's reference property. 152 // 153 // this test asserts that reference properties do not break when they are unmarshalled 154 // as empty interface{} slices. 155 t.Run("when rep prop is stored as empty interface{} slice", func(t *testing.T) { 156 ctx := context.Background() 157 repo := &fakeObjectsRepo{} 158 refProps := []interface{}{"toRef"} 159 cfg := fakeClassConfig{"method": "mean", "referenceProperties": refProps} 160 161 obj := &models.Object{ 162 Properties: map[string]interface{}{"toRef": []interface{}{}}, 163 } 164 165 _, err := New(cfg, repo.Object).Object(ctx, obj) 166 assert.Nil(t, err) 167 assert.Nil(t, obj.Vector) 168 }) 169 } 170 171 func TestVectorizer_Tenant(t *testing.T) { 172 objectSearchResults := search.Result{Vector: []float32{}} 173 ctx := context.Background() 174 repo := &fakeObjectsRepo{} 175 refProps := []interface{}{"toRef"} 176 cfg := fakeClassConfig{"method": "mean", "referenceProperties": refProps} 177 tenant := "randomTenant" 178 179 crossRef := crossref.New("localhost", "SomeClass", 180 strfmt.UUID(uuid.NewString())) 181 modelRefs := models.MultipleRef{crossRef.SingleRef()} 182 183 repo.On("Object", ctx, crossRef.Class, crossRef.TargetID, tenant). 184 Return(&objectSearchResults, nil) 185 186 obj := &models.Object{ 187 Properties: map[string]interface{}{"toRef": modelRefs}, 188 Tenant: tenant, 189 } 190 191 _, err := New(cfg, repo.Object).Object(ctx, obj) 192 assert.Nil(t, err) 193 assert.Nil(t, obj.Vector) 194 }