github.com/weaviate/weaviate@v1.24.6/modules/text2vec-aws/clients/aws_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "context" 16 "encoding/json" 17 "io" 18 "net/http" 19 "net/http/httptest" 20 "os" 21 "strings" 22 "testing" 23 "time" 24 25 "github.com/pkg/errors" 26 "github.com/sirupsen/logrus" 27 "github.com/sirupsen/logrus/hooks/test" 28 "github.com/stretchr/testify/assert" 29 "github.com/stretchr/testify/require" 30 "github.com/weaviate/weaviate/modules/text2vec-aws/ent" 31 ) 32 33 func TestClient(t *testing.T) { 34 t.Run("when all is fine", func(t *testing.T) { 35 t.Skip("Skipping this test for now") 36 server := httptest.NewServer(&fakeHandler{t: t}) 37 defer server.Close() 38 c := &aws{ 39 httpClient: &http.Client{}, 40 logger: nullLogger(), 41 awsAccessKey: "access_key", 42 awsSecret: "secret", 43 buildBedrockUrlFn: func(service, region, model string) string { 44 return server.URL 45 }, 46 buildSagemakerUrlFn: func(service, region, endpoint string) string { 47 return server.URL 48 }, 49 } 50 expected := &ent.VectorizationResult{ 51 Text: "This is my text", 52 Vector: []float32{0.1, 0.2, 0.3}, 53 Dimensions: 3, 54 } 55 res, err := c.Vectorize(context.Background(), []string{"This is my text"}, 56 ent.VectorizationConfig{ 57 Service: "bedrock", 58 Region: "region", 59 Model: "model", 60 }) 61 62 assert.Nil(t, err) 63 assert.Equal(t, expected, res) 64 }) 65 66 t.Run("when all is fine - Sagemaker", func(t *testing.T) { 67 server := httptest.NewServer(&fakeHandler{t: t}) 68 defer server.Close() 69 c := &aws{ 70 httpClient: &http.Client{}, 71 logger: nullLogger(), 72 awsAccessKey: "access_key", 73 awsSecret: "secret", 74 buildBedrockUrlFn: func(service, region, model string) string { 75 return server.URL 76 }, 77 buildSagemakerUrlFn: func(service, region, endpoint string) string { 78 return server.URL 79 }, 80 } 81 expected := &ent.VectorizationResult{ 82 Text: "This is my text", 83 Vector: []float32{0.1, 0.2, 0.3}, 84 Dimensions: 3, 85 } 86 res, err := c.Vectorize(context.Background(), []string{"This is my text"}, 87 ent.VectorizationConfig{ 88 Service: "sagemaker", 89 Region: "region", 90 Endpoint: "endpoint", 91 }) 92 93 assert.Nil(t, err) 94 assert.Equal(t, expected, res) 95 }) 96 97 t.Run("when the server returns an error", func(t *testing.T) { 98 t.Skip("Skipping this test for now") 99 server := httptest.NewServer(&fakeHandler{ 100 t: t, 101 serverError: errors.Errorf("nope, not gonna happen"), 102 }) 103 defer server.Close() 104 c := &aws{ 105 httpClient: &http.Client{}, 106 logger: nullLogger(), 107 awsAccessKey: "access_key", 108 awsSecret: "secret", 109 buildBedrockUrlFn: func(service, region, model string) string { 110 return server.URL 111 }, 112 buildSagemakerUrlFn: func(service, region, endpoint string) string { 113 return server.URL 114 }, 115 } 116 _, err := c.Vectorize(context.Background(), []string{"This is my text"}, 117 ent.VectorizationConfig{ 118 Service: "bedrock", 119 }) 120 121 require.NotNil(t, err) 122 assert.EqualError(t, err, "connection to AWS failed with status: 500 error: nope, not gonna happen") 123 }) 124 125 t.Run("when AWS key is passed using X-Aws-Api-Key header", func(t *testing.T) { 126 t.Skip("Skipping this test for now") 127 server := httptest.NewServer(&fakeHandler{t: t}) 128 defer server.Close() 129 c := &aws{ 130 httpClient: &http.Client{}, 131 logger: nullLogger(), 132 awsAccessKey: "access_key", 133 awsSecret: "secret", 134 buildBedrockUrlFn: func(service, region, model string) string { 135 return server.URL 136 }, 137 buildSagemakerUrlFn: func(service, region, endpoint string) string { 138 return server.URL 139 }, 140 } 141 ctxWithValue := context.WithValue(context.Background(), 142 "X-Aws-Api-Key", []string{"some-key"}) 143 144 expected := &ent.VectorizationResult{ 145 Text: "This is my text", 146 Vector: []float32{0.1, 0.2, 0.3}, 147 Dimensions: 3, 148 } 149 res, err := c.Vectorize(ctxWithValue, []string{"This is my text"}, ent.VectorizationConfig{ 150 Service: "bedrock", 151 }) 152 153 require.Nil(t, err) 154 assert.Equal(t, expected, res) 155 }) 156 157 t.Run("when X-Aws-Access-Key header is passed but empty", func(t *testing.T) { 158 t.Skip("Skipping this test for now") 159 server := httptest.NewServer(&fakeHandler{t: t}) 160 defer server.Close() 161 c := &aws{ 162 httpClient: &http.Client{}, 163 logger: nullLogger(), 164 awsAccessKey: "", 165 awsSecret: "123", 166 buildBedrockUrlFn: func(service, region, model string) string { 167 return server.URL 168 }, 169 buildSagemakerUrlFn: func(service, region, endpoint string) string { 170 return server.URL 171 }, 172 } 173 ctxWithValue := context.WithValue(context.Background(), 174 "X-Aws-Api-Key", []string{""}) 175 176 _, err := c.Vectorize(ctxWithValue, []string{"This is my text"}, ent.VectorizationConfig{ 177 Service: "bedrock", 178 }) 179 180 require.NotNil(t, err) 181 assert.Equal(t, err.Error(), "AWS Access Key: no access key found neither in request header: "+ 182 "X-Aws-Access-Key nor in environment variable under AWS_ACCESS_KEY_ID") 183 }) 184 185 t.Run("when X-Aws-Secret-Key header is passed but empty", func(t *testing.T) { 186 t.Skip("Skipping this test for now") 187 server := httptest.NewServer(&fakeHandler{t: t}) 188 defer server.Close() 189 c := &aws{ 190 httpClient: &http.Client{}, 191 logger: nullLogger(), 192 awsAccessKey: "123", 193 awsSecret: "", 194 buildBedrockUrlFn: func(service, region, model string) string { 195 return server.URL 196 }, 197 buildSagemakerUrlFn: func(service, region, endpoint string) string { 198 return server.URL 199 }, 200 } 201 ctxWithValue := context.WithValue(context.Background(), 202 "X-Aws-Api-Key", []string{""}) 203 204 _, err := c.Vectorize(ctxWithValue, []string{"This is my text"}, ent.VectorizationConfig{ 205 Service: "bedrock", 206 }) 207 208 require.NotNil(t, err) 209 assert.Equal(t, err.Error(), "AWS Secret Key: no secret found neither in request header: "+ 210 "X-Aws-Access-Secret nor in environment variable under AWS_SECRET_ACCESS_KEY") 211 }) 212 } 213 214 func TestBuildBedrockUrl(t *testing.T) { 215 service := "bedrock" 216 region := "us-east-1" 217 t.Run("when using a Cohere", func(t *testing.T) { 218 model := "cohere.embed-english-v3" 219 220 expected := "https://bedrock-runtime.us-east-1.amazonaws.com/model/cohere.embed-english-v3/invoke" 221 result := buildBedrockUrl(service, region, model) 222 223 if result != expected { 224 t.Errorf("Expected %s but got %s", expected, result) 225 } 226 }) 227 228 t.Run("When using an AWS model", func(t *testing.T) { 229 model := "amazon.titan-e1t-medium" 230 231 expected := "https://bedrock.us-east-1.amazonaws.com/model/amazon.titan-e1t-medium/invoke" 232 result := buildBedrockUrl(service, region, model) 233 234 if result != expected { 235 t.Errorf("Expected %s but got %s", expected, result) 236 } 237 }) 238 } 239 240 func TestCreateRequestBody(t *testing.T) { 241 input := []string{"Hello, world!"} 242 243 t.Run("Create request for Amazon embedding model", func(t *testing.T) { 244 model := "amazon.titan-e1t-medium" 245 req, _ := createRequestBody(model, input, vectorizeObject) 246 _, ok := req.(bedrockEmbeddingsRequest) 247 if !ok { 248 t.Fatalf("Expected req to be a bedrockEmbeddingsRequest, got %T", req) 249 } 250 }) 251 252 t.Run("Create request for Cohere embedding model", func(t *testing.T) { 253 model := "cohere.embed-english-v3" 254 req, _ := createRequestBody(model, input, vectorizeObject) 255 _, ok := req.(bedrockCohereEmbeddingRequest) 256 if !ok { 257 t.Fatalf("Expected req to be a bedrockCohereEmbeddingRequest, got %T", req) 258 } 259 }) 260 261 t.Run("Create request for unknown embedding model", func(t *testing.T) { 262 model := "unknown.model" 263 _, err := createRequestBody(model, input, vectorizeObject) 264 if err == nil { 265 t.Errorf("Expected an error for unknown model, got nil") 266 } 267 }) 268 } 269 270 func TestVectorize(t *testing.T) { 271 ctx := context.Background() 272 input := []string{"Hello, world!"} 273 274 t.Run("Vectorize using an Amazon model", func(t *testing.T) { 275 t.Skip("Skipping because CI doesnt have the right credentials") 276 config := ent.VectorizationConfig{ 277 Model: "amazon.titan-e1t-medium", 278 Service: "bedrock", 279 Region: "us-east-1", 280 } 281 282 awsAccessKeyID := os.Getenv("AWS_ACCESS_KEY_ID_AMAZON") 283 awsSecretAccessKey := os.Getenv("AWS_SECRET_ACCESS_KEY_AMAZON") 284 285 aws := New(awsAccessKeyID, awsSecretAccessKey, 60*time.Second, nil) 286 287 _, err := aws.Vectorize(ctx, input, config) 288 if err != nil { 289 t.Errorf("Vectorize returned an error: %v", err) 290 } 291 }) 292 293 t.Run("Vectorize using a Cohere model", func(t *testing.T) { 294 t.Skip("Skipping because CI doesnt have the right credentials") 295 config := ent.VectorizationConfig{ 296 Model: "cohere.embed-english-v3", 297 Service: "bedrock", 298 Region: "us-east-1", 299 } 300 301 awsAccessKeyID := os.Getenv("AWS_ACCESS_KEY_ID_COHERE") 302 awsSecretAccessKey := os.Getenv("AWS_SECRET_ACCESS_KEY_COHERE") 303 304 aws := New(awsAccessKeyID, awsSecretAccessKey, 60*time.Second, nil) 305 306 _, err := aws.Vectorize(ctx, input, config) 307 if err != nil { 308 t.Errorf("Vectorize returned an error: %v", err) 309 } 310 }) 311 } 312 313 func TestExtractHostAndPath(t *testing.T) { 314 t.Run("valid URL", func(t *testing.T) { 315 endpointUrl := "https://service.region.amazonaws.com/model/model-name/invoke" 316 expectedHost := "service.region.amazonaws.com" 317 expectedPath := "/model/model-name/invoke" 318 319 host, path, err := extractHostAndPath(endpointUrl) 320 if err != nil { 321 t.Errorf("Unexpected error: %v", err) 322 } 323 if host != expectedHost { 324 t.Errorf("Expected host %s but got %s", expectedHost, host) 325 } 326 if path != expectedPath { 327 t.Errorf("Expected path %s but got %s", expectedPath, path) 328 } 329 }) 330 331 t.Run("URL without host or path", func(t *testing.T) { 332 endpointUrl := "https://" 333 334 _, _, err := extractHostAndPath(endpointUrl) 335 336 if err == nil { 337 t.Error("Expected error but got nil") 338 } 339 }) 340 } 341 342 type fakeHandler struct { 343 t *testing.T 344 serverError error 345 } 346 347 func (f *fakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 348 assert.Equal(f.t, http.MethodPost, r.Method) 349 350 authHeader := r.Header["Authorization"][0] 351 if f.serverError != nil { 352 var outBytes []byte 353 var err error 354 355 if strings.Contains(authHeader, "bedrock") { 356 embeddingResponse := &bedrockEmbeddingResponse{ 357 Message: ptString(f.serverError.Error()), 358 } 359 outBytes, err = json.Marshal(embeddingResponse) 360 } else { 361 embeddingResponse := &sagemakerEmbeddingResponse{ 362 Message: ptString(f.serverError.Error()), 363 } 364 outBytes, err = json.Marshal(embeddingResponse) 365 } 366 367 require.Nil(f.t, err) 368 369 w.WriteHeader(http.StatusInternalServerError) 370 w.Write(outBytes) 371 return 372 } 373 374 bodyBytes, err := io.ReadAll(r.Body) 375 require.Nil(f.t, err) 376 defer r.Body.Close() 377 378 var outBytes []byte 379 if strings.Contains(authHeader, "bedrock") { 380 var req bedrockEmbeddingsRequest 381 require.Nil(f.t, json.Unmarshal(bodyBytes, &req)) 382 383 textInput := req.InputText 384 assert.Greater(f.t, len(textInput), 0) 385 embeddingResponse := &bedrockEmbeddingResponse{ 386 Embedding: []float32{0.1, 0.2, 0.3}, 387 } 388 outBytes, err = json.Marshal(embeddingResponse) 389 } else { 390 var req sagemakerEmbeddingsRequest 391 require.Nil(f.t, json.Unmarshal(bodyBytes, &req)) 392 393 textInputs := req.TextInputs 394 assert.Greater(f.t, len(textInputs), 0) 395 embeddingResponse := &sagemakerEmbeddingResponse{ 396 Embedding: [][]float32{{0.1, 0.2, 0.3}}, 397 } 398 outBytes, err = json.Marshal(embeddingResponse) 399 } 400 401 require.Nil(f.t, err) 402 403 w.Write(outBytes) 404 } 405 406 func nullLogger() logrus.FieldLogger { 407 l, _ := test.NewNullLogger() 408 return l 409 } 410 411 func ptString(in string) *string { 412 return &in 413 }