github.com/weaviate/weaviate@v1.24.6/usecases/traverser/hybrid/searcher_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package hybrid 13 14 import ( 15 "context" 16 "testing" 17 18 "github.com/sirupsen/logrus/hooks/test" 19 "github.com/stretchr/testify/assert" 20 "github.com/stretchr/testify/require" 21 "github.com/weaviate/weaviate/adapters/handlers/graphql/local/common_filters" 22 "github.com/weaviate/weaviate/entities/models" 23 "github.com/weaviate/weaviate/entities/searchparams" 24 "github.com/weaviate/weaviate/entities/storobj" 25 ) 26 27 func TestSearcher(t *testing.T) { 28 ctx := context.Background() 29 logger, _ := test.NewNullLogger() 30 class := "HybridClass" 31 32 tests := []struct { 33 name string 34 f func(t *testing.T) 35 }{ 36 { 37 name: "with module provider", 38 f: func(t *testing.T) { 39 params := &Params{ 40 HybridSearch: &searchparams.HybridSearch{ 41 Type: "hybrid", 42 Alpha: 0.5, 43 Query: "some query", 44 TargetVectors: []string{"default"}, 45 }, 46 Class: class, 47 } 48 sparse := func() ([]*storobj.Object, []float32, error) { return nil, nil, nil } 49 dense := func([]float32) ([]*storobj.Object, []float32, error) { return nil, nil, nil } 50 provider := &fakeModuleProvider{} 51 schemaGetter := newFakeSchemaManager() 52 targetVectorParamHelper := newFakeTargetVectorParamHelper() 53 _, err := Search(ctx, params, logger, sparse, dense, nil, provider, schemaGetter, targetVectorParamHelper) 54 require.Nil(t, err) 55 }, 56 }, 57 { 58 name: "without module provider", 59 f: func(t *testing.T) { 60 params := &Params{ 61 HybridSearch: &searchparams.HybridSearch{ 62 Type: "hybrid", 63 Alpha: 0.5, 64 Query: "some query", 65 }, 66 Class: class, 67 } 68 sparse := func() ([]*storobj.Object, []float32, error) { return nil, nil, nil } 69 dense := func([]float32) ([]*storobj.Object, []float32, error) { return nil, nil, nil } 70 71 _, err := Search(ctx, params, logger, sparse, dense, nil, nil, nil, nil) 72 require.Nil(t, err) 73 }, 74 }, 75 { 76 name: "with sparse search only", 77 f: func(t *testing.T) { 78 params := &Params{ 79 HybridSearch: &searchparams.HybridSearch{ 80 Type: "hybrid", 81 Alpha: 0, 82 Query: "some query", 83 }, 84 Class: class, 85 } 86 sparse := func() ([]*storobj.Object, []float32, error) { 87 return []*storobj.Object{ 88 { 89 Object: models.Object{ 90 Class: class, 91 ID: "1889a225-3b28-477d-b8fc-5f6071bb4731", 92 Properties: map[string]any{"prop": "val"}, 93 Vector: []float32{1, 2, 3}, 94 }, 95 Vector: []float32{1, 2, 3}, 96 }, 97 }, []float32{0.008}, nil 98 } 99 dense := func([]float32) ([]*storobj.Object, []float32, error) { return nil, nil, nil } 100 res, err := Search(ctx, params, logger, sparse, dense, nil, nil, nil, nil) 101 require.Nil(t, err) 102 assert.Len(t, res, 1) 103 assert.NotNil(t, res[0]) 104 assert.Contains(t, res[0].ExplainScore, "(Result Set keyword) Document") 105 assert.Contains(t, res[0].ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731") 106 assert.Equal(t, res[0].Vector, []float32{1, 2, 3}) 107 assert.Equal(t, res[0].Dist, float32(0.000)) 108 }, 109 }, 110 { 111 name: "with dense search only", 112 f: func(t *testing.T) { 113 params := &Params{ 114 HybridSearch: &searchparams.HybridSearch{ 115 Type: "hybrid", 116 Alpha: 1, 117 Query: "some query", 118 Vector: []float32{1, 2, 3}, 119 }, 120 Class: class, 121 } 122 sparse := func() ([]*storobj.Object, []float32, error) { return nil, nil, nil } 123 dense := func([]float32) ([]*storobj.Object, []float32, error) { 124 return []*storobj.Object{ 125 { 126 Object: models.Object{ 127 Class: class, 128 ID: "1889a225-3b28-477d-b8fc-5f6071bb4731", 129 Properties: map[string]any{"prop": "val"}, 130 Vector: []float32{1, 2, 3}, 131 }, 132 Vector: []float32{1, 2, 3}, 133 }, 134 }, []float32{0.008}, nil 135 } 136 137 res, err := Search(ctx, params, logger, sparse, dense, nil, nil, nil, nil) 138 require.Nil(t, err) 139 assert.Len(t, res, 1) 140 assert.NotNil(t, res[0]) 141 assert.Contains(t, res[0].ExplainScore, "(Result Set vector) Document") 142 assert.Contains(t, res[0].ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731") 143 assert.Equal(t, res[0].Vector, []float32{1, 2, 3}) 144 assert.Equal(t, res[0].Dist, float32(0.008)) 145 }, 146 }, 147 { 148 name: "combined hybrid search", 149 f: func(t *testing.T) { 150 params := &Params{ 151 HybridSearch: &searchparams.HybridSearch{ 152 Type: "hybrid", 153 Alpha: 0.5, 154 Query: "some query", 155 Vector: []float32{1, 2, 3}, 156 }, 157 Class: class, 158 } 159 sparse := func() ([]*storobj.Object, []float32, error) { 160 return []*storobj.Object{ 161 { 162 Object: models.Object{ 163 Class: class, 164 ID: "1889a225-3b28-477d-b8fc-5f6071bb4731", 165 Properties: map[string]any{"prop": "val"}, 166 Vector: []float32{1, 2, 3}, 167 }, 168 Vector: []float32{1, 2, 3}, 169 }, 170 }, []float32{0.008}, nil 171 } 172 dense := func([]float32) ([]*storobj.Object, []float32, error) { 173 return []*storobj.Object{ 174 { 175 Object: models.Object{ 176 Class: class, 177 ID: "79a636c2-3314-442e-a4d1-e94d7c0afc3a", 178 Properties: map[string]any{"prop": "val"}, 179 Vector: []float32{4, 5, 6}, 180 }, 181 Vector: []float32{4, 5, 6}, 182 }, 183 }, []float32{0.008}, nil 184 } 185 res, err := Search(ctx, params, logger, sparse, dense, nil, nil, nil, nil) 186 require.Nil(t, err) 187 assert.Len(t, res, 2) 188 assert.NotNil(t, res[0]) 189 assert.NotNil(t, res[1]) 190 assert.Contains(t, res[0].ExplainScore, "(Result Set vector) Document") 191 assert.Contains(t, res[0].ExplainScore, "79a636c2-3314-442e-a4d1-e94d7c0afc3a") 192 assert.Equal(t, res[0].Vector, []float32{4, 5, 6}) 193 assert.Equal(t, res[0].Dist, float32(0.008)) 194 assert.Contains(t, res[1].ExplainScore, "(Result Set keyword) Document") 195 assert.Contains(t, res[1].ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731") 196 assert.Equal(t, res[1].Vector, []float32{1, 2, 3}) 197 assert.Equal(t, res[1].Dist, float32(0.000)) 198 assert.Equal(t, float32(0.008333334), res[0].Score) 199 }, 200 }, 201 { 202 name: "with sparse subsearch filter", 203 f: func(t *testing.T) { 204 params := &Params{ 205 HybridSearch: &searchparams.HybridSearch{ 206 Type: "hybrid", 207 SubSearches: []searchparams.WeightedSearchResult{ 208 { 209 Type: "sparseSearch", 210 SearchParams: searchparams.KeywordRanking{ 211 Type: "bm25", 212 Properties: []string{"propA", "propB"}, 213 Query: "some query", 214 }, 215 }, 216 }, 217 }, 218 Class: class, 219 } 220 sparse := func() ([]*storobj.Object, []float32, error) { 221 return []*storobj.Object{ 222 { 223 Object: models.Object{ 224 Class: class, 225 ID: "1889a225-3b28-477d-b8fc-5f6071bb4731", 226 Properties: map[string]any{"prop": "val"}, 227 Vector: []float32{1, 2, 3}, 228 Additional: map[string]interface{}{"score": float32(0.008)}, 229 }, 230 Vector: []float32{1, 2, 3}, 231 }, 232 }, []float32{0.008}, nil 233 } 234 dense := func([]float32) ([]*storobj.Object, []float32, error) { 235 return nil, nil, nil 236 } 237 res, err := Search(ctx, params, logger, sparse, dense, nil, nil, nil, nil) 238 require.Nil(t, err) 239 assert.Len(t, res, 1) 240 assert.NotNil(t, res[0]) 241 assert.Contains(t, res[0].ExplainScore, "(Result Set bm25f) Document 1889a225-3b28-477d-b8fc-5f6071bb4731") 242 assert.Contains(t, res[0].ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731") 243 assert.Equal(t, res[0].Vector, []float32{1, 2, 3}) 244 assert.Equal(t, res[0].Dist, float32(0.008)) 245 }, 246 }, 247 { 248 name: "with nearText subsearch filter", 249 f: func(t *testing.T) { 250 params := &Params{ 251 HybridSearch: &searchparams.HybridSearch{ 252 TargetVectors: []string{"default"}, 253 Type: "hybrid", 254 SubSearches: []searchparams.WeightedSearchResult{ 255 { 256 Type: "nearText", 257 SearchParams: searchparams.NearTextParams{ 258 Values: []string{"some query"}, 259 Certainty: 0.8, 260 }, 261 }, 262 }, 263 }, 264 Class: class, 265 } 266 sparse := func() ([]*storobj.Object, []float32, error) { 267 return nil, nil, nil 268 } 269 dense := func([]float32) ([]*storobj.Object, []float32, error) { 270 return []*storobj.Object{ 271 { 272 Object: models.Object{ 273 Class: class, 274 ID: "1889a225-3b28-477d-b8fc-5f6071bb4731", 275 Properties: map[string]any{"prop": "val"}, 276 Vector: []float32{1, 2, 3}, 277 Additional: map[string]interface{}{"score": float32(0.008)}, 278 }, 279 Vector: []float32{1, 2, 3}, 280 }, 281 }, []float32{0.008}, nil 282 } 283 provider := &fakeModuleProvider{} 284 schemaGetter := newFakeSchemaManager() 285 targetVectorParamHelper := newFakeTargetVectorParamHelper() 286 res, err := Search(ctx, params, logger, sparse, dense, nil, provider, schemaGetter, targetVectorParamHelper) 287 require.Nil(t, err) 288 assert.Len(t, res, 1) 289 assert.NotNil(t, res[0]) 290 assert.Contains(t, res[0].ExplainScore, "(Result Set vector,nearText) Document 1889a225-3b28-477d-b8fc-5f6071bb4731") 291 assert.Contains(t, res[0].ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731") 292 assert.Equal(t, res[0].Vector, []float32{1, 2, 3}) 293 assert.Equal(t, res[0].Dist, float32(0.008)) 294 }, 295 }, 296 { 297 name: "with nearVector subsearch filter", 298 f: func(t *testing.T) { 299 params := &Params{ 300 HybridSearch: &searchparams.HybridSearch{ 301 TargetVectors: []string{"default"}, 302 Type: "hybrid", 303 SubSearches: []searchparams.WeightedSearchResult{ 304 { 305 Type: "nearVector", 306 SearchParams: searchparams.NearVector{ 307 Vector: []float32{1, 2, 3}, 308 Certainty: 0.8, 309 }, 310 }, 311 }, 312 }, 313 Class: class, 314 } 315 sparse := func() ([]*storobj.Object, []float32, error) { 316 return nil, nil, nil 317 } 318 dense := func([]float32) ([]*storobj.Object, []float32, error) { 319 return []*storobj.Object{ 320 { 321 Object: models.Object{ 322 Class: class, 323 ID: "1889a225-3b28-477d-b8fc-5f6071bb4731", 324 Properties: map[string]any{"prop": "val"}, 325 Vector: []float32{1, 2, 3}, 326 Additional: map[string]interface{}{"score": float32(0.008)}, 327 }, 328 Vector: []float32{1, 2, 3}, 329 }, 330 }, []float32{0.008}, nil 331 } 332 provider := &fakeModuleProvider{} 333 schemaGetter := newFakeSchemaManager() 334 targetVectorParamHelper := newFakeTargetVectorParamHelper() 335 res, err := Search(ctx, params, logger, sparse, dense, nil, provider, schemaGetter, targetVectorParamHelper) 336 require.Nil(t, err) 337 assert.Len(t, res, 1) 338 assert.NotNil(t, res[0]) 339 assert.Contains(t, res[0].ExplainScore, "(Result Set vector,nearVector) Document 1889a225-3b28-477d-b8fc-5f6071bb4731") 340 assert.Contains(t, res[0].ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731") 341 assert.Equal(t, res[0].Vector, []float32{1, 2, 3}) 342 assert.Equal(t, res[0].Dist, float32(0.008)) 343 }, 344 }, 345 { 346 name: "with all subsearch filters", 347 f: func(t *testing.T) { 348 params := &Params{ 349 HybridSearch: &searchparams.HybridSearch{ 350 TargetVectors: []string{"default"}, 351 Type: "hybrid", 352 SubSearches: []searchparams.WeightedSearchResult{ 353 { 354 Type: "nearVector", 355 SearchParams: searchparams.NearVector{ 356 Vector: []float32{1, 2, 3}, 357 Certainty: 0.8, 358 }, 359 Weight: 100, 360 }, 361 { 362 Type: "nearText", 363 SearchParams: searchparams.NearTextParams{ 364 Values: []string{"some query"}, 365 Certainty: 0.8, 366 }, 367 Weight: 2, 368 }, 369 { 370 Type: "sparseSearch", 371 SearchParams: searchparams.KeywordRanking{ 372 Type: "bm25", 373 Properties: []string{"propA", "propB"}, 374 Query: "some query", 375 }, 376 Weight: 3, 377 }, 378 }, 379 }, 380 Class: class, 381 } 382 sparse := func() ([]*storobj.Object, []float32, error) { 383 return []*storobj.Object{ 384 { 385 Object: models.Object{ 386 Class: class, 387 ID: "1889a225-3b28-477d-b8fc-5f6071bb4731", 388 Properties: map[string]any{"prop": "val"}, 389 Vector: []float32{1, 2, 3}, 390 Additional: map[string]interface{}{"score": float32(0.008)}, 391 }, 392 Vector: []float32{1, 2, 3}, 393 }, 394 }, []float32{0.008}, nil 395 } 396 dense := func([]float32) ([]*storobj.Object, []float32, error) { 397 return []*storobj.Object{ 398 { 399 Object: models.Object{ 400 Class: class, 401 ID: "79a636c2-3314-442e-a4d1-e94d7c0afc3a", 402 Properties: map[string]any{"prop": "val"}, 403 Vector: []float32{4, 5, 6}, 404 Additional: map[string]interface{}{"score": float32(0.8)}, 405 }, 406 Vector: []float32{4, 5, 6}, 407 }, 408 }, []float32{0.008}, nil 409 } 410 provider := &fakeModuleProvider{} 411 schemaGetter := newFakeSchemaManager() 412 targetVectorParamHelper := newFakeTargetVectorParamHelper() 413 res, err := Search(ctx, params, logger, sparse, dense, nil, provider, schemaGetter, targetVectorParamHelper) 414 require.Nil(t, err) 415 assert.Len(t, res, 2) 416 assert.NotNil(t, res[0]) 417 assert.NotNil(t, res[1]) 418 assert.Contains(t, res[0].ExplainScore, "(Result Set vector,nearVector) Document 79a636c2-3314-442e-a4d1-e94d7c0afc3a") 419 assert.Contains(t, res[0].ExplainScore, "79a636c2-3314-442e-a4d1-e94d7c0afc3a") 420 assert.Equal(t, res[0].Vector, []float32{4, 5, 6}) 421 assert.Equal(t, res[0].Dist, float32(0.008)) 422 assert.Contains(t, res[1].ExplainScore, "(Result Set bm25f) Document 1889a225-3b28-477d-b8fc-5f6071bb4731") 423 assert.Contains(t, res[1].ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731") 424 assert.Equal(t, res[1].Vector, []float32{1, 2, 3}) 425 assert.Equal(t, res[1].Dist, float32(0.008)) 426 }, 427 }, 428 } 429 430 for _, test := range tests { 431 t.Run(test.name, test.f) 432 } 433 } 434 435 type fakeModuleProvider struct { 436 vector []float32 437 } 438 439 func (f *fakeModuleProvider) VectorFromInput(ctx context.Context, className, input, targetVector string) ([]float32, error) { 440 return f.vector, nil 441 } 442 443 func TestNullScores(t *testing.T) { 444 ctx := context.Background() 445 logger, _ := test.NewNullLogger() 446 class := "HybridClass" 447 448 params := &Params{ 449 HybridSearch: &searchparams.HybridSearch{ 450 TargetVectors: []string{"default"}, 451 FusionAlgorithm: common_filters.HybridRelativeScoreFusion, 452 Type: "hybrid", 453 SubSearches: []searchparams.WeightedSearchResult{ 454 { 455 Type: "nearVector", 456 SearchParams: searchparams.NearVector{ 457 Vector: []float32{1, 2, 3}, 458 Certainty: 0.8, 459 }, 460 Weight: 100, 461 }, 462 { 463 Type: "nearText", 464 SearchParams: searchparams.NearTextParams{ 465 Values: []string{"some query"}, 466 Certainty: 0.8, 467 }, 468 Weight: 2, 469 }, 470 { 471 Type: "sparseSearch", 472 SearchParams: searchparams.KeywordRanking{ 473 Type: "bm25", 474 Properties: []string{"propA", "propB"}, 475 Query: "some query", 476 }, 477 Weight: 3, 478 }, 479 }, 480 }, 481 Class: class, 482 } 483 sparse := func() ([]*storobj.Object, []float32, error) { 484 return []*storobj.Object{ 485 { 486 Object: models.Object{ 487 Class: class, 488 ID: "1889a225-3b28-477d-b8fc-5f6071bb4731", 489 Properties: map[string]any{"prop": "val"}, 490 Vector: []float32{1, 2, 3}, 491 Additional: map[string]interface{}{"score": float32(0.008)}, 492 }, 493 Vector: []float32{1, 2, 3}, 494 VectorLen: 3, 495 }, 496 }, []float32{0.008}, nil 497 } 498 dense := func([]float32) ([]*storobj.Object, []float32, error) { 499 return []*storobj.Object{ 500 { 501 Object: models.Object{ 502 Class: class, 503 ID: "79a636c2-3314-442e-a4d1-e94d7c0afc3a", 504 Properties: map[string]any{"prop": "val"}, 505 Vector: []float32{4, 5, 6}, 506 Additional: map[string]interface{}{"score": float32(0.8)}, 507 }, 508 Vector: []float32{4, 5, 6}, 509 VectorLen: 3, 510 }, 511 }, []float32{0.008}, nil 512 } 513 provider := &fakeModuleProvider{} 514 schemaGetter := newFakeSchemaManager() 515 targetVectorParamHelper := newFakeTargetVectorParamHelper() 516 res, err := Search(ctx, params, logger, sparse, dense, nil, provider, schemaGetter, targetVectorParamHelper) 517 require.Nil(t, err) 518 assert.Len(t, res, 2) 519 assert.NotNil(t, res[0]) 520 assert.NotNil(t, res[1]) 521 assert.Contains(t, res[0].ExplainScore, "(Result Set vector,nearVector) Document 79a636c2-3314-442e-a4d1-e94d7c0afc3a") 522 assert.Contains(t, res[0].ExplainScore, "79a636c2-3314-442e-a4d1-e94d7c0afc3a") 523 assert.Equal(t, res[0].Vector, []float32{4, 5, 6}) 524 assert.Equal(t, res[0].Dist, float32(0.008)) 525 assert.Contains(t, res[1].ExplainScore, "(Result Set bm25f) Document 1889a225-3b28-477d-b8fc-5f6071bb4731") 526 assert.Contains(t, res[1].ExplainScore, "1889a225-3b28-477d-b8fc-5f6071bb4731") 527 assert.Equal(t, res[1].Vector, []float32{1, 2, 3}) 528 assert.Equal(t, res[1].Dist, float32(0.008)) 529 }