github.com/weaviate/weaviate@v1.24.6/modules/generative-openai/clients/openai_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "context" 16 "encoding/json" 17 "io" 18 "net/http" 19 "net/http/httptest" 20 "os" 21 "strings" 22 "testing" 23 24 "github.com/sirupsen/logrus" 25 "github.com/sirupsen/logrus/hooks/test" 26 "github.com/stretchr/testify/assert" 27 "github.com/stretchr/testify/require" 28 "github.com/weaviate/weaviate/entities/models" 29 "github.com/weaviate/weaviate/modules/generative-openai/config" 30 generativemodels "github.com/weaviate/weaviate/usecases/modulecomponents/additional/models" 31 ) 32 33 func nullLogger() logrus.FieldLogger { 34 l, _ := test.NewNullLogger() 35 return l 36 } 37 38 func fakeBuildUrl(serverURL string, isLegacy bool, resourceName, deploymentID, baseURL, apiVersion string) (string, error) { 39 endpoint, err := buildUrlFn(isLegacy, resourceName, deploymentID, baseURL, apiVersion) 40 if err != nil { 41 return "", err 42 } 43 endpoint = strings.Replace(endpoint, "https://api.openai.com", serverURL, 1) 44 return endpoint, nil 45 } 46 47 func TestBuildUrlFn(t *testing.T) { 48 t.Run("buildUrlFn returns default OpenAI Client", func(t *testing.T) { 49 url, err := buildUrlFn(false, "", "", config.DefaultOpenAIBaseURL, config.DefaultApiVersion) 50 assert.Nil(t, err) 51 assert.Equal(t, "https://api.openai.com/v1/chat/completions", url) 52 }) 53 t.Run("buildUrlFn returns Azure Client", func(t *testing.T) { 54 url, err := buildUrlFn(false, "resourceID", "deploymentID", "", config.DefaultApiVersion) 55 assert.Nil(t, err) 56 assert.Equal(t, "https://resourceID.openai.azure.com/openai/deployments/deploymentID/chat/completions?api-version=2023-05-15", url) 57 }) 58 t.Run("buildUrlFn loads from environment variable", func(t *testing.T) { 59 url, err := buildUrlFn(false, "", "", "https://foobar.some.proxy", config.DefaultApiVersion) 60 assert.Nil(t, err) 61 assert.Equal(t, "https://foobar.some.proxy/v1/chat/completions", url) 62 os.Unsetenv("OPENAI_BASE_URL") 63 }) 64 t.Run("buildUrlFn returns Azure Client with custom baseURL", func(t *testing.T) { 65 url, err := buildUrlFn(false, "resourceID", "deploymentID", "customBaseURL", config.DefaultApiVersion) 66 assert.Nil(t, err) 67 assert.Equal(t, "customBaseURL/openai/deployments/deploymentID/chat/completions?api-version=2023-05-15", url) 68 }) 69 } 70 71 func TestGetAnswer(t *testing.T) { 72 textProperties := []map[string]string{{"prop": "My name is john"}} 73 t.Run("when the server has a successful answer ", func(t *testing.T) { 74 handler := &testAnswerHandler{ 75 t: t, 76 answer: generateResponse{ 77 Choices: []choice{{ 78 FinishReason: "test", 79 Index: 0, 80 Logprobs: "", 81 Text: "John", 82 }}, 83 Error: nil, 84 }, 85 } 86 server := httptest.NewServer(handler) 87 defer server.Close() 88 89 c := New("openAIApiKey", "", "", 0, nullLogger()) 90 c.buildUrl = func(isLegacy bool, resourceName, deploymentID, baseURL, apiVersion string) (string, error) { 91 return fakeBuildUrl(server.URL, isLegacy, resourceName, deploymentID, baseURL, apiVersion) 92 } 93 94 expected := generativemodels.GenerateResponse{ 95 Result: ptString("John"), 96 } 97 98 res, err := c.GenerateAllResults(context.Background(), textProperties, "What is my name?", nil) 99 100 assert.Nil(t, err) 101 assert.Equal(t, expected, *res) 102 }) 103 104 t.Run("when the server has a an error", func(t *testing.T) { 105 server := httptest.NewServer(&testAnswerHandler{ 106 t: t, 107 answer: generateResponse{ 108 Error: &openAIApiError{ 109 Message: "some error from the server", 110 }, 111 }, 112 }) 113 defer server.Close() 114 115 c := New("openAIApiKey", "", "", 0, nullLogger()) 116 c.buildUrl = func(isLegacy bool, resourceName, deploymentID, baseURL, apiVersion string) (string, error) { 117 return fakeBuildUrl(server.URL, isLegacy, resourceName, deploymentID, baseURL, apiVersion) 118 } 119 120 _, err := c.GenerateAllResults(context.Background(), textProperties, "What is my name?", nil) 121 122 require.NotNil(t, err) 123 assert.Error(t, err, "connection to OpenAI failed with status: 500 error: some error from the server") 124 }) 125 126 t.Run("when X-OpenAI-BaseURL header is passed", func(t *testing.T) { 127 settings := &fakeClassSettings{ 128 baseURL: "http://default-url.com", 129 } 130 c := New("openAIApiKey", "", "", 0, nullLogger()) 131 132 ctxWithValue := context.WithValue(context.Background(), 133 "X-Openai-Baseurl", []string{"http://base-url-passed-in-header.com"}) 134 135 buildURL, err := c.buildOpenAIUrl(ctxWithValue, settings) 136 require.NoError(t, err) 137 assert.Equal(t, "http://base-url-passed-in-header.com/v1/chat/completions", buildURL) 138 139 buildURL, err = c.buildOpenAIUrl(context.TODO(), settings) 140 require.NoError(t, err) 141 assert.Equal(t, "http://default-url.com/v1/chat/completions", buildURL) 142 }) 143 } 144 145 type testAnswerHandler struct { 146 t *testing.T 147 // the test handler will report as not ready before the time has passed 148 answer generateResponse 149 } 150 151 func (f *testAnswerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 152 assert.Equal(f.t, "/v1/chat/completions", r.URL.String()) 153 assert.Equal(f.t, http.MethodPost, r.Method) 154 155 if f.answer.Error != nil && f.answer.Error.Message != "" { 156 outBytes, err := json.Marshal(f.answer) 157 require.Nil(f.t, err) 158 159 w.WriteHeader(http.StatusInternalServerError) 160 w.Write(outBytes) 161 return 162 } 163 164 bodyBytes, err := io.ReadAll(r.Body) 165 require.Nil(f.t, err) 166 defer r.Body.Close() 167 168 var b map[string]interface{} 169 require.Nil(f.t, json.Unmarshal(bodyBytes, &b)) 170 171 outBytes, err := json.Marshal(f.answer) 172 require.Nil(f.t, err) 173 174 w.Write(outBytes) 175 } 176 177 func TestOpenAIApiErrorDecode(t *testing.T) { 178 t.Run("getModelStringQuery", func(t *testing.T) { 179 type args struct { 180 response []byte 181 } 182 tests := []struct { 183 name string 184 args args 185 want string 186 }{ 187 { 188 name: "Error code: missing property", 189 args: args{ 190 response: []byte(`{"message": "failed", "type": "error", "param": "arg..."}`), 191 }, 192 want: "", 193 }, 194 { 195 name: "Error code: as int", 196 args: args{ 197 response: []byte(`{"message": "failed", "type": "error", "param": "arg...", "code": 500}`), 198 }, 199 want: "500", 200 }, 201 { 202 name: "Error code as string number", 203 args: args{ 204 response: []byte(`{"message": "failed", "type": "error", "param": "arg...", "code": "500"}`), 205 }, 206 want: "500", 207 }, 208 { 209 name: "Error code as string text", 210 args: args{ 211 response: []byte(`{"message": "failed", "type": "error", "param": "arg...", "code": "invalid_api_key"}`), 212 }, 213 want: "invalid_api_key", 214 }, 215 } 216 for _, tt := range tests { 217 t.Run(tt.name, func(t *testing.T) { 218 var got *openAIApiError 219 err := json.Unmarshal(tt.args.response, &got) 220 require.NoError(t, err) 221 222 if got.Code.String() != tt.want { 223 t.Errorf("OpenAIerror.code = %v, want %v", got.Code, tt.want) 224 } 225 }) 226 } 227 }) 228 } 229 230 func ptString(in string) *string { 231 return &in 232 } 233 234 type fakeClassSettings struct { 235 isLegacy bool 236 model string 237 maxTokens float64 238 temperature float64 239 frequencyPenalty float64 240 presencePenalty float64 241 topP float64 242 resourceName string 243 deploymentID string 244 isAzure bool 245 baseURL string 246 apiVersion string 247 } 248 249 func (s *fakeClassSettings) IsLegacy() bool { 250 return s.isLegacy 251 } 252 253 func (s *fakeClassSettings) Model() string { 254 return s.model 255 } 256 257 func (s *fakeClassSettings) MaxTokens() float64 { 258 return s.maxTokens 259 } 260 261 func (s *fakeClassSettings) Temperature() float64 { 262 return s.temperature 263 } 264 265 func (s *fakeClassSettings) FrequencyPenalty() float64 { 266 return s.frequencyPenalty 267 } 268 269 func (s *fakeClassSettings) PresencePenalty() float64 { 270 return s.presencePenalty 271 } 272 273 func (s *fakeClassSettings) TopP() float64 { 274 return s.topP 275 } 276 277 func (s *fakeClassSettings) ResourceName() string { 278 return s.resourceName 279 } 280 281 func (s *fakeClassSettings) DeploymentID() string { 282 return s.deploymentID 283 } 284 285 func (s *fakeClassSettings) IsAzure() bool { 286 return s.isAzure 287 } 288 289 func (s *fakeClassSettings) GetMaxTokensForModel(model string) float64 { 290 return 0 291 } 292 293 func (s *fakeClassSettings) Validate(class *models.Class) error { 294 return nil 295 } 296 297 func (s *fakeClassSettings) BaseURL() string { 298 return s.baseURL 299 } 300 301 func (s *fakeClassSettings) ApiVersion() string { 302 return s.apiVersion 303 }