github.com/weaviate/weaviate@v1.24.6/modules/reranker-transformers/clients/ranker.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package client 13 14 import ( 15 "bytes" 16 "context" 17 "encoding/json" 18 "fmt" 19 "io" 20 "net/http" 21 "runtime" 22 "sync" 23 "time" 24 25 enterrors "github.com/weaviate/weaviate/entities/errors" 26 27 "github.com/pkg/errors" 28 "github.com/sirupsen/logrus" 29 "github.com/weaviate/weaviate/entities/moduletools" 30 "github.com/weaviate/weaviate/usecases/modulecomponents/ent" 31 ) 32 33 var _NUMCPU = runtime.NumCPU() 34 35 type client struct { 36 lock sync.RWMutex 37 origin string 38 httpClient *http.Client 39 maxDocuments int 40 logger logrus.FieldLogger 41 } 42 43 func New(origin string, timeout time.Duration, logger logrus.FieldLogger) *client { 44 return &client{ 45 origin: origin, 46 httpClient: &http.Client{Timeout: timeout}, 47 maxDocuments: 32, 48 logger: logger, 49 } 50 } 51 52 func (c *client) Rank(ctx context.Context, 53 query string, documents []string, cfg moduletools.ClassConfig, 54 ) (*ent.RankResult, error) { 55 eg := enterrors.NewErrorGroupWrapper(c.logger) 56 eg.SetLimit(_NUMCPU) 57 58 chunkedDocuments := c.chunkDocuments(documents, c.maxDocuments) 59 documentScoreResponses := make([][]DocumentScore, len(chunkedDocuments)) 60 for i := range chunkedDocuments { 61 i := i // https://golang.org/doc/faq#closures_and_goroutines 62 eg.Go(func() error { 63 documentScoreResponse, err := c.performRank(ctx, query, chunkedDocuments[i], cfg) 64 if err != nil { 65 return err 66 } 67 c.lockGuard(func() { 68 documentScoreResponses[i] = documentScoreResponse 69 }) 70 return nil 71 }, chunkedDocuments[i]) 72 } 73 if err := eg.Wait(); err != nil { 74 return nil, err 75 } 76 77 return c.toRankResult(query, documentScoreResponses), nil 78 } 79 80 func (c *client) lockGuard(mutate func()) { 81 c.lock.Lock() 82 defer c.lock.Unlock() 83 mutate() 84 } 85 86 func (c *client) toRankResult(query string, scores [][]DocumentScore) *ent.RankResult { 87 documentScores := []ent.DocumentScore{} 88 for _, docScores := range scores { 89 for i := range docScores { 90 documentScores = append(documentScores, ent.DocumentScore{ 91 Document: docScores[i].Document, 92 Score: docScores[i].Score, 93 }) 94 } 95 } 96 return &ent.RankResult{ 97 Query: query, 98 DocumentScores: documentScores, 99 } 100 } 101 102 func (c *client) performRank(ctx context.Context, 103 query string, documents []string, cfg moduletools.ClassConfig, 104 ) ([]DocumentScore, error) { 105 body, err := json.Marshal(RankInput{ 106 Query: query, 107 Documents: documents, 108 }) 109 if err != nil { 110 return nil, errors.Wrapf(err, "marshal body") 111 } 112 113 req, err := http.NewRequestWithContext(ctx, "POST", c.url("/rerank"), 114 bytes.NewReader(body)) 115 if err != nil { 116 return nil, errors.Wrap(err, "create POST request") 117 } 118 119 res, err := c.httpClient.Do(req) 120 if err != nil { 121 return nil, errors.Wrap(err, "send POST request") 122 } 123 defer res.Body.Close() 124 125 bodyBytes, err := io.ReadAll(res.Body) 126 if err != nil { 127 return nil, errors.Wrap(err, "read response body") 128 } 129 130 var resBody RankResponse 131 if err := json.Unmarshal(bodyBytes, &resBody); err != nil { 132 return nil, errors.Wrap(err, "unmarshal response body") 133 } 134 135 if res.StatusCode != 200 { 136 if resBody.Error != "" { 137 return nil, errors.Errorf("fail with status %d: %s", res.StatusCode, 138 resBody.Error) 139 } 140 return nil, errors.Errorf("fail with status %d", res.StatusCode) 141 } 142 143 return resBody.Scores, nil 144 } 145 146 func (c *client) chunkDocuments(documents []string, chunkSize int) [][]string { 147 var requests [][]string 148 for i := 0; i < len(documents); i += chunkSize { 149 end := i + chunkSize 150 151 if end > len(documents) { 152 end = len(documents) 153 } 154 155 requests = append(requests, documents[i:end]) 156 } 157 158 return requests 159 } 160 161 func (c *client) url(path string) string { 162 return fmt.Sprintf("%s%s", c.origin, path) 163 } 164 165 type RankInput struct { 166 Query string `json:"query"` 167 Documents []string `json:"documents"` 168 RankPropertyValue string `json:"property"` 169 } 170 171 type DocumentScore struct { 172 Document string `json:"document"` 173 Score float64 `json:"score"` 174 } 175 176 type RankResponse struct { 177 Query string `json:"query"` 178 Scores []DocumentScore `json:"scores"` 179 RankPropertyValue string `json:"property"` 180 Score float64 `json:"score"` 181 Error string `json:"error"` 182 }