github.com/weaviate/weaviate@v1.24.6/modules/reranker-cohere/clients/ranker_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "context" 16 "encoding/json" 17 "io" 18 "net/http" 19 "net/http/httptest" 20 "sync" 21 "testing" 22 23 "github.com/sirupsen/logrus" 24 "github.com/sirupsen/logrus/hooks/test" 25 "github.com/stretchr/testify/assert" 26 "github.com/stretchr/testify/require" 27 "github.com/weaviate/weaviate/usecases/modulecomponents/ent" 28 ) 29 30 func nullLogger() logrus.FieldLogger { 31 l, _ := test.NewNullLogger() 32 return l 33 } 34 35 func TestRank(t *testing.T) { 36 t.Run("when the server has a successful response", func(t *testing.T) { 37 handler := &testRankHandler{ 38 t: t, 39 response: RankResponse{ 40 Results: []Result{ 41 { 42 Index: 0, 43 RelevanceScore: 0.9, 44 }, 45 }, 46 }, 47 } 48 server := httptest.NewServer(handler) 49 defer server.Close() 50 51 c := New("apiKey", 0, nullLogger()) 52 c.host = server.URL 53 54 expected := &ent.RankResult{ 55 DocumentScores: []ent.DocumentScore{ 56 { 57 Document: "I work at Apple", 58 Score: 0.9, 59 }, 60 }, 61 Query: "Where do I work?", 62 } 63 64 res, err := c.Rank(context.Background(), "Where do I work?", []string{"I work at Apple"}, nil) 65 66 assert.Nil(t, err) 67 assert.Equal(t, expected, res) 68 }) 69 70 t.Run("when the server has an error", func(t *testing.T) { 71 handler := &testRankHandler{ 72 t: t, 73 response: RankResponse{ 74 Results: []Result{}, 75 }, 76 errorMessage: "some error from the server", 77 } 78 server := httptest.NewServer(handler) 79 defer server.Close() 80 81 c := New("apiKey", 0, nullLogger()) 82 c.host = server.URL 83 84 _, err := c.Rank(context.Background(), "I work at Apple", []string{"Where do I work?"}, nil) 85 86 require.NotNil(t, err) 87 assert.Contains(t, err.Error(), "some error from the server") 88 }) 89 90 t.Run("when we send requests in batches", func(t *testing.T) { 91 handler := &testRankHandler{ 92 t: t, 93 batchedResults: [][]Result{ 94 { 95 { 96 Index: 0, 97 RelevanceScore: 0.99, 98 }, 99 { 100 Index: 1, 101 RelevanceScore: 0.89, 102 }, 103 }, 104 { 105 { 106 Index: 0, 107 RelevanceScore: 0.19, 108 }, 109 { 110 Index: 1, 111 RelevanceScore: 0.29, 112 }, 113 }, 114 { 115 { 116 Index: 0, 117 RelevanceScore: 0.79, 118 }, 119 { 120 Index: 1, 121 RelevanceScore: 0.789, 122 }, 123 }, 124 { 125 { 126 Index: 0, 127 RelevanceScore: 0.0001, 128 }, 129 }, 130 }, 131 } 132 server := httptest.NewServer(handler) 133 defer server.Close() 134 135 c := New("apiKey", 0, nullLogger()) 136 c.host = server.URL 137 // this will trigger 4 go routines 138 c.maxDocuments = 2 139 140 query := "Where do I work?" 141 documents := []string{ 142 "Response 1", "Response 2", "Response 3", "Response 4", 143 "Response 5", "Response 6", "Response 7", 144 } 145 146 resp, err := c.Rank(context.Background(), query, documents, nil) 147 148 require.Nil(t, err) 149 require.NotNil(t, resp) 150 require.NotNil(t, resp.DocumentScores) 151 for i := range resp.DocumentScores { 152 assert.Equal(t, documents[i], resp.DocumentScores[i].Document) 153 if i == 0 { 154 assert.Equal(t, 0.99, resp.DocumentScores[i].Score) 155 } 156 if i == len(documents)-1 { 157 assert.Equal(t, 0.0001, resp.DocumentScores[i].Score) 158 } 159 } 160 }) 161 } 162 163 type testRankHandler struct { 164 lock sync.RWMutex 165 t *testing.T 166 response RankResponse 167 batchedResults [][]Result 168 errorMessage string 169 } 170 171 func (f *testRankHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 172 f.lock.Lock() 173 defer f.lock.Unlock() 174 175 if f.errorMessage != "" { 176 w.WriteHeader(http.StatusInternalServerError) 177 w.Write([]byte(`{"message":"` + f.errorMessage + `"}`)) 178 return 179 } 180 181 bodyBytes, err := io.ReadAll(r.Body) 182 require.Nil(f.t, err) 183 defer r.Body.Close() 184 185 var req RankInput 186 require.Nil(f.t, json.Unmarshal(bodyBytes, &req)) 187 188 containsDocument := func(req RankInput, in string) bool { 189 for _, doc := range req.Documents { 190 if doc == in { 191 return true 192 } 193 } 194 return false 195 } 196 197 index := 0 198 if len(f.batchedResults) > 0 { 199 if containsDocument(req, "Response 3") { 200 index = 1 201 } 202 if containsDocument(req, "Response 5") { 203 index = 2 204 } 205 if containsDocument(req, "Response 7") { 206 index = 3 207 } 208 f.response.Results = f.batchedResults[index] 209 } 210 211 outBytes, err := json.Marshal(f.response) 212 require.Nil(f.t, err) 213 214 w.Write(outBytes) 215 }