github.com/weaviate/weaviate@v1.24.6/modules/text2vec-contextionary/helpers_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package modcontextionary
    13  
    14  import (
    15  	"context"
    16  	"fmt"
    17  	"net/http"
    18  
    19  	"github.com/sirupsen/logrus/hooks/test"
    20  	"github.com/tailor-inc/graphql"
    21  	"github.com/tailor-inc/graphql/language/ast"
    22  	"github.com/weaviate/weaviate/adapters/handlers/graphql/local/explore"
    23  	"github.com/weaviate/weaviate/adapters/handlers/graphql/local/get"
    24  	test_helper "github.com/weaviate/weaviate/adapters/handlers/graphql/test/helper"
    25  	"github.com/weaviate/weaviate/entities/dto"
    26  	"github.com/weaviate/weaviate/entities/models"
    27  	"github.com/weaviate/weaviate/entities/modulecapabilities"
    28  	"github.com/weaviate/weaviate/entities/moduletools"
    29  	"github.com/weaviate/weaviate/entities/search"
    30  	text2vecadditional "github.com/weaviate/weaviate/modules/text2vec-contextionary/additional"
    31  	text2vecadditionalsempath "github.com/weaviate/weaviate/modules/text2vec-contextionary/additional/sempath"
    32  	text2vecadditionalprojector "github.com/weaviate/weaviate/usecases/modulecomponents/additional/projector"
    33  	text2vecneartext "github.com/weaviate/weaviate/usecases/modulecomponents/arguments/nearText"
    34  	"github.com/weaviate/weaviate/usecases/traverser"
    35  )
    36  
    37  type mockRequestsLog struct{}
    38  
    39  func (m *mockRequestsLog) Register(first string, second string) {
    40  }
    41  
    42  type mockResolver struct {
    43  	test_helper.MockResolver
    44  }
    45  
    46  type fakeInterpretation struct{}
    47  
    48  func (f *fakeInterpretation) AdditionalPropertyFn(ctx context.Context,
    49  	in []search.Result, params interface{}, limit *int,
    50  	argumentModuleParams map[string]interface{}, cfg moduletools.ClassConfig,
    51  ) ([]search.Result, error) {
    52  	return in, nil
    53  }
    54  
    55  func (f *fakeInterpretation) ExtractAdditionalFn(param []*ast.Argument) interface{} {
    56  	return true
    57  }
    58  
    59  func (f *fakeInterpretation) AdditionalPropertyDefaultValue() interface{} {
    60  	return true
    61  }
    62  
    63  type fakeExtender struct {
    64  	returnArgs []search.Result
    65  }
    66  
    67  func (f *fakeExtender) AdditionalPropertyFn(ctx context.Context,
    68  	in []search.Result, params interface{}, limit *int,
    69  	argumentModuleParams map[string]interface{}, cfg moduletools.ClassConfig,
    70  ) ([]search.Result, error) {
    71  	return f.returnArgs, nil
    72  }
    73  
    74  func (f *fakeExtender) ExtractAdditionalFn(param []*ast.Argument) interface{} {
    75  	return true
    76  }
    77  
    78  func (f *fakeExtender) AdditionalPropertyDefaultValue() interface{} {
    79  	return true
    80  }
    81  
    82  type fakeProjector struct {
    83  	returnArgs []search.Result
    84  }
    85  
    86  func (f *fakeProjector) AdditionalPropertyFn(ctx context.Context,
    87  	in []search.Result, params interface{}, limit *int,
    88  	argumentModuleParams map[string]interface{}, cfg moduletools.ClassConfig,
    89  ) ([]search.Result, error) {
    90  	return f.returnArgs, nil
    91  }
    92  
    93  func (f *fakeProjector) ExtractAdditionalFn(param []*ast.Argument) interface{} {
    94  	if len(param) > 0 {
    95  		p := &text2vecadditionalprojector.Params{}
    96  		err := p.SetDefaultsAndValidate(100, 4)
    97  		if err != nil {
    98  			return nil
    99  		}
   100  		return p
   101  	}
   102  	return &text2vecadditionalprojector.Params{
   103  		Enabled: true,
   104  	}
   105  }
   106  
   107  func (f *fakeProjector) AdditionalPropertyDefaultValue() interface{} {
   108  	return &text2vecadditionalprojector.Params{}
   109  }
   110  
   111  type fakePathBuilder struct {
   112  	returnArgs []search.Result
   113  }
   114  
   115  func (f *fakePathBuilder) AdditionalPropertyFn(ctx context.Context,
   116  	in []search.Result, params interface{}, limit *int,
   117  	argumentModuleParams map[string]interface{}, cfg moduletools.ClassConfig,
   118  ) ([]search.Result, error) {
   119  	return f.returnArgs, nil
   120  }
   121  
   122  func (f *fakePathBuilder) ExtractAdditionalFn(param []*ast.Argument) interface{} {
   123  	return &text2vecadditionalsempath.Params{}
   124  }
   125  
   126  func (f *fakePathBuilder) AdditionalPropertyDefaultValue() interface{} {
   127  	return &text2vecadditionalsempath.Params{}
   128  }
   129  
   130  type mockText2vecContextionaryModule struct{}
   131  
   132  func (m *mockText2vecContextionaryModule) Name() string {
   133  	return "text2vec-contextionary"
   134  }
   135  
   136  func (m *mockText2vecContextionaryModule) Init(params moduletools.ModuleInitParams) error {
   137  	return nil
   138  }
   139  
   140  func (m *mockText2vecContextionaryModule) RootHandler() http.Handler {
   141  	return nil
   142  }
   143  
   144  // graphql arguments
   145  func (m *mockText2vecContextionaryModule) Arguments() map[string]modulecapabilities.GraphQLArgument {
   146  	return text2vecneartext.New(nil).Arguments()
   147  }
   148  
   149  // additional properties
   150  func (m *mockText2vecContextionaryModule) AdditionalProperties() map[string]modulecapabilities.AdditionalProperty {
   151  	return text2vecadditional.New(&fakeExtender{}, &fakeProjector{}, &fakePathBuilder{}, &fakeInterpretation{}).AdditionalProperties()
   152  }
   153  
   154  type fakeModulesProvider struct{}
   155  
   156  func (fmp *fakeModulesProvider) GetAll() []modulecapabilities.Module {
   157  	panic("implement me")
   158  }
   159  
   160  func (fmp *fakeModulesProvider) VectorFromInput(ctx context.Context, className string, input string) ([]float32, error) {
   161  	panic("not implemented")
   162  }
   163  
   164  func (fmp *fakeModulesProvider) GetArguments(class *models.Class) map[string]*graphql.ArgumentConfig {
   165  	args := map[string]*graphql.ArgumentConfig{}
   166  	txt2vec := &mockText2vecContextionaryModule{}
   167  	if class.Vectorizer == txt2vec.Name() {
   168  		for name, argument := range txt2vec.Arguments() {
   169  			args[name] = argument.GetArgumentsFunction(class.Class)
   170  		}
   171  	}
   172  	return args
   173  }
   174  
   175  func (fmp *fakeModulesProvider) ExploreArguments(schema *models.Schema) map[string]*graphql.ArgumentConfig {
   176  	args := map[string]*graphql.ArgumentConfig{}
   177  	txt2vec := &mockText2vecContextionaryModule{}
   178  	for _, c := range schema.Classes {
   179  		if c.Vectorizer == txt2vec.Name() {
   180  			for name, argument := range txt2vec.Arguments() {
   181  				args[name] = argument.ExploreArgumentsFunction()
   182  			}
   183  		}
   184  	}
   185  	return args
   186  }
   187  
   188  func (fmp *fakeModulesProvider) CrossClassExtractSearchParams(arguments map[string]interface{}) map[string]interface{} {
   189  	return fmp.ExtractSearchParams(arguments, "")
   190  }
   191  
   192  func (fmp *fakeModulesProvider) ExtractSearchParams(arguments map[string]interface{}, className string) map[string]interface{} {
   193  	exractedParams := map[string]interface{}{}
   194  	if param, ok := arguments["nearText"]; ok {
   195  		exractedParams["nearText"] = extractNearTextParam(param.(map[string]interface{}))
   196  	}
   197  	return exractedParams
   198  }
   199  
   200  func (fmp *fakeModulesProvider) GetAdditionalFields(class *models.Class) map[string]*graphql.Field {
   201  	txt2vec := &mockText2vecContextionaryModule{}
   202  	additionalProperties := map[string]*graphql.Field{}
   203  	for name, additionalProperty := range txt2vec.AdditionalProperties() {
   204  		if additionalProperty.GraphQLFieldFunction != nil {
   205  			additionalProperties[name] = additionalProperty.GraphQLFieldFunction(class.Class)
   206  		}
   207  	}
   208  	return additionalProperties
   209  }
   210  
   211  func (fmp *fakeModulesProvider) ExtractAdditionalField(className, name string, params []*ast.Argument) interface{} {
   212  	txt2vec := &mockText2vecContextionaryModule{}
   213  	if additionalProperties := txt2vec.AdditionalProperties(); len(additionalProperties) > 0 {
   214  		if additionalProperty, ok := additionalProperties[name]; ok {
   215  			if additionalProperty.GraphQLExtractFunction != nil {
   216  				return additionalProperty.GraphQLExtractFunction(params)
   217  			}
   218  		}
   219  	}
   220  	return nil
   221  }
   222  
   223  func (fmp *fakeModulesProvider) GetExploreAdditionalExtend(ctx context.Context, in []search.Result,
   224  	moduleParams map[string]interface{}, searchVector []float32,
   225  	argumentModuleParams map[string]interface{}, cfg moduletools.ClassConfig,
   226  ) ([]search.Result, error) {
   227  	return fmp.additionalExtend(ctx, in, moduleParams, searchVector, "ExploreGet", argumentModuleParams, nil)
   228  }
   229  
   230  func (fmp *fakeModulesProvider) additionalExtend(ctx context.Context,
   231  	in search.Results, moduleParams map[string]interface{},
   232  	searchVector []float32, capability string, argumentModuleParams map[string]interface{}, cfg moduletools.ClassConfig,
   233  ) (search.Results, error) {
   234  	txt2vec := &mockText2vecContextionaryModule{}
   235  	additionalProperties := txt2vec.AdditionalProperties()
   236  	for name, value := range moduleParams {
   237  		additionalPropertyFn := fmp.getAdditionalPropertyFn(additionalProperties[name], capability)
   238  		if additionalPropertyFn != nil && value != nil {
   239  			searchValue := value
   240  			if searchVectorValue, ok := value.(modulecapabilities.AdditionalPropertyWithSearchVector); ok {
   241  				searchVectorValue.SetSearchVector(searchVector)
   242  				searchValue = searchVectorValue
   243  			}
   244  			resArray, err := additionalPropertyFn(ctx, in, searchValue, nil, nil, nil)
   245  			if err != nil {
   246  				return nil, err
   247  			}
   248  			in = resArray
   249  		}
   250  	}
   251  	return in, nil
   252  }
   253  
   254  func (fmp *fakeModulesProvider) getAdditionalPropertyFn(additionalProperty modulecapabilities.AdditionalProperty,
   255  	capability string,
   256  ) modulecapabilities.AdditionalPropertyFn {
   257  	switch capability {
   258  	case "ObjectGet":
   259  		return additionalProperty.SearchFunctions.ObjectGet
   260  	case "ObjectList":
   261  		return additionalProperty.SearchFunctions.ObjectList
   262  	case "ExploreGet":
   263  		return additionalProperty.SearchFunctions.ExploreGet
   264  	case "ExploreList":
   265  		return additionalProperty.SearchFunctions.ExploreList
   266  	default:
   267  		return nil
   268  	}
   269  }
   270  
   271  func (fmp *fakeModulesProvider) GraphQLAdditionalFieldNames() []string {
   272  	txt2vec := &mockText2vecContextionaryModule{}
   273  	additionalPropertiesNames := []string{}
   274  	for _, additionalProperty := range txt2vec.AdditionalProperties() {
   275  		if additionalProperty.GraphQLNames != nil {
   276  			additionalPropertiesNames = append(additionalPropertiesNames, additionalProperty.GraphQLNames...)
   277  		}
   278  	}
   279  	return additionalPropertiesNames
   280  }
   281  
   282  func extractNearTextParam(param map[string]interface{}) interface{} {
   283  	txt2vec := &mockText2vecContextionaryModule{}
   284  	argument := txt2vec.Arguments()["nearText"]
   285  	return argument.ExtractFunction(param)
   286  }
   287  
   288  func createArg(name string, value string) *ast.Argument {
   289  	n := ast.Name{
   290  		Value: name,
   291  	}
   292  	val := ast.StringValue{
   293  		Kind:  "Kind",
   294  		Value: value,
   295  	}
   296  	arg := ast.Argument{
   297  		Name:  ast.NewName(&n),
   298  		Kind:  "Kind",
   299  		Value: ast.NewStringValue(&val),
   300  	}
   301  	a := ast.NewArgument(&arg)
   302  	return a
   303  }
   304  
   305  func extractAdditionalParam(name string, args []*ast.Argument) interface{} {
   306  	txt2vec := &mockText2vecContextionaryModule{}
   307  	additionalProperties := txt2vec.AdditionalProperties()
   308  	switch name {
   309  	case "semanticPath", "featureProjection":
   310  		if ap, ok := additionalProperties[name]; ok {
   311  			return ap.GraphQLExtractFunction(args)
   312  		}
   313  		return nil
   314  	default:
   315  		return nil
   316  	}
   317  }
   318  
   319  func getFakeModulesProvider() *fakeModulesProvider {
   320  	return &fakeModulesProvider{}
   321  }
   322  
   323  func newMockResolver() *mockResolver {
   324  	logger, _ := test.NewNullLogger()
   325  	field, err := get.Build(&test_helper.SimpleSchema, logger, getFakeModulesProvider())
   326  	if err != nil {
   327  		panic(fmt.Sprintf("could not build graphql test schema: %s", err))
   328  	}
   329  	mocker := &mockResolver{}
   330  	mockLog := &mockRequestsLog{}
   331  	mocker.RootFieldName = "Get"
   332  	mocker.RootField = field
   333  	mocker.RootObject = map[string]interface{}{"Resolver": GetResolver(mocker), "RequestsLog": RequestsLog(mockLog)}
   334  	return mocker
   335  }
   336  
   337  func newExploreMockResolver() *mockResolver {
   338  	field := explore.Build(test_helper.SimpleSchema.Objects, getFakeModulesProvider())
   339  	mocker := &mockResolver{}
   340  	mockLog := &mockRequestsLog{}
   341  	mocker.RootFieldName = "Explore"
   342  	mocker.RootField = field
   343  	mocker.RootObject = map[string]interface{}{
   344  		"Resolver":    ExploreResolver(mocker),
   345  		"RequestsLog": mockLog,
   346  	}
   347  	return mocker
   348  }
   349  
   350  func (m *mockResolver) GetClass(ctx context.Context, principal *models.Principal,
   351  	params dto.GetParams,
   352  ) ([]interface{}, error) {
   353  	args := m.Called(params)
   354  	return args.Get(0).([]interface{}), args.Error(1)
   355  }
   356  
   357  func (m *mockResolver) Explore(ctx context.Context,
   358  	principal *models.Principal, params traverser.ExploreParams,
   359  ) ([]search.Result, error) {
   360  	args := m.Called(params)
   361  	return args.Get(0).([]search.Result), args.Error(1)
   362  }
   363  
   364  // Resolver is a local abstraction of the required UC resolvers
   365  type GetResolver interface {
   366  	GetClass(ctx context.Context, principal *models.Principal, info dto.GetParams) ([]interface{}, error)
   367  }
   368  
   369  type ExploreResolver interface {
   370  	Explore(ctx context.Context, principal *models.Principal, params traverser.ExploreParams) ([]search.Result, error)
   371  }
   372  
   373  // RequestsLog is a local abstraction on the RequestsLog that needs to be
   374  // provided to the graphQL API in order to log Local.Get queries.
   375  type RequestsLog interface {
   376  	Register(requestType string, identifier string)
   377  }