github.com/weaviate/weaviate@v1.24.6/usecases/modules/modules.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package modules
    13  
    14  import (
    15  	"context"
    16  	"fmt"
    17  	"regexp"
    18  	"sync"
    19  
    20  	"github.com/pkg/errors"
    21  	"github.com/sirupsen/logrus"
    22  	"github.com/tailor-inc/graphql"
    23  	"github.com/tailor-inc/graphql/language/ast"
    24  	"github.com/weaviate/weaviate/entities/models"
    25  	"github.com/weaviate/weaviate/entities/modulecapabilities"
    26  	"github.com/weaviate/weaviate/entities/moduletools"
    27  	"github.com/weaviate/weaviate/entities/schema"
    28  	"github.com/weaviate/weaviate/entities/search"
    29  	"github.com/weaviate/weaviate/usecases/modulecomponents"
    30  )
    31  
    32  var (
    33  	internalSearchers = []string{
    34  		"nearObject", "nearVector", "where", "group", "limit", "offset",
    35  		"after", "groupBy", "bm25", "hybrid",
    36  	}
    37  	internalAdditionalProperties = []string{"classification", "certainty", "id", "distance", "group"}
    38  )
    39  
    40  type Provider struct {
    41  	vectorsLock               sync.RWMutex
    42  	registered                map[string]modulecapabilities.Module
    43  	altNames                  map[string]string
    44  	schemaGetter              schemaGetter
    45  	hasMultipleVectorizers    bool
    46  	targetVectorNameValidator *regexp.Regexp
    47  }
    48  
    49  type schemaGetter interface {
    50  	GetSchemaSkipAuth() schema.Schema
    51  }
    52  
    53  func NewProvider() *Provider {
    54  	return &Provider{
    55  		registered:                map[string]modulecapabilities.Module{},
    56  		altNames:                  map[string]string{},
    57  		targetVectorNameValidator: regexp.MustCompile(`^` + schema.TargetVectorNameRegex + `$`),
    58  	}
    59  }
    60  
    61  func (p *Provider) Register(mod modulecapabilities.Module) {
    62  	p.registered[mod.Name()] = mod
    63  	if modHasAltNames, ok := mod.(modulecapabilities.ModuleHasAltNames); ok {
    64  		for _, altName := range modHasAltNames.AltNames() {
    65  			p.altNames[altName] = mod.Name()
    66  		}
    67  	}
    68  }
    69  
    70  func (p *Provider) GetByName(name string) modulecapabilities.Module {
    71  	if mod, ok := p.registered[name]; ok {
    72  		return mod
    73  	}
    74  	if origName, ok := p.altNames[name]; ok {
    75  		return p.registered[origName]
    76  	}
    77  	return nil
    78  }
    79  
    80  func (p *Provider) GetAll() []modulecapabilities.Module {
    81  	out := make([]modulecapabilities.Module, len(p.registered))
    82  	i := 0
    83  	for _, mod := range p.registered {
    84  		out[i] = mod
    85  		i++
    86  	}
    87  
    88  	return out
    89  }
    90  
    91  func (p *Provider) GetAllExclude(module string) []modulecapabilities.Module {
    92  	filtered := []modulecapabilities.Module{}
    93  	for _, mod := range p.GetAll() {
    94  		if mod.Name() != module {
    95  			filtered = append(filtered, mod)
    96  		}
    97  	}
    98  	return filtered
    99  }
   100  
   101  func (p *Provider) SetSchemaGetter(sg schemaGetter) {
   102  	p.schemaGetter = sg
   103  }
   104  
   105  func (p *Provider) Init(ctx context.Context,
   106  	params moduletools.ModuleInitParams, logger logrus.FieldLogger,
   107  ) error {
   108  	for i, mod := range p.GetAll() {
   109  		if err := mod.Init(ctx, params); err != nil {
   110  			return errors.Wrapf(err, "init module %d (%q)", i, mod.Name())
   111  		} else {
   112  			logger.WithField("action", "startup").
   113  				WithField("module", mod.Name()).
   114  				Debug("initialized module")
   115  		}
   116  	}
   117  	for i, mod := range p.GetAll() {
   118  		if modExtension, ok := mod.(modulecapabilities.ModuleExtension); ok {
   119  			if err := modExtension.InitExtension(p.GetAllExclude(mod.Name())); err != nil {
   120  				return errors.Wrapf(err, "init module extension %d (%q)", i, mod.Name())
   121  			} else {
   122  				logger.WithField("action", "startup").
   123  					WithField("module", mod.Name()).
   124  					Debug("initialized module extension")
   125  			}
   126  		}
   127  	}
   128  	for i, mod := range p.GetAll() {
   129  		if modDependency, ok := mod.(modulecapabilities.ModuleDependency); ok {
   130  			if err := modDependency.InitDependency(p.GetAllExclude(mod.Name())); err != nil {
   131  				return errors.Wrapf(err, "init module dependency %d (%q)", i, mod.Name())
   132  			} else {
   133  				logger.WithField("action", "startup").
   134  					WithField("module", mod.Name()).
   135  					Debug("initialized module dependency")
   136  			}
   137  		}
   138  	}
   139  	if err := p.validate(); err != nil {
   140  		return errors.Wrap(err, "validate modules")
   141  	}
   142  	if p.HasMultipleVectorizers() {
   143  		logger.Warn("Multiple vector spaces are present, GraphQL Explore and REST API list objects endpoint module include params has been disabled as a result.")
   144  	}
   145  	return nil
   146  }
   147  
   148  func (p *Provider) validate() error {
   149  	searchers := map[string][]string{}
   150  	additionalGraphQLProps := map[string][]string{}
   151  	additionalRestAPIProps := map[string][]string{}
   152  	for _, mod := range p.GetAll() {
   153  		if module, ok := mod.(modulecapabilities.GraphQLArguments); ok {
   154  			allArguments := []string{}
   155  			for paraName, argument := range module.Arguments() {
   156  				if argument.ExtractFunction != nil {
   157  					allArguments = append(allArguments, paraName)
   158  				}
   159  			}
   160  			searchers = p.scanProperties(searchers, allArguments, mod.Name())
   161  		}
   162  		if module, ok := mod.(modulecapabilities.AdditionalProperties); ok {
   163  			allAdditionalRestAPIProps, allAdditionalGrapQLProps := p.getAdditionalProps(module.AdditionalProperties())
   164  			additionalGraphQLProps = p.scanProperties(additionalGraphQLProps,
   165  				allAdditionalGrapQLProps, mod.Name())
   166  			additionalRestAPIProps = p.scanProperties(additionalRestAPIProps,
   167  				allAdditionalRestAPIProps, mod.Name())
   168  		}
   169  	}
   170  
   171  	var errorMessages []string
   172  	errorMessages = append(errorMessages,
   173  		p.validateModules("searcher", searchers, internalSearchers)...)
   174  	errorMessages = append(errorMessages,
   175  		p.validateModules("graphql additional property", additionalGraphQLProps, internalAdditionalProperties)...)
   176  	errorMessages = append(errorMessages,
   177  		p.validateModules("rest api additional property", additionalRestAPIProps, internalAdditionalProperties)...)
   178  	if len(errorMessages) > 0 {
   179  		return errors.Errorf("%v", errorMessages)
   180  	}
   181  
   182  	return nil
   183  }
   184  
   185  func (p *Provider) scanProperties(result map[string][]string, properties []string, module string) map[string][]string {
   186  	for i := range properties {
   187  		if result[properties[i]] == nil {
   188  			result[properties[i]] = []string{}
   189  		}
   190  		modules := result[properties[i]]
   191  		modules = append(modules, module)
   192  		result[properties[i]] = modules
   193  	}
   194  	return result
   195  }
   196  
   197  func (p *Provider) getAdditionalProps(additionalProps map[string]modulecapabilities.AdditionalProperty) ([]string, []string) {
   198  	restProps := []string{}
   199  	graphQLProps := []string{}
   200  
   201  	for _, additionalProperty := range additionalProps {
   202  		if additionalProperty.RestNames != nil {
   203  			restProps = append(restProps, additionalProperty.RestNames...)
   204  		}
   205  		if additionalProperty.GraphQLNames != nil {
   206  			graphQLProps = append(graphQLProps, additionalProperty.GraphQLNames...)
   207  		}
   208  	}
   209  	return restProps, graphQLProps
   210  }
   211  
   212  func (p *Provider) validateModules(name string, properties map[string][]string, internalProperties []string) []string {
   213  	errorMessages := []string{}
   214  	for propertyName, modules := range properties {
   215  		for i := range internalProperties {
   216  			if internalProperties[i] == propertyName {
   217  				errorMessages = append(errorMessages,
   218  					fmt.Sprintf("%s: %s conflicts with weaviate's internal searcher in modules: %v",
   219  						name, propertyName, modules))
   220  			}
   221  		}
   222  		if len(modules) > 1 {
   223  			p.hasMultipleVectorizers = true
   224  		}
   225  		for _, moduleName := range modules {
   226  			moduleType := p.GetByName(moduleName).Type()
   227  			if p.moduleProvidesMultipleVectorizers(moduleType) {
   228  				p.hasMultipleVectorizers = true
   229  			}
   230  		}
   231  	}
   232  	return errorMessages
   233  }
   234  
   235  func (p *Provider) moduleProvidesMultipleVectorizers(moduleType modulecapabilities.ModuleType) bool {
   236  	return moduleType == modulecapabilities.Text2MultiVec
   237  }
   238  
   239  func (p *Provider) isOnlyOneModuleEnabledOfAGivenType(moduleType modulecapabilities.ModuleType) bool {
   240  	i := 0
   241  	for _, mod := range p.registered {
   242  		if mod.Type() == moduleType {
   243  			i++
   244  		}
   245  	}
   246  	return i == 1
   247  }
   248  
   249  func (p *Provider) isVectorizerModule(moduleType modulecapabilities.ModuleType) bool {
   250  	switch moduleType {
   251  	case modulecapabilities.Text2Vec,
   252  		modulecapabilities.Img2Vec,
   253  		modulecapabilities.Multi2Vec,
   254  		modulecapabilities.Text2MultiVec,
   255  		modulecapabilities.Ref2Vec:
   256  		return true
   257  	default:
   258  		return false
   259  	}
   260  }
   261  
   262  func (p *Provider) shouldIncludeClassArgument(class *models.Class, module string,
   263  	moduleType modulecapabilities.ModuleType,
   264  ) bool {
   265  	if p.isVectorizerModule(moduleType) {
   266  		for _, vectorConfig := range class.VectorConfig {
   267  			if vectorizer, ok := vectorConfig.Vectorizer.(map[string]interface{}); ok {
   268  				if _, ok := vectorizer[module]; ok {
   269  					return true
   270  				}
   271  			}
   272  		}
   273  		return class.Vectorizer == module
   274  	}
   275  	if moduleConfig, ok := class.ModuleConfig.(map[string]interface{}); ok {
   276  		existsConfigForModule := moduleConfig[module] != nil
   277  		if existsConfigForModule {
   278  			return true
   279  		}
   280  	}
   281  	// Allow Text2Text (Generative, QnA, Summarize, NER) modules to be registered to a given class
   282  	// only if there's no configuration present and there's only one module of a given type enabled
   283  	return p.isOnlyOneModuleEnabledOfAGivenType(moduleType)
   284  }
   285  
   286  func (p *Provider) shouldCrossClassIncludeClassArgument(class *models.Class, module string,
   287  	moduleType modulecapabilities.ModuleType,
   288  ) bool {
   289  	if class == nil {
   290  		return !p.HasMultipleVectorizers()
   291  	}
   292  	return p.shouldIncludeClassArgument(class, module, moduleType)
   293  }
   294  
   295  func (p *Provider) shouldIncludeArgument(schema *models.Schema, module string,
   296  	moduleType modulecapabilities.ModuleType,
   297  ) bool {
   298  	for _, c := range schema.Classes {
   299  		if p.shouldIncludeClassArgument(c, module, moduleType) {
   300  			return true
   301  		}
   302  	}
   303  	return false
   304  }
   305  
   306  func (p *Provider) shouldAddGenericArgument(class *models.Class, moduleType modulecapabilities.ModuleType) bool {
   307  	return p.hasMultipleVectorizersConfig(class) && p.isVectorizerModule(moduleType)
   308  }
   309  
   310  func (p *Provider) hasMultipleVectorizersConfig(class *models.Class) bool {
   311  	return len(class.VectorConfig) > 0
   312  }
   313  
   314  func (p *Provider) shouldCrossClassAddGenericArgument(schema *models.Schema, moduleType modulecapabilities.ModuleType) bool {
   315  	for _, c := range schema.Classes {
   316  		if p.shouldAddGenericArgument(c, moduleType) {
   317  			return true
   318  		}
   319  	}
   320  	return false
   321  }
   322  
   323  func (p *Provider) getGenericArgument(name, className string,
   324  	argumentType modulecomponents.ArgumentType,
   325  ) *graphql.ArgumentConfig {
   326  	var nearTextTransformer modulecapabilities.TextTransform
   327  	if name == "nearText" {
   328  		// nearText argument might be exposed with an extension, we need to check
   329  		// if text transformers module is enabled if so then we need to init nearText
   330  		// argument with this extension
   331  		for _, mod := range p.GetAll() {
   332  			if arg, ok := mod.(modulecapabilities.TextTransformers); ok {
   333  				if arg != nil && arg.TextTransformers() != nil {
   334  					nearTextTransformer = arg.TextTransformers()["nearText"]
   335  					break
   336  				}
   337  			}
   338  		}
   339  	}
   340  	return modulecomponents.GetGenericArgument(name, className, argumentType, nearTextTransformer)
   341  }
   342  
   343  func (p *Provider) getGenericAdditionalProperty(name string, class *models.Class) *modulecapabilities.AdditionalProperty {
   344  	if p.hasMultipleVectorizersConfig(class) {
   345  		return modulecomponents.GetGenericAdditionalProperty(name, class.Class)
   346  	}
   347  	return nil
   348  }
   349  
   350  // GetArguments provides GraphQL Get arguments
   351  func (p *Provider) GetArguments(class *models.Class) map[string]*graphql.ArgumentConfig {
   352  	arguments := map[string]*graphql.ArgumentConfig{}
   353  	for _, module := range p.GetAll() {
   354  		if p.shouldIncludeClassArgument(class, module.Name(), module.Type()) {
   355  			if arg, ok := module.(modulecapabilities.GraphQLArguments); ok {
   356  				for name, argument := range arg.Arguments() {
   357  					if argument.GetArgumentsFunction != nil {
   358  						if p.shouldAddGenericArgument(class, module.Type()) {
   359  							if _, ok := arguments[name]; !ok {
   360  								arguments[name] = p.getGenericArgument(name, class.Class, modulecomponents.Get)
   361  							}
   362  						} else {
   363  							arguments[name] = argument.GetArgumentsFunction(class.Class)
   364  						}
   365  					}
   366  				}
   367  			}
   368  		}
   369  	}
   370  	return arguments
   371  }
   372  
   373  // AggregateArguments provides GraphQL Aggregate arguments
   374  func (p *Provider) AggregateArguments(class *models.Class) map[string]*graphql.ArgumentConfig {
   375  	arguments := map[string]*graphql.ArgumentConfig{}
   376  	for _, module := range p.GetAll() {
   377  		if p.shouldIncludeClassArgument(class, module.Name(), module.Type()) {
   378  			if arg, ok := module.(modulecapabilities.GraphQLArguments); ok {
   379  				for name, argument := range arg.Arguments() {
   380  					if argument.AggregateArgumentsFunction != nil {
   381  						if p.shouldAddGenericArgument(class, module.Type()) {
   382  							if _, ok := arguments[name]; !ok {
   383  								arguments[name] = p.getGenericArgument(name, class.Class, modulecomponents.Aggregate)
   384  							}
   385  						} else {
   386  							arguments[name] = argument.AggregateArgumentsFunction(class.Class)
   387  						}
   388  					}
   389  				}
   390  			}
   391  		}
   392  	}
   393  	return arguments
   394  }
   395  
   396  // ExploreArguments provides GraphQL Explore arguments
   397  func (p *Provider) ExploreArguments(schema *models.Schema) map[string]*graphql.ArgumentConfig {
   398  	arguments := map[string]*graphql.ArgumentConfig{}
   399  	for _, module := range p.GetAll() {
   400  		if p.shouldIncludeArgument(schema, module.Name(), module.Type()) {
   401  			if arg, ok := module.(modulecapabilities.GraphQLArguments); ok {
   402  				for name, argument := range arg.Arguments() {
   403  					if argument.ExploreArgumentsFunction != nil {
   404  						if p.shouldCrossClassAddGenericArgument(schema, module.Type()) {
   405  							if _, ok := arguments[name]; !ok {
   406  								arguments[name] = p.getGenericArgument(name, "", modulecomponents.Explore)
   407  							}
   408  						} else {
   409  							arguments[name] = argument.ExploreArgumentsFunction()
   410  						}
   411  					}
   412  				}
   413  			}
   414  		}
   415  	}
   416  	return arguments
   417  }
   418  
   419  // CrossClassExtractSearchParams extracts GraphQL arguments from modules without
   420  // being specific to any one class and it's configuration. This is used in
   421  // Explore() { } for example
   422  func (p *Provider) CrossClassExtractSearchParams(arguments map[string]interface{}) map[string]interface{} {
   423  	return p.extractSearchParams(arguments, nil)
   424  }
   425  
   426  // ExtractSearchParams extracts GraphQL arguments
   427  func (p *Provider) ExtractSearchParams(arguments map[string]interface{}, className string) map[string]interface{} {
   428  	exractedParams := map[string]interface{}{}
   429  	class, err := p.getClass(className)
   430  	if err != nil {
   431  		return exractedParams
   432  	}
   433  	return p.extractSearchParams(arguments, class)
   434  }
   435  
   436  func (p *Provider) extractSearchParams(arguments map[string]interface{}, class *models.Class) map[string]interface{} {
   437  	exractedParams := map[string]interface{}{}
   438  	for _, module := range p.GetAll() {
   439  		if p.shouldCrossClassIncludeClassArgument(class, module.Name(), module.Type()) {
   440  			if args, ok := module.(modulecapabilities.GraphQLArguments); ok {
   441  				for paramName, argument := range args.Arguments() {
   442  					if param, ok := arguments[paramName]; ok && argument.ExtractFunction != nil {
   443  						extracted := argument.ExtractFunction(param.(map[string]interface{}))
   444  						exractedParams[paramName] = extracted
   445  					}
   446  				}
   447  			}
   448  		}
   449  	}
   450  	return exractedParams
   451  }
   452  
   453  // CrossClassValidateSearchParam validates module parameters without
   454  // being specific to any one class and it's configuration. This is used in
   455  // Explore() { } for example
   456  func (p *Provider) CrossClassValidateSearchParam(name string, value interface{}) error {
   457  	return p.validateSearchParam(name, value, nil)
   458  }
   459  
   460  // ValidateSearchParam validates module parameters
   461  func (p *Provider) ValidateSearchParam(name string, value interface{}, className string) error {
   462  	class, err := p.getClass(className)
   463  	if err != nil {
   464  		return err
   465  	}
   466  
   467  	return p.validateSearchParam(name, value, class)
   468  }
   469  
   470  func (p *Provider) validateSearchParam(name string, value interface{}, class *models.Class) error {
   471  	for _, module := range p.GetAll() {
   472  		if p.shouldCrossClassIncludeClassArgument(class, module.Name(), module.Type()) {
   473  			if args, ok := module.(modulecapabilities.GraphQLArguments); ok {
   474  				for paramName, argument := range args.Arguments() {
   475  					if paramName == name && argument.ValidateFunction != nil {
   476  						return argument.ValidateFunction(value)
   477  					}
   478  				}
   479  			}
   480  		}
   481  	}
   482  
   483  	panic("ValidateParam was called without any known params present")
   484  }
   485  
   486  // GetAdditionalFields provides GraphQL Get additional fields
   487  func (p *Provider) GetAdditionalFields(class *models.Class) map[string]*graphql.Field {
   488  	additionalProperties := map[string]*graphql.Field{}
   489  	for _, module := range p.GetAll() {
   490  		if p.shouldIncludeClassArgument(class, module.Name(), module.Type()) {
   491  			if arg, ok := module.(modulecapabilities.AdditionalProperties); ok {
   492  				for name, additionalProperty := range arg.AdditionalProperties() {
   493  					if additionalProperty.GraphQLFieldFunction != nil {
   494  						if genericAdditionalProperty := p.getGenericAdditionalProperty(name, class); genericAdditionalProperty != nil {
   495  							if genericAdditionalProperty.GraphQLFieldFunction != nil {
   496  								if _, ok := additionalProperties[name]; !ok {
   497  									additionalProperties[name] = genericAdditionalProperty.GraphQLFieldFunction(class.Class)
   498  								}
   499  							}
   500  						} else {
   501  							additionalProperties[name] = additionalProperty.GraphQLFieldFunction(class.Class)
   502  						}
   503  					}
   504  				}
   505  			}
   506  		}
   507  	}
   508  	return additionalProperties
   509  }
   510  
   511  // ExtractAdditionalField extracts additional properties from given graphql arguments
   512  func (p *Provider) ExtractAdditionalField(className, name string, params []*ast.Argument) interface{} {
   513  	class, err := p.getClass(className)
   514  	if err != nil {
   515  		return err
   516  	}
   517  	for _, module := range p.GetAll() {
   518  		if p.shouldIncludeClassArgument(class, module.Name(), module.Type()) {
   519  			if arg, ok := module.(modulecapabilities.AdditionalProperties); ok {
   520  				if additionalProperties := arg.AdditionalProperties(); len(additionalProperties) > 0 {
   521  					if additionalProperty, ok := additionalProperties[name]; ok {
   522  						return additionalProperty.GraphQLExtractFunction(params)
   523  					}
   524  				}
   525  			}
   526  		}
   527  	}
   528  	return nil
   529  }
   530  
   531  // GetObjectAdditionalExtend extends rest api get queries with additional properties
   532  func (p *Provider) GetObjectAdditionalExtend(ctx context.Context,
   533  	in *search.Result, moduleParams map[string]interface{},
   534  ) (*search.Result, error) {
   535  	resArray, err := p.additionalExtend(ctx, search.Results{*in}, moduleParams, nil, "ObjectGet", nil)
   536  	if err != nil {
   537  		return nil, err
   538  	}
   539  	return &resArray[0], nil
   540  }
   541  
   542  // ListObjectsAdditionalExtend extends rest api list queries with additional properties
   543  func (p *Provider) ListObjectsAdditionalExtend(ctx context.Context,
   544  	in search.Results, moduleParams map[string]interface{},
   545  ) (search.Results, error) {
   546  	return p.additionalExtend(ctx, in, moduleParams, nil, "ObjectList", nil)
   547  }
   548  
   549  // GetExploreAdditionalExtend extends graphql api get queries with additional properties
   550  func (p *Provider) GetExploreAdditionalExtend(ctx context.Context, in []search.Result,
   551  	moduleParams map[string]interface{}, searchVector []float32,
   552  	argumentModuleParams map[string]interface{},
   553  ) ([]search.Result, error) {
   554  	return p.additionalExtend(ctx, in, moduleParams, searchVector, "ExploreGet", argumentModuleParams)
   555  }
   556  
   557  // ListExploreAdditionalExtend extends graphql api list queries with additional properties
   558  func (p *Provider) ListExploreAdditionalExtend(ctx context.Context, in []search.Result,
   559  	moduleParams map[string]interface{},
   560  	argumentModuleParams map[string]interface{},
   561  ) ([]search.Result, error) {
   562  	return p.additionalExtend(ctx, in, moduleParams, nil, "ExploreList", argumentModuleParams)
   563  }
   564  
   565  func (p *Provider) additionalExtend(ctx context.Context, in []search.Result,
   566  	moduleParams map[string]interface{}, searchVector []float32,
   567  	capability string, argumentModuleParams map[string]interface{},
   568  ) ([]search.Result, error) {
   569  	toBeExtended := in
   570  	if len(toBeExtended) > 0 {
   571  		class, err := p.getClassFromSearchResult(toBeExtended)
   572  		if err != nil {
   573  			return nil, err
   574  		}
   575  		allAdditionalProperties := map[string]modulecapabilities.AdditionalProperty{}
   576  		for _, module := range p.GetAll() {
   577  			if p.shouldIncludeClassArgument(class, module.Name(), module.Type()) {
   578  				if arg, ok := module.(modulecapabilities.AdditionalProperties); ok {
   579  					if arg != nil && arg.AdditionalProperties() != nil {
   580  						for name, additionalProperty := range arg.AdditionalProperties() {
   581  							allAdditionalProperties[name] = additionalProperty
   582  						}
   583  					}
   584  				}
   585  			}
   586  		}
   587  		if len(allAdditionalProperties) > 0 {
   588  			if err := p.checkCapabilities(allAdditionalProperties, moduleParams, capability); err != nil {
   589  				return nil, err
   590  			}
   591  			cfg := NewClassBasedModuleConfig(class, "", "", "")
   592  			for name, value := range moduleParams {
   593  				additionalPropertyFn := p.getAdditionalPropertyFn(allAdditionalProperties[name], capability)
   594  				if additionalPropertyFn != nil && value != nil {
   595  					searchValue := value
   596  					if searchVectorValue, ok := value.(modulecapabilities.AdditionalPropertyWithSearchVector); ok {
   597  						searchVectorValue.SetSearchVector(searchVector)
   598  						searchValue = searchVectorValue
   599  					}
   600  					resArray, err := additionalPropertyFn(ctx, toBeExtended, searchValue, nil, argumentModuleParams, cfg)
   601  					if err != nil {
   602  						return nil, errors.Errorf("extend %s: %v", name, err)
   603  					}
   604  					toBeExtended = resArray
   605  				} else {
   606  					return nil, errors.Errorf("unknown capability: %s", name)
   607  				}
   608  			}
   609  		}
   610  	}
   611  	return toBeExtended, nil
   612  }
   613  
   614  func (p *Provider) getClassFromSearchResult(in []search.Result) (*models.Class, error) {
   615  	if len(in) > 0 {
   616  		return p.getClass(in[0].ClassName)
   617  	}
   618  	return nil, errors.Errorf("unknown class")
   619  }
   620  
   621  func (p *Provider) checkCapabilities(additionalProperties map[string]modulecapabilities.AdditionalProperty,
   622  	moduleParams map[string]interface{}, capability string,
   623  ) error {
   624  	for name := range moduleParams {
   625  		additionalPropertyFn := p.getAdditionalPropertyFn(additionalProperties[name], capability)
   626  		if additionalPropertyFn == nil {
   627  			return errors.Errorf("unknown capability: %s", name)
   628  		}
   629  	}
   630  	return nil
   631  }
   632  
   633  func (p *Provider) getAdditionalPropertyFn(
   634  	additionalProperty modulecapabilities.AdditionalProperty,
   635  	capability string,
   636  ) modulecapabilities.AdditionalPropertyFn {
   637  	switch capability {
   638  	case "ObjectGet":
   639  		return additionalProperty.SearchFunctions.ObjectGet
   640  	case "ObjectList":
   641  		return additionalProperty.SearchFunctions.ObjectList
   642  	case "ExploreGet":
   643  		return additionalProperty.SearchFunctions.ExploreGet
   644  	case "ExploreList":
   645  		return additionalProperty.SearchFunctions.ExploreList
   646  	default:
   647  		return nil
   648  	}
   649  }
   650  
   651  // GraphQLAdditionalFieldNames get's all additional field names used in graphql
   652  func (p *Provider) GraphQLAdditionalFieldNames() []string {
   653  	additionalPropertiesNames := []string{}
   654  	for _, module := range p.GetAll() {
   655  		if arg, ok := module.(modulecapabilities.AdditionalProperties); ok {
   656  			for _, additionalProperty := range arg.AdditionalProperties() {
   657  				if additionalProperty.GraphQLNames != nil {
   658  					additionalPropertiesNames = append(additionalPropertiesNames, additionalProperty.GraphQLNames...)
   659  				}
   660  			}
   661  		}
   662  	}
   663  	return additionalPropertiesNames
   664  }
   665  
   666  // RestApiAdditionalProperties get's all rest specific additional properties with their
   667  // default values
   668  func (p *Provider) RestApiAdditionalProperties(includeProp string, class *models.Class) map[string]interface{} {
   669  	moduleParams := map[string]interface{}{}
   670  	for _, module := range p.GetAll() {
   671  		if p.shouldCrossClassIncludeClassArgument(class, module.Name(), module.Type()) {
   672  			if arg, ok := module.(modulecapabilities.AdditionalProperties); ok {
   673  				for name, additionalProperty := range arg.AdditionalProperties() {
   674  					for _, includePropName := range additionalProperty.RestNames {
   675  						if includePropName == includeProp && moduleParams[name] == nil {
   676  							moduleParams[name] = additionalProperty.DefaultValue
   677  						}
   678  					}
   679  				}
   680  			}
   681  		}
   682  	}
   683  	return moduleParams
   684  }
   685  
   686  // VectorFromSearchParam gets a vector for a given argument. This is used in
   687  // Get { Class() } for example
   688  func (p *Provider) VectorFromSearchParam(ctx context.Context,
   689  	className string, param string, params interface{},
   690  	findVectorFn modulecapabilities.FindVectorFn, tenant string,
   691  ) ([]float32, string, error) {
   692  	class, err := p.getClass(className)
   693  	if err != nil {
   694  		return nil, "", err
   695  	}
   696  	targetVector, err := p.getTargetVector(class, params)
   697  	if err != nil {
   698  		return nil, "", err
   699  	}
   700  	targetModule := p.getModuleNameForTargetVector(class, targetVector)
   701  
   702  	for _, mod := range p.GetAll() {
   703  		if p.shouldIncludeClassArgument(class, mod.Name(), mod.Type()) {
   704  			var moduleName string
   705  			var vectorSearches modulecapabilities.ArgumentVectorForParams
   706  			if searcher, ok := mod.(modulecapabilities.Searcher); ok {
   707  				if mod.Name() == targetModule {
   708  					moduleName = mod.Name()
   709  					vectorSearches = searcher.VectorSearches()
   710  				}
   711  			} else if searchers, ok := mod.(modulecapabilities.DependencySearcher); ok {
   712  				if dependencySearchers := searchers.VectorSearches(); dependencySearchers != nil {
   713  					moduleName = targetModule
   714  					vectorSearches = dependencySearchers[targetModule]
   715  				}
   716  			}
   717  			if vectorSearches != nil {
   718  				if searchVectorFn := vectorSearches[param]; searchVectorFn != nil {
   719  					cfg := NewClassBasedModuleConfig(class, moduleName, tenant, targetVector)
   720  					vector, err := searchVectorFn(ctx, params, class.Class, findVectorFn, cfg)
   721  					if err != nil {
   722  						return nil, "", errors.Errorf("vectorize params: %v", err)
   723  					}
   724  					return vector, targetVector, nil
   725  				}
   726  			}
   727  		}
   728  	}
   729  
   730  	panic("VectorFromParams was called without any known params present")
   731  }
   732  
   733  // CrossClassVectorFromSearchParam gets a vector for a given argument without
   734  // being specific to any one class and it's configuration. This is used in
   735  // Explore() { } for example
   736  func (p *Provider) CrossClassVectorFromSearchParam(ctx context.Context,
   737  	param string, params interface{},
   738  	findVectorFn modulecapabilities.FindVectorFn,
   739  ) ([]float32, string, error) {
   740  	for _, mod := range p.GetAll() {
   741  		if searcher, ok := mod.(modulecapabilities.Searcher); ok {
   742  			if vectorSearches := searcher.VectorSearches(); vectorSearches != nil {
   743  				if searchVectorFn := vectorSearches[param]; searchVectorFn != nil {
   744  					cfg := NewCrossClassModuleConfig()
   745  					vector, err := searchVectorFn(ctx, params, "", findVectorFn, cfg)
   746  					if err != nil {
   747  						return nil, "", errors.Errorf("vectorize params: %v", err)
   748  					}
   749  					targetVector, err := p.getTargetVector(nil, params)
   750  					if err != nil {
   751  						return nil, "", errors.Errorf("get target vector: %v", err)
   752  					}
   753  					return vector, targetVector, nil
   754  				}
   755  			}
   756  		}
   757  	}
   758  
   759  	panic("VectorFromParams was called without any known params present")
   760  }
   761  
   762  func (p *Provider) getTargetVector(class *models.Class, params interface{}) (string, error) {
   763  	if nearParam, ok := params.(modulecapabilities.NearParam); ok && len(nearParam.GetTargetVectors()) == 1 {
   764  		return nearParam.GetTargetVectors()[0], nil
   765  	}
   766  	if class != nil {
   767  		if len(class.VectorConfig) > 1 {
   768  			return "", fmt.Errorf("multiple vectorizers configuration found, please specify target vector name")
   769  		}
   770  
   771  		if len(class.VectorConfig) == 1 {
   772  			for name := range class.VectorConfig {
   773  				return name, nil
   774  			}
   775  		}
   776  	}
   777  	return "", nil
   778  }
   779  
   780  func (p *Provider) getModuleNameForTargetVector(class *models.Class, targetVector string) string {
   781  	if len(class.VectorConfig) > 0 {
   782  		if vectorConfig, ok := class.VectorConfig[targetVector]; ok {
   783  			if vectorizer, ok := vectorConfig.Vectorizer.(map[string]interface{}); ok && len(vectorizer) == 1 {
   784  				for moduleName := range vectorizer {
   785  					return moduleName
   786  				}
   787  			}
   788  		}
   789  	}
   790  	return class.Vectorizer
   791  }
   792  
   793  func (p *Provider) VectorFromInput(ctx context.Context,
   794  	className, input, targetVector string,
   795  ) ([]float32, error) {
   796  	class, err := p.getClass(className)
   797  	if err != nil {
   798  		return nil, err
   799  	}
   800  	targetModule := p.getModuleNameForTargetVector(class, targetVector)
   801  
   802  	for _, mod := range p.GetAll() {
   803  		if mod.Name() == targetModule {
   804  			if p.shouldIncludeClassArgument(class, mod.Name(), mod.Type()) {
   805  				if vectorizer, ok := mod.(modulecapabilities.InputVectorizer); ok {
   806  					// does not access any objects, therefore tenant is irrelevant
   807  					cfg := NewClassBasedModuleConfig(class, mod.Name(), "", targetVector)
   808  					return vectorizer.VectorizeInput(ctx, input, cfg)
   809  				}
   810  			}
   811  		}
   812  	}
   813  
   814  	return nil, fmt.Errorf("VectorFromInput was called without vectorizer")
   815  }
   816  
   817  // ParseClassifierSettings parses and adds classifier specific settings
   818  func (p *Provider) ParseClassifierSettings(name string,
   819  	params *models.Classification,
   820  ) error {
   821  	class, err := p.getClass(params.Class)
   822  	if err != nil {
   823  		return err
   824  	}
   825  	for _, module := range p.GetAll() {
   826  		if p.shouldIncludeClassArgument(class, module.Name(), module.Type()) {
   827  			if c, ok := module.(modulecapabilities.ClassificationProvider); ok {
   828  				for _, classifier := range c.Classifiers() {
   829  					if classifier != nil && classifier.Name() == name {
   830  						return classifier.ParseClassifierSettings(params)
   831  					}
   832  				}
   833  			}
   834  		}
   835  	}
   836  	return nil
   837  }
   838  
   839  // GetClassificationFn returns given module's classification
   840  func (p *Provider) GetClassificationFn(className, name string,
   841  	params modulecapabilities.ClassifyParams,
   842  ) (modulecapabilities.ClassifyItemFn, error) {
   843  	class, err := p.getClass(className)
   844  	if err != nil {
   845  		return nil, err
   846  	}
   847  	for _, module := range p.GetAll() {
   848  		if p.shouldIncludeClassArgument(class, module.Name(), module.Type()) {
   849  			if c, ok := module.(modulecapabilities.ClassificationProvider); ok {
   850  				for _, classifier := range c.Classifiers() {
   851  					if classifier != nil && classifier.Name() == name {
   852  						return classifier.ClassifyFn(params)
   853  					}
   854  				}
   855  			}
   856  		}
   857  	}
   858  	return nil, errors.Errorf("classifier %s not found", name)
   859  }
   860  
   861  // GetMeta returns meta information about modules
   862  func (p *Provider) GetMeta() (map[string]interface{}, error) {
   863  	metaInfos := map[string]interface{}{}
   864  	for _, module := range p.GetAll() {
   865  		if c, ok := module.(modulecapabilities.MetaProvider); ok {
   866  			meta, err := c.MetaInfo()
   867  			if err != nil {
   868  				return nil, err
   869  			}
   870  			metaInfos[module.Name()] = meta
   871  		}
   872  	}
   873  	return metaInfos, nil
   874  }
   875  
   876  func (p *Provider) getClass(className string) (*models.Class, error) {
   877  	sch := p.schemaGetter.GetSchemaSkipAuth()
   878  	class := sch.FindClassByName(schema.ClassName(className))
   879  	if class == nil {
   880  		return nil, errors.Errorf("class %q not found in schema", className)
   881  	}
   882  	return class, nil
   883  }
   884  
   885  func (p *Provider) HasMultipleVectorizers() bool {
   886  	return p.hasMultipleVectorizers
   887  }
   888  
   889  func (p *Provider) BackupBackend(backend string) (modulecapabilities.BackupBackend, error) {
   890  	if module := p.GetByName(backend); module != nil {
   891  		if module.Type() == modulecapabilities.Backup {
   892  			if backend, ok := module.(modulecapabilities.BackupBackend); ok {
   893  				return backend, nil
   894  			}
   895  		}
   896  	}
   897  	return nil, errors.Errorf("backup: %s not found", backend)
   898  }