github.com/weaviate/weaviate@v1.24.6/modules/text2vec-voyageai/clients/voyageai_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "context" 16 "encoding/json" 17 "io" 18 "net/http" 19 "net/http/httptest" 20 "testing" 21 "time" 22 23 "github.com/pkg/errors" 24 "github.com/sirupsen/logrus" 25 "github.com/sirupsen/logrus/hooks/test" 26 "github.com/stretchr/testify/assert" 27 "github.com/stretchr/testify/require" 28 "github.com/weaviate/weaviate/modules/text2vec-voyageai/ent" 29 ) 30 31 func TestClient(t *testing.T) { 32 t.Run("when all is fine", func(t *testing.T) { 33 server := httptest.NewServer(&fakeHandler{t: t}) 34 defer server.Close() 35 c := &vectorizer{ 36 apiKey: "apiKey", 37 httpClient: &http.Client{}, 38 urlBuilder: &voyageaiUrlBuilder{ 39 origin: server.URL, 40 pathMask: "/embeddings", 41 }, 42 logger: nullLogger(), 43 } 44 expected := &ent.VectorizationResult{ 45 Text: []string{"This is my text"}, 46 Vectors: [][]float32{{0.1, 0.2, 0.3}}, 47 Dimensions: 3, 48 } 49 res, err := c.Vectorize(context.Background(), []string{"This is my text"}, 50 ent.VectorizationConfig{ 51 Model: "voyage-2", 52 BaseURL: server.URL, 53 }) 54 55 assert.Nil(t, err) 56 assert.Equal(t, expected, res) 57 }) 58 59 t.Run("when the context is expired", func(t *testing.T) { 60 server := httptest.NewServer(&fakeHandler{t: t}) 61 defer server.Close() 62 c := &vectorizer{ 63 apiKey: "apiKey", 64 httpClient: &http.Client{}, 65 urlBuilder: &voyageaiUrlBuilder{ 66 origin: server.URL, 67 pathMask: "/embeddings", 68 }, 69 logger: nullLogger(), 70 } 71 ctx, cancel := context.WithDeadline(context.Background(), time.Now()) 72 defer cancel() 73 74 _, err := c.Vectorize(ctx, []string{"This is my text"}, ent.VectorizationConfig{ 75 Model: "voyage-2", 76 }) 77 78 require.NotNil(t, err) 79 assert.Contains(t, err.Error(), "context deadline exceeded") 80 }) 81 82 t.Run("when the server returns an error", func(t *testing.T) { 83 server := httptest.NewServer(&fakeHandler{ 84 t: t, 85 serverError: errors.Errorf("nope, not gonna happen"), 86 }) 87 defer server.Close() 88 c := &vectorizer{ 89 apiKey: "apiKey", 90 httpClient: &http.Client{}, 91 urlBuilder: &voyageaiUrlBuilder{ 92 origin: server.URL, 93 pathMask: "/embeddings", 94 }, 95 logger: nullLogger(), 96 } 97 _, err := c.Vectorize(context.Background(), []string{"This is my text"}, 98 ent.VectorizationConfig{ 99 Model: "voyage-2", 100 BaseURL: server.URL, 101 }) 102 103 require.NotNil(t, err) 104 assert.Equal(t, err.Error(), "connection to VoyageAI failed with status: 500 error: nope, not gonna happen") 105 }) 106 107 t.Run("when VoyageAI key is passed using VoyageAIre-Api-Key header", func(t *testing.T) { 108 server := httptest.NewServer(&fakeHandler{t: t}) 109 defer server.Close() 110 c := &vectorizer{ 111 apiKey: "", 112 httpClient: &http.Client{}, 113 urlBuilder: &voyageaiUrlBuilder{ 114 origin: server.URL, 115 pathMask: "/embeddings", 116 }, 117 logger: nullLogger(), 118 } 119 ctxWithValue := context.WithValue(context.Background(), 120 "X-Voyageai-Api-Key", []string{"some-key"}) 121 122 expected := &ent.VectorizationResult{ 123 Text: []string{"This is my text"}, 124 Vectors: [][]float32{{0.1, 0.2, 0.3}}, 125 Dimensions: 3, 126 } 127 res, err := c.Vectorize(ctxWithValue, []string{"This is my text"}, 128 ent.VectorizationConfig{ 129 Model: "voyage-2", 130 }) 131 132 require.Nil(t, err) 133 assert.Equal(t, expected, res) 134 }) 135 136 t.Run("when VoyageAI key is empty", func(t *testing.T) { 137 server := httptest.NewServer(&fakeHandler{t: t}) 138 defer server.Close() 139 c := &vectorizer{ 140 apiKey: "", 141 httpClient: &http.Client{}, 142 urlBuilder: &voyageaiUrlBuilder{ 143 origin: server.URL, 144 pathMask: "/embeddings", 145 }, 146 logger: nullLogger(), 147 } 148 ctx, cancel := context.WithDeadline(context.Background(), time.Now()) 149 defer cancel() 150 151 _, err := c.Vectorize(ctx, []string{"This is my text"}, ent.VectorizationConfig{ 152 Model: "voyage-2", 153 }) 154 155 require.NotNil(t, err) 156 assert.Equal(t, err.Error(), "VoyageAI API Key: no api key found "+ 157 "neither in request header: X-VoyageAI-Api-Key "+ 158 "nor in environment variable under VOYAGEAI_APIKEY") 159 }) 160 161 t.Run("when X-VoyageAI-Api-Key header is passed but empty", func(t *testing.T) { 162 server := httptest.NewServer(&fakeHandler{t: t}) 163 defer server.Close() 164 c := &vectorizer{ 165 apiKey: "", 166 httpClient: &http.Client{}, 167 urlBuilder: &voyageaiUrlBuilder{ 168 origin: server.URL, 169 pathMask: "/embeddings", 170 }, 171 logger: nullLogger(), 172 } 173 ctxWithValue := context.WithValue(context.Background(), 174 "X-VoyageAI-Api-Key", []string{""}) 175 176 _, err := c.Vectorize(ctxWithValue, []string{"This is my text"}, 177 ent.VectorizationConfig{ 178 Model: "voyage-2", 179 }) 180 181 require.NotNil(t, err) 182 assert.Equal(t, err.Error(), "VoyageAI API Key: no api key found "+ 183 "neither in request header: X-VoyageAI-Api-Key "+ 184 "nor in environment variable under VOYAGEAI_APIKEY") 185 }) 186 187 t.Run("when X-VoyageAI-BaseURL header is passed", func(t *testing.T) { 188 server := httptest.NewServer(&fakeHandler{t: t}) 189 defer server.Close() 190 c := &vectorizer{ 191 apiKey: "", 192 httpClient: &http.Client{}, 193 urlBuilder: &voyageaiUrlBuilder{ 194 origin: server.URL, 195 pathMask: "/embeddings", 196 }, 197 logger: nullLogger(), 198 } 199 200 baseURL := "http://default-url.com" 201 ctxWithValue := context.WithValue(context.Background(), 202 "X-Voyageai-Baseurl", []string{"http://base-url-passed-in-header.com"}) 203 204 buildURL := c.getVoyageAIUrl(ctxWithValue, baseURL) 205 assert.Equal(t, "http://base-url-passed-in-header.com/embeddings", buildURL) 206 207 buildURL = c.getVoyageAIUrl(context.TODO(), baseURL) 208 assert.Equal(t, "http://default-url.com/embeddings", buildURL) 209 }) 210 } 211 212 type fakeHandler struct { 213 t *testing.T 214 serverError error 215 } 216 217 func (f *fakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 218 assert.Equal(f.t, http.MethodPost, r.Method) 219 220 if f.serverError != nil { 221 resp := embeddingsResponse{ 222 Detail: "nope, not gonna happen", 223 } 224 outBytes, err := json.Marshal(resp) 225 require.Nil(f.t, err) 226 227 w.WriteHeader(http.StatusInternalServerError) 228 w.Write(outBytes) 229 return 230 } 231 232 bodyBytes, err := io.ReadAll(r.Body) 233 require.Nil(f.t, err) 234 defer r.Body.Close() 235 236 var req embeddingsRequest 237 require.Nil(f.t, json.Unmarshal(bodyBytes, &req)) 238 239 assert.NotNil(f.t, req) 240 assert.NotEmpty(f.t, req.Input) 241 242 resp := embeddingsResponse{ 243 Data: []embeddingsDataResponse{{Embeddings: []float32{0.1, 0.2, 0.3}}}, 244 } 245 outBytes, err := json.Marshal(resp) 246 require.Nil(f.t, err) 247 248 w.Write(outBytes) 249 } 250 251 func nullLogger() logrus.FieldLogger { 252 l, _ := test.NewNullLogger() 253 return l 254 }