github.com/weaviate/weaviate@v1.24.6/usecases/modules/searchers_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package modules 13 14 import ( 15 "context" 16 "testing" 17 18 "github.com/go-openapi/strfmt" 19 "github.com/sirupsen/logrus/hooks/test" 20 "github.com/stretchr/testify/assert" 21 "github.com/stretchr/testify/require" 22 "github.com/weaviate/weaviate/entities/models" 23 "github.com/weaviate/weaviate/entities/modulecapabilities" 24 "github.com/weaviate/weaviate/entities/moduletools" 25 "github.com/weaviate/weaviate/entities/schema" 26 ) 27 28 func TestModulesWithSearchers(t *testing.T) { 29 sch := schema.Schema{ 30 Objects: &models.Schema{ 31 Classes: []*models.Class{ 32 { 33 Class: "MyClass", 34 Vectorizer: "mod", 35 ModuleConfig: map[string]interface{}{ 36 "mod": map[string]interface{}{ 37 "some-config": "some-config-value", 38 }, 39 }, 40 }, 41 }, 42 }, 43 } 44 logger, _ := test.NewNullLogger() 45 46 t.Run("get a vector for a class", func(t *testing.T) { 47 p := NewProvider() 48 p.SetSchemaGetter(&fakeSchemaGetter{ 49 schema: sch, 50 }) 51 p.Register(newSearcherModule("mod"). 52 withArg("nearGrape"). 53 withSearcher("nearGrape", func(ctx context.Context, params interface{}, 54 className string, 55 findVectorFn modulecapabilities.FindVectorFn, 56 cfg moduletools.ClassConfig, 57 ) ([]float32, error) { 58 // verify that the config tool is set, as this is a per-class search, 59 // so it must be set 60 assert.NotNil(t, cfg) 61 62 // take the findVectorFn and append one dimension. This doesn't make too 63 // much sense, but helps verify that the modules method was used in the 64 // decisions 65 initial, _, _ := findVectorFn(ctx, "class", "123", "", "") 66 return append(initial, 4), nil 67 }), 68 ) 69 p.Init(context.Background(), nil, logger) 70 71 res, targetVector, err := p.VectorFromSearchParam(context.Background(), "MyClass", 72 "nearGrape", nil, fakeFindVector, "") 73 74 require.Nil(t, err) 75 assert.Equal(t, []float32{1, 2, 3, 4}, res) 76 assert.Equal(t, "", targetVector) 77 }) 78 79 t.Run("get a vector across classes", func(t *testing.T) { 80 p := NewProvider() 81 p.SetSchemaGetter(&fakeSchemaGetter{ 82 schema: sch, 83 }) 84 p.Register(newSearcherModule("mod"). 85 withArg("nearGrape"). 86 withSearcher("nearGrape", func(ctx context.Context, params interface{}, 87 className string, 88 findVectorFn modulecapabilities.FindVectorFn, 89 cfg moduletools.ClassConfig, 90 ) ([]float32, error) { 91 // this is a cross-class search, such as is used for Explore{}, in this 92 // case we do not have class-based config, but we need at least pass 93 // a tenant information, that's why we pass an empty config with empty tenant 94 // so that it would be possible to perform cross class searches, without 95 // tenant context. Modules must be able to deal with this situation! 96 assert.NotNil(t, cfg) 97 assert.Equal(t, "", cfg.Tenant()) 98 99 // take the findVectorFn and append one dimension. This doesn't make too 100 // much sense, but helps verify that the modules method was used in the 101 // decisions 102 initial, _, _ := findVectorFn(ctx, "class", "123", "", "") 103 return append(initial, 4), nil 104 }), 105 ) 106 p.Init(context.Background(), nil, logger) 107 108 res, targetVector, err := p.CrossClassVectorFromSearchParam(context.Background(), 109 "nearGrape", nil, fakeFindVector) 110 111 require.Nil(t, err) 112 assert.Equal(t, []float32{1, 2, 3, 4}, res) 113 assert.Equal(t, "", targetVector) 114 }) 115 } 116 117 func fakeFindVector(ctx context.Context, className string, id strfmt.UUID, tenant, targetVector string) ([]float32, string, error) { 118 return []float32{1, 2, 3}, targetVector, nil 119 } 120 121 func newSearcherModule(name string) *dummySearcherModule { 122 return &dummySearcherModule{ 123 dummyGraphQLModule: newGraphQLModule(name), 124 searchers: map[string]modulecapabilities.VectorForParams{}, 125 } 126 } 127 128 type dummySearcherModule struct { 129 *dummyGraphQLModule 130 searchers map[string]modulecapabilities.VectorForParams 131 } 132 133 func (m *dummySearcherModule) withArg(arg string) *dummySearcherModule { 134 // call the super's withArg 135 m.dummyGraphQLModule.withArg(arg) 136 137 // but don't return their return type but ours :) 138 return m 139 } 140 141 // a helper for our test 142 func (m *dummySearcherModule) withSearcher(arg string, 143 impl modulecapabilities.VectorForParams, 144 ) *dummySearcherModule { 145 m.searchers[arg] = impl 146 return m 147 } 148 149 // public method to implement the modulecapabilities.Searcher interface 150 func (m *dummySearcherModule) VectorSearches() map[string]modulecapabilities.VectorForParams { 151 return m.searchers 152 }