github.com/weaviate/weaviate@v1.24.6/modules/generative-palm/clients/palm.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package clients
    13  
    14  import (
    15  	"bytes"
    16  	"context"
    17  	"encoding/json"
    18  	"fmt"
    19  	"io"
    20  	"net/http"
    21  	"regexp"
    22  	"strings"
    23  	"time"
    24  
    25  	"github.com/weaviate/weaviate/usecases/modulecomponents"
    26  
    27  	"github.com/pkg/errors"
    28  	"github.com/sirupsen/logrus"
    29  	"github.com/weaviate/weaviate/entities/moduletools"
    30  	"github.com/weaviate/weaviate/modules/generative-palm/config"
    31  	generativemodels "github.com/weaviate/weaviate/usecases/modulecomponents/additional/models"
    32  )
    33  
    34  type harmCategory string
    35  
    36  var (
    37  	// Category is unspecified.
    38  	HarmCategoryUnspecified harmCategory = "HARM_CATEGORY_UNSPECIFIED"
    39  	// Negative or harmful comments targeting identity and/or protected attribute.
    40  	HarmCategoryDerogatory harmCategory = "HARM_CATEGORY_DEROGATORY"
    41  	// Content that is rude, disrepspectful, or profane.
    42  	HarmCategoryToxicity harmCategory = "HARM_CATEGORY_TOXICITY"
    43  	// Describes scenarios depictng violence against an individual or group, or general descriptions of gore.
    44  	HarmCategoryViolence harmCategory = "HARM_CATEGORY_VIOLENCE"
    45  	// Contains references to sexual acts or other lewd content.
    46  	HarmCategorySexual harmCategory = "HARM_CATEGORY_SEXUAL"
    47  	// Promotes unchecked medical advice.
    48  	HarmCategoryMedical harmCategory = "HARM_CATEGORY_MEDICAL"
    49  	// Dangerous content that promotes, facilitates, or encourages harmful acts.
    50  	HarmCategoryDangerous harmCategory = "HARM_CATEGORY_DANGEROUS"
    51  	// Harassment content.
    52  	HarmCategoryHarassment harmCategory = "HARM_CATEGORY_HARASSMENT"
    53  	// Hate speech and content.
    54  	HarmCategoryHate_speech harmCategory = "HARM_CATEGORY_HATE_SPEECH"
    55  	// Sexually explicit content.
    56  	HarmCategorySexually_explicit harmCategory = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
    57  	// Dangerous content.
    58  	HarmCategoryDangerous_content harmCategory = "HARM_CATEGORY_DANGEROUS_CONTENT"
    59  )
    60  
    61  type harmBlockThreshold string
    62  
    63  var (
    64  	// Threshold is unspecified.
    65  	HarmBlockThresholdUnspecified harmBlockThreshold = "HARM_BLOCK_THRESHOLD_UNSPECIFIED"
    66  	// Content with NEGLIGIBLE will be allowed.
    67  	BlockLowAndAbove harmBlockThreshold = "BLOCK_LOW_AND_ABOVE"
    68  	// Content with NEGLIGIBLE and LOW will be allowed.
    69  	BlockMediumAndAbove harmBlockThreshold = "BLOCK_MEDIUM_AND_ABOVE"
    70  	// Content with NEGLIGIBLE, LOW, and MEDIUM will be allowed.
    71  	BlockOnlyHigh harmBlockThreshold = "BLOCK_ONLY_HIGH"
    72  	// All content will be allowed.
    73  	BlockNone harmBlockThreshold = "BLOCK_NONE"
    74  )
    75  
    76  type harmProbability string
    77  
    78  var (
    79  	// Probability is unspecified.
    80  	HARM_PROBABILITY_UNSPECIFIED harmProbability = "HARM_PROBABILITY_UNSPECIFIED"
    81  	// Content has a negligible chance of being unsafe.
    82  	NEGLIGIBLE harmProbability = "NEGLIGIBLE"
    83  	// Content has a low chance of being unsafe.
    84  	LOW harmProbability = "LOW"
    85  	// Content has a medium chance of being unsafe.
    86  	MEDIUM harmProbability = "MEDIUM"
    87  	// Content has a high chance of being unsafe.
    88  	HIGH harmProbability = "HIGH"
    89  )
    90  
    91  var compile, _ = regexp.Compile(`{([\w\s]*?)}`)
    92  
    93  func buildURL(useGenerativeAI bool, apiEndoint, projectID, modelID string) string {
    94  	if useGenerativeAI {
    95  		// Generative AI endpoints, for more context check out this link:
    96  		// https://developers.generativeai.google/models/language#model_variations
    97  		// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage
    98  		if strings.HasPrefix(modelID, "gemini") {
    99  			return fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:generateContent", modelID)
   100  		}
   101  		return "https://generativelanguage.googleapis.com/v1beta2/models/chat-bison-001:generateMessage"
   102  	}
   103  	urlTemplate := "https://%s/v1/projects/%s/locations/us-central1/publishers/google/models/%s:predict"
   104  	return fmt.Sprintf(urlTemplate, apiEndoint, projectID, modelID)
   105  }
   106  
   107  type palm struct {
   108  	apiKey     string
   109  	buildUrlFn func(useGenerativeAI bool, apiEndoint, projectID, modelID string) string
   110  	httpClient *http.Client
   111  	logger     logrus.FieldLogger
   112  }
   113  
   114  func New(apiKey string, timeout time.Duration, logger logrus.FieldLogger) *palm {
   115  	return &palm{
   116  		apiKey: apiKey,
   117  		httpClient: &http.Client{
   118  			Timeout: timeout,
   119  		},
   120  		buildUrlFn: buildURL,
   121  		logger:     logger,
   122  	}
   123  }
   124  
   125  func (v *palm) GenerateSingleResult(ctx context.Context, textProperties map[string]string, prompt string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) {
   126  	forPrompt, err := v.generateForPrompt(textProperties, prompt)
   127  	if err != nil {
   128  		return nil, err
   129  	}
   130  	return v.Generate(ctx, cfg, forPrompt)
   131  }
   132  
   133  func (v *palm) GenerateAllResults(ctx context.Context, textProperties []map[string]string, task string, cfg moduletools.ClassConfig) (*generativemodels.GenerateResponse, error) {
   134  	forTask, err := v.generatePromptForTask(textProperties, task)
   135  	if err != nil {
   136  		return nil, err
   137  	}
   138  	return v.Generate(ctx, cfg, forTask)
   139  }
   140  
   141  func (v *palm) Generate(ctx context.Context, cfg moduletools.ClassConfig, prompt string) (*generativemodels.GenerateResponse, error) {
   142  	settings := config.NewClassSettings(cfg)
   143  
   144  	useGenerativeAIEndpoint := v.useGenerativeAIEndpoint(settings.ApiEndpoint())
   145  	modelID := settings.ModelID()
   146  	if settings.EndpointID() != "" {
   147  		modelID = settings.EndpointID()
   148  	}
   149  
   150  	endpointURL := v.buildUrlFn(useGenerativeAIEndpoint, settings.ApiEndpoint(), settings.ProjectID(), modelID)
   151  	input := v.getPayload(useGenerativeAIEndpoint, prompt, settings)
   152  
   153  	body, err := json.Marshal(input)
   154  	if err != nil {
   155  		return nil, errors.Wrap(err, "marshal body")
   156  	}
   157  
   158  	req, err := http.NewRequestWithContext(ctx, "POST", endpointURL,
   159  		bytes.NewReader(body))
   160  	if err != nil {
   161  		return nil, errors.Wrap(err, "create POST request")
   162  	}
   163  
   164  	apiKey, err := v.getApiKey(ctx)
   165  	if err != nil {
   166  		return nil, errors.Wrapf(err, "Google API Key")
   167  	}
   168  	req.Header.Add("Content-Type", "application/json")
   169  	if useGenerativeAIEndpoint {
   170  		req.Header.Add("x-goog-api-key", apiKey)
   171  	} else {
   172  		req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", apiKey))
   173  	}
   174  
   175  	res, err := v.httpClient.Do(req)
   176  	if err != nil {
   177  		return nil, errors.Wrap(err, "send POST request")
   178  	}
   179  	defer res.Body.Close()
   180  
   181  	bodyBytes, err := io.ReadAll(res.Body)
   182  	if err != nil {
   183  		return nil, errors.Wrap(err, "read response body")
   184  	}
   185  
   186  	if useGenerativeAIEndpoint {
   187  		if strings.HasPrefix(modelID, "gemini") {
   188  			return v.parseGenerateContentResponse(res.StatusCode, bodyBytes)
   189  		}
   190  		return v.parseGenerateMessageResponse(res.StatusCode, bodyBytes)
   191  	}
   192  	return v.parseResponse(res.StatusCode, bodyBytes)
   193  }
   194  
   195  func (v *palm) parseGenerateMessageResponse(statusCode int, bodyBytes []byte) (*generativemodels.GenerateResponse, error) {
   196  	var resBody generateMessageResponse
   197  	if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
   198  		return nil, errors.Wrap(err, "unmarshal response body")
   199  	}
   200  
   201  	if err := v.checkResponse(statusCode, resBody.Error); err != nil {
   202  		return nil, err
   203  	}
   204  
   205  	if len(resBody.Candidates) > 0 {
   206  		return v.getGenerateResponse(resBody.Candidates[0].Content)
   207  	}
   208  
   209  	return &generativemodels.GenerateResponse{
   210  		Result: nil,
   211  	}, nil
   212  }
   213  
   214  func (v *palm) parseGenerateContentResponse(statusCode int, bodyBytes []byte) (*generativemodels.GenerateResponse, error) {
   215  	var resBody generateContentResponse
   216  	if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
   217  		return nil, errors.Wrap(err, "unmarshal response body")
   218  	}
   219  
   220  	if err := v.checkResponse(statusCode, resBody.Error); err != nil {
   221  		return nil, err
   222  	}
   223  
   224  	if len(resBody.Candidates) > 0 && len(resBody.Candidates[0].Content.Parts) > 0 {
   225  		return v.getGenerateResponse(resBody.Candidates[0].Content.Parts[0].Text)
   226  	}
   227  
   228  	return &generativemodels.GenerateResponse{
   229  		Result: nil,
   230  	}, nil
   231  }
   232  
   233  func (v *palm) parseResponse(statusCode int, bodyBytes []byte) (*generativemodels.GenerateResponse, error) {
   234  	var resBody generateResponse
   235  	if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
   236  		return nil, errors.Wrap(err, "unmarshal response body")
   237  	}
   238  
   239  	if err := v.checkResponse(statusCode, resBody.Error); err != nil {
   240  		return nil, err
   241  	}
   242  
   243  	if len(resBody.Predictions) > 0 && len(resBody.Predictions[0].Candidates) > 0 {
   244  		return v.getGenerateResponse(resBody.Predictions[0].Candidates[0].Content)
   245  	}
   246  
   247  	return &generativemodels.GenerateResponse{
   248  		Result: nil,
   249  	}, nil
   250  }
   251  
   252  func (v *palm) getGenerateResponse(content string) (*generativemodels.GenerateResponse, error) {
   253  	if content != "" {
   254  		trimmedResponse := strings.Trim(content, "\n")
   255  		return &generativemodels.GenerateResponse{
   256  			Result: &trimmedResponse,
   257  		}, nil
   258  	}
   259  
   260  	return &generativemodels.GenerateResponse{
   261  		Result: nil,
   262  	}, nil
   263  }
   264  
   265  func (v *palm) checkResponse(statusCode int, palmApiError *palmApiError) error {
   266  	if statusCode != 200 || palmApiError != nil {
   267  		if palmApiError != nil {
   268  			return fmt.Errorf("connection to Google failed with status: %v error: %v",
   269  				statusCode, palmApiError.Message)
   270  		}
   271  		return fmt.Errorf("connection to Google failed with status: %d", statusCode)
   272  	}
   273  	return nil
   274  }
   275  
   276  func (v *palm) useGenerativeAIEndpoint(apiEndpoint string) bool {
   277  	return apiEndpoint == "generativelanguage.googleapis.com"
   278  }
   279  
   280  func (v *palm) getPayload(useGenerativeAI bool, prompt string, settings config.ClassSettings) any {
   281  	if useGenerativeAI {
   282  		if strings.HasPrefix(settings.ModelID(), "gemini") {
   283  			input := generateContentRequest{
   284  				Contents: []content{
   285  					{
   286  						Role: "user",
   287  						Parts: []part{
   288  							{
   289  								Text: prompt,
   290  							},
   291  						},
   292  					},
   293  				},
   294  				GenerationConfig: &generationConfig{
   295  					Temperature:    settings.Temperature(),
   296  					TopP:           settings.TopP(),
   297  					TopK:           settings.TopK(),
   298  					CandidateCount: 1,
   299  				},
   300  				SafetySettings: []safetySetting{
   301  					{
   302  						Category:  HarmCategoryHarassment,
   303  						Threshold: BlockMediumAndAbove,
   304  					},
   305  					{
   306  						Category:  HarmCategoryHate_speech,
   307  						Threshold: BlockMediumAndAbove,
   308  					},
   309  					{
   310  						Category:  HarmCategoryDangerous_content,
   311  						Threshold: BlockMediumAndAbove,
   312  					},
   313  					{
   314  						Category:  HarmCategoryDangerous_content,
   315  						Threshold: BlockMediumAndAbove,
   316  					},
   317  				},
   318  			}
   319  			return input
   320  		}
   321  		input := generateMessageRequest{
   322  			Prompt: &generateMessagePrompt{
   323  				Messages: []generateMessage{
   324  					{
   325  						Content: prompt,
   326  					},
   327  				},
   328  			},
   329  			Temperature:    settings.Temperature(),
   330  			TopP:           settings.TopP(),
   331  			TopK:           settings.TopK(),
   332  			CandidateCount: 1,
   333  		}
   334  		return input
   335  	}
   336  	input := generateInput{
   337  		Instances: []instance{
   338  			{
   339  				Messages: []message{
   340  					{
   341  						Content: prompt,
   342  					},
   343  				},
   344  			},
   345  		},
   346  		Parameters: parameters{
   347  			Temperature:     settings.Temperature(),
   348  			MaxOutputTokens: settings.TokenLimit(),
   349  			TopP:            settings.TopP(),
   350  			TopK:            settings.TopK(),
   351  		},
   352  	}
   353  	return input
   354  }
   355  
   356  func (v *palm) generatePromptForTask(textProperties []map[string]string, task string) (string, error) {
   357  	marshal, err := json.Marshal(textProperties)
   358  	if err != nil {
   359  		return "", err
   360  	}
   361  	return fmt.Sprintf(`'%v:
   362  %v`, task, string(marshal)), nil
   363  }
   364  
   365  func (v *palm) generateForPrompt(textProperties map[string]string, prompt string) (string, error) {
   366  	all := compile.FindAll([]byte(prompt), -1)
   367  	for _, match := range all {
   368  		originalProperty := string(match)
   369  		replacedProperty := compile.FindStringSubmatch(originalProperty)[1]
   370  		replacedProperty = strings.TrimSpace(replacedProperty)
   371  		value := textProperties[replacedProperty]
   372  		if value == "" {
   373  			return "", errors.Errorf("Following property has empty value: '%v'. Make sure you spell the property name correctly, verify that the property exists and has a value", replacedProperty)
   374  		}
   375  		prompt = strings.ReplaceAll(prompt, originalProperty, value)
   376  	}
   377  	return prompt, nil
   378  }
   379  
   380  func (v *palm) getApiKey(ctx context.Context) (string, error) {
   381  	if apiKeyValue := v.getValueFromContext(ctx, "X-Google-Api-Key"); apiKeyValue != "" {
   382  		return apiKeyValue, nil
   383  	}
   384  	if apiKeyValue := v.getValueFromContext(ctx, "X-Palm-Api-Key"); apiKeyValue != "" {
   385  		return apiKeyValue, nil
   386  	}
   387  	if len(v.apiKey) > 0 {
   388  		return v.apiKey, nil
   389  	}
   390  	return "", errors.New("no api key found " +
   391  		"neither in request header: X-Palm-Api-Key or X-Google-Api-Key " +
   392  		"nor in environment variable under PALM_APIKEY or GOOGLE_APIKEY")
   393  }
   394  
   395  func (v *palm) getValueFromContext(ctx context.Context, key string) string {
   396  	if value := ctx.Value(key); value != nil {
   397  		if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 {
   398  			return keyHeader[0]
   399  		}
   400  	}
   401  	// try getting header from GRPC if not successful
   402  	if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 {
   403  		return apiKey[0]
   404  	}
   405  	return ""
   406  }
   407  
   408  type generateInput struct {
   409  	Instances  []instance `json:"instances,omitempty"`
   410  	Parameters parameters `json:"parameters"`
   411  }
   412  
   413  type instance struct {
   414  	Context  string    `json:"context,omitempty"`
   415  	Messages []message `json:"messages,omitempty"`
   416  	Examples []example `json:"examples,omitempty"`
   417  }
   418  
   419  type message struct {
   420  	Author  string `json:"author"`
   421  	Content string `json:"content"`
   422  }
   423  
   424  type example struct {
   425  	Input  string `json:"input"`
   426  	Output string `json:"output"`
   427  }
   428  
   429  type parameters struct {
   430  	Temperature     float64 `json:"temperature"`
   431  	MaxOutputTokens int     `json:"maxOutputTokens"`
   432  	TopP            float64 `json:"topP"`
   433  	TopK            int     `json:"topK"`
   434  }
   435  
   436  type generateResponse struct {
   437  	Predictions      []prediction  `json:"predictions,omitempty"`
   438  	Error            *palmApiError `json:"error,omitempty"`
   439  	DeployedModelId  string        `json:"deployedModelId,omitempty"`
   440  	Model            string        `json:"model,omitempty"`
   441  	ModelDisplayName string        `json:"modelDisplayName,omitempty"`
   442  	ModelVersionId   string        `json:"modelVersionId,omitempty"`
   443  }
   444  
   445  type prediction struct {
   446  	Candidates       []candidate         `json:"candidates,omitempty"`
   447  	SafetyAttributes *[]safetyAttributes `json:"safetyAttributes,omitempty"`
   448  }
   449  
   450  type candidate struct {
   451  	Author  string `json:"author"`
   452  	Content string `json:"content"`
   453  }
   454  
   455  type safetyAttributes struct {
   456  	Scores     []float64 `json:"scores,omitempty"`
   457  	Blocked    *bool     `json:"blocked,omitempty"`
   458  	Categories []string  `json:"categories,omitempty"`
   459  }
   460  
   461  type palmApiError struct {
   462  	Code    int    `json:"code"`
   463  	Message string `json:"message"`
   464  	Status  string `json:"status"`
   465  }
   466  
   467  type generateMessageRequest struct {
   468  	Prompt         *generateMessagePrompt `json:"prompt,omitempty"`
   469  	Temperature    float64                `json:"temperature,omitempty"`
   470  	CandidateCount int                    `json:"candidateCount,omitempty"` // default 1
   471  	TopP           float64                `json:"topP"`
   472  	TopK           int                    `json:"topK"`
   473  }
   474  
   475  type generateMessagePrompt struct {
   476  	Context  string            `json:"prompt,omitempty"`
   477  	Examples []generateExample `json:"examples,omitempty"`
   478  	Messages []generateMessage `json:"messages,omitempty"`
   479  }
   480  
   481  type generateMessage struct {
   482  	Author           string                    `json:"author,omitempty"`
   483  	Content          string                    `json:"content,omitempty"`
   484  	CitationMetadata *generateCitationMetadata `json:"citationMetadata,omitempty"`
   485  }
   486  
   487  type generateCitationMetadata struct {
   488  	CitationSources []generateCitationSource `json:"citationSources,omitempty"`
   489  }
   490  
   491  type generateCitationSource struct {
   492  	StartIndex int    `json:"startIndex,omitempty"`
   493  	EndIndex   int    `json:"endIndex,omitempty"`
   494  	URI        string `json:"uri,omitempty"`
   495  	License    string `json:"license,omitempty"`
   496  }
   497  
   498  type generateExample struct {
   499  	Input  *generateMessage `json:"input,omitempty"`
   500  	Output *generateMessage `json:"output,omitempty"`
   501  }
   502  
   503  type generateMessageResponse struct {
   504  	Candidates []generateMessage `json:"candidates,omitempty"`
   505  	Messages   []generateMessage `json:"messages,omitempty"`
   506  	Filters    []contentFilter   `json:"filters,omitempty"`
   507  	Error      *palmApiError     `json:"error,omitempty"`
   508  }
   509  
   510  type contentFilter struct {
   511  	Reason  string `json:"reason,omitempty"`
   512  	Message string `json:"message,omitempty"`
   513  }
   514  
   515  type generateContentRequest struct {
   516  	Contents         []content         `json:"contents,omitempty"`
   517  	SafetySettings   []safetySetting   `json:"safetySettings,omitempty"`
   518  	GenerationConfig *generationConfig `json:"generationConfig,omitempty"`
   519  }
   520  
   521  type content struct {
   522  	Parts []part `json:"parts,omitempty"`
   523  	Role  string `json:"role,omitempty"`
   524  }
   525  
   526  type part struct {
   527  	Text       string `json:"text,omitempty"`
   528  	InlineData string `json:"inline_data,omitempty"`
   529  }
   530  
   531  type safetySetting struct {
   532  	Category  harmCategory       `json:"category,omitempty"`
   533  	Threshold harmBlockThreshold `json:"threshold,omitempty"`
   534  }
   535  
   536  type generationConfig struct {
   537  	StopSequences   []string `json:"stopSequences,omitempty"`
   538  	CandidateCount  int      `json:"candidateCount,omitempty"`
   539  	MaxOutputTokens int      `json:"maxOutputTokens,omitempty"`
   540  	Temperature     float64  `json:"temperature,omitempty"`
   541  	TopP            float64  `json:"topP,omitempty"`
   542  	TopK            int      `json:"topK,omitempty"`
   543  }
   544  
   545  type generateContentResponse struct {
   546  	Candidates     []generateContentCandidate `json:"candidates,omitempty"`
   547  	PromptFeedback *promptFeedback            `json:"promptFeedback,omitempty"`
   548  	Error          *palmApiError              `json:"error,omitempty"`
   549  }
   550  
   551  type generateContentCandidate struct {
   552  	Content       contentResponse `json:"content,omitempty"`
   553  	FinishReason  string          `json:"finishReason,omitempty"`
   554  	Index         int             `json:"index,omitempty"`
   555  	SafetyRatings []safetyRating  `json:"safetyRatings,omitempty"`
   556  }
   557  
   558  type contentResponse struct {
   559  	Parts []part `json:"parts,omitempty"`
   560  	Role  string `json:"role,omitempty"`
   561  }
   562  
   563  type promptFeedback struct {
   564  	SafetyRatings []safetyRating `json:"safetyRatings,omitempty"`
   565  }
   566  
   567  type safetyRating struct {
   568  	Category    harmCategory    `json:"category,omitempty"`
   569  	Probability harmProbability `json:"probability,omitempty"`
   570  	Blocked     *bool           `json:"blocked,omitempty"`
   571  }