github.com/weaviate/weaviate@v1.24.6/modules/reranker-cohere/clients/ranker.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "bytes" 16 "context" 17 "encoding/json" 18 "fmt" 19 "io" 20 "net/http" 21 "net/url" 22 "runtime" 23 "sync" 24 "time" 25 26 enterrors "github.com/weaviate/weaviate/entities/errors" 27 28 "github.com/weaviate/weaviate/usecases/modulecomponents" 29 30 "github.com/pkg/errors" 31 "github.com/sirupsen/logrus" 32 "github.com/weaviate/weaviate/entities/moduletools" 33 "github.com/weaviate/weaviate/modules/reranker-cohere/config" 34 "github.com/weaviate/weaviate/usecases/modulecomponents/ent" 35 ) 36 37 var _NUMCPU = runtime.NumCPU() 38 39 type client struct { 40 lock sync.RWMutex 41 apiKey string 42 host string 43 path string 44 httpClient *http.Client 45 maxDocuments int 46 logger logrus.FieldLogger 47 } 48 49 func New(apiKey string, timeout time.Duration, logger logrus.FieldLogger) *client { 50 return &client{ 51 apiKey: apiKey, 52 httpClient: &http.Client{Timeout: timeout}, 53 host: "https://api.cohere.ai", 54 path: "/v1/rerank", 55 maxDocuments: 1000, 56 logger: logger, 57 } 58 } 59 60 func (c *client) Rank(ctx context.Context, query string, documents []string, 61 cfg moduletools.ClassConfig, 62 ) (*ent.RankResult, error) { 63 eg := enterrors.NewErrorGroupWrapper(c.logger) 64 eg.SetLimit(_NUMCPU) 65 66 chunkedDocuments := c.chunkDocuments(documents, c.maxDocuments) 67 documentScoreResponses := make([][]ent.DocumentScore, len(chunkedDocuments)) 68 for i := range chunkedDocuments { 69 i := i // https://golang.org/doc/faq#closures_and_goroutines 70 eg.Go(func() error { 71 documentScoreResponse, err := c.performRank(ctx, query, chunkedDocuments[i], cfg) 72 if err != nil { 73 return err 74 } 75 c.lockGuard(func() { 76 documentScoreResponses[i] = documentScoreResponse 77 }) 78 return nil 79 }, chunkedDocuments[i]) 80 } 81 if err := eg.Wait(); err != nil { 82 return nil, err 83 } 84 85 return c.toRankResult(query, documentScoreResponses), nil 86 } 87 88 func (c *client) lockGuard(mutate func()) { 89 c.lock.Lock() 90 defer c.lock.Unlock() 91 mutate() 92 } 93 94 func (c *client) performRank(ctx context.Context, query string, documents []string, 95 cfg moduletools.ClassConfig, 96 ) ([]ent.DocumentScore, error) { 97 settings := config.NewClassSettings(cfg) 98 cohereUrl, err := url.JoinPath(c.host, c.path) 99 if err != nil { 100 return nil, errors.Wrap(err, "join Cohere API host and path") 101 } 102 103 input := RankInput{ 104 Documents: documents, 105 Query: query, 106 Model: settings.Model(), 107 ReturnDocuments: false, 108 } 109 110 body, err := json.Marshal(input) 111 if err != nil { 112 return nil, errors.Wrapf(err, "marshal body") 113 } 114 115 req, err := http.NewRequestWithContext(ctx, "POST", cohereUrl, bytes.NewReader(body)) 116 if err != nil { 117 return nil, errors.Wrap(err, "create POST request") 118 } 119 120 apiKey, err := c.getApiKey(ctx) 121 if err != nil { 122 return nil, errors.Wrapf(err, "Cohere API Key") 123 } 124 req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", apiKey)) 125 req.Header.Add("Content-Type", "application/json") 126 req.Header.Add("Request-Source", "unspecified:weaviate") 127 128 res, err := c.httpClient.Do(req) 129 if err != nil { 130 return nil, errors.Wrap(err, "send POST request") 131 } 132 defer res.Body.Close() 133 134 bodyBytes, err := io.ReadAll(res.Body) 135 if err != nil { 136 return nil, errors.Wrap(err, "read response body") 137 } 138 139 if res.StatusCode != 200 { 140 var apiError cohereApiError 141 err = json.Unmarshal(bodyBytes, &apiError) 142 if err != nil { 143 return nil, errors.Wrap(err, "unmarshal error from response body") 144 } 145 if apiError.Message != "" { 146 return nil, errors.Errorf("connection to Cohere API failed with status %d: %s", res.StatusCode, apiError.Message) 147 } 148 return nil, errors.Errorf("connection to Cohere API failed with status %d", res.StatusCode) 149 } 150 151 var rankResponse RankResponse 152 if err := json.Unmarshal(bodyBytes, &rankResponse); err != nil { 153 return nil, errors.Wrap(err, "unmarshal response body") 154 } 155 return c.toDocumentScores(documents, rankResponse.Results), nil 156 } 157 158 func (c *client) chunkDocuments(documents []string, chunkSize int) [][]string { 159 var requests [][]string 160 for i := 0; i < len(documents); i += chunkSize { 161 end := i + chunkSize 162 163 if end > len(documents) { 164 end = len(documents) 165 } 166 167 requests = append(requests, documents[i:end]) 168 } 169 170 return requests 171 } 172 173 func (c *client) toDocumentScores(documents []string, results []Result) []ent.DocumentScore { 174 documentScores := make([]ent.DocumentScore, len(results)) 175 for _, result := range results { 176 documentScores[result.Index] = ent.DocumentScore{ 177 Document: documents[result.Index], 178 Score: result.RelevanceScore, 179 } 180 } 181 return documentScores 182 } 183 184 func (c *client) toRankResult(query string, results [][]ent.DocumentScore) *ent.RankResult { 185 documentScores := []ent.DocumentScore{} 186 for i := range results { 187 documentScores = append(documentScores, results[i]...) 188 } 189 return &ent.RankResult{ 190 Query: query, 191 DocumentScores: documentScores, 192 } 193 } 194 195 func (c *client) getApiKey(ctx context.Context) (string, error) { 196 if len(c.apiKey) > 0 { 197 return c.apiKey, nil 198 } 199 key := "X-Cohere-Api-Key" 200 201 apiKey := ctx.Value(key) 202 // try getting header from GRPC if not successful 203 if apiKey == nil { 204 apiKey = modulecomponents.GetValueFromGRPC(ctx, key) 205 } 206 if apiKeyHeader, ok := apiKey.([]string); ok && 207 len(apiKeyHeader) > 0 && len(apiKeyHeader[0]) > 0 { 208 return apiKeyHeader[0], nil 209 } 210 return "", errors.New("no api key found " + 211 "neither in request header: X-Cohere-Api-Key " + 212 "nor in environment variable under COHERE_APIKEY") 213 } 214 215 type RankInput struct { 216 Documents []string `json:"documents"` 217 Query string `json:"query"` 218 Model string `json:"model"` 219 ReturnDocuments bool `json:"return_documents"` 220 } 221 222 type Document struct { 223 Text string `json:"text"` 224 } 225 226 type Result struct { 227 Index int `json:"index"` 228 RelevanceScore float64 `json:"relevance_score"` 229 Document Document `json:"document"` 230 } 231 232 type APIVersion struct { 233 Version string `json:"version"` 234 } 235 236 type Meta struct { 237 APIVersion APIVersion `json:"api_version"` 238 } 239 240 type RankResponse struct { 241 ID string `json:"id"` 242 Results []Result `json:"results"` 243 Meta Meta `json:"meta"` 244 } 245 246 type cohereApiError struct { 247 Message string `json:"message"` 248 }