github.com/weaviate/weaviate@v1.24.6/modules/reranker-cohere/clients/ranker_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package clients
    13  
    14  import (
    15  	"context"
    16  	"encoding/json"
    17  	"io"
    18  	"net/http"
    19  	"net/http/httptest"
    20  	"sync"
    21  	"testing"
    22  
    23  	"github.com/sirupsen/logrus"
    24  	"github.com/sirupsen/logrus/hooks/test"
    25  	"github.com/stretchr/testify/assert"
    26  	"github.com/stretchr/testify/require"
    27  	"github.com/weaviate/weaviate/usecases/modulecomponents/ent"
    28  )
    29  
    30  func nullLogger() logrus.FieldLogger {
    31  	l, _ := test.NewNullLogger()
    32  	return l
    33  }
    34  
    35  func TestRank(t *testing.T) {
    36  	t.Run("when the server has a successful response", func(t *testing.T) {
    37  		handler := &testRankHandler{
    38  			t: t,
    39  			response: RankResponse{
    40  				Results: []Result{
    41  					{
    42  						Index:          0,
    43  						RelevanceScore: 0.9,
    44  					},
    45  				},
    46  			},
    47  		}
    48  		server := httptest.NewServer(handler)
    49  		defer server.Close()
    50  
    51  		c := New("apiKey", 0, nullLogger())
    52  		c.host = server.URL
    53  
    54  		expected := &ent.RankResult{
    55  			DocumentScores: []ent.DocumentScore{
    56  				{
    57  					Document: "I work at Apple",
    58  					Score:    0.9,
    59  				},
    60  			},
    61  			Query: "Where do I work?",
    62  		}
    63  
    64  		res, err := c.Rank(context.Background(), "Where do I work?", []string{"I work at Apple"}, nil)
    65  
    66  		assert.Nil(t, err)
    67  		assert.Equal(t, expected, res)
    68  	})
    69  
    70  	t.Run("when the server has an error", func(t *testing.T) {
    71  		handler := &testRankHandler{
    72  			t: t,
    73  			response: RankResponse{
    74  				Results: []Result{},
    75  			},
    76  			errorMessage: "some error from the server",
    77  		}
    78  		server := httptest.NewServer(handler)
    79  		defer server.Close()
    80  
    81  		c := New("apiKey", 0, nullLogger())
    82  		c.host = server.URL
    83  
    84  		_, err := c.Rank(context.Background(), "I work at Apple", []string{"Where do I work?"}, nil)
    85  
    86  		require.NotNil(t, err)
    87  		assert.Contains(t, err.Error(), "some error from the server")
    88  	})
    89  
    90  	t.Run("when we send requests in batches", func(t *testing.T) {
    91  		handler := &testRankHandler{
    92  			t: t,
    93  			batchedResults: [][]Result{
    94  				{
    95  					{
    96  						Index:          0,
    97  						RelevanceScore: 0.99,
    98  					},
    99  					{
   100  						Index:          1,
   101  						RelevanceScore: 0.89,
   102  					},
   103  				},
   104  				{
   105  					{
   106  						Index:          0,
   107  						RelevanceScore: 0.19,
   108  					},
   109  					{
   110  						Index:          1,
   111  						RelevanceScore: 0.29,
   112  					},
   113  				},
   114  				{
   115  					{
   116  						Index:          0,
   117  						RelevanceScore: 0.79,
   118  					},
   119  					{
   120  						Index:          1,
   121  						RelevanceScore: 0.789,
   122  					},
   123  				},
   124  				{
   125  					{
   126  						Index:          0,
   127  						RelevanceScore: 0.0001,
   128  					},
   129  				},
   130  			},
   131  		}
   132  		server := httptest.NewServer(handler)
   133  		defer server.Close()
   134  
   135  		c := New("apiKey", 0, nullLogger())
   136  		c.host = server.URL
   137  		// this will trigger 4 go routines
   138  		c.maxDocuments = 2
   139  
   140  		query := "Where do I work?"
   141  		documents := []string{
   142  			"Response 1", "Response 2", "Response 3", "Response 4",
   143  			"Response 5", "Response 6", "Response 7",
   144  		}
   145  
   146  		resp, err := c.Rank(context.Background(), query, documents, nil)
   147  
   148  		require.Nil(t, err)
   149  		require.NotNil(t, resp)
   150  		require.NotNil(t, resp.DocumentScores)
   151  		for i := range resp.DocumentScores {
   152  			assert.Equal(t, documents[i], resp.DocumentScores[i].Document)
   153  			if i == 0 {
   154  				assert.Equal(t, 0.99, resp.DocumentScores[i].Score)
   155  			}
   156  			if i == len(documents)-1 {
   157  				assert.Equal(t, 0.0001, resp.DocumentScores[i].Score)
   158  			}
   159  		}
   160  	})
   161  }
   162  
   163  type testRankHandler struct {
   164  	lock           sync.RWMutex
   165  	t              *testing.T
   166  	response       RankResponse
   167  	batchedResults [][]Result
   168  	errorMessage   string
   169  }
   170  
   171  func (f *testRankHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   172  	f.lock.Lock()
   173  	defer f.lock.Unlock()
   174  
   175  	if f.errorMessage != "" {
   176  		w.WriteHeader(http.StatusInternalServerError)
   177  		w.Write([]byte(`{"message":"` + f.errorMessage + `"}`))
   178  		return
   179  	}
   180  
   181  	bodyBytes, err := io.ReadAll(r.Body)
   182  	require.Nil(f.t, err)
   183  	defer r.Body.Close()
   184  
   185  	var req RankInput
   186  	require.Nil(f.t, json.Unmarshal(bodyBytes, &req))
   187  
   188  	containsDocument := func(req RankInput, in string) bool {
   189  		for _, doc := range req.Documents {
   190  			if doc == in {
   191  				return true
   192  			}
   193  		}
   194  		return false
   195  	}
   196  
   197  	index := 0
   198  	if len(f.batchedResults) > 0 {
   199  		if containsDocument(req, "Response 3") {
   200  			index = 1
   201  		}
   202  		if containsDocument(req, "Response 5") {
   203  			index = 2
   204  		}
   205  		if containsDocument(req, "Response 7") {
   206  			index = 3
   207  		}
   208  		f.response.Results = f.batchedResults[index]
   209  	}
   210  
   211  	outBytes, err := json.Marshal(f.response)
   212  	require.Nil(f.t, err)
   213  
   214  	w.Write(outBytes)
   215  }