github.com/weaviate/weaviate@v1.24.6/adapters/handlers/graphql/local/explore/helpers_for_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package explore
    13  
    14  import (
    15  	"context"
    16  	"fmt"
    17  	"net/http"
    18  
    19  	"github.com/tailor-inc/graphql"
    20  	"github.com/weaviate/weaviate/adapters/handlers/graphql/descriptions"
    21  	testhelper "github.com/weaviate/weaviate/adapters/handlers/graphql/test/helper"
    22  	"github.com/weaviate/weaviate/entities/models"
    23  	"github.com/weaviate/weaviate/entities/modulecapabilities"
    24  	"github.com/weaviate/weaviate/entities/moduletools"
    25  	"github.com/weaviate/weaviate/entities/search"
    26  	"github.com/weaviate/weaviate/usecases/traverser"
    27  )
    28  
    29  type mockRequestsLog struct{}
    30  
    31  func (m *mockRequestsLog) Register(first string, second string) {
    32  }
    33  
    34  type mockResolver struct {
    35  	testhelper.MockResolver
    36  }
    37  
    38  type fakeModulesProvider struct{}
    39  
    40  func (p *fakeModulesProvider) VectorFromInput(ctx context.Context, className string, input string) ([]float32, error) {
    41  	panic("not implemented")
    42  }
    43  
    44  func (p *fakeModulesProvider) ExploreArguments(schema *models.Schema) map[string]*graphql.ArgumentConfig {
    45  	args := map[string]*graphql.ArgumentConfig{}
    46  	txt2vec := &nearCustomTextModule{}
    47  	for _, c := range schema.Classes {
    48  		if c.Vectorizer == txt2vec.Name() {
    49  			for name, argument := range txt2vec.Arguments() {
    50  				args[name] = argument.ExploreArgumentsFunction()
    51  			}
    52  		}
    53  	}
    54  	return args
    55  }
    56  
    57  func (p *fakeModulesProvider) CrossClassExtractSearchParams(arguments map[string]interface{}) map[string]interface{} {
    58  	exractedParams := map[string]interface{}{}
    59  	if param, ok := arguments["nearCustomText"]; ok {
    60  		exractedParams["nearCustomText"] = extractNearCustomTextParam(param.(map[string]interface{}))
    61  	}
    62  	return exractedParams
    63  }
    64  
    65  func extractNearCustomTextParam(param map[string]interface{}) interface{} {
    66  	nearCustomText := &nearCustomTextModule{}
    67  	argument := nearCustomText.Arguments()["nearCustomText"]
    68  	return argument.ExtractFunction(param)
    69  }
    70  
    71  func getFakeModulesProvider() ModulesProvider {
    72  	return &fakeModulesProvider{}
    73  }
    74  
    75  func newMockResolver() *mockResolver {
    76  	field := Build(testhelper.SimpleSchema.Objects, getFakeModulesProvider())
    77  	mocker := &mockResolver{}
    78  	mockLog := &mockRequestsLog{}
    79  	mocker.RootFieldName = "Explore"
    80  	mocker.RootField = field
    81  	mocker.RootObject = map[string]interface{}{
    82  		"Resolver":    Resolver(mocker),
    83  		"RequestsLog": mockLog,
    84  	}
    85  	return mocker
    86  }
    87  
    88  func newMockResolverNoModules() *mockResolver {
    89  	field := Build(testhelper.SimpleSchema.Objects, nil)
    90  	mocker := &mockResolver{}
    91  	mockLog := &mockRequestsLog{}
    92  	mocker.RootFieldName = "Explore"
    93  	mocker.RootField = field
    94  	mocker.RootObject = map[string]interface{}{
    95  		"Resolver":    Resolver(mocker),
    96  		"RequestsLog": mockLog,
    97  	}
    98  	return mocker
    99  }
   100  
   101  func newMockResolverEmptySchema() *mockResolver {
   102  	field := Build(&models.Schema{}, getFakeModulesProvider())
   103  	mocker := &mockResolver{}
   104  	mockLog := &mockRequestsLog{}
   105  	mocker.RootFieldName = "Explore"
   106  	mocker.RootField = field
   107  	mocker.RootObject = map[string]interface{}{
   108  		"Resolver":    Resolver(mocker),
   109  		"RequestsLog": mockLog,
   110  	}
   111  	return mocker
   112  }
   113  
   114  func (m *mockResolver) Explore(ctx context.Context,
   115  	principal *models.Principal, params traverser.ExploreParams,
   116  ) ([]search.Result, error) {
   117  	args := m.Called(params)
   118  	return args.Get(0).([]search.Result), args.Error(1)
   119  }
   120  
   121  type nearCustomTextParams struct {
   122  	Values       []string
   123  	MoveTo       nearExploreMove
   124  	MoveAwayFrom nearExploreMove
   125  	Certainty    float64
   126  	Distance     float64
   127  	WithDistance bool
   128  }
   129  
   130  type nearExploreMove struct {
   131  	Values  []string
   132  	Force   float32
   133  	Objects []nearObjectMove
   134  }
   135  
   136  type nearObjectMove struct {
   137  	ID     string
   138  	Beacon string
   139  }
   140  
   141  type nearCustomTextModule struct{}
   142  
   143  func (m *nearCustomTextModule) Name() string {
   144  	return "text2vec-contextionary"
   145  }
   146  
   147  func (m *nearCustomTextModule) Init(params moduletools.ModuleInitParams) error {
   148  	return nil
   149  }
   150  
   151  func (m *nearCustomTextModule) RootHandler() http.Handler {
   152  	return nil
   153  }
   154  
   155  func (m *nearCustomTextModule) Arguments() map[string]modulecapabilities.GraphQLArgument {
   156  	arguments := map[string]modulecapabilities.GraphQLArgument{}
   157  	// define nearCustomText argument
   158  	arguments["nearCustomText"] = modulecapabilities.GraphQLArgument{
   159  		GetArgumentsFunction: func(classname string) *graphql.ArgumentConfig {
   160  			return m.getNearCustomTextArgument(classname)
   161  		},
   162  		ExploreArgumentsFunction: func() *graphql.ArgumentConfig {
   163  			return m.getNearCustomTextArgument("")
   164  		},
   165  		ExtractFunction: func(source map[string]interface{}) interface{} {
   166  			return m.extractNearCustomTextArgument(source)
   167  		},
   168  		ValidateFunction: func(param interface{}) error {
   169  			// all is valid
   170  			return nil
   171  		},
   172  	}
   173  	return arguments
   174  }
   175  
   176  func (m *nearCustomTextModule) getNearCustomTextArgument(classname string) *graphql.ArgumentConfig {
   177  	prefix := classname
   178  	return &graphql.ArgumentConfig{
   179  		Type: graphql.NewInputObject(
   180  			graphql.InputObjectConfig{
   181  				Name: fmt.Sprintf("%sNearCustomTextInpObj", prefix),
   182  				Fields: graphql.InputObjectConfigFieldMap{
   183  					"concepts": &graphql.InputObjectFieldConfig{
   184  						Type: graphql.NewNonNull(graphql.NewList(graphql.String)),
   185  					},
   186  					"moveTo": &graphql.InputObjectFieldConfig{
   187  						Description: descriptions.VectorMovement,
   188  						Type: graphql.NewInputObject(
   189  							graphql.InputObjectConfig{
   190  								Name: fmt.Sprintf("%sMoveTo", prefix),
   191  								Fields: graphql.InputObjectConfigFieldMap{
   192  									"concepts": &graphql.InputObjectFieldConfig{
   193  										Description: descriptions.Keywords,
   194  										Type:        graphql.NewList(graphql.String),
   195  									},
   196  									"objects": &graphql.InputObjectFieldConfig{
   197  										Description: "objects",
   198  										Type: graphql.NewList(graphql.NewInputObject(
   199  											graphql.InputObjectConfig{
   200  												Name: fmt.Sprintf("%sMovementObjectsToInpObj", prefix),
   201  												Fields: graphql.InputObjectConfigFieldMap{
   202  													"id": &graphql.InputObjectFieldConfig{
   203  														Type:        graphql.String,
   204  														Description: "id of an object",
   205  													},
   206  													"beacon": &graphql.InputObjectFieldConfig{
   207  														Type:        graphql.String,
   208  														Description: descriptions.Beacon,
   209  													},
   210  												},
   211  												Description: "Movement Object",
   212  											},
   213  										)),
   214  									},
   215  									"force": &graphql.InputObjectFieldConfig{
   216  										Description: descriptions.Force,
   217  										Type:        graphql.NewNonNull(graphql.Float),
   218  									},
   219  								},
   220  							}),
   221  					},
   222  					"moveAwayFrom": &graphql.InputObjectFieldConfig{
   223  						Description: descriptions.VectorMovement,
   224  						Type: graphql.NewInputObject(
   225  							graphql.InputObjectConfig{
   226  								Name: fmt.Sprintf("%sMoveAway", prefix),
   227  								Fields: graphql.InputObjectConfigFieldMap{
   228  									"concepts": &graphql.InputObjectFieldConfig{
   229  										Description: descriptions.Keywords,
   230  										Type:        graphql.NewList(graphql.String),
   231  									},
   232  									"objects": &graphql.InputObjectFieldConfig{
   233  										Description: "objects",
   234  										Type: graphql.NewList(graphql.NewInputObject(
   235  											graphql.InputObjectConfig{
   236  												Name: fmt.Sprintf("%sMovementObjectsAwayInpObj", prefix),
   237  												Fields: graphql.InputObjectConfigFieldMap{
   238  													"id": &graphql.InputObjectFieldConfig{
   239  														Type:        graphql.String,
   240  														Description: "id of an object",
   241  													},
   242  													"beacon": &graphql.InputObjectFieldConfig{
   243  														Type:        graphql.String,
   244  														Description: descriptions.Beacon,
   245  													},
   246  												},
   247  												Description: "Movement Object",
   248  											},
   249  										)),
   250  									},
   251  									"force": &graphql.InputObjectFieldConfig{
   252  										Description: descriptions.Force,
   253  										Type:        graphql.NewNonNull(graphql.Float),
   254  									},
   255  								},
   256  							}),
   257  					},
   258  					"certainty": &graphql.InputObjectFieldConfig{
   259  						Description: descriptions.Certainty,
   260  						Type:        graphql.Float,
   261  					},
   262  					"distance": &graphql.InputObjectFieldConfig{
   263  						Description: descriptions.Distance,
   264  						Type:        graphql.Float,
   265  					},
   266  				},
   267  				Description: descriptions.GetWhereInpObj,
   268  			},
   269  		),
   270  	}
   271  }
   272  
   273  func (m *nearCustomTextModule) extractNearCustomTextArgument(source map[string]interface{}) *nearCustomTextParams {
   274  	var args nearCustomTextParams
   275  
   276  	concepts := source["concepts"].([]interface{})
   277  	args.Values = make([]string, len(concepts))
   278  	for i, value := range concepts {
   279  		args.Values[i] = value.(string)
   280  	}
   281  
   282  	certainty, ok := source["certainty"]
   283  	if ok {
   284  		args.Certainty = certainty.(float64)
   285  	}
   286  
   287  	distance, ok := source["distance"]
   288  	if ok {
   289  		args.Distance = distance.(float64)
   290  		args.WithDistance = true
   291  	}
   292  
   293  	// moveTo is an optional arg, so it could be nil
   294  	moveTo, ok := source["moveTo"]
   295  	if ok {
   296  		moveToMap := moveTo.(map[string]interface{})
   297  		args.MoveTo = m.parseMoveParam(moveToMap)
   298  	}
   299  
   300  	moveAwayFrom, ok := source["moveAwayFrom"]
   301  	if ok {
   302  		moveAwayFromMap := moveAwayFrom.(map[string]interface{})
   303  		args.MoveAwayFrom = m.parseMoveParam(moveAwayFromMap)
   304  	}
   305  
   306  	return &args
   307  }
   308  
   309  func (m *nearCustomTextModule) parseMoveParam(source map[string]interface{}) nearExploreMove {
   310  	res := nearExploreMove{}
   311  	res.Force = float32(source["force"].(float64))
   312  
   313  	concepts, ok := source["concepts"].([]interface{})
   314  	if ok {
   315  		res.Values = make([]string, len(concepts))
   316  		for i, value := range concepts {
   317  			res.Values[i] = value.(string)
   318  		}
   319  	}
   320  
   321  	objects, ok := source["objects"].([]interface{})
   322  	if ok {
   323  		res.Objects = make([]nearObjectMove, len(objects))
   324  		for i, value := range objects {
   325  			v, ok := value.(map[string]interface{})
   326  			if ok {
   327  				if v["id"] != nil {
   328  					res.Objects[i].ID = v["id"].(string)
   329  				}
   330  				if v["beacon"] != nil {
   331  					res.Objects[i].Beacon = v["beacon"].(string)
   332  				}
   333  			}
   334  		}
   335  	}
   336  
   337  	return res
   338  }