github.com/weaviate/weaviate@v1.24.6/modules/text2vec-palm/clients/palm.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "bytes" 16 "context" 17 "encoding/json" 18 "fmt" 19 "io" 20 "net/http" 21 "strings" 22 "time" 23 24 "github.com/weaviate/weaviate/usecases/modulecomponents" 25 26 "github.com/pkg/errors" 27 "github.com/sirupsen/logrus" 28 "github.com/weaviate/weaviate/modules/text2vec-palm/ent" 29 ) 30 31 type taskType string 32 33 var ( 34 // Specifies the given text is a document in a search/retrieval setting 35 retrievalQuery taskType = "RETRIEVAL_QUERY" 36 // Specifies the given text is a query in a search/retrieval setting 37 retrievalDocument taskType = "RETRIEVAL_DOCUMENT" 38 ) 39 40 func buildURL(useGenerativeAI bool, apiEndoint, projectID, modelID string) string { 41 if useGenerativeAI { 42 // Generative AI supports only 1 embedding model: embedding-gecko-001. So for now 43 // in order to keep it simple we generate one variation of PaLM API url. 44 // For more context check out this link: 45 // https://developers.generativeai.google/models/language#model_variations 46 return "https://generativelanguage.googleapis.com/v1beta3/models/embedding-gecko-001:batchEmbedText" 47 } 48 urlTemplate := "https://%s/v1/projects/%s/locations/us-central1/publishers/google/models/%s:predict" 49 return fmt.Sprintf(urlTemplate, apiEndoint, projectID, modelID) 50 } 51 52 type palm struct { 53 apiKey string 54 httpClient *http.Client 55 urlBuilderFn func(useGenerativeAI bool, apiEndoint, projectID, modelID string) string 56 logger logrus.FieldLogger 57 } 58 59 func New(apiKey string, timeout time.Duration, logger logrus.FieldLogger) *palm { 60 return &palm{ 61 apiKey: apiKey, 62 httpClient: &http.Client{ 63 Timeout: timeout, 64 }, 65 urlBuilderFn: buildURL, 66 logger: logger, 67 } 68 } 69 70 func (v *palm) Vectorize(ctx context.Context, input []string, 71 config ent.VectorizationConfig, titlePropertyValue string, 72 ) (*ent.VectorizationResult, error) { 73 return v.vectorize(ctx, input, retrievalDocument, titlePropertyValue, config) 74 } 75 76 func (v *palm) VectorizeQuery(ctx context.Context, input []string, 77 config ent.VectorizationConfig, 78 ) (*ent.VectorizationResult, error) { 79 return v.vectorize(ctx, input, retrievalQuery, "", config) 80 } 81 82 func (v *palm) vectorize(ctx context.Context, input []string, taskType taskType, 83 titlePropertyValue string, config ent.VectorizationConfig, 84 ) (*ent.VectorizationResult, error) { 85 useGenerativeAIEndpoint := v.useGenerativeAIEndpoint(config) 86 87 payload := v.getPayload(useGenerativeAIEndpoint, input, taskType, titlePropertyValue, config) 88 body, err := json.Marshal(payload) 89 if err != nil { 90 return nil, errors.Wrapf(err, "marshal body") 91 } 92 93 endpointURL := v.urlBuilderFn(useGenerativeAIEndpoint, 94 v.getApiEndpoint(config), v.getProjectID(config), v.getModel(config)) 95 96 req, err := http.NewRequestWithContext(ctx, "POST", endpointURL, 97 bytes.NewReader(body)) 98 if err != nil { 99 return nil, errors.Wrap(err, "create POST request") 100 } 101 102 apiKey, err := v.getApiKey(ctx) 103 if err != nil { 104 return nil, errors.Wrapf(err, "Google API Key") 105 } 106 req.Header.Add("Content-Type", "application/json") 107 if useGenerativeAIEndpoint { 108 req.Header.Add("x-goog-api-key", apiKey) 109 } else { 110 req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", apiKey)) 111 } 112 113 res, err := v.httpClient.Do(req) 114 if err != nil { 115 return nil, errors.Wrap(err, "send POST request") 116 } 117 defer res.Body.Close() 118 119 bodyBytes, err := io.ReadAll(res.Body) 120 if err != nil { 121 return nil, errors.Wrap(err, "read response body") 122 } 123 124 if useGenerativeAIEndpoint { 125 return v.parseGenerativeAIApiResponse(res.StatusCode, bodyBytes, input) 126 } 127 return v.parseEmbeddingsResponse(res.StatusCode, bodyBytes, input) 128 } 129 130 func (v *palm) useGenerativeAIEndpoint(config ent.VectorizationConfig) bool { 131 return v.getApiEndpoint(config) == "generativelanguage.googleapis.com" 132 } 133 134 func (v *palm) getPayload(useGenerativeAI bool, input []string, 135 taskType taskType, title string, config ent.VectorizationConfig, 136 ) interface{} { 137 if useGenerativeAI { 138 return batchEmbedTextRequest{Texts: input} 139 } 140 isModelVersion001 := strings.HasSuffix(config.Model, "@001") 141 instances := make([]instance, len(input)) 142 for i := range input { 143 if isModelVersion001 { 144 instances[i] = instance{Content: input[i]} 145 } else { 146 instances[i] = instance{Content: input[i], TaskType: taskType, Title: title} 147 } 148 } 149 return embeddingsRequest{instances} 150 } 151 152 func (v *palm) checkResponse(statusCode int, palmApiError *palmApiError) error { 153 if statusCode != 200 || palmApiError != nil { 154 if palmApiError != nil { 155 return fmt.Errorf("connection to Google failed with status: %v error: %v", 156 statusCode, palmApiError.Message) 157 } 158 return fmt.Errorf("connection to Google failed with status: %d", statusCode) 159 } 160 return nil 161 } 162 163 func (v *palm) getApiKey(ctx context.Context) (string, error) { 164 if apiKeyValue := v.getValueFromContext(ctx, "X-Google-Api-Key"); apiKeyValue != "" { 165 return apiKeyValue, nil 166 } 167 if apiKeyValue := v.getValueFromContext(ctx, "X-Palm-Api-Key"); apiKeyValue != "" { 168 return apiKeyValue, nil 169 } 170 if len(v.apiKey) > 0 { 171 return v.apiKey, nil 172 } 173 return "", errors.New("no api key found " + 174 "neither in request header: X-Palm-Api-Key or X-Google-Api-Key " + 175 "nor in environment variable under PALM_APIKEY or GOOGLE_APIKEY") 176 } 177 178 func (v *palm) getValueFromContext(ctx context.Context, key string) string { 179 if value := ctx.Value(key); value != nil { 180 if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 { 181 return keyHeader[0] 182 } 183 } 184 // try getting header from GRPC if not successful 185 if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 { 186 return apiKey[0] 187 } 188 return "" 189 } 190 191 func (v *palm) parseGenerativeAIApiResponse(statusCode int, 192 bodyBytes []byte, input []string, 193 ) (*ent.VectorizationResult, error) { 194 var resBody batchEmbedTextResponse 195 if err := json.Unmarshal(bodyBytes, &resBody); err != nil { 196 return nil, errors.Wrap(err, "unmarshal response body") 197 } 198 199 if err := v.checkResponse(statusCode, resBody.Error); err != nil { 200 return nil, err 201 } 202 203 if len(resBody.Embeddings) == 0 { 204 return nil, errors.Errorf("empty embeddings response") 205 } 206 207 vectors := make([][]float32, len(resBody.Embeddings)) 208 for i := range resBody.Embeddings { 209 vectors[i] = resBody.Embeddings[i].Value 210 } 211 dimensions := len(resBody.Embeddings[0].Value) 212 213 return v.getResponse(input, dimensions, vectors) 214 } 215 216 func (v *palm) parseEmbeddingsResponse(statusCode int, 217 bodyBytes []byte, input []string, 218 ) (*ent.VectorizationResult, error) { 219 var resBody embeddingsResponse 220 if err := json.Unmarshal(bodyBytes, &resBody); err != nil { 221 return nil, errors.Wrap(err, "unmarshal response body") 222 } 223 224 if err := v.checkResponse(statusCode, resBody.Error); err != nil { 225 return nil, err 226 } 227 228 if len(resBody.Predictions) == 0 { 229 return nil, errors.Errorf("empty embeddings response") 230 } 231 232 vectors := make([][]float32, len(resBody.Predictions)) 233 for i := range resBody.Predictions { 234 vectors[i] = resBody.Predictions[i].Embeddings.Values 235 } 236 dimensions := len(resBody.Predictions[0].Embeddings.Values) 237 238 return v.getResponse(input, dimensions, vectors) 239 } 240 241 func (v *palm) getResponse(input []string, dimensions int, vectors [][]float32) (*ent.VectorizationResult, error) { 242 return &ent.VectorizationResult{ 243 Texts: input, 244 Dimensions: dimensions, 245 Vectors: vectors, 246 }, nil 247 } 248 249 func (v *palm) getApiEndpoint(config ent.VectorizationConfig) string { 250 return config.ApiEndpoint 251 } 252 253 func (v *palm) getProjectID(config ent.VectorizationConfig) string { 254 return config.ProjectID 255 } 256 257 func (v *palm) getModel(config ent.VectorizationConfig) string { 258 return config.Model 259 } 260 261 type embeddingsRequest struct { 262 Instances []instance `json:"instances,omitempty"` 263 } 264 265 type instance struct { 266 Content string `json:"content"` 267 TaskType taskType `json:"task_type,omitempty"` 268 Title string `json:"title,omitempty"` 269 } 270 271 type embeddingsResponse struct { 272 Predictions []prediction `json:"predictions,omitempty"` 273 Error *palmApiError `json:"error,omitempty"` 274 DeployedModelId string `json:"deployedModelId,omitempty"` 275 Model string `json:"model,omitempty"` 276 ModelDisplayName string `json:"modelDisplayName,omitempty"` 277 ModelVersionId string `json:"modelVersionId,omitempty"` 278 } 279 280 type prediction struct { 281 Embeddings embeddings `json:"embeddings,omitempty"` 282 SafetyAttributes *safetyAttributes `json:"safetyAttributes,omitempty"` 283 } 284 285 type embeddings struct { 286 Values []float32 `json:"values,omitempty"` 287 } 288 289 type safetyAttributes struct { 290 Scores []float64 `json:"scores,omitempty"` 291 Blocked *bool `json:"blocked,omitempty"` 292 Categories []string `json:"categories,omitempty"` 293 } 294 295 type palmApiError struct { 296 Code int `json:"code"` 297 Message string `json:"message"` 298 Status string `json:"status"` 299 } 300 301 type batchEmbedTextRequest struct { 302 Texts []string `json:"texts,omitempty"` 303 } 304 305 type batchEmbedTextResponse struct { 306 Embeddings []embedding `json:"embeddings,omitempty"` 307 Error *palmApiError `json:"error,omitempty"` 308 } 309 310 type embedding struct { 311 Value []float32 `json:"value,omitempty"` 312 }