github.com/weaviate/weaviate@v1.24.6/modules/reranker-transformers/clients/ranker.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package client
    13  
    14  import (
    15  	"bytes"
    16  	"context"
    17  	"encoding/json"
    18  	"fmt"
    19  	"io"
    20  	"net/http"
    21  	"runtime"
    22  	"sync"
    23  	"time"
    24  
    25  	enterrors "github.com/weaviate/weaviate/entities/errors"
    26  
    27  	"github.com/pkg/errors"
    28  	"github.com/sirupsen/logrus"
    29  	"github.com/weaviate/weaviate/entities/moduletools"
    30  	"github.com/weaviate/weaviate/usecases/modulecomponents/ent"
    31  )
    32  
    33  var _NUMCPU = runtime.NumCPU()
    34  
    35  type client struct {
    36  	lock         sync.RWMutex
    37  	origin       string
    38  	httpClient   *http.Client
    39  	maxDocuments int
    40  	logger       logrus.FieldLogger
    41  }
    42  
    43  func New(origin string, timeout time.Duration, logger logrus.FieldLogger) *client {
    44  	return &client{
    45  		origin:       origin,
    46  		httpClient:   &http.Client{Timeout: timeout},
    47  		maxDocuments: 32,
    48  		logger:       logger,
    49  	}
    50  }
    51  
    52  func (c *client) Rank(ctx context.Context,
    53  	query string, documents []string, cfg moduletools.ClassConfig,
    54  ) (*ent.RankResult, error) {
    55  	eg := enterrors.NewErrorGroupWrapper(c.logger)
    56  	eg.SetLimit(_NUMCPU)
    57  
    58  	chunkedDocuments := c.chunkDocuments(documents, c.maxDocuments)
    59  	documentScoreResponses := make([][]DocumentScore, len(chunkedDocuments))
    60  	for i := range chunkedDocuments {
    61  		i := i // https://golang.org/doc/faq#closures_and_goroutines
    62  		eg.Go(func() error {
    63  			documentScoreResponse, err := c.performRank(ctx, query, chunkedDocuments[i], cfg)
    64  			if err != nil {
    65  				return err
    66  			}
    67  			c.lockGuard(func() {
    68  				documentScoreResponses[i] = documentScoreResponse
    69  			})
    70  			return nil
    71  		}, chunkedDocuments[i])
    72  	}
    73  	if err := eg.Wait(); err != nil {
    74  		return nil, err
    75  	}
    76  
    77  	return c.toRankResult(query, documentScoreResponses), nil
    78  }
    79  
    80  func (c *client) lockGuard(mutate func()) {
    81  	c.lock.Lock()
    82  	defer c.lock.Unlock()
    83  	mutate()
    84  }
    85  
    86  func (c *client) toRankResult(query string, scores [][]DocumentScore) *ent.RankResult {
    87  	documentScores := []ent.DocumentScore{}
    88  	for _, docScores := range scores {
    89  		for i := range docScores {
    90  			documentScores = append(documentScores, ent.DocumentScore{
    91  				Document: docScores[i].Document,
    92  				Score:    docScores[i].Score,
    93  			})
    94  		}
    95  	}
    96  	return &ent.RankResult{
    97  		Query:          query,
    98  		DocumentScores: documentScores,
    99  	}
   100  }
   101  
   102  func (c *client) performRank(ctx context.Context,
   103  	query string, documents []string, cfg moduletools.ClassConfig,
   104  ) ([]DocumentScore, error) {
   105  	body, err := json.Marshal(RankInput{
   106  		Query:     query,
   107  		Documents: documents,
   108  	})
   109  	if err != nil {
   110  		return nil, errors.Wrapf(err, "marshal body")
   111  	}
   112  
   113  	req, err := http.NewRequestWithContext(ctx, "POST", c.url("/rerank"),
   114  		bytes.NewReader(body))
   115  	if err != nil {
   116  		return nil, errors.Wrap(err, "create POST request")
   117  	}
   118  
   119  	res, err := c.httpClient.Do(req)
   120  	if err != nil {
   121  		return nil, errors.Wrap(err, "send POST request")
   122  	}
   123  	defer res.Body.Close()
   124  
   125  	bodyBytes, err := io.ReadAll(res.Body)
   126  	if err != nil {
   127  		return nil, errors.Wrap(err, "read response body")
   128  	}
   129  
   130  	var resBody RankResponse
   131  	if err := json.Unmarshal(bodyBytes, &resBody); err != nil {
   132  		return nil, errors.Wrap(err, "unmarshal response body")
   133  	}
   134  
   135  	if res.StatusCode != 200 {
   136  		if resBody.Error != "" {
   137  			return nil, errors.Errorf("fail with status %d: %s", res.StatusCode,
   138  				resBody.Error)
   139  		}
   140  		return nil, errors.Errorf("fail with status %d", res.StatusCode)
   141  	}
   142  
   143  	return resBody.Scores, nil
   144  }
   145  
   146  func (c *client) chunkDocuments(documents []string, chunkSize int) [][]string {
   147  	var requests [][]string
   148  	for i := 0; i < len(documents); i += chunkSize {
   149  		end := i + chunkSize
   150  
   151  		if end > len(documents) {
   152  			end = len(documents)
   153  		}
   154  
   155  		requests = append(requests, documents[i:end])
   156  	}
   157  
   158  	return requests
   159  }
   160  
   161  func (c *client) url(path string) string {
   162  	return fmt.Sprintf("%s%s", c.origin, path)
   163  }
   164  
   165  type RankInput struct {
   166  	Query             string   `json:"query"`
   167  	Documents         []string `json:"documents"`
   168  	RankPropertyValue string   `json:"property"`
   169  }
   170  
   171  type DocumentScore struct {
   172  	Document string  `json:"document"`
   173  	Score    float64 `json:"score"`
   174  }
   175  
   176  type RankResponse struct {
   177  	Query             string          `json:"query"`
   178  	Scores            []DocumentScore `json:"scores"`
   179  	RankPropertyValue string          `json:"property"`
   180  	Score             float64         `json:"score"`
   181  	Error             string          `json:"error"`
   182  }