github.com/weaviate/weaviate@v1.24.6/modules/qna-transformers/clients/qna_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "context" 16 "encoding/json" 17 "net/http" 18 "net/http/httptest" 19 "testing" 20 21 "github.com/stretchr/testify/assert" 22 "github.com/stretchr/testify/require" 23 "github.com/weaviate/weaviate/entities/additional" 24 "github.com/weaviate/weaviate/modules/qna-transformers/ent" 25 ) 26 27 func TestGetAnswer(t *testing.T) { 28 t.Run("when the server has a successful answer (with distance)", func(t *testing.T) { 29 server := httptest.NewServer(&testAnswerHandler{ 30 t: t, 31 answer: answersResponse{ 32 answersInput: answersInput{ 33 Text: "My name is John", 34 Question: "What is my name?", 35 }, 36 Answer: ptString("John"), 37 Certainty: ptFloat(0.7), 38 Distance: ptFloat(0.3), 39 }, 40 }) 41 defer server.Close() 42 c := New(server.URL, 0, nullLogger()) 43 res, err := c.Answer(context.Background(), "My name is John", 44 "What is my name?") 45 assert.Nil(t, err) 46 47 expectedResult := ent.AnswerResult{ 48 Text: "My name is John", 49 Question: "What is my name?", 50 Answer: ptString("John"), 51 Certainty: ptFloat(0.7), 52 Distance: ptFloat(0.6), 53 } 54 55 assert.Equal(t, expectedResult.Text, res.Text) 56 assert.Equal(t, expectedResult.Question, res.Question) 57 assert.Equal(t, expectedResult.Answer, res.Answer) 58 assert.Equal(t, expectedResult.Certainty, res.Certainty) 59 assert.InDelta(t, *expectedResult.Distance, *res.Distance, 1e-9) 60 }) 61 62 t.Run("when the server has a successful answer (with certainty)", func(t *testing.T) { 63 server := httptest.NewServer(&testAnswerHandler{ 64 t: t, 65 answer: answersResponse{ 66 answersInput: answersInput{ 67 Text: "My name is John", 68 Question: "What is my name?", 69 }, 70 Answer: ptString("John"), 71 Certainty: ptFloat(0.7), 72 }, 73 }) 74 defer server.Close() 75 c := New(server.URL, 0, nullLogger()) 76 res, err := c.Answer(context.Background(), "My name is John", 77 "What is my name?") 78 79 assert.Nil(t, err) 80 assert.Equal(t, &ent.AnswerResult{ 81 Text: "My name is John", 82 Question: "What is my name?", 83 Answer: ptString("John"), 84 Certainty: ptFloat(0.7), 85 Distance: additional.CertaintyToDistPtr(ptFloat(0.7)), 86 }, res) 87 }) 88 89 t.Run("when the server has a an error", func(t *testing.T) { 90 server := httptest.NewServer(&testAnswerHandler{ 91 t: t, 92 answer: answersResponse{ 93 Error: "some error from the server", 94 }, 95 }) 96 defer server.Close() 97 c := New(server.URL, 0, nullLogger()) 98 _, err := c.Answer(context.Background(), "My name is John", 99 "What is my name?") 100 101 require.NotNil(t, err) 102 assert.Contains(t, err.Error(), "some error from the server") 103 }) 104 } 105 106 type testAnswerHandler struct { 107 t *testing.T 108 // the test handler will report as not ready before the time has passed 109 answer answersResponse 110 } 111 112 func (f *testAnswerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 113 assert.Equal(f.t, "/answers/", r.URL.String()) 114 assert.Equal(f.t, http.MethodPost, r.Method) 115 116 if f.answer.Error != "" { 117 w.WriteHeader(500) 118 } 119 jsonBytes, _ := json.Marshal(f.answer) 120 w.Write(jsonBytes) 121 } 122 123 func ptFloat(in float64) *float64 { 124 return &in 125 } 126 127 func ptString(in string) *string { 128 return &in 129 }