github.com/weaviate/weaviate@v1.24.6/usecases/modules/modules.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package modules 13 14 import ( 15 "context" 16 "fmt" 17 "regexp" 18 "sync" 19 20 "github.com/pkg/errors" 21 "github.com/sirupsen/logrus" 22 "github.com/tailor-inc/graphql" 23 "github.com/tailor-inc/graphql/language/ast" 24 "github.com/weaviate/weaviate/entities/models" 25 "github.com/weaviate/weaviate/entities/modulecapabilities" 26 "github.com/weaviate/weaviate/entities/moduletools" 27 "github.com/weaviate/weaviate/entities/schema" 28 "github.com/weaviate/weaviate/entities/search" 29 "github.com/weaviate/weaviate/usecases/modulecomponents" 30 ) 31 32 var ( 33 internalSearchers = []string{ 34 "nearObject", "nearVector", "where", "group", "limit", "offset", 35 "after", "groupBy", "bm25", "hybrid", 36 } 37 internalAdditionalProperties = []string{"classification", "certainty", "id", "distance", "group"} 38 ) 39 40 type Provider struct { 41 vectorsLock sync.RWMutex 42 registered map[string]modulecapabilities.Module 43 altNames map[string]string 44 schemaGetter schemaGetter 45 hasMultipleVectorizers bool 46 targetVectorNameValidator *regexp.Regexp 47 } 48 49 type schemaGetter interface { 50 GetSchemaSkipAuth() schema.Schema 51 } 52 53 func NewProvider() *Provider { 54 return &Provider{ 55 registered: map[string]modulecapabilities.Module{}, 56 altNames: map[string]string{}, 57 targetVectorNameValidator: regexp.MustCompile(`^` + schema.TargetVectorNameRegex + `$`), 58 } 59 } 60 61 func (p *Provider) Register(mod modulecapabilities.Module) { 62 p.registered[mod.Name()] = mod 63 if modHasAltNames, ok := mod.(modulecapabilities.ModuleHasAltNames); ok { 64 for _, altName := range modHasAltNames.AltNames() { 65 p.altNames[altName] = mod.Name() 66 } 67 } 68 } 69 70 func (p *Provider) GetByName(name string) modulecapabilities.Module { 71 if mod, ok := p.registered[name]; ok { 72 return mod 73 } 74 if origName, ok := p.altNames[name]; ok { 75 return p.registered[origName] 76 } 77 return nil 78 } 79 80 func (p *Provider) GetAll() []modulecapabilities.Module { 81 out := make([]modulecapabilities.Module, len(p.registered)) 82 i := 0 83 for _, mod := range p.registered { 84 out[i] = mod 85 i++ 86 } 87 88 return out 89 } 90 91 func (p *Provider) GetAllExclude(module string) []modulecapabilities.Module { 92 filtered := []modulecapabilities.Module{} 93 for _, mod := range p.GetAll() { 94 if mod.Name() != module { 95 filtered = append(filtered, mod) 96 } 97 } 98 return filtered 99 } 100 101 func (p *Provider) SetSchemaGetter(sg schemaGetter) { 102 p.schemaGetter = sg 103 } 104 105 func (p *Provider) Init(ctx context.Context, 106 params moduletools.ModuleInitParams, logger logrus.FieldLogger, 107 ) error { 108 for i, mod := range p.GetAll() { 109 if err := mod.Init(ctx, params); err != nil { 110 return errors.Wrapf(err, "init module %d (%q)", i, mod.Name()) 111 } else { 112 logger.WithField("action", "startup"). 113 WithField("module", mod.Name()). 114 Debug("initialized module") 115 } 116 } 117 for i, mod := range p.GetAll() { 118 if modExtension, ok := mod.(modulecapabilities.ModuleExtension); ok { 119 if err := modExtension.InitExtension(p.GetAllExclude(mod.Name())); err != nil { 120 return errors.Wrapf(err, "init module extension %d (%q)", i, mod.Name()) 121 } else { 122 logger.WithField("action", "startup"). 123 WithField("module", mod.Name()). 124 Debug("initialized module extension") 125 } 126 } 127 } 128 for i, mod := range p.GetAll() { 129 if modDependency, ok := mod.(modulecapabilities.ModuleDependency); ok { 130 if err := modDependency.InitDependency(p.GetAllExclude(mod.Name())); err != nil { 131 return errors.Wrapf(err, "init module dependency %d (%q)", i, mod.Name()) 132 } else { 133 logger.WithField("action", "startup"). 134 WithField("module", mod.Name()). 135 Debug("initialized module dependency") 136 } 137 } 138 } 139 if err := p.validate(); err != nil { 140 return errors.Wrap(err, "validate modules") 141 } 142 if p.HasMultipleVectorizers() { 143 logger.Warn("Multiple vector spaces are present, GraphQL Explore and REST API list objects endpoint module include params has been disabled as a result.") 144 } 145 return nil 146 } 147 148 func (p *Provider) validate() error { 149 searchers := map[string][]string{} 150 additionalGraphQLProps := map[string][]string{} 151 additionalRestAPIProps := map[string][]string{} 152 for _, mod := range p.GetAll() { 153 if module, ok := mod.(modulecapabilities.GraphQLArguments); ok { 154 allArguments := []string{} 155 for paraName, argument := range module.Arguments() { 156 if argument.ExtractFunction != nil { 157 allArguments = append(allArguments, paraName) 158 } 159 } 160 searchers = p.scanProperties(searchers, allArguments, mod.Name()) 161 } 162 if module, ok := mod.(modulecapabilities.AdditionalProperties); ok { 163 allAdditionalRestAPIProps, allAdditionalGrapQLProps := p.getAdditionalProps(module.AdditionalProperties()) 164 additionalGraphQLProps = p.scanProperties(additionalGraphQLProps, 165 allAdditionalGrapQLProps, mod.Name()) 166 additionalRestAPIProps = p.scanProperties(additionalRestAPIProps, 167 allAdditionalRestAPIProps, mod.Name()) 168 } 169 } 170 171 var errorMessages []string 172 errorMessages = append(errorMessages, 173 p.validateModules("searcher", searchers, internalSearchers)...) 174 errorMessages = append(errorMessages, 175 p.validateModules("graphql additional property", additionalGraphQLProps, internalAdditionalProperties)...) 176 errorMessages = append(errorMessages, 177 p.validateModules("rest api additional property", additionalRestAPIProps, internalAdditionalProperties)...) 178 if len(errorMessages) > 0 { 179 return errors.Errorf("%v", errorMessages) 180 } 181 182 return nil 183 } 184 185 func (p *Provider) scanProperties(result map[string][]string, properties []string, module string) map[string][]string { 186 for i := range properties { 187 if result[properties[i]] == nil { 188 result[properties[i]] = []string{} 189 } 190 modules := result[properties[i]] 191 modules = append(modules, module) 192 result[properties[i]] = modules 193 } 194 return result 195 } 196 197 func (p *Provider) getAdditionalProps(additionalProps map[string]modulecapabilities.AdditionalProperty) ([]string, []string) { 198 restProps := []string{} 199 graphQLProps := []string{} 200 201 for _, additionalProperty := range additionalProps { 202 if additionalProperty.RestNames != nil { 203 restProps = append(restProps, additionalProperty.RestNames...) 204 } 205 if additionalProperty.GraphQLNames != nil { 206 graphQLProps = append(graphQLProps, additionalProperty.GraphQLNames...) 207 } 208 } 209 return restProps, graphQLProps 210 } 211 212 func (p *Provider) validateModules(name string, properties map[string][]string, internalProperties []string) []string { 213 errorMessages := []string{} 214 for propertyName, modules := range properties { 215 for i := range internalProperties { 216 if internalProperties[i] == propertyName { 217 errorMessages = append(errorMessages, 218 fmt.Sprintf("%s: %s conflicts with weaviate's internal searcher in modules: %v", 219 name, propertyName, modules)) 220 } 221 } 222 if len(modules) > 1 { 223 p.hasMultipleVectorizers = true 224 } 225 for _, moduleName := range modules { 226 moduleType := p.GetByName(moduleName).Type() 227 if p.moduleProvidesMultipleVectorizers(moduleType) { 228 p.hasMultipleVectorizers = true 229 } 230 } 231 } 232 return errorMessages 233 } 234 235 func (p *Provider) moduleProvidesMultipleVectorizers(moduleType modulecapabilities.ModuleType) bool { 236 return moduleType == modulecapabilities.Text2MultiVec 237 } 238 239 func (p *Provider) isOnlyOneModuleEnabledOfAGivenType(moduleType modulecapabilities.ModuleType) bool { 240 i := 0 241 for _, mod := range p.registered { 242 if mod.Type() == moduleType { 243 i++ 244 } 245 } 246 return i == 1 247 } 248 249 func (p *Provider) isVectorizerModule(moduleType modulecapabilities.ModuleType) bool { 250 switch moduleType { 251 case modulecapabilities.Text2Vec, 252 modulecapabilities.Img2Vec, 253 modulecapabilities.Multi2Vec, 254 modulecapabilities.Text2MultiVec, 255 modulecapabilities.Ref2Vec: 256 return true 257 default: 258 return false 259 } 260 } 261 262 func (p *Provider) shouldIncludeClassArgument(class *models.Class, module string, 263 moduleType modulecapabilities.ModuleType, 264 ) bool { 265 if p.isVectorizerModule(moduleType) { 266 for _, vectorConfig := range class.VectorConfig { 267 if vectorizer, ok := vectorConfig.Vectorizer.(map[string]interface{}); ok { 268 if _, ok := vectorizer[module]; ok { 269 return true 270 } 271 } 272 } 273 return class.Vectorizer == module 274 } 275 if moduleConfig, ok := class.ModuleConfig.(map[string]interface{}); ok { 276 existsConfigForModule := moduleConfig[module] != nil 277 if existsConfigForModule { 278 return true 279 } 280 } 281 // Allow Text2Text (Generative, QnA, Summarize, NER) modules to be registered to a given class 282 // only if there's no configuration present and there's only one module of a given type enabled 283 return p.isOnlyOneModuleEnabledOfAGivenType(moduleType) 284 } 285 286 func (p *Provider) shouldCrossClassIncludeClassArgument(class *models.Class, module string, 287 moduleType modulecapabilities.ModuleType, 288 ) bool { 289 if class == nil { 290 return !p.HasMultipleVectorizers() 291 } 292 return p.shouldIncludeClassArgument(class, module, moduleType) 293 } 294 295 func (p *Provider) shouldIncludeArgument(schema *models.Schema, module string, 296 moduleType modulecapabilities.ModuleType, 297 ) bool { 298 for _, c := range schema.Classes { 299 if p.shouldIncludeClassArgument(c, module, moduleType) { 300 return true 301 } 302 } 303 return false 304 } 305 306 func (p *Provider) shouldAddGenericArgument(class *models.Class, moduleType modulecapabilities.ModuleType) bool { 307 return p.hasMultipleVectorizersConfig(class) && p.isVectorizerModule(moduleType) 308 } 309 310 func (p *Provider) hasMultipleVectorizersConfig(class *models.Class) bool { 311 return len(class.VectorConfig) > 0 312 } 313 314 func (p *Provider) shouldCrossClassAddGenericArgument(schema *models.Schema, moduleType modulecapabilities.ModuleType) bool { 315 for _, c := range schema.Classes { 316 if p.shouldAddGenericArgument(c, moduleType) { 317 return true 318 } 319 } 320 return false 321 } 322 323 func (p *Provider) getGenericArgument(name, className string, 324 argumentType modulecomponents.ArgumentType, 325 ) *graphql.ArgumentConfig { 326 var nearTextTransformer modulecapabilities.TextTransform 327 if name == "nearText" { 328 // nearText argument might be exposed with an extension, we need to check 329 // if text transformers module is enabled if so then we need to init nearText 330 // argument with this extension 331 for _, mod := range p.GetAll() { 332 if arg, ok := mod.(modulecapabilities.TextTransformers); ok { 333 if arg != nil && arg.TextTransformers() != nil { 334 nearTextTransformer = arg.TextTransformers()["nearText"] 335 break 336 } 337 } 338 } 339 } 340 return modulecomponents.GetGenericArgument(name, className, argumentType, nearTextTransformer) 341 } 342 343 func (p *Provider) getGenericAdditionalProperty(name string, class *models.Class) *modulecapabilities.AdditionalProperty { 344 if p.hasMultipleVectorizersConfig(class) { 345 return modulecomponents.GetGenericAdditionalProperty(name, class.Class) 346 } 347 return nil 348 } 349 350 // GetArguments provides GraphQL Get arguments 351 func (p *Provider) GetArguments(class *models.Class) map[string]*graphql.ArgumentConfig { 352 arguments := map[string]*graphql.ArgumentConfig{} 353 for _, module := range p.GetAll() { 354 if p.shouldIncludeClassArgument(class, module.Name(), module.Type()) { 355 if arg, ok := module.(modulecapabilities.GraphQLArguments); ok { 356 for name, argument := range arg.Arguments() { 357 if argument.GetArgumentsFunction != nil { 358 if p.shouldAddGenericArgument(class, module.Type()) { 359 if _, ok := arguments[name]; !ok { 360 arguments[name] = p.getGenericArgument(name, class.Class, modulecomponents.Get) 361 } 362 } else { 363 arguments[name] = argument.GetArgumentsFunction(class.Class) 364 } 365 } 366 } 367 } 368 } 369 } 370 return arguments 371 } 372 373 // AggregateArguments provides GraphQL Aggregate arguments 374 func (p *Provider) AggregateArguments(class *models.Class) map[string]*graphql.ArgumentConfig { 375 arguments := map[string]*graphql.ArgumentConfig{} 376 for _, module := range p.GetAll() { 377 if p.shouldIncludeClassArgument(class, module.Name(), module.Type()) { 378 if arg, ok := module.(modulecapabilities.GraphQLArguments); ok { 379 for name, argument := range arg.Arguments() { 380 if argument.AggregateArgumentsFunction != nil { 381 if p.shouldAddGenericArgument(class, module.Type()) { 382 if _, ok := arguments[name]; !ok { 383 arguments[name] = p.getGenericArgument(name, class.Class, modulecomponents.Aggregate) 384 } 385 } else { 386 arguments[name] = argument.AggregateArgumentsFunction(class.Class) 387 } 388 } 389 } 390 } 391 } 392 } 393 return arguments 394 } 395 396 // ExploreArguments provides GraphQL Explore arguments 397 func (p *Provider) ExploreArguments(schema *models.Schema) map[string]*graphql.ArgumentConfig { 398 arguments := map[string]*graphql.ArgumentConfig{} 399 for _, module := range p.GetAll() { 400 if p.shouldIncludeArgument(schema, module.Name(), module.Type()) { 401 if arg, ok := module.(modulecapabilities.GraphQLArguments); ok { 402 for name, argument := range arg.Arguments() { 403 if argument.ExploreArgumentsFunction != nil { 404 if p.shouldCrossClassAddGenericArgument(schema, module.Type()) { 405 if _, ok := arguments[name]; !ok { 406 arguments[name] = p.getGenericArgument(name, "", modulecomponents.Explore) 407 } 408 } else { 409 arguments[name] = argument.ExploreArgumentsFunction() 410 } 411 } 412 } 413 } 414 } 415 } 416 return arguments 417 } 418 419 // CrossClassExtractSearchParams extracts GraphQL arguments from modules without 420 // being specific to any one class and it's configuration. This is used in 421 // Explore() { } for example 422 func (p *Provider) CrossClassExtractSearchParams(arguments map[string]interface{}) map[string]interface{} { 423 return p.extractSearchParams(arguments, nil) 424 } 425 426 // ExtractSearchParams extracts GraphQL arguments 427 func (p *Provider) ExtractSearchParams(arguments map[string]interface{}, className string) map[string]interface{} { 428 exractedParams := map[string]interface{}{} 429 class, err := p.getClass(className) 430 if err != nil { 431 return exractedParams 432 } 433 return p.extractSearchParams(arguments, class) 434 } 435 436 func (p *Provider) extractSearchParams(arguments map[string]interface{}, class *models.Class) map[string]interface{} { 437 exractedParams := map[string]interface{}{} 438 for _, module := range p.GetAll() { 439 if p.shouldCrossClassIncludeClassArgument(class, module.Name(), module.Type()) { 440 if args, ok := module.(modulecapabilities.GraphQLArguments); ok { 441 for paramName, argument := range args.Arguments() { 442 if param, ok := arguments[paramName]; ok && argument.ExtractFunction != nil { 443 extracted := argument.ExtractFunction(param.(map[string]interface{})) 444 exractedParams[paramName] = extracted 445 } 446 } 447 } 448 } 449 } 450 return exractedParams 451 } 452 453 // CrossClassValidateSearchParam validates module parameters without 454 // being specific to any one class and it's configuration. This is used in 455 // Explore() { } for example 456 func (p *Provider) CrossClassValidateSearchParam(name string, value interface{}) error { 457 return p.validateSearchParam(name, value, nil) 458 } 459 460 // ValidateSearchParam validates module parameters 461 func (p *Provider) ValidateSearchParam(name string, value interface{}, className string) error { 462 class, err := p.getClass(className) 463 if err != nil { 464 return err 465 } 466 467 return p.validateSearchParam(name, value, class) 468 } 469 470 func (p *Provider) validateSearchParam(name string, value interface{}, class *models.Class) error { 471 for _, module := range p.GetAll() { 472 if p.shouldCrossClassIncludeClassArgument(class, module.Name(), module.Type()) { 473 if args, ok := module.(modulecapabilities.GraphQLArguments); ok { 474 for paramName, argument := range args.Arguments() { 475 if paramName == name && argument.ValidateFunction != nil { 476 return argument.ValidateFunction(value) 477 } 478 } 479 } 480 } 481 } 482 483 panic("ValidateParam was called without any known params present") 484 } 485 486 // GetAdditionalFields provides GraphQL Get additional fields 487 func (p *Provider) GetAdditionalFields(class *models.Class) map[string]*graphql.Field { 488 additionalProperties := map[string]*graphql.Field{} 489 for _, module := range p.GetAll() { 490 if p.shouldIncludeClassArgument(class, module.Name(), module.Type()) { 491 if arg, ok := module.(modulecapabilities.AdditionalProperties); ok { 492 for name, additionalProperty := range arg.AdditionalProperties() { 493 if additionalProperty.GraphQLFieldFunction != nil { 494 if genericAdditionalProperty := p.getGenericAdditionalProperty(name, class); genericAdditionalProperty != nil { 495 if genericAdditionalProperty.GraphQLFieldFunction != nil { 496 if _, ok := additionalProperties[name]; !ok { 497 additionalProperties[name] = genericAdditionalProperty.GraphQLFieldFunction(class.Class) 498 } 499 } 500 } else { 501 additionalProperties[name] = additionalProperty.GraphQLFieldFunction(class.Class) 502 } 503 } 504 } 505 } 506 } 507 } 508 return additionalProperties 509 } 510 511 // ExtractAdditionalField extracts additional properties from given graphql arguments 512 func (p *Provider) ExtractAdditionalField(className, name string, params []*ast.Argument) interface{} { 513 class, err := p.getClass(className) 514 if err != nil { 515 return err 516 } 517 for _, module := range p.GetAll() { 518 if p.shouldIncludeClassArgument(class, module.Name(), module.Type()) { 519 if arg, ok := module.(modulecapabilities.AdditionalProperties); ok { 520 if additionalProperties := arg.AdditionalProperties(); len(additionalProperties) > 0 { 521 if additionalProperty, ok := additionalProperties[name]; ok { 522 return additionalProperty.GraphQLExtractFunction(params) 523 } 524 } 525 } 526 } 527 } 528 return nil 529 } 530 531 // GetObjectAdditionalExtend extends rest api get queries with additional properties 532 func (p *Provider) GetObjectAdditionalExtend(ctx context.Context, 533 in *search.Result, moduleParams map[string]interface{}, 534 ) (*search.Result, error) { 535 resArray, err := p.additionalExtend(ctx, search.Results{*in}, moduleParams, nil, "ObjectGet", nil) 536 if err != nil { 537 return nil, err 538 } 539 return &resArray[0], nil 540 } 541 542 // ListObjectsAdditionalExtend extends rest api list queries with additional properties 543 func (p *Provider) ListObjectsAdditionalExtend(ctx context.Context, 544 in search.Results, moduleParams map[string]interface{}, 545 ) (search.Results, error) { 546 return p.additionalExtend(ctx, in, moduleParams, nil, "ObjectList", nil) 547 } 548 549 // GetExploreAdditionalExtend extends graphql api get queries with additional properties 550 func (p *Provider) GetExploreAdditionalExtend(ctx context.Context, in []search.Result, 551 moduleParams map[string]interface{}, searchVector []float32, 552 argumentModuleParams map[string]interface{}, 553 ) ([]search.Result, error) { 554 return p.additionalExtend(ctx, in, moduleParams, searchVector, "ExploreGet", argumentModuleParams) 555 } 556 557 // ListExploreAdditionalExtend extends graphql api list queries with additional properties 558 func (p *Provider) ListExploreAdditionalExtend(ctx context.Context, in []search.Result, 559 moduleParams map[string]interface{}, 560 argumentModuleParams map[string]interface{}, 561 ) ([]search.Result, error) { 562 return p.additionalExtend(ctx, in, moduleParams, nil, "ExploreList", argumentModuleParams) 563 } 564 565 func (p *Provider) additionalExtend(ctx context.Context, in []search.Result, 566 moduleParams map[string]interface{}, searchVector []float32, 567 capability string, argumentModuleParams map[string]interface{}, 568 ) ([]search.Result, error) { 569 toBeExtended := in 570 if len(toBeExtended) > 0 { 571 class, err := p.getClassFromSearchResult(toBeExtended) 572 if err != nil { 573 return nil, err 574 } 575 allAdditionalProperties := map[string]modulecapabilities.AdditionalProperty{} 576 for _, module := range p.GetAll() { 577 if p.shouldIncludeClassArgument(class, module.Name(), module.Type()) { 578 if arg, ok := module.(modulecapabilities.AdditionalProperties); ok { 579 if arg != nil && arg.AdditionalProperties() != nil { 580 for name, additionalProperty := range arg.AdditionalProperties() { 581 allAdditionalProperties[name] = additionalProperty 582 } 583 } 584 } 585 } 586 } 587 if len(allAdditionalProperties) > 0 { 588 if err := p.checkCapabilities(allAdditionalProperties, moduleParams, capability); err != nil { 589 return nil, err 590 } 591 cfg := NewClassBasedModuleConfig(class, "", "", "") 592 for name, value := range moduleParams { 593 additionalPropertyFn := p.getAdditionalPropertyFn(allAdditionalProperties[name], capability) 594 if additionalPropertyFn != nil && value != nil { 595 searchValue := value 596 if searchVectorValue, ok := value.(modulecapabilities.AdditionalPropertyWithSearchVector); ok { 597 searchVectorValue.SetSearchVector(searchVector) 598 searchValue = searchVectorValue 599 } 600 resArray, err := additionalPropertyFn(ctx, toBeExtended, searchValue, nil, argumentModuleParams, cfg) 601 if err != nil { 602 return nil, errors.Errorf("extend %s: %v", name, err) 603 } 604 toBeExtended = resArray 605 } else { 606 return nil, errors.Errorf("unknown capability: %s", name) 607 } 608 } 609 } 610 } 611 return toBeExtended, nil 612 } 613 614 func (p *Provider) getClassFromSearchResult(in []search.Result) (*models.Class, error) { 615 if len(in) > 0 { 616 return p.getClass(in[0].ClassName) 617 } 618 return nil, errors.Errorf("unknown class") 619 } 620 621 func (p *Provider) checkCapabilities(additionalProperties map[string]modulecapabilities.AdditionalProperty, 622 moduleParams map[string]interface{}, capability string, 623 ) error { 624 for name := range moduleParams { 625 additionalPropertyFn := p.getAdditionalPropertyFn(additionalProperties[name], capability) 626 if additionalPropertyFn == nil { 627 return errors.Errorf("unknown capability: %s", name) 628 } 629 } 630 return nil 631 } 632 633 func (p *Provider) getAdditionalPropertyFn( 634 additionalProperty modulecapabilities.AdditionalProperty, 635 capability string, 636 ) modulecapabilities.AdditionalPropertyFn { 637 switch capability { 638 case "ObjectGet": 639 return additionalProperty.SearchFunctions.ObjectGet 640 case "ObjectList": 641 return additionalProperty.SearchFunctions.ObjectList 642 case "ExploreGet": 643 return additionalProperty.SearchFunctions.ExploreGet 644 case "ExploreList": 645 return additionalProperty.SearchFunctions.ExploreList 646 default: 647 return nil 648 } 649 } 650 651 // GraphQLAdditionalFieldNames get's all additional field names used in graphql 652 func (p *Provider) GraphQLAdditionalFieldNames() []string { 653 additionalPropertiesNames := []string{} 654 for _, module := range p.GetAll() { 655 if arg, ok := module.(modulecapabilities.AdditionalProperties); ok { 656 for _, additionalProperty := range arg.AdditionalProperties() { 657 if additionalProperty.GraphQLNames != nil { 658 additionalPropertiesNames = append(additionalPropertiesNames, additionalProperty.GraphQLNames...) 659 } 660 } 661 } 662 } 663 return additionalPropertiesNames 664 } 665 666 // RestApiAdditionalProperties get's all rest specific additional properties with their 667 // default values 668 func (p *Provider) RestApiAdditionalProperties(includeProp string, class *models.Class) map[string]interface{} { 669 moduleParams := map[string]interface{}{} 670 for _, module := range p.GetAll() { 671 if p.shouldCrossClassIncludeClassArgument(class, module.Name(), module.Type()) { 672 if arg, ok := module.(modulecapabilities.AdditionalProperties); ok { 673 for name, additionalProperty := range arg.AdditionalProperties() { 674 for _, includePropName := range additionalProperty.RestNames { 675 if includePropName == includeProp && moduleParams[name] == nil { 676 moduleParams[name] = additionalProperty.DefaultValue 677 } 678 } 679 } 680 } 681 } 682 } 683 return moduleParams 684 } 685 686 // VectorFromSearchParam gets a vector for a given argument. This is used in 687 // Get { Class() } for example 688 func (p *Provider) VectorFromSearchParam(ctx context.Context, 689 className string, param string, params interface{}, 690 findVectorFn modulecapabilities.FindVectorFn, tenant string, 691 ) ([]float32, string, error) { 692 class, err := p.getClass(className) 693 if err != nil { 694 return nil, "", err 695 } 696 targetVector, err := p.getTargetVector(class, params) 697 if err != nil { 698 return nil, "", err 699 } 700 targetModule := p.getModuleNameForTargetVector(class, targetVector) 701 702 for _, mod := range p.GetAll() { 703 if p.shouldIncludeClassArgument(class, mod.Name(), mod.Type()) { 704 var moduleName string 705 var vectorSearches modulecapabilities.ArgumentVectorForParams 706 if searcher, ok := mod.(modulecapabilities.Searcher); ok { 707 if mod.Name() == targetModule { 708 moduleName = mod.Name() 709 vectorSearches = searcher.VectorSearches() 710 } 711 } else if searchers, ok := mod.(modulecapabilities.DependencySearcher); ok { 712 if dependencySearchers := searchers.VectorSearches(); dependencySearchers != nil { 713 moduleName = targetModule 714 vectorSearches = dependencySearchers[targetModule] 715 } 716 } 717 if vectorSearches != nil { 718 if searchVectorFn := vectorSearches[param]; searchVectorFn != nil { 719 cfg := NewClassBasedModuleConfig(class, moduleName, tenant, targetVector) 720 vector, err := searchVectorFn(ctx, params, class.Class, findVectorFn, cfg) 721 if err != nil { 722 return nil, "", errors.Errorf("vectorize params: %v", err) 723 } 724 return vector, targetVector, nil 725 } 726 } 727 } 728 } 729 730 panic("VectorFromParams was called without any known params present") 731 } 732 733 // CrossClassVectorFromSearchParam gets a vector for a given argument without 734 // being specific to any one class and it's configuration. This is used in 735 // Explore() { } for example 736 func (p *Provider) CrossClassVectorFromSearchParam(ctx context.Context, 737 param string, params interface{}, 738 findVectorFn modulecapabilities.FindVectorFn, 739 ) ([]float32, string, error) { 740 for _, mod := range p.GetAll() { 741 if searcher, ok := mod.(modulecapabilities.Searcher); ok { 742 if vectorSearches := searcher.VectorSearches(); vectorSearches != nil { 743 if searchVectorFn := vectorSearches[param]; searchVectorFn != nil { 744 cfg := NewCrossClassModuleConfig() 745 vector, err := searchVectorFn(ctx, params, "", findVectorFn, cfg) 746 if err != nil { 747 return nil, "", errors.Errorf("vectorize params: %v", err) 748 } 749 targetVector, err := p.getTargetVector(nil, params) 750 if err != nil { 751 return nil, "", errors.Errorf("get target vector: %v", err) 752 } 753 return vector, targetVector, nil 754 } 755 } 756 } 757 } 758 759 panic("VectorFromParams was called without any known params present") 760 } 761 762 func (p *Provider) getTargetVector(class *models.Class, params interface{}) (string, error) { 763 if nearParam, ok := params.(modulecapabilities.NearParam); ok && len(nearParam.GetTargetVectors()) == 1 { 764 return nearParam.GetTargetVectors()[0], nil 765 } 766 if class != nil { 767 if len(class.VectorConfig) > 1 { 768 return "", fmt.Errorf("multiple vectorizers configuration found, please specify target vector name") 769 } 770 771 if len(class.VectorConfig) == 1 { 772 for name := range class.VectorConfig { 773 return name, nil 774 } 775 } 776 } 777 return "", nil 778 } 779 780 func (p *Provider) getModuleNameForTargetVector(class *models.Class, targetVector string) string { 781 if len(class.VectorConfig) > 0 { 782 if vectorConfig, ok := class.VectorConfig[targetVector]; ok { 783 if vectorizer, ok := vectorConfig.Vectorizer.(map[string]interface{}); ok && len(vectorizer) == 1 { 784 for moduleName := range vectorizer { 785 return moduleName 786 } 787 } 788 } 789 } 790 return class.Vectorizer 791 } 792 793 func (p *Provider) VectorFromInput(ctx context.Context, 794 className, input, targetVector string, 795 ) ([]float32, error) { 796 class, err := p.getClass(className) 797 if err != nil { 798 return nil, err 799 } 800 targetModule := p.getModuleNameForTargetVector(class, targetVector) 801 802 for _, mod := range p.GetAll() { 803 if mod.Name() == targetModule { 804 if p.shouldIncludeClassArgument(class, mod.Name(), mod.Type()) { 805 if vectorizer, ok := mod.(modulecapabilities.InputVectorizer); ok { 806 // does not access any objects, therefore tenant is irrelevant 807 cfg := NewClassBasedModuleConfig(class, mod.Name(), "", targetVector) 808 return vectorizer.VectorizeInput(ctx, input, cfg) 809 } 810 } 811 } 812 } 813 814 return nil, fmt.Errorf("VectorFromInput was called without vectorizer") 815 } 816 817 // ParseClassifierSettings parses and adds classifier specific settings 818 func (p *Provider) ParseClassifierSettings(name string, 819 params *models.Classification, 820 ) error { 821 class, err := p.getClass(params.Class) 822 if err != nil { 823 return err 824 } 825 for _, module := range p.GetAll() { 826 if p.shouldIncludeClassArgument(class, module.Name(), module.Type()) { 827 if c, ok := module.(modulecapabilities.ClassificationProvider); ok { 828 for _, classifier := range c.Classifiers() { 829 if classifier != nil && classifier.Name() == name { 830 return classifier.ParseClassifierSettings(params) 831 } 832 } 833 } 834 } 835 } 836 return nil 837 } 838 839 // GetClassificationFn returns given module's classification 840 func (p *Provider) GetClassificationFn(className, name string, 841 params modulecapabilities.ClassifyParams, 842 ) (modulecapabilities.ClassifyItemFn, error) { 843 class, err := p.getClass(className) 844 if err != nil { 845 return nil, err 846 } 847 for _, module := range p.GetAll() { 848 if p.shouldIncludeClassArgument(class, module.Name(), module.Type()) { 849 if c, ok := module.(modulecapabilities.ClassificationProvider); ok { 850 for _, classifier := range c.Classifiers() { 851 if classifier != nil && classifier.Name() == name { 852 return classifier.ClassifyFn(params) 853 } 854 } 855 } 856 } 857 } 858 return nil, errors.Errorf("classifier %s not found", name) 859 } 860 861 // GetMeta returns meta information about modules 862 func (p *Provider) GetMeta() (map[string]interface{}, error) { 863 metaInfos := map[string]interface{}{} 864 for _, module := range p.GetAll() { 865 if c, ok := module.(modulecapabilities.MetaProvider); ok { 866 meta, err := c.MetaInfo() 867 if err != nil { 868 return nil, err 869 } 870 metaInfos[module.Name()] = meta 871 } 872 } 873 return metaInfos, nil 874 } 875 876 func (p *Provider) getClass(className string) (*models.Class, error) { 877 sch := p.schemaGetter.GetSchemaSkipAuth() 878 class := sch.FindClassByName(schema.ClassName(className)) 879 if class == nil { 880 return nil, errors.Errorf("class %q not found in schema", className) 881 } 882 return class, nil 883 } 884 885 func (p *Provider) HasMultipleVectorizers() bool { 886 return p.hasMultipleVectorizers 887 } 888 889 func (p *Provider) BackupBackend(backend string) (modulecapabilities.BackupBackend, error) { 890 if module := p.GetByName(backend); module != nil { 891 if module.Type() == modulecapabilities.Backup { 892 if backend, ok := module.(modulecapabilities.BackupBackend); ok { 893 return backend, nil 894 } 895 } 896 } 897 return nil, errors.Errorf("backup: %s not found", backend) 898 }