github.com/weaviate/weaviate@v1.24.6/modules/text2vec-aws/clients/aws_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package clients
    13  
    14  import (
    15  	"context"
    16  	"encoding/json"
    17  	"io"
    18  	"net/http"
    19  	"net/http/httptest"
    20  	"os"
    21  	"strings"
    22  	"testing"
    23  	"time"
    24  
    25  	"github.com/pkg/errors"
    26  	"github.com/sirupsen/logrus"
    27  	"github.com/sirupsen/logrus/hooks/test"
    28  	"github.com/stretchr/testify/assert"
    29  	"github.com/stretchr/testify/require"
    30  	"github.com/weaviate/weaviate/modules/text2vec-aws/ent"
    31  )
    32  
    33  func TestClient(t *testing.T) {
    34  	t.Run("when all is fine", func(t *testing.T) {
    35  		t.Skip("Skipping this test for now")
    36  		server := httptest.NewServer(&fakeHandler{t: t})
    37  		defer server.Close()
    38  		c := &aws{
    39  			httpClient:   &http.Client{},
    40  			logger:       nullLogger(),
    41  			awsAccessKey: "access_key",
    42  			awsSecret:    "secret",
    43  			buildBedrockUrlFn: func(service, region, model string) string {
    44  				return server.URL
    45  			},
    46  			buildSagemakerUrlFn: func(service, region, endpoint string) string {
    47  				return server.URL
    48  			},
    49  		}
    50  		expected := &ent.VectorizationResult{
    51  			Text:       "This is my text",
    52  			Vector:     []float32{0.1, 0.2, 0.3},
    53  			Dimensions: 3,
    54  		}
    55  		res, err := c.Vectorize(context.Background(), []string{"This is my text"},
    56  			ent.VectorizationConfig{
    57  				Service: "bedrock",
    58  				Region:  "region",
    59  				Model:   "model",
    60  			})
    61  
    62  		assert.Nil(t, err)
    63  		assert.Equal(t, expected, res)
    64  	})
    65  
    66  	t.Run("when all is fine - Sagemaker", func(t *testing.T) {
    67  		server := httptest.NewServer(&fakeHandler{t: t})
    68  		defer server.Close()
    69  		c := &aws{
    70  			httpClient:   &http.Client{},
    71  			logger:       nullLogger(),
    72  			awsAccessKey: "access_key",
    73  			awsSecret:    "secret",
    74  			buildBedrockUrlFn: func(service, region, model string) string {
    75  				return server.URL
    76  			},
    77  			buildSagemakerUrlFn: func(service, region, endpoint string) string {
    78  				return server.URL
    79  			},
    80  		}
    81  		expected := &ent.VectorizationResult{
    82  			Text:       "This is my text",
    83  			Vector:     []float32{0.1, 0.2, 0.3},
    84  			Dimensions: 3,
    85  		}
    86  		res, err := c.Vectorize(context.Background(), []string{"This is my text"},
    87  			ent.VectorizationConfig{
    88  				Service:  "sagemaker",
    89  				Region:   "region",
    90  				Endpoint: "endpoint",
    91  			})
    92  
    93  		assert.Nil(t, err)
    94  		assert.Equal(t, expected, res)
    95  	})
    96  
    97  	t.Run("when the server returns an error", func(t *testing.T) {
    98  		t.Skip("Skipping this test for now")
    99  		server := httptest.NewServer(&fakeHandler{
   100  			t:           t,
   101  			serverError: errors.Errorf("nope, not gonna happen"),
   102  		})
   103  		defer server.Close()
   104  		c := &aws{
   105  			httpClient:   &http.Client{},
   106  			logger:       nullLogger(),
   107  			awsAccessKey: "access_key",
   108  			awsSecret:    "secret",
   109  			buildBedrockUrlFn: func(service, region, model string) string {
   110  				return server.URL
   111  			},
   112  			buildSagemakerUrlFn: func(service, region, endpoint string) string {
   113  				return server.URL
   114  			},
   115  		}
   116  		_, err := c.Vectorize(context.Background(), []string{"This is my text"},
   117  			ent.VectorizationConfig{
   118  				Service: "bedrock",
   119  			})
   120  
   121  		require.NotNil(t, err)
   122  		assert.EqualError(t, err, "connection to AWS failed with status: 500 error: nope, not gonna happen")
   123  	})
   124  
   125  	t.Run("when AWS key is passed using X-Aws-Api-Key header", func(t *testing.T) {
   126  		t.Skip("Skipping this test for now")
   127  		server := httptest.NewServer(&fakeHandler{t: t})
   128  		defer server.Close()
   129  		c := &aws{
   130  			httpClient:   &http.Client{},
   131  			logger:       nullLogger(),
   132  			awsAccessKey: "access_key",
   133  			awsSecret:    "secret",
   134  			buildBedrockUrlFn: func(service, region, model string) string {
   135  				return server.URL
   136  			},
   137  			buildSagemakerUrlFn: func(service, region, endpoint string) string {
   138  				return server.URL
   139  			},
   140  		}
   141  		ctxWithValue := context.WithValue(context.Background(),
   142  			"X-Aws-Api-Key", []string{"some-key"})
   143  
   144  		expected := &ent.VectorizationResult{
   145  			Text:       "This is my text",
   146  			Vector:     []float32{0.1, 0.2, 0.3},
   147  			Dimensions: 3,
   148  		}
   149  		res, err := c.Vectorize(ctxWithValue, []string{"This is my text"}, ent.VectorizationConfig{
   150  			Service: "bedrock",
   151  		})
   152  
   153  		require.Nil(t, err)
   154  		assert.Equal(t, expected, res)
   155  	})
   156  
   157  	t.Run("when X-Aws-Access-Key header is passed but empty", func(t *testing.T) {
   158  		t.Skip("Skipping this test for now")
   159  		server := httptest.NewServer(&fakeHandler{t: t})
   160  		defer server.Close()
   161  		c := &aws{
   162  			httpClient:   &http.Client{},
   163  			logger:       nullLogger(),
   164  			awsAccessKey: "",
   165  			awsSecret:    "123",
   166  			buildBedrockUrlFn: func(service, region, model string) string {
   167  				return server.URL
   168  			},
   169  			buildSagemakerUrlFn: func(service, region, endpoint string) string {
   170  				return server.URL
   171  			},
   172  		}
   173  		ctxWithValue := context.WithValue(context.Background(),
   174  			"X-Aws-Api-Key", []string{""})
   175  
   176  		_, err := c.Vectorize(ctxWithValue, []string{"This is my text"}, ent.VectorizationConfig{
   177  			Service: "bedrock",
   178  		})
   179  
   180  		require.NotNil(t, err)
   181  		assert.Equal(t, err.Error(), "AWS Access Key: no access key found neither in request header: "+
   182  			"X-Aws-Access-Key nor in environment variable under AWS_ACCESS_KEY_ID")
   183  	})
   184  
   185  	t.Run("when X-Aws-Secret-Key header is passed but empty", func(t *testing.T) {
   186  		t.Skip("Skipping this test for now")
   187  		server := httptest.NewServer(&fakeHandler{t: t})
   188  		defer server.Close()
   189  		c := &aws{
   190  			httpClient:   &http.Client{},
   191  			logger:       nullLogger(),
   192  			awsAccessKey: "123",
   193  			awsSecret:    "",
   194  			buildBedrockUrlFn: func(service, region, model string) string {
   195  				return server.URL
   196  			},
   197  			buildSagemakerUrlFn: func(service, region, endpoint string) string {
   198  				return server.URL
   199  			},
   200  		}
   201  		ctxWithValue := context.WithValue(context.Background(),
   202  			"X-Aws-Api-Key", []string{""})
   203  
   204  		_, err := c.Vectorize(ctxWithValue, []string{"This is my text"}, ent.VectorizationConfig{
   205  			Service: "bedrock",
   206  		})
   207  
   208  		require.NotNil(t, err)
   209  		assert.Equal(t, err.Error(), "AWS Secret Key: no secret found neither in request header: "+
   210  			"X-Aws-Access-Secret nor in environment variable under AWS_SECRET_ACCESS_KEY")
   211  	})
   212  }
   213  
   214  func TestBuildBedrockUrl(t *testing.T) {
   215  	service := "bedrock"
   216  	region := "us-east-1"
   217  	t.Run("when using a Cohere", func(t *testing.T) {
   218  		model := "cohere.embed-english-v3"
   219  
   220  		expected := "https://bedrock-runtime.us-east-1.amazonaws.com/model/cohere.embed-english-v3/invoke"
   221  		result := buildBedrockUrl(service, region, model)
   222  
   223  		if result != expected {
   224  			t.Errorf("Expected %s but got %s", expected, result)
   225  		}
   226  	})
   227  
   228  	t.Run("When using an AWS model", func(t *testing.T) {
   229  		model := "amazon.titan-e1t-medium"
   230  
   231  		expected := "https://bedrock.us-east-1.amazonaws.com/model/amazon.titan-e1t-medium/invoke"
   232  		result := buildBedrockUrl(service, region, model)
   233  
   234  		if result != expected {
   235  			t.Errorf("Expected %s but got %s", expected, result)
   236  		}
   237  	})
   238  }
   239  
   240  func TestCreateRequestBody(t *testing.T) {
   241  	input := []string{"Hello, world!"}
   242  
   243  	t.Run("Create request for Amazon embedding model", func(t *testing.T) {
   244  		model := "amazon.titan-e1t-medium"
   245  		req, _ := createRequestBody(model, input, vectorizeObject)
   246  		_, ok := req.(bedrockEmbeddingsRequest)
   247  		if !ok {
   248  			t.Fatalf("Expected req to be a bedrockEmbeddingsRequest, got %T", req)
   249  		}
   250  	})
   251  
   252  	t.Run("Create request for Cohere embedding model", func(t *testing.T) {
   253  		model := "cohere.embed-english-v3"
   254  		req, _ := createRequestBody(model, input, vectorizeObject)
   255  		_, ok := req.(bedrockCohereEmbeddingRequest)
   256  		if !ok {
   257  			t.Fatalf("Expected req to be a bedrockCohereEmbeddingRequest, got %T", req)
   258  		}
   259  	})
   260  
   261  	t.Run("Create request for unknown embedding model", func(t *testing.T) {
   262  		model := "unknown.model"
   263  		_, err := createRequestBody(model, input, vectorizeObject)
   264  		if err == nil {
   265  			t.Errorf("Expected an error for unknown model, got nil")
   266  		}
   267  	})
   268  }
   269  
   270  func TestVectorize(t *testing.T) {
   271  	ctx := context.Background()
   272  	input := []string{"Hello, world!"}
   273  
   274  	t.Run("Vectorize using an Amazon model", func(t *testing.T) {
   275  		t.Skip("Skipping because CI doesnt have the right credentials")
   276  		config := ent.VectorizationConfig{
   277  			Model:   "amazon.titan-e1t-medium",
   278  			Service: "bedrock",
   279  			Region:  "us-east-1",
   280  		}
   281  
   282  		awsAccessKeyID := os.Getenv("AWS_ACCESS_KEY_ID_AMAZON")
   283  		awsSecretAccessKey := os.Getenv("AWS_SECRET_ACCESS_KEY_AMAZON")
   284  
   285  		aws := New(awsAccessKeyID, awsSecretAccessKey, 60*time.Second, nil)
   286  
   287  		_, err := aws.Vectorize(ctx, input, config)
   288  		if err != nil {
   289  			t.Errorf("Vectorize returned an error: %v", err)
   290  		}
   291  	})
   292  
   293  	t.Run("Vectorize using a Cohere model", func(t *testing.T) {
   294  		t.Skip("Skipping because CI doesnt have the right credentials")
   295  		config := ent.VectorizationConfig{
   296  			Model:   "cohere.embed-english-v3",
   297  			Service: "bedrock",
   298  			Region:  "us-east-1",
   299  		}
   300  
   301  		awsAccessKeyID := os.Getenv("AWS_ACCESS_KEY_ID_COHERE")
   302  		awsSecretAccessKey := os.Getenv("AWS_SECRET_ACCESS_KEY_COHERE")
   303  
   304  		aws := New(awsAccessKeyID, awsSecretAccessKey, 60*time.Second, nil)
   305  
   306  		_, err := aws.Vectorize(ctx, input, config)
   307  		if err != nil {
   308  			t.Errorf("Vectorize returned an error: %v", err)
   309  		}
   310  	})
   311  }
   312  
   313  func TestExtractHostAndPath(t *testing.T) {
   314  	t.Run("valid URL", func(t *testing.T) {
   315  		endpointUrl := "https://service.region.amazonaws.com/model/model-name/invoke"
   316  		expectedHost := "service.region.amazonaws.com"
   317  		expectedPath := "/model/model-name/invoke"
   318  
   319  		host, path, err := extractHostAndPath(endpointUrl)
   320  		if err != nil {
   321  			t.Errorf("Unexpected error: %v", err)
   322  		}
   323  		if host != expectedHost {
   324  			t.Errorf("Expected host %s but got %s", expectedHost, host)
   325  		}
   326  		if path != expectedPath {
   327  			t.Errorf("Expected path %s but got %s", expectedPath, path)
   328  		}
   329  	})
   330  
   331  	t.Run("URL without host or path", func(t *testing.T) {
   332  		endpointUrl := "https://"
   333  
   334  		_, _, err := extractHostAndPath(endpointUrl)
   335  
   336  		if err == nil {
   337  			t.Error("Expected error but got nil")
   338  		}
   339  	})
   340  }
   341  
   342  type fakeHandler struct {
   343  	t           *testing.T
   344  	serverError error
   345  }
   346  
   347  func (f *fakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   348  	assert.Equal(f.t, http.MethodPost, r.Method)
   349  
   350  	authHeader := r.Header["Authorization"][0]
   351  	if f.serverError != nil {
   352  		var outBytes []byte
   353  		var err error
   354  
   355  		if strings.Contains(authHeader, "bedrock") {
   356  			embeddingResponse := &bedrockEmbeddingResponse{
   357  				Message: ptString(f.serverError.Error()),
   358  			}
   359  			outBytes, err = json.Marshal(embeddingResponse)
   360  		} else {
   361  			embeddingResponse := &sagemakerEmbeddingResponse{
   362  				Message: ptString(f.serverError.Error()),
   363  			}
   364  			outBytes, err = json.Marshal(embeddingResponse)
   365  		}
   366  
   367  		require.Nil(f.t, err)
   368  
   369  		w.WriteHeader(http.StatusInternalServerError)
   370  		w.Write(outBytes)
   371  		return
   372  	}
   373  
   374  	bodyBytes, err := io.ReadAll(r.Body)
   375  	require.Nil(f.t, err)
   376  	defer r.Body.Close()
   377  
   378  	var outBytes []byte
   379  	if strings.Contains(authHeader, "bedrock") {
   380  		var req bedrockEmbeddingsRequest
   381  		require.Nil(f.t, json.Unmarshal(bodyBytes, &req))
   382  
   383  		textInput := req.InputText
   384  		assert.Greater(f.t, len(textInput), 0)
   385  		embeddingResponse := &bedrockEmbeddingResponse{
   386  			Embedding: []float32{0.1, 0.2, 0.3},
   387  		}
   388  		outBytes, err = json.Marshal(embeddingResponse)
   389  	} else {
   390  		var req sagemakerEmbeddingsRequest
   391  		require.Nil(f.t, json.Unmarshal(bodyBytes, &req))
   392  
   393  		textInputs := req.TextInputs
   394  		assert.Greater(f.t, len(textInputs), 0)
   395  		embeddingResponse := &sagemakerEmbeddingResponse{
   396  			Embedding: [][]float32{{0.1, 0.2, 0.3}},
   397  		}
   398  		outBytes, err = json.Marshal(embeddingResponse)
   399  	}
   400  
   401  	require.Nil(f.t, err)
   402  
   403  	w.Write(outBytes)
   404  }
   405  
   406  func nullLogger() logrus.FieldLogger {
   407  	l, _ := test.NewNullLogger()
   408  	return l
   409  }
   410  
   411  func ptString(in string) *string {
   412  	return &in
   413  }