github.com/weaviate/weaviate@v1.24.6/modules/text2vec-huggingface/clients/huggingface.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "bytes" 16 "context" 17 "encoding/json" 18 "fmt" 19 "io" 20 "net/http" 21 "time" 22 23 "github.com/weaviate/weaviate/usecases/modulecomponents" 24 25 "github.com/pkg/errors" 26 "github.com/sirupsen/logrus" 27 "github.com/weaviate/weaviate/modules/text2vec-huggingface/ent" 28 ) 29 30 const ( 31 DefaultOrigin = "https://api-inference.huggingface.co" 32 DefaultPath = "pipeline/feature-extraction" 33 ) 34 35 type embeddingsRequest struct { 36 Inputs []string `json:"inputs"` 37 Options *options `json:"options,omitempty"` 38 } 39 40 type options struct { 41 WaitForModel bool `json:"wait_for_model,omitempty"` 42 UseGPU bool `json:"use_gpu,omitempty"` 43 UseCache bool `json:"use_cache,omitempty"` 44 } 45 46 type embedding [][]float32 47 48 type embeddingBert [][][][]float32 49 50 type embeddingObject struct { 51 Embeddings embedding `json:"embeddings"` 52 } 53 54 type huggingFaceApiError struct { 55 Error string `json:"error"` 56 EstimatedTime *float32 `json:"estimated_time,omitempty"` 57 Warnings []string `json:"warnings"` 58 } 59 60 type vectorizer struct { 61 apiKey string 62 httpClient *http.Client 63 bertEmbeddingsDecoder *bertEmbeddingsDecoder 64 logger logrus.FieldLogger 65 } 66 67 func New(apiKey string, timeout time.Duration, logger logrus.FieldLogger) *vectorizer { 68 return &vectorizer{ 69 apiKey: apiKey, 70 httpClient: &http.Client{ 71 Timeout: timeout, 72 }, 73 bertEmbeddingsDecoder: newBertEmbeddingsDecoder(), 74 logger: logger, 75 } 76 } 77 78 func (v *vectorizer) Vectorize(ctx context.Context, input string, 79 config ent.VectorizationConfig, 80 ) (*ent.VectorizationResult, error) { 81 return v.vectorize(ctx, v.getURL(config), input, v.getOptions(config)) 82 } 83 84 func (v *vectorizer) VectorizeQuery(ctx context.Context, input string, 85 config ent.VectorizationConfig, 86 ) (*ent.VectorizationResult, error) { 87 return v.vectorize(ctx, v.getURL(config), input, v.getOptions(config)) 88 } 89 90 func (v *vectorizer) vectorize(ctx context.Context, url string, 91 input string, options options, 92 ) (*ent.VectorizationResult, error) { 93 body, err := json.Marshal(embeddingsRequest{ 94 Inputs: []string{input}, 95 Options: &options, 96 }) 97 if err != nil { 98 return nil, errors.Wrapf(err, "marshal body") 99 } 100 101 req, err := http.NewRequestWithContext(ctx, "POST", url, 102 bytes.NewReader(body)) 103 if err != nil { 104 return nil, errors.Wrap(err, "create POST request") 105 } 106 if apiKey := v.getApiKey(ctx); apiKey != "" { 107 req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", apiKey)) 108 } 109 req.Header.Add("Content-Type", "application/json") 110 111 res, err := v.httpClient.Do(req) 112 if err != nil { 113 return nil, errors.Wrap(err, "send POST request") 114 } 115 defer res.Body.Close() 116 117 bodyBytes, err := io.ReadAll(res.Body) 118 if err != nil { 119 return nil, errors.Wrap(err, "read response body") 120 } 121 122 if err := checkResponse(res, bodyBytes); err != nil { 123 return nil, err 124 } 125 126 vector, err := v.decodeVector(bodyBytes) 127 if err != nil { 128 return nil, errors.Wrap(err, "cannot decode vector") 129 } 130 131 return &ent.VectorizationResult{ 132 Text: input, 133 Dimensions: len(vector), 134 Vector: vector, 135 }, nil 136 } 137 138 func checkResponse(res *http.Response, bodyBytes []byte) error { 139 if res.StatusCode < 400 { 140 return nil 141 } 142 143 var resBody huggingFaceApiError 144 if err := json.Unmarshal(bodyBytes, &resBody); err != nil { 145 return fmt.Errorf("unmarshal error response body: %v", string(bodyBytes)) 146 } 147 148 message := fmt.Sprintf("failed with status: %d", res.StatusCode) 149 if resBody.Error != "" { 150 message = fmt.Sprintf("%s error: %v", message, resBody.Error) 151 if resBody.EstimatedTime != nil { 152 message = fmt.Sprintf("%s estimated time: %v", message, *resBody.EstimatedTime) 153 } 154 if len(resBody.Warnings) > 0 { 155 message = fmt.Sprintf("%s warnings: %v", message, resBody.Warnings) 156 } 157 } 158 159 if res.StatusCode == http.StatusInternalServerError { 160 message = fmt.Sprintf("connection to HuggingFace %v", message) 161 } 162 163 return errors.New(message) 164 } 165 166 func (v *vectorizer) decodeVector(bodyBytes []byte) ([]float32, error) { 167 var emb embedding 168 if err := json.Unmarshal(bodyBytes, &emb); err != nil { 169 var embObject embeddingObject 170 if err := json.Unmarshal(bodyBytes, &embObject); err != nil { 171 var embBert embeddingBert 172 if err := json.Unmarshal(bodyBytes, &embBert); err != nil { 173 return nil, errors.Wrap(err, "unmarshal response body") 174 } 175 176 if len(embBert) == 1 && len(embBert[0]) == 1 { 177 return v.bertEmbeddingsDecoder.calculateVector(embBert[0][0]) 178 } 179 180 return nil, errors.New("unprocessable response body") 181 } 182 if len(embObject.Embeddings) == 1 { 183 return embObject.Embeddings[0], nil 184 } 185 186 return nil, errors.New("unprocessable response body") 187 } 188 189 if len(emb) == 1 { 190 return emb[0], nil 191 } 192 193 return nil, errors.New("unprocessable response body") 194 } 195 196 func (v *vectorizer) getApiKey(ctx context.Context) string { 197 if len(v.apiKey) > 0 { 198 return v.apiKey 199 } 200 key := "X-Huggingface-Api-Key" 201 apiKey := ctx.Value(key) 202 // try getting header from GRPC if not successful 203 if apiKey == nil { 204 apiKey = modulecomponents.GetValueFromGRPC(ctx, key) 205 } 206 207 if apiKeyHeader, ok := apiKey.([]string); ok && 208 len(apiKeyHeader) > 0 && len(apiKeyHeader[0]) > 0 { 209 return apiKeyHeader[0] 210 } 211 return "" 212 } 213 214 func (v *vectorizer) getOptions(config ent.VectorizationConfig) options { 215 return options{ 216 WaitForModel: config.WaitForModel, 217 UseGPU: config.UseGPU, 218 UseCache: config.UseCache, 219 } 220 } 221 222 func (v *vectorizer) getURL(config ent.VectorizationConfig) string { 223 if config.EndpointURL != "" { 224 return config.EndpointURL 225 } 226 227 return fmt.Sprintf("%s/%s/%s", DefaultOrigin, DefaultPath, config.Model) 228 }