github.com/weaviate/weaviate@v1.24.6/modules/text2vec-openai/clients/openai_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package clients
    13  
    14  import (
    15  	"context"
    16  	"encoding/json"
    17  	"io"
    18  	"net/http"
    19  	"net/http/httptest"
    20  	"testing"
    21  	"time"
    22  
    23  	"github.com/pkg/errors"
    24  	"github.com/sirupsen/logrus"
    25  	"github.com/sirupsen/logrus/hooks/test"
    26  	"github.com/stretchr/testify/assert"
    27  	"github.com/stretchr/testify/require"
    28  	"github.com/weaviate/weaviate/modules/text2vec-openai/ent"
    29  )
    30  
    31  func TestBuildUrlFn(t *testing.T) {
    32  	t.Run("buildUrlFn returns default OpenAI Client", func(t *testing.T) {
    33  		config := ent.VectorizationConfig{
    34  			Type:         "",
    35  			Model:        "",
    36  			ModelVersion: "",
    37  			ResourceName: "",
    38  			DeploymentID: "",
    39  			BaseURL:      "https://api.openai.com",
    40  			IsAzure:      false,
    41  		}
    42  		url, err := buildUrl(config.BaseURL, config.ResourceName, config.DeploymentID, config.IsAzure)
    43  		assert.Nil(t, err)
    44  		assert.Equal(t, "https://api.openai.com/v1/embeddings", url)
    45  	})
    46  	t.Run("buildUrlFn returns Azure Client", func(t *testing.T) {
    47  		config := ent.VectorizationConfig{
    48  			Type:         "",
    49  			Model:        "",
    50  			ModelVersion: "",
    51  			ResourceName: "resourceID",
    52  			DeploymentID: "deploymentID",
    53  			BaseURL:      "",
    54  			IsAzure:      true,
    55  		}
    56  		url, err := buildUrl(config.BaseURL, config.ResourceName, config.DeploymentID, config.IsAzure)
    57  		assert.Nil(t, err)
    58  		assert.Equal(t, "https://resourceID.openai.azure.com/openai/deployments/deploymentID/embeddings?api-version=2022-12-01", url)
    59  	})
    60  
    61  	t.Run("buildUrlFn returns Azure client with BaseUrl set", func(t *testing.T) {
    62  		config := ent.VectorizationConfig{
    63  			Type:         "",
    64  			Model:        "",
    65  			ModelVersion: "",
    66  			ResourceName: "resourceID",
    67  			DeploymentID: "deploymentID",
    68  			BaseURL:      "https://foobar.some.proxy",
    69  			IsAzure:      true,
    70  		}
    71  		url, err := buildUrl(config.BaseURL, config.ResourceName, config.DeploymentID, config.IsAzure)
    72  		assert.Nil(t, err)
    73  		assert.Equal(t, "https://foobar.some.proxy/openai/deployments/deploymentID/embeddings?api-version=2022-12-01", url)
    74  	})
    75  
    76  	t.Run("buildUrlFn loads from BaseURL", func(t *testing.T) {
    77  		config := ent.VectorizationConfig{
    78  			Type:         "",
    79  			Model:        "",
    80  			ModelVersion: "",
    81  			ResourceName: "resourceID",
    82  			DeploymentID: "deploymentID",
    83  			BaseURL:      "https://foobar.some.proxy",
    84  			IsAzure:      false,
    85  		}
    86  		url, err := buildUrl(config.BaseURL, config.ResourceName, config.DeploymentID, config.IsAzure)
    87  		assert.Nil(t, err)
    88  		assert.Equal(t, "https://foobar.some.proxy/v1/embeddings", url)
    89  	})
    90  }
    91  
    92  func TestClient(t *testing.T) {
    93  	t.Run("when all is fine", func(t *testing.T) {
    94  		server := httptest.NewServer(&fakeHandler{t: t})
    95  		defer server.Close()
    96  
    97  		c := New("apiKey", "", "", 0, nullLogger())
    98  		c.buildUrlFn = func(baseURL, resourceName, deploymentID string, isAzure bool) (string, error) {
    99  			return server.URL, nil
   100  		}
   101  
   102  		expected := &ent.VectorizationResult{
   103  			Text:       []string{"This is my text"},
   104  			Vector:     [][]float32{{0.1, 0.2, 0.3}},
   105  			Dimensions: 3,
   106  		}
   107  		res, err := c.Vectorize(context.Background(), "This is my text",
   108  			ent.VectorizationConfig{
   109  				Type:  "text",
   110  				Model: "ada",
   111  			})
   112  
   113  		assert.Nil(t, err)
   114  		assert.Equal(t, expected, res)
   115  	})
   116  
   117  	t.Run("when the context is expired", func(t *testing.T) {
   118  		server := httptest.NewServer(&fakeHandler{t: t})
   119  		defer server.Close()
   120  		c := New("apiKey", "", "", 0, nullLogger())
   121  		c.buildUrlFn = func(baseURL, resourceName, deploymentID string, isAzure bool) (string, error) {
   122  			return server.URL, nil
   123  		}
   124  
   125  		ctx, cancel := context.WithDeadline(context.Background(), time.Now())
   126  		defer cancel()
   127  
   128  		_, err := c.Vectorize(ctx, "This is my text", ent.VectorizationConfig{})
   129  
   130  		require.NotNil(t, err)
   131  		assert.Contains(t, err.Error(), "context deadline exceeded")
   132  	})
   133  
   134  	t.Run("when the server returns an error", func(t *testing.T) {
   135  		server := httptest.NewServer(&fakeHandler{
   136  			t:           t,
   137  			serverError: errors.Errorf("nope, not gonna happen"),
   138  		})
   139  		defer server.Close()
   140  		c := New("apiKey", "", "", 0, nullLogger())
   141  		c.buildUrlFn = func(baseURL, resourceName, deploymentID string, isAzure bool) (string, error) {
   142  			return server.URL, nil
   143  		}
   144  
   145  		_, err := c.Vectorize(context.Background(), "This is my text",
   146  			ent.VectorizationConfig{})
   147  
   148  		require.NotNil(t, err)
   149  		assert.EqualError(t, err, "connection to: OpenAI API failed with status: 500 error: nope, not gonna happen")
   150  	})
   151  
   152  	t.Run("when OpenAI key is passed using X-Openai-Api-Key header", func(t *testing.T) {
   153  		server := httptest.NewServer(&fakeHandler{t: t})
   154  		defer server.Close()
   155  		c := New("", "", "", 0, nullLogger())
   156  		c.buildUrlFn = func(baseURL, resourceName, deploymentID string, isAzure bool) (string, error) {
   157  			return server.URL, nil
   158  		}
   159  
   160  		ctxWithValue := context.WithValue(context.Background(),
   161  			"X-Openai-Api-Key", []string{"some-key"})
   162  
   163  		expected := &ent.VectorizationResult{
   164  			Text:       []string{"This is my text"},
   165  			Vector:     [][]float32{{0.1, 0.2, 0.3}},
   166  			Dimensions: 3,
   167  		}
   168  		res, err := c.Vectorize(ctxWithValue, "This is my text",
   169  			ent.VectorizationConfig{
   170  				Type:  "text",
   171  				Model: "ada",
   172  			})
   173  
   174  		require.Nil(t, err)
   175  		assert.Equal(t, expected, res)
   176  	})
   177  
   178  	t.Run("when OpenAI key is empty", func(t *testing.T) {
   179  		server := httptest.NewServer(&fakeHandler{t: t})
   180  		defer server.Close()
   181  		c := New("", "", "", 0, nullLogger())
   182  		c.buildUrlFn = func(baseURL, resourceName, deploymentID string, isAzure bool) (string, error) {
   183  			return server.URL, nil
   184  		}
   185  
   186  		ctx, cancel := context.WithDeadline(context.Background(), time.Now())
   187  		defer cancel()
   188  
   189  		_, err := c.Vectorize(ctx, "This is my text", ent.VectorizationConfig{})
   190  
   191  		require.NotNil(t, err)
   192  		assert.EqualError(t, err, "API Key: no api key found "+
   193  			"neither in request header: X-Openai-Api-Key "+
   194  			"nor in environment variable under OPENAI_APIKEY")
   195  	})
   196  
   197  	t.Run("when X-Openai-Api-Key header is passed but empty", func(t *testing.T) {
   198  		server := httptest.NewServer(&fakeHandler{t: t})
   199  		defer server.Close()
   200  		c := New("", "", "", 0, nullLogger())
   201  		c.buildUrlFn = func(baseURL, resourceName, deploymentID string, isAzure bool) (string, error) {
   202  			return server.URL, nil
   203  		}
   204  
   205  		ctxWithValue := context.WithValue(context.Background(),
   206  			"X-Openai-Api-Key", []string{""})
   207  
   208  		_, err := c.Vectorize(ctxWithValue, "This is my text",
   209  			ent.VectorizationConfig{
   210  				Type:  "text",
   211  				Model: "ada",
   212  			})
   213  
   214  		require.NotNil(t, err)
   215  		assert.EqualError(t, err, "API Key: no api key found "+
   216  			"neither in request header: X-Openai-Api-Key "+
   217  			"nor in environment variable under OPENAI_APIKEY")
   218  	})
   219  
   220  	t.Run("when X-OpenAI-BaseURL header is passed", func(t *testing.T) {
   221  		server := httptest.NewServer(&fakeHandler{t: t})
   222  		defer server.Close()
   223  		c := New("", "", "", 0, nullLogger())
   224  
   225  		config := ent.VectorizationConfig{
   226  			Type:    "text",
   227  			Model:   "ada",
   228  			BaseURL: "http://default-url.com",
   229  		}
   230  
   231  		ctxWithValue := context.WithValue(context.Background(),
   232  			"X-Openai-Baseurl", []string{"http://base-url-passed-in-header.com"})
   233  
   234  		buildURL, err := c.buildURL(ctxWithValue, config)
   235  		require.NoError(t, err)
   236  		assert.Equal(t, "http://base-url-passed-in-header.com/v1/embeddings", buildURL)
   237  
   238  		buildURL, err = c.buildURL(context.TODO(), config)
   239  		require.NoError(t, err)
   240  		assert.Equal(t, "http://default-url.com/v1/embeddings", buildURL)
   241  	})
   242  }
   243  
   244  type fakeHandler struct {
   245  	t           *testing.T
   246  	serverError error
   247  }
   248  
   249  func (f *fakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   250  	assert.Equal(f.t, http.MethodPost, r.Method)
   251  
   252  	if f.serverError != nil {
   253  		embeddingError := map[string]interface{}{
   254  			"message": f.serverError.Error(),
   255  			"type":    "invalid_request_error",
   256  		}
   257  		embedding := map[string]interface{}{
   258  			"error": embeddingError,
   259  		}
   260  		outBytes, err := json.Marshal(embedding)
   261  		require.Nil(f.t, err)
   262  
   263  		w.WriteHeader(http.StatusInternalServerError)
   264  		w.Write(outBytes)
   265  		return
   266  	}
   267  
   268  	bodyBytes, err := io.ReadAll(r.Body)
   269  	require.Nil(f.t, err)
   270  	defer r.Body.Close()
   271  
   272  	var b map[string]interface{}
   273  	require.Nil(f.t, json.Unmarshal(bodyBytes, &b))
   274  
   275  	textInputArray := b["input"].([]interface{})
   276  	textInput := textInputArray[0].(string)
   277  	assert.Greater(f.t, len(textInput), 0)
   278  
   279  	embeddingData := map[string]interface{}{
   280  		"object":    textInput,
   281  		"index":     0,
   282  		"embedding": []float32{0.1, 0.2, 0.3},
   283  	}
   284  	embedding := map[string]interface{}{
   285  		"object": "list",
   286  		"data":   []interface{}{embeddingData},
   287  	}
   288  
   289  	outBytes, err := json.Marshal(embedding)
   290  	require.Nil(f.t, err)
   291  
   292  	w.Write(outBytes)
   293  }
   294  
   295  func nullLogger() logrus.FieldLogger {
   296  	l, _ := test.NewNullLogger()
   297  	return l
   298  }
   299  
   300  func Test_getModelString(t *testing.T) {
   301  	t.Run("getModelStringDocument", func(t *testing.T) {
   302  		type args struct {
   303  			docType string
   304  			model   string
   305  			version string
   306  		}
   307  		tests := []struct {
   308  			name string
   309  			args args
   310  			want string
   311  		}{
   312  			{
   313  				name: "Document type: text model: ada vectorizationType: document",
   314  				args: args{
   315  					docType: "text",
   316  					model:   "ada",
   317  				},
   318  				want: "text-search-ada-doc-001",
   319  			},
   320  			{
   321  				name: "Document type: text model: ada-002 vectorizationType: document",
   322  				args: args{
   323  					docType: "text",
   324  					model:   "ada",
   325  					version: "002",
   326  				},
   327  				want: "text-embedding-ada-002",
   328  			},
   329  			{
   330  				name: "Document type: text model: babbage vectorizationType: document",
   331  				args: args{
   332  					docType: "text",
   333  					model:   "babbage",
   334  				},
   335  				want: "text-search-babbage-doc-001",
   336  			},
   337  			{
   338  				name: "Document type: text model: curie vectorizationType: document",
   339  				args: args{
   340  					docType: "text",
   341  					model:   "curie",
   342  				},
   343  				want: "text-search-curie-doc-001",
   344  			},
   345  			{
   346  				name: "Document type: text model: davinci vectorizationType: document",
   347  				args: args{
   348  					docType: "text",
   349  					model:   "davinci",
   350  				},
   351  				want: "text-search-davinci-doc-001",
   352  			},
   353  			{
   354  				name: "Document type: code model: ada vectorizationType: code",
   355  				args: args{
   356  					docType: "code",
   357  					model:   "ada",
   358  				},
   359  				want: "code-search-ada-code-001",
   360  			},
   361  			{
   362  				name: "Document type: code model: babbage vectorizationType: code",
   363  				args: args{
   364  					docType: "code",
   365  					model:   "babbage",
   366  				},
   367  				want: "code-search-babbage-code-001",
   368  			},
   369  		}
   370  		for _, tt := range tests {
   371  			t.Run(tt.name, func(t *testing.T) {
   372  				v := New("apiKey", "", "", 0, nullLogger())
   373  				if got := v.getModelString(tt.args.docType, tt.args.model, "document", tt.args.version); got != tt.want {
   374  					t.Errorf("vectorizer.getModelString() = %v, want %v", got, tt.want)
   375  				}
   376  			})
   377  		}
   378  	})
   379  
   380  	t.Run("getModelStringQuery", func(t *testing.T) {
   381  		type args struct {
   382  			docType string
   383  			model   string
   384  			version string
   385  		}
   386  		tests := []struct {
   387  			name string
   388  			args args
   389  			want string
   390  		}{
   391  			{
   392  				name: "Document type: text model: ada vectorizationType: query",
   393  				args: args{
   394  					docType: "text",
   395  					model:   "ada",
   396  				},
   397  				want: "text-search-ada-query-001",
   398  			},
   399  			{
   400  				name: "Document type: text model: babbage vectorizationType: query",
   401  				args: args{
   402  					docType: "text",
   403  					model:   "babbage",
   404  				},
   405  				want: "text-search-babbage-query-001",
   406  			},
   407  			{
   408  				name: "Document type: text model: curie vectorizationType: query",
   409  				args: args{
   410  					docType: "text",
   411  					model:   "curie",
   412  				},
   413  				want: "text-search-curie-query-001",
   414  			},
   415  			{
   416  				name: "Document type: text model: davinci vectorizationType: query",
   417  				args: args{
   418  					docType: "text",
   419  					model:   "davinci",
   420  				},
   421  				want: "text-search-davinci-query-001",
   422  			},
   423  			{
   424  				name: "Document type: code model: ada vectorizationType: text",
   425  				args: args{
   426  					docType: "code",
   427  					model:   "ada",
   428  				},
   429  				want: "code-search-ada-text-001",
   430  			},
   431  			{
   432  				name: "Document type: code model: babbage vectorizationType: text",
   433  				args: args{
   434  					docType: "code",
   435  					model:   "babbage",
   436  				},
   437  				want: "code-search-babbage-text-001",
   438  			},
   439  		}
   440  		for _, tt := range tests {
   441  			t.Run(tt.name, func(t *testing.T) {
   442  				v := New("apiKey", "", "", 0, nullLogger())
   443  				if got := v.getModelString(tt.args.docType, tt.args.model, "query", tt.args.version); got != tt.want {
   444  					t.Errorf("vectorizer.getModelString() = %v, want %v", got, tt.want)
   445  				}
   446  			})
   447  		}
   448  	})
   449  }
   450  
   451  func TestOpenAIApiErrorDecode(t *testing.T) {
   452  	t.Run("getModelStringQuery", func(t *testing.T) {
   453  		type args struct {
   454  			response []byte
   455  		}
   456  		tests := []struct {
   457  			name string
   458  			args args
   459  			want string
   460  		}{
   461  			{
   462  				name: "Error code: missing property",
   463  				args: args{
   464  					response: []byte(`{"message": "failed", "type": "error", "param": "arg..."}`),
   465  				},
   466  				want: "",
   467  			},
   468  			{
   469  				name: "Error code: as int",
   470  				args: args{
   471  					response: []byte(`{"message": "failed", "type": "error", "param": "arg...", "code": 500}`),
   472  				},
   473  				want: "500",
   474  			},
   475  			{
   476  				name: "Error code as string number",
   477  				args: args{
   478  					response: []byte(`{"message": "failed", "type": "error", "param": "arg...", "code": "500"}`),
   479  				},
   480  				want: "500",
   481  			},
   482  			{
   483  				name: "Error code as string text",
   484  				args: args{
   485  					response: []byte(`{"message": "failed", "type": "error", "param": "arg...", "code": "invalid_api_key"}`),
   486  				},
   487  				want: "invalid_api_key",
   488  			},
   489  		}
   490  		for _, tt := range tests {
   491  			t.Run(tt.name, func(t *testing.T) {
   492  				var got *openAIApiError
   493  				err := json.Unmarshal(tt.args.response, &got)
   494  				require.NoError(t, err)
   495  
   496  				if got.Code.String() != tt.want {
   497  					t.Errorf("OpenAIerror.code = %v, want %v", got.Code, tt.want)
   498  				}
   499  			})
   500  		}
   501  	})
   502  }