github.com/weaviate/weaviate@v1.24.6/modules/generative-openai/clients/openai_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package clients
    13  
    14  import (
    15  	"context"
    16  	"encoding/json"
    17  	"io"
    18  	"net/http"
    19  	"net/http/httptest"
    20  	"os"
    21  	"strings"
    22  	"testing"
    23  
    24  	"github.com/sirupsen/logrus"
    25  	"github.com/sirupsen/logrus/hooks/test"
    26  	"github.com/stretchr/testify/assert"
    27  	"github.com/stretchr/testify/require"
    28  	"github.com/weaviate/weaviate/entities/models"
    29  	"github.com/weaviate/weaviate/modules/generative-openai/config"
    30  	generativemodels "github.com/weaviate/weaviate/usecases/modulecomponents/additional/models"
    31  )
    32  
    33  func nullLogger() logrus.FieldLogger {
    34  	l, _ := test.NewNullLogger()
    35  	return l
    36  }
    37  
    38  func fakeBuildUrl(serverURL string, isLegacy bool, resourceName, deploymentID, baseURL, apiVersion string) (string, error) {
    39  	endpoint, err := buildUrlFn(isLegacy, resourceName, deploymentID, baseURL, apiVersion)
    40  	if err != nil {
    41  		return "", err
    42  	}
    43  	endpoint = strings.Replace(endpoint, "https://api.openai.com", serverURL, 1)
    44  	return endpoint, nil
    45  }
    46  
    47  func TestBuildUrlFn(t *testing.T) {
    48  	t.Run("buildUrlFn returns default OpenAI Client", func(t *testing.T) {
    49  		url, err := buildUrlFn(false, "", "", config.DefaultOpenAIBaseURL, config.DefaultApiVersion)
    50  		assert.Nil(t, err)
    51  		assert.Equal(t, "https://api.openai.com/v1/chat/completions", url)
    52  	})
    53  	t.Run("buildUrlFn returns Azure Client", func(t *testing.T) {
    54  		url, err := buildUrlFn(false, "resourceID", "deploymentID", "", config.DefaultApiVersion)
    55  		assert.Nil(t, err)
    56  		assert.Equal(t, "https://resourceID.openai.azure.com/openai/deployments/deploymentID/chat/completions?api-version=2023-05-15", url)
    57  	})
    58  	t.Run("buildUrlFn loads from environment variable", func(t *testing.T) {
    59  		url, err := buildUrlFn(false, "", "", "https://foobar.some.proxy", config.DefaultApiVersion)
    60  		assert.Nil(t, err)
    61  		assert.Equal(t, "https://foobar.some.proxy/v1/chat/completions", url)
    62  		os.Unsetenv("OPENAI_BASE_URL")
    63  	})
    64  	t.Run("buildUrlFn returns Azure Client with custom baseURL", func(t *testing.T) {
    65  		url, err := buildUrlFn(false, "resourceID", "deploymentID", "customBaseURL", config.DefaultApiVersion)
    66  		assert.Nil(t, err)
    67  		assert.Equal(t, "customBaseURL/openai/deployments/deploymentID/chat/completions?api-version=2023-05-15", url)
    68  	})
    69  }
    70  
    71  func TestGetAnswer(t *testing.T) {
    72  	textProperties := []map[string]string{{"prop": "My name is john"}}
    73  	t.Run("when the server has a successful answer ", func(t *testing.T) {
    74  		handler := &testAnswerHandler{
    75  			t: t,
    76  			answer: generateResponse{
    77  				Choices: []choice{{
    78  					FinishReason: "test",
    79  					Index:        0,
    80  					Logprobs:     "",
    81  					Text:         "John",
    82  				}},
    83  				Error: nil,
    84  			},
    85  		}
    86  		server := httptest.NewServer(handler)
    87  		defer server.Close()
    88  
    89  		c := New("openAIApiKey", "", "", 0, nullLogger())
    90  		c.buildUrl = func(isLegacy bool, resourceName, deploymentID, baseURL, apiVersion string) (string, error) {
    91  			return fakeBuildUrl(server.URL, isLegacy, resourceName, deploymentID, baseURL, apiVersion)
    92  		}
    93  
    94  		expected := generativemodels.GenerateResponse{
    95  			Result: ptString("John"),
    96  		}
    97  
    98  		res, err := c.GenerateAllResults(context.Background(), textProperties, "What is my name?", nil)
    99  
   100  		assert.Nil(t, err)
   101  		assert.Equal(t, expected, *res)
   102  	})
   103  
   104  	t.Run("when the server has a an error", func(t *testing.T) {
   105  		server := httptest.NewServer(&testAnswerHandler{
   106  			t: t,
   107  			answer: generateResponse{
   108  				Error: &openAIApiError{
   109  					Message: "some error from the server",
   110  				},
   111  			},
   112  		})
   113  		defer server.Close()
   114  
   115  		c := New("openAIApiKey", "", "", 0, nullLogger())
   116  		c.buildUrl = func(isLegacy bool, resourceName, deploymentID, baseURL, apiVersion string) (string, error) {
   117  			return fakeBuildUrl(server.URL, isLegacy, resourceName, deploymentID, baseURL, apiVersion)
   118  		}
   119  
   120  		_, err := c.GenerateAllResults(context.Background(), textProperties, "What is my name?", nil)
   121  
   122  		require.NotNil(t, err)
   123  		assert.Error(t, err, "connection to OpenAI failed with status: 500 error: some error from the server")
   124  	})
   125  
   126  	t.Run("when X-OpenAI-BaseURL header is passed", func(t *testing.T) {
   127  		settings := &fakeClassSettings{
   128  			baseURL: "http://default-url.com",
   129  		}
   130  		c := New("openAIApiKey", "", "", 0, nullLogger())
   131  
   132  		ctxWithValue := context.WithValue(context.Background(),
   133  			"X-Openai-Baseurl", []string{"http://base-url-passed-in-header.com"})
   134  
   135  		buildURL, err := c.buildOpenAIUrl(ctxWithValue, settings)
   136  		require.NoError(t, err)
   137  		assert.Equal(t, "http://base-url-passed-in-header.com/v1/chat/completions", buildURL)
   138  
   139  		buildURL, err = c.buildOpenAIUrl(context.TODO(), settings)
   140  		require.NoError(t, err)
   141  		assert.Equal(t, "http://default-url.com/v1/chat/completions", buildURL)
   142  	})
   143  }
   144  
   145  type testAnswerHandler struct {
   146  	t *testing.T
   147  	// the test handler will report as not ready before the time has passed
   148  	answer generateResponse
   149  }
   150  
   151  func (f *testAnswerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   152  	assert.Equal(f.t, "/v1/chat/completions", r.URL.String())
   153  	assert.Equal(f.t, http.MethodPost, r.Method)
   154  
   155  	if f.answer.Error != nil && f.answer.Error.Message != "" {
   156  		outBytes, err := json.Marshal(f.answer)
   157  		require.Nil(f.t, err)
   158  
   159  		w.WriteHeader(http.StatusInternalServerError)
   160  		w.Write(outBytes)
   161  		return
   162  	}
   163  
   164  	bodyBytes, err := io.ReadAll(r.Body)
   165  	require.Nil(f.t, err)
   166  	defer r.Body.Close()
   167  
   168  	var b map[string]interface{}
   169  	require.Nil(f.t, json.Unmarshal(bodyBytes, &b))
   170  
   171  	outBytes, err := json.Marshal(f.answer)
   172  	require.Nil(f.t, err)
   173  
   174  	w.Write(outBytes)
   175  }
   176  
   177  func TestOpenAIApiErrorDecode(t *testing.T) {
   178  	t.Run("getModelStringQuery", func(t *testing.T) {
   179  		type args struct {
   180  			response []byte
   181  		}
   182  		tests := []struct {
   183  			name string
   184  			args args
   185  			want string
   186  		}{
   187  			{
   188  				name: "Error code: missing property",
   189  				args: args{
   190  					response: []byte(`{"message": "failed", "type": "error", "param": "arg..."}`),
   191  				},
   192  				want: "",
   193  			},
   194  			{
   195  				name: "Error code: as int",
   196  				args: args{
   197  					response: []byte(`{"message": "failed", "type": "error", "param": "arg...", "code": 500}`),
   198  				},
   199  				want: "500",
   200  			},
   201  			{
   202  				name: "Error code as string number",
   203  				args: args{
   204  					response: []byte(`{"message": "failed", "type": "error", "param": "arg...", "code": "500"}`),
   205  				},
   206  				want: "500",
   207  			},
   208  			{
   209  				name: "Error code as string text",
   210  				args: args{
   211  					response: []byte(`{"message": "failed", "type": "error", "param": "arg...", "code": "invalid_api_key"}`),
   212  				},
   213  				want: "invalid_api_key",
   214  			},
   215  		}
   216  		for _, tt := range tests {
   217  			t.Run(tt.name, func(t *testing.T) {
   218  				var got *openAIApiError
   219  				err := json.Unmarshal(tt.args.response, &got)
   220  				require.NoError(t, err)
   221  
   222  				if got.Code.String() != tt.want {
   223  					t.Errorf("OpenAIerror.code = %v, want %v", got.Code, tt.want)
   224  				}
   225  			})
   226  		}
   227  	})
   228  }
   229  
   230  func ptString(in string) *string {
   231  	return &in
   232  }
   233  
   234  type fakeClassSettings struct {
   235  	isLegacy         bool
   236  	model            string
   237  	maxTokens        float64
   238  	temperature      float64
   239  	frequencyPenalty float64
   240  	presencePenalty  float64
   241  	topP             float64
   242  	resourceName     string
   243  	deploymentID     string
   244  	isAzure          bool
   245  	baseURL          string
   246  	apiVersion       string
   247  }
   248  
   249  func (s *fakeClassSettings) IsLegacy() bool {
   250  	return s.isLegacy
   251  }
   252  
   253  func (s *fakeClassSettings) Model() string {
   254  	return s.model
   255  }
   256  
   257  func (s *fakeClassSettings) MaxTokens() float64 {
   258  	return s.maxTokens
   259  }
   260  
   261  func (s *fakeClassSettings) Temperature() float64 {
   262  	return s.temperature
   263  }
   264  
   265  func (s *fakeClassSettings) FrequencyPenalty() float64 {
   266  	return s.frequencyPenalty
   267  }
   268  
   269  func (s *fakeClassSettings) PresencePenalty() float64 {
   270  	return s.presencePenalty
   271  }
   272  
   273  func (s *fakeClassSettings) TopP() float64 {
   274  	return s.topP
   275  }
   276  
   277  func (s *fakeClassSettings) ResourceName() string {
   278  	return s.resourceName
   279  }
   280  
   281  func (s *fakeClassSettings) DeploymentID() string {
   282  	return s.deploymentID
   283  }
   284  
   285  func (s *fakeClassSettings) IsAzure() bool {
   286  	return s.isAzure
   287  }
   288  
   289  func (s *fakeClassSettings) GetMaxTokensForModel(model string) float64 {
   290  	return 0
   291  }
   292  
   293  func (s *fakeClassSettings) Validate(class *models.Class) error {
   294  	return nil
   295  }
   296  
   297  func (s *fakeClassSettings) BaseURL() string {
   298  	return s.baseURL
   299  }
   300  
   301  func (s *fakeClassSettings) ApiVersion() string {
   302  	return s.apiVersion
   303  }