github.com/weaviate/weaviate@v1.24.6/modules/qna-transformers/additional/answer/answer_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package answer 13 14 import ( 15 "context" 16 "testing" 17 18 "github.com/stretchr/testify/assert" 19 "github.com/stretchr/testify/require" 20 "github.com/weaviate/weaviate/entities/additional" 21 "github.com/weaviate/weaviate/entities/search" 22 qnamodels "github.com/weaviate/weaviate/modules/qna-transformers/additional/models" 23 "github.com/weaviate/weaviate/modules/qna-transformers/ent" 24 ) 25 26 func TestAdditionalAnswerProvider(t *testing.T) { 27 t.Run("should fail with empty content", func(t *testing.T) { 28 // given 29 qnaClient := &fakeQnAClient{} 30 fakeHelper := &fakeParamsHelper{} 31 answerProvider := New(qnaClient, fakeHelper) 32 in := []search.Result{ 33 { 34 ID: "some-uuid", 35 }, 36 } 37 fakeParams := &Params{} 38 limit := 1 39 argumentModuleParams := map[string]interface{}{} 40 41 // when 42 out, err := answerProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil) 43 44 // then 45 require.NotNil(t, err) 46 require.NotEmpty(t, out) 47 assert.Error(t, err, "empty content") 48 }) 49 50 t.Run("should fail with empty question", func(t *testing.T) { 51 // given 52 qnaClient := &fakeQnAClient{} 53 fakeHelper := &fakeParamsHelper{} 54 answerProvider := New(qnaClient, fakeHelper) 55 in := []search.Result{ 56 { 57 ID: "some-uuid", 58 Schema: map[string]interface{}{ 59 "content": "content", 60 }, 61 }, 62 } 63 fakeParams := &Params{} 64 limit := 1 65 argumentModuleParams := map[string]interface{}{} 66 67 // when 68 out, err := answerProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil) 69 70 // then 71 require.NotNil(t, err) 72 require.NotEmpty(t, out) 73 assert.Error(t, err, "empty content") 74 }) 75 76 t.Run("should answer", func(t *testing.T) { 77 // given 78 qnaClient := &fakeQnAClient{} 79 fakeHelper := &fakeParamsHelper{} 80 answerProvider := New(qnaClient, fakeHelper) 81 in := []search.Result{ 82 { 83 ID: "some-uuid", 84 Schema: map[string]interface{}{ 85 "content": "content", 86 }, 87 }, 88 } 89 fakeParams := &Params{} 90 limit := 1 91 argumentModuleParams := map[string]interface{}{ 92 "ask": map[string]interface{}{ 93 "question": "question", 94 }, 95 } 96 97 // when 98 out, err := answerProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil) 99 100 // then 101 require.Nil(t, err) 102 require.NotEmpty(t, out) 103 assert.Equal(t, 1, len(in)) 104 answer, answerOK := in[0].AdditionalProperties["answer"] 105 assert.True(t, answerOK) 106 assert.NotNil(t, answer) 107 answerAdditional, answerAdditionalOK := answer.(*qnamodels.Answer) 108 assert.True(t, answerAdditionalOK) 109 assert.Equal(t, "answer", *answerAdditional.Result) 110 }) 111 112 t.Run("should answer with property", func(t *testing.T) { 113 // given 114 qnaClient := &fakeQnAClient{} 115 fakeHelper := &fakeParamsHelper{} 116 answerProvider := New(qnaClient, fakeHelper) 117 in := []search.Result{ 118 { 119 ID: "some-uuid", 120 Schema: map[string]interface{}{ 121 "content": "content with answer", 122 "content2": "this one is just a title", 123 }, 124 }, 125 } 126 fakeParams := &Params{} 127 limit := 1 128 argumentModuleParams := map[string]interface{}{ 129 "ask": map[string]interface{}{ 130 "question": "question", 131 "properties": []string{"content", "content2"}, 132 }, 133 } 134 135 // when 136 out, err := answerProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil) 137 138 // then 139 require.Nil(t, err) 140 require.NotEmpty(t, out) 141 assert.Equal(t, 1, len(in)) 142 answer, answerOK := in[0].AdditionalProperties["answer"] 143 assert.True(t, answerOK) 144 assert.NotNil(t, answer) 145 answerAdditional, answerAdditionalOK := answer.(*qnamodels.Answer) 146 assert.True(t, answerAdditionalOK) 147 assert.Equal(t, "answer", *answerAdditional.Result) 148 assert.Equal(t, "content", *answerAdditional.Property) 149 assert.Equal(t, 0.8, *answerAdditional.Certainty) 150 assert.InDelta(t, 0.4, *answerAdditional.Distance, 1e-9) 151 assert.Equal(t, 13, answerAdditional.StartPosition) 152 assert.Equal(t, 19, answerAdditional.EndPosition) 153 assert.Equal(t, true, answerAdditional.HasAnswer) 154 }) 155 156 t.Run("should answer with similarity set above ask distance", func(t *testing.T) { 157 // given 158 qnaClient := &fakeQnAClient{} 159 fakeHelper := &fakeParamsHelper{} 160 answerProvider := New(qnaClient, fakeHelper) 161 in := []search.Result{ 162 { 163 ID: "some-uuid", 164 Schema: map[string]interface{}{ 165 "content": "content with answer", 166 "content2": "this one is just a title", 167 }, 168 }, 169 } 170 fakeParams := &Params{} 171 limit := 1 172 argumentModuleParams := map[string]interface{}{ 173 "ask": map[string]interface{}{ 174 "question": "question", 175 "properties": []string{"content", "content2"}, 176 "distance": float64(0.4), 177 }, 178 } 179 180 // when 181 out, err := answerProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil) 182 183 // then 184 require.Nil(t, err) 185 require.NotEmpty(t, out) 186 assert.Equal(t, 1, len(out)) 187 answer, answerOK := out[0].AdditionalProperties["answer"] 188 assert.True(t, answerOK) 189 assert.NotNil(t, answer) 190 answerAdditional, answerAdditionalOK := answer.(*qnamodels.Answer) 191 assert.True(t, answerAdditionalOK) 192 assert.Equal(t, "answer", *answerAdditional.Result) 193 assert.Equal(t, "content", *answerAdditional.Property) 194 assert.Equal(t, 0.8, *answerAdditional.Certainty) 195 assert.Equal(t, *additional.CertaintyToDistPtr(ptFloat(0.8)), *answerAdditional.Distance) 196 assert.Equal(t, 13, answerAdditional.StartPosition) 197 assert.Equal(t, 19, answerAdditional.EndPosition) 198 assert.Equal(t, true, answerAdditional.HasAnswer) 199 }) 200 201 t.Run("should answer with similarity set above ask certainty", func(t *testing.T) { 202 // given 203 qnaClient := &fakeQnAClient{} 204 fakeHelper := &fakeParamsHelper{} 205 answerProvider := New(qnaClient, fakeHelper) 206 in := []search.Result{ 207 { 208 ID: "some-uuid", 209 Schema: map[string]interface{}{ 210 "content": "content with answer", 211 "content2": "this one is just a title", 212 }, 213 }, 214 } 215 fakeParams := &Params{} 216 limit := 1 217 argumentModuleParams := map[string]interface{}{ 218 "ask": map[string]interface{}{ 219 "question": "question", 220 "properties": []string{"content", "content2"}, 221 "certainty": float64(0.8), 222 }, 223 } 224 225 // when 226 out, err := answerProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil) 227 228 // then 229 require.Nil(t, err) 230 require.NotEmpty(t, out) 231 assert.Equal(t, 1, len(out)) 232 answer, answerOK := out[0].AdditionalProperties["answer"] 233 assert.True(t, answerOK) 234 assert.NotNil(t, answer) 235 answerAdditional, answerAdditionalOK := answer.(*qnamodels.Answer) 236 assert.True(t, answerAdditionalOK) 237 assert.Equal(t, "answer", *answerAdditional.Result) 238 assert.Equal(t, "content", *answerAdditional.Property) 239 assert.Equal(t, 0.8, *answerAdditional.Certainty) 240 assert.Equal(t, *additional.CertaintyToDistPtr(ptFloat(0.8)), *answerAdditional.Distance) 241 assert.Equal(t, 13, answerAdditional.StartPosition) 242 assert.Equal(t, 19, answerAdditional.EndPosition) 243 assert.Equal(t, true, answerAdditional.HasAnswer) 244 }) 245 246 t.Run("should not answer with distance set below ask distance", func(t *testing.T) { 247 // given 248 qnaClient := &fakeQnAClient{} 249 fakeHelper := &fakeParamsHelper{} 250 answerProvider := New(qnaClient, fakeHelper) 251 in := []search.Result{ 252 { 253 ID: "some-uuid", 254 Schema: map[string]interface{}{ 255 "content": "content with answer", 256 "content2": "this one is just a title", 257 }, 258 }, 259 } 260 fakeParams := &Params{} 261 limit := 1 262 argumentModuleParams := map[string]interface{}{ 263 "ask": map[string]interface{}{ 264 "question": "question", 265 "properties": []string{"content", "content2"}, 266 "distance": float64(0.19), 267 }, 268 } 269 270 // when 271 out, err := answerProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil) 272 273 // then 274 require.Nil(t, err) 275 require.NotEmpty(t, out) 276 assert.Equal(t, 1, len(in)) 277 answer, answerOK := in[0].AdditionalProperties["answer"] 278 assert.True(t, answerOK) 279 assert.NotNil(t, answer) 280 answerAdditional, answerAdditionalOK := answer.(*qnamodels.Answer) 281 assert.True(t, answerAdditionalOK) 282 assert.True(t, answerAdditional.Result == nil) 283 assert.True(t, answerAdditional.Property == nil) 284 assert.True(t, answerAdditional.Certainty == nil) 285 assert.True(t, answerAdditional.Distance == nil) 286 assert.Equal(t, 0, answerAdditional.StartPosition) 287 assert.Equal(t, 0, answerAdditional.EndPosition) 288 assert.Equal(t, false, answerAdditional.HasAnswer) 289 }) 290 291 t.Run("should not answer with certainty set below ask certainty", func(t *testing.T) { 292 // given 293 qnaClient := &fakeQnAClient{} 294 fakeHelper := &fakeParamsHelper{} 295 answerProvider := New(qnaClient, fakeHelper) 296 in := []search.Result{ 297 { 298 ID: "some-uuid", 299 Schema: map[string]interface{}{ 300 "content": "content with answer", 301 "content2": "this one is just a title", 302 }, 303 }, 304 } 305 fakeParams := &Params{} 306 limit := 1 307 argumentModuleParams := map[string]interface{}{ 308 "ask": map[string]interface{}{ 309 "question": "question", 310 "properties": []string{"content", "content2"}, 311 "certainty": float64(0.81), 312 }, 313 } 314 315 // when 316 out, err := answerProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil) 317 318 // then 319 require.Nil(t, err) 320 require.NotEmpty(t, out) 321 assert.Equal(t, 1, len(in)) 322 answer, answerOK := in[0].AdditionalProperties["answer"] 323 assert.True(t, answerOK) 324 assert.NotNil(t, answer) 325 answerAdditional, answerAdditionalOK := answer.(*qnamodels.Answer) 326 assert.True(t, answerAdditionalOK) 327 assert.True(t, answerAdditional.Result == nil) 328 assert.True(t, answerAdditional.Property == nil) 329 assert.True(t, answerAdditional.Certainty == nil) 330 assert.True(t, answerAdditional.Distance == nil) 331 assert.Equal(t, 0, answerAdditional.StartPosition) 332 assert.Equal(t, 0, answerAdditional.EndPosition) 333 assert.Equal(t, false, answerAdditional.HasAnswer) 334 }) 335 336 t.Run("should answer with certainty set above ask certainty and the results should be reranked", func(t *testing.T) { 337 // given 338 qnaClient := &fakeQnAClient{} 339 fakeHelper := &fakeParamsHelper{} 340 answerProvider := New(qnaClient, fakeHelper) 341 in := []search.Result{ 342 { 343 ID: "uuid1", 344 Schema: map[string]interface{}{ 345 "content": "rerank 0.5", 346 }, 347 }, 348 { 349 ID: "uuid2", 350 Schema: map[string]interface{}{ 351 "content": "rerank 0.2", 352 }, 353 }, 354 { 355 ID: "uuid3", 356 Schema: map[string]interface{}{ 357 "content": "rerank 0.9", 358 }, 359 }, 360 } 361 fakeParams := &Params{} 362 limit := 1 363 argumentModuleParams := map[string]interface{}{ 364 "ask": map[string]interface{}{ 365 "question": "question", 366 "properties": []string{"content"}, 367 "rerank": true, 368 }, 369 } 370 371 // when 372 out, err := answerProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil) 373 374 // then 375 require.Nil(t, err) 376 require.NotEmpty(t, out) 377 assert.Equal(t, 3, len(in)) 378 answer, answerOK := in[0].AdditionalProperties["answer"] 379 assert.True(t, answerOK) 380 assert.NotNil(t, answer) 381 answerAdditional, answerAdditionalOK := answer.(*qnamodels.Answer) 382 assert.True(t, answerAdditionalOK) 383 assert.Equal(t, "rerank 0.9", *answerAdditional.Result) 384 assert.Equal(t, "content", *answerAdditional.Property) 385 assert.Equal(t, 0.9, *answerAdditional.Certainty) 386 assert.Equal(t, *additional.CertaintyToDistPtr(ptFloat(0.9)), *answerAdditional.Distance) 387 assert.Equal(t, 0, answerAdditional.StartPosition) 388 assert.Equal(t, 10, answerAdditional.EndPosition) 389 assert.Equal(t, true, answerAdditional.HasAnswer) 390 391 answer, answerOK = in[1].AdditionalProperties["answer"] 392 assert.True(t, answerOK) 393 assert.NotNil(t, answer) 394 answerAdditional, answerAdditionalOK = answer.(*qnamodels.Answer) 395 assert.True(t, answerAdditionalOK) 396 assert.Equal(t, "rerank 0.5", *answerAdditional.Result) 397 assert.Equal(t, "content", *answerAdditional.Property) 398 assert.Equal(t, 0.5, *answerAdditional.Certainty) 399 assert.Equal(t, *additional.CertaintyToDistPtr(ptFloat(0.5)), *answerAdditional.Distance) 400 assert.Equal(t, 0, answerAdditional.StartPosition) 401 assert.Equal(t, 10, answerAdditional.EndPosition) 402 assert.Equal(t, true, answerAdditional.HasAnswer) 403 404 answer, answerOK = in[2].AdditionalProperties["answer"] 405 assert.True(t, answerOK) 406 assert.NotNil(t, answer) 407 answerAdditional, answerAdditionalOK = answer.(*qnamodels.Answer) 408 assert.True(t, answerAdditionalOK) 409 assert.Equal(t, "rerank 0.2", *answerAdditional.Result) 410 assert.Equal(t, "content", *answerAdditional.Property) 411 assert.Equal(t, 0.2, *answerAdditional.Certainty) 412 assert.Equal(t, *additional.CertaintyToDistPtr(ptFloat(0.2)), *answerAdditional.Distance) 413 assert.Equal(t, 0, answerAdditional.StartPosition) 414 assert.Equal(t, 10, answerAdditional.EndPosition) 415 assert.Equal(t, true, answerAdditional.HasAnswer) 416 }) 417 418 t.Run("should answer with certainty set above ask certainty and the results should not be reranked", func(t *testing.T) { 419 // given 420 qnaClient := &fakeQnAClient{} 421 fakeHelper := &fakeParamsHelper{} 422 answerProvider := New(qnaClient, fakeHelper) 423 in := []search.Result{ 424 { 425 ID: "uuid1", 426 Schema: map[string]interface{}{ 427 "content": "rerank 0.5", 428 }, 429 }, 430 { 431 ID: "uuid2", 432 Schema: map[string]interface{}{ 433 "content": "rerank 0.2", 434 }, 435 }, 436 { 437 ID: "uuid3", 438 Schema: map[string]interface{}{ 439 "content": "rerank 0.9", 440 }, 441 }, 442 } 443 fakeParams := &Params{} 444 limit := 1 445 argumentModuleParams := map[string]interface{}{ 446 "ask": map[string]interface{}{ 447 "question": "question", 448 "properties": []string{"content"}, 449 "rerank": false, 450 }, 451 } 452 453 // when 454 out, err := answerProvider.AdditionalPropertyFn(context.Background(), in, fakeParams, &limit, argumentModuleParams, nil) 455 456 // then 457 require.Nil(t, err) 458 require.NotEmpty(t, out) 459 assert.Equal(t, 3, len(in)) 460 answer, answerOK := in[0].AdditionalProperties["answer"] 461 assert.True(t, answerOK) 462 assert.NotNil(t, answer) 463 answerAdditional, answerAdditionalOK := answer.(*qnamodels.Answer) 464 assert.True(t, answerAdditionalOK) 465 assert.Equal(t, "rerank 0.5", *answerAdditional.Result) 466 assert.Equal(t, "content", *answerAdditional.Property) 467 assert.Equal(t, 0.5, *answerAdditional.Certainty) 468 assert.Equal(t, *additional.CertaintyToDistPtr(ptFloat(0.5)), *answerAdditional.Distance) 469 assert.Equal(t, 0, answerAdditional.StartPosition) 470 assert.Equal(t, 10, answerAdditional.EndPosition) 471 assert.Equal(t, true, answerAdditional.HasAnswer) 472 473 answer, answerOK = in[1].AdditionalProperties["answer"] 474 assert.True(t, answerOK) 475 assert.NotNil(t, answer) 476 answerAdditional, answerAdditionalOK = answer.(*qnamodels.Answer) 477 assert.True(t, answerAdditionalOK) 478 assert.Equal(t, "rerank 0.2", *answerAdditional.Result) 479 assert.Equal(t, "content", *answerAdditional.Property) 480 assert.Equal(t, 0.2, *answerAdditional.Certainty) 481 assert.Equal(t, *additional.CertaintyToDistPtr(ptFloat(0.2)), *answerAdditional.Distance) 482 assert.Equal(t, 0, answerAdditional.StartPosition) 483 assert.Equal(t, 10, answerAdditional.EndPosition) 484 assert.Equal(t, true, answerAdditional.HasAnswer) 485 486 answer, answerOK = in[2].AdditionalProperties["answer"] 487 assert.True(t, answerOK) 488 assert.NotNil(t, answer) 489 answerAdditional, answerAdditionalOK = answer.(*qnamodels.Answer) 490 assert.True(t, answerAdditionalOK) 491 assert.Equal(t, "rerank 0.9", *answerAdditional.Result) 492 assert.Equal(t, "content", *answerAdditional.Property) 493 assert.Equal(t, 0.9, *answerAdditional.Certainty) 494 assert.Equal(t, *additional.CertaintyToDistPtr(ptFloat(0.9)), *answerAdditional.Distance) 495 assert.Equal(t, 0, answerAdditional.StartPosition) 496 assert.Equal(t, 10, answerAdditional.EndPosition) 497 assert.Equal(t, true, answerAdditional.HasAnswer) 498 }) 499 } 500 501 type fakeQnAClient struct{} 502 503 func (c *fakeQnAClient) Answer(ctx context.Context, 504 text, question string, 505 ) (*ent.AnswerResult, error) { 506 if text == "rerank 0.9" { 507 return c.getAnswer(question, "rerank 0.9", 0.9), nil 508 } 509 if text == "rerank 0.5" { 510 return c.getAnswer(question, "rerank 0.5", 0.5), nil 511 } 512 if text == "rerank 0.2" { 513 return c.getAnswer(question, "rerank 0.2", 0.2), nil 514 } 515 return c.getAnswer(question, "answer", 0.8), nil 516 } 517 518 func (c *fakeQnAClient) getAnswer(question, answer string, certainty float64) *ent.AnswerResult { 519 return &ent.AnswerResult{ 520 Text: question, 521 Question: question, 522 Answer: &answer, 523 Certainty: &certainty, 524 Distance: additional.CertaintyToDistPtr(&certainty), 525 } 526 } 527 528 type fakeParamsHelper struct{} 529 530 func (h *fakeParamsHelper) GetQuestion(params interface{}) string { 531 if fakeParamsMap, ok := params.(map[string]interface{}); ok { 532 if question, ok := fakeParamsMap["question"].(string); ok { 533 return question 534 } 535 } 536 return "" 537 } 538 539 func (h *fakeParamsHelper) GetProperties(params interface{}) []string { 540 if fakeParamsMap, ok := params.(map[string]interface{}); ok { 541 if properties, ok := fakeParamsMap["properties"].([]string); ok { 542 return properties 543 } 544 } 545 return nil 546 } 547 548 func (h *fakeParamsHelper) GetCertainty(params interface{}) float64 { 549 if fakeParamsMap, ok := params.(map[string]interface{}); ok { 550 if certainty, ok := fakeParamsMap["certainty"].(float64); ok { 551 return certainty 552 } 553 } 554 return 0 555 } 556 557 func (h *fakeParamsHelper) GetDistance(params interface{}) float64 { 558 if fakeParamsMap, ok := params.(map[string]interface{}); ok { 559 if distance, ok := fakeParamsMap["distance"].(float64); ok { 560 return distance 561 } 562 } 563 return 0 564 } 565 566 func (h *fakeParamsHelper) GetRerank(params interface{}) bool { 567 if fakeParamsMap, ok := params.(map[string]interface{}); ok { 568 if rerank, ok := fakeParamsMap["rerank"].(bool); ok { 569 return rerank 570 } 571 } 572 return false 573 } 574 575 func ptFloat(f float64) *float64 { 576 return &f 577 }