github.com/weaviate/weaviate@v1.24.6/modules/text2vec-palm/clients/palm.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package clients
    13  
    14  import (
    15  	"bytes"
    16  	"context"
    17  	"encoding/json"
    18  	"fmt"
    19  	"io"
    20  	"net/http"
    21  	"strings"
    22  	"time"
    23  
    24  	"github.com/weaviate/weaviate/usecases/modulecomponents"
    25  
    26  	"github.com/pkg/errors"
    27  	"github.com/sirupsen/logrus"
    28  	"github.com/weaviate/weaviate/modules/text2vec-palm/ent"
    29  )
    30  
    31  type taskType string
    32  
    33  var (
    34  	// Specifies the given text is a document in a search/retrieval setting
    35  	retrievalQuery taskType = "RETRIEVAL_QUERY"
    36  	// Specifies the given text is a query in a search/retrieval setting
    37  	retrievalDocument taskType = "RETRIEVAL_DOCUMENT"
    38  )
    39  
    40  func buildURL(useGenerativeAI bool, apiEndoint, projectID, modelID string) string {
    41  	if useGenerativeAI {
    42  		// Generative AI supports only 1 embedding model: embedding-gecko-001. So for now
    43  		// in order to keep it simple we generate one variation of PaLM API url.
    44  		// For more context check out this link:
    45  		// https://developers.generativeai.google/models/language#model_variations
    46  		return "https://generativelanguage.googleapis.com/v1beta3/models/embedding-gecko-001:batchEmbedText"
    47  	}
    48  	urlTemplate := "https://%s/v1/projects/%s/locations/us-central1/publishers/google/models/%s:predict"
    49  	return fmt.Sprintf(urlTemplate, apiEndoint, projectID, modelID)
    50  }
    51  
    52  type palm struct {
    53  	apiKey       string
    54  	httpClient   *http.Client
    55  	urlBuilderFn func(useGenerativeAI bool, apiEndoint, projectID, modelID string) string
    56  	logger       logrus.FieldLogger
    57  }
    58  
    59  func New(apiKey string, timeout time.Duration, logger logrus.FieldLogger) *palm {
    60  	return &palm{
    61  		apiKey: apiKey,
    62  		httpClient: &http.Client{
    63  			Timeout: timeout,
    64  		},
    65  		urlBuilderFn: buildURL,
    66  		logger:       logger,
    67  	}
    68  }
    69  
    70  func (v *palm) Vectorize(ctx context.Context, input []string,
    71  	config ent.VectorizationConfig, titlePropertyValue string,
    72  ) (*ent.VectorizationResult, error) {
    73  	return v.vectorize(ctx, input, retrievalDocument, titlePropertyValue, config)
    74  }
    75  
    76  func (v *palm) VectorizeQuery(ctx context.Context, input []string,
    77  	config ent.VectorizationConfig,
    78  ) (*ent.VectorizationResult, error) {
    79  	return v.vectorize(ctx, input, retrievalQuery, "", config)
    80  }
    81  
    82  func (v *palm) vectorize(ctx context.Context, input []string, taskType taskType,
    83  	titlePropertyValue string, config ent.VectorizationConfig,
    84  ) (*ent.VectorizationResult, error) {
    85  	useGenerativeAIEndpoint := v.useGenerativeAIEndpoint(config)
    86  
    87  	payload := v.getPayload(useGenerativeAIEndpoint, input, taskType, titlePropertyValue, config)
    88  	body, err := json.Marshal(payload)
    89  	if err != nil {
    90  		return nil, errors.Wrapf(err, "marshal body")
    91  	}
    92  
    93  	endpointURL := v.urlBuilderFn(useGenerativeAIEndpoint,
    94  		v.getApiEndpoint(config), v.getProjectID(config), v.getModel(config))
    95  
    96  	req, err := http.NewRequestWithContext(ctx, "POST", endpointURL,
    97  		bytes.NewReader(body))
    98  	if err != nil {
    99  		return nil, errors.Wrap(err, "create POST request")
   100  	}
   101  
   102  	apiKey, err := v.getApiKey(ctx)
   103  	if err != nil {
   104  		return nil, errors.Wrapf(err, "Google API Key")
   105  	}
   106  	req.Header.Add("Content-Type", "application/json")
   107  	if useGenerativeAIEndpoint {
   108  		req.Header.Add("x-goog-api-key", apiKey)
   109  	} else {
   110  		req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", apiKey))
   111  	}
   112  
   113  	res, err := v.httpClient.Do(req)
   114  	if err != nil {
   115  		return nil, errors.Wrap(err, "send POST request")
   116  	}
   117  	defer res.Body.Close()
   118  
   119  	bodyBytes, err := io.ReadAll(res.Body)
   120  	if err != nil {
   121  		return nil, errors.Wrap(err, "read response body")
   122  	}
   123  
   124  	if useGenerativeAIEndpoint {
   125  		return v.parseGenerativeAIApiResponse(res.StatusCode, bodyBytes, input)
   126  	}
   127  	return v.parseEmbeddingsResponse(res.StatusCode, bodyBytes, input)
   128  }
   129  
   130  func (v *palm) useGenerativeAIEndpoint(config ent.VectorizationConfig) bool {
   131  	return v.getApiEndpoint(config) == "generativelanguage.googleapis.com"
   132  }
   133  
   134  func (v *palm) getPayload(useGenerativeAI bool, input []string,
   135  	taskType taskType, title string, config ent.VectorizationConfig,
   136  ) interface{} {
   137  	if useGenerativeAI {
   138  		return batchEmbedTextRequest{Texts: input}
   139  	}
   140  	isModelVersion001 := strings.HasSuffix(config.Model, "@001")
   141  	instances := make([]instance, len(input))
   142  	for i := range input {
   143  		if isModelVersion001 {
   144  			instances[i] = instance{Content: input[i]}
   145  		} else {
   146  			instances[i] = instance{Content: input[i], TaskType: taskType, Title: title}
   147  		}
   148  	}
   149  	return embeddingsRequest{instances}
   150  }
   151  
   152  func (v *palm) checkResponse(statusCode int, palmApiError *palmApiError) error {
   153  	if statusCode != 200 || palmApiError != nil {
   154  		if palmApiError != nil {
   155  			return fmt.Errorf("connection to Google failed with status: %v error: %v",
   156  				statusCode, palmApiError.Message)
   157  		}
   158  		return fmt.Errorf("connection to Google failed with status: %d", statusCode)
   159  	}
   160  	return nil
   161  }
   162  
   163  func (v *palm) getApiKey(ctx context.Context) (string, error) {
   164  	if apiKeyValue := v.getValueFromContext(ctx, "X-Google-Api-Key"); apiKeyValue != "" {
   165  		return apiKeyValue, nil
   166  	}
   167  	if apiKeyValue := v.getValueFromContext(ctx, "X-Palm-Api-Key"); apiKeyValue != "" {
   168  		return apiKeyValue, nil
   169  	}
   170  	if len(v.apiKey) > 0 {
   171  		return v.apiKey, nil
   172  	}
   173  	return "", errors.New("no api key found " +
   174  		"neither in request header: X-Palm-Api-Key or X-Google-Api-Key " +
   175  		"nor in environment variable under PALM_APIKEY or GOOGLE_APIKEY")
   176  }
   177  
   178  func (v *palm) getValueFromContext(ctx context.Context, key string) string {
   179  	if value := ctx.Value(key); value != nil {
   180  		if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 {
   181  			return keyHeader[0]
   182  		}
   183  	}
   184  	// try getting header from GRPC if not successful
   185  	if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 {
   186  		return apiKey[0]
   187  	}
   188  	return ""
   189  }
   190  
   191  func (v *palm) parseGenerativeAIApiResponse(statusCode int,
   192  	bodyBytes []byte, input []string,
   193  ) (*ent.VectorizationResult, error) {
   194  	var resBody batchEmbedTextResponse
   195  	if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
   196  		return nil, errors.Wrap(err, "unmarshal response body")
   197  	}
   198  
   199  	if err := v.checkResponse(statusCode, resBody.Error); err != nil {
   200  		return nil, err
   201  	}
   202  
   203  	if len(resBody.Embeddings) == 0 {
   204  		return nil, errors.Errorf("empty embeddings response")
   205  	}
   206  
   207  	vectors := make([][]float32, len(resBody.Embeddings))
   208  	for i := range resBody.Embeddings {
   209  		vectors[i] = resBody.Embeddings[i].Value
   210  	}
   211  	dimensions := len(resBody.Embeddings[0].Value)
   212  
   213  	return v.getResponse(input, dimensions, vectors)
   214  }
   215  
   216  func (v *palm) parseEmbeddingsResponse(statusCode int,
   217  	bodyBytes []byte, input []string,
   218  ) (*ent.VectorizationResult, error) {
   219  	var resBody embeddingsResponse
   220  	if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
   221  		return nil, errors.Wrap(err, "unmarshal response body")
   222  	}
   223  
   224  	if err := v.checkResponse(statusCode, resBody.Error); err != nil {
   225  		return nil, err
   226  	}
   227  
   228  	if len(resBody.Predictions) == 0 {
   229  		return nil, errors.Errorf("empty embeddings response")
   230  	}
   231  
   232  	vectors := make([][]float32, len(resBody.Predictions))
   233  	for i := range resBody.Predictions {
   234  		vectors[i] = resBody.Predictions[i].Embeddings.Values
   235  	}
   236  	dimensions := len(resBody.Predictions[0].Embeddings.Values)
   237  
   238  	return v.getResponse(input, dimensions, vectors)
   239  }
   240  
   241  func (v *palm) getResponse(input []string, dimensions int, vectors [][]float32) (*ent.VectorizationResult, error) {
   242  	return &ent.VectorizationResult{
   243  		Texts:      input,
   244  		Dimensions: dimensions,
   245  		Vectors:    vectors,
   246  	}, nil
   247  }
   248  
   249  func (v *palm) getApiEndpoint(config ent.VectorizationConfig) string {
   250  	return config.ApiEndpoint
   251  }
   252  
   253  func (v *palm) getProjectID(config ent.VectorizationConfig) string {
   254  	return config.ProjectID
   255  }
   256  
   257  func (v *palm) getModel(config ent.VectorizationConfig) string {
   258  	return config.Model
   259  }
   260  
   261  type embeddingsRequest struct {
   262  	Instances []instance `json:"instances,omitempty"`
   263  }
   264  
   265  type instance struct {
   266  	Content  string   `json:"content"`
   267  	TaskType taskType `json:"task_type,omitempty"`
   268  	Title    string   `json:"title,omitempty"`
   269  }
   270  
   271  type embeddingsResponse struct {
   272  	Predictions      []prediction  `json:"predictions,omitempty"`
   273  	Error            *palmApiError `json:"error,omitempty"`
   274  	DeployedModelId  string        `json:"deployedModelId,omitempty"`
   275  	Model            string        `json:"model,omitempty"`
   276  	ModelDisplayName string        `json:"modelDisplayName,omitempty"`
   277  	ModelVersionId   string        `json:"modelVersionId,omitempty"`
   278  }
   279  
   280  type prediction struct {
   281  	Embeddings       embeddings        `json:"embeddings,omitempty"`
   282  	SafetyAttributes *safetyAttributes `json:"safetyAttributes,omitempty"`
   283  }
   284  
   285  type embeddings struct {
   286  	Values []float32 `json:"values,omitempty"`
   287  }
   288  
   289  type safetyAttributes struct {
   290  	Scores     []float64 `json:"scores,omitempty"`
   291  	Blocked    *bool     `json:"blocked,omitempty"`
   292  	Categories []string  `json:"categories,omitempty"`
   293  }
   294  
   295  type palmApiError struct {
   296  	Code    int    `json:"code"`
   297  	Message string `json:"message"`
   298  	Status  string `json:"status"`
   299  }
   300  
   301  type batchEmbedTextRequest struct {
   302  	Texts []string `json:"texts,omitempty"`
   303  }
   304  
   305  type batchEmbedTextResponse struct {
   306  	Embeddings []embedding   `json:"embeddings,omitempty"`
   307  	Error      *palmApiError `json:"error,omitempty"`
   308  }
   309  
   310  type embedding struct {
   311  	Value []float32 `json:"value,omitempty"`
   312  }