github.com/weaviate/weaviate@v1.24.6/modules/text2vec-openai/clients/openai_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "context" 16 "encoding/json" 17 "io" 18 "net/http" 19 "net/http/httptest" 20 "testing" 21 "time" 22 23 "github.com/pkg/errors" 24 "github.com/sirupsen/logrus" 25 "github.com/sirupsen/logrus/hooks/test" 26 "github.com/stretchr/testify/assert" 27 "github.com/stretchr/testify/require" 28 "github.com/weaviate/weaviate/modules/text2vec-openai/ent" 29 ) 30 31 func TestBuildUrlFn(t *testing.T) { 32 t.Run("buildUrlFn returns default OpenAI Client", func(t *testing.T) { 33 config := ent.VectorizationConfig{ 34 Type: "", 35 Model: "", 36 ModelVersion: "", 37 ResourceName: "", 38 DeploymentID: "", 39 BaseURL: "https://api.openai.com", 40 IsAzure: false, 41 } 42 url, err := buildUrl(config.BaseURL, config.ResourceName, config.DeploymentID, config.IsAzure) 43 assert.Nil(t, err) 44 assert.Equal(t, "https://api.openai.com/v1/embeddings", url) 45 }) 46 t.Run("buildUrlFn returns Azure Client", func(t *testing.T) { 47 config := ent.VectorizationConfig{ 48 Type: "", 49 Model: "", 50 ModelVersion: "", 51 ResourceName: "resourceID", 52 DeploymentID: "deploymentID", 53 BaseURL: "", 54 IsAzure: true, 55 } 56 url, err := buildUrl(config.BaseURL, config.ResourceName, config.DeploymentID, config.IsAzure) 57 assert.Nil(t, err) 58 assert.Equal(t, "https://resourceID.openai.azure.com/openai/deployments/deploymentID/embeddings?api-version=2022-12-01", url) 59 }) 60 61 t.Run("buildUrlFn returns Azure client with BaseUrl set", func(t *testing.T) { 62 config := ent.VectorizationConfig{ 63 Type: "", 64 Model: "", 65 ModelVersion: "", 66 ResourceName: "resourceID", 67 DeploymentID: "deploymentID", 68 BaseURL: "https://foobar.some.proxy", 69 IsAzure: true, 70 } 71 url, err := buildUrl(config.BaseURL, config.ResourceName, config.DeploymentID, config.IsAzure) 72 assert.Nil(t, err) 73 assert.Equal(t, "https://foobar.some.proxy/openai/deployments/deploymentID/embeddings?api-version=2022-12-01", url) 74 }) 75 76 t.Run("buildUrlFn loads from BaseURL", func(t *testing.T) { 77 config := ent.VectorizationConfig{ 78 Type: "", 79 Model: "", 80 ModelVersion: "", 81 ResourceName: "resourceID", 82 DeploymentID: "deploymentID", 83 BaseURL: "https://foobar.some.proxy", 84 IsAzure: false, 85 } 86 url, err := buildUrl(config.BaseURL, config.ResourceName, config.DeploymentID, config.IsAzure) 87 assert.Nil(t, err) 88 assert.Equal(t, "https://foobar.some.proxy/v1/embeddings", url) 89 }) 90 } 91 92 func TestClient(t *testing.T) { 93 t.Run("when all is fine", func(t *testing.T) { 94 server := httptest.NewServer(&fakeHandler{t: t}) 95 defer server.Close() 96 97 c := New("apiKey", "", "", 0, nullLogger()) 98 c.buildUrlFn = func(baseURL, resourceName, deploymentID string, isAzure bool) (string, error) { 99 return server.URL, nil 100 } 101 102 expected := &ent.VectorizationResult{ 103 Text: []string{"This is my text"}, 104 Vector: [][]float32{{0.1, 0.2, 0.3}}, 105 Dimensions: 3, 106 } 107 res, err := c.Vectorize(context.Background(), "This is my text", 108 ent.VectorizationConfig{ 109 Type: "text", 110 Model: "ada", 111 }) 112 113 assert.Nil(t, err) 114 assert.Equal(t, expected, res) 115 }) 116 117 t.Run("when the context is expired", func(t *testing.T) { 118 server := httptest.NewServer(&fakeHandler{t: t}) 119 defer server.Close() 120 c := New("apiKey", "", "", 0, nullLogger()) 121 c.buildUrlFn = func(baseURL, resourceName, deploymentID string, isAzure bool) (string, error) { 122 return server.URL, nil 123 } 124 125 ctx, cancel := context.WithDeadline(context.Background(), time.Now()) 126 defer cancel() 127 128 _, err := c.Vectorize(ctx, "This is my text", ent.VectorizationConfig{}) 129 130 require.NotNil(t, err) 131 assert.Contains(t, err.Error(), "context deadline exceeded") 132 }) 133 134 t.Run("when the server returns an error", func(t *testing.T) { 135 server := httptest.NewServer(&fakeHandler{ 136 t: t, 137 serverError: errors.Errorf("nope, not gonna happen"), 138 }) 139 defer server.Close() 140 c := New("apiKey", "", "", 0, nullLogger()) 141 c.buildUrlFn = func(baseURL, resourceName, deploymentID string, isAzure bool) (string, error) { 142 return server.URL, nil 143 } 144 145 _, err := c.Vectorize(context.Background(), "This is my text", 146 ent.VectorizationConfig{}) 147 148 require.NotNil(t, err) 149 assert.EqualError(t, err, "connection to: OpenAI API failed with status: 500 error: nope, not gonna happen") 150 }) 151 152 t.Run("when OpenAI key is passed using X-Openai-Api-Key header", func(t *testing.T) { 153 server := httptest.NewServer(&fakeHandler{t: t}) 154 defer server.Close() 155 c := New("", "", "", 0, nullLogger()) 156 c.buildUrlFn = func(baseURL, resourceName, deploymentID string, isAzure bool) (string, error) { 157 return server.URL, nil 158 } 159 160 ctxWithValue := context.WithValue(context.Background(), 161 "X-Openai-Api-Key", []string{"some-key"}) 162 163 expected := &ent.VectorizationResult{ 164 Text: []string{"This is my text"}, 165 Vector: [][]float32{{0.1, 0.2, 0.3}}, 166 Dimensions: 3, 167 } 168 res, err := c.Vectorize(ctxWithValue, "This is my text", 169 ent.VectorizationConfig{ 170 Type: "text", 171 Model: "ada", 172 }) 173 174 require.Nil(t, err) 175 assert.Equal(t, expected, res) 176 }) 177 178 t.Run("when OpenAI key is empty", func(t *testing.T) { 179 server := httptest.NewServer(&fakeHandler{t: t}) 180 defer server.Close() 181 c := New("", "", "", 0, nullLogger()) 182 c.buildUrlFn = func(baseURL, resourceName, deploymentID string, isAzure bool) (string, error) { 183 return server.URL, nil 184 } 185 186 ctx, cancel := context.WithDeadline(context.Background(), time.Now()) 187 defer cancel() 188 189 _, err := c.Vectorize(ctx, "This is my text", ent.VectorizationConfig{}) 190 191 require.NotNil(t, err) 192 assert.EqualError(t, err, "API Key: no api key found "+ 193 "neither in request header: X-Openai-Api-Key "+ 194 "nor in environment variable under OPENAI_APIKEY") 195 }) 196 197 t.Run("when X-Openai-Api-Key header is passed but empty", func(t *testing.T) { 198 server := httptest.NewServer(&fakeHandler{t: t}) 199 defer server.Close() 200 c := New("", "", "", 0, nullLogger()) 201 c.buildUrlFn = func(baseURL, resourceName, deploymentID string, isAzure bool) (string, error) { 202 return server.URL, nil 203 } 204 205 ctxWithValue := context.WithValue(context.Background(), 206 "X-Openai-Api-Key", []string{""}) 207 208 _, err := c.Vectorize(ctxWithValue, "This is my text", 209 ent.VectorizationConfig{ 210 Type: "text", 211 Model: "ada", 212 }) 213 214 require.NotNil(t, err) 215 assert.EqualError(t, err, "API Key: no api key found "+ 216 "neither in request header: X-Openai-Api-Key "+ 217 "nor in environment variable under OPENAI_APIKEY") 218 }) 219 220 t.Run("when X-OpenAI-BaseURL header is passed", func(t *testing.T) { 221 server := httptest.NewServer(&fakeHandler{t: t}) 222 defer server.Close() 223 c := New("", "", "", 0, nullLogger()) 224 225 config := ent.VectorizationConfig{ 226 Type: "text", 227 Model: "ada", 228 BaseURL: "http://default-url.com", 229 } 230 231 ctxWithValue := context.WithValue(context.Background(), 232 "X-Openai-Baseurl", []string{"http://base-url-passed-in-header.com"}) 233 234 buildURL, err := c.buildURL(ctxWithValue, config) 235 require.NoError(t, err) 236 assert.Equal(t, "http://base-url-passed-in-header.com/v1/embeddings", buildURL) 237 238 buildURL, err = c.buildURL(context.TODO(), config) 239 require.NoError(t, err) 240 assert.Equal(t, "http://default-url.com/v1/embeddings", buildURL) 241 }) 242 } 243 244 type fakeHandler struct { 245 t *testing.T 246 serverError error 247 } 248 249 func (f *fakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 250 assert.Equal(f.t, http.MethodPost, r.Method) 251 252 if f.serverError != nil { 253 embeddingError := map[string]interface{}{ 254 "message": f.serverError.Error(), 255 "type": "invalid_request_error", 256 } 257 embedding := map[string]interface{}{ 258 "error": embeddingError, 259 } 260 outBytes, err := json.Marshal(embedding) 261 require.Nil(f.t, err) 262 263 w.WriteHeader(http.StatusInternalServerError) 264 w.Write(outBytes) 265 return 266 } 267 268 bodyBytes, err := io.ReadAll(r.Body) 269 require.Nil(f.t, err) 270 defer r.Body.Close() 271 272 var b map[string]interface{} 273 require.Nil(f.t, json.Unmarshal(bodyBytes, &b)) 274 275 textInputArray := b["input"].([]interface{}) 276 textInput := textInputArray[0].(string) 277 assert.Greater(f.t, len(textInput), 0) 278 279 embeddingData := map[string]interface{}{ 280 "object": textInput, 281 "index": 0, 282 "embedding": []float32{0.1, 0.2, 0.3}, 283 } 284 embedding := map[string]interface{}{ 285 "object": "list", 286 "data": []interface{}{embeddingData}, 287 } 288 289 outBytes, err := json.Marshal(embedding) 290 require.Nil(f.t, err) 291 292 w.Write(outBytes) 293 } 294 295 func nullLogger() logrus.FieldLogger { 296 l, _ := test.NewNullLogger() 297 return l 298 } 299 300 func Test_getModelString(t *testing.T) { 301 t.Run("getModelStringDocument", func(t *testing.T) { 302 type args struct { 303 docType string 304 model string 305 version string 306 } 307 tests := []struct { 308 name string 309 args args 310 want string 311 }{ 312 { 313 name: "Document type: text model: ada vectorizationType: document", 314 args: args{ 315 docType: "text", 316 model: "ada", 317 }, 318 want: "text-search-ada-doc-001", 319 }, 320 { 321 name: "Document type: text model: ada-002 vectorizationType: document", 322 args: args{ 323 docType: "text", 324 model: "ada", 325 version: "002", 326 }, 327 want: "text-embedding-ada-002", 328 }, 329 { 330 name: "Document type: text model: babbage vectorizationType: document", 331 args: args{ 332 docType: "text", 333 model: "babbage", 334 }, 335 want: "text-search-babbage-doc-001", 336 }, 337 { 338 name: "Document type: text model: curie vectorizationType: document", 339 args: args{ 340 docType: "text", 341 model: "curie", 342 }, 343 want: "text-search-curie-doc-001", 344 }, 345 { 346 name: "Document type: text model: davinci vectorizationType: document", 347 args: args{ 348 docType: "text", 349 model: "davinci", 350 }, 351 want: "text-search-davinci-doc-001", 352 }, 353 { 354 name: "Document type: code model: ada vectorizationType: code", 355 args: args{ 356 docType: "code", 357 model: "ada", 358 }, 359 want: "code-search-ada-code-001", 360 }, 361 { 362 name: "Document type: code model: babbage vectorizationType: code", 363 args: args{ 364 docType: "code", 365 model: "babbage", 366 }, 367 want: "code-search-babbage-code-001", 368 }, 369 } 370 for _, tt := range tests { 371 t.Run(tt.name, func(t *testing.T) { 372 v := New("apiKey", "", "", 0, nullLogger()) 373 if got := v.getModelString(tt.args.docType, tt.args.model, "document", tt.args.version); got != tt.want { 374 t.Errorf("vectorizer.getModelString() = %v, want %v", got, tt.want) 375 } 376 }) 377 } 378 }) 379 380 t.Run("getModelStringQuery", func(t *testing.T) { 381 type args struct { 382 docType string 383 model string 384 version string 385 } 386 tests := []struct { 387 name string 388 args args 389 want string 390 }{ 391 { 392 name: "Document type: text model: ada vectorizationType: query", 393 args: args{ 394 docType: "text", 395 model: "ada", 396 }, 397 want: "text-search-ada-query-001", 398 }, 399 { 400 name: "Document type: text model: babbage vectorizationType: query", 401 args: args{ 402 docType: "text", 403 model: "babbage", 404 }, 405 want: "text-search-babbage-query-001", 406 }, 407 { 408 name: "Document type: text model: curie vectorizationType: query", 409 args: args{ 410 docType: "text", 411 model: "curie", 412 }, 413 want: "text-search-curie-query-001", 414 }, 415 { 416 name: "Document type: text model: davinci vectorizationType: query", 417 args: args{ 418 docType: "text", 419 model: "davinci", 420 }, 421 want: "text-search-davinci-query-001", 422 }, 423 { 424 name: "Document type: code model: ada vectorizationType: text", 425 args: args{ 426 docType: "code", 427 model: "ada", 428 }, 429 want: "code-search-ada-text-001", 430 }, 431 { 432 name: "Document type: code model: babbage vectorizationType: text", 433 args: args{ 434 docType: "code", 435 model: "babbage", 436 }, 437 want: "code-search-babbage-text-001", 438 }, 439 } 440 for _, tt := range tests { 441 t.Run(tt.name, func(t *testing.T) { 442 v := New("apiKey", "", "", 0, nullLogger()) 443 if got := v.getModelString(tt.args.docType, tt.args.model, "query", tt.args.version); got != tt.want { 444 t.Errorf("vectorizer.getModelString() = %v, want %v", got, tt.want) 445 } 446 }) 447 } 448 }) 449 } 450 451 func TestOpenAIApiErrorDecode(t *testing.T) { 452 t.Run("getModelStringQuery", func(t *testing.T) { 453 type args struct { 454 response []byte 455 } 456 tests := []struct { 457 name string 458 args args 459 want string 460 }{ 461 { 462 name: "Error code: missing property", 463 args: args{ 464 response: []byte(`{"message": "failed", "type": "error", "param": "arg..."}`), 465 }, 466 want: "", 467 }, 468 { 469 name: "Error code: as int", 470 args: args{ 471 response: []byte(`{"message": "failed", "type": "error", "param": "arg...", "code": 500}`), 472 }, 473 want: "500", 474 }, 475 { 476 name: "Error code as string number", 477 args: args{ 478 response: []byte(`{"message": "failed", "type": "error", "param": "arg...", "code": "500"}`), 479 }, 480 want: "500", 481 }, 482 { 483 name: "Error code as string text", 484 args: args{ 485 response: []byte(`{"message": "failed", "type": "error", "param": "arg...", "code": "invalid_api_key"}`), 486 }, 487 want: "invalid_api_key", 488 }, 489 } 490 for _, tt := range tests { 491 t.Run(tt.name, func(t *testing.T) { 492 var got *openAIApiError 493 err := json.Unmarshal(tt.args.response, &got) 494 require.NoError(t, err) 495 496 if got.Code.String() != tt.want { 497 t.Errorf("OpenAIerror.code = %v, want %v", got.Code, tt.want) 498 } 499 }) 500 } 501 }) 502 }