github.com/weaviate/weaviate@v1.24.6/adapters/handlers/graphql/local/get/helper_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package get 13 14 import ( 15 "context" 16 "fmt" 17 "net/http" 18 19 "github.com/sirupsen/logrus/hooks/test" 20 "github.com/tailor-inc/graphql" 21 "github.com/tailor-inc/graphql/language/ast" 22 "github.com/weaviate/weaviate/adapters/handlers/graphql/descriptions" 23 test_helper "github.com/weaviate/weaviate/adapters/handlers/graphql/test/helper" 24 "github.com/weaviate/weaviate/entities/dto" 25 "github.com/weaviate/weaviate/entities/models" 26 "github.com/weaviate/weaviate/entities/modulecapabilities" 27 "github.com/weaviate/weaviate/entities/moduletools" 28 "github.com/weaviate/weaviate/entities/search" 29 "github.com/weaviate/weaviate/usecases/config" 30 ) 31 32 type mockRequestsLog struct{} 33 34 func (m *mockRequestsLog) Register(first string, second string) { 35 } 36 37 type mockResolver struct { 38 test_helper.MockResolver 39 } 40 41 type fakeInterpretation struct { 42 returnArgs []search.Result 43 } 44 45 func (f *fakeInterpretation) AdditionalPropertyFn(ctx context.Context, 46 in []search.Result, params interface{}, limit *int, 47 argumentModuleParams map[string]interface{}, 48 ) ([]search.Result, error) { 49 return f.returnArgs, nil 50 } 51 52 func (f *fakeInterpretation) ExtractAdditionalFn(param []*ast.Argument) interface{} { 53 return true 54 } 55 56 func (f *fakeInterpretation) AdditonalPropertyDefaultValue() interface{} { 57 return true 58 } 59 60 type fakeExtender struct { 61 returnArgs []search.Result 62 } 63 64 func (f *fakeExtender) AdditionalPropertyFn(ctx context.Context, 65 in []search.Result, params interface{}, limit *int, 66 argumentModuleParams map[string]interface{}, 67 ) ([]search.Result, error) { 68 return f.returnArgs, nil 69 } 70 71 func (f *fakeExtender) ExtractAdditionalFn(param []*ast.Argument) interface{} { 72 return true 73 } 74 75 func (f *fakeExtender) AdditonalPropertyDefaultValue() interface{} { 76 return true 77 } 78 79 type fakeProjectorParams struct { 80 Enabled bool 81 Algorithm string 82 Dimensions int 83 Perplexity int 84 Iterations int 85 LearningRate int 86 IncludeNeighbors bool 87 } 88 89 type fakeProjector struct { 90 returnArgs []search.Result 91 } 92 93 func (f *fakeProjector) AdditionalPropertyFn(ctx context.Context, 94 in []search.Result, params interface{}, limit *int, 95 argumentModuleParams map[string]interface{}, 96 ) ([]search.Result, error) { 97 return f.returnArgs, nil 98 } 99 100 func (f *fakeProjector) ExtractAdditionalFn(param []*ast.Argument) interface{} { 101 if len(param) > 0 { 102 return &fakeProjectorParams{ 103 Enabled: true, 104 Algorithm: "tsne", 105 Dimensions: 3, 106 Iterations: 100, 107 LearningRate: 15, 108 Perplexity: 10, 109 } 110 } 111 return &fakeProjectorParams{ 112 Enabled: true, 113 } 114 } 115 116 func (f *fakeProjector) AdditonalPropertyDefaultValue() interface{} { 117 return &fakeProjectorParams{} 118 } 119 120 type pathBuilderParams struct{} 121 122 type fakePathBuilder struct { 123 returnArgs []search.Result 124 } 125 126 func (f *fakePathBuilder) AdditionalPropertyFn(ctx context.Context, 127 in []search.Result, params interface{}, limit *int, 128 ) ([]search.Result, error) { 129 return f.returnArgs, nil 130 } 131 132 func (f *fakePathBuilder) ExtractAdditionalFn(param []*ast.Argument) interface{} { 133 return &pathBuilderParams{} 134 } 135 136 func (f *fakePathBuilder) AdditonalPropertyDefaultValue() interface{} { 137 return &pathBuilderParams{} 138 } 139 140 type nearCustomTextParams struct { 141 Values []string 142 MoveTo nearExploreMove 143 MoveAwayFrom nearExploreMove 144 Certainty float64 145 Distance float64 146 WithDistance bool 147 TargetVectors []string 148 } 149 150 // implements the modulecapabilities.NearParam interface 151 func (n *nearCustomTextParams) GetCertainty() float64 { 152 return n.Certainty 153 } 154 155 func (n nearCustomTextParams) GetDistance() float64 { 156 return n.Distance 157 } 158 159 func (n nearCustomTextParams) SimilarityMetricProvided() bool { 160 return n.Certainty != 0 || n.WithDistance 161 } 162 163 func (n nearCustomTextParams) GetTargetVectors() []string { 164 return n.TargetVectors 165 } 166 167 type nearExploreMove struct { 168 Values []string 169 Force float32 170 Objects []nearObjectMove 171 } 172 173 type nearObjectMove struct { 174 ID string 175 Beacon string 176 } 177 178 type nearCustomTextModule struct { 179 fakePathBuilder *fakePathBuilder 180 fakeProjector *fakeProjector 181 fakeExtender *fakeExtender 182 fakeInterpretation *fakeInterpretation 183 } 184 185 func newNearCustomTextModule() *nearCustomTextModule { 186 return &nearCustomTextModule{ 187 fakePathBuilder: &fakePathBuilder{}, 188 fakeProjector: &fakeProjector{}, 189 fakeExtender: &fakeExtender{}, 190 fakeInterpretation: &fakeInterpretation{}, 191 } 192 } 193 194 func (m *nearCustomTextModule) Name() string { 195 return "mock-custom-near-text-module" 196 } 197 198 func (m *nearCustomTextModule) Init(params moduletools.ModuleInitParams) error { 199 return nil 200 } 201 202 func (m *nearCustomTextModule) RootHandler() http.Handler { 203 return nil 204 } 205 206 func (m *nearCustomTextModule) getNearCustomTextArgument(classname string) *graphql.ArgumentConfig { 207 prefix := classname 208 return &graphql.ArgumentConfig{ 209 Type: graphql.NewInputObject( 210 graphql.InputObjectConfig{ 211 Name: fmt.Sprintf("%sNearCustomTextInpObj", prefix), 212 Fields: graphql.InputObjectConfigFieldMap{ 213 "concepts": &graphql.InputObjectFieldConfig{ 214 Type: graphql.NewNonNull(graphql.NewList(graphql.String)), 215 }, 216 "moveTo": &graphql.InputObjectFieldConfig{ 217 Description: descriptions.VectorMovement, 218 Type: graphql.NewInputObject( 219 graphql.InputObjectConfig{ 220 Name: fmt.Sprintf("%sMoveTo", prefix), 221 Fields: graphql.InputObjectConfigFieldMap{ 222 "concepts": &graphql.InputObjectFieldConfig{ 223 Description: descriptions.Keywords, 224 Type: graphql.NewList(graphql.String), 225 }, 226 "objects": &graphql.InputObjectFieldConfig{ 227 Description: "objects", 228 Type: graphql.NewList(graphql.NewInputObject( 229 graphql.InputObjectConfig{ 230 Name: fmt.Sprintf("%sMovementObjectsToInpObj", prefix), 231 Fields: graphql.InputObjectConfigFieldMap{ 232 "id": &graphql.InputObjectFieldConfig{ 233 Type: graphql.String, 234 Description: "id of an object", 235 }, 236 "beacon": &graphql.InputObjectFieldConfig{ 237 Type: graphql.String, 238 Description: descriptions.Beacon, 239 }, 240 }, 241 Description: "Movement Object", 242 }, 243 )), 244 }, 245 "force": &graphql.InputObjectFieldConfig{ 246 Description: descriptions.Force, 247 Type: graphql.NewNonNull(graphql.Float), 248 }, 249 }, 250 }), 251 }, 252 "moveAwayFrom": &graphql.InputObjectFieldConfig{ 253 Description: descriptions.VectorMovement, 254 Type: graphql.NewInputObject( 255 graphql.InputObjectConfig{ 256 Name: fmt.Sprintf("%sMoveAway", prefix), 257 Fields: graphql.InputObjectConfigFieldMap{ 258 "concepts": &graphql.InputObjectFieldConfig{ 259 Description: descriptions.Keywords, 260 Type: graphql.NewList(graphql.String), 261 }, 262 "objects": &graphql.InputObjectFieldConfig{ 263 Description: "objects", 264 Type: graphql.NewList(graphql.NewInputObject( 265 graphql.InputObjectConfig{ 266 Name: fmt.Sprintf("%sMovementObjectsAwayInpObj", prefix), 267 Fields: graphql.InputObjectConfigFieldMap{ 268 "id": &graphql.InputObjectFieldConfig{ 269 Type: graphql.String, 270 Description: "id of an object", 271 }, 272 "beacon": &graphql.InputObjectFieldConfig{ 273 Type: graphql.String, 274 Description: descriptions.Beacon, 275 }, 276 }, 277 Description: "Movement Object", 278 }, 279 )), 280 }, 281 "force": &graphql.InputObjectFieldConfig{ 282 Description: descriptions.Force, 283 Type: graphql.NewNonNull(graphql.Float), 284 }, 285 }, 286 }), 287 }, 288 "certainty": &graphql.InputObjectFieldConfig{ 289 Description: descriptions.Certainty, 290 Type: graphql.Float, 291 }, 292 "distance": &graphql.InputObjectFieldConfig{ 293 Description: descriptions.Distance, 294 Type: graphql.Float, 295 }, 296 "targetVectors": &graphql.InputObjectFieldConfig{ 297 Description: "Target vectors", 298 Type: graphql.NewList(graphql.String), 299 }, 300 }, 301 Description: descriptions.GetWhereInpObj, 302 }, 303 ), 304 } 305 } 306 307 func (m *nearCustomTextModule) extractNearCustomTextArgument(source map[string]interface{}) *nearCustomTextParams { 308 var args nearCustomTextParams 309 310 concepts := source["concepts"].([]interface{}) 311 args.Values = make([]string, len(concepts)) 312 for i, value := range concepts { 313 args.Values[i] = value.(string) 314 } 315 316 certainty, ok := source["certainty"] 317 if ok { 318 args.Certainty = certainty.(float64) 319 } 320 321 distance, ok := source["distance"] 322 if ok { 323 args.Distance = distance.(float64) 324 args.WithDistance = true 325 } 326 327 // moveTo is an optional arg, so it could be nil 328 moveTo, ok := source["moveTo"] 329 if ok { 330 moveToMap := moveTo.(map[string]interface{}) 331 args.MoveTo = m.parseMoveParam(moveToMap) 332 } 333 334 moveAwayFrom, ok := source["moveAwayFrom"] 335 if ok { 336 moveAwayFromMap := moveAwayFrom.(map[string]interface{}) 337 args.MoveAwayFrom = m.parseMoveParam(moveAwayFromMap) 338 } 339 340 return &args 341 } 342 343 func (m *nearCustomTextModule) parseMoveParam(source map[string]interface{}) nearExploreMove { 344 res := nearExploreMove{} 345 res.Force = float32(source["force"].(float64)) 346 347 concepts, ok := source["concepts"].([]interface{}) 348 if ok { 349 res.Values = make([]string, len(concepts)) 350 for i, value := range concepts { 351 res.Values[i] = value.(string) 352 } 353 } 354 355 objects, ok := source["objects"].([]interface{}) 356 if ok { 357 res.Objects = make([]nearObjectMove, len(objects)) 358 for i, value := range objects { 359 v, ok := value.(map[string]interface{}) 360 if ok { 361 if v["id"] != nil { 362 res.Objects[i].ID = v["id"].(string) 363 } 364 if v["beacon"] != nil { 365 res.Objects[i].Beacon = v["beacon"].(string) 366 } 367 } 368 } 369 } 370 371 return res 372 } 373 374 func (m *nearCustomTextModule) Arguments() map[string]modulecapabilities.GraphQLArgument { 375 arguments := map[string]modulecapabilities.GraphQLArgument{} 376 // define nearCustomText argument 377 arguments["nearCustomText"] = modulecapabilities.GraphQLArgument{ 378 GetArgumentsFunction: func(classname string) *graphql.ArgumentConfig { 379 return m.getNearCustomTextArgument(classname) 380 }, 381 ExtractFunction: func(source map[string]interface{}) interface{} { 382 return m.extractNearCustomTextArgument(source) 383 }, 384 ValidateFunction: func(param interface{}) error { 385 // all is valid 386 return nil 387 }, 388 } 389 return arguments 390 } 391 392 // additional properties 393 func (m *nearCustomTextModule) AdditionalProperties() map[string]modulecapabilities.AdditionalProperty { 394 additionalProperties := map[string]modulecapabilities.AdditionalProperty{} 395 additionalProperties["featureProjection"] = m.getFeatureProjection() 396 additionalProperties["nearestNeighbors"] = m.getNearestNeighbors() 397 additionalProperties["semanticPath"] = m.getSemanticPath() 398 additionalProperties["interpretation"] = m.getInterpretation() 399 return additionalProperties 400 } 401 402 func (m *nearCustomTextModule) getFeatureProjection() modulecapabilities.AdditionalProperty { 403 return modulecapabilities.AdditionalProperty{ 404 DefaultValue: m.fakeProjector.AdditonalPropertyDefaultValue(), 405 GraphQLNames: []string{"featureProjection"}, 406 GraphQLFieldFunction: func(classname string) *graphql.Field { 407 return &graphql.Field{ 408 Args: graphql.FieldConfigArgument{ 409 "algorithm": &graphql.ArgumentConfig{ 410 Type: graphql.String, 411 DefaultValue: nil, 412 }, 413 "dimensions": &graphql.ArgumentConfig{ 414 Type: graphql.Int, 415 DefaultValue: nil, 416 }, 417 "learningRate": &graphql.ArgumentConfig{ 418 Type: graphql.Int, 419 DefaultValue: nil, 420 }, 421 "iterations": &graphql.ArgumentConfig{ 422 Type: graphql.Int, 423 DefaultValue: nil, 424 }, 425 "perplexity": &graphql.ArgumentConfig{ 426 Type: graphql.Int, 427 DefaultValue: nil, 428 }, 429 }, 430 Type: graphql.NewObject(graphql.ObjectConfig{ 431 Name: fmt.Sprintf("%sAdditionalFeatureProjection", classname), 432 Fields: graphql.Fields{ 433 "vector": &graphql.Field{Type: graphql.NewList(graphql.Float)}, 434 }, 435 }), 436 } 437 }, 438 GraphQLExtractFunction: m.fakeProjector.ExtractAdditionalFn, 439 } 440 } 441 442 func (m *nearCustomTextModule) getNearestNeighbors() modulecapabilities.AdditionalProperty { 443 return modulecapabilities.AdditionalProperty{ 444 DefaultValue: m.fakeExtender.AdditonalPropertyDefaultValue(), 445 GraphQLNames: []string{"nearestNeighbors"}, 446 GraphQLFieldFunction: func(classname string) *graphql.Field { 447 return &graphql.Field{ 448 Type: graphql.NewObject(graphql.ObjectConfig{ 449 Name: fmt.Sprintf("%sAdditionalNearestNeighbors", classname), 450 Fields: graphql.Fields{ 451 "neighbors": &graphql.Field{Type: graphql.NewList(graphql.NewObject(graphql.ObjectConfig{ 452 Name: fmt.Sprintf("%sAdditionalNearestNeighborsNeighbors", classname), 453 Fields: graphql.Fields{ 454 "concept": &graphql.Field{Type: graphql.String}, 455 "distance": &graphql.Field{Type: graphql.Float}, 456 }, 457 }))}, 458 }, 459 }), 460 } 461 }, 462 GraphQLExtractFunction: m.fakeExtender.ExtractAdditionalFn, 463 } 464 } 465 466 func (m *nearCustomTextModule) getSemanticPath() modulecapabilities.AdditionalProperty { 467 return modulecapabilities.AdditionalProperty{ 468 DefaultValue: m.fakePathBuilder.AdditonalPropertyDefaultValue(), 469 GraphQLNames: []string{"semanticPath"}, 470 GraphQLFieldFunction: func(classname string) *graphql.Field { 471 return &graphql.Field{ 472 Type: graphql.NewObject(graphql.ObjectConfig{ 473 Name: fmt.Sprintf("%sAdditionalSemanticPath", classname), 474 Fields: graphql.Fields{ 475 "path": &graphql.Field{Type: graphql.NewList(graphql.NewObject(graphql.ObjectConfig{ 476 Name: fmt.Sprintf("%sAdditionalSemanticPathElement", classname), 477 Fields: graphql.Fields{ 478 "concept": &graphql.Field{Type: graphql.String}, 479 "distanceToQuery": &graphql.Field{Type: graphql.Float}, 480 "distanceToResult": &graphql.Field{Type: graphql.Float}, 481 "distanceToNext": &graphql.Field{Type: graphql.Float}, 482 "distanceToPrevious": &graphql.Field{Type: graphql.Float}, 483 }, 484 }))}, 485 }, 486 }), 487 } 488 }, 489 GraphQLExtractFunction: m.fakePathBuilder.ExtractAdditionalFn, 490 } 491 } 492 493 func (m *nearCustomTextModule) getInterpretation() modulecapabilities.AdditionalProperty { 494 return modulecapabilities.AdditionalProperty{ 495 DefaultValue: m.fakeInterpretation.AdditonalPropertyDefaultValue(), 496 GraphQLNames: []string{"interpretation"}, 497 GraphQLFieldFunction: func(classname string) *graphql.Field { 498 return &graphql.Field{ 499 Type: graphql.NewObject(graphql.ObjectConfig{ 500 Name: fmt.Sprintf("%sAdditionalInterpretation", classname), 501 Fields: graphql.Fields{ 502 "source": &graphql.Field{Type: graphql.NewList(graphql.NewObject(graphql.ObjectConfig{ 503 Name: fmt.Sprintf("%sAdditionalInterpretationSource", classname), 504 Fields: graphql.Fields{ 505 "concept": &graphql.Field{Type: graphql.String}, 506 "weight": &graphql.Field{Type: graphql.Float}, 507 "occurrence": &graphql.Field{Type: graphql.Int}, 508 }, 509 }))}, 510 }, 511 }), 512 } 513 }, 514 GraphQLExtractFunction: m.fakeInterpretation.ExtractAdditionalFn, 515 } 516 } 517 518 type fakeModulesProvider struct { 519 nearCustomTextModule *nearCustomTextModule 520 } 521 522 func newFakeModulesProvider() *fakeModulesProvider { 523 return &fakeModulesProvider{newNearCustomTextModule()} 524 } 525 526 func (fmp *fakeModulesProvider) GetAll() []modulecapabilities.Module { 527 panic("implement me") 528 } 529 530 func (fmp *fakeModulesProvider) VectorFromInput(ctx context.Context, className, input, targetVector string) ([]float32, error) { 531 panic("not implemented") 532 } 533 534 func (fmp *fakeModulesProvider) GetArguments(class *models.Class) map[string]*graphql.ArgumentConfig { 535 args := map[string]*graphql.ArgumentConfig{} 536 if class.Vectorizer == fmp.nearCustomTextModule.Name() { 537 for name, argument := range fmp.nearCustomTextModule.Arguments() { 538 args[name] = argument.GetArgumentsFunction(class.Class) 539 } 540 } 541 return args 542 } 543 544 func (fmp *fakeModulesProvider) ExtractSearchParams(arguments map[string]interface{}, className string) map[string]interface{} { 545 exractedParams := map[string]interface{}{} 546 if param, ok := arguments["nearCustomText"]; ok { 547 exractedParams["nearCustomText"] = extractNearTextParam(param.(map[string]interface{})) 548 } 549 return exractedParams 550 } 551 552 func (fmp *fakeModulesProvider) GetAdditionalFields(class *models.Class) map[string]*graphql.Field { 553 additionalProperties := map[string]*graphql.Field{} 554 for name, additionalProperty := range fmp.nearCustomTextModule.AdditionalProperties() { 555 if additionalProperty.GraphQLFieldFunction != nil { 556 additionalProperties[name] = additionalProperty.GraphQLFieldFunction(class.Class) 557 } 558 } 559 return additionalProperties 560 } 561 562 func (fmp *fakeModulesProvider) ExtractAdditionalField(className, name string, params []*ast.Argument) interface{} { 563 if additionalProperties := fmp.nearCustomTextModule.AdditionalProperties(); len(additionalProperties) > 0 { 564 if additionalProperty, ok := additionalProperties[name]; ok { 565 if additionalProperty.GraphQLExtractFunction != nil { 566 return additionalProperty.GraphQLExtractFunction(params) 567 } 568 } 569 } 570 return nil 571 } 572 573 func (fmp *fakeModulesProvider) GraphQLAdditionalFieldNames() []string { 574 additionalPropertiesNames := []string{} 575 for _, additionalProperty := range fmp.nearCustomTextModule.AdditionalProperties() { 576 if additionalProperty.GraphQLNames != nil { 577 additionalPropertiesNames = append(additionalPropertiesNames, additionalProperty.GraphQLNames...) 578 } 579 } 580 return additionalPropertiesNames 581 } 582 583 func extractNearTextParam(param map[string]interface{}) interface{} { 584 nearCustomTextModule := newNearCustomTextModule() 585 argument := nearCustomTextModule.Arguments()["nearCustomText"] 586 return argument.ExtractFunction(param) 587 } 588 589 func createArg(name string, value string) *ast.Argument { 590 n := ast.Name{ 591 Value: name, 592 } 593 val := ast.StringValue{ 594 Kind: "Kind", 595 Value: value, 596 } 597 arg := ast.Argument{ 598 Name: ast.NewName(&n), 599 Kind: "Kind", 600 Value: ast.NewStringValue(&val), 601 } 602 a := ast.NewArgument(&arg) 603 return a 604 } 605 606 func extractAdditionalParam(name string, args []*ast.Argument) interface{} { 607 nearCustomTextModule := newNearCustomTextModule() 608 additionalProperties := nearCustomTextModule.AdditionalProperties() 609 switch name { 610 case "semanticPath", "featureProjection": 611 if ap, ok := additionalProperties[name]; ok { 612 return ap.GraphQLExtractFunction(args) 613 } 614 return nil 615 default: 616 return nil 617 } 618 } 619 620 func getFakeModulesProvider() ModulesProvider { 621 return newFakeModulesProvider() 622 } 623 624 func newMockResolver() *mockResolver { 625 return newMockResolverWithVectorizer(config.VectorizerModuleText2VecContextionary) 626 } 627 628 func newMockResolverWithVectorizer(vectorizer string) *mockResolver { 629 logger, _ := test.NewNullLogger() 630 simpleSchema := test_helper.CreateSimpleSchema(vectorizer) 631 field, err := Build(&simpleSchema, logger, getFakeModulesProvider()) 632 if err != nil { 633 panic(fmt.Sprintf("could not build graphql test schema: %s", err)) 634 } 635 mocker := &mockResolver{} 636 mockLog := &mockRequestsLog{} 637 mocker.RootFieldName = "Get" 638 mocker.RootField = field 639 mocker.RootObject = map[string]interface{}{"Resolver": Resolver(mocker), "RequestsLog": RequestsLog(mockLog)} 640 return mocker 641 } 642 643 func newMockResolverWithNoModules() *mockResolver { 644 logger, _ := test.NewNullLogger() 645 field, err := Build(&test_helper.SimpleSchema, logger, nil) 646 if err != nil { 647 panic(fmt.Sprintf("could not build graphql test schema: %s", err)) 648 } 649 mocker := &mockResolver{} 650 mockLog := &mockRequestsLog{} 651 mocker.RootFieldName = "Get" 652 mocker.RootField = field 653 mocker.RootObject = map[string]interface{}{"Resolver": Resolver(mocker), "RequestsLog": RequestsLog(mockLog)} 654 return mocker 655 } 656 657 func (m *mockResolver) GetClass(ctx context.Context, principal *models.Principal, 658 params dto.GetParams, 659 ) ([]interface{}, error) { 660 args := m.Called(params) 661 return args.Get(0).([]interface{}), args.Error(1) 662 }