github.com/weaviate/weaviate@v1.24.6/modules/generative-mistral/clients/mistral_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package clients
    13  
    14  import (
    15  	"context"
    16  	"encoding/json"
    17  	"io"
    18  	"net/http"
    19  	"net/http/httptest"
    20  	"testing"
    21  	"time"
    22  
    23  	"github.com/sirupsen/logrus"
    24  	"github.com/sirupsen/logrus/hooks/test"
    25  	"github.com/stretchr/testify/assert"
    26  	"github.com/stretchr/testify/require"
    27  )
    28  
    29  func nullLogger() logrus.FieldLogger {
    30  	l, _ := test.NewNullLogger()
    31  	return l
    32  }
    33  
    34  func TestGetAnswer(t *testing.T) {
    35  	textProperties := []map[string]string{{"prop": "My name is john"}}
    36  
    37  	tests := []struct {
    38  		name           string
    39  		answer         generateResponse
    40  		timeout        time.Duration
    41  		expectedResult string
    42  	}{
    43  		{
    44  			name: "when the server has a successful answer",
    45  			answer: generateResponse{
    46  				Choices: []Choice{
    47  					{
    48  						Message: Message{
    49  							Content: "John",
    50  						},
    51  					},
    52  				},
    53  				Error: nil,
    54  			},
    55  			expectedResult: "John",
    56  		},
    57  		{
    58  			name: "when the server has an error",
    59  			answer: generateResponse{
    60  				Error: &mistralApiError{
    61  					Message: "some error from the server",
    62  				},
    63  			},
    64  		},
    65  		{
    66  			name:    "when the server does not respond in time",
    67  			answer:  generateResponse{Error: &mistralApiError{Message: "context deadline exceeded"}},
    68  			timeout: time.Second,
    69  		},
    70  	}
    71  	for _, test := range tests {
    72  		t.Run(test.name, func(t *testing.T) {
    73  			handler := &testAnswerHandler{
    74  				t:       t,
    75  				answer:  test.answer,
    76  				timeout: test.timeout,
    77  			}
    78  			server := httptest.NewServer(handler)
    79  			defer server.Close()
    80  
    81  			c := New("apiKey", test.timeout, nullLogger())
    82  
    83  			settings := &fakeClassConfig{baseURL: server.URL}
    84  			res, err := c.GenerateAllResults(context.Background(), textProperties, "What is my name?", settings)
    85  
    86  			if test.answer.Error != nil {
    87  				assert.Contains(t, err.Error(), test.answer.Error.Message)
    88  			} else {
    89  				assert.Equal(t, test.expectedResult, *res.Result)
    90  			}
    91  		})
    92  	}
    93  	t.Run("when X-Mistral-BaseURL header is passed", func(t *testing.T) {
    94  		c := New("apiKey", 5*time.Second, nullLogger())
    95  
    96  		baseURL := "http://default-url.com"
    97  		ctxWithValue := context.WithValue(context.Background(),
    98  			"X-Mistral-Baseurl", []string{"http://base-url-passed-in-header.com"})
    99  
   100  		buildURL, err := c.getMistralUrl(ctxWithValue, baseURL)
   101  		require.NoError(t, err)
   102  		assert.Equal(t, "http://base-url-passed-in-header.com/v1/chat/completions", buildURL)
   103  
   104  		buildURL, err = c.getMistralUrl(context.TODO(), baseURL)
   105  		require.NoError(t, err)
   106  		assert.Equal(t, "http://default-url.com/v1/chat/completions", buildURL)
   107  	})
   108  }
   109  
   110  type testAnswerHandler struct {
   111  	t *testing.T
   112  	// the test handler will report as not ready before the time has passed
   113  	answer  generateResponse
   114  	timeout time.Duration
   115  }
   116  
   117  func (f *testAnswerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   118  	assert.Equal(f.t, "/v1/chat/completions", r.URL.String())
   119  	assert.Equal(f.t, http.MethodPost, r.Method)
   120  
   121  	time.Sleep(f.timeout)
   122  
   123  	if f.answer.Error != nil && f.answer.Error.Message != "" {
   124  		outBytes, err := json.Marshal(f.answer)
   125  		require.Nil(f.t, err)
   126  
   127  		w.WriteHeader(http.StatusInternalServerError)
   128  		w.Write(outBytes)
   129  		return
   130  	}
   131  
   132  	bodyBytes, err := io.ReadAll(r.Body)
   133  	require.Nil(f.t, err)
   134  	defer r.Body.Close()
   135  
   136  	var b map[string]interface{}
   137  	require.Nil(f.t, json.Unmarshal(bodyBytes, &b))
   138  
   139  	outBytes, err := json.Marshal(f.answer)
   140  	require.Nil(f.t, err)
   141  
   142  	w.Write(outBytes)
   143  }
   144  
   145  type fakeClassConfig struct {
   146  	baseURL string
   147  }
   148  
   149  func (cfg *fakeClassConfig) Tenant() string {
   150  	return ""
   151  }
   152  
   153  func (cfg *fakeClassConfig) Class() map[string]interface{} {
   154  	return nil
   155  }
   156  
   157  func (cfg *fakeClassConfig) ClassByModuleName(moduleName string) map[string]interface{} {
   158  	settings := map[string]interface{}{
   159  		"baseURL": cfg.baseURL,
   160  	}
   161  	return settings
   162  }
   163  
   164  func (cfg *fakeClassConfig) Property(propName string) map[string]interface{} {
   165  	return nil
   166  }
   167  
   168  func (f fakeClassConfig) TargetVector() string {
   169  	return ""
   170  }