github.com/weaviate/weaviate@v1.24.6/modules/text2vec-voyageai/clients/voyageai.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package clients
    13  
    14  import (
    15  	"bytes"
    16  	"context"
    17  	"encoding/json"
    18  	"fmt"
    19  	"io"
    20  	"net/http"
    21  	"time"
    22  
    23  	"github.com/weaviate/weaviate/usecases/modulecomponents"
    24  
    25  	"github.com/pkg/errors"
    26  	"github.com/sirupsen/logrus"
    27  	"github.com/weaviate/weaviate/modules/text2vec-voyageai/ent"
    28  )
    29  
    30  type embeddingsRequest struct {
    31  	Input      []string  `json:"input"`
    32  	Model      string    `json:"model"`
    33  	Truncation bool      `json:"truncation,omitempty"`
    34  	InputType  inputType `json:"input_type,omitempty"`
    35  }
    36  
    37  type embeddingsDataResponse struct {
    38  	Embeddings []float32 `json:"embedding"`
    39  }
    40  
    41  type embeddingsResponse struct {
    42  	Data   []embeddingsDataResponse `json:"data,omitempty"`
    43  	Model  string                   `json:"model,omitempty"`
    44  	Detail string                   `json:"detail,omitempty"`
    45  }
    46  
    47  type vectorizer struct {
    48  	apiKey     string
    49  	httpClient *http.Client
    50  	urlBuilder *voyageaiUrlBuilder
    51  	logger     logrus.FieldLogger
    52  }
    53  
    54  type inputType string
    55  
    56  const (
    57  	searchDocument inputType = "document"
    58  	searchQuery    inputType = "query"
    59  )
    60  
    61  func New(apiKey string, timeout time.Duration, logger logrus.FieldLogger) *vectorizer {
    62  	return &vectorizer{
    63  		apiKey: apiKey,
    64  		httpClient: &http.Client{
    65  			Timeout: timeout,
    66  		},
    67  		urlBuilder: newVoyageAIUrlBuilder(),
    68  		logger:     logger,
    69  	}
    70  }
    71  
    72  func (v *vectorizer) Vectorize(ctx context.Context, input []string,
    73  	config ent.VectorizationConfig,
    74  ) (*ent.VectorizationResult, error) {
    75  	return v.vectorize(ctx, input, config.Model, config.Truncate, config.BaseURL, searchDocument)
    76  }
    77  
    78  func (v *vectorizer) VectorizeQuery(ctx context.Context, input []string,
    79  	config ent.VectorizationConfig,
    80  ) (*ent.VectorizationResult, error) {
    81  	return v.vectorize(ctx, input, config.Model, config.Truncate, config.BaseURL, searchQuery)
    82  }
    83  
    84  func (v *vectorizer) vectorize(ctx context.Context, input []string,
    85  	model string, truncate bool, baseURL string, inputType inputType,
    86  ) (*ent.VectorizationResult, error) {
    87  	body, err := json.Marshal(embeddingsRequest{
    88  		Input:      input,
    89  		Model:      model,
    90  		Truncation: truncate,
    91  		InputType:  inputType,
    92  	})
    93  	if err != nil {
    94  		return nil, errors.Wrapf(err, "marshal body")
    95  	}
    96  
    97  	url := v.getVoyageAIUrl(ctx, baseURL)
    98  	req, err := http.NewRequestWithContext(ctx, "POST", url,
    99  		bytes.NewReader(body))
   100  	if err != nil {
   101  		return nil, errors.Wrap(err, "create POST request")
   102  	}
   103  	apiKey, err := v.getApiKey(ctx)
   104  	if err != nil {
   105  		return nil, errors.Wrapf(err, "VoyageAI API Key")
   106  	}
   107  	req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", apiKey))
   108  	req.Header.Add("Content-Type", "application/json")
   109  
   110  	res, err := v.httpClient.Do(req)
   111  	if err != nil {
   112  		return nil, errors.Wrap(err, "send POST request")
   113  	}
   114  	defer res.Body.Close()
   115  	bodyBytes, err := io.ReadAll(res.Body)
   116  	if err != nil {
   117  		return nil, errors.Wrap(err, "read response body")
   118  	}
   119  	var resBody embeddingsResponse
   120  	if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
   121  		return nil, errors.Wrap(err, "unmarshal response body")
   122  	}
   123  
   124  	if res.StatusCode != 200 {
   125  		if resBody.Detail != "" {
   126  			errorMessage := getErrorMessage(res.StatusCode, resBody.Detail, "connection to VoyageAI failed with status: %d error: %v")
   127  			return nil, errors.Errorf(errorMessage)
   128  		}
   129  		errorMessage := getErrorMessage(res.StatusCode, "", "connection to VoyageAI failed with status: %d")
   130  		return nil, errors.Errorf(errorMessage)
   131  	}
   132  
   133  	if len(resBody.Data) == 0 || len(resBody.Data[0].Embeddings) == 0 {
   134  		return nil, errors.Errorf("empty embeddings response")
   135  	}
   136  
   137  	vectors := make([][]float32, len(resBody.Data))
   138  	for i, data := range resBody.Data {
   139  		vectors[i] = data.Embeddings
   140  	}
   141  
   142  	return &ent.VectorizationResult{
   143  		Text:       input,
   144  		Dimensions: len(resBody.Data[0].Embeddings),
   145  		Vectors:    vectors,
   146  	}, nil
   147  }
   148  
   149  func (v *vectorizer) getVoyageAIUrl(ctx context.Context, baseURL string) string {
   150  	passedBaseURL := baseURL
   151  	if headerBaseURL := v.getValueFromContext(ctx, "X-Voyageai-Baseurl"); headerBaseURL != "" {
   152  		passedBaseURL = headerBaseURL
   153  	}
   154  	return v.urlBuilder.url(passedBaseURL)
   155  }
   156  
   157  func getErrorMessage(statusCode int, resBodyError string, errorTemplate string) string {
   158  	if resBodyError != "" {
   159  		return fmt.Sprintf(errorTemplate, statusCode, resBodyError)
   160  	}
   161  	return fmt.Sprintf(errorTemplate, statusCode)
   162  }
   163  
   164  func (v *vectorizer) getValueFromContext(ctx context.Context, key string) string {
   165  	if value := ctx.Value(key); value != nil {
   166  		if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 {
   167  			return keyHeader[0]
   168  		}
   169  	}
   170  	// try getting header from GRPC if not successful
   171  	if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 {
   172  		return apiKey[0]
   173  	}
   174  	return ""
   175  }
   176  
   177  func (v *vectorizer) getApiKey(ctx context.Context) (string, error) {
   178  	if apiKey := v.getValueFromContext(ctx, "X-Voyageai-Api-Key"); apiKey != "" {
   179  		return apiKey, nil
   180  	}
   181  	if v.apiKey != "" {
   182  		return v.apiKey, nil
   183  	}
   184  	return "", errors.New("no api key found " +
   185  		"neither in request header: X-VoyageAI-Api-Key " +
   186  		"nor in environment variable under VOYAGEAI_APIKEY")
   187  }