github.com/weaviate/weaviate@v1.24.6/usecases/modules/searchers_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package modules
    13  
    14  import (
    15  	"context"
    16  	"testing"
    17  
    18  	"github.com/go-openapi/strfmt"
    19  	"github.com/sirupsen/logrus/hooks/test"
    20  	"github.com/stretchr/testify/assert"
    21  	"github.com/stretchr/testify/require"
    22  	"github.com/weaviate/weaviate/entities/models"
    23  	"github.com/weaviate/weaviate/entities/modulecapabilities"
    24  	"github.com/weaviate/weaviate/entities/moduletools"
    25  	"github.com/weaviate/weaviate/entities/schema"
    26  )
    27  
    28  func TestModulesWithSearchers(t *testing.T) {
    29  	sch := schema.Schema{
    30  		Objects: &models.Schema{
    31  			Classes: []*models.Class{
    32  				{
    33  					Class:      "MyClass",
    34  					Vectorizer: "mod",
    35  					ModuleConfig: map[string]interface{}{
    36  						"mod": map[string]interface{}{
    37  							"some-config": "some-config-value",
    38  						},
    39  					},
    40  				},
    41  			},
    42  		},
    43  	}
    44  	logger, _ := test.NewNullLogger()
    45  
    46  	t.Run("get a vector for a class", func(t *testing.T) {
    47  		p := NewProvider()
    48  		p.SetSchemaGetter(&fakeSchemaGetter{
    49  			schema: sch,
    50  		})
    51  		p.Register(newSearcherModule("mod").
    52  			withArg("nearGrape").
    53  			withSearcher("nearGrape", func(ctx context.Context, params interface{},
    54  				className string,
    55  				findVectorFn modulecapabilities.FindVectorFn,
    56  				cfg moduletools.ClassConfig,
    57  			) ([]float32, error) {
    58  				// verify that the config tool is set, as this is a per-class search,
    59  				// so it must be set
    60  				assert.NotNil(t, cfg)
    61  
    62  				// take the findVectorFn and append one dimension. This doesn't make too
    63  				// much sense, but helps verify that the modules method was used in the
    64  				// decisions
    65  				initial, _, _ := findVectorFn(ctx, "class", "123", "", "")
    66  				return append(initial, 4), nil
    67  			}),
    68  		)
    69  		p.Init(context.Background(), nil, logger)
    70  
    71  		res, targetVector, err := p.VectorFromSearchParam(context.Background(), "MyClass",
    72  			"nearGrape", nil, fakeFindVector, "")
    73  
    74  		require.Nil(t, err)
    75  		assert.Equal(t, []float32{1, 2, 3, 4}, res)
    76  		assert.Equal(t, "", targetVector)
    77  	})
    78  
    79  	t.Run("get a vector across classes", func(t *testing.T) {
    80  		p := NewProvider()
    81  		p.SetSchemaGetter(&fakeSchemaGetter{
    82  			schema: sch,
    83  		})
    84  		p.Register(newSearcherModule("mod").
    85  			withArg("nearGrape").
    86  			withSearcher("nearGrape", func(ctx context.Context, params interface{},
    87  				className string,
    88  				findVectorFn modulecapabilities.FindVectorFn,
    89  				cfg moduletools.ClassConfig,
    90  			) ([]float32, error) {
    91  				// this is a cross-class search, such as is used for Explore{}, in this
    92  				// case we do not have class-based config, but we need at least pass
    93  				// a tenant information, that's why we pass an empty config with empty tenant
    94  				// so that it would be possible to perform cross class searches, without
    95  				// tenant context. Modules must be able to deal with this situation!
    96  				assert.NotNil(t, cfg)
    97  				assert.Equal(t, "", cfg.Tenant())
    98  
    99  				// take the findVectorFn and append one dimension. This doesn't make too
   100  				// much sense, but helps verify that the modules method was used in the
   101  				// decisions
   102  				initial, _, _ := findVectorFn(ctx, "class", "123", "", "")
   103  				return append(initial, 4), nil
   104  			}),
   105  		)
   106  		p.Init(context.Background(), nil, logger)
   107  
   108  		res, targetVector, err := p.CrossClassVectorFromSearchParam(context.Background(),
   109  			"nearGrape", nil, fakeFindVector)
   110  
   111  		require.Nil(t, err)
   112  		assert.Equal(t, []float32{1, 2, 3, 4}, res)
   113  		assert.Equal(t, "", targetVector)
   114  	})
   115  }
   116  
   117  func fakeFindVector(ctx context.Context, className string, id strfmt.UUID, tenant, targetVector string) ([]float32, string, error) {
   118  	return []float32{1, 2, 3}, targetVector, nil
   119  }
   120  
   121  func newSearcherModule(name string) *dummySearcherModule {
   122  	return &dummySearcherModule{
   123  		dummyGraphQLModule: newGraphQLModule(name),
   124  		searchers:          map[string]modulecapabilities.VectorForParams{},
   125  	}
   126  }
   127  
   128  type dummySearcherModule struct {
   129  	*dummyGraphQLModule
   130  	searchers map[string]modulecapabilities.VectorForParams
   131  }
   132  
   133  func (m *dummySearcherModule) withArg(arg string) *dummySearcherModule {
   134  	// call the super's withArg
   135  	m.dummyGraphQLModule.withArg(arg)
   136  
   137  	// but don't return their return type but ours :)
   138  	return m
   139  }
   140  
   141  // a helper for our test
   142  func (m *dummySearcherModule) withSearcher(arg string,
   143  	impl modulecapabilities.VectorForParams,
   144  ) *dummySearcherModule {
   145  	m.searchers[arg] = impl
   146  	return m
   147  }
   148  
   149  // public method to implement the modulecapabilities.Searcher interface
   150  func (m *dummySearcherModule) VectorSearches() map[string]modulecapabilities.VectorForParams {
   151  	return m.searchers
   152  }