github.com/weaviate/weaviate@v1.24.6/modules/text2vec-cohere/clients/cohere.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package clients
    13  
    14  import (
    15  	"bytes"
    16  	"context"
    17  	"encoding/json"
    18  	"fmt"
    19  	"io"
    20  	"net/http"
    21  	"time"
    22  
    23  	"github.com/weaviate/weaviate/usecases/modulecomponents"
    24  
    25  	"github.com/pkg/errors"
    26  	"github.com/sirupsen/logrus"
    27  	"github.com/weaviate/weaviate/modules/text2vec-cohere/ent"
    28  )
    29  
    30  type embeddingsRequest struct {
    31  	Texts     []string  `json:"texts"`
    32  	Model     string    `json:"model,omitempty"`
    33  	Truncate  string    `json:"truncate,omitempty"`
    34  	InputType inputType `json:"input_type,omitempty"`
    35  }
    36  
    37  type embeddingsResponse struct {
    38  	Embeddings [][]float32 `json:"embeddings,omitempty"`
    39  	Message    string      `json:"message,omitempty"`
    40  }
    41  
    42  type vectorizer struct {
    43  	apiKey     string
    44  	httpClient *http.Client
    45  	urlBuilder *cohereUrlBuilder
    46  	logger     logrus.FieldLogger
    47  }
    48  
    49  type inputType string
    50  
    51  const (
    52  	searchDocument inputType = "search_document"
    53  	searchQuery    inputType = "search_query"
    54  )
    55  
    56  func New(apiKey string, timeout time.Duration, logger logrus.FieldLogger) *vectorizer {
    57  	return &vectorizer{
    58  		apiKey: apiKey,
    59  		httpClient: &http.Client{
    60  			Timeout: timeout,
    61  		},
    62  		urlBuilder: newCohereUrlBuilder(),
    63  		logger:     logger,
    64  	}
    65  }
    66  
    67  func (v *vectorizer) Vectorize(ctx context.Context, input []string,
    68  	config ent.VectorizationConfig,
    69  ) (*ent.VectorizationResult, error) {
    70  	return v.vectorize(ctx, input, config.Model, config.Truncate, config.BaseURL, searchDocument)
    71  }
    72  
    73  func (v *vectorizer) VectorizeQuery(ctx context.Context, input []string,
    74  	config ent.VectorizationConfig,
    75  ) (*ent.VectorizationResult, error) {
    76  	return v.vectorize(ctx, input, config.Model, config.Truncate, config.BaseURL, searchQuery)
    77  }
    78  
    79  func (v *vectorizer) vectorize(ctx context.Context, input []string,
    80  	model, truncate, baseURL string, inputType inputType,
    81  ) (*ent.VectorizationResult, error) {
    82  	body, err := json.Marshal(embeddingsRequest{
    83  		Texts:     input,
    84  		Model:     model,
    85  		Truncate:  truncate,
    86  		InputType: inputType,
    87  	})
    88  	if err != nil {
    89  		return nil, errors.Wrapf(err, "marshal body")
    90  	}
    91  
    92  	url := v.getCohereUrl(ctx, baseURL)
    93  	req, err := http.NewRequestWithContext(ctx, "POST", url,
    94  		bytes.NewReader(body))
    95  	if err != nil {
    96  		return nil, errors.Wrap(err, "create POST request")
    97  	}
    98  	apiKey, err := v.getApiKey(ctx)
    99  	if err != nil {
   100  		return nil, errors.Wrapf(err, "Cohere API Key")
   101  	}
   102  	req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", apiKey))
   103  	req.Header.Add("Content-Type", "application/json")
   104  	req.Header.Add("Request-Source", "unspecified:weaviate")
   105  
   106  	res, err := v.httpClient.Do(req)
   107  	if err != nil {
   108  		return nil, errors.Wrap(err, "send POST request")
   109  	}
   110  	defer res.Body.Close()
   111  	bodyBytes, err := io.ReadAll(res.Body)
   112  	if err != nil {
   113  		return nil, errors.Wrap(err, "read response body")
   114  	}
   115  	var resBody embeddingsResponse
   116  	if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
   117  		return nil, errors.Wrap(err, "unmarshal response body")
   118  	}
   119  
   120  	if res.StatusCode >= 500 {
   121  		errorMessage := getErrorMessage(res.StatusCode, resBody.Message, "connection to Cohere failed with status: %d error: %v")
   122  		return nil, errors.Errorf(errorMessage)
   123  	} else if res.StatusCode > 200 {
   124  		errorMessage := getErrorMessage(res.StatusCode, resBody.Message, "failed with status: %d error: %v")
   125  		return nil, errors.Errorf(errorMessage)
   126  	}
   127  
   128  	if len(resBody.Embeddings) == 0 {
   129  		return nil, errors.Errorf("empty embeddings response")
   130  	}
   131  
   132  	return &ent.VectorizationResult{
   133  		Text:       input,
   134  		Dimensions: len(resBody.Embeddings[0]),
   135  		Vectors:    resBody.Embeddings,
   136  	}, nil
   137  }
   138  
   139  func (v *vectorizer) getCohereUrl(ctx context.Context, baseURL string) string {
   140  	passedBaseURL := baseURL
   141  	if headerBaseURL := v.getValueFromContext(ctx, "X-Cohere-Baseurl"); headerBaseURL != "" {
   142  		passedBaseURL = headerBaseURL
   143  	}
   144  	return v.urlBuilder.url(passedBaseURL)
   145  }
   146  
   147  func getErrorMessage(statusCode int, resBodyError string, errorTemplate string) string {
   148  	if resBodyError != "" {
   149  		return fmt.Sprintf(errorTemplate, statusCode, resBodyError)
   150  	}
   151  	return fmt.Sprintf(errorTemplate, statusCode)
   152  }
   153  
   154  func (v *vectorizer) getValueFromContext(ctx context.Context, key string) string {
   155  	if value := ctx.Value(key); value != nil {
   156  		if keyHeader, ok := value.([]string); ok && len(keyHeader) > 0 && len(keyHeader[0]) > 0 {
   157  			return keyHeader[0]
   158  		}
   159  	}
   160  	// try getting header from GRPC if not successful
   161  	if apiKey := modulecomponents.GetValueFromGRPC(ctx, key); len(apiKey) > 0 && len(apiKey[0]) > 0 {
   162  		return apiKey[0]
   163  	}
   164  	return ""
   165  }
   166  
   167  func (v *vectorizer) getApiKey(ctx context.Context) (string, error) {
   168  	if apiKey := v.getValueFromContext(ctx, "X-Cohere-Api-Key"); apiKey != "" {
   169  		return apiKey, nil
   170  	}
   171  	if v.apiKey != "" {
   172  		return v.apiKey, nil
   173  	}
   174  	return "", errors.New("no api key found " +
   175  		"neither in request header: X-Cohere-Api-Key " +
   176  		"nor in environment variable under COHERE_APIKEY")
   177  }