github.com/weaviate/weaviate@v1.24.6/modules/reranker-transformers/clients/ranker_test.go (about)

     1  //                           _       _
     2  // __      _____  __ ___   ___  __ _| |_ ___
     3  // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \
     4  //  \ V  V /  __/ (_| |\ V /| | (_| | ||  __/
     5  //   \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___|
     6  //
     7  //  Copyright © 2016 - 2024 Weaviate B.V. All rights reserved.
     8  //
     9  //  CONTACT: hello@weaviate.io
    10  //
    11  
    12  package client
    13  
    14  import (
    15  	"context"
    16  	"encoding/json"
    17  	"io"
    18  	"net/http"
    19  	"net/http/httptest"
    20  	"sync"
    21  	"testing"
    22  
    23  	"github.com/stretchr/testify/assert"
    24  	"github.com/stretchr/testify/require"
    25  	"github.com/weaviate/weaviate/usecases/modulecomponents/ent"
    26  )
    27  
    28  func TestGetScore(t *testing.T) {
    29  	t.Run("when the server has a successful answer", func(t *testing.T) {
    30  		server := httptest.NewServer(&testCrossRankerHandler{
    31  			t: t,
    32  			res: RankResponse{
    33  				Query: "Where do I work?",
    34  				Scores: []DocumentScore{
    35  					{
    36  						Document: "I work at Apple",
    37  						Score:    0.15,
    38  					},
    39  				},
    40  			},
    41  		})
    42  		defer server.Close()
    43  		c := New(server.URL, 0, nullLogger())
    44  		res, err := c.Rank(context.Background(), "Where do I work?", []string{"I work at Apple"}, nil)
    45  
    46  		assert.Nil(t, err)
    47  		assert.Equal(t, ent.RankResult{
    48  			Query: "Where do I work?",
    49  			DocumentScores: []ent.DocumentScore{
    50  				{
    51  					Document: "I work at Apple",
    52  					Score:    0.15,
    53  				},
    54  			},
    55  		}, *res)
    56  	})
    57  
    58  	t.Run("when the server has an error", func(t *testing.T) {
    59  		server := httptest.NewServer(&testCrossRankerHandler{
    60  			t: t,
    61  			res: RankResponse{
    62  				Error: "some error from the server",
    63  			},
    64  		})
    65  		defer server.Close()
    66  		c := New(server.URL, 0, nullLogger())
    67  		_, err := c.Rank(context.Background(), "prop",
    68  			[]string{"I work at Apple"}, nil)
    69  
    70  		require.NotNil(t, err)
    71  		assert.Contains(t, err.Error(), "some error from the server")
    72  	})
    73  
    74  	t.Run("when we send requests in batches", func(t *testing.T) {
    75  		server := httptest.NewServer(&testCrossRankerHandler{
    76  			t: t,
    77  			res: RankResponse{
    78  				Query: "Where do I work?",
    79  				Scores: []DocumentScore{
    80  					{
    81  						Document: "I work at Apple",
    82  						Score:    0.15,
    83  					},
    84  				},
    85  			},
    86  			batchedResults: [][]DocumentScore{
    87  				{
    88  					{
    89  						Document: "Response 1",
    90  						Score:    0.99,
    91  					},
    92  					{
    93  						Document: "Response 2",
    94  						Score:    0.89,
    95  					},
    96  				},
    97  				{
    98  					{
    99  						Document: "Response 3",
   100  						Score:    0.19,
   101  					},
   102  					{
   103  						Document: "Response 4",
   104  						Score:    0.29,
   105  					},
   106  				},
   107  				{
   108  					{
   109  						Document: "Response 5",
   110  						Score:    0.79,
   111  					},
   112  					{
   113  						Document: "Response 6",
   114  						Score:    0.789,
   115  					},
   116  				},
   117  				{
   118  					{
   119  						Document: "Response 7",
   120  						Score:    0.0001,
   121  					},
   122  				},
   123  			},
   124  		})
   125  		defer server.Close()
   126  
   127  		c := New(server.URL, 0, nullLogger())
   128  		c.maxDocuments = 2
   129  
   130  		query := "Where do I work?"
   131  		documents := []string{
   132  			"Response 1", "Response 2", "Response 3", "Response 4",
   133  			"Response 5", "Response 6", "Response 7",
   134  		}
   135  
   136  		resp, err := c.Rank(context.Background(), query, documents, nil)
   137  
   138  		require.Nil(t, err)
   139  		require.NotNil(t, resp)
   140  		require.NotNil(t, resp.DocumentScores)
   141  		for i := range resp.DocumentScores {
   142  			assert.Equal(t, documents[i], resp.DocumentScores[i].Document)
   143  			if i == 0 {
   144  				assert.Equal(t, 0.99, resp.DocumentScores[i].Score)
   145  			}
   146  			if i == len(documents)-1 {
   147  				assert.Equal(t, 0.0001, resp.DocumentScores[i].Score)
   148  			}
   149  		}
   150  	})
   151  }
   152  
   153  type testCrossRankerHandler struct {
   154  	lock           sync.RWMutex
   155  	t              *testing.T
   156  	res            RankResponse
   157  	batchedResults [][]DocumentScore
   158  }
   159  
   160  func (f *testCrossRankerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   161  	f.lock.Lock()
   162  	defer f.lock.Unlock()
   163  
   164  	assert.Equal(f.t, "/rerank", r.URL.String())
   165  	assert.Equal(f.t, http.MethodPost, r.Method)
   166  
   167  	if f.res.Error != "" {
   168  		w.WriteHeader(500)
   169  	}
   170  
   171  	bodyBytes, err := io.ReadAll(r.Body)
   172  	require.Nil(f.t, err)
   173  	defer r.Body.Close()
   174  
   175  	var req RankInput
   176  	require.Nil(f.t, json.Unmarshal(bodyBytes, &req))
   177  
   178  	containsDocument := func(req RankInput, in string) bool {
   179  		for _, doc := range req.Documents {
   180  			if doc == in {
   181  				return true
   182  			}
   183  		}
   184  		return false
   185  	}
   186  
   187  	index := 0
   188  	if len(f.batchedResults) > 0 {
   189  		if containsDocument(req, "Response 3") {
   190  			index = 1
   191  		}
   192  		if containsDocument(req, "Response 5") {
   193  			index = 2
   194  		}
   195  		if containsDocument(req, "Response 7") {
   196  			index = 3
   197  		}
   198  		f.res.Scores = f.batchedResults[index]
   199  	}
   200  
   201  	jsonBytes, _ := json.Marshal(f.res)
   202  	w.Write(jsonBytes)
   203  }