github.com/weaviate/weaviate@v1.24.6/modules/text2vec-cohere/clients/cohere_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "context" 16 "encoding/json" 17 "io" 18 "net/http" 19 "net/http/httptest" 20 "testing" 21 "time" 22 23 "github.com/pkg/errors" 24 "github.com/sirupsen/logrus" 25 "github.com/sirupsen/logrus/hooks/test" 26 "github.com/stretchr/testify/assert" 27 "github.com/stretchr/testify/require" 28 "github.com/weaviate/weaviate/modules/text2vec-cohere/ent" 29 ) 30 31 func TestClient(t *testing.T) { 32 t.Run("when all is fine", func(t *testing.T) { 33 server := httptest.NewServer(&fakeHandler{t: t}) 34 defer server.Close() 35 c := &vectorizer{ 36 apiKey: "apiKey", 37 httpClient: &http.Client{}, 38 urlBuilder: &cohereUrlBuilder{ 39 origin: server.URL, 40 pathMask: "/v1/embed", 41 }, 42 logger: nullLogger(), 43 } 44 expected := &ent.VectorizationResult{ 45 Text: []string{"This is my text"}, 46 Vectors: [][]float32{{0.1, 0.2, 0.3}}, 47 Dimensions: 3, 48 } 49 res, err := c.Vectorize(context.Background(), []string{"This is my text"}, 50 ent.VectorizationConfig{ 51 Model: "large", 52 }) 53 54 assert.Nil(t, err) 55 assert.Equal(t, expected, res) 56 }) 57 58 t.Run("when the context is expired", func(t *testing.T) { 59 server := httptest.NewServer(&fakeHandler{t: t}) 60 defer server.Close() 61 c := &vectorizer{ 62 apiKey: "apiKey", 63 httpClient: &http.Client{}, 64 urlBuilder: &cohereUrlBuilder{ 65 origin: server.URL, 66 pathMask: "/v1/embed", 67 }, 68 logger: nullLogger(), 69 } 70 ctx, cancel := context.WithDeadline(context.Background(), time.Now()) 71 defer cancel() 72 73 _, err := c.Vectorize(ctx, []string{"This is my text"}, ent.VectorizationConfig{}) 74 75 require.NotNil(t, err) 76 assert.Contains(t, err.Error(), "context deadline exceeded") 77 }) 78 79 t.Run("when the server returns an error", func(t *testing.T) { 80 server := httptest.NewServer(&fakeHandler{ 81 t: t, 82 serverError: errors.Errorf("nope, not gonna happen"), 83 }) 84 defer server.Close() 85 c := &vectorizer{ 86 apiKey: "apiKey", 87 httpClient: &http.Client{}, 88 urlBuilder: &cohereUrlBuilder{ 89 origin: server.URL, 90 pathMask: "/v1/embed", 91 }, 92 logger: nullLogger(), 93 } 94 _, err := c.Vectorize(context.Background(), []string{"This is my text"}, 95 ent.VectorizationConfig{}) 96 97 require.NotNil(t, err) 98 assert.Equal(t, err.Error(), "connection to Cohere failed with status: 500 error: nope, not gonna happen") 99 }) 100 101 t.Run("when Cohere key is passed using X-Cohere-Api-Key header", func(t *testing.T) { 102 server := httptest.NewServer(&fakeHandler{t: t}) 103 defer server.Close() 104 c := &vectorizer{ 105 apiKey: "", 106 httpClient: &http.Client{}, 107 urlBuilder: &cohereUrlBuilder{ 108 origin: server.URL, 109 pathMask: "/v1/embed", 110 }, 111 logger: nullLogger(), 112 } 113 ctxWithValue := context.WithValue(context.Background(), 114 "X-Cohere-Api-Key", []string{"some-key"}) 115 116 expected := &ent.VectorizationResult{ 117 Text: []string{"This is my text"}, 118 Vectors: [][]float32{{0.1, 0.2, 0.3}}, 119 Dimensions: 3, 120 } 121 res, err := c.Vectorize(ctxWithValue, []string{"This is my text"}, 122 ent.VectorizationConfig{ 123 Model: "large", 124 }) 125 126 require.Nil(t, err) 127 assert.Equal(t, expected, res) 128 }) 129 130 t.Run("when Cohere key is empty", func(t *testing.T) { 131 server := httptest.NewServer(&fakeHandler{t: t}) 132 defer server.Close() 133 c := &vectorizer{ 134 apiKey: "", 135 httpClient: &http.Client{}, 136 urlBuilder: &cohereUrlBuilder{ 137 origin: server.URL, 138 pathMask: "/v1/embed", 139 }, 140 logger: nullLogger(), 141 } 142 ctx, cancel := context.WithDeadline(context.Background(), time.Now()) 143 defer cancel() 144 145 _, err := c.Vectorize(ctx, []string{"This is my text"}, ent.VectorizationConfig{}) 146 147 require.NotNil(t, err) 148 assert.Equal(t, err.Error(), "Cohere API Key: no api key found "+ 149 "neither in request header: X-Cohere-Api-Key "+ 150 "nor in environment variable under COHERE_APIKEY") 151 }) 152 153 t.Run("when X-Cohere-Api-Key header is passed but empty", func(t *testing.T) { 154 server := httptest.NewServer(&fakeHandler{t: t}) 155 defer server.Close() 156 c := &vectorizer{ 157 apiKey: "", 158 httpClient: &http.Client{}, 159 urlBuilder: &cohereUrlBuilder{ 160 origin: server.URL, 161 pathMask: "/v1/embed", 162 }, 163 logger: nullLogger(), 164 } 165 ctxWithValue := context.WithValue(context.Background(), 166 "X-Cohere-Api-Key", []string{""}) 167 168 _, err := c.Vectorize(ctxWithValue, []string{"This is my text"}, 169 ent.VectorizationConfig{ 170 Model: "large", 171 }) 172 173 require.NotNil(t, err) 174 assert.Equal(t, err.Error(), "Cohere API Key: no api key found "+ 175 "neither in request header: X-Cohere-Api-Key "+ 176 "nor in environment variable under COHERE_APIKEY") 177 }) 178 179 t.Run("when X-Cohere-BaseURL header is passed", func(t *testing.T) { 180 server := httptest.NewServer(&fakeHandler{t: t}) 181 defer server.Close() 182 c := &vectorizer{ 183 apiKey: "", 184 httpClient: &http.Client{}, 185 urlBuilder: &cohereUrlBuilder{ 186 origin: server.URL, 187 pathMask: "/v1/embed", 188 }, 189 logger: nullLogger(), 190 } 191 192 baseURL := "http://default-url.com" 193 ctxWithValue := context.WithValue(context.Background(), 194 "X-Cohere-Baseurl", []string{"http://base-url-passed-in-header.com"}) 195 196 buildURL := c.getCohereUrl(ctxWithValue, baseURL) 197 assert.Equal(t, "http://base-url-passed-in-header.com/v1/embed", buildURL) 198 199 buildURL = c.getCohereUrl(context.TODO(), baseURL) 200 assert.Equal(t, "http://default-url.com/v1/embed", buildURL) 201 }) 202 } 203 204 type fakeHandler struct { 205 t *testing.T 206 serverError error 207 } 208 209 func (f *fakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 210 assert.Equal(f.t, http.MethodPost, r.Method) 211 212 if f.serverError != nil { 213 embeddingError := map[string]interface{}{ 214 "message": f.serverError.Error(), 215 "type": "invalid_request_error", 216 } 217 embeddingResponse := map[string]interface{}{ 218 "message": embeddingError["message"], 219 } 220 outBytes, err := json.Marshal(embeddingResponse) 221 require.Nil(f.t, err) 222 223 w.WriteHeader(http.StatusInternalServerError) 224 w.Write(outBytes) 225 return 226 } 227 228 bodyBytes, err := io.ReadAll(r.Body) 229 require.Nil(f.t, err) 230 defer r.Body.Close() 231 232 var b map[string]interface{} 233 require.Nil(f.t, json.Unmarshal(bodyBytes, &b)) 234 235 textInput := b["texts"].([]interface{}) 236 assert.Greater(f.t, len(textInput), 0) 237 238 embeddingResponse := map[string]interface{}{ 239 "embeddings": [][]float32{{0.1, 0.2, 0.3}}, 240 } 241 outBytes, err := json.Marshal(embeddingResponse) 242 require.Nil(f.t, err) 243 244 w.Write(outBytes) 245 } 246 247 func nullLogger() logrus.FieldLogger { 248 l, _ := test.NewNullLogger() 249 return l 250 }