github.com/weaviate/weaviate@v1.24.6/modules/text2vec-huggingface/clients/huggingface_test.go (about) 1 // _ _ 2 // __ _____ __ ___ ___ __ _| |_ ___ 3 // \ \ /\ / / _ \/ _` \ \ / / |/ _` | __/ _ \ 4 // \ V V / __/ (_| |\ V /| | (_| | || __/ 5 // \_/\_/ \___|\__,_| \_/ |_|\__,_|\__\___| 6 // 7 // Copyright © 2016 - 2024 Weaviate B.V. All rights reserved. 8 // 9 // CONTACT: hello@weaviate.io 10 // 11 12 package clients 13 14 import ( 15 "context" 16 "encoding/json" 17 "io" 18 "net/http" 19 "net/http/httptest" 20 "testing" 21 "time" 22 23 "github.com/pkg/errors" 24 "github.com/sirupsen/logrus" 25 "github.com/sirupsen/logrus/hooks/test" 26 "github.com/stretchr/testify/assert" 27 "github.com/stretchr/testify/require" 28 "github.com/weaviate/weaviate/modules/text2vec-huggingface/ent" 29 ) 30 31 func TestClient(t *testing.T) { 32 t.Run("when all is fine", func(t *testing.T) { 33 server := httptest.NewServer(&fakeHandler{t: t}) 34 defer server.Close() 35 c := &vectorizer{ 36 apiKey: "apiKey", 37 httpClient: &http.Client{}, 38 logger: nullLogger(), 39 } 40 expected := &ent.VectorizationResult{ 41 Text: "This is my text", 42 Vector: []float32{0.1, 0.2, 0.3}, 43 Dimensions: 3, 44 } 45 res, err := c.Vectorize(context.Background(), "This is my text", 46 ent.VectorizationConfig{ 47 Model: "sentence-transformers/gtr-t5-xxl", 48 WaitForModel: false, 49 UseGPU: false, 50 UseCache: true, 51 EndpointURL: server.URL, 52 }) 53 54 assert.Nil(t, err) 55 assert.Equal(t, expected, res) 56 }) 57 58 t.Run("when the context is expired", func(t *testing.T) { 59 server := httptest.NewServer(&fakeHandler{t: t}) 60 defer server.Close() 61 c := &vectorizer{ 62 apiKey: "apiKey", 63 httpClient: &http.Client{}, 64 logger: nullLogger(), 65 } 66 ctx, cancel := context.WithDeadline(context.Background(), time.Now()) 67 defer cancel() 68 69 _, err := c.Vectorize(ctx, "This is my text", ent.VectorizationConfig{ 70 EndpointURL: server.URL, 71 }) 72 73 require.NotNil(t, err) 74 assert.Contains(t, err.Error(), "context deadline exceeded") 75 }) 76 77 t.Run("when the server returns an error", func(t *testing.T) { 78 server := httptest.NewServer(&fakeHandler{ 79 t: t, 80 serverError: errors.Errorf("nope, not gonna happen"), 81 }) 82 defer server.Close() 83 c := &vectorizer{ 84 apiKey: "apiKey", 85 httpClient: &http.Client{}, 86 logger: nullLogger(), 87 } 88 _, err := c.Vectorize(context.Background(), "This is my text", 89 ent.VectorizationConfig{ 90 EndpointURL: server.URL, 91 }) 92 93 require.NotNil(t, err) 94 assert.Equal(t, err.Error(), "connection to HuggingFace failed with status: 500 error: nope, not gonna happen estimated time: 20") 95 }) 96 97 t.Run("when HuggingFace key is passed using X-Huggingface-Api-Key header", func(t *testing.T) { 98 server := httptest.NewServer(&fakeHandler{t: t}) 99 defer server.Close() 100 c := &vectorizer{ 101 apiKey: "", 102 httpClient: &http.Client{}, 103 logger: nullLogger(), 104 } 105 ctxWithValue := context.WithValue(context.Background(), 106 "X-Huggingface-Api-Key", []string{"some-key"}) 107 108 expected := &ent.VectorizationResult{ 109 Text: "This is my text", 110 Vector: []float32{0.1, 0.2, 0.3}, 111 Dimensions: 3, 112 } 113 res, err := c.Vectorize(ctxWithValue, "This is my text", 114 ent.VectorizationConfig{ 115 Model: "sentence-transformers/gtr-t5-xxl", 116 WaitForModel: true, 117 UseGPU: false, 118 UseCache: true, 119 EndpointURL: server.URL, 120 }) 121 122 require.Nil(t, err) 123 assert.Equal(t, expected, res) 124 }) 125 126 t.Run("when a request requires an API KEY", func(t *testing.T) { 127 server := httptest.NewServer(&fakeHandler{ 128 t: t, 129 serverError: errors.Errorf("A valid user or organization token is required"), 130 }) 131 defer server.Close() 132 c := &vectorizer{ 133 apiKey: "", 134 httpClient: &http.Client{}, 135 logger: nullLogger(), 136 } 137 ctxWithValue := context.WithValue(context.Background(), 138 "X-Huggingface-Api-Key", []string{""}) 139 140 _, err := c.Vectorize(ctxWithValue, "This is my text", 141 ent.VectorizationConfig{ 142 Model: "sentence-transformers/gtr-t5-xxl", 143 EndpointURL: server.URL, 144 }) 145 146 require.NotNil(t, err) 147 assert.Equal(t, err.Error(), "failed with status: 401 error: A valid user or organization token is required") 148 }) 149 150 t.Run("when the server returns an error with warnings", func(t *testing.T) { 151 server := httptest.NewServer(&fakeHandler{ 152 t: t, 153 serverError: errors.Errorf("with warnings"), 154 }) 155 defer server.Close() 156 c := &vectorizer{ 157 apiKey: "apiKey", 158 httpClient: &http.Client{}, 159 logger: nullLogger(), 160 } 161 _, err := c.Vectorize(context.Background(), "This is my text", 162 ent.VectorizationConfig{ 163 EndpointURL: server.URL, 164 }) 165 166 require.NotNil(t, err) 167 assert.Equal(t, err.Error(), "connection to HuggingFace failed with status: 500 error: with warnings "+ 168 "warnings: [There was an inference error: CUDA error: all CUDA-capable devices are busy or unavailable\n"+ 169 "CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.\n"+ 170 "For debugging consider passing CUDA_LAUNCH_BLOCKING=1.]") 171 }) 172 } 173 174 type fakeHandler struct { 175 t *testing.T 176 serverError error 177 } 178 179 func (f *fakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 180 assert.Equal(f.t, http.MethodPost, r.Method) 181 182 if f.serverError != nil { 183 switch f.serverError.Error() { 184 case "with warnings": 185 embeddingError := map[string]interface{}{ 186 "error": f.serverError.Error(), 187 "warnings": []string{ 188 "There was an inference error: CUDA error: all CUDA-capable devices are busy or unavailable\n" + 189 "CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.\n" + 190 "For debugging consider passing CUDA_LAUNCH_BLOCKING=1.", 191 }, 192 } 193 outBytes, err := json.Marshal(embeddingError) 194 require.Nil(f.t, err) 195 196 w.WriteHeader(http.StatusInternalServerError) 197 w.Write(outBytes) 198 return 199 case "A valid user or organization token is required": 200 embeddingError := map[string]interface{}{ 201 "error": "A valid user or organization token is required", 202 } 203 outBytes, err := json.Marshal(embeddingError) 204 require.Nil(f.t, err) 205 206 w.WriteHeader(http.StatusUnauthorized) 207 w.Write(outBytes) 208 return 209 default: 210 embeddingError := map[string]interface{}{ 211 "error": f.serverError.Error(), 212 "estimated_time": 20.0, 213 } 214 outBytes, err := json.Marshal(embeddingError) 215 require.Nil(f.t, err) 216 217 w.WriteHeader(http.StatusInternalServerError) 218 w.Write(outBytes) 219 return 220 } 221 } 222 223 bodyBytes, err := io.ReadAll(r.Body) 224 require.Nil(f.t, err) 225 defer r.Body.Close() 226 227 var b map[string]interface{} 228 require.Nil(f.t, json.Unmarshal(bodyBytes, &b)) 229 230 textInputs := b["inputs"].([]interface{}) 231 assert.Greater(f.t, len(textInputs), 0) 232 textInput := textInputs[0].(string) 233 assert.Greater(f.t, len(textInput), 0) 234 235 // TODO: fix this 236 embedding := [][]float32{{0.1, 0.2, 0.3}} 237 outBytes, err := json.Marshal(embedding) 238 require.Nil(f.t, err) 239 240 w.Write(outBytes) 241 } 242 243 func nullLogger() logrus.FieldLogger { 244 l, _ := test.NewNullLogger() 245 return l 246 } 247 248 func Test_getURL(t *testing.T) { 249 v := &vectorizer{} 250 251 tests := []struct { 252 name string 253 config ent.VectorizationConfig 254 want string 255 }{ 256 { 257 name: "Facebook DPR model", 258 config: ent.VectorizationConfig{ 259 Model: "sentence-transformers/facebook-dpr-ctx_encoder-multiset-base", 260 }, 261 want: "https://api-inference.huggingface.co/pipeline/feature-extraction/sentence-transformers/facebook-dpr-ctx_encoder-multiset-base", 262 }, 263 { 264 name: "BERT base model (uncased)", 265 config: ent.VectorizationConfig{ 266 Model: "bert-base-uncased", 267 }, 268 want: "https://api-inference.huggingface.co/pipeline/feature-extraction/bert-base-uncased", 269 }, 270 { 271 name: "BERT base model (uncased)", 272 config: ent.VectorizationConfig{ 273 EndpointURL: "https://self-hosted-instance.com/bert-base-uncased", 274 }, 275 want: "https://self-hosted-instance.com/bert-base-uncased", 276 }, 277 } 278 for _, tt := range tests { 279 t.Run(tt.name, func(t *testing.T) { 280 assert.Equal(t, tt.want, v.getURL(tt.config)) 281 }) 282 } 283 }