github.com/weaviate/weaviate@v1.24.6/modules/reranker-transformers/clients/ranker_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package client 13 14 import ( 15 "context" 16 "encoding/json" 17 "io" 18 "net/http" 19 "net/http/httptest" 20 "sync" 21 "testing" 22 23 "github.com/stretchr/testify/assert" 24 "github.com/stretchr/testify/require" 25 "github.com/weaviate/weaviate/usecases/modulecomponents/ent" 26 ) 27 28 func TestGetScore(t *testing.T) { 29 t.Run("when the server has a successful answer", func(t *testing.T) { 30 server := httptest.NewServer(&testCrossRankerHandler{ 31 t: t, 32 res: RankResponse{ 33 Query: "Where do I work?", 34 Scores: []DocumentScore{ 35 { 36 Document: "I work at Apple", 37 Score: 0.15, 38 }, 39 }, 40 }, 41 }) 42 defer server.Close() 43 c := New(server.URL, 0, nullLogger()) 44 res, err := c.Rank(context.Background(), "Where do I work?", []string{"I work at Apple"}, nil) 45 46 assert.Nil(t, err) 47 assert.Equal(t, ent.RankResult{ 48 Query: "Where do I work?", 49 DocumentScores: []ent.DocumentScore{ 50 { 51 Document: "I work at Apple", 52 Score: 0.15, 53 }, 54 }, 55 }, *res) 56 }) 57 58 t.Run("when the server has an error", func(t *testing.T) { 59 server := httptest.NewServer(&testCrossRankerHandler{ 60 t: t, 61 res: RankResponse{ 62 Error: "some error from the server", 63 }, 64 }) 65 defer server.Close() 66 c := New(server.URL, 0, nullLogger()) 67 _, err := c.Rank(context.Background(), "prop", 68 []string{"I work at Apple"}, nil) 69 70 require.NotNil(t, err) 71 assert.Contains(t, err.Error(), "some error from the server") 72 }) 73 74 t.Run("when we send requests in batches", func(t *testing.T) { 75 server := httptest.NewServer(&testCrossRankerHandler{ 76 t: t, 77 res: RankResponse{ 78 Query: "Where do I work?", 79 Scores: []DocumentScore{ 80 { 81 Document: "I work at Apple", 82 Score: 0.15, 83 }, 84 }, 85 }, 86 batchedResults: [][]DocumentScore{ 87 { 88 { 89 Document: "Response 1", 90 Score: 0.99, 91 }, 92 { 93 Document: "Response 2", 94 Score: 0.89, 95 }, 96 }, 97 { 98 { 99 Document: "Response 3", 100 Score: 0.19, 101 }, 102 { 103 Document: "Response 4", 104 Score: 0.29, 105 }, 106 }, 107 { 108 { 109 Document: "Response 5", 110 Score: 0.79, 111 }, 112 { 113 Document: "Response 6", 114 Score: 0.789, 115 }, 116 }, 117 { 118 { 119 Document: "Response 7", 120 Score: 0.0001, 121 }, 122 }, 123 }, 124 }) 125 defer server.Close() 126 127 c := New(server.URL, 0, nullLogger()) 128 c.maxDocuments = 2 129 130 query := "Where do I work?" 131 documents := []string{ 132 "Response 1", "Response 2", "Response 3", "Response 4", 133 "Response 5", "Response 6", "Response 7", 134 } 135 136 resp, err := c.Rank(context.Background(), query, documents, nil) 137 138 require.Nil(t, err) 139 require.NotNil(t, resp) 140 require.NotNil(t, resp.DocumentScores) 141 for i := range resp.DocumentScores { 142 assert.Equal(t, documents[i], resp.DocumentScores[i].Document) 143 if i == 0 { 144 assert.Equal(t, 0.99, resp.DocumentScores[i].Score) 145 } 146 if i == len(documents)-1 { 147 assert.Equal(t, 0.0001, resp.DocumentScores[i].Score) 148 } 149 } 150 }) 151 } 152 153 type testCrossRankerHandler struct { 154 lock sync.RWMutex 155 t *testing.T 156 res RankResponse 157 batchedResults [][]DocumentScore 158 } 159 160 func (f *testCrossRankerHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 161 f.lock.Lock() 162 defer f.lock.Unlock() 163 164 assert.Equal(f.t, "/rerank", r.URL.String()) 165 assert.Equal(f.t, http.MethodPost, r.Method) 166 167 if f.res.Error != "" { 168 w.WriteHeader(500) 169 } 170 171 bodyBytes, err := io.ReadAll(r.Body) 172 require.Nil(f.t, err) 173 defer r.Body.Close() 174 175 var req RankInput 176 require.Nil(f.t, json.Unmarshal(bodyBytes, &req)) 177 178 containsDocument := func(req RankInput, in string) bool { 179 for _, doc := range req.Documents { 180 if doc == in { 181 return true 182 } 183 } 184 return false 185 } 186 187 index := 0 188 if len(f.batchedResults) > 0 { 189 if containsDocument(req, "Response 3") { 190 index = 1 191 } 192 if containsDocument(req, "Response 5") { 193 index = 2 194 } 195 if containsDocument(req, "Response 7") { 196 index = 3 197 } 198 f.res.Scores = f.batchedResults[index] 199 } 200 201 jsonBytes, _ := json.Marshal(f.res) 202 w.Write(jsonBytes) 203 }