github.com/weaviate/weaviate@v1.24.6/modules/generative-aws/clients/aws.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "bytes" 16 "context" 17 "encoding/json" 18 "fmt" 19 "io" 20 "net/http" 21 "regexp" 22 "strings" 23 "time" 24 25 "github.com/pkg/errors" 26 "github.com/sirupsen/logrus" 27 "github.com/weaviate/weaviate/entities/moduletools" 28 "github.com/weaviate/weaviate/modules/generative-aws/config" 29 generativemodels "github.com/weaviate/weaviate/usecases/modulecomponents/additional/models" 30 ) 31 32 var compile, _ = regexp.Compile(`{([\w\s]*?)}`) 33 34 func buildBedrockUrl(service, region, model string) string { 35 urlTemplate := "https://%s.%s.amazonaws.com/model/%s/invoke" 36 return fmt.Sprintf(urlTemplate, fmt.Sprintf("%s-runtime", service), region, model) 37 } 38 39 func buildSagemakerUrl(service, region, endpoint string) string { 40 urlTemplate := "https://runtime.%s.%s.amazonaws.com/endpoints/%s/invocations" 41 return fmt.Sprintf(urlTemplate, service, region, endpoint) 42 } 43 44 type aws struct { 45 awsAccessKey string 46 awsSecretKey string 47 buildBedrockUrlFn func(service, region, model string) string 48 buildSagemakerUrlFn func(service, region, endpoint string) string 49 httpClient *http.Client 50 logger logrus.FieldLogger 51 } 52 53 func New(awsAccessKey string, awsSecretKey string, timeout time.Duration, logger logrus.FieldLogger) *aws { 54 return &aws{ 55 awsAccessKey: awsAccessKey, 56 awsSecretKey: awsSecretKey, 57 httpClient: &http.Client{ 58 Timeout: timeout, 59 }, 60 buildBedrockUrlFn: buildBedrockUrl, 61 buildSagemakerUrlFn: buildSagemakerUrl, 62 logger: logger, 63 } 64 } 65 66 func (v *aws) GenerateSingleResult(ctx context.Context, textProperties map[string]string, prompt string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) { 67 forPrompt, err := v.generateForPrompt(textProperties, prompt) 68 if err != nil { 69 return nil, err 70 } 71 return v.Generate(ctx, cfg, forPrompt) 72 } 73 74 func (v *aws) GenerateAllResults(ctx context.Context, textProperties []map[string]string, task string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) { 75 forTask, err := v.generatePromptForTask(textProperties, task) 76 if err != nil { 77 return nil, err 78 } 79 return v.Generate(ctx, cfg, forTask) 80 } 81 82 func (v *aws) Generate(ctx context.Context, cfg moduletools.ClassConfig, prompt string) (*generativemodels.GenerateResponse, error) { 83 settings := config.NewClassSettings(cfg) 84 service := settings.Service() 85 region := settings.Region() 86 model := settings.Model() 87 endpoint := settings.Endpoint() 88 targetModel := settings.TargetModel() 89 targetVariant := settings.TargetVariant() 90 91 var body []byte 92 var endpointUrl string 93 var host string 94 var path string 95 var err error 96 97 headers := map[string]string{ 98 "accept": "*/*", 99 "content-type": contentType, 100 } 101 102 if v.isBedrock(service) { 103 endpointUrl = v.buildBedrockUrlFn(service, region, model) 104 host = service + "-runtime" + "." + region + ".amazonaws.com" 105 path = "/model/" + model + "/invoke" 106 107 if v.isAmazonModel(model) { 108 body, err = json.Marshal(bedrockAmazonGenerateRequest{ 109 InputText: prompt, 110 }) 111 } else if v.isAnthropicModel(model) { 112 var builder strings.Builder 113 builder.WriteString("\n\nHuman: ") 114 builder.WriteString(prompt) 115 builder.WriteString("\n\nAssistant:") 116 body, err = json.Marshal(bedrockAnthropicGenerateRequest{ 117 Prompt: builder.String(), 118 MaxTokensToSample: *settings.MaxTokenCount(), 119 Temperature: *settings.Temperature(), 120 TopK: *settings.TopK(), 121 TopP: settings.TopP(), 122 StopSequences: settings.StopSequences(), 123 AnthropicVersion: "bedrock-2023-05-31", 124 }) 125 } else if v.isAI21Model(model) { 126 body, err = json.Marshal(bedrockAI21GenerateRequest{ 127 Prompt: prompt, 128 MaxTokens: *settings.MaxTokenCount(), 129 Temperature: *settings.Temperature(), 130 TopP: settings.TopP(), 131 StopSequences: settings.StopSequences(), 132 }) 133 } else if v.isCohereModel(model) { 134 body, err = json.Marshal(bedrockCohereRequest{ 135 Prompt: prompt, 136 Temperature: *settings.Temperature(), 137 MaxTokens: *settings.MaxTokenCount(), 138 // ReturnLikeliHood: "GENERATION", // contray to docs, this is invalid 139 }) 140 } 141 142 headers["x-amzn-bedrock-save"] = "false" 143 if err != nil { 144 return nil, errors.Wrapf(err, "marshal body") 145 } 146 } else if v.isSagemaker(service) { 147 endpointUrl = v.buildSagemakerUrlFn(service, region, endpoint) 148 host = "runtime." + service + "." + region + ".amazonaws.com" 149 path = "/endpoints/" + endpoint + "/invocations" 150 if targetModel != "" { 151 headers["x-amzn-sagemaker-target-model"] = targetModel 152 } 153 if targetVariant != "" { 154 headers["x-amzn-sagemaker-target-variant"] = targetVariant 155 } 156 body, err = json.Marshal(sagemakerGenerateRequest{ 157 Prompt: prompt, 158 }) 159 if err != nil { 160 return nil, errors.Wrapf(err, "marshal body") 161 } 162 } else { 163 return nil, errors.Wrapf(err, "service error") 164 } 165 166 accessKey, err := v.getAwsAccessKey(ctx) 167 if err != nil { 168 return nil, errors.Wrapf(err, "AWS Access Key") 169 } 170 secretKey, err := v.getAwsAccessSecret(ctx) 171 if err != nil { 172 return nil, errors.Wrapf(err, "AWS Secret Key") 173 } 174 175 headers["host"] = host 176 amzDate, headers, authorizationHeader := getAuthHeader(accessKey, secretKey, host, service, region, path, body, headers) 177 headers["Authorization"] = authorizationHeader 178 headers["x-amz-date"] = amzDate 179 180 req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointUrl, bytes.NewReader(body)) 181 if err != nil { 182 return nil, errors.Wrap(err, "create POST request") 183 } 184 185 for k, v := range headers { 186 req.Header.Set(k, v) 187 } 188 189 res, err := v.httpClient.Do(req) 190 if err != nil { 191 return nil, errors.Wrap(err, "send POST request") 192 } 193 defer res.Body.Close() 194 195 bodyBytes, err := io.ReadAll(res.Body) 196 if err != nil { 197 return nil, errors.Wrap(err, "read response body") 198 } 199 200 if v.isBedrock(service) { 201 return v.parseBedrockResponse(bodyBytes, res) 202 } else if v.isSagemaker(service) { 203 return v.parseSagemakerResponse(bodyBytes, res) 204 } else { 205 return &generativemodels.GenerateResponse{ 206 Result: nil, 207 }, nil 208 } 209 } 210 211 func (v *aws) parseBedrockResponse(bodyBytes []byte, res *http.Response) (*generativemodels.GenerateResponse, error) { 212 var resBodyMap map[string]interface{} 213 if err := json.Unmarshal(bodyBytes, &resBodyMap); err != nil { 214 return nil, errors.Wrap(err, "unmarshal response body") 215 } 216 217 var resBody bedrockGenerateResponse 218 if err := json.Unmarshal(bodyBytes, &resBody); err != nil { 219 return nil, errors.Wrap(err, "unmarshal response body") 220 } 221 222 if res.StatusCode != 200 || resBody.Message != nil { 223 if resBody.Message != nil { 224 return nil, fmt.Errorf("connection to AWS Bedrock failed with status: %v error: %s", 225 res.StatusCode, *resBody.Message) 226 } 227 return nil, fmt.Errorf("connection to AWS Bedrock failed with status: %d", res.StatusCode) 228 } 229 230 if len(resBody.Results) == 0 && len(resBody.Generations) == 0 { 231 return nil, fmt.Errorf("received empty response from AWS Bedrock") 232 } 233 234 var content string 235 if len(resBody.Results) > 0 && len(resBody.Results[0].CompletionReason) > 0 { 236 content = resBody.Results[0].OutputText 237 } else if len(resBody.Generations) > 0 { 238 content = resBody.Generations[0].Text 239 } 240 241 if content != "" { 242 return &generativemodels.GenerateResponse{ 243 Result: &content, 244 }, nil 245 } 246 247 return &generativemodels.GenerateResponse{ 248 Result: nil, 249 }, nil 250 } 251 252 func (v *aws) parseSagemakerResponse(bodyBytes []byte, res *http.Response) (*generativemodels.GenerateResponse, error) { 253 var resBody sagemakerGenerateResponse 254 if err := json.Unmarshal(bodyBytes, &resBody); err != nil { 255 return nil, errors.Wrap(err, "unmarshal response body") 256 } 257 258 if res.StatusCode != 200 || resBody.Message != nil { 259 if resBody.Message != nil { 260 return nil, fmt.Errorf("connection to AWS Sagemaker failed with status: %v error: %s", 261 res.StatusCode, *resBody.Message) 262 } 263 return nil, fmt.Errorf("connection to AWS Sagemaker failed with status: %d", res.StatusCode) 264 } 265 266 if len(resBody.Generations) == 0 { 267 return nil, fmt.Errorf("received empty response from AWS Sagemaker") 268 } 269 270 if len(resBody.Generations) > 0 && len(resBody.Generations[0].Id) > 0 { 271 content := resBody.Generations[0].Text 272 if content != "" { 273 return &generativemodels.GenerateResponse{ 274 Result: &content, 275 }, nil 276 } 277 } 278 return &generativemodels.GenerateResponse{ 279 Result: nil, 280 }, nil 281 } 282 283 func (v *aws) isSagemaker(service string) bool { 284 return service == "sagemaker" 285 } 286 287 func (v *aws) isBedrock(service string) bool { 288 return service == "bedrock" 289 } 290 291 func (v *aws) generatePromptForTask(textProperties []map[string]string, task string) (string, error) { 292 marshal, err := json.Marshal(textProperties) 293 if err != nil { 294 return "", err 295 } 296 return fmt.Sprintf(`'%v: 297 %v`, task, string(marshal)), nil 298 } 299 300 func (v *aws) generateForPrompt(textProperties map[string]string, prompt string) (string, error) { 301 all := compile.FindAll([]byte(prompt), -1) 302 for _, match := range all { 303 originalProperty := string(match) 304 replacedProperty := compile.FindStringSubmatch(originalProperty)[1] 305 replacedProperty = strings.TrimSpace(replacedProperty) 306 value := textProperties[replacedProperty] 307 if value == "" { 308 return "", errors.Errorf("Following property has empty value: '%v'. Make sure you spell the property name correctly, verify that the property exists and has a value", replacedProperty) 309 } 310 prompt = strings.ReplaceAll(prompt, originalProperty, value) 311 } 312 return prompt, nil 313 } 314 315 func (v *aws) getAwsAccessKey(ctx context.Context) (string, error) { 316 awsAccessKey := ctx.Value("X-Aws-Access-Key") 317 if awsAccessKeyHeader, ok := awsAccessKey.([]string); ok && 318 len(awsAccessKeyHeader) > 0 && len(awsAccessKeyHeader[0]) > 0 { 319 return awsAccessKeyHeader[0], nil 320 } 321 if len(v.awsAccessKey) > 0 { 322 return v.awsAccessKey, nil 323 } 324 return "", errors.New("no access key found " + 325 "neither in request header: X-AWS-Access-Key " + 326 "nor in environment variable under AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY") 327 } 328 329 func (v *aws) getAwsAccessSecret(ctx context.Context) (string, error) { 330 awsAccessSecret := ctx.Value("X-Aws-Secret-Key") 331 if awsAccessSecretHeader, ok := awsAccessSecret.([]string); ok && 332 len(awsAccessSecretHeader) > 0 && len(awsAccessSecretHeader[0]) > 0 { 333 return awsAccessSecretHeader[0], nil 334 } 335 if len(v.awsSecretKey) > 0 { 336 return v.awsSecretKey, nil 337 } 338 return "", errors.New("no secret found " + 339 "neither in request header: X-Aws-Secret-Key " + 340 "nor in environment variable under AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY") 341 } 342 343 func (v *aws) isAmazonModel(model string) bool { 344 return strings.Contains(model, "amazon") 345 } 346 347 func (v *aws) isAI21Model(model string) bool { 348 return strings.Contains(model, "ai21") 349 } 350 351 func (v *aws) isAnthropicModel(model string) bool { 352 return strings.Contains(model, "anthropic") 353 } 354 355 func (v *aws) isCohereModel(model string) bool { 356 return strings.Contains(model, "cohere") 357 } 358 359 type bedrockAmazonGenerateRequest struct { 360 InputText string `json:"inputText,omitempty"` 361 TextGenerationConfig *textGenerationConfig `json:"textGenerationConfig,omitempty"` 362 } 363 364 type bedrockAnthropicGenerateRequest struct { 365 Prompt string `json:"prompt,omitempty"` 366 MaxTokensToSample int `json:"max_tokens_to_sample,omitempty"` 367 Temperature float64 `json:"temperature,omitempty"` 368 TopK int `json:"top_k,omitempty"` 369 TopP *float64 `json:"top_p,omitempty"` 370 StopSequences []string `json:"stop_sequences,omitempty"` 371 AnthropicVersion string `json:"anthropic_version,omitempty"` 372 } 373 374 type bedrockAI21GenerateRequest struct { 375 Prompt string `json:"prompt,omitempty"` 376 MaxTokens int `json:"maxTokens,omitempty"` 377 Temperature float64 `json:"temperature,omitempty"` 378 TopP *float64 `json:"top_p,omitempty"` 379 StopSequences []string `json:"stop_sequences,omitempty"` 380 CountPenalty penalty `json:"countPenalty,omitempty"` 381 PresencePenalty penalty `json:"presencePenalty,omitempty"` 382 FrequencyPenalty penalty `json:"frequencyPenalty,omitempty"` 383 } 384 type bedrockCohereRequest struct { 385 Prompt string `json:"prompt,omitempty"` 386 MaxTokens int `json:"max_tokens,omitempty"` 387 Temperature float64 `json:"temperature,omitempty"` 388 ReturnLikeliHood string `json:"return_likelihood,omitempty"` 389 } 390 391 type penalty struct { 392 Scale int `json:"scale,omitempty"` 393 } 394 395 type sagemakerGenerateRequest struct { 396 Prompt string `json:"prompt,omitempty"` 397 } 398 399 type textGenerationConfig struct { 400 MaxTokenCount int `json:"maxTokenCount"` 401 StopSequences []string `json:"stopSequences"` 402 Temperature float64 `json:"temperature"` 403 TopP int `json:"topP"` 404 } 405 406 type bedrockGenerateResponse struct { 407 InputTextTokenCount int `json:"InputTextTokenCount,omitempty"` 408 Results []Result `json:"results,omitempty"` 409 Generations []BedrockGeneration `json:"generations,omitempty"` 410 Message *string `json:"message,omitempty"` 411 } 412 413 type sagemakerGenerateResponse struct { 414 Generations []Generation `json:"generations,omitempty"` 415 Message *string `json:"message,omitempty"` 416 } 417 418 type Generation struct { 419 Id string `json:"id,omitempty"` 420 Text string `json:"text,omitempty"` 421 } 422 423 type BedrockGeneration struct { 424 Id string `json:"id,omitempty"` 425 Text string `json:"text,omitempty"` 426 FinishReason string `json:"finish_reason,omitempty"` 427 } 428 429 type Result struct { 430 TokenCount int `json:"tokenCount,omitempty"` 431 OutputText string `json:"outputText,omitempty"` 432 CompletionReason string `json:"completionReason,omitempty"` 433 }