github.com/weaviate/weaviate@v1.24.6/modules/multi2vec-palm/clients/palm.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package clients
    13  
    14  import (
    15  	"bytes"
    16  	"context"
    17  	"encoding/json"
    18  	"fmt"
    19  	"io"
    20  	"net/http"
    21  	"time"
    22  
    23  	"github.com/weaviate/weaviate/usecases/modulecomponents"
    24  
    25  	"github.com/pkg/errors"
    26  	"github.com/sirupsen/logrus"
    27  	"github.com/weaviate/weaviate/modules/multi2vec-palm/ent"
    28  	libvectorizer "github.com/weaviate/weaviate/usecases/vectorizer"
    29  )
    30  
    31  func buildURL(location, projectID, model string) string {
    32  	return fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predict",
    33  		location, projectID, location, model)
    34  }
    35  
    36  type palm struct {
    37  	apiKey       string
    38  	httpClient   *http.Client
    39  	urlBuilderFn func(location, projectID, model string) string
    40  	logger       logrus.FieldLogger
    41  }
    42  
    43  func New(apiKey string, timeout time.Duration, logger logrus.FieldLogger) *palm {
    44  	return &palm{
    45  		apiKey: apiKey,
    46  		httpClient: &http.Client{
    47  			Timeout: timeout,
    48  		},
    49  		urlBuilderFn: buildURL,
    50  		logger:       logger,
    51  	}
    52  }
    53  
    54  func (v *palm) Vectorize(ctx context.Context,
    55  	texts, images, videos []string, config ent.VectorizationConfig,
    56  ) (*ent.VectorizationResult, error) {
    57  	return v.vectorize(ctx, texts, images, videos, config)
    58  }
    59  
    60  func (v *palm) VectorizeQuery(ctx context.Context, input []string,
    61  	config ent.VectorizationConfig,
    62  ) (*ent.VectorizationResult, error) {
    63  	return v.vectorize(ctx, input, nil, nil, config)
    64  }
    65  
    66  func (v *palm) vectorize(ctx context.Context,
    67  	texts, images, videos []string, config ent.VectorizationConfig,
    68  ) (*ent.VectorizationResult, error) {
    69  	var textEmbeddings [][]float32
    70  	var imageEmbeddings [][]float32
    71  	var videoEmbeddings [][]float32
    72  	endpointURL := v.getURL(config)
    73  	maxCount := max(len(texts), len(images), len(videos))
    74  	for i := 0; i < maxCount; i++ {
    75  		text := v.safelyGet(texts, i)
    76  		image := v.safelyGet(images, i)
    77  		video := v.safelyGet(videos, i)
    78  		payload := v.getPayload(text, image, video, config)
    79  		statusCode, res, err := v.sendRequest(ctx, endpointURL, payload)
    80  		if err != nil {
    81  			return nil, err
    82  		}
    83  		textVectors, imageVectors, videoVectors, err := v.getEmbeddingsFromResponse(statusCode, res)
    84  		if err != nil {
    85  			return nil, err
    86  		}
    87  		textEmbeddings = append(textEmbeddings, textVectors...)
    88  		imageEmbeddings = append(imageEmbeddings, imageVectors...)
    89  		videoEmbeddings = append(videoEmbeddings, videoVectors...)
    90  	}
    91  
    92  	return v.getResponse(textEmbeddings, imageEmbeddings, videoEmbeddings)
    93  }
    94  
    95  func (v *palm) safelyGet(input []string, i int) string {
    96  	if i < len(input) {
    97  		return input[i]
    98  	}
    99  	return ""
   100  }
   101  
   102  func (v *palm) sendRequest(ctx context.Context,
   103  	endpointURL string, payload embeddingsRequest,
   104  ) (int, embeddingsResponse, error) {
   105  	body, err := json.Marshal(payload)
   106  	if err != nil {
   107  		return 0, embeddingsResponse{}, errors.Wrapf(err, "marshal body")
   108  	}
   109  
   110  	req, err := http.NewRequestWithContext(ctx, "POST", endpointURL,
   111  		bytes.NewReader(body))
   112  	if err != nil {
   113  		return 0, embeddingsResponse{}, errors.Wrap(err, "create POST request")
   114  	}
   115  
   116  	apiKey, err := v.getApiKey(ctx)
   117  	if err != nil {
   118  		return 0, embeddingsResponse{}, errors.Wrapf(err, "Google API Key")
   119  	}
   120  	req.Header.Add("Content-Type", "application/json; charset=utf-8")
   121  	req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", apiKey))
   122  
   123  	res, err := v.httpClient.Do(req)
   124  	if err != nil {
   125  		return 0, embeddingsResponse{}, errors.Wrap(err, "send POST request")
   126  	}
   127  	defer res.Body.Close()
   128  
   129  	bodyBytes, err := io.ReadAll(res.Body)
   130  	if err != nil {
   131  		return 0, embeddingsResponse{}, errors.Wrap(err, "read response body")
   132  	}
   133  
   134  	var resBody embeddingsResponse
   135  	if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
   136  		return 0, embeddingsResponse{}, errors.Wrap(err, "unmarshal response body")
   137  	}
   138  
   139  	return res.StatusCode, resBody, nil
   140  }
   141  
   142  func (v *palm) getURL(config ent.VectorizationConfig) string {
   143  	return v.urlBuilderFn(config.Location, config.ProjectID, config.Model)
   144  }
   145  
   146  func (v *palm) getPayload(text, img, vid string, config ent.VectorizationConfig) embeddingsRequest {
   147  	inst := instance{}
   148  	if text != "" {
   149  		inst.Text = &text
   150  	}
   151  	if img != "" {
   152  		inst.Image = &image{BytesBase64Encoded: img}
   153  	}
   154  	if vid != "" {
   155  		inst.Video = &video{
   156  			BytesBase64Encoded: vid,
   157  			VideoSegmentConfig: videoSegmentConfig{IntervalSec: &config.VideoIntervalSeconds},
   158  		}
   159  	}
   160  	return embeddingsRequest{
   161  		Instances:  []instance{inst},
   162  		Parameters: parameters{Dimension: config.Dimensions},
   163  	}
   164  }
   165  
   166  func (v *palm) checkResponse(statusCode int, palmApiError *palmApiError) error {
   167  	if statusCode != 200 || palmApiError != nil {
   168  		if palmApiError != nil {
   169  			return fmt.Errorf("connection to Google failed with status: %v error: %v",
   170  				statusCode, palmApiError.Message)
   171  		}
   172  		return fmt.Errorf("connection to Google failed with status: %d", statusCode)
   173  	}
   174  	return nil
   175  }
   176  
   177  func (v *palm) getApiKey(ctx context.Context) (string, error) {
   178  	if apiKeyValue := v.getValueFromContext(ctx, "X-Google-Api-Key"); apiKeyValue != "" {
   179  		return apiKeyValue, nil
   180  	}
   181  	if apiKeyValue := v.getValueFromContext(ctx, "X-Palm-Api-Key"); apiKeyValue != "" {
   182  		return apiKeyValue, nil
   183  	}
   184  	if len(v.apiKey) > 0 {
   185  		return v.apiKey, nil
   186  	}
   187  	return "", errors.New("no api key found " +
   188  		"neither in request header: X-Palm-Api-Key or X-Google-Api-Key " +
   189  		"nor in environment variable under PALM_APIKEY or GOOGLE_APIKEY")
   190  }
   191  
   192  func (v *palm) getValueFromContext(ctx context.Context, key string) string {
   193  	if value := ctx.Value(key); value != nil {
   194  		if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 {
   195  			return keyHeader[0]
   196  		}
   197  	}
   198  	// try getting header from GRPC if not successful
   199  	if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 {
   200  		return apiKey[0]
   201  	}
   202  	return ""
   203  }
   204  
   205  func (v *palm) getEmbeddingsFromResponse(statusCode int, resBody embeddingsResponse) (
   206  	textEmbeddings [][]float32,
   207  	imageEmbeddings [][]float32,
   208  	videoEmbeddings [][]float32,
   209  	err error,
   210  ) {
   211  	if respErr := v.checkResponse(statusCode, resBody.Error); respErr != nil {
   212  		err = respErr
   213  		return
   214  	}
   215  
   216  	if len(resBody.Predictions) == 0 {
   217  		err = errors.Errorf("empty embeddings response")
   218  		return
   219  	}
   220  
   221  	for _, p := range resBody.Predictions {
   222  		if len(p.TextEmbedding) > 0 {
   223  			textEmbeddings = append(textEmbeddings, p.TextEmbedding)
   224  		}
   225  		if len(p.ImageEmbedding) > 0 {
   226  			imageEmbeddings = append(imageEmbeddings, p.ImageEmbedding)
   227  		}
   228  		if len(p.VideoEmbeddings) > 0 {
   229  			var embeddings [][]float32
   230  			for _, videoEmbedding := range p.VideoEmbeddings {
   231  				embeddings = append(embeddings, videoEmbedding.Embedding)
   232  			}
   233  			embedding := embeddings[0]
   234  			if len(embeddings) > 1 {
   235  				embedding = libvectorizer.CombineVectors(embeddings)
   236  			}
   237  			videoEmbeddings = append(videoEmbeddings, embedding)
   238  		}
   239  	}
   240  	return
   241  }
   242  
   243  func (v *palm) getResponse(textVectors, imageVectors, videoVectors [][]float32) (*ent.VectorizationResult, error) {
   244  	return &ent.VectorizationResult{
   245  		TextVectors:  textVectors,
   246  		ImageVectors: imageVectors,
   247  		VideoVectors: videoVectors,
   248  	}, nil
   249  }
   250  
   251  type embeddingsRequest struct {
   252  	Instances  []instance `json:"instances,omitempty"`
   253  	Parameters parameters `json:"parameters,omitempty"`
   254  }
   255  
   256  type parameters struct {
   257  	Dimension int64 `json:"dimension,omitempty"`
   258  }
   259  
   260  type instance struct {
   261  	Text  *string `json:"text,omitempty"`
   262  	Image *image  `json:"image,omitempty"`
   263  	Video *video  `json:"video,omitempty"`
   264  }
   265  
   266  type image struct {
   267  	BytesBase64Encoded string `json:"bytesBase64Encoded"`
   268  }
   269  
   270  type video struct {
   271  	BytesBase64Encoded string             `json:"bytesBase64Encoded"`
   272  	VideoSegmentConfig videoSegmentConfig `json:"videoSegmentConfig"`
   273  }
   274  
   275  type videoSegmentConfig struct {
   276  	StartOffsetSec *int64 `json:"startOffsetSec,omitempty"`
   277  	EndOffsetSec   *int64 `json:"endOffsetSec,omitempty"`
   278  	IntervalSec    *int64 `json:"intervalSec,omitempty"`
   279  }
   280  
   281  type embeddingsResponse struct {
   282  	Predictions     []prediction  `json:"predictions,omitempty"`
   283  	Error           *palmApiError `json:"error,omitempty"`
   284  	DeployedModelId string        `json:"deployedModelId,omitempty"`
   285  }
   286  
   287  type prediction struct {
   288  	TextEmbedding   []float32        `json:"textEmbedding,omitempty"`
   289  	ImageEmbedding  []float32        `json:"imageEmbedding,omitempty"`
   290  	VideoEmbeddings []videoEmbedding `json:"videoEmbeddings,omitempty"`
   291  }
   292  
   293  type videoEmbedding struct {
   294  	StartOffsetSec *int64    `json:"startOffsetSec,omitempty"`
   295  	EndOffsetSec   *int64    `json:"endOffsetSec,omitempty"`
   296  	Embedding      []float32 `json:"embedding,omitempty"`
   297  }
   298  
   299  type palmApiError struct {
   300  	Code    int    `json:"code"`
   301  	Message string `json:"message"`
   302  	Status  string `json:"status"`
   303  }