github.com/weaviate/weaviate@v1.24.6/usecases/traverser/hybrid/searcher_score_fusion_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package hybrid
    13  
    14  import (
    15  	"context"
    16  	"fmt"
    17  	"testing"
    18  
    19  	"github.com/sirupsen/logrus/hooks/test"
    20  	"github.com/stretchr/testify/assert"
    21  	"github.com/stretchr/testify/require"
    22  	"github.com/weaviate/weaviate/adapters/handlers/graphql/local/common_filters"
    23  	"github.com/weaviate/weaviate/entities/models"
    24  	"github.com/weaviate/weaviate/entities/searchparams"
    25  	"github.com/weaviate/weaviate/entities/storobj"
    26  )
    27  
    28  type hybridTestSet struct {
    29  	documents      []*storobj.Object
    30  	weights        []float64
    31  	inputScores    [][]float32
    32  	expectedScores []float32
    33  	expectedOrder  []uint64
    34  }
    35  
    36  func inputSet() []hybridTestSet {
    37  	cases := []hybridTestSet{
    38  		{
    39  			documents: []*storobj.Object{
    40  				{Object: models.Object{}, Vector: []float32{1, 2, 3}, VectorLen: 3, DocID: 12345},
    41  				{Object: models.Object{}, Vector: []float32{4, 5, 6}, VectorLen: 3, DocID: 12346},
    42  				{Object: models.Object{}, Vector: []float32{7, 8, 9}, VectorLen: 3, DocID: 12347},
    43  			},
    44  			weights:        []float64{0.5, 0.5},
    45  			inputScores:    [][]float32{{1, 2, 3}, {0, 1, 2}},
    46  			expectedScores: []float32{1, 0.5, 0},
    47  			expectedOrder:  []uint64{2, 1, 0},
    48  		},
    49  
    50  		{weights: []float64{0.5, 0.5}, inputScores: [][]float32{{0, 2, 0.1}, {0, 0.2, 2}}, expectedScores: []float32{0.55, 0.525, 0}, expectedOrder: []uint64{1, 2, 0}},
    51  		{weights: []float64{0.75, 0.25}, inputScores: [][]float32{{0.5, 0.5, 0}, {0, 0.01, 0.001}}, expectedScores: []float32{1, 0.75, 0.025}, expectedOrder: []uint64{1, 0, 2}},
    52  		{weights: []float64{0.75, 0.25}, inputScores: [][]float32{{}, {}}, expectedScores: []float32{}, expectedOrder: []uint64{}},
    53  		{weights: []float64{0.75, 0.25}, inputScores: [][]float32{{1}, {}}, expectedScores: []float32{0.75}, expectedOrder: []uint64{0}},
    54  		{weights: []float64{0.75, 0.25}, inputScores: [][]float32{{}, {1}}, expectedScores: []float32{0.25}, expectedOrder: []uint64{0}},
    55  		{weights: []float64{0.75, 0.25}, inputScores: [][]float32{{1, 2}, {}}, expectedScores: []float32{0.75, 0}, expectedOrder: []uint64{1, 0}},
    56  		{weights: []float64{0.75, 0.25}, inputScores: [][]float32{{}, {1, 2}}, expectedScores: []float32{0.25, 0}, expectedOrder: []uint64{1, 0}},
    57  		{weights: []float64{0.75, 0.25}, inputScores: [][]float32{{1, 1}, {1, 2}}, expectedScores: []float32{1, 0.75}, expectedOrder: []uint64{1, 0}},
    58  		{weights: []float64{1}, inputScores: [][]float32{{1, 2, 3}}, expectedScores: []float32{1, 0.5, 0}, expectedOrder: []uint64{2, 1, 0}},
    59  		{weights: []float64{0.75, 0.25}, inputScores: [][]float32{{1, 2, 3, 4}, {1, 2, 3}}, expectedScores: []float32{0.75, 0.75, 0.375, 0}, expectedOrder: []uint64{3, 2, 1, 0}},
    60  	}
    61  
    62  	return cases
    63  }
    64  
    65  func TestScoreFusionSearchWithoutModuleProvider(t *testing.T) {
    66  	ctx := context.Background()
    67  	logger, _ := test.NewNullLogger()
    68  	class := "HybridClass"
    69  	inputs := inputSet()
    70  	params := &Params{
    71  		HybridSearch: &searchparams.HybridSearch{
    72  			Type:            "hybrid",
    73  			Alpha:           0.5,
    74  			Query:           "some query",
    75  			FusionAlgorithm: common_filters.HybridRelativeScoreFusion,
    76  		},
    77  		Class: class,
    78  	}
    79  	sparse := func() ([]*storobj.Object, []float32, error) {
    80  		return inputs[0].documents, inputs[0].inputScores[0], nil
    81  	}
    82  	dense := func([]float32) ([]*storobj.Object, []float32, error) {
    83  		return inputs[0].documents, inputs[0].inputScores[1], nil
    84  	}
    85  
    86  	res, err := Search(ctx, params, logger, sparse, dense, nil, nil, nil, nil)
    87  	require.Nil(t, err)
    88  	fmt.Printf("res: %v\n", res)
    89  }
    90  
    91  func TestScoreFusionSearchWithModuleProvider(t *testing.T) {
    92  	ctx := context.Background()
    93  	logger, _ := test.NewNullLogger()
    94  	class := "HybridClass"
    95  	params := &Params{
    96  		HybridSearch: &searchparams.HybridSearch{
    97  			Type:            "hybrid",
    98  			Alpha:           0.5,
    99  			Query:           "some query",
   100  			TargetVectors:   []string{"default"},
   101  			FusionAlgorithm: common_filters.HybridRelativeScoreFusion,
   102  		},
   103  		Class: class,
   104  	}
   105  	sparse := func() ([]*storobj.Object, []float32, error) { return nil, nil, nil }
   106  	dense := func([]float32) ([]*storobj.Object, []float32, error) { return nil, nil, nil }
   107  	provider := &fakeModuleProvider{}
   108  	schemaGetter := newFakeSchemaManager()
   109  	targetVectorParamHelper := newFakeTargetVectorParamHelper()
   110  	_, err := Search(ctx, params, logger, sparse, dense, nil, provider, schemaGetter, targetVectorParamHelper)
   111  	require.Nil(t, err)
   112  }
   113  
   114  func TestScoreFusionSearchWithSparseSearchOnly(t *testing.T) {
   115  	ctx := context.Background()
   116  	logger, _ := test.NewNullLogger()
   117  	class := "HybridClass"
   118  	params := &Params{
   119  		HybridSearch: &searchparams.HybridSearch{
   120  			Type:            "hybrid",
   121  			Alpha:           0,
   122  			Query:           "some query",
   123  			FusionAlgorithm: common_filters.HybridRelativeScoreFusion,
   124  		},
   125  		Class: class,
   126  	}
   127  	sparse := func() ([]*storobj.Object, []float32, error) {
   128  		return []*storobj.Object{
   129  			{
   130  				Object: models.Object{
   131  					Class:      class,
   132  					ID:         "1889a225-3b28-477d-b8fc-5f6071bb4731",
   133  					Properties: map[string]any{"prop": "val"},
   134  					Vector:     []float32{1, 2, 3},
   135  				},
   136  				Vector:    []float32{1, 2, 3},
   137  				VectorLen: 3,
   138  				DocID:     1,
   139  			},
   140  		}, []float32{0.008}, nil
   141  	}
   142  	dense := func([]float32) ([]*storobj.Object, []float32, error) { return nil, nil, nil }
   143  	res, err := Search(ctx, params, logger, sparse, dense, nil, nil, nil, nil)
   144  	require.Nil(t, err)
   145  	assert.Len(t, res, 1)
   146  	assert.NotNil(t, res[0])
   147  	assert.Contains(t, res[0].ExplainScore, "(Result Set keyword) Document")
   148  	assert.Contains(t, res[0].ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731")
   149  	assert.Equal(t, res[0].Vector, []float32{1, 2, 3})
   150  	assert.Equal(t, res[0].Dist, float32(0.000))
   151  	assert.Equal(t, float32(1), res[0].Score)
   152  }
   153  
   154  func TestScoreFusionSearchWithDenseSearchOnly(t *testing.T) {
   155  	ctx := context.Background()
   156  	logger, _ := test.NewNullLogger()
   157  	class := "HybridClass"
   158  	params := &Params{
   159  		HybridSearch: &searchparams.HybridSearch{
   160  			Type:            "hybrid",
   161  			Alpha:           1,
   162  			Query:           "some query",
   163  			Vector:          []float32{1, 2, 3},
   164  			FusionAlgorithm: common_filters.HybridRelativeScoreFusion,
   165  		},
   166  		Class: class,
   167  	}
   168  	sparse := func() ([]*storobj.Object, []float32, error) { return nil, nil, nil }
   169  	dense := func([]float32) ([]*storobj.Object, []float32, error) {
   170  		return []*storobj.Object{
   171  			{
   172  				Object: models.Object{
   173  					Class:      class,
   174  					ID:         "1889a225-3b28-477d-b8fc-5f6071bb4731",
   175  					Properties: map[string]any{"prop": "val"},
   176  					Vector:     []float32{1, 2, 3},
   177  				},
   178  				Vector:    []float32{1, 2, 3},
   179  				VectorLen: 3,
   180  				DocID:     1,
   181  			},
   182  		}, []float32{0.008}, nil
   183  	}
   184  
   185  	res, err := Search(ctx, params, logger, sparse, dense, nil, nil, nil, nil)
   186  	require.Nil(t, err)
   187  	assert.Len(t, res, 1)
   188  	assert.NotNil(t, res[0])
   189  	assert.Contains(t, res[0].ExplainScore, "(Result Set vector) Document")
   190  	assert.Contains(t, res[0].ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731")
   191  	assert.Equal(t, res[0].Vector, []float32{1, 2, 3})
   192  	assert.Equal(t, res[0].Dist, float32(0.008))
   193  	assert.Equal(t, float32(1), res[0].Score)
   194  }
   195  
   196  func TestScoreFusionCombinedHybridSearch(t *testing.T) {
   197  	ctx := context.Background()
   198  	logger, _ := test.NewNullLogger()
   199  	class := "HybridClass"
   200  	params := &Params{
   201  		HybridSearch: &searchparams.HybridSearch{
   202  			Type:            "hybrid",
   203  			Alpha:           0.5,
   204  			Query:           "some query",
   205  			Vector:          []float32{1, 2, 3},
   206  			FusionAlgorithm: common_filters.HybridRelativeScoreFusion,
   207  		},
   208  		Class: class,
   209  	}
   210  	sparse := func() ([]*storobj.Object, []float32, error) {
   211  		return []*storobj.Object{
   212  			{
   213  				Object: models.Object{
   214  					Class:      class,
   215  					ID:         "1889a225-3b28-477d-b8fc-5f6071bb4731",
   216  					Properties: map[string]any{"prop": "val"},
   217  					Vector:     []float32{1, 2, 3},
   218  				},
   219  				Vector:    []float32{1, 2, 3},
   220  				VectorLen: 3,
   221  				DocID:     1,
   222  			},
   223  		}, []float32{0.008}, nil
   224  	}
   225  	dense := func([]float32) ([]*storobj.Object, []float32, error) {
   226  		return []*storobj.Object{
   227  			{
   228  				Object: models.Object{
   229  					Class:      class,
   230  					ID:         "79a636c2-3314-442e-a4d1-e94d7c0afc3a",
   231  					Properties: map[string]any{"prop": "val"},
   232  					Vector:     []float32{4, 5, 6},
   233  				},
   234  				Vector:    []float32{4, 5, 6},
   235  				VectorLen: 3,
   236  				DocID:     2,
   237  			},
   238  		}, []float32{0.008}, nil
   239  	}
   240  	res, err := Search(ctx, params, logger, sparse, dense, nil, nil, nil, nil)
   241  	require.Nil(t, err)
   242  	assert.Len(t, res, 2)
   243  	assert.NotNil(t, res[0])
   244  	assert.NotNil(t, res[1])
   245  	assert.Contains(t, res[0].ExplainScore, "(Result Set vector) Document")
   246  	assert.Contains(t, res[0].ExplainScore, "79a636c2-3314-442e-a4d1-e94d7c0afc3a")
   247  	assert.Equal(t, res[0].Vector, []float32{4, 5, 6})
   248  	assert.Equal(t, res[0].Dist, float32(0.008))
   249  	assert.Equal(t, float32(0.5), res[0].Score)
   250  	assert.Contains(t, res[1].ExplainScore, "(Result Set keyword) Document")
   251  	assert.Contains(t, res[1].ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731")
   252  	assert.Equal(t, res[1].Vector, []float32{1, 2, 3})
   253  	assert.Equal(t, res[1].Dist, float32(0.000))
   254  	assert.Equal(t, float32(0.5), res[1].Score)
   255  }
   256  
   257  func TestScoreFusionWithSparseSubsearchFilter(t *testing.T) {
   258  	ctx := context.Background()
   259  	logger, _ := test.NewNullLogger()
   260  	class := "HybridClass"
   261  	params := &Params{
   262  		HybridSearch: &searchparams.HybridSearch{
   263  			Type:            "hybrid",
   264  			FusionAlgorithm: common_filters.HybridRelativeScoreFusion,
   265  			SubSearches: []searchparams.WeightedSearchResult{
   266  				{
   267  					Type: "sparseSearch",
   268  					SearchParams: searchparams.KeywordRanking{
   269  						Type:       "bm25",
   270  						Properties: []string{"propA", "propB"},
   271  						Query:      "some query",
   272  					},
   273  				},
   274  			},
   275  		},
   276  		Class: class,
   277  	}
   278  	sparse := func() ([]*storobj.Object, []float32, error) {
   279  		return []*storobj.Object{
   280  			{
   281  				Object: models.Object{
   282  					Class:      class,
   283  					ID:         "1889a225-3b28-477d-b8fc-5f6071bb4731",
   284  					Properties: map[string]any{"prop": "val"},
   285  					Vector:     []float32{1, 2, 3},
   286  					Additional: map[string]interface{}{"score": float32(0.008)},
   287  				},
   288  				Vector: []float32{1, 2, 3},
   289  			},
   290  		}, []float32{0.008}, nil
   291  	}
   292  	dense := func([]float32) ([]*storobj.Object, []float32, error) { return nil, nil, nil }
   293  	res, err := Search(ctx, params, logger, sparse, dense, nil, nil, nil, nil)
   294  	require.Nil(t, err)
   295  	assert.Len(t, res, 1)
   296  	assert.NotNil(t, res[0])
   297  	assert.Contains(t, res[0].ExplainScore, "(Result Set bm25f) Document 1889a225-3b28-477d-b8fc-5f6071bb4731")
   298  	assert.Contains(t, res[0].ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731")
   299  	assert.Equal(t, res[0].Vector, []float32{1, 2, 3})
   300  	assert.Equal(t, res[0].Dist, float32(0.008))
   301  }
   302  
   303  func TestScoreFusionWithNearTextSubsearchFilter(t *testing.T) {
   304  	ctx := context.Background()
   305  	logger, _ := test.NewNullLogger()
   306  	class := "HybridClass"
   307  	params := &Params{
   308  		HybridSearch: &searchparams.HybridSearch{
   309  			TargetVectors:   []string{"default"},
   310  			Type:            "hybrid",
   311  			FusionAlgorithm: common_filters.HybridRelativeScoreFusion,
   312  			SubSearches: []searchparams.WeightedSearchResult{
   313  				{
   314  					Type: "nearText",
   315  					SearchParams: searchparams.NearTextParams{
   316  						Values:    []string{"some query"},
   317  						Certainty: 0.8,
   318  					},
   319  				},
   320  			},
   321  		},
   322  		Class: class,
   323  	}
   324  	sparse := func() ([]*storobj.Object, []float32, error) { return nil, nil, nil }
   325  	dense := func([]float32) ([]*storobj.Object, []float32, error) {
   326  		return []*storobj.Object{
   327  			{
   328  				Object: models.Object{
   329  					Class:      class,
   330  					ID:         "1889a225-3b28-477d-b8fc-5f6071bb4731",
   331  					Properties: map[string]any{"prop": "val"},
   332  					Vector:     []float32{1, 2, 3},
   333  					Additional: map[string]interface{}{"score": float32(0.008)},
   334  				},
   335  				Vector: []float32{1, 2, 3},
   336  			},
   337  		}, []float32{0.008}, nil
   338  	}
   339  	provider := &fakeModuleProvider{}
   340  	schemaGetter := newFakeSchemaManager()
   341  	targetVectorParamHelper := newFakeTargetVectorParamHelper()
   342  	res, err := Search(ctx, params, logger, sparse, dense, nil, provider, schemaGetter, targetVectorParamHelper)
   343  	require.Nil(t, err)
   344  	assert.Len(t, res, 1)
   345  	assert.NotNil(t, res[0])
   346  	assert.Contains(t, res[0].ExplainScore, "(Result Set vector,nearText) Document 1889a225-3b28-477d-b8fc-5f6071bb4731")
   347  	assert.Contains(t, res[0].ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731")
   348  	assert.Equal(t, res[0].Vector, []float32{1, 2, 3})
   349  	assert.Equal(t, res[0].Dist, float32(0.008))
   350  }
   351  
   352  func TestScoreFusionWithNearVectorSubsearchFilter(t *testing.T) {
   353  	ctx := context.Background()
   354  	logger, _ := test.NewNullLogger()
   355  	class := "HybridClass"
   356  	params := &Params{
   357  		HybridSearch: &searchparams.HybridSearch{
   358  			TargetVectors:   []string{"default"},
   359  			Type:            "hybrid",
   360  			FusionAlgorithm: common_filters.HybridRelativeScoreFusion,
   361  			SubSearches: []searchparams.WeightedSearchResult{
   362  				{
   363  					Type: "nearVector",
   364  					SearchParams: searchparams.NearVector{
   365  						Vector:    []float32{1, 2, 3},
   366  						Certainty: 0.8,
   367  					},
   368  				},
   369  			},
   370  		},
   371  		Class: class,
   372  	}
   373  	sparse := func() ([]*storobj.Object, []float32, error) { return nil, nil, nil }
   374  	dense := func([]float32) ([]*storobj.Object, []float32, error) {
   375  		return []*storobj.Object{
   376  			{
   377  				Object: models.Object{
   378  					Class:      class,
   379  					ID:         "1889a225-3b28-477d-b8fc-5f6071bb4731",
   380  					Properties: map[string]any{"prop": "val"},
   381  					Vector:     []float32{1, 2, 3},
   382  					Additional: map[string]interface{}{"score": float32(0.008)},
   383  				},
   384  				Vector: []float32{1, 2, 3},
   385  			},
   386  		}, []float32{0.008}, nil
   387  	}
   388  	provider := &fakeModuleProvider{}
   389  	schemaGetter := newFakeSchemaManager()
   390  	targetVectorParamHelper := newFakeTargetVectorParamHelper()
   391  	res, err := Search(ctx, params, logger, sparse, dense, nil, provider, schemaGetter, targetVectorParamHelper)
   392  	require.Nil(t, err)
   393  	assert.Len(t, res, 1)
   394  	assert.NotNil(t, res[0])
   395  	assert.Contains(t, res[0].ExplainScore, "(Result Set vector,nearVector) Document 1889a225-3b28-477d-b8fc-5f6071bb4731")
   396  	assert.Contains(t, res[0].ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731")
   397  	assert.Equal(t, res[0].Vector, []float32{1, 2, 3})
   398  	assert.Equal(t, res[0].Dist, float32(0.008))
   399  }
   400  
   401  func TestScoreFusionWithAllSubsearchFilters(t *testing.T) {
   402  	ctx := context.Background()
   403  	logger, _ := test.NewNullLogger()
   404  	class := "HybridClass"
   405  	params := &Params{
   406  		HybridSearch: &searchparams.HybridSearch{
   407  			TargetVectors:   []string{"default"},
   408  			Type:            "hybrid",
   409  			FusionAlgorithm: common_filters.HybridRelativeScoreFusion,
   410  			SubSearches: []searchparams.WeightedSearchResult{
   411  				{
   412  					Type: "nearVector",
   413  					SearchParams: searchparams.NearVector{
   414  						Vector:    []float32{1, 2, 3},
   415  						Certainty: 0.8,
   416  					},
   417  					Weight: 100,
   418  				},
   419  				{
   420  					Type: "nearText",
   421  					SearchParams: searchparams.NearTextParams{
   422  						Values:    []string{"some query"},
   423  						Certainty: 0.8,
   424  					},
   425  					Weight: 2,
   426  				},
   427  				{
   428  					Type: "sparseSearch",
   429  					SearchParams: searchparams.KeywordRanking{
   430  						Type:       "bm25",
   431  						Properties: []string{"propA", "propB"},
   432  						Query:      "some query",
   433  					},
   434  					Weight: 3,
   435  				},
   436  			},
   437  		},
   438  		Class: class,
   439  	}
   440  	sparse := func() ([]*storobj.Object, []float32, error) {
   441  		return []*storobj.Object{
   442  			{
   443  				Object: models.Object{
   444  					Class:      class,
   445  					ID:         "1889a225-3b28-477d-b8fc-5f6071bb4731",
   446  					Properties: map[string]any{"prop": "val"},
   447  					Vector:     []float32{1, 2, 3},
   448  					Additional: map[string]interface{}{"score": float32(0.008)},
   449  				},
   450  				Vector: []float32{1, 2, 3},
   451  			},
   452  		}, []float32{0.008}, nil
   453  	}
   454  	dense := func([]float32) ([]*storobj.Object, []float32, error) {
   455  		return []*storobj.Object{
   456  			{
   457  				Object: models.Object{
   458  					Class:      class,
   459  					ID:         "79a636c2-3314-442e-a4d1-e94d7c0afc3a",
   460  					Properties: map[string]any{"prop": "val"},
   461  					Vector:     []float32{4, 5, 6},
   462  					Additional: map[string]interface{}{"score": float32(0.8)},
   463  				},
   464  				Vector: []float32{4, 5, 6},
   465  			},
   466  		}, []float32{0.008}, nil
   467  	}
   468  	provider := &fakeModuleProvider{}
   469  	schemaGetter := newFakeSchemaManager()
   470  	targetVectorParamHelper := newFakeTargetVectorParamHelper()
   471  	res, err := Search(ctx, params, logger, sparse, dense, nil, provider, schemaGetter, targetVectorParamHelper)
   472  	require.Nil(t, err)
   473  	assert.Len(t, res, 2)
   474  	assert.NotNil(t, res[0])
   475  	assert.NotNil(t, res[1])
   476  	assert.Contains(t, res[0].ExplainScore, "(Result Set vector,nearText) Document 79a636c2-3314-442e-a4d1-e94d7c0afc3a")
   477  	assert.Contains(t, res[0].ExplainScore, "79a636c2-3314-442e-a4d1-e94d7c0afc3a")
   478  	assert.Equal(t, res[0].Vector, []float32{4, 5, 6})
   479  	assert.Equal(t, res[0].Dist, float32(0.008))
   480  	assert.Contains(t, res[1].ExplainScore, "(Result Set bm25f) Document 1889a225-3b28-477d-b8fc-5f6071bb4731")
   481  	assert.Contains(t, res[1].ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731")
   482  	assert.Equal(t, res[1].Vector, []float32{1, 2, 3})
   483  	assert.Equal(t, res[1].Dist, float32(0.008))
   484  }