github.com/weaviate/weaviate@v1.24.6/modules/text2vec-contextionary/helpers_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package modcontextionary 13 14 import ( 15 "context" 16 "fmt" 17 "net/http" 18 19 "github.com/sirupsen/logrus/hooks/test" 20 "github.com/tailor-inc/graphql" 21 "github.com/tailor-inc/graphql/language/ast" 22 "github.com/weaviate/weaviate/adapters/handlers/graphql/local/explore" 23 "github.com/weaviate/weaviate/adapters/handlers/graphql/local/get" 24 test_helper "github.com/weaviate/weaviate/adapters/handlers/graphql/test/helper" 25 "github.com/weaviate/weaviate/entities/dto" 26 "github.com/weaviate/weaviate/entities/models" 27 "github.com/weaviate/weaviate/entities/modulecapabilities" 28 "github.com/weaviate/weaviate/entities/moduletools" 29 "github.com/weaviate/weaviate/entities/search" 30 text2vecadditional "github.com/weaviate/weaviate/modules/text2vec-contextionary/additional" 31 text2vecadditionalsempath "github.com/weaviate/weaviate/modules/text2vec-contextionary/additional/sempath" 32 text2vecadditionalprojector "github.com/weaviate/weaviate/usecases/modulecomponents/additional/projector" 33 text2vecneartext "github.com/weaviate/weaviate/usecases/modulecomponents/arguments/nearText" 34 "github.com/weaviate/weaviate/usecases/traverser" 35 ) 36 37 type mockRequestsLog struct{} 38 39 func (m *mockRequestsLog) Register(first string, second string) { 40 } 41 42 type mockResolver struct { 43 test_helper.MockResolver 44 } 45 46 type fakeInterpretation struct{} 47 48 func (f *fakeInterpretation) AdditionalPropertyFn(ctx context.Context, 49 in []search.Result, params interface{}, limit *int, 50 argumentModuleParams map[string]interface{}, cfg moduletools.ClassConfig, 51 ) ([]search.Result, error) { 52 return in, nil 53 } 54 55 func (f *fakeInterpretation) ExtractAdditionalFn(param []*ast.Argument) interface{} { 56 return true 57 } 58 59 func (f *fakeInterpretation) AdditionalPropertyDefaultValue() interface{} { 60 return true 61 } 62 63 type fakeExtender struct { 64 returnArgs []search.Result 65 } 66 67 func (f *fakeExtender) AdditionalPropertyFn(ctx context.Context, 68 in []search.Result, params interface{}, limit *int, 69 argumentModuleParams map[string]interface{}, cfg moduletools.ClassConfig, 70 ) ([]search.Result, error) { 71 return f.returnArgs, nil 72 } 73 74 func (f *fakeExtender) ExtractAdditionalFn(param []*ast.Argument) interface{} { 75 return true 76 } 77 78 func (f *fakeExtender) AdditionalPropertyDefaultValue() interface{} { 79 return true 80 } 81 82 type fakeProjector struct { 83 returnArgs []search.Result 84 } 85 86 func (f *fakeProjector) AdditionalPropertyFn(ctx context.Context, 87 in []search.Result, params interface{}, limit *int, 88 argumentModuleParams map[string]interface{}, cfg moduletools.ClassConfig, 89 ) ([]search.Result, error) { 90 return f.returnArgs, nil 91 } 92 93 func (f *fakeProjector) ExtractAdditionalFn(param []*ast.Argument) interface{} { 94 if len(param) > 0 { 95 p := &text2vecadditionalprojector.Params{} 96 err := p.SetDefaultsAndValidate(100, 4) 97 if err != nil { 98 return nil 99 } 100 return p 101 } 102 return &text2vecadditionalprojector.Params{ 103 Enabled: true, 104 } 105 } 106 107 func (f *fakeProjector) AdditionalPropertyDefaultValue() interface{} { 108 return &text2vecadditionalprojector.Params{} 109 } 110 111 type fakePathBuilder struct { 112 returnArgs []search.Result 113 } 114 115 func (f *fakePathBuilder) AdditionalPropertyFn(ctx context.Context, 116 in []search.Result, params interface{}, limit *int, 117 argumentModuleParams map[string]interface{}, cfg moduletools.ClassConfig, 118 ) ([]search.Result, error) { 119 return f.returnArgs, nil 120 } 121 122 func (f *fakePathBuilder) ExtractAdditionalFn(param []*ast.Argument) interface{} { 123 return &text2vecadditionalsempath.Params{} 124 } 125 126 func (f *fakePathBuilder) AdditionalPropertyDefaultValue() interface{} { 127 return &text2vecadditionalsempath.Params{} 128 } 129 130 type mockText2vecContextionaryModule struct{} 131 132 func (m *mockText2vecContextionaryModule) Name() string { 133 return "text2vec-contextionary" 134 } 135 136 func (m *mockText2vecContextionaryModule) Init(params moduletools.ModuleInitParams) error { 137 return nil 138 } 139 140 func (m *mockText2vecContextionaryModule) RootHandler() http.Handler { 141 return nil 142 } 143 144 // graphql arguments 145 func (m *mockText2vecContextionaryModule) Arguments() map[string]modulecapabilities.GraphQLArgument { 146 return text2vecneartext.New(nil).Arguments() 147 } 148 149 // additional properties 150 func (m *mockText2vecContextionaryModule) AdditionalProperties() map[string]modulecapabilities.AdditionalProperty { 151 return text2vecadditional.New(&fakeExtender{}, &fakeProjector{}, &fakePathBuilder{}, &fakeInterpretation{}).AdditionalProperties() 152 } 153 154 type fakeModulesProvider struct{} 155 156 func (fmp *fakeModulesProvider) GetAll() []modulecapabilities.Module { 157 panic("implement me") 158 } 159 160 func (fmp *fakeModulesProvider) VectorFromInput(ctx context.Context, className string, input string) ([]float32, error) { 161 panic("not implemented") 162 } 163 164 func (fmp *fakeModulesProvider) GetArguments(class *models.Class) map[string]*graphql.ArgumentConfig { 165 args := map[string]*graphql.ArgumentConfig{} 166 txt2vec := &mockText2vecContextionaryModule{} 167 if class.Vectorizer == txt2vec.Name() { 168 for name, argument := range txt2vec.Arguments() { 169 args[name] = argument.GetArgumentsFunction(class.Class) 170 } 171 } 172 return args 173 } 174 175 func (fmp *fakeModulesProvider) ExploreArguments(schema *models.Schema) map[string]*graphql.ArgumentConfig { 176 args := map[string]*graphql.ArgumentConfig{} 177 txt2vec := &mockText2vecContextionaryModule{} 178 for _, c := range schema.Classes { 179 if c.Vectorizer == txt2vec.Name() { 180 for name, argument := range txt2vec.Arguments() { 181 args[name] = argument.ExploreArgumentsFunction() 182 } 183 } 184 } 185 return args 186 } 187 188 func (fmp *fakeModulesProvider) CrossClassExtractSearchParams(arguments map[string]interface{}) map[string]interface{} { 189 return fmp.ExtractSearchParams(arguments, "") 190 } 191 192 func (fmp *fakeModulesProvider) ExtractSearchParams(arguments map[string]interface{}, className string) map[string]interface{} { 193 exractedParams := map[string]interface{}{} 194 if param, ok := arguments["nearText"]; ok { 195 exractedParams["nearText"] = extractNearTextParam(param.(map[string]interface{})) 196 } 197 return exractedParams 198 } 199 200 func (fmp *fakeModulesProvider) GetAdditionalFields(class *models.Class) map[string]*graphql.Field { 201 txt2vec := &mockText2vecContextionaryModule{} 202 additionalProperties := map[string]*graphql.Field{} 203 for name, additionalProperty := range txt2vec.AdditionalProperties() { 204 if additionalProperty.GraphQLFieldFunction != nil { 205 additionalProperties[name] = additionalProperty.GraphQLFieldFunction(class.Class) 206 } 207 } 208 return additionalProperties 209 } 210 211 func (fmp *fakeModulesProvider) ExtractAdditionalField(className, name string, params []*ast.Argument) interface{} { 212 txt2vec := &mockText2vecContextionaryModule{} 213 if additionalProperties := txt2vec.AdditionalProperties(); len(additionalProperties) > 0 { 214 if additionalProperty, ok := additionalProperties[name]; ok { 215 if additionalProperty.GraphQLExtractFunction != nil { 216 return additionalProperty.GraphQLExtractFunction(params) 217 } 218 } 219 } 220 return nil 221 } 222 223 func (fmp *fakeModulesProvider) GetExploreAdditionalExtend(ctx context.Context, in []search.Result, 224 moduleParams map[string]interface{}, searchVector []float32, 225 argumentModuleParams map[string]interface{}, cfg moduletools.ClassConfig, 226 ) ([]search.Result, error) { 227 return fmp.additionalExtend(ctx, in, moduleParams, searchVector, "ExploreGet", argumentModuleParams, nil) 228 } 229 230 func (fmp *fakeModulesProvider) additionalExtend(ctx context.Context, 231 in search.Results, moduleParams map[string]interface{}, 232 searchVector []float32, capability string, argumentModuleParams map[string]interface{}, cfg moduletools.ClassConfig, 233 ) (search.Results, error) { 234 txt2vec := &mockText2vecContextionaryModule{} 235 additionalProperties := txt2vec.AdditionalProperties() 236 for name, value := range moduleParams { 237 additionalPropertyFn := fmp.getAdditionalPropertyFn(additionalProperties[name], capability) 238 if additionalPropertyFn != nil && value != nil { 239 searchValue := value 240 if searchVectorValue, ok := value.(modulecapabilities.AdditionalPropertyWithSearchVector); ok { 241 searchVectorValue.SetSearchVector(searchVector) 242 searchValue = searchVectorValue 243 } 244 resArray, err := additionalPropertyFn(ctx, in, searchValue, nil, nil, nil) 245 if err != nil { 246 return nil, err 247 } 248 in = resArray 249 } 250 } 251 return in, nil 252 } 253 254 func (fmp *fakeModulesProvider) getAdditionalPropertyFn(additionalProperty modulecapabilities.AdditionalProperty, 255 capability string, 256 ) modulecapabilities.AdditionalPropertyFn { 257 switch capability { 258 case "ObjectGet": 259 return additionalProperty.SearchFunctions.ObjectGet 260 case "ObjectList": 261 return additionalProperty.SearchFunctions.ObjectList 262 case "ExploreGet": 263 return additionalProperty.SearchFunctions.ExploreGet 264 case "ExploreList": 265 return additionalProperty.SearchFunctions.ExploreList 266 default: 267 return nil 268 } 269 } 270 271 func (fmp *fakeModulesProvider) GraphQLAdditionalFieldNames() []string { 272 txt2vec := &mockText2vecContextionaryModule{} 273 additionalPropertiesNames := []string{} 274 for _, additionalProperty := range txt2vec.AdditionalProperties() { 275 if additionalProperty.GraphQLNames != nil { 276 additionalPropertiesNames = append(additionalPropertiesNames, additionalProperty.GraphQLNames...) 277 } 278 } 279 return additionalPropertiesNames 280 } 281 282 func extractNearTextParam(param map[string]interface{}) interface{} { 283 txt2vec := &mockText2vecContextionaryModule{} 284 argument := txt2vec.Arguments()["nearText"] 285 return argument.ExtractFunction(param) 286 } 287 288 func createArg(name string, value string) *ast.Argument { 289 n := ast.Name{ 290 Value: name, 291 } 292 val := ast.StringValue{ 293 Kind: "Kind", 294 Value: value, 295 } 296 arg := ast.Argument{ 297 Name: ast.NewName(&n), 298 Kind: "Kind", 299 Value: ast.NewStringValue(&val), 300 } 301 a := ast.NewArgument(&arg) 302 return a 303 } 304 305 func extractAdditionalParam(name string, args []*ast.Argument) interface{} { 306 txt2vec := &mockText2vecContextionaryModule{} 307 additionalProperties := txt2vec.AdditionalProperties() 308 switch name { 309 case "semanticPath", "featureProjection": 310 if ap, ok := additionalProperties[name]; ok { 311 return ap.GraphQLExtractFunction(args) 312 } 313 return nil 314 default: 315 return nil 316 } 317 } 318 319 func getFakeModulesProvider() *fakeModulesProvider { 320 return &fakeModulesProvider{} 321 } 322 323 func newMockResolver() *mockResolver { 324 logger, _ := test.NewNullLogger() 325 field, err := get.Build(&test_helper.SimpleSchema, logger, getFakeModulesProvider()) 326 if err != nil { 327 panic(fmt.Sprintf("could not build graphql test schema: %s", err)) 328 } 329 mocker := &mockResolver{} 330 mockLog := &mockRequestsLog{} 331 mocker.RootFieldName = "Get" 332 mocker.RootField = field 333 mocker.RootObject = map[string]interface{}{"Resolver": GetResolver(mocker), "RequestsLog": RequestsLog(mockLog)} 334 return mocker 335 } 336 337 func newExploreMockResolver() *mockResolver { 338 field := explore.Build(test_helper.SimpleSchema.Objects, getFakeModulesProvider()) 339 mocker := &mockResolver{} 340 mockLog := &mockRequestsLog{} 341 mocker.RootFieldName = "Explore" 342 mocker.RootField = field 343 mocker.RootObject = map[string]interface{}{ 344 "Resolver": ExploreResolver(mocker), 345 "RequestsLog": mockLog, 346 } 347 return mocker 348 } 349 350 func (m *mockResolver) GetClass(ctx context.Context, principal *models.Principal, 351 params dto.GetParams, 352 ) ([]interface{}, error) { 353 args := m.Called(params) 354 return args.Get(0).([]interface{}), args.Error(1) 355 } 356 357 func (m *mockResolver) Explore(ctx context.Context, 358 principal *models.Principal, params traverser.ExploreParams, 359 ) ([]search.Result, error) { 360 args := m.Called(params) 361 return args.Get(0).([]search.Result), args.Error(1) 362 } 363 364 // Resolver is a local abstraction of the required UC resolvers 365 type GetResolver interface { 366 GetClass(ctx context.Context, principal *models.Principal, info dto.GetParams) ([]interface{}, error) 367 } 368 369 type ExploreResolver interface { 370 Explore(ctx context.Context, principal *models.Principal, params traverser.ExploreParams) ([]search.Result, error) 371 } 372 373 // RequestsLog is a local abstraction on the RequestsLog that needs to be 374 // provided to the graphQL API in order to log Local.Get queries. 375 type RequestsLog interface { 376 Register(requestType string, identifier string) 377 }