github.com/instill-ai/component@v0.16.0-beta/pkg/connector/huggingface/v0/connector_test.go (about) 1 package huggingface 2 3 import ( 4 "encoding/base64" 5 "fmt" 6 "io" 7 "net/http" 8 "net/http/httptest" 9 "testing" 10 11 qt "github.com/frankban/quicktest" 12 "go.uber.org/zap" 13 "google.golang.org/protobuf/types/known/structpb" 14 15 "github.com/instill-ai/component/pkg/base" 16 "github.com/instill-ai/component/pkg/connector/util/httpclient" 17 "github.com/instill-ai/x/errmsg" 18 ) 19 20 const ( 21 apiKey = "123" 22 model = "openai/whisper-tiny" 23 24 testInput = "testing generation" 25 26 fillMaskResp = `[{"score": 0.234, "token": 3, "sequence": "one", "token_str": "three"}]` 27 classificationResp = `[{"score": 0.123, "label": "backpack hip-hop"}, {"score": 0.894, "label": "lo-fi jazz"}]` 28 tokenClassificationResp = `[{"entity_group":"foo", "score": 0.234, "start": 0, "end": 5, "word": "bar"}]` 29 objDetectionResp = ` 30 [ 31 { 32 "score": 0.123, 33 "label": "backpack hip-hop", 34 "box": { 35 "xmin": 0, 36 "xmax": 1, 37 "ymin": 0, 38 "ymax": 1 39 } 40 } 41 ]` 42 43 errorResp = ` { "error": "Invalid request" }` 44 errorsResp = ` { "error": ["Temporarily unavailable", "Too many requests"] }` 45 ) 46 47 var ( 48 bRaw = []byte("aaa") 49 bEncoded = base64.StdEncoding.EncodeToString(bRaw) 50 51 inputsBody = []byte(`{"inputs": "testing generation"}`) 52 ) 53 54 type taskParams struct { 55 task string 56 input any 57 contentType string // content type received in Hugging Face 58 wantBody []byte // expected request body in Hugging Face 59 okResp string // successful response from Hugging Face 60 wantResp string // successful response from connector 61 } 62 63 func wrapArrayInObject(array, key string) string { 64 return fmt.Sprintf(`{"%s": %s}`, key, array) 65 } 66 67 var coveredTasks = []taskParams{ 68 { 69 task: textGenerationTask, 70 input: TextGenerationRequest{Inputs: testInput}, 71 contentType: httpclient.MIMETypeJSON, 72 wantBody: inputsBody, 73 okResp: `[{"generated_text": "text response"}]`, 74 wantResp: `{"generated_text": "text response"}`, 75 }, 76 { 77 task: textToImageTask, 78 input: TextToImageRequest{Inputs: testInput}, 79 contentType: httpclient.MIMETypeJSON, 80 wantBody: inputsBody, 81 okResp: string(bRaw), 82 wantResp: fmt.Sprintf(`{"image": "data:image/jpeg;base64,%s"}`, bEncoded), 83 }, 84 { 85 task: fillMaskTask, 86 input: FillMaskRequest{Inputs: testInput}, 87 contentType: httpclient.MIMETypeJSON, 88 wantBody: inputsBody, 89 okResp: fillMaskResp, 90 wantResp: wrapArrayInObject(fillMaskResp, "results"), 91 }, 92 { 93 task: summarizationTask, 94 input: SummarizationRequest{Inputs: testInput}, 95 contentType: httpclient.MIMETypeJSON, 96 wantBody: inputsBody, 97 okResp: `[{"summary_text": "summary"}]`, 98 wantResp: `{"summary_text": "summary"}`, 99 }, 100 { 101 task: textClassificationTask, 102 input: TextClassificationRequest{Inputs: testInput}, 103 contentType: httpclient.MIMETypeJSON, 104 wantBody: inputsBody, 105 okResp: "[" + classificationResp + "]", 106 wantResp: wrapArrayInObject(classificationResp, "results"), 107 }, 108 { 109 task: tokenClassificationTask, 110 input: TokenClassificationRequest{Inputs: testInput}, 111 contentType: httpclient.MIMETypeJSON, 112 wantBody: inputsBody, 113 okResp: tokenClassificationResp, 114 wantResp: wrapArrayInObject(tokenClassificationResp, "results"), 115 }, 116 { 117 task: translationTask, 118 input: TranslationRequest{Inputs: testInput}, 119 contentType: httpclient.MIMETypeJSON, 120 wantBody: inputsBody, 121 okResp: `[{"translation_text": "translated"}]`, 122 wantResp: `{"translation_text": "translated"}`, 123 }, 124 { 125 task: zeroShotClassificationTask, 126 input: ZeroShotRequest{Inputs: testInput}, 127 contentType: httpclient.MIMETypeJSON, 128 wantBody: inputsBody, 129 okResp: `{"sequence": "seq"}`, 130 wantResp: `{"sequence": "seq"}`, 131 }, 132 { 133 task: questionAnsweringTask, 134 input: QuestionAnsweringRequest{Inputs: QuestionAnsweringInputs{Question: "isn't it?"}}, 135 contentType: httpclient.MIMETypeJSON, 136 wantBody: []byte(`{"inputs": {"question": "is it?"}}`), 137 okResp: `{"answer": "it is"}`, 138 wantResp: `{"answer": "it is"}`, 139 }, 140 { 141 task: tableQuestionAnsweringTask, 142 input: TableQuestionAnsweringRequest{Inputs: TableQuestionAnsweringInputs{Query: "yes?"}}, 143 contentType: httpclient.MIMETypeJSON, 144 wantBody: []byte(`{"inputs": {"query": "yes?"}}`), 145 okResp: `{"answer": "yes"}`, 146 wantResp: `{"answer": "yes"}`, 147 }, 148 { 149 task: sentenceSimilarityTask, 150 input: SentenceSimilarityRequest{Inputs: SentenceSimilarityInputs{}}, 151 contentType: httpclient.MIMETypeJSON, 152 wantBody: []byte(`{"inputs": {}}`), 153 okResp: `[0.23]`, 154 wantResp: wrapArrayInObject(`[0.23]`, "scores"), 155 }, 156 { 157 task: conversationalTask, 158 input: ConversationalRequest{Inputs: ConversationalInputs{}}, 159 contentType: httpclient.MIMETypeJSON, 160 wantBody: []byte(`{"inputs": {}}`), 161 okResp: `{"generated_text": "gen"}`, 162 wantResp: `{"generated_text": "gen"}`, 163 }, 164 { 165 task: imageClassificationTask, 166 input: ImageRequest{Image: bEncoded}, 167 contentType: "text/plain.*", 168 wantBody: bRaw, 169 okResp: classificationResp, 170 wantResp: wrapArrayInObject(classificationResp, "classes"), 171 }, 172 { 173 task: imageSegmentationTask, 174 input: ImageRequest{Image: bEncoded}, 175 contentType: "text/plain.*", 176 wantBody: bRaw, 177 okResp: `[{"score": 0.123, "label": "backpack hip-hop", "mask": "YBcsSdfg"}]`, 178 wantResp: `{"segments": [{"score": 0.123, "label": "backpack hip-hop", "mask": ""}]}`, 179 }, 180 { 181 task: objectDetectionTask, 182 input: ImageRequest{Image: bEncoded}, 183 contentType: "text/plain.*", 184 wantBody: bRaw, 185 okResp: objDetectionResp, 186 wantResp: wrapArrayInObject(objDetectionResp, "objects"), 187 }, 188 { 189 task: imageToTextTask, 190 input: ImageRequest{Image: bEncoded}, 191 contentType: "text/plain.*", 192 wantBody: bRaw, 193 okResp: `[{"generated_text": "Me robaron mi runa mula"}]`, 194 wantResp: `{"text": "Me robaron mi runa mula"}`, 195 }, 196 { 197 task: speechRecognitionTask, 198 input: AudioRequest{Audio: bEncoded}, 199 contentType: "text/plain.*", 200 wantBody: bRaw, 201 okResp: `{"text": "Me robaron mi runa mula"}`, 202 wantResp: `{"text": "Me robaron mi runa mula"}`, 203 }, 204 { 205 task: audioClassificationTask, 206 input: AudioRequest{Audio: bEncoded}, 207 contentType: "text/plain.*", 208 wantBody: bRaw, 209 okResp: classificationResp, 210 wantResp: wrapArrayInObject(classificationResp, "classes"), 211 }, 212 } 213 214 func TestConnector_ExecuteSpeechRecognition(t *testing.T) { 215 c := qt.New(t) 216 217 for _, params := range coveredTasks { 218 testTask(c, params) 219 } 220 } 221 222 func testTask(c *qt.C, p taskParams) { 223 logger := zap.NewNop() 224 connector := Init(logger, nil) 225 226 c.Run("nok - HTTP client error - "+p.task, func(c *qt.C) { 227 c.Parallel() 228 229 connection, err := structpb.NewStruct(map[string]any{ 230 "base_url": "http://no-such.host", 231 }) 232 c.Assert(err, qt.IsNil) 233 234 exec, err := connector.CreateExecution(nil, connection, p.task) 235 c.Assert(err, qt.IsNil) 236 237 pbIn, err := base.ConvertToStructpb(p.input) 238 c.Assert(err, qt.IsNil) 239 pbIn.Fields["model"] = structpb.NewStringValue(model) 240 241 _, err = exec.Execution.Execute([]*structpb.Struct{pbIn}) 242 c.Check(err, qt.IsNotNil) 243 c.Check(err, qt.ErrorMatches, ".*no such host") 244 c.Check(errmsg.Message(err), qt.Matches, "Failed to call .*check that the connector configuration is correct.") 245 }) 246 247 testcases := []struct { 248 name string 249 customEndpoint bool 250 httpStatus int 251 httpBody string 252 wantErr string 253 }{ 254 { 255 name: "ok", 256 httpStatus: http.StatusOK, 257 httpBody: p.okResp, 258 }, 259 { 260 name: "nok - API error", 261 httpStatus: http.StatusBadRequest, 262 httpBody: errorResp, 263 wantErr: "Hugging Face responded with a 400 status code. Invalid request", 264 }, 265 { 266 name: "nok - API errors", 267 customEndpoint: true, 268 httpStatus: http.StatusTooManyRequests, 269 httpBody: errorsResp, 270 wantErr: "Hugging Face responded with a 429 status code. [Temporarily unavailable, Too many requests]", 271 }, 272 } 273 274 for _, tc := range testcases { 275 tc := tc 276 c.Run(tc.name+" _ "+p.task, func(c *qt.C) { 277 c.Parallel() 278 279 wantPath := modelsPath + model 280 if tc.customEndpoint { 281 wantPath = "/" 282 } 283 284 h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 285 c.Check(r.Method, qt.Equals, http.MethodPost) 286 c.Check(r.URL.Path, qt.Matches, wantPath) 287 288 c.Check(r.Header.Get("Authorization"), qt.Equals, "Bearer "+apiKey) 289 290 ct := r.Header.Get("Content-Type") 291 c.Check(ct, qt.Matches, p.contentType) 292 293 c.Assert(r.Body, qt.IsNotNil) 294 defer r.Body.Close() 295 296 body, err := io.ReadAll(r.Body) 297 c.Assert(err, qt.IsNil) 298 if ct == httpclient.MIMETypeJSON { 299 // If we have a case where we don't pass the input request, 300 // we can check if p.wantBody is not empty and then do 301 // c.Check(body, qt.JSONEquals, json.RawMessage(p.wantBody) 302 303 c.Check(body, qt.JSONEquals, p.input) 304 } else { 305 c.Check(body, qt.ContentEquals, p.wantBody) 306 } 307 308 w.Header().Set("Content-Type", httpclient.MIMETypeJSON) 309 w.WriteHeader(tc.httpStatus) 310 fmt.Fprint(w, tc.httpBody) 311 }) 312 313 srv := httptest.NewServer(h) 314 c.Cleanup(srv.Close) 315 316 connection, _ := structpb.NewStruct(map[string]any{ 317 "api_key": apiKey, 318 "base_url": srv.URL, 319 "is_custom_endpoint": tc.customEndpoint, 320 }) 321 322 exec, err := connector.CreateExecution(nil, connection, p.task) 323 c.Assert(err, qt.IsNil) 324 325 pbIn, err := base.ConvertToStructpb(p.input) 326 c.Assert(err, qt.IsNil) 327 pbIn.Fields["model"] = structpb.NewStringValue(model) 328 329 got, err := exec.Execution.Execute([]*structpb.Struct{pbIn}) 330 if tc.wantErr != "" { 331 c.Check(err, qt.IsNotNil) 332 c.Check(errmsg.Message(err), qt.Equals, tc.wantErr) 333 return 334 } 335 336 c.Check(err, qt.IsNil) 337 338 c.Assert(got, qt.HasLen, 1) 339 c.Check(p.wantResp, qt.JSONEquals, got[0].AsMap()) 340 }) 341 } 342 }