github.com/weaviate/weaviate@v1.24.6/modules/generative-mistral/clients/mistral_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "context" 16 "encoding/json" 17 "io" 18 "net/http" 19 "net/http/httptest" 20 "testing" 21 "time" 22 23 "github.com/sirupsen/logrus" 24 "github.com/sirupsen/logrus/hooks/test" 25 "github.com/stretchr/testify/assert" 26 "github.com/stretchr/testify/require" 27 ) 28 29 func nullLogger() logrus.FieldLogger { 30 l, _ := test.NewNullLogger() 31 return l 32 } 33 34 func TestGetAnswer(t *testing.T) { 35 textProperties := []map[string]string{{"prop": "My name is john"}} 36 37 tests := []struct { 38 name string 39 answer generateResponse 40 timeout time.Duration 41 expectedResult string 42 }{ 43 { 44 name: "when the server has a successful answer", 45 answer: generateResponse{ 46 Choices: []Choice{ 47 { 48 Message: Message{ 49 Content: "John", 50 }, 51 }, 52 }, 53 Error: nil, 54 }, 55 expectedResult: "John", 56 }, 57 { 58 name: "when the server has an error", 59 answer: generateResponse{ 60 Error: &mistralApiError{ 61 Message: "some error from the server", 62 }, 63 }, 64 }, 65 { 66 name: "when the server does not respond in time", 67 answer: generateResponse{Error: &mistralApiError{Message: "context deadline exceeded"}}, 68 timeout: time.Second, 69 }, 70 } 71 for _, test := range tests { 72 t.Run(test.name, func(t *testing.T) { 73 handler := &testAnswerHandler{ 74 t: t, 75 answer: test.answer, 76 timeout: test.timeout, 77 } 78 server := httptest.NewServer(handler) 79 defer server.Close() 80 81 c := New("apiKey", test.timeout, nullLogger()) 82 83 settings := &fakeClassConfig{baseURL: server.URL} 84 res, err := c.GenerateAllResults(context.Background(), textProperties, "What is my name?", settings) 85 86 if test.answer.Error != nil { 87 assert.Contains(t, err.Error(), test.answer.Error.Message) 88 } else { 89 assert.Equal(t, test.expectedResult, *res.Result) 90 } 91 }) 92 } 93 t.Run("when X-Mistral-BaseURL header is passed", func(t *testing.T) { 94 c := New("apiKey", 5*time.Second, nullLogger()) 95 96 baseURL := "http://default-url.com" 97 ctxWithValue := context.WithValue(context.Background(), 98 "X-Mistral-Baseurl", []string{"http://base-url-passed-in-header.com"}) 99 100 buildURL, err := c.getMistralUrl(ctxWithValue, baseURL) 101 require.NoError(t, err) 102 assert.Equal(t, "http://base-url-passed-in-header.com/v1/chat/completions", buildURL) 103 104 buildURL, err = c.getMistralUrl(context.TODO(), baseURL) 105 require.NoError(t, err) 106 assert.Equal(t, "http://default-url.com/v1/chat/completions", buildURL) 107 }) 108 } 109 110 type testAnswerHandler struct { 111 t *testing.T 112 // the test handler will report as not ready before the time has passed 113 answer generateResponse 114 timeout time.Duration 115 } 116 117 func (f *testAnswerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 118 assert.Equal(f.t, "/v1/chat/completions", r.URL.String()) 119 assert.Equal(f.t, http.MethodPost, r.Method) 120 121 time.Sleep(f.timeout) 122 123 if f.answer.Error != nil && f.answer.Error.Message != "" { 124 outBytes, err := json.Marshal(f.answer) 125 require.Nil(f.t, err) 126 127 w.WriteHeader(http.StatusInternalServerError) 128 w.Write(outBytes) 129 return 130 } 131 132 bodyBytes, err := io.ReadAll(r.Body) 133 require.Nil(f.t, err) 134 defer r.Body.Close() 135 136 var b map[string]interface{} 137 require.Nil(f.t, json.Unmarshal(bodyBytes, &b)) 138 139 outBytes, err := json.Marshal(f.answer) 140 require.Nil(f.t, err) 141 142 w.Write(outBytes) 143 } 144 145 type fakeClassConfig struct { 146 baseURL string 147 } 148 149 func (cfg *fakeClassConfig) Tenant() string { 150 return "" 151 } 152 153 func (cfg *fakeClassConfig) Class() map[string]interface{} { 154 return nil 155 } 156 157 func (cfg *fakeClassConfig) ClassByModuleName(moduleName string) map[string]interface{} { 158 settings := map[string]interface{}{ 159 "baseURL": cfg.baseURL, 160 } 161 return settings 162 } 163 164 func (cfg *fakeClassConfig) Property(propName string) map[string]interface{} { 165 return nil 166 } 167 168 func (f fakeClassConfig) TargetVector() string { 169 return "" 170 }