github.com/weaviate/weaviate@v1.24.6/modules/qna-openai/clients/qna_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package clients
    13  
    14  import (
    15  	"context"
    16  	"encoding/json"
    17  	"io"
    18  	"net/http"
    19  	"net/http/httptest"
    20  	"testing"
    21  
    22  	"github.com/sirupsen/logrus"
    23  	"github.com/sirupsen/logrus/hooks/test"
    24  	"github.com/stretchr/testify/assert"
    25  	"github.com/stretchr/testify/require"
    26  	"github.com/weaviate/weaviate/modules/qna-openai/ent"
    27  )
    28  
    29  func nullLogger() logrus.FieldLogger {
    30  	l, _ := test.NewNullLogger()
    31  	return l
    32  }
    33  
    34  func TestGetAnswer(t *testing.T) {
    35  	t.Run("when the server has a successful answer ", func(t *testing.T) {
    36  		handler := &testAnswerHandler{
    37  			t: t,
    38  			answer: answersResponse{
    39  				Choices: []choice{{
    40  					FinishReason: "test",
    41  					Index:        0,
    42  					Logprobs:     "",
    43  					Text:         "John",
    44  				}},
    45  				Error: nil,
    46  			},
    47  		}
    48  		server := httptest.NewServer(handler)
    49  		defer server.Close()
    50  
    51  		c := New("openAIApiKey", "", "", 0, nullLogger())
    52  		c.buildUrlFn = func(baseURL, resourceName, deploymentID string) (string, error) {
    53  			return buildUrl(server.URL, resourceName, deploymentID)
    54  		}
    55  
    56  		expected := ent.AnswerResult{
    57  			Text:     "My name is John",
    58  			Question: "What is my name?",
    59  			Answer:   ptString("John"),
    60  		}
    61  
    62  		res, err := c.Answer(context.Background(), "My name is John", "What is my name?", nil)
    63  
    64  		assert.Nil(t, err)
    65  		assert.Equal(t, expected, *res)
    66  	})
    67  
    68  	t.Run("when the server has a an error", func(t *testing.T) {
    69  		server := httptest.NewServer(&testAnswerHandler{
    70  			t: t,
    71  			answer: answersResponse{
    72  				Error: &openAIApiError{
    73  					Message: "some error from the server",
    74  				},
    75  			},
    76  		})
    77  		defer server.Close()
    78  
    79  		c := New("openAIApiKey", "", "", 0, nullLogger())
    80  		c.buildUrlFn = func(baseURL, resourceName, deploymentID string) (string, error) {
    81  			return buildUrl(server.URL, resourceName, deploymentID)
    82  		}
    83  
    84  		_, err := c.Answer(context.Background(), "My name is John", "What is my name?", nil)
    85  
    86  		require.NotNil(t, err)
    87  		assert.Error(t, err, "connection to OpenAI failed with status: 500 error: some error from the server")
    88  	})
    89  
    90  	t.Run("when X-OpenAI-BaseURL header is passed", func(t *testing.T) {
    91  		c := New("openAIApiKey", "", "", 0, nullLogger())
    92  
    93  		ctxWithValue := context.WithValue(context.Background(),
    94  			"X-Openai-Baseurl", []string{"http://base-url-passed-in-header.com"})
    95  
    96  		buildURL, err := c.buildOpenAIUrl(ctxWithValue, "http://default-url.com", "", "")
    97  		require.NoError(t, err)
    98  		assert.Equal(t, "http://base-url-passed-in-header.com/v1/completions", buildURL)
    99  
   100  		buildURL, err = c.buildOpenAIUrl(context.TODO(), "http://default-url.com", "", "")
   101  		require.NoError(t, err)
   102  		assert.Equal(t, "http://default-url.com/v1/completions", buildURL)
   103  	})
   104  }
   105  
   106  type testAnswerHandler struct {
   107  	t *testing.T
   108  	// the test handler will report as not ready before the time has passed
   109  	answer answersResponse
   110  }
   111  
   112  func (f *testAnswerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   113  	assert.Equal(f.t, "/v1/completions", r.URL.String())
   114  	assert.Equal(f.t, http.MethodPost, r.Method)
   115  
   116  	if f.answer.Error != nil && f.answer.Error.Message != "" {
   117  		outBytes, err := json.Marshal(f.answer)
   118  		require.Nil(f.t, err)
   119  
   120  		w.WriteHeader(http.StatusInternalServerError)
   121  		w.Write(outBytes)
   122  		return
   123  	}
   124  
   125  	bodyBytes, err := io.ReadAll(r.Body)
   126  	require.Nil(f.t, err)
   127  	defer r.Body.Close()
   128  
   129  	var b map[string]interface{}
   130  	require.Nil(f.t, json.Unmarshal(bodyBytes, &b))
   131  
   132  	outBytes, err := json.Marshal(f.answer)
   133  	require.Nil(f.t, err)
   134  
   135  	w.Write(outBytes)
   136  }
   137  
   138  func TestOpenAIApiErrorDecode(t *testing.T) {
   139  	t.Run("getModelStringQuery", func(t *testing.T) {
   140  		type args struct {
   141  			response []byte
   142  		}
   143  		tests := []struct {
   144  			name string
   145  			args args
   146  			want string
   147  		}{
   148  			{
   149  				name: "Error code: missing property",
   150  				args: args{
   151  					response: []byte(`{"message": "failed", "type": "error", "param": "arg..."}`),
   152  				},
   153  				want: "",
   154  			},
   155  			{
   156  				name: "Error code: as int",
   157  				args: args{
   158  					response: []byte(`{"message": "failed", "type": "error", "param": "arg...", "code": 500}`),
   159  				},
   160  				want: "500",
   161  			},
   162  			{
   163  				name: "Error code as string number",
   164  				args: args{
   165  					response: []byte(`{"message": "failed", "type": "error", "param": "arg...", "code": "500"}`),
   166  				},
   167  				want: "500",
   168  			},
   169  			{
   170  				name: "Error code as string text",
   171  				args: args{
   172  					response: []byte(`{"message": "failed", "type": "error", "param": "arg...", "code": "invalid_api_key"}`),
   173  				},
   174  				want: "invalid_api_key",
   175  			},
   176  		}
   177  		for _, tt := range tests {
   178  			t.Run(tt.name, func(t *testing.T) {
   179  				var got *openAIApiError
   180  				err := json.Unmarshal(tt.args.response, &got)
   181  				require.NoError(t, err)
   182  
   183  				if got.Code.String() != tt.want {
   184  					t.Errorf("OpenAIerror.code = %v, want %v", got.Code, tt.want)
   185  				}
   186  			})
   187  		}
   188  	})
   189  }
   190  
   191  func ptString(in string) *string {
   192  	return &in
   193  }