github.com/weaviate/weaviate@v1.24.6/modules/qna-openai/clients/qna_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "context" 16 "encoding/json" 17 "io" 18 "net/http" 19 "net/http/httptest" 20 "testing" 21 22 "github.com/sirupsen/logrus" 23 "github.com/sirupsen/logrus/hooks/test" 24 "github.com/stretchr/testify/assert" 25 "github.com/stretchr/testify/require" 26 "github.com/weaviate/weaviate/modules/qna-openai/ent" 27 ) 28 29 func nullLogger() logrus.FieldLogger { 30 l, _ := test.NewNullLogger() 31 return l 32 } 33 34 func TestGetAnswer(t *testing.T) { 35 t.Run("when the server has a successful answer ", func(t *testing.T) { 36 handler := &testAnswerHandler{ 37 t: t, 38 answer: answersResponse{ 39 Choices: []choice{{ 40 FinishReason: "test", 41 Index: 0, 42 Logprobs: "", 43 Text: "John", 44 }}, 45 Error: nil, 46 }, 47 } 48 server := httptest.NewServer(handler) 49 defer server.Close() 50 51 c := New("openAIApiKey", "", "", 0, nullLogger()) 52 c.buildUrlFn = func(baseURL, resourceName, deploymentID string) (string, error) { 53 return buildUrl(server.URL, resourceName, deploymentID) 54 } 55 56 expected := ent.AnswerResult{ 57 Text: "My name is John", 58 Question: "What is my name?", 59 Answer: ptString("John"), 60 } 61 62 res, err := c.Answer(context.Background(), "My name is John", "What is my name?", nil) 63 64 assert.Nil(t, err) 65 assert.Equal(t, expected, *res) 66 }) 67 68 t.Run("when the server has a an error", func(t *testing.T) { 69 server := httptest.NewServer(&testAnswerHandler{ 70 t: t, 71 answer: answersResponse{ 72 Error: &openAIApiError{ 73 Message: "some error from the server", 74 }, 75 }, 76 }) 77 defer server.Close() 78 79 c := New("openAIApiKey", "", "", 0, nullLogger()) 80 c.buildUrlFn = func(baseURL, resourceName, deploymentID string) (string, error) { 81 return buildUrl(server.URL, resourceName, deploymentID) 82 } 83 84 _, err := c.Answer(context.Background(), "My name is John", "What is my name?", nil) 85 86 require.NotNil(t, err) 87 assert.Error(t, err, "connection to OpenAI failed with status: 500 error: some error from the server") 88 }) 89 90 t.Run("when X-OpenAI-BaseURL header is passed", func(t *testing.T) { 91 c := New("openAIApiKey", "", "", 0, nullLogger()) 92 93 ctxWithValue := context.WithValue(context.Background(), 94 "X-Openai-Baseurl", []string{"http://base-url-passed-in-header.com"}) 95 96 buildURL, err := c.buildOpenAIUrl(ctxWithValue, "http://default-url.com", "", "") 97 require.NoError(t, err) 98 assert.Equal(t, "http://base-url-passed-in-header.com/v1/completions", buildURL) 99 100 buildURL, err = c.buildOpenAIUrl(context.TODO(), "http://default-url.com", "", "") 101 require.NoError(t, err) 102 assert.Equal(t, "http://default-url.com/v1/completions", buildURL) 103 }) 104 } 105 106 type testAnswerHandler struct { 107 t *testing.T 108 // the test handler will report as not ready before the time has passed 109 answer answersResponse 110 } 111 112 func (f *testAnswerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 113 assert.Equal(f.t, "/v1/completions", r.URL.String()) 114 assert.Equal(f.t, http.MethodPost, r.Method) 115 116 if f.answer.Error != nil && f.answer.Error.Message != "" { 117 outBytes, err := json.Marshal(f.answer) 118 require.Nil(f.t, err) 119 120 w.WriteHeader(http.StatusInternalServerError) 121 w.Write(outBytes) 122 return 123 } 124 125 bodyBytes, err := io.ReadAll(r.Body) 126 require.Nil(f.t, err) 127 defer r.Body.Close() 128 129 var b map[string]interface{} 130 require.Nil(f.t, json.Unmarshal(bodyBytes, &b)) 131 132 outBytes, err := json.Marshal(f.answer) 133 require.Nil(f.t, err) 134 135 w.Write(outBytes) 136 } 137 138 func TestOpenAIApiErrorDecode(t *testing.T) { 139 t.Run("getModelStringQuery", func(t *testing.T) { 140 type args struct { 141 response []byte 142 } 143 tests := []struct { 144 name string 145 args args 146 want string 147 }{ 148 { 149 name: "Error code: missing property", 150 args: args{ 151 response: []byte(`{"message": "failed", "type": "error", "param": "arg..."}`), 152 }, 153 want: "", 154 }, 155 { 156 name: "Error code: as int", 157 args: args{ 158 response: []byte(`{"message": "failed", "type": "error", "param": "arg...", "code": 500}`), 159 }, 160 want: "500", 161 }, 162 { 163 name: "Error code as string number", 164 args: args{ 165 response: []byte(`{"message": "failed", "type": "error", "param": "arg...", "code": "500"}`), 166 }, 167 want: "500", 168 }, 169 { 170 name: "Error code as string text", 171 args: args{ 172 response: []byte(`{"message": "failed", "type": "error", "param": "arg...", "code": "invalid_api_key"}`), 173 }, 174 want: "invalid_api_key", 175 }, 176 } 177 for _, tt := range tests { 178 t.Run(tt.name, func(t *testing.T) { 179 var got *openAIApiError 180 err := json.Unmarshal(tt.args.response, &got) 181 require.NoError(t, err) 182 183 if got.Code.String() != tt.want { 184 t.Errorf("OpenAIerror.code = %v, want %v", got.Code, tt.want) 185 } 186 }) 187 } 188 }) 189 } 190 191 func ptString(in string) *string { 192 return &in 193 }