github.com/weaviate/weaviate@v1.24.6/modules/text2vec-aws/clients/aws.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package clients
    13  
    14  import (
    15  	"bytes"
    16  	"context"
    17  	"encoding/json"
    18  	"fmt"
    19  	"io"
    20  	"net/http"
    21  	"net/url"
    22  	"strings"
    23  	"time"
    24  
    25  	"github.com/google/uuid"
    26  	"github.com/pkg/errors"
    27  	"github.com/sirupsen/logrus"
    28  	"github.com/weaviate/weaviate/modules/text2vec-aws/ent"
    29  	"github.com/weaviate/weaviate/usecases/modulecomponents"
    30  )
    31  
    32  type operationType string
    33  
    34  var (
    35  	vectorizeObject operationType = "vectorize_object"
    36  	vectorizeQuery  operationType = "vectorize_query"
    37  )
    38  
    39  func buildBedrockUrl(service, region, model string) string {
    40  	serviceName := service
    41  	if strings.HasPrefix(model, "cohere") {
    42  		serviceName = fmt.Sprintf("%s-runtime", serviceName)
    43  	}
    44  	urlTemplate := "https://%s.%s.amazonaws.com/model/%s/invoke"
    45  	return fmt.Sprintf(urlTemplate, serviceName, region, model)
    46  }
    47  
    48  func buildSagemakerUrl(service, region, endpoint string) string {
    49  	urlTemplate := "https://runtime.%s.%s.amazonaws.com/endpoints/%s/invocations"
    50  	return fmt.Sprintf(urlTemplate, service, region, endpoint)
    51  }
    52  
    53  type aws struct {
    54  	awsAccessKey        string
    55  	awsSecret           string
    56  	buildBedrockUrlFn   func(service, region, model string) string
    57  	buildSagemakerUrlFn func(service, region, endpoint string) string
    58  	httpClient          *http.Client
    59  	logger              logrus.FieldLogger
    60  }
    61  
    62  func New(awsAccessKey string, awsSecret string, timeout time.Duration, logger logrus.FieldLogger) *aws {
    63  	return &aws{
    64  		awsAccessKey: awsAccessKey,
    65  		awsSecret:    awsSecret,
    66  		httpClient: &http.Client{
    67  			Timeout: timeout,
    68  		},
    69  		buildBedrockUrlFn:   buildBedrockUrl,
    70  		buildSagemakerUrlFn: buildSagemakerUrl,
    71  		logger:              logger,
    72  	}
    73  }
    74  
    75  func (v *aws) Vectorize(ctx context.Context, input []string,
    76  	config ent.VectorizationConfig,
    77  ) (*ent.VectorizationResult, error) {
    78  	return v.vectorize(ctx, input, vectorizeObject, config)
    79  }
    80  
    81  func (v *aws) VectorizeQuery(ctx context.Context, input []string,
    82  	config ent.VectorizationConfig,
    83  ) (*ent.VectorizationResult, error) {
    84  	return v.vectorize(ctx, input, vectorizeQuery, config)
    85  }
    86  
    87  func (v *aws) vectorize(ctx context.Context, input []string, operation operationType, config ent.VectorizationConfig) (*ent.VectorizationResult, error) {
    88  	service := v.getService(config)
    89  	region := v.getRegion(config)
    90  	model := v.getModel(config)
    91  	endpoint := v.getEndpoint(config)
    92  	targetModel := v.getTargetModel(config)
    93  	targetVariant := v.getTargetVariant(config)
    94  
    95  	var body []byte
    96  	var endpointUrl string
    97  	var host string
    98  	var path string
    99  	var err error
   100  
   101  	headers := map[string]string{
   102  		"accept":       "*/*",
   103  		"content-type": contentType,
   104  	}
   105  
   106  	if v.isBedrock(service) {
   107  		endpointUrl = v.buildBedrockUrlFn(service, region, model)
   108  		host, path, _ = extractHostAndPath(endpointUrl)
   109  
   110  		req, err := createRequestBody(model, input, operation)
   111  		if err != nil {
   112  			return nil, err
   113  		}
   114  
   115  		body, err = json.Marshal(req)
   116  		if err != nil {
   117  			return nil, errors.Wrapf(err, "marshal body")
   118  		}
   119  	} else if v.isSagemaker(service) {
   120  		endpointUrl = v.buildSagemakerUrlFn(service, region, endpoint)
   121  		host = "runtime." + service + "." + region + ".amazonaws.com"
   122  		path = "/endpoints/" + endpoint + "/invocations"
   123  		if targetModel != "" {
   124  			headers["x-amzn-sagemaker-target-model"] = targetModel
   125  		}
   126  		if targetVariant != "" {
   127  			headers["x-amzn-sagemaker-target-variant"] = targetVariant
   128  		}
   129  		body, err = json.Marshal(sagemakerEmbeddingsRequest{
   130  			TextInputs: input,
   131  		})
   132  		if err != nil {
   133  			return nil, errors.Wrapf(err, "marshal body")
   134  		}
   135  	} else {
   136  		return nil, errors.Wrapf(err, "service error")
   137  	}
   138  
   139  	accessKey, err := v.getAwsAccessKey(ctx)
   140  	if err != nil {
   141  		return nil, errors.Wrapf(err, "AWS Access Key")
   142  	}
   143  	secretKey, err := v.getAwsAccessSecret(ctx)
   144  	if err != nil {
   145  		return nil, errors.Wrapf(err, "AWS Secret Key")
   146  	}
   147  
   148  	headers["host"] = host
   149  	amzDate, headers, authorizationHeader := getAuthHeader(accessKey, secretKey, host, service, region, path, body, headers)
   150  	headers["Authorization"] = authorizationHeader
   151  	headers["x-amz-date"] = amzDate
   152  
   153  	req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointUrl, bytes.NewReader(body))
   154  	if err != nil {
   155  		return nil, errors.Wrap(err, "create POST request")
   156  	}
   157  
   158  	for k, v := range headers {
   159  		req.Header.Set(k, v)
   160  	}
   161  
   162  	res, err := v.makeRequest(req, 30, 5)
   163  	if err != nil {
   164  		return nil, errors.Wrap(err, "send POST request")
   165  	}
   166  	defer res.Body.Close()
   167  
   168  	bodyBytes, err := io.ReadAll(res.Body)
   169  	if err != nil {
   170  		return nil, errors.Wrap(err, "read response body")
   171  	}
   172  	if v.isBedrock(service) {
   173  		return v.parseBedrockResponse(bodyBytes, res, input)
   174  	} else {
   175  		return v.parseSagemakerResponse(bodyBytes, res, input)
   176  	}
   177  }
   178  
   179  func (v *aws) makeRequest(req *http.Request, delayInSeconds int, maxRetries int) (*http.Response, error) {
   180  	var res *http.Response
   181  	var err error
   182  
   183  	// Generate a UUID for this request
   184  	requestID := uuid.New().String()
   185  
   186  	for i := 0; i < maxRetries; i++ {
   187  		res, err = v.httpClient.Do(req)
   188  		if err != nil {
   189  			return nil, errors.Wrap(err, "send POST request")
   190  		}
   191  
   192  		// If the status code is not 429 or 400, break the loop
   193  		if res.StatusCode != http.StatusTooManyRequests && res.StatusCode != http.StatusBadRequest {
   194  			break
   195  		}
   196  
   197  		v.logger.Debugf("Request ID %s to %s returned 429, retrying in %d seconds", requestID, req.URL, delayInSeconds)
   198  
   199  		// Sleep for a while and then continue to the next iteration
   200  		time.Sleep(time.Duration(delayInSeconds) * time.Second)
   201  
   202  		// Double the delay for the next iteration
   203  		delayInSeconds *= 2
   204  
   205  	}
   206  
   207  	return res, err
   208  }
   209  
   210  func (v *aws) parseBedrockResponse(bodyBytes []byte, res *http.Response, input []string) (*ent.VectorizationResult, error) {
   211  	var resBodyMap map[string]interface{}
   212  	if err := json.Unmarshal(bodyBytes, &resBodyMap); err != nil {
   213  		return nil, errors.Wrap(err, "unmarshal response body")
   214  	}
   215  
   216  	// if resBodyMap has inputTextTokenCount, it's a resonse from an Amazon model
   217  	// otherwise, it is a response from a Cohere model
   218  	var resBody bedrockEmbeddingResponse
   219  	if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
   220  		return nil, errors.Wrap(err, "unmarshal response body")
   221  	}
   222  
   223  	if res.StatusCode != 200 || resBody.Message != nil {
   224  		if resBody.Message != nil {
   225  			return nil, fmt.Errorf("connection to AWS Bedrock failed with status: %v error: %s",
   226  				res.StatusCode, *resBody.Message)
   227  		}
   228  		return nil, fmt.Errorf("connection to AWS Bedrock failed with status: %d", res.StatusCode)
   229  	}
   230  
   231  	if len(resBody.Embedding) == 0 && len(resBody.Embeddings) == 0 {
   232  		return nil, fmt.Errorf("could not obtain vector from AWS Bedrock")
   233  	}
   234  
   235  	embedding := resBody.Embedding
   236  	if len(resBody.Embeddings) > 0 {
   237  		embedding = resBody.Embeddings[0]
   238  	}
   239  
   240  	return &ent.VectorizationResult{
   241  		Text:       input[0],
   242  		Dimensions: len(embedding),
   243  		Vector:     embedding,
   244  	}, nil
   245  }
   246  
   247  func (v *aws) parseSagemakerResponse(bodyBytes []byte, res *http.Response, input []string) (*ent.VectorizationResult, error) {
   248  	var resBody sagemakerEmbeddingResponse
   249  	if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
   250  		return nil, errors.Wrap(err, "unmarshal response body")
   251  	}
   252  
   253  	if res.StatusCode != 200 || resBody.Message != nil {
   254  		if resBody.Message != nil {
   255  			return nil, fmt.Errorf("connection to AWS failed with status: %v error: %s",
   256  				res.StatusCode, *resBody.Message)
   257  		}
   258  		return nil, fmt.Errorf("connection to AWS failed with status: %d", res.StatusCode)
   259  	}
   260  
   261  	if len(resBody.Embedding) == 0 {
   262  		return nil, errors.Errorf("empty embeddings response")
   263  	}
   264  
   265  	return &ent.VectorizationResult{
   266  		Text:       input[0],
   267  		Dimensions: len(resBody.Embedding[0]),
   268  		Vector:     resBody.Embedding[0],
   269  	}, nil
   270  }
   271  
   272  func (v *aws) isSagemaker(service string) bool {
   273  	return service == "sagemaker"
   274  }
   275  
   276  func (v *aws) isBedrock(service string) bool {
   277  	return service == "bedrock"
   278  }
   279  
   280  func (v *aws) getAwsAccessKey(ctx context.Context) (string, error) {
   281  	awsAccessKey := ctx.Value("X-Aws-Access-Key")
   282  	if awsAccessKeyHeader, ok := awsAccessKey.([]string); ok &&
   283  		len(awsAccessKeyHeader) > 0 && len(awsAccessKeyHeader[0]) > 0 {
   284  		return awsAccessKeyHeader[0], nil
   285  	}
   286  	if len(v.awsAccessKey) > 0 {
   287  		return v.awsAccessKey, nil
   288  	}
   289  	// try getting header from GRPC if not successful
   290  	if accessKey := modulecomponents.GetValueFromGRPC(ctx, "X-Aws-Access-Key"); len(accessKey) > 0 && len(accessKey[0]) > 0 {
   291  		return accessKey[0], nil
   292  	}
   293  	return "", errors.New("no access key found " +
   294  		"neither in request header: X-AWS-Access-Key " +
   295  		"nor in environment variable under AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY")
   296  }
   297  
   298  func (v *aws) getAwsAccessSecret(ctx context.Context) (string, error) {
   299  	awsSecretKey := ctx.Value("X-Aws-Secret-Key")
   300  	if awsAccessSecretHeader, ok := awsSecretKey.([]string); ok &&
   301  		len(awsAccessSecretHeader) > 0 && len(awsAccessSecretHeader[0]) > 0 {
   302  		return awsAccessSecretHeader[0], nil
   303  	}
   304  	if len(v.awsSecret) > 0 {
   305  		return v.awsSecret, nil
   306  	}
   307  	// try getting header from GRPC if not successful
   308  	if secretKey := modulecomponents.GetValueFromGRPC(ctx, "X-Aws-Secret-Key"); len(secretKey) > 0 && len(secretKey[0]) > 0 {
   309  		return secretKey[0], nil
   310  	}
   311  	return "", errors.New("no secret found " +
   312  		"neither in request header: X-AWS-Secret-Key " +
   313  		"nor in environment variable under AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY")
   314  }
   315  
   316  func (v *aws) getModel(config ent.VectorizationConfig) string {
   317  	return config.Model
   318  }
   319  
   320  func (v *aws) getRegion(config ent.VectorizationConfig) string {
   321  	return config.Region
   322  }
   323  
   324  func (v *aws) getService(config ent.VectorizationConfig) string {
   325  	return config.Service
   326  }
   327  
   328  func (v *aws) getEndpoint(config ent.VectorizationConfig) string {
   329  	return config.Endpoint
   330  }
   331  
   332  func (v *aws) getTargetModel(config ent.VectorizationConfig) string {
   333  	return config.TargetModel
   334  }
   335  
   336  func (v *aws) getTargetVariant(config ent.VectorizationConfig) string {
   337  	return config.TargetVariant
   338  }
   339  
   340  type bedrockEmbeddingsRequest struct {
   341  	InputText string `json:"inputText,omitempty"`
   342  }
   343  
   344  type bedrockCohereEmbeddingRequest struct {
   345  	Texts     []string `json:"texts"`
   346  	InputType string   `json:"input_type"`
   347  }
   348  
   349  type sagemakerEmbeddingsRequest struct {
   350  	TextInputs []string `json:"text_inputs,omitempty"`
   351  }
   352  
   353  type bedrockEmbeddingResponse struct {
   354  	InputTextTokenCount int         `json:"InputTextTokenCount,omitempty"`
   355  	Embedding           []float32   `json:"embedding,omitempty"`
   356  	Embeddings          [][]float32 `json:"embeddings,omitempty"`
   357  	Message             *string     `json:"message,omitempty"`
   358  }
   359  type sagemakerEmbeddingResponse struct {
   360  	Embedding          [][]float32 `json:"embedding,omitempty"`
   361  	ErrorCode          *string     `json:"ErrorCode,omitempty"`
   362  	LogStreamArn       *string     `json:"LogStreamArn,omitempty"`
   363  	OriginalMessage    *string     `json:"OriginalMessage,omitempty"`
   364  	Message            *string     `json:"Message,omitempty"`
   365  	OriginalStatusCode *int        `json:"OriginalStatusCode,omitempty"`
   366  }
   367  
   368  func extractHostAndPath(endpointUrl string) (string, string, error) {
   369  	u, err := url.Parse(endpointUrl)
   370  	if err != nil {
   371  		return "", "", err
   372  	}
   373  
   374  	if u.Host == "" || u.Path == "" {
   375  		return "", "", fmt.Errorf("invalid endpoint URL: %s", endpointUrl)
   376  	}
   377  
   378  	return u.Host, u.Path, nil
   379  }
   380  
   381  func createRequestBody(model string, texts []string, operation operationType) (interface{}, error) {
   382  	modelParts := strings.Split(model, ".")
   383  	if len(modelParts) == 0 {
   384  		return nil, fmt.Errorf("invalid model: %s", model)
   385  	}
   386  
   387  	modelProvider := modelParts[0]
   388  
   389  	switch modelProvider {
   390  	case "amazon":
   391  		return bedrockEmbeddingsRequest{
   392  			InputText: texts[0],
   393  		}, nil
   394  	case "cohere":
   395  		inputType := "search_document"
   396  		if operation == vectorizeQuery {
   397  			inputType = "search_query"
   398  		}
   399  		return bedrockCohereEmbeddingRequest{
   400  			Texts:     texts,
   401  			InputType: inputType,
   402  		}, nil
   403  	default:
   404  		return nil, fmt.Errorf("unknown model provider: %s", modelProvider)
   405  	}
   406  }