github.com/weaviate/weaviate@v1.24.6/usecases/traverser/hybrid/searcher_score_fusion_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package hybrid 13 14 import ( 15 "context" 16 "fmt" 17 "testing" 18 19 "github.com/sirupsen/logrus/hooks/test" 20 "github.com/stretchr/testify/assert" 21 "github.com/stretchr/testify/require" 22 "github.com/weaviate/weaviate/adapters/handlers/graphql/local/common_filters" 23 "github.com/weaviate/weaviate/entities/models" 24 "github.com/weaviate/weaviate/entities/searchparams" 25 "github.com/weaviate/weaviate/entities/storobj" 26 ) 27 28 type hybridTestSet struct { 29 documents []*storobj.Object 30 weights []float64 31 inputScores [][]float32 32 expectedScores []float32 33 expectedOrder []uint64 34 } 35 36 func inputSet() []hybridTestSet { 37 cases := []hybridTestSet{ 38 { 39 documents: []*storobj.Object{ 40 {Object: models.Object{}, Vector: []float32{1, 2, 3}, VectorLen: 3, DocID: 12345}, 41 {Object: models.Object{}, Vector: []float32{4, 5, 6}, VectorLen: 3, DocID: 12346}, 42 {Object: models.Object{}, Vector: []float32{7, 8, 9}, VectorLen: 3, DocID: 12347}, 43 }, 44 weights: []float64{0.5, 0.5}, 45 inputScores: [][]float32{{1, 2, 3}, {0, 1, 2}}, 46 expectedScores: []float32{1, 0.5, 0}, 47 expectedOrder: []uint64{2, 1, 0}, 48 }, 49 50 {weights: []float64{0.5, 0.5}, inputScores: [][]float32{{0, 2, 0.1}, {0, 0.2, 2}}, expectedScores: []float32{0.55, 0.525, 0}, expectedOrder: []uint64{1, 2, 0}}, 51 {weights: []float64{0.75, 0.25}, inputScores: [][]float32{{0.5, 0.5, 0}, {0, 0.01, 0.001}}, expectedScores: []float32{1, 0.75, 0.025}, expectedOrder: []uint64{1, 0, 2}}, 52 {weights: []float64{0.75, 0.25}, inputScores: [][]float32{{}, {}}, expectedScores: []float32{}, expectedOrder: []uint64{}}, 53 {weights: []float64{0.75, 0.25}, inputScores: [][]float32{{1}, {}}, expectedScores: []float32{0.75}, expectedOrder: []uint64{0}}, 54 {weights: []float64{0.75, 0.25}, inputScores: [][]float32{{}, {1}}, expectedScores: []float32{0.25}, expectedOrder: []uint64{0}}, 55 {weights: []float64{0.75, 0.25}, inputScores: [][]float32{{1, 2}, {}}, expectedScores: []float32{0.75, 0}, expectedOrder: []uint64{1, 0}}, 56 {weights: []float64{0.75, 0.25}, inputScores: [][]float32{{}, {1, 2}}, expectedScores: []float32{0.25, 0}, expectedOrder: []uint64{1, 0}}, 57 {weights: []float64{0.75, 0.25}, inputScores: [][]float32{{1, 1}, {1, 2}}, expectedScores: []float32{1, 0.75}, expectedOrder: []uint64{1, 0}}, 58 {weights: []float64{1}, inputScores: [][]float32{{1, 2, 3}}, expectedScores: []float32{1, 0.5, 0}, expectedOrder: []uint64{2, 1, 0}}, 59 {weights: []float64{0.75, 0.25}, inputScores: [][]float32{{1, 2, 3, 4}, {1, 2, 3}}, expectedScores: []float32{0.75, 0.75, 0.375, 0}, expectedOrder: []uint64{3, 2, 1, 0}}, 60 } 61 62 return cases 63 } 64 65 func TestScoreFusionSearchWithoutModuleProvider(t *testing.T) { 66 ctx := context.Background() 67 logger, _ := test.NewNullLogger() 68 class := "HybridClass" 69 inputs := inputSet() 70 params := &Params{ 71 HybridSearch: &searchparams.HybridSearch{ 72 Type: "hybrid", 73 Alpha: 0.5, 74 Query: "some query", 75 FusionAlgorithm: common_filters.HybridRelativeScoreFusion, 76 }, 77 Class: class, 78 } 79 sparse := func() ([]*storobj.Object, []float32, error) { 80 return inputs[0].documents, inputs[0].inputScores[0], nil 81 } 82 dense := func([]float32) ([]*storobj.Object, []float32, error) { 83 return inputs[0].documents, inputs[0].inputScores[1], nil 84 } 85 86 res, err := Search(ctx, params, logger, sparse, dense, nil, nil, nil, nil) 87 require.Nil(t, err) 88 fmt.Printf("res: %v\n", res) 89 } 90 91 func TestScoreFusionSearchWithModuleProvider(t *testing.T) { 92 ctx := context.Background() 93 logger, _ := test.NewNullLogger() 94 class := "HybridClass" 95 params := &Params{ 96 HybridSearch: &searchparams.HybridSearch{ 97 Type: "hybrid", 98 Alpha: 0.5, 99 Query: "some query", 100 TargetVectors: []string{"default"}, 101 FusionAlgorithm: common_filters.HybridRelativeScoreFusion, 102 }, 103 Class: class, 104 } 105 sparse := func() ([]*storobj.Object, []float32, error) { return nil, nil, nil } 106 dense := func([]float32) ([]*storobj.Object, []float32, error) { return nil, nil, nil } 107 provider := &fakeModuleProvider{} 108 schemaGetter := newFakeSchemaManager() 109 targetVectorParamHelper := newFakeTargetVectorParamHelper() 110 _, err := Search(ctx, params, logger, sparse, dense, nil, provider, schemaGetter, targetVectorParamHelper) 111 require.Nil(t, err) 112 } 113 114 func TestScoreFusionSearchWithSparseSearchOnly(t *testing.T) { 115 ctx := context.Background() 116 logger, _ := test.NewNullLogger() 117 class := "HybridClass" 118 params := &Params{ 119 HybridSearch: &searchparams.HybridSearch{ 120 Type: "hybrid", 121 Alpha: 0, 122 Query: "some query", 123 FusionAlgorithm: common_filters.HybridRelativeScoreFusion, 124 }, 125 Class: class, 126 } 127 sparse := func() ([]*storobj.Object, []float32, error) { 128 return []*storobj.Object{ 129 { 130 Object: models.Object{ 131 Class: class, 132 ID: "1889a225-3b28-477d-b8fc-5f6071bb4731", 133 Properties: map[string]any{"prop": "val"}, 134 Vector: []float32{1, 2, 3}, 135 }, 136 Vector: []float32{1, 2, 3}, 137 VectorLen: 3, 138 DocID: 1, 139 }, 140 }, []float32{0.008}, nil 141 } 142 dense := func([]float32) ([]*storobj.Object, []float32, error) { return nil, nil, nil } 143 res, err := Search(ctx, params, logger, sparse, dense, nil, nil, nil, nil) 144 require.Nil(t, err) 145 assert.Len(t, res, 1) 146 assert.NotNil(t, res[0]) 147 assert.Contains(t, res[0].ExplainScore, "(Result Set keyword) Document") 148 assert.Contains(t, res[0].ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731") 149 assert.Equal(t, res[0].Vector, []float32{1, 2, 3}) 150 assert.Equal(t, res[0].Dist, float32(0.000)) 151 assert.Equal(t, float32(1), res[0].Score) 152 } 153 154 func TestScoreFusionSearchWithDenseSearchOnly(t *testing.T) { 155 ctx := context.Background() 156 logger, _ := test.NewNullLogger() 157 class := "HybridClass" 158 params := &Params{ 159 HybridSearch: &searchparams.HybridSearch{ 160 Type: "hybrid", 161 Alpha: 1, 162 Query: "some query", 163 Vector: []float32{1, 2, 3}, 164 FusionAlgorithm: common_filters.HybridRelativeScoreFusion, 165 }, 166 Class: class, 167 } 168 sparse := func() ([]*storobj.Object, []float32, error) { return nil, nil, nil } 169 dense := func([]float32) ([]*storobj.Object, []float32, error) { 170 return []*storobj.Object{ 171 { 172 Object: models.Object{ 173 Class: class, 174 ID: "1889a225-3b28-477d-b8fc-5f6071bb4731", 175 Properties: map[string]any{"prop": "val"}, 176 Vector: []float32{1, 2, 3}, 177 }, 178 Vector: []float32{1, 2, 3}, 179 VectorLen: 3, 180 DocID: 1, 181 }, 182 }, []float32{0.008}, nil 183 } 184 185 res, err := Search(ctx, params, logger, sparse, dense, nil, nil, nil, nil) 186 require.Nil(t, err) 187 assert.Len(t, res, 1) 188 assert.NotNil(t, res[0]) 189 assert.Contains(t, res[0].ExplainScore, "(Result Set vector) Document") 190 assert.Contains(t, res[0].ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731") 191 assert.Equal(t, res[0].Vector, []float32{1, 2, 3}) 192 assert.Equal(t, res[0].Dist, float32(0.008)) 193 assert.Equal(t, float32(1), res[0].Score) 194 } 195 196 func TestScoreFusionCombinedHybridSearch(t *testing.T) { 197 ctx := context.Background() 198 logger, _ := test.NewNullLogger() 199 class := "HybridClass" 200 params := &Params{ 201 HybridSearch: &searchparams.HybridSearch{ 202 Type: "hybrid", 203 Alpha: 0.5, 204 Query: "some query", 205 Vector: []float32{1, 2, 3}, 206 FusionAlgorithm: common_filters.HybridRelativeScoreFusion, 207 }, 208 Class: class, 209 } 210 sparse := func() ([]*storobj.Object, []float32, error) { 211 return []*storobj.Object{ 212 { 213 Object: models.Object{ 214 Class: class, 215 ID: "1889a225-3b28-477d-b8fc-5f6071bb4731", 216 Properties: map[string]any{"prop": "val"}, 217 Vector: []float32{1, 2, 3}, 218 }, 219 Vector: []float32{1, 2, 3}, 220 VectorLen: 3, 221 DocID: 1, 222 }, 223 }, []float32{0.008}, nil 224 } 225 dense := func([]float32) ([]*storobj.Object, []float32, error) { 226 return []*storobj.Object{ 227 { 228 Object: models.Object{ 229 Class: class, 230 ID: "79a636c2-3314-442e-a4d1-e94d7c0afc3a", 231 Properties: map[string]any{"prop": "val"}, 232 Vector: []float32{4, 5, 6}, 233 }, 234 Vector: []float32{4, 5, 6}, 235 VectorLen: 3, 236 DocID: 2, 237 }, 238 }, []float32{0.008}, nil 239 } 240 res, err := Search(ctx, params, logger, sparse, dense, nil, nil, nil, nil) 241 require.Nil(t, err) 242 assert.Len(t, res, 2) 243 assert.NotNil(t, res[0]) 244 assert.NotNil(t, res[1]) 245 assert.Contains(t, res[0].ExplainScore, "(Result Set vector) Document") 246 assert.Contains(t, res[0].ExplainScore, "79a636c2-3314-442e-a4d1-e94d7c0afc3a") 247 assert.Equal(t, res[0].Vector, []float32{4, 5, 6}) 248 assert.Equal(t, res[0].Dist, float32(0.008)) 249 assert.Equal(t, float32(0.5), res[0].Score) 250 assert.Contains(t, res[1].ExplainScore, "(Result Set keyword) Document") 251 assert.Contains(t, res[1].ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731") 252 assert.Equal(t, res[1].Vector, []float32{1, 2, 3}) 253 assert.Equal(t, res[1].Dist, float32(0.000)) 254 assert.Equal(t, float32(0.5), res[1].Score) 255 } 256 257 func TestScoreFusionWithSparseSubsearchFilter(t *testing.T) { 258 ctx := context.Background() 259 logger, _ := test.NewNullLogger() 260 class := "HybridClass" 261 params := &Params{ 262 HybridSearch: &searchparams.HybridSearch{ 263 Type: "hybrid", 264 FusionAlgorithm: common_filters.HybridRelativeScoreFusion, 265 SubSearches: []searchparams.WeightedSearchResult{ 266 { 267 Type: "sparseSearch", 268 SearchParams: searchparams.KeywordRanking{ 269 Type: "bm25", 270 Properties: []string{"propA", "propB"}, 271 Query: "some query", 272 }, 273 }, 274 }, 275 }, 276 Class: class, 277 } 278 sparse := func() ([]*storobj.Object, []float32, error) { 279 return []*storobj.Object{ 280 { 281 Object: models.Object{ 282 Class: class, 283 ID: "1889a225-3b28-477d-b8fc-5f6071bb4731", 284 Properties: map[string]any{"prop": "val"}, 285 Vector: []float32{1, 2, 3}, 286 Additional: map[string]interface{}{"score": float32(0.008)}, 287 }, 288 Vector: []float32{1, 2, 3}, 289 }, 290 }, []float32{0.008}, nil 291 } 292 dense := func([]float32) ([]*storobj.Object, []float32, error) { return nil, nil, nil } 293 res, err := Search(ctx, params, logger, sparse, dense, nil, nil, nil, nil) 294 require.Nil(t, err) 295 assert.Len(t, res, 1) 296 assert.NotNil(t, res[0]) 297 assert.Contains(t, res[0].ExplainScore, "(Result Set bm25f) Document 1889a225-3b28-477d-b8fc-5f6071bb4731") 298 assert.Contains(t, res[0].ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731") 299 assert.Equal(t, res[0].Vector, []float32{1, 2, 3}) 300 assert.Equal(t, res[0].Dist, float32(0.008)) 301 } 302 303 func TestScoreFusionWithNearTextSubsearchFilter(t *testing.T) { 304 ctx := context.Background() 305 logger, _ := test.NewNullLogger() 306 class := "HybridClass" 307 params := &Params{ 308 HybridSearch: &searchparams.HybridSearch{ 309 TargetVectors: []string{"default"}, 310 Type: "hybrid", 311 FusionAlgorithm: common_filters.HybridRelativeScoreFusion, 312 SubSearches: []searchparams.WeightedSearchResult{ 313 { 314 Type: "nearText", 315 SearchParams: searchparams.NearTextParams{ 316 Values: []string{"some query"}, 317 Certainty: 0.8, 318 }, 319 }, 320 }, 321 }, 322 Class: class, 323 } 324 sparse := func() ([]*storobj.Object, []float32, error) { return nil, nil, nil } 325 dense := func([]float32) ([]*storobj.Object, []float32, error) { 326 return []*storobj.Object{ 327 { 328 Object: models.Object{ 329 Class: class, 330 ID: "1889a225-3b28-477d-b8fc-5f6071bb4731", 331 Properties: map[string]any{"prop": "val"}, 332 Vector: []float32{1, 2, 3}, 333 Additional: map[string]interface{}{"score": float32(0.008)}, 334 }, 335 Vector: []float32{1, 2, 3}, 336 }, 337 }, []float32{0.008}, nil 338 } 339 provider := &fakeModuleProvider{} 340 schemaGetter := newFakeSchemaManager() 341 targetVectorParamHelper := newFakeTargetVectorParamHelper() 342 res, err := Search(ctx, params, logger, sparse, dense, nil, provider, schemaGetter, targetVectorParamHelper) 343 require.Nil(t, err) 344 assert.Len(t, res, 1) 345 assert.NotNil(t, res[0]) 346 assert.Contains(t, res[0].ExplainScore, "(Result Set vector,nearText) Document 1889a225-3b28-477d-b8fc-5f6071bb4731") 347 assert.Contains(t, res[0].ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731") 348 assert.Equal(t, res[0].Vector, []float32{1, 2, 3}) 349 assert.Equal(t, res[0].Dist, float32(0.008)) 350 } 351 352 func TestScoreFusionWithNearVectorSubsearchFilter(t *testing.T) { 353 ctx := context.Background() 354 logger, _ := test.NewNullLogger() 355 class := "HybridClass" 356 params := &Params{ 357 HybridSearch: &searchparams.HybridSearch{ 358 TargetVectors: []string{"default"}, 359 Type: "hybrid", 360 FusionAlgorithm: common_filters.HybridRelativeScoreFusion, 361 SubSearches: []searchparams.WeightedSearchResult{ 362 { 363 Type: "nearVector", 364 SearchParams: searchparams.NearVector{ 365 Vector: []float32{1, 2, 3}, 366 Certainty: 0.8, 367 }, 368 }, 369 }, 370 }, 371 Class: class, 372 } 373 sparse := func() ([]*storobj.Object, []float32, error) { return nil, nil, nil } 374 dense := func([]float32) ([]*storobj.Object, []float32, error) { 375 return []*storobj.Object{ 376 { 377 Object: models.Object{ 378 Class: class, 379 ID: "1889a225-3b28-477d-b8fc-5f6071bb4731", 380 Properties: map[string]any{"prop": "val"}, 381 Vector: []float32{1, 2, 3}, 382 Additional: map[string]interface{}{"score": float32(0.008)}, 383 }, 384 Vector: []float32{1, 2, 3}, 385 }, 386 }, []float32{0.008}, nil 387 } 388 provider := &fakeModuleProvider{} 389 schemaGetter := newFakeSchemaManager() 390 targetVectorParamHelper := newFakeTargetVectorParamHelper() 391 res, err := Search(ctx, params, logger, sparse, dense, nil, provider, schemaGetter, targetVectorParamHelper) 392 require.Nil(t, err) 393 assert.Len(t, res, 1) 394 assert.NotNil(t, res[0]) 395 assert.Contains(t, res[0].ExplainScore, "(Result Set vector,nearVector) Document 1889a225-3b28-477d-b8fc-5f6071bb4731") 396 assert.Contains(t, res[0].ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731") 397 assert.Equal(t, res[0].Vector, []float32{1, 2, 3}) 398 assert.Equal(t, res[0].Dist, float32(0.008)) 399 } 400 401 func TestScoreFusionWithAllSubsearchFilters(t *testing.T) { 402 ctx := context.Background() 403 logger, _ := test.NewNullLogger() 404 class := "HybridClass" 405 params := &Params{ 406 HybridSearch: &searchparams.HybridSearch{ 407 TargetVectors: []string{"default"}, 408 Type: "hybrid", 409 FusionAlgorithm: common_filters.HybridRelativeScoreFusion, 410 SubSearches: []searchparams.WeightedSearchResult{ 411 { 412 Type: "nearVector", 413 SearchParams: searchparams.NearVector{ 414 Vector: []float32{1, 2, 3}, 415 Certainty: 0.8, 416 }, 417 Weight: 100, 418 }, 419 { 420 Type: "nearText", 421 SearchParams: searchparams.NearTextParams{ 422 Values: []string{"some query"}, 423 Certainty: 0.8, 424 }, 425 Weight: 2, 426 }, 427 { 428 Type: "sparseSearch", 429 SearchParams: searchparams.KeywordRanking{ 430 Type: "bm25", 431 Properties: []string{"propA", "propB"}, 432 Query: "some query", 433 }, 434 Weight: 3, 435 }, 436 }, 437 }, 438 Class: class, 439 } 440 sparse := func() ([]*storobj.Object, []float32, error) { 441 return []*storobj.Object{ 442 { 443 Object: models.Object{ 444 Class: class, 445 ID: "1889a225-3b28-477d-b8fc-5f6071bb4731", 446 Properties: map[string]any{"prop": "val"}, 447 Vector: []float32{1, 2, 3}, 448 Additional: map[string]interface{}{"score": float32(0.008)}, 449 }, 450 Vector: []float32{1, 2, 3}, 451 }, 452 }, []float32{0.008}, nil 453 } 454 dense := func([]float32) ([]*storobj.Object, []float32, error) { 455 return []*storobj.Object{ 456 { 457 Object: models.Object{ 458 Class: class, 459 ID: "79a636c2-3314-442e-a4d1-e94d7c0afc3a", 460 Properties: map[string]any{"prop": "val"}, 461 Vector: []float32{4, 5, 6}, 462 Additional: map[string]interface{}{"score": float32(0.8)}, 463 }, 464 Vector: []float32{4, 5, 6}, 465 }, 466 }, []float32{0.008}, nil 467 } 468 provider := &fakeModuleProvider{} 469 schemaGetter := newFakeSchemaManager() 470 targetVectorParamHelper := newFakeTargetVectorParamHelper() 471 res, err := Search(ctx, params, logger, sparse, dense, nil, provider, schemaGetter, targetVectorParamHelper) 472 require.Nil(t, err) 473 assert.Len(t, res, 2) 474 assert.NotNil(t, res[0]) 475 assert.NotNil(t, res[1]) 476 assert.Contains(t, res[0].ExplainScore, "(Result Set vector,nearText) Document 79a636c2-3314-442e-a4d1-e94d7c0afc3a") 477 assert.Contains(t, res[0].ExplainScore, "79a636c2-3314-442e-a4d1-e94d7c0afc3a") 478 assert.Equal(t, res[0].Vector, []float32{4, 5, 6}) 479 assert.Equal(t, res[0].Dist, float32(0.008)) 480 assert.Contains(t, res[1].ExplainScore, "(Result Set bm25f) Document 1889a225-3b28-477d-b8fc-5f6071bb4731") 481 assert.Contains(t, res[1].ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731") 482 assert.Equal(t, res[1].Vector, []float32{1, 2, 3}) 483 assert.Equal(t, res[1].Dist, float32(0.008)) 484 }