github.com/weaviate/weaviate@v1.24.6/modules/generative-aws/clients/aws.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package clients
    13  
    14  import (
    15  	"bytes"
    16  	"context"
    17  	"encoding/json"
    18  	"fmt"
    19  	"io"
    20  	"net/http"
    21  	"regexp"
    22  	"strings"
    23  	"time"
    24  
    25  	"github.com/pkg/errors"
    26  	"github.com/sirupsen/logrus"
    27  	"github.com/weaviate/weaviate/entities/moduletools"
    28  	"github.com/weaviate/weaviate/modules/generative-aws/config"
    29  	generativemodels "github.com/weaviate/weaviate/usecases/modulecomponents/additional/models"
    30  )
    31  
    32  var compile, _ = regexp.Compile(`{([\w\s]*?)}`)
    33  
    34  func buildBedrockUrl(service, region, model string) string {
    35  	urlTemplate := "https://%s.%s.amazonaws.com/model/%s/invoke"
    36  	return fmt.Sprintf(urlTemplate, fmt.Sprintf("%s-runtime", service), region, model)
    37  }
    38  
    39  func buildSagemakerUrl(service, region, endpoint string) string {
    40  	urlTemplate := "https://runtime.%s.%s.amazonaws.com/endpoints/%s/invocations"
    41  	return fmt.Sprintf(urlTemplate, service, region, endpoint)
    42  }
    43  
    44  type aws struct {
    45  	awsAccessKey        string
    46  	awsSecretKey        string
    47  	buildBedrockUrlFn   func(service, region, model string) string
    48  	buildSagemakerUrlFn func(service, region, endpoint string) string
    49  	httpClient          *http.Client
    50  	logger              logrus.FieldLogger
    51  }
    52  
    53  func New(awsAccessKey string, awsSecretKey string, timeout time.Duration, logger logrus.FieldLogger) *aws {
    54  	return &aws{
    55  		awsAccessKey: awsAccessKey,
    56  		awsSecretKey: awsSecretKey,
    57  		httpClient: &http.Client{
    58  			Timeout: timeout,
    59  		},
    60  		buildBedrockUrlFn:   buildBedrockUrl,
    61  		buildSagemakerUrlFn: buildSagemakerUrl,
    62  		logger:              logger,
    63  	}
    64  }
    65  
    66  func (v *aws) GenerateSingleResult(ctx context.Context, textProperties map[string]string, prompt string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) {
    67  	forPrompt, err := v.generateForPrompt(textProperties, prompt)
    68  	if err != nil {
    69  		return nil, err
    70  	}
    71  	return v.Generate(ctx, cfg, forPrompt)
    72  }
    73  
    74  func (v *aws) GenerateAllResults(ctx context.Context, textProperties []map[string]string, task string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) {
    75  	forTask, err := v.generatePromptForTask(textProperties, task)
    76  	if err != nil {
    77  		return nil, err
    78  	}
    79  	return v.Generate(ctx, cfg, forTask)
    80  }
    81  
    82  func (v *aws) Generate(ctx context.Context, cfg moduletools.ClassConfig, prompt string) (*generativemodels.GenerateResponse, error) {
    83  	settings := config.NewClassSettings(cfg)
    84  	service := settings.Service()
    85  	region := settings.Region()
    86  	model := settings.Model()
    87  	endpoint := settings.Endpoint()
    88  	targetModel := settings.TargetModel()
    89  	targetVariant := settings.TargetVariant()
    90  
    91  	var body []byte
    92  	var endpointUrl string
    93  	var host string
    94  	var path string
    95  	var err error
    96  
    97  	headers := map[string]string{
    98  		"accept":       "*/*",
    99  		"content-type": contentType,
   100  	}
   101  
   102  	if v.isBedrock(service) {
   103  		endpointUrl = v.buildBedrockUrlFn(service, region, model)
   104  		host = service + "-runtime" + "." + region + ".amazonaws.com"
   105  		path = "/model/" + model + "/invoke"
   106  
   107  		if v.isAmazonModel(model) {
   108  			body, err = json.Marshal(bedrockAmazonGenerateRequest{
   109  				InputText: prompt,
   110  			})
   111  		} else if v.isAnthropicModel(model) {
   112  			var builder strings.Builder
   113  			builder.WriteString("\n\nHuman: ")
   114  			builder.WriteString(prompt)
   115  			builder.WriteString("\n\nAssistant:")
   116  			body, err = json.Marshal(bedrockAnthropicGenerateRequest{
   117  				Prompt:            builder.String(),
   118  				MaxTokensToSample: *settings.MaxTokenCount(),
   119  				Temperature:       *settings.Temperature(),
   120  				TopK:              *settings.TopK(),
   121  				TopP:              settings.TopP(),
   122  				StopSequences:     settings.StopSequences(),
   123  				AnthropicVersion:  "bedrock-2023-05-31",
   124  			})
   125  		} else if v.isAI21Model(model) {
   126  			body, err = json.Marshal(bedrockAI21GenerateRequest{
   127  				Prompt:        prompt,
   128  				MaxTokens:     *settings.MaxTokenCount(),
   129  				Temperature:   *settings.Temperature(),
   130  				TopP:          settings.TopP(),
   131  				StopSequences: settings.StopSequences(),
   132  			})
   133  		} else if v.isCohereModel(model) {
   134  			body, err = json.Marshal(bedrockCohereRequest{
   135  				Prompt:      prompt,
   136  				Temperature: *settings.Temperature(),
   137  				MaxTokens:   *settings.MaxTokenCount(),
   138  				// ReturnLikeliHood: "GENERATION", // contray to docs, this is invalid
   139  			})
   140  		}
   141  
   142  		headers["x-amzn-bedrock-save"] = "false"
   143  		if err != nil {
   144  			return nil, errors.Wrapf(err, "marshal body")
   145  		}
   146  	} else if v.isSagemaker(service) {
   147  		endpointUrl = v.buildSagemakerUrlFn(service, region, endpoint)
   148  		host = "runtime." + service + "." + region + ".amazonaws.com"
   149  		path = "/endpoints/" + endpoint + "/invocations"
   150  		if targetModel != "" {
   151  			headers["x-amzn-sagemaker-target-model"] = targetModel
   152  		}
   153  		if targetVariant != "" {
   154  			headers["x-amzn-sagemaker-target-variant"] = targetVariant
   155  		}
   156  		body, err = json.Marshal(sagemakerGenerateRequest{
   157  			Prompt: prompt,
   158  		})
   159  		if err != nil {
   160  			return nil, errors.Wrapf(err, "marshal body")
   161  		}
   162  	} else {
   163  		return nil, errors.Wrapf(err, "service error")
   164  	}
   165  
   166  	accessKey, err := v.getAwsAccessKey(ctx)
   167  	if err != nil {
   168  		return nil, errors.Wrapf(err, "AWS Access Key")
   169  	}
   170  	secretKey, err := v.getAwsAccessSecret(ctx)
   171  	if err != nil {
   172  		return nil, errors.Wrapf(err, "AWS Secret Key")
   173  	}
   174  
   175  	headers["host"] = host
   176  	amzDate, headers, authorizationHeader := getAuthHeader(accessKey, secretKey, host, service, region, path, body, headers)
   177  	headers["Authorization"] = authorizationHeader
   178  	headers["x-amz-date"] = amzDate
   179  
   180  	req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointUrl, bytes.NewReader(body))
   181  	if err != nil {
   182  		return nil, errors.Wrap(err, "create POST request")
   183  	}
   184  
   185  	for k, v := range headers {
   186  		req.Header.Set(k, v)
   187  	}
   188  
   189  	res, err := v.httpClient.Do(req)
   190  	if err != nil {
   191  		return nil, errors.Wrap(err, "send POST request")
   192  	}
   193  	defer res.Body.Close()
   194  
   195  	bodyBytes, err := io.ReadAll(res.Body)
   196  	if err != nil {
   197  		return nil, errors.Wrap(err, "read response body")
   198  	}
   199  
   200  	if v.isBedrock(service) {
   201  		return v.parseBedrockResponse(bodyBytes, res)
   202  	} else if v.isSagemaker(service) {
   203  		return v.parseSagemakerResponse(bodyBytes, res)
   204  	} else {
   205  		return &generativemodels.GenerateResponse{
   206  			Result: nil,
   207  		}, nil
   208  	}
   209  }
   210  
   211  func (v *aws) parseBedrockResponse(bodyBytes []byte, res *http.Response) (*generativemodels.GenerateResponse, error) {
   212  	var resBodyMap map[string]interface{}
   213  	if err := json.Unmarshal(bodyBytes, &resBodyMap); err != nil {
   214  		return nil, errors.Wrap(err, "unmarshal response body")
   215  	}
   216  
   217  	var resBody bedrockGenerateResponse
   218  	if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
   219  		return nil, errors.Wrap(err, "unmarshal response body")
   220  	}
   221  
   222  	if res.StatusCode != 200 || resBody.Message != nil {
   223  		if resBody.Message != nil {
   224  			return nil, fmt.Errorf("connection to AWS Bedrock failed with status: %v error: %s",
   225  				res.StatusCode, *resBody.Message)
   226  		}
   227  		return nil, fmt.Errorf("connection to AWS Bedrock failed with status: %d", res.StatusCode)
   228  	}
   229  
   230  	if len(resBody.Results) == 0 && len(resBody.Generations) == 0 {
   231  		return nil, fmt.Errorf("received empty response from AWS Bedrock")
   232  	}
   233  
   234  	var content string
   235  	if len(resBody.Results) > 0 && len(resBody.Results[0].CompletionReason) > 0 {
   236  		content = resBody.Results[0].OutputText
   237  	} else if len(resBody.Generations) > 0 {
   238  		content = resBody.Generations[0].Text
   239  	}
   240  
   241  	if content != "" {
   242  		return &generativemodels.GenerateResponse{
   243  			Result: &content,
   244  		}, nil
   245  	}
   246  
   247  	return &generativemodels.GenerateResponse{
   248  		Result: nil,
   249  	}, nil
   250  }
   251  
   252  func (v *aws) parseSagemakerResponse(bodyBytes []byte, res *http.Response) (*generativemodels.GenerateResponse, error) {
   253  	var resBody sagemakerGenerateResponse
   254  	if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
   255  		return nil, errors.Wrap(err, "unmarshal response body")
   256  	}
   257  
   258  	if res.StatusCode != 200 || resBody.Message != nil {
   259  		if resBody.Message != nil {
   260  			return nil, fmt.Errorf("connection to AWS Sagemaker failed with status: %v error: %s",
   261  				res.StatusCode, *resBody.Message)
   262  		}
   263  		return nil, fmt.Errorf("connection to AWS Sagemaker failed with status: %d", res.StatusCode)
   264  	}
   265  
   266  	if len(resBody.Generations) == 0 {
   267  		return nil, fmt.Errorf("received empty response from AWS Sagemaker")
   268  	}
   269  
   270  	if len(resBody.Generations) > 0 && len(resBody.Generations[0].Id) > 0 {
   271  		content := resBody.Generations[0].Text
   272  		if content != "" {
   273  			return &generativemodels.GenerateResponse{
   274  				Result: &content,
   275  			}, nil
   276  		}
   277  	}
   278  	return &generativemodels.GenerateResponse{
   279  		Result: nil,
   280  	}, nil
   281  }
   282  
   283  func (v *aws) isSagemaker(service string) bool {
   284  	return service == "sagemaker"
   285  }
   286  
   287  func (v *aws) isBedrock(service string) bool {
   288  	return service == "bedrock"
   289  }
   290  
   291  func (v *aws) generatePromptForTask(textProperties []map[string]string, task string) (string, error) {
   292  	marshal, err := json.Marshal(textProperties)
   293  	if err != nil {
   294  		return "", err
   295  	}
   296  	return fmt.Sprintf(`'%v:
   297  %v`, task, string(marshal)), nil
   298  }
   299  
   300  func (v *aws) generateForPrompt(textProperties map[string]string, prompt string) (string, error) {
   301  	all := compile.FindAll([]byte(prompt), -1)
   302  	for _, match := range all {
   303  		originalProperty := string(match)
   304  		replacedProperty := compile.FindStringSubmatch(originalProperty)[1]
   305  		replacedProperty = strings.TrimSpace(replacedProperty)
   306  		value := textProperties[replacedProperty]
   307  		if value == "" {
   308  			return "", errors.Errorf("Following property has empty value: '%v'. Make sure you spell the property name correctly, verify that the property exists and has a value", replacedProperty)
   309  		}
   310  		prompt = strings.ReplaceAll(prompt, originalProperty, value)
   311  	}
   312  	return prompt, nil
   313  }
   314  
   315  func (v *aws) getAwsAccessKey(ctx context.Context) (string, error) {
   316  	awsAccessKey := ctx.Value("X-Aws-Access-Key")
   317  	if awsAccessKeyHeader, ok := awsAccessKey.([]string); ok &&
   318  		len(awsAccessKeyHeader) > 0 && len(awsAccessKeyHeader[0]) > 0 {
   319  		return awsAccessKeyHeader[0], nil
   320  	}
   321  	if len(v.awsAccessKey) > 0 {
   322  		return v.awsAccessKey, nil
   323  	}
   324  	return "", errors.New("no access key found " +
   325  		"neither in request header: X-AWS-Access-Key " +
   326  		"nor in environment variable under AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY")
   327  }
   328  
   329  func (v *aws) getAwsAccessSecret(ctx context.Context) (string, error) {
   330  	awsAccessSecret := ctx.Value("X-Aws-Secret-Key")
   331  	if awsAccessSecretHeader, ok := awsAccessSecret.([]string); ok &&
   332  		len(awsAccessSecretHeader) > 0 && len(awsAccessSecretHeader[0]) > 0 {
   333  		return awsAccessSecretHeader[0], nil
   334  	}
   335  	if len(v.awsSecretKey) > 0 {
   336  		return v.awsSecretKey, nil
   337  	}
   338  	return "", errors.New("no secret found " +
   339  		"neither in request header: X-Aws-Secret-Key " +
   340  		"nor in environment variable under AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY")
   341  }
   342  
   343  func (v *aws) isAmazonModel(model string) bool {
   344  	return strings.Contains(model, "amazon")
   345  }
   346  
   347  func (v *aws) isAI21Model(model string) bool {
   348  	return strings.Contains(model, "ai21")
   349  }
   350  
   351  func (v *aws) isAnthropicModel(model string) bool {
   352  	return strings.Contains(model, "anthropic")
   353  }
   354  
   355  func (v *aws) isCohereModel(model string) bool {
   356  	return strings.Contains(model, "cohere")
   357  }
   358  
   359  type bedrockAmazonGenerateRequest struct {
   360  	InputText            string                `json:"inputText,omitempty"`
   361  	TextGenerationConfig *textGenerationConfig `json:"textGenerationConfig,omitempty"`
   362  }
   363  
   364  type bedrockAnthropicGenerateRequest struct {
   365  	Prompt            string   `json:"prompt,omitempty"`
   366  	MaxTokensToSample int      `json:"max_tokens_to_sample,omitempty"`
   367  	Temperature       float64  `json:"temperature,omitempty"`
   368  	TopK              int      `json:"top_k,omitempty"`
   369  	TopP              *float64 `json:"top_p,omitempty"`
   370  	StopSequences     []string `json:"stop_sequences,omitempty"`
   371  	AnthropicVersion  string   `json:"anthropic_version,omitempty"`
   372  }
   373  
   374  type bedrockAI21GenerateRequest struct {
   375  	Prompt           string   `json:"prompt,omitempty"`
   376  	MaxTokens        int      `json:"maxTokens,omitempty"`
   377  	Temperature      float64  `json:"temperature,omitempty"`
   378  	TopP             *float64 `json:"top_p,omitempty"`
   379  	StopSequences    []string `json:"stop_sequences,omitempty"`
   380  	CountPenalty     penalty  `json:"countPenalty,omitempty"`
   381  	PresencePenalty  penalty  `json:"presencePenalty,omitempty"`
   382  	FrequencyPenalty penalty  `json:"frequencyPenalty,omitempty"`
   383  }
   384  type bedrockCohereRequest struct {
   385  	Prompt           string  `json:"prompt,omitempty"`
   386  	MaxTokens        int     `json:"max_tokens,omitempty"`
   387  	Temperature      float64 `json:"temperature,omitempty"`
   388  	ReturnLikeliHood string  `json:"return_likelihood,omitempty"`
   389  }
   390  
   391  type penalty struct {
   392  	Scale int `json:"scale,omitempty"`
   393  }
   394  
   395  type sagemakerGenerateRequest struct {
   396  	Prompt string `json:"prompt,omitempty"`
   397  }
   398  
   399  type textGenerationConfig struct {
   400  	MaxTokenCount int      `json:"maxTokenCount"`
   401  	StopSequences []string `json:"stopSequences"`
   402  	Temperature   float64  `json:"temperature"`
   403  	TopP          int      `json:"topP"`
   404  }
   405  
   406  type bedrockGenerateResponse struct {
   407  	InputTextTokenCount int                 `json:"InputTextTokenCount,omitempty"`
   408  	Results             []Result            `json:"results,omitempty"`
   409  	Generations         []BedrockGeneration `json:"generations,omitempty"`
   410  	Message             *string             `json:"message,omitempty"`
   411  }
   412  
   413  type sagemakerGenerateResponse struct {
   414  	Generations []Generation `json:"generations,omitempty"`
   415  	Message     *string      `json:"message,omitempty"`
   416  }
   417  
   418  type Generation struct {
   419  	Id   string `json:"id,omitempty"`
   420  	Text string `json:"text,omitempty"`
   421  }
   422  
   423  type BedrockGeneration struct {
   424  	Id           string `json:"id,omitempty"`
   425  	Text         string `json:"text,omitempty"`
   426  	FinishReason string `json:"finish_reason,omitempty"`
   427  }
   428  
   429  type Result struct {
   430  	TokenCount       int    `json:"tokenCount,omitempty"`
   431  	OutputText       string `json:"outputText,omitempty"`
   432  	CompletionReason string `json:"completionReason,omitempty"`
   433  }