github.com/weaviate/weaviate@v1.24.6/modules/multi2vec-palm/clients/palm.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "bytes" 16 "context" 17 "encoding/json" 18 "fmt" 19 "io" 20 "net/http" 21 "time" 22 23 "github.com/weaviate/weaviate/usecases/modulecomponents" 24 25 "github.com/pkg/errors" 26 "github.com/sirupsen/logrus" 27 "github.com/weaviate/weaviate/modules/multi2vec-palm/ent" 28 libvectorizer "github.com/weaviate/weaviate/usecases/vectorizer" 29 ) 30 31 func buildURL(location, projectID, model string) string { 32 return fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict", 33 location, projectID, location, model) 34 } 35 36 type palm struct { 37 apiKey string 38 httpClient *http.Client 39 urlBuilderFn func(location, projectID, model string) string 40 logger logrus.FieldLogger 41 } 42 43 func New(apiKey string, timeout time.Duration, logger logrus.FieldLogger) *palm { 44 return &palm{ 45 apiKey: apiKey, 46 httpClient: &http.Client{ 47 Timeout: timeout, 48 }, 49 urlBuilderFn: buildURL, 50 logger: logger, 51 } 52 } 53 54 func (v *palm) Vectorize(ctx context.Context, 55 texts, images, videos []string, config ent.VectorizationConfig, 56 ) (*ent.VectorizationResult, error) { 57 return v.vectorize(ctx, texts, images, videos, config) 58 } 59 60 func (v *palm) VectorizeQuery(ctx context.Context, input []string, 61 config ent.VectorizationConfig, 62 ) (*ent.VectorizationResult, error) { 63 return v.vectorize(ctx, input, nil, nil, config) 64 } 65 66 func (v *palm) vectorize(ctx context.Context, 67 texts, images, videos []string, config ent.VectorizationConfig, 68 ) (*ent.VectorizationResult, error) { 69 var textEmbeddings [][]float32 70 var imageEmbeddings [][]float32 71 var videoEmbeddings [][]float32 72 endpointURL := v.getURL(config) 73 maxCount := max(len(texts), len(images), len(videos)) 74 for i := 0; i < maxCount; i++ { 75 text := v.safelyGet(texts, i) 76 image := v.safelyGet(images, i) 77 video := v.safelyGet(videos, i) 78 payload := v.getPayload(text, image, video, config) 79 statusCode, res, err := v.sendRequest(ctx, endpointURL, payload) 80 if err != nil { 81 return nil, err 82 } 83 textVectors, imageVectors, videoVectors, err := v.getEmbeddingsFromResponse(statusCode, res) 84 if err != nil { 85 return nil, err 86 } 87 textEmbeddings = append(textEmbeddings, textVectors...) 88 imageEmbeddings = append(imageEmbeddings, imageVectors...) 89 videoEmbeddings = append(videoEmbeddings, videoVectors...) 90 } 91 92 return v.getResponse(textEmbeddings, imageEmbeddings, videoEmbeddings) 93 } 94 95 func (v *palm) safelyGet(input []string, i int) string { 96 if i < len(input) { 97 return input[i] 98 } 99 return "" 100 } 101 102 func (v *palm) sendRequest(ctx context.Context, 103 endpointURL string, payload embeddingsRequest, 104 ) (int, embeddingsResponse, error) { 105 body, err := json.Marshal(payload) 106 if err != nil { 107 return 0, embeddingsResponse{}, errors.Wrapf(err, "marshal body") 108 } 109 110 req, err := http.NewRequestWithContext(ctx, "POST", endpointURL, 111 bytes.NewReader(body)) 112 if err != nil { 113 return 0, embeddingsResponse{}, errors.Wrap(err, "create POST request") 114 } 115 116 apiKey, err := v.getApiKey(ctx) 117 if err != nil { 118 return 0, embeddingsResponse{}, errors.Wrapf(err, "Google API Key") 119 } 120 req.Header.Add("Content-Type", "application/json; charset=utf-8") 121 req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", apiKey)) 122 123 res, err := v.httpClient.Do(req) 124 if err != nil { 125 return 0, embeddingsResponse{}, errors.Wrap(err, "send POST request") 126 } 127 defer res.Body.Close() 128 129 bodyBytes, err := io.ReadAll(res.Body) 130 if err != nil { 131 return 0, embeddingsResponse{}, errors.Wrap(err, "read response body") 132 } 133 134 var resBody embeddingsResponse 135 if err := json.Unmarshal(bodyBytes, &resBody); err != nil { 136 return 0, embeddingsResponse{}, errors.Wrap(err, "unmarshal response body") 137 } 138 139 return res.StatusCode, resBody, nil 140 } 141 142 func (v *palm) getURL(config ent.VectorizationConfig) string { 143 return v.urlBuilderFn(config.Location, config.ProjectID, config.Model) 144 } 145 146 func (v *palm) getPayload(text, img, vid string, config ent.VectorizationConfig) embeddingsRequest { 147 inst := instance{} 148 if text != "" { 149 inst.Text = &text 150 } 151 if img != "" { 152 inst.Image = &image{BytesBase64Encoded: img} 153 } 154 if vid != "" { 155 inst.Video = &video{ 156 BytesBase64Encoded: vid, 157 VideoSegmentConfig: videoSegmentConfig{IntervalSec: &config.VideoIntervalSeconds}, 158 } 159 } 160 return embeddingsRequest{ 161 Instances: []instance{inst}, 162 Parameters: parameters{Dimension: config.Dimensions}, 163 } 164 } 165 166 func (v *palm) checkResponse(statusCode int, palmApiError *palmApiError) error { 167 if statusCode != 200 || palmApiError != nil { 168 if palmApiError != nil { 169 return fmt.Errorf("connection to Google failed with status: %v error: %v", 170 statusCode, palmApiError.Message) 171 } 172 return fmt.Errorf("connection to Google failed with status: %d", statusCode) 173 } 174 return nil 175 } 176 177 func (v *palm) getApiKey(ctx context.Context) (string, error) { 178 if apiKeyValue := v.getValueFromContext(ctx, "X-Google-Api-Key"); apiKeyValue != "" { 179 return apiKeyValue, nil 180 } 181 if apiKeyValue := v.getValueFromContext(ctx, "X-Palm-Api-Key"); apiKeyValue != "" { 182 return apiKeyValue, nil 183 } 184 if len(v.apiKey) > 0 { 185 return v.apiKey, nil 186 } 187 return "", errors.New("no api key found " + 188 "neither in request header: X-Palm-Api-Key or X-Google-Api-Key " + 189 "nor in environment variable under PALM_APIKEY or GOOGLE_APIKEY") 190 } 191 192 func (v *palm) getValueFromContext(ctx context.Context, key string) string { 193 if value := ctx.Value(key); value != nil { 194 if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 { 195 return keyHeader[0] 196 } 197 } 198 // try getting header from GRPC if not successful 199 if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 { 200 return apiKey[0] 201 } 202 return "" 203 } 204 205 func (v *palm) getEmbeddingsFromResponse(statusCode int, resBody embeddingsResponse) ( 206 textEmbeddings [][]float32, 207 imageEmbeddings [][]float32, 208 videoEmbeddings [][]float32, 209 err error, 210 ) { 211 if respErr := v.checkResponse(statusCode, resBody.Error); respErr != nil { 212 err = respErr 213 return 214 } 215 216 if len(resBody.Predictions) == 0 { 217 err = errors.Errorf("empty embeddings response") 218 return 219 } 220 221 for _, p := range resBody.Predictions { 222 if len(p.TextEmbedding) > 0 { 223 textEmbeddings = append(textEmbeddings, p.TextEmbedding) 224 } 225 if len(p.ImageEmbedding) > 0 { 226 imageEmbeddings = append(imageEmbeddings, p.ImageEmbedding) 227 } 228 if len(p.VideoEmbeddings) > 0 { 229 var embeddings [][]float32 230 for _, videoEmbedding := range p.VideoEmbeddings { 231 embeddings = append(embeddings, videoEmbedding.Embedding) 232 } 233 embedding := embeddings[0] 234 if len(embeddings) > 1 { 235 embedding = libvectorizer.CombineVectors(embeddings) 236 } 237 videoEmbeddings = append(videoEmbeddings, embedding) 238 } 239 } 240 return 241 } 242 243 func (v *palm) getResponse(textVectors, imageVectors, videoVectors [][]float32) (*ent.VectorizationResult, error) { 244 return &ent.VectorizationResult{ 245 TextVectors: textVectors, 246 ImageVectors: imageVectors, 247 VideoVectors: videoVectors, 248 }, nil 249 } 250 251 type embeddingsRequest struct { 252 Instances []instance `json:"instances,omitempty"` 253 Parameters parameters `json:"parameters,omitempty"` 254 } 255 256 type parameters struct { 257 Dimension int64 `json:"dimension,omitempty"` 258 } 259 260 type instance struct { 261 Text *string `json:"text,omitempty"` 262 Image *image `json:"image,omitempty"` 263 Video *video `json:"video,omitempty"` 264 } 265 266 type image struct { 267 BytesBase64Encoded string `json:"bytesBase64Encoded"` 268 } 269 270 type video struct { 271 BytesBase64Encoded string `json:"bytesBase64Encoded"` 272 VideoSegmentConfig videoSegmentConfig `json:"videoSegmentConfig"` 273 } 274 275 type videoSegmentConfig struct { 276 StartOffsetSec *int64 `json:"startOffsetSec,omitempty"` 277 EndOffsetSec *int64 `json:"endOffsetSec,omitempty"` 278 IntervalSec *int64 `json:"intervalSec,omitempty"` 279 } 280 281 type embeddingsResponse struct { 282 Predictions []prediction `json:"predictions,omitempty"` 283 Error *palmApiError `json:"error,omitempty"` 284 DeployedModelId string `json:"deployedModelId,omitempty"` 285 } 286 287 type prediction struct { 288 TextEmbedding []float32 `json:"textEmbedding,omitempty"` 289 ImageEmbedding []float32 `json:"imageEmbedding,omitempty"` 290 VideoEmbeddings []videoEmbedding `json:"videoEmbeddings,omitempty"` 291 } 292 293 type videoEmbedding struct { 294 StartOffsetSec *int64 `json:"startOffsetSec,omitempty"` 295 EndOffsetSec *int64 `json:"endOffsetSec,omitempty"` 296 Embedding []float32 `json:"embedding,omitempty"` 297 } 298 299 type palmApiError struct { 300 Code int `json:"code"` 301 Message string `json:"message"` 302 Status string `json:"status"` 303 }