github.com/weaviate/weaviate@v1.24.6/usecases/traverser/hybrid/searcher_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package hybrid
    13  
    14  import (
    15  	"context"
    16  	"testing"
    17  
    18  	"github.com/sirupsen/logrus/hooks/test"
    19  	"github.com/stretchr/testify/assert"
    20  	"github.com/stretchr/testify/require"
    21  	"github.com/weaviate/weaviate/adapters/handlers/graphql/local/common_filters"
    22  	"github.com/weaviate/weaviate/entities/models"
    23  	"github.com/weaviate/weaviate/entities/searchparams"
    24  	"github.com/weaviate/weaviate/entities/storobj"
    25  )
    26  
    27  func TestSearcher(t *testing.T) {
    28  	ctx := context.Background()
    29  	logger, _ := test.NewNullLogger()
    30  	class := "HybridClass"
    31  
    32  	tests := []struct {
    33  		name string
    34  		f    func(t *testing.T)
    35  	}{
    36  		{
    37  			name: "with module provider",
    38  			f: func(t *testing.T) {
    39  				params := &Params{
    40  					HybridSearch: &searchparams.HybridSearch{
    41  						Type:          "hybrid",
    42  						Alpha:         0.5,
    43  						Query:         "some query",
    44  						TargetVectors: []string{"default"},
    45  					},
    46  					Class: class,
    47  				}
    48  				sparse := func() ([]*storobj.Object, []float32, error) { return nil, nil, nil }
    49  				dense := func([]float32) ([]*storobj.Object, []float32, error) { return nil, nil, nil }
    50  				provider := &fakeModuleProvider{}
    51  				schemaGetter := newFakeSchemaManager()
    52  				targetVectorParamHelper := newFakeTargetVectorParamHelper()
    53  				_, err := Search(ctx, params, logger, sparse, dense, nil, provider, schemaGetter, targetVectorParamHelper)
    54  				require.Nil(t, err)
    55  			},
    56  		},
    57  		{
    58  			name: "without module provider",
    59  			f: func(t *testing.T) {
    60  				params := &Params{
    61  					HybridSearch: &searchparams.HybridSearch{
    62  						Type:  "hybrid",
    63  						Alpha: 0.5,
    64  						Query: "some query",
    65  					},
    66  					Class: class,
    67  				}
    68  				sparse := func() ([]*storobj.Object, []float32, error) { return nil, nil, nil }
    69  				dense := func([]float32) ([]*storobj.Object, []float32, error) { return nil, nil, nil }
    70  
    71  				_, err := Search(ctx, params, logger, sparse, dense, nil, nil, nil, nil)
    72  				require.Nil(t, err)
    73  			},
    74  		},
    75  		{
    76  			name: "with sparse search only",
    77  			f: func(t *testing.T) {
    78  				params := &Params{
    79  					HybridSearch: &searchparams.HybridSearch{
    80  						Type:  "hybrid",
    81  						Alpha: 0,
    82  						Query: "some query",
    83  					},
    84  					Class: class,
    85  				}
    86  				sparse := func() ([]*storobj.Object, []float32, error) {
    87  					return []*storobj.Object{
    88  						{
    89  							Object: models.Object{
    90  								Class:      class,
    91  								ID:         "1889a225-3b28-477d-b8fc-5f6071bb4731",
    92  								Properties: map[string]any{"prop": "val"},
    93  								Vector:     []float32{1, 2, 3},
    94  							},
    95  							Vector: []float32{1, 2, 3},
    96  						},
    97  					}, []float32{0.008}, nil
    98  				}
    99  				dense := func([]float32) ([]*storobj.Object, []float32, error) { return nil, nil, nil }
   100  				res, err := Search(ctx, params, logger, sparse, dense, nil, nil, nil, nil)
   101  				require.Nil(t, err)
   102  				assert.Len(t, res, 1)
   103  				assert.NotNil(t, res[0])
   104  				assert.Contains(t, res[0].ExplainScore, "(Result Set keyword) Document")
   105  				assert.Contains(t, res[0].ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731")
   106  				assert.Equal(t, res[0].Vector, []float32{1, 2, 3})
   107  				assert.Equal(t, res[0].Dist, float32(0.000))
   108  			},
   109  		},
   110  		{
   111  			name: "with dense search only",
   112  			f: func(t *testing.T) {
   113  				params := &Params{
   114  					HybridSearch: &searchparams.HybridSearch{
   115  						Type:   "hybrid",
   116  						Alpha:  1,
   117  						Query:  "some query",
   118  						Vector: []float32{1, 2, 3},
   119  					},
   120  					Class: class,
   121  				}
   122  				sparse := func() ([]*storobj.Object, []float32, error) { return nil, nil, nil }
   123  				dense := func([]float32) ([]*storobj.Object, []float32, error) {
   124  					return []*storobj.Object{
   125  						{
   126  							Object: models.Object{
   127  								Class:      class,
   128  								ID:         "1889a225-3b28-477d-b8fc-5f6071bb4731",
   129  								Properties: map[string]any{"prop": "val"},
   130  								Vector:     []float32{1, 2, 3},
   131  							},
   132  							Vector: []float32{1, 2, 3},
   133  						},
   134  					}, []float32{0.008}, nil
   135  				}
   136  
   137  				res, err := Search(ctx, params, logger, sparse, dense, nil, nil, nil, nil)
   138  				require.Nil(t, err)
   139  				assert.Len(t, res, 1)
   140  				assert.NotNil(t, res[0])
   141  				assert.Contains(t, res[0].ExplainScore, "(Result Set vector) Document")
   142  				assert.Contains(t, res[0].ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731")
   143  				assert.Equal(t, res[0].Vector, []float32{1, 2, 3})
   144  				assert.Equal(t, res[0].Dist, float32(0.008))
   145  			},
   146  		},
   147  		{
   148  			name: "combined hybrid search",
   149  			f: func(t *testing.T) {
   150  				params := &Params{
   151  					HybridSearch: &searchparams.HybridSearch{
   152  						Type:   "hybrid",
   153  						Alpha:  0.5,
   154  						Query:  "some query",
   155  						Vector: []float32{1, 2, 3},
   156  					},
   157  					Class: class,
   158  				}
   159  				sparse := func() ([]*storobj.Object, []float32, error) {
   160  					return []*storobj.Object{
   161  						{
   162  							Object: models.Object{
   163  								Class:      class,
   164  								ID:         "1889a225-3b28-477d-b8fc-5f6071bb4731",
   165  								Properties: map[string]any{"prop": "val"},
   166  								Vector:     []float32{1, 2, 3},
   167  							},
   168  							Vector: []float32{1, 2, 3},
   169  						},
   170  					}, []float32{0.008}, nil
   171  				}
   172  				dense := func([]float32) ([]*storobj.Object, []float32, error) {
   173  					return []*storobj.Object{
   174  						{
   175  							Object: models.Object{
   176  								Class:      class,
   177  								ID:         "79a636c2-3314-442e-a4d1-e94d7c0afc3a",
   178  								Properties: map[string]any{"prop": "val"},
   179  								Vector:     []float32{4, 5, 6},
   180  							},
   181  							Vector: []float32{4, 5, 6},
   182  						},
   183  					}, []float32{0.008}, nil
   184  				}
   185  				res, err := Search(ctx, params, logger, sparse, dense, nil, nil, nil, nil)
   186  				require.Nil(t, err)
   187  				assert.Len(t, res, 2)
   188  				assert.NotNil(t, res[0])
   189  				assert.NotNil(t, res[1])
   190  				assert.Contains(t, res[0].ExplainScore, "(Result Set vector) Document")
   191  				assert.Contains(t, res[0].ExplainScore, "79a636c2-3314-442e-a4d1-e94d7c0afc3a")
   192  				assert.Equal(t, res[0].Vector, []float32{4, 5, 6})
   193  				assert.Equal(t, res[0].Dist, float32(0.008))
   194  				assert.Contains(t, res[1].ExplainScore, "(Result Set keyword) Document")
   195  				assert.Contains(t, res[1].ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731")
   196  				assert.Equal(t, res[1].Vector, []float32{1, 2, 3})
   197  				assert.Equal(t, res[1].Dist, float32(0.000))
   198  				assert.Equal(t, float32(0.008333334), res[0].Score)
   199  			},
   200  		},
   201  		{
   202  			name: "with sparse subsearch filter",
   203  			f: func(t *testing.T) {
   204  				params := &Params{
   205  					HybridSearch: &searchparams.HybridSearch{
   206  						Type: "hybrid",
   207  						SubSearches: []searchparams.WeightedSearchResult{
   208  							{
   209  								Type: "sparseSearch",
   210  								SearchParams: searchparams.KeywordRanking{
   211  									Type:       "bm25",
   212  									Properties: []string{"propA", "propB"},
   213  									Query:      "some query",
   214  								},
   215  							},
   216  						},
   217  					},
   218  					Class: class,
   219  				}
   220  				sparse := func() ([]*storobj.Object, []float32, error) {
   221  					return []*storobj.Object{
   222  						{
   223  							Object: models.Object{
   224  								Class:      class,
   225  								ID:         "1889a225-3b28-477d-b8fc-5f6071bb4731",
   226  								Properties: map[string]any{"prop": "val"},
   227  								Vector:     []float32{1, 2, 3},
   228  								Additional: map[string]interface{}{"score": float32(0.008)},
   229  							},
   230  							Vector: []float32{1, 2, 3},
   231  						},
   232  					}, []float32{0.008}, nil
   233  				}
   234  				dense := func([]float32) ([]*storobj.Object, []float32, error) {
   235  					return nil, nil, nil
   236  				}
   237  				res, err := Search(ctx, params, logger, sparse, dense, nil, nil, nil, nil)
   238  				require.Nil(t, err)
   239  				assert.Len(t, res, 1)
   240  				assert.NotNil(t, res[0])
   241  				assert.Contains(t, res[0].ExplainScore, "(Result Set bm25f) Document 1889a225-3b28-477d-b8fc-5f6071bb4731")
   242  				assert.Contains(t, res[0].ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731")
   243  				assert.Equal(t, res[0].Vector, []float32{1, 2, 3})
   244  				assert.Equal(t, res[0].Dist, float32(0.008))
   245  			},
   246  		},
   247  		{
   248  			name: "with nearText subsearch filter",
   249  			f: func(t *testing.T) {
   250  				params := &Params{
   251  					HybridSearch: &searchparams.HybridSearch{
   252  						TargetVectors: []string{"default"},
   253  						Type:          "hybrid",
   254  						SubSearches: []searchparams.WeightedSearchResult{
   255  							{
   256  								Type: "nearText",
   257  								SearchParams: searchparams.NearTextParams{
   258  									Values:    []string{"some query"},
   259  									Certainty: 0.8,
   260  								},
   261  							},
   262  						},
   263  					},
   264  					Class: class,
   265  				}
   266  				sparse := func() ([]*storobj.Object, []float32, error) {
   267  					return nil, nil, nil
   268  				}
   269  				dense := func([]float32) ([]*storobj.Object, []float32, error) {
   270  					return []*storobj.Object{
   271  						{
   272  							Object: models.Object{
   273  								Class:      class,
   274  								ID:         "1889a225-3b28-477d-b8fc-5f6071bb4731",
   275  								Properties: map[string]any{"prop": "val"},
   276  								Vector:     []float32{1, 2, 3},
   277  								Additional: map[string]interface{}{"score": float32(0.008)},
   278  							},
   279  							Vector: []float32{1, 2, 3},
   280  						},
   281  					}, []float32{0.008}, nil
   282  				}
   283  				provider := &fakeModuleProvider{}
   284  				schemaGetter := newFakeSchemaManager()
   285  				targetVectorParamHelper := newFakeTargetVectorParamHelper()
   286  				res, err := Search(ctx, params, logger, sparse, dense, nil, provider, schemaGetter, targetVectorParamHelper)
   287  				require.Nil(t, err)
   288  				assert.Len(t, res, 1)
   289  				assert.NotNil(t, res[0])
   290  				assert.Contains(t, res[0].ExplainScore, "(Result Set vector,nearText) Document 1889a225-3b28-477d-b8fc-5f6071bb4731")
   291  				assert.Contains(t, res[0].ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731")
   292  				assert.Equal(t, res[0].Vector, []float32{1, 2, 3})
   293  				assert.Equal(t, res[0].Dist, float32(0.008))
   294  			},
   295  		},
   296  		{
   297  			name: "with nearVector subsearch filter",
   298  			f: func(t *testing.T) {
   299  				params := &Params{
   300  					HybridSearch: &searchparams.HybridSearch{
   301  						TargetVectors: []string{"default"},
   302  						Type:          "hybrid",
   303  						SubSearches: []searchparams.WeightedSearchResult{
   304  							{
   305  								Type: "nearVector",
   306  								SearchParams: searchparams.NearVector{
   307  									Vector:    []float32{1, 2, 3},
   308  									Certainty: 0.8,
   309  								},
   310  							},
   311  						},
   312  					},
   313  					Class: class,
   314  				}
   315  				sparse := func() ([]*storobj.Object, []float32, error) {
   316  					return nil, nil, nil
   317  				}
   318  				dense := func([]float32) ([]*storobj.Object, []float32, error) {
   319  					return []*storobj.Object{
   320  						{
   321  							Object: models.Object{
   322  								Class:      class,
   323  								ID:         "1889a225-3b28-477d-b8fc-5f6071bb4731",
   324  								Properties: map[string]any{"prop": "val"},
   325  								Vector:     []float32{1, 2, 3},
   326  								Additional: map[string]interface{}{"score": float32(0.008)},
   327  							},
   328  							Vector: []float32{1, 2, 3},
   329  						},
   330  					}, []float32{0.008}, nil
   331  				}
   332  				provider := &fakeModuleProvider{}
   333  				schemaGetter := newFakeSchemaManager()
   334  				targetVectorParamHelper := newFakeTargetVectorParamHelper()
   335  				res, err := Search(ctx, params, logger, sparse, dense, nil, provider, schemaGetter, targetVectorParamHelper)
   336  				require.Nil(t, err)
   337  				assert.Len(t, res, 1)
   338  				assert.NotNil(t, res[0])
   339  				assert.Contains(t, res[0].ExplainScore, "(Result Set vector,nearVector) Document 1889a225-3b28-477d-b8fc-5f6071bb4731")
   340  				assert.Contains(t, res[0].ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731")
   341  				assert.Equal(t, res[0].Vector, []float32{1, 2, 3})
   342  				assert.Equal(t, res[0].Dist, float32(0.008))
   343  			},
   344  		},
   345  		{
   346  			name: "with all subsearch filters",
   347  			f: func(t *testing.T) {
   348  				params := &Params{
   349  					HybridSearch: &searchparams.HybridSearch{
   350  						TargetVectors: []string{"default"},
   351  						Type:          "hybrid",
   352  						SubSearches: []searchparams.WeightedSearchResult{
   353  							{
   354  								Type: "nearVector",
   355  								SearchParams: searchparams.NearVector{
   356  									Vector:    []float32{1, 2, 3},
   357  									Certainty: 0.8,
   358  								},
   359  								Weight: 100,
   360  							},
   361  							{
   362  								Type: "nearText",
   363  								SearchParams: searchparams.NearTextParams{
   364  									Values:    []string{"some query"},
   365  									Certainty: 0.8,
   366  								},
   367  								Weight: 2,
   368  							},
   369  							{
   370  								Type: "sparseSearch",
   371  								SearchParams: searchparams.KeywordRanking{
   372  									Type:       "bm25",
   373  									Properties: []string{"propA", "propB"},
   374  									Query:      "some query",
   375  								},
   376  								Weight: 3,
   377  							},
   378  						},
   379  					},
   380  					Class: class,
   381  				}
   382  				sparse := func() ([]*storobj.Object, []float32, error) {
   383  					return []*storobj.Object{
   384  						{
   385  							Object: models.Object{
   386  								Class:      class,
   387  								ID:         "1889a225-3b28-477d-b8fc-5f6071bb4731",
   388  								Properties: map[string]any{"prop": "val"},
   389  								Vector:     []float32{1, 2, 3},
   390  								Additional: map[string]interface{}{"score": float32(0.008)},
   391  							},
   392  							Vector: []float32{1, 2, 3},
   393  						},
   394  					}, []float32{0.008}, nil
   395  				}
   396  				dense := func([]float32) ([]*storobj.Object, []float32, error) {
   397  					return []*storobj.Object{
   398  						{
   399  							Object: models.Object{
   400  								Class:      class,
   401  								ID:         "79a636c2-3314-442e-a4d1-e94d7c0afc3a",
   402  								Properties: map[string]any{"prop": "val"},
   403  								Vector:     []float32{4, 5, 6},
   404  								Additional: map[string]interface{}{"score": float32(0.8)},
   405  							},
   406  							Vector: []float32{4, 5, 6},
   407  						},
   408  					}, []float32{0.008}, nil
   409  				}
   410  				provider := &fakeModuleProvider{}
   411  				schemaGetter := newFakeSchemaManager()
   412  				targetVectorParamHelper := newFakeTargetVectorParamHelper()
   413  				res, err := Search(ctx, params, logger, sparse, dense, nil, provider, schemaGetter, targetVectorParamHelper)
   414  				require.Nil(t, err)
   415  				assert.Len(t, res, 2)
   416  				assert.NotNil(t, res[0])
   417  				assert.NotNil(t, res[1])
   418  				assert.Contains(t, res[0].ExplainScore, "(Result Set vector,nearVector) Document 79a636c2-3314-442e-a4d1-e94d7c0afc3a")
   419  				assert.Contains(t, res[0].ExplainScore, "79a636c2-3314-442e-a4d1-e94d7c0afc3a")
   420  				assert.Equal(t, res[0].Vector, []float32{4, 5, 6})
   421  				assert.Equal(t, res[0].Dist, float32(0.008))
   422  				assert.Contains(t, res[1].ExplainScore, "(Result Set bm25f) Document 1889a225-3b28-477d-b8fc-5f6071bb4731")
   423  				assert.Contains(t, res[1].ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731")
   424  				assert.Equal(t, res[1].Vector, []float32{1, 2, 3})
   425  				assert.Equal(t, res[1].Dist, float32(0.008))
   426  			},
   427  		},
   428  	}
   429  
   430  	for _, test := range tests {
   431  		t.Run(test.name, test.f)
   432  	}
   433  }
   434  
   435  type fakeModuleProvider struct {
   436  	vector []float32
   437  }
   438  
   439  func (f *fakeModuleProvider) VectorFromInput(ctx context.Context, className, input, targetVector string) ([]float32, error) {
   440  	return f.vector, nil
   441  }
   442  
   443  func TestNullScores(t *testing.T) {
   444  	ctx := context.Background()
   445  	logger, _ := test.NewNullLogger()
   446  	class := "HybridClass"
   447  
   448  	params := &Params{
   449  		HybridSearch: &searchparams.HybridSearch{
   450  			TargetVectors:   []string{"default"},
   451  			FusionAlgorithm: common_filters.HybridRelativeScoreFusion,
   452  			Type:            "hybrid",
   453  			SubSearches: []searchparams.WeightedSearchResult{
   454  				{
   455  					Type: "nearVector",
   456  					SearchParams: searchparams.NearVector{
   457  						Vector:    []float32{1, 2, 3},
   458  						Certainty: 0.8,
   459  					},
   460  					Weight: 100,
   461  				},
   462  				{
   463  					Type: "nearText",
   464  					SearchParams: searchparams.NearTextParams{
   465  						Values:    []string{"some query"},
   466  						Certainty: 0.8,
   467  					},
   468  					Weight: 2,
   469  				},
   470  				{
   471  					Type: "sparseSearch",
   472  					SearchParams: searchparams.KeywordRanking{
   473  						Type:       "bm25",
   474  						Properties: []string{"propA", "propB"},
   475  						Query:      "some query",
   476  					},
   477  					Weight: 3,
   478  				},
   479  			},
   480  		},
   481  		Class: class,
   482  	}
   483  	sparse := func() ([]*storobj.Object, []float32, error) {
   484  		return []*storobj.Object{
   485  			{
   486  				Object: models.Object{
   487  					Class:      class,
   488  					ID:         "1889a225-3b28-477d-b8fc-5f6071bb4731",
   489  					Properties: map[string]any{"prop": "val"},
   490  					Vector:     []float32{1, 2, 3},
   491  					Additional: map[string]interface{}{"score": float32(0.008)},
   492  				},
   493  				Vector:    []float32{1, 2, 3},
   494  				VectorLen: 3,
   495  			},
   496  		}, []float32{0.008}, nil
   497  	}
   498  	dense := func([]float32) ([]*storobj.Object, []float32, error) {
   499  		return []*storobj.Object{
   500  			{
   501  				Object: models.Object{
   502  					Class:      class,
   503  					ID:         "79a636c2-3314-442e-a4d1-e94d7c0afc3a",
   504  					Properties: map[string]any{"prop": "val"},
   505  					Vector:     []float32{4, 5, 6},
   506  					Additional: map[string]interface{}{"score": float32(0.8)},
   507  				},
   508  				Vector:    []float32{4, 5, 6},
   509  				VectorLen: 3,
   510  			},
   511  		}, []float32{0.008}, nil
   512  	}
   513  	provider := &fakeModuleProvider{}
   514  	schemaGetter := newFakeSchemaManager()
   515  	targetVectorParamHelper := newFakeTargetVectorParamHelper()
   516  	res, err := Search(ctx, params, logger, sparse, dense, nil, provider, schemaGetter, targetVectorParamHelper)
   517  	require.Nil(t, err)
   518  	assert.Len(t, res, 2)
   519  	assert.NotNil(t, res[0])
   520  	assert.NotNil(t, res[1])
   521  	assert.Contains(t, res[0].ExplainScore, "(Result Set vector,nearVector) Document 79a636c2-3314-442e-a4d1-e94d7c0afc3a")
   522  	assert.Contains(t, res[0].ExplainScore, "79a636c2-3314-442e-a4d1-e94d7c0afc3a")
   523  	assert.Equal(t, res[0].Vector, []float32{4, 5, 6})
   524  	assert.Equal(t, res[0].Dist, float32(0.008))
   525  	assert.Contains(t, res[1].ExplainScore, "(Result Set bm25f) Document 1889a225-3b28-477d-b8fc-5f6071bb4731")
   526  	assert.Contains(t, res[1].ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731")
   527  	assert.Equal(t, res[1].Vector, []float32{1, 2, 3})
   528  	assert.Equal(t, res[1].Dist, float32(0.008))
   529  }