github.com/weaviate/weaviate@v1.24.6/modules/reranker-cohere/clients/ranker.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package clients
    13  
    14  import (
    15  	"bytes"
    16  	"context"
    17  	"encoding/json"
    18  	"fmt"
    19  	"io"
    20  	"net/http"
    21  	"net/url"
    22  	"runtime"
    23  	"sync"
    24  	"time"
    25  
    26  	enterrors "github.com/weaviate/weaviate/entities/errors"
    27  
    28  	"github.com/weaviate/weaviate/usecases/modulecomponents"
    29  
    30  	"github.com/pkg/errors"
    31  	"github.com/sirupsen/logrus"
    32  	"github.com/weaviate/weaviate/entities/moduletools"
    33  	"github.com/weaviate/weaviate/modules/reranker-cohere/config"
    34  	"github.com/weaviate/weaviate/usecases/modulecomponents/ent"
    35  )
    36  
    37  var _NUMCPU = runtime.NumCPU()
    38  
    39  type client struct {
    40  	lock         sync.RWMutex
    41  	apiKey       string
    42  	host         string
    43  	path         string
    44  	httpClient   *http.Client
    45  	maxDocuments int
    46  	logger       logrus.FieldLogger
    47  }
    48  
    49  func New(apiKey string, timeout time.Duration, logger logrus.FieldLogger) *client {
    50  	return &client{
    51  		apiKey:       apiKey,
    52  		httpClient:   &http.Client{Timeout: timeout},
    53  		host:         "https://api.cohere.ai",
    54  		path:         "/v1/rerank",
    55  		maxDocuments: 1000,
    56  		logger:       logger,
    57  	}
    58  }
    59  
    60  func (c *client) Rank(ctx context.Context, query string, documents []string,
    61  	cfg moduletools.ClassConfig,
    62  ) (*ent.RankResult, error) {
    63  	eg := enterrors.NewErrorGroupWrapper(c.logger)
    64  	eg.SetLimit(_NUMCPU)
    65  
    66  	chunkedDocuments := c.chunkDocuments(documents, c.maxDocuments)
    67  	documentScoreResponses := make([][]ent.DocumentScore, len(chunkedDocuments))
    68  	for i := range chunkedDocuments {
    69  		i := i // https://golang.org/doc/faq#closures_and_goroutines
    70  		eg.Go(func() error {
    71  			documentScoreResponse, err := c.performRank(ctx, query, chunkedDocuments[i], cfg)
    72  			if err != nil {
    73  				return err
    74  			}
    75  			c.lockGuard(func() {
    76  				documentScoreResponses[i] = documentScoreResponse
    77  			})
    78  			return nil
    79  		}, chunkedDocuments[i])
    80  	}
    81  	if err := eg.Wait(); err != nil {
    82  		return nil, err
    83  	}
    84  
    85  	return c.toRankResult(query, documentScoreResponses), nil
    86  }
    87  
    88  func (c *client) lockGuard(mutate func()) {
    89  	c.lock.Lock()
    90  	defer c.lock.Unlock()
    91  	mutate()
    92  }
    93  
    94  func (c *client) performRank(ctx context.Context, query string, documents []string,
    95  	cfg moduletools.ClassConfig,
    96  ) ([]ent.DocumentScore, error) {
    97  	settings := config.NewClassSettings(cfg)
    98  	cohereUrl, err := url.JoinPath(c.host, c.path)
    99  	if err != nil {
   100  		return nil, errors.Wrap(err, "join Cohere API host and path")
   101  	}
   102  
   103  	input := RankInput{
   104  		Documents:       documents,
   105  		Query:           query,
   106  		Model:           settings.Model(),
   107  		ReturnDocuments: false,
   108  	}
   109  
   110  	body, err := json.Marshal(input)
   111  	if err != nil {
   112  		return nil, errors.Wrapf(err, "marshal body")
   113  	}
   114  
   115  	req, err := http.NewRequestWithContext(ctx, "POST", cohereUrl, bytes.NewReader(body))
   116  	if err != nil {
   117  		return nil, errors.Wrap(err, "create POST request")
   118  	}
   119  
   120  	apiKey, err := c.getApiKey(ctx)
   121  	if err != nil {
   122  		return nil, errors.Wrapf(err, "Cohere API Key")
   123  	}
   124  	req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", apiKey))
   125  	req.Header.Add("Content-Type", "application/json")
   126  	req.Header.Add("Request-Source", "unspecified:weaviate")
   127  
   128  	res, err := c.httpClient.Do(req)
   129  	if err != nil {
   130  		return nil, errors.Wrap(err, "send POST request")
   131  	}
   132  	defer res.Body.Close()
   133  
   134  	bodyBytes, err := io.ReadAll(res.Body)
   135  	if err != nil {
   136  		return nil, errors.Wrap(err, "read response body")
   137  	}
   138  
   139  	if res.StatusCode != 200 {
   140  		var apiError cohereApiError
   141  		err = json.Unmarshal(bodyBytes, &apiError)
   142  		if err != nil {
   143  			return nil, errors.Wrap(err, "unmarshal error from response body")
   144  		}
   145  		if apiError.Message != "" {
   146  			return nil, errors.Errorf("connection to Cohere API failed with status %d: %s", res.StatusCode, apiError.Message)
   147  		}
   148  		return nil, errors.Errorf("connection to Cohere API failed with status %d", res.StatusCode)
   149  	}
   150  
   151  	var rankResponse RankResponse
   152  	if err := json.Unmarshal(bodyBytes, &rankResponse); err != nil {
   153  		return nil, errors.Wrap(err, "unmarshal response body")
   154  	}
   155  	return c.toDocumentScores(documents, rankResponse.Results), nil
   156  }
   157  
   158  func (c *client) chunkDocuments(documents []string, chunkSize int) [][]string {
   159  	var requests [][]string
   160  	for i := 0; i < len(documents); i += chunkSize {
   161  		end := i + chunkSize
   162  
   163  		if end > len(documents) {
   164  			end = len(documents)
   165  		}
   166  
   167  		requests = append(requests, documents[i:end])
   168  	}
   169  
   170  	return requests
   171  }
   172  
   173  func (c *client) toDocumentScores(documents []string, results []Result) []ent.DocumentScore {
   174  	documentScores := make([]ent.DocumentScore, len(results))
   175  	for _, result := range results {
   176  		documentScores[result.Index] = ent.DocumentScore{
   177  			Document: documents[result.Index],
   178  			Score:    result.RelevanceScore,
   179  		}
   180  	}
   181  	return documentScores
   182  }
   183  
   184  func (c *client) toRankResult(query string, results [][]ent.DocumentScore) *ent.RankResult {
   185  	documentScores := []ent.DocumentScore{}
   186  	for i := range results {
   187  		documentScores = append(documentScores, results[i]...)
   188  	}
   189  	return &ent.RankResult{
   190  		Query:          query,
   191  		DocumentScores: documentScores,
   192  	}
   193  }
   194  
   195  func (c *client) getApiKey(ctx context.Context) (string, error) {
   196  	if len(c.apiKey) > 0 {
   197  		return c.apiKey, nil
   198  	}
   199  	key := "X-Cohere-Api-Key"
   200  
   201  	apiKey := ctx.Value(key)
   202  	// try getting header from GRPC if not successful
   203  	if apiKey == nil {
   204  		apiKey = modulecomponents.GetValueFromGRPC(ctx, key)
   205  	}
   206  	if apiKeyHeader, ok := apiKey.([]string); ok &&
   207  		len(apiKeyHeader) > 0 && len(apiKeyHeader[0]) > 0 {
   208  		return apiKeyHeader[0], nil
   209  	}
   210  	return "", errors.New("no api key found " +
   211  		"neither in request header: X-Cohere-Api-Key " +
   212  		"nor in environment variable under COHERE_APIKEY")
   213  }
   214  
   215  type RankInput struct {
   216  	Documents       []string `json:"documents"`
   217  	Query           string   `json:"query"`
   218  	Model           string   `json:"model"`
   219  	ReturnDocuments bool     `json:"return_documents"`
   220  }
   221  
   222  type Document struct {
   223  	Text string `json:"text"`
   224  }
   225  
   226  type Result struct {
   227  	Index          int      `json:"index"`
   228  	RelevanceScore float64  `json:"relevance_score"`
   229  	Document       Document `json:"document"`
   230  }
   231  
   232  type APIVersion struct {
   233  	Version string `json:"version"`
   234  }
   235  
   236  type Meta struct {
   237  	APIVersion APIVersion `json:"api_version"`
   238  }
   239  
   240  type RankResponse struct {
   241  	ID      string   `json:"id"`
   242  	Results []Result `json:"results"`
   243  	Meta    Meta     `json:"meta"`
   244  }
   245  
   246  type cohereApiError struct {
   247  	Message string `json:"message"`
   248  }