github.com/weaviate/weaviate@v1.24.6/modules/text2vec-aws/clients/aws.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "bytes" 16 "context" 17 "encoding/json" 18 "fmt" 19 "io" 20 "net/http" 21 "net/url" 22 "strings" 23 "time" 24 25 "github.com/google/uuid" 26 "github.com/pkg/errors" 27 "github.com/sirupsen/logrus" 28 "github.com/weaviate/weaviate/modules/text2vec-aws/ent" 29 "github.com/weaviate/weaviate/usecases/modulecomponents" 30 ) 31 32 type operationType string 33 34 var ( 35 vectorizeObject operationType = "vectorize_object" 36 vectorizeQuery operationType = "vectorize_query" 37 ) 38 39 func buildBedrockUrl(service, region, model string) string { 40 serviceName := service 41 if strings.HasPrefix(model, "cohere") { 42 serviceName = fmt.Sprintf("%s-runtime", serviceName) 43 } 44 urlTemplate := "https://%s.%s.amazonaws.com/model/%s/invoke" 45 return fmt.Sprintf(urlTemplate, serviceName, region, model) 46 } 47 48 func buildSagemakerUrl(service, region, endpoint string) string { 49 urlTemplate := "https://runtime.%s.%s.amazonaws.com/endpoints/%s/invocations" 50 return fmt.Sprintf(urlTemplate, service, region, endpoint) 51 } 52 53 type aws struct { 54 awsAccessKey string 55 awsSecret string 56 buildBedrockUrlFn func(service, region, model string) string 57 buildSagemakerUrlFn func(service, region, endpoint string) string 58 httpClient *http.Client 59 logger logrus.FieldLogger 60 } 61 62 func New(awsAccessKey string, awsSecret string, timeout time.Duration, logger logrus.FieldLogger) *aws { 63 return &aws{ 64 awsAccessKey: awsAccessKey, 65 awsSecret: awsSecret, 66 httpClient: &http.Client{ 67 Timeout: timeout, 68 }, 69 buildBedrockUrlFn: buildBedrockUrl, 70 buildSagemakerUrlFn: buildSagemakerUrl, 71 logger: logger, 72 } 73 } 74 75 func (v *aws) Vectorize(ctx context.Context, input []string, 76 config ent.VectorizationConfig, 77 ) (*ent.VectorizationResult, error) { 78 return v.vectorize(ctx, input, vectorizeObject, config) 79 } 80 81 func (v *aws) VectorizeQuery(ctx context.Context, input []string, 82 config ent.VectorizationConfig, 83 ) (*ent.VectorizationResult, error) { 84 return v.vectorize(ctx, input, vectorizeQuery, config) 85 } 86 87 func (v *aws) vectorize(ctx context.Context, input []string, operation operationType, config ent.VectorizationConfig) (*ent.VectorizationResult, error) { 88 service := v.getService(config) 89 region := v.getRegion(config) 90 model := v.getModel(config) 91 endpoint := v.getEndpoint(config) 92 targetModel := v.getTargetModel(config) 93 targetVariant := v.getTargetVariant(config) 94 95 var body []byte 96 var endpointUrl string 97 var host string 98 var path string 99 var err error 100 101 headers := map[string]string{ 102 "accept": "*/*", 103 "content-type": contentType, 104 } 105 106 if v.isBedrock(service) { 107 endpointUrl = v.buildBedrockUrlFn(service, region, model) 108 host, path, _ = extractHostAndPath(endpointUrl) 109 110 req, err := createRequestBody(model, input, operation) 111 if err != nil { 112 return nil, err 113 } 114 115 body, err = json.Marshal(req) 116 if err != nil { 117 return nil, errors.Wrapf(err, "marshal body") 118 } 119 } else if v.isSagemaker(service) { 120 endpointUrl = v.buildSagemakerUrlFn(service, region, endpoint) 121 host = "runtime." + service + "." + region + ".amazonaws.com" 122 path = "/endpoints/" + endpoint + "/invocations" 123 if targetModel != "" { 124 headers["x-amzn-sagemaker-target-model"] = targetModel 125 } 126 if targetVariant != "" { 127 headers["x-amzn-sagemaker-target-variant"] = targetVariant 128 } 129 body, err = json.Marshal(sagemakerEmbeddingsRequest{ 130 TextInputs: input, 131 }) 132 if err != nil { 133 return nil, errors.Wrapf(err, "marshal body") 134 } 135 } else { 136 return nil, errors.Wrapf(err, "service error") 137 } 138 139 accessKey, err := v.getAwsAccessKey(ctx) 140 if err != nil { 141 return nil, errors.Wrapf(err, "AWS Access Key") 142 } 143 secretKey, err := v.getAwsAccessSecret(ctx) 144 if err != nil { 145 return nil, errors.Wrapf(err, "AWS Secret Key") 146 } 147 148 headers["host"] = host 149 amzDate, headers, authorizationHeader := getAuthHeader(accessKey, secretKey, host, service, region, path, body, headers) 150 headers["Authorization"] = authorizationHeader 151 headers["x-amz-date"] = amzDate 152 153 req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointUrl, bytes.NewReader(body)) 154 if err != nil { 155 return nil, errors.Wrap(err, "create POST request") 156 } 157 158 for k, v := range headers { 159 req.Header.Set(k, v) 160 } 161 162 res, err := v.makeRequest(req, 30, 5) 163 if err != nil { 164 return nil, errors.Wrap(err, "send POST request") 165 } 166 defer res.Body.Close() 167 168 bodyBytes, err := io.ReadAll(res.Body) 169 if err != nil { 170 return nil, errors.Wrap(err, "read response body") 171 } 172 if v.isBedrock(service) { 173 return v.parseBedrockResponse(bodyBytes, res, input) 174 } else { 175 return v.parseSagemakerResponse(bodyBytes, res, input) 176 } 177 } 178 179 func (v *aws) makeRequest(req *http.Request, delayInSeconds int, maxRetries int) (*http.Response, error) { 180 var res *http.Response 181 var err error 182 183 // Generate a UUID for this request 184 requestID := uuid.New().String() 185 186 for i := 0; i < maxRetries; i++ { 187 res, err = v.httpClient.Do(req) 188 if err != nil { 189 return nil, errors.Wrap(err, "send POST request") 190 } 191 192 // If the status code is not 429 or 400, break the loop 193 if res.StatusCode != http.StatusTooManyRequests && res.StatusCode != http.StatusBadRequest { 194 break 195 } 196 197 v.logger.Debugf("Request ID %s to %s returned 429, retrying in %d seconds", requestID, req.URL, delayInSeconds) 198 199 // Sleep for a while and then continue to the next iteration 200 time.Sleep(time.Duration(delayInSeconds) * time.Second) 201 202 // Double the delay for the next iteration 203 delayInSeconds *= 2 204 205 } 206 207 return res, err 208 } 209 210 func (v *aws) parseBedrockResponse(bodyBytes []byte, res *http.Response, input []string) (*ent.VectorizationResult, error) { 211 var resBodyMap map[string]interface{} 212 if err := json.Unmarshal(bodyBytes, &resBodyMap); err != nil { 213 return nil, errors.Wrap(err, "unmarshal response body") 214 } 215 216 // if resBodyMap has inputTextTokenCount, it's a resonse from an Amazon model 217 // otherwise, it is a response from a Cohere model 218 var resBody bedrockEmbeddingResponse 219 if err := json.Unmarshal(bodyBytes, &resBody); err != nil { 220 return nil, errors.Wrap(err, "unmarshal response body") 221 } 222 223 if res.StatusCode != 200 || resBody.Message != nil { 224 if resBody.Message != nil { 225 return nil, fmt.Errorf("connection to AWS Bedrock failed with status: %v error: %s", 226 res.StatusCode, *resBody.Message) 227 } 228 return nil, fmt.Errorf("connection to AWS Bedrock failed with status: %d", res.StatusCode) 229 } 230 231 if len(resBody.Embedding) == 0 && len(resBody.Embeddings) == 0 { 232 return nil, fmt.Errorf("could not obtain vector from AWS Bedrock") 233 } 234 235 embedding := resBody.Embedding 236 if len(resBody.Embeddings) > 0 { 237 embedding = resBody.Embeddings[0] 238 } 239 240 return &ent.VectorizationResult{ 241 Text: input[0], 242 Dimensions: len(embedding), 243 Vector: embedding, 244 }, nil 245 } 246 247 func (v *aws) parseSagemakerResponse(bodyBytes []byte, res *http.Response, input []string) (*ent.VectorizationResult, error) { 248 var resBody sagemakerEmbeddingResponse 249 if err := json.Unmarshal(bodyBytes, &resBody); err != nil { 250 return nil, errors.Wrap(err, "unmarshal response body") 251 } 252 253 if res.StatusCode != 200 || resBody.Message != nil { 254 if resBody.Message != nil { 255 return nil, fmt.Errorf("connection to AWS failed with status: %v error: %s", 256 res.StatusCode, *resBody.Message) 257 } 258 return nil, fmt.Errorf("connection to AWS failed with status: %d", res.StatusCode) 259 } 260 261 if len(resBody.Embedding) == 0 { 262 return nil, errors.Errorf("empty embeddings response") 263 } 264 265 return &ent.VectorizationResult{ 266 Text: input[0], 267 Dimensions: len(resBody.Embedding[0]), 268 Vector: resBody.Embedding[0], 269 }, nil 270 } 271 272 func (v *aws) isSagemaker(service string) bool { 273 return service == "sagemaker" 274 } 275 276 func (v *aws) isBedrock(service string) bool { 277 return service == "bedrock" 278 } 279 280 func (v *aws) getAwsAccessKey(ctx context.Context) (string, error) { 281 awsAccessKey := ctx.Value("X-Aws-Access-Key") 282 if awsAccessKeyHeader, ok := awsAccessKey.([]string); ok && 283 len(awsAccessKeyHeader) > 0 && len(awsAccessKeyHeader[0]) > 0 { 284 return awsAccessKeyHeader[0], nil 285 } 286 if len(v.awsAccessKey) > 0 { 287 return v.awsAccessKey, nil 288 } 289 // try getting header from GRPC if not successful 290 if accessKey := modulecomponents.GetValueFromGRPC(ctx, "X-Aws-Access-Key"); len(accessKey) > 0 && len(accessKey[0]) > 0 { 291 return accessKey[0], nil 292 } 293 return "", errors.New("no access key found " + 294 "neither in request header: X-AWS-Access-Key " + 295 "nor in environment variable under AWS_ACCESS_KEY_ID or AWS_ACCESS_KEY") 296 } 297 298 func (v *aws) getAwsAccessSecret(ctx context.Context) (string, error) { 299 awsSecretKey := ctx.Value("X-Aws-Secret-Key") 300 if awsAccessSecretHeader, ok := awsSecretKey.([]string); ok && 301 len(awsAccessSecretHeader) > 0 && len(awsAccessSecretHeader[0]) > 0 { 302 return awsAccessSecretHeader[0], nil 303 } 304 if len(v.awsSecret) > 0 { 305 return v.awsSecret, nil 306 } 307 // try getting header from GRPC if not successful 308 if secretKey := modulecomponents.GetValueFromGRPC(ctx, "X-Aws-Secret-Key"); len(secretKey) > 0 && len(secretKey[0]) > 0 { 309 return secretKey[0], nil 310 } 311 return "", errors.New("no secret found " + 312 "neither in request header: X-AWS-Secret-Key " + 313 "nor in environment variable under AWS_SECRET_ACCESS_KEY or AWS_SECRET_KEY") 314 } 315 316 func (v *aws) getModel(config ent.VectorizationConfig) string { 317 return config.Model 318 } 319 320 func (v *aws) getRegion(config ent.VectorizationConfig) string { 321 return config.Region 322 } 323 324 func (v *aws) getService(config ent.VectorizationConfig) string { 325 return config.Service 326 } 327 328 func (v *aws) getEndpoint(config ent.VectorizationConfig) string { 329 return config.Endpoint 330 } 331 332 func (v *aws) getTargetModel(config ent.VectorizationConfig) string { 333 return config.TargetModel 334 } 335 336 func (v *aws) getTargetVariant(config ent.VectorizationConfig) string { 337 return config.TargetVariant 338 } 339 340 type bedrockEmbeddingsRequest struct { 341 InputText string `json:"inputText,omitempty"` 342 } 343 344 type bedrockCohereEmbeddingRequest struct { 345 Texts []string `json:"texts"` 346 InputType string `json:"input_type"` 347 } 348 349 type sagemakerEmbeddingsRequest struct { 350 TextInputs []string `json:"text_inputs,omitempty"` 351 } 352 353 type bedrockEmbeddingResponse struct { 354 InputTextTokenCount int `json:"InputTextTokenCount,omitempty"` 355 Embedding []float32 `json:"embedding,omitempty"` 356 Embeddings [][]float32 `json:"embeddings,omitempty"` 357 Message *string `json:"message,omitempty"` 358 } 359 type sagemakerEmbeddingResponse struct { 360 Embedding [][]float32 `json:"embedding,omitempty"` 361 ErrorCode *string `json:"ErrorCode,omitempty"` 362 LogStreamArn *string `json:"LogStreamArn,omitempty"` 363 OriginalMessage *string `json:"OriginalMessage,omitempty"` 364 Message *string `json:"Message,omitempty"` 365 OriginalStatusCode *int `json:"OriginalStatusCode,omitempty"` 366 } 367 368 func extractHostAndPath(endpointUrl string) (string, string, error) { 369 u, err := url.Parse(endpointUrl) 370 if err != nil { 371 return "", "", err 372 } 373 374 if u.Host == "" || u.Path == "" { 375 return "", "", fmt.Errorf("invalid endpoint URL: %s", endpointUrl) 376 } 377 378 return u.Host, u.Path, nil 379 } 380 381 func createRequestBody(model string, texts []string, operation operationType) (interface{}, error) { 382 modelParts := strings.Split(model, ".") 383 if len(modelParts) == 0 { 384 return nil, fmt.Errorf("invalid model: %s", model) 385 } 386 387 modelProvider := modelParts[0] 388 389 switch modelProvider { 390 case "amazon": 391 return bedrockEmbeddingsRequest{ 392 InputText: texts[0], 393 }, nil 394 case "cohere": 395 inputType := "search_document" 396 if operation == vectorizeQuery { 397 inputType = "search_query" 398 } 399 return bedrockCohereEmbeddingRequest{ 400 Texts: texts, 401 InputType: inputType, 402 }, nil 403 default: 404 return nil, fmt.Errorf("unknown model provider: %s", modelProvider) 405 } 406 }