github.com/weaviate/weaviate@v1.24.6/modules/generative-palm/clients/palm.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "bytes" 16 "context" 17 "encoding/json" 18 "fmt" 19 "io" 20 "net/http" 21 "regexp" 22 "strings" 23 "time" 24 25 "github.com/weaviate/weaviate/usecases/modulecomponents" 26 27 "github.com/pkg/errors" 28 "github.com/sirupsen/logrus" 29 "github.com/weaviate/weaviate/entities/moduletools" 30 "github.com/weaviate/weaviate/modules/generative-palm/config" 31 generativemodels "github.com/weaviate/weaviate/usecases/modulecomponents/additional/models" 32 ) 33 34 type harmCategory string 35 36 var ( 37 // Category is unspecified. 38 HarmCategoryUnspecified harmCategory = "HARM_CATEGORY_UNSPECIFIED" 39 // Negative or harmful comments targeting identity and/or protected attribute. 40 HarmCategoryDerogatory harmCategory = "HARM_CATEGORY_DEROGATORY" 41 // Content that is rude, disrepspectful, or profane. 42 HarmCategoryToxicity harmCategory = "HARM_CATEGORY_TOXICITY" 43 // Describes scenarios depictng violence against an individual or group, or general descriptions of gore. 44 HarmCategoryViolence harmCategory = "HARM_CATEGORY_VIOLENCE" 45 // Contains references to sexual acts or other lewd content. 46 HarmCategorySexual harmCategory = "HARM_CATEGORY_SEXUAL" 47 // Promotes unchecked medical advice. 48 HarmCategoryMedical harmCategory = "HARM_CATEGORY_MEDICAL" 49 // Dangerous content that promotes, facilitates, or encourages harmful acts. 50 HarmCategoryDangerous harmCategory = "HARM_CATEGORY_DANGEROUS" 51 // Harassment content. 52 HarmCategoryHarassment harmCategory = "HARM_CATEGORY_HARASSMENT" 53 // Hate speech and content. 54 HarmCategoryHate_speech harmCategory = "HARM_CATEGORY_HATE_SPEECH" 55 // Sexually explicit content. 56 HarmCategorySexually_explicit harmCategory = "HARM_CATEGORY_SEXUALLY_EXPLICIT" 57 // Dangerous content. 58 HarmCategoryDangerous_content harmCategory = "HARM_CATEGORY_DANGEROUS_CONTENT" 59 ) 60 61 type harmBlockThreshold string 62 63 var ( 64 // Threshold is unspecified. 65 HarmBlockThresholdUnspecified harmBlockThreshold = "HARM_BLOCK_THRESHOLD_UNSPECIFIED" 66 // Content with NEGLIGIBLE will be allowed. 67 BlockLowAndAbove harmBlockThreshold = "BLOCK_LOW_AND_ABOVE" 68 // Content with NEGLIGIBLE and LOW will be allowed. 69 BlockMediumAndAbove harmBlockThreshold = "BLOCK_MEDIUM_AND_ABOVE" 70 // Content with NEGLIGIBLE, LOW, and MEDIUM will be allowed. 71 BlockOnlyHigh harmBlockThreshold = "BLOCK_ONLY_HIGH" 72 // All content will be allowed. 73 BlockNone harmBlockThreshold = "BLOCK_NONE" 74 ) 75 76 type harmProbability string 77 78 var ( 79 // Probability is unspecified. 80 HARM_PROBABILITY_UNSPECIFIED harmProbability = "HARM_PROBABILITY_UNSPECIFIED" 81 // Content has a negligible chance of being unsafe. 82 NEGLIGIBLE harmProbability = "NEGLIGIBLE" 83 // Content has a low chance of being unsafe. 84 LOW harmProbability = "LOW" 85 // Content has a medium chance of being unsafe. 86 MEDIUM harmProbability = "MEDIUM" 87 // Content has a high chance of being unsafe. 88 HIGH harmProbability = "HIGH" 89 ) 90 91 var compile, _ = regexp.Compile(`{([\w\s]*?)}`) 92 93 func buildURL(useGenerativeAI bool, apiEndoint, projectID, modelID string) string { 94 if useGenerativeAI { 95 // Generative AI endpoints, for more context check out this link: 96 // https://developers.generativeai.google/models/language#model_variations 97 // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage 98 if strings.HasPrefix(modelID, "gemini") { 99 return fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:generateContent", modelID) 100 } 101 return "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage" 102 } 103 urlTemplate := "https://%s/v1/projects/%s/locations/us-central1/publishers/google/models/%s:predict" 104 return fmt.Sprintf(urlTemplate, apiEndoint, projectID, modelID) 105 } 106 107 type palm struct { 108 apiKey string 109 buildUrlFn func(useGenerativeAI bool, apiEndoint, projectID, modelID string) string 110 httpClient *http.Client 111 logger logrus.FieldLogger 112 } 113 114 func New(apiKey string, timeout time.Duration, logger logrus.FieldLogger) *palm { 115 return &palm{ 116 apiKey: apiKey, 117 httpClient: &http.Client{ 118 Timeout: timeout, 119 }, 120 buildUrlFn: buildURL, 121 logger: logger, 122 } 123 } 124 125 func (v *palm) GenerateSingleResult(ctx context.Context, textProperties map[string]string, prompt string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) { 126 forPrompt, err := v.generateForPrompt(textProperties, prompt) 127 if err != nil { 128 return nil, err 129 } 130 return v.Generate(ctx, cfg, forPrompt) 131 } 132 133 func (v *palm) GenerateAllResults(ctx context.Context, textProperties []map[string]string, task string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) { 134 forTask, err := v.generatePromptForTask(textProperties, task) 135 if err != nil { 136 return nil, err 137 } 138 return v.Generate(ctx, cfg, forTask) 139 } 140 141 func (v *palm) Generate(ctx context.Context, cfg moduletools.ClassConfig, prompt string) (*generativemodels.GenerateResponse, error) { 142 settings := config.NewClassSettings(cfg) 143 144 useGenerativeAIEndpoint := v.useGenerativeAIEndpoint(settings.ApiEndpoint()) 145 modelID := settings.ModelID() 146 if settings.EndpointID() != "" { 147 modelID = settings.EndpointID() 148 } 149 150 endpointURL := v.buildUrlFn(useGenerativeAIEndpoint, settings.ApiEndpoint(), settings.ProjectID(), modelID) 151 input := v.getPayload(useGenerativeAIEndpoint, prompt, settings) 152 153 body, err := json.Marshal(input) 154 if err != nil { 155 return nil, errors.Wrap(err, "marshal body") 156 } 157 158 req, err := http.NewRequestWithContext(ctx, "POST", endpointURL, 159 bytes.NewReader(body)) 160 if err != nil { 161 return nil, errors.Wrap(err, "create POST request") 162 } 163 164 apiKey, err := v.getApiKey(ctx) 165 if err != nil { 166 return nil, errors.Wrapf(err, "Google API Key") 167 } 168 req.Header.Add("Content-Type", "application/json") 169 if useGenerativeAIEndpoint { 170 req.Header.Add("x-goog-api-key", apiKey) 171 } else { 172 req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", apiKey)) 173 } 174 175 res, err := v.httpClient.Do(req) 176 if err != nil { 177 return nil, errors.Wrap(err, "send POST request") 178 } 179 defer res.Body.Close() 180 181 bodyBytes, err := io.ReadAll(res.Body) 182 if err != nil { 183 return nil, errors.Wrap(err, "read response body") 184 } 185 186 if useGenerativeAIEndpoint { 187 if strings.HasPrefix(modelID, "gemini") { 188 return v.parseGenerateContentResponse(res.StatusCode, bodyBytes) 189 } 190 return v.parseGenerateMessageResponse(res.StatusCode, bodyBytes) 191 } 192 return v.parseResponse(res.StatusCode, bodyBytes) 193 } 194 195 func (v *palm) parseGenerateMessageResponse(statusCode int, bodyBytes []byte) (*generativemodels.GenerateResponse, error) { 196 var resBody generateMessageResponse 197 if err := json.Unmarshal(bodyBytes, &resBody); err != nil { 198 return nil, errors.Wrap(err, "unmarshal response body") 199 } 200 201 if err := v.checkResponse(statusCode, resBody.Error); err != nil { 202 return nil, err 203 } 204 205 if len(resBody.Candidates) > 0 { 206 return v.getGenerateResponse(resBody.Candidates[0].Content) 207 } 208 209 return &generativemodels.GenerateResponse{ 210 Result: nil, 211 }, nil 212 } 213 214 func (v *palm) parseGenerateContentResponse(statusCode int, bodyBytes []byte) (*generativemodels.GenerateResponse, error) { 215 var resBody generateContentResponse 216 if err := json.Unmarshal(bodyBytes, &resBody); err != nil { 217 return nil, errors.Wrap(err, "unmarshal response body") 218 } 219 220 if err := v.checkResponse(statusCode, resBody.Error); err != nil { 221 return nil, err 222 } 223 224 if len(resBody.Candidates) > 0 && len(resBody.Candidates[0].Content.Parts) > 0 { 225 return v.getGenerateResponse(resBody.Candidates[0].Content.Parts[0].Text) 226 } 227 228 return &generativemodels.GenerateResponse{ 229 Result: nil, 230 }, nil 231 } 232 233 func (v *palm) parseResponse(statusCode int, bodyBytes []byte) (*generativemodels.GenerateResponse, error) { 234 var resBody generateResponse 235 if err := json.Unmarshal(bodyBytes, &resBody); err != nil { 236 return nil, errors.Wrap(err, "unmarshal response body") 237 } 238 239 if err := v.checkResponse(statusCode, resBody.Error); err != nil { 240 return nil, err 241 } 242 243 if len(resBody.Predictions) > 0 && len(resBody.Predictions[0].Candidates) > 0 { 244 return v.getGenerateResponse(resBody.Predictions[0].Candidates[0].Content) 245 } 246 247 return &generativemodels.GenerateResponse{ 248 Result: nil, 249 }, nil 250 } 251 252 func (v *palm) getGenerateResponse(content string) (*generativemodels.GenerateResponse, error) { 253 if content != "" { 254 trimmedResponse := strings.Trim(content, "\n") 255 return &generativemodels.GenerateResponse{ 256 Result: &trimmedResponse, 257 }, nil 258 } 259 260 return &generativemodels.GenerateResponse{ 261 Result: nil, 262 }, nil 263 } 264 265 func (v *palm) checkResponse(statusCode int, palmApiError *palmApiError) error { 266 if statusCode != 200 || palmApiError != nil { 267 if palmApiError != nil { 268 return fmt.Errorf("connection to Google failed with status: %v error: %v", 269 statusCode, palmApiError.Message) 270 } 271 return fmt.Errorf("connection to Google failed with status: %d", statusCode) 272 } 273 return nil 274 } 275 276 func (v *palm) useGenerativeAIEndpoint(apiEndpoint string) bool { 277 return apiEndpoint == "generativelanguage.googleapis.com" 278 } 279 280 func (v *palm) getPayload(useGenerativeAI bool, prompt string, settings config.ClassSettings) any { 281 if useGenerativeAI { 282 if strings.HasPrefix(settings.ModelID(), "gemini") { 283 input := generateContentRequest{ 284 Contents: []content{ 285 { 286 Role: "user", 287 Parts: []part{ 288 { 289 Text: prompt, 290 }, 291 }, 292 }, 293 }, 294 GenerationConfig: &generationConfig{ 295 Temperature: settings.Temperature(), 296 TopP: settings.TopP(), 297 TopK: settings.TopK(), 298 CandidateCount: 1, 299 }, 300 SafetySettings: []safetySetting{ 301 { 302 Category: HarmCategoryHarassment, 303 Threshold: BlockMediumAndAbove, 304 }, 305 { 306 Category: HarmCategoryHate_speech, 307 Threshold: BlockMediumAndAbove, 308 }, 309 { 310 Category: HarmCategoryDangerous_content, 311 Threshold: BlockMediumAndAbove, 312 }, 313 { 314 Category: HarmCategoryDangerous_content, 315 Threshold: BlockMediumAndAbove, 316 }, 317 }, 318 } 319 return input 320 } 321 input := generateMessageRequest{ 322 Prompt: &generateMessagePrompt{ 323 Messages: []generateMessage{ 324 { 325 Content: prompt, 326 }, 327 }, 328 }, 329 Temperature: settings.Temperature(), 330 TopP: settings.TopP(), 331 TopK: settings.TopK(), 332 CandidateCount: 1, 333 } 334 return input 335 } 336 input := generateInput{ 337 Instances: []instance{ 338 { 339 Messages: []message{ 340 { 341 Content: prompt, 342 }, 343 }, 344 }, 345 }, 346 Parameters: parameters{ 347 Temperature: settings.Temperature(), 348 MaxOutputTokens: settings.TokenLimit(), 349 TopP: settings.TopP(), 350 TopK: settings.TopK(), 351 }, 352 } 353 return input 354 } 355 356 func (v *palm) generatePromptForTask(textProperties []map[string]string, task string) (string, error) { 357 marshal, err := json.Marshal(textProperties) 358 if err != nil { 359 return "", err 360 } 361 return fmt.Sprintf(`'%v: 362 %v`, task, string(marshal)), nil 363 } 364 365 func (v *palm) generateForPrompt(textProperties map[string]string, prompt string) (string, error) { 366 all := compile.FindAll([]byte(prompt), -1) 367 for _, match := range all { 368 originalProperty := string(match) 369 replacedProperty := compile.FindStringSubmatch(originalProperty)[1] 370 replacedProperty = strings.TrimSpace(replacedProperty) 371 value := textProperties[replacedProperty] 372 if value == "" { 373 return "", errors.Errorf("Following property has empty value: '%v'. Make sure you spell the property name correctly, verify that the property exists and has a value", replacedProperty) 374 } 375 prompt = strings.ReplaceAll(prompt, originalProperty, value) 376 } 377 return prompt, nil 378 } 379 380 func (v *palm) getApiKey(ctx context.Context) (string, error) { 381 if apiKeyValue := v.getValueFromContext(ctx, "X-Google-Api-Key"); apiKeyValue != "" { 382 return apiKeyValue, nil 383 } 384 if apiKeyValue := v.getValueFromContext(ctx, "X-Palm-Api-Key"); apiKeyValue != "" { 385 return apiKeyValue, nil 386 } 387 if len(v.apiKey) > 0 { 388 return v.apiKey, nil 389 } 390 return "", errors.New("no api key found " + 391 "neither in request header: X-Palm-Api-Key or X-Google-Api-Key " + 392 "nor in environment variable under PALM_APIKEY or GOOGLE_APIKEY") 393 } 394 395 func (v *palm) getValueFromContext(ctx context.Context, key string) string { 396 if value := ctx.Value(key); value != nil { 397 if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 { 398 return keyHeader[0] 399 } 400 } 401 // try getting header from GRPC if not successful 402 if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 { 403 return apiKey[0] 404 } 405 return "" 406 } 407 408 type generateInput struct { 409 Instances []instance `json:"instances,omitempty"` 410 Parameters parameters `json:"parameters"` 411 } 412 413 type instance struct { 414 Context string `json:"context,omitempty"` 415 Messages []message `json:"messages,omitempty"` 416 Examples []example `json:"examples,omitempty"` 417 } 418 419 type message struct { 420 Author string `json:"author"` 421 Content string `json:"content"` 422 } 423 424 type example struct { 425 Input string `json:"input"` 426 Output string `json:"output"` 427 } 428 429 type parameters struct { 430 Temperature float64 `json:"temperature"` 431 MaxOutputTokens int `json:"maxOutputTokens"` 432 TopP float64 `json:"topP"` 433 TopK int `json:"topK"` 434 } 435 436 type generateResponse struct { 437 Predictions []prediction `json:"predictions,omitempty"` 438 Error *palmApiError `json:"error,omitempty"` 439 DeployedModelId string `json:"deployedModelId,omitempty"` 440 Model string `json:"model,omitempty"` 441 ModelDisplayName string `json:"modelDisplayName,omitempty"` 442 ModelVersionId string `json:"modelVersionId,omitempty"` 443 } 444 445 type prediction struct { 446 Candidates []candidate `json:"candidates,omitempty"` 447 SafetyAttributes *[]safetyAttributes `json:"safetyAttributes,omitempty"` 448 } 449 450 type candidate struct { 451 Author string `json:"author"` 452 Content string `json:"content"` 453 } 454 455 type safetyAttributes struct { 456 Scores []float64 `json:"scores,omitempty"` 457 Blocked *bool `json:"blocked,omitempty"` 458 Categories []string `json:"categories,omitempty"` 459 } 460 461 type palmApiError struct { 462 Code int `json:"code"` 463 Message string `json:"message"` 464 Status string `json:"status"` 465 } 466 467 type generateMessageRequest struct { 468 Prompt *generateMessagePrompt `json:"prompt,omitempty"` 469 Temperature float64 `json:"temperature,omitempty"` 470 CandidateCount int `json:"candidateCount,omitempty"` // default 1 471 TopP float64 `json:"topP"` 472 TopK int `json:"topK"` 473 } 474 475 type generateMessagePrompt struct { 476 Context string `json:"prompt,omitempty"` 477 Examples []generateExample `json:"examples,omitempty"` 478 Messages []generateMessage `json:"messages,omitempty"` 479 } 480 481 type generateMessage struct { 482 Author string `json:"author,omitempty"` 483 Content string `json:"content,omitempty"` 484 CitationMetadata *generateCitationMetadata `json:"citationMetadata,omitempty"` 485 } 486 487 type generateCitationMetadata struct { 488 CitationSources []generateCitationSource `json:"citationSources,omitempty"` 489 } 490 491 type generateCitationSource struct { 492 StartIndex int `json:"startIndex,omitempty"` 493 EndIndex int `json:"endIndex,omitempty"` 494 URI string `json:"uri,omitempty"` 495 License string `json:"license,omitempty"` 496 } 497 498 type generateExample struct { 499 Input *generateMessage `json:"input,omitempty"` 500 Output *generateMessage `json:"output,omitempty"` 501 } 502 503 type generateMessageResponse struct { 504 Candidates []generateMessage `json:"candidates,omitempty"` 505 Messages []generateMessage `json:"messages,omitempty"` 506 Filters []contentFilter `json:"filters,omitempty"` 507 Error *palmApiError `json:"error,omitempty"` 508 } 509 510 type contentFilter struct { 511 Reason string `json:"reason,omitempty"` 512 Message string `json:"message,omitempty"` 513 } 514 515 type generateContentRequest struct { 516 Contents []content `json:"contents,omitempty"` 517 SafetySettings []safetySetting `json:"safetySettings,omitempty"` 518 GenerationConfig *generationConfig `json:"generationConfig,omitempty"` 519 } 520 521 type content struct { 522 Parts []part `json:"parts,omitempty"` 523 Role string `json:"role,omitempty"` 524 } 525 526 type part struct { 527 Text string `json:"text,omitempty"` 528 InlineData string `json:"inline_data,omitempty"` 529 } 530 531 type safetySetting struct { 532 Category harmCategory `json:"category,omitempty"` 533 Threshold harmBlockThreshold `json:"threshold,omitempty"` 534 } 535 536 type generationConfig struct { 537 StopSequences []string `json:"stopSequences,omitempty"` 538 CandidateCount int `json:"candidateCount,omitempty"` 539 MaxOutputTokens int `json:"maxOutputTokens,omitempty"` 540 Temperature float64 `json:"temperature,omitempty"` 541 TopP float64 `json:"topP,omitempty"` 542 TopK int `json:"topK,omitempty"` 543 } 544 545 type generateContentResponse struct { 546 Candidates []generateContentCandidate `json:"candidates,omitempty"` 547 PromptFeedback *promptFeedback `json:"promptFeedback,omitempty"` 548 Error *palmApiError `json:"error,omitempty"` 549 } 550 551 type generateContentCandidate struct { 552 Content contentResponse `json:"content,omitempty"` 553 FinishReason string `json:"finishReason,omitempty"` 554 Index int `json:"index,omitempty"` 555 SafetyRatings []safetyRating `json:"safetyRatings,omitempty"` 556 } 557 558 type contentResponse struct { 559 Parts []part `json:"parts,omitempty"` 560 Role string `json:"role,omitempty"` 561 } 562 563 type promptFeedback struct { 564 SafetyRatings []safetyRating `json:"safetyRatings,omitempty"` 565 } 566 567 type safetyRating struct { 568 Category harmCategory `json:"category,omitempty"` 569 Probability harmProbability `json:"probability,omitempty"` 570 Blocked *bool `json:"blocked,omitempty"` 571 }