github.com/instill-ai/component@v0.16.0-beta/pkg/connector/huggingface/v0/main.go (about) 1 //go:generate compogen readme --connector ./config ./README.mdx 2 package huggingface 3 4 import ( 5 _ "embed" 6 "encoding/base64" 7 "encoding/json" 8 "fmt" 9 "sync" 10 11 "go.uber.org/zap" 12 "google.golang.org/protobuf/encoding/protojson" 13 "google.golang.org/protobuf/types/known/structpb" 14 15 "github.com/instill-ai/component/pkg/base" 16 "github.com/instill-ai/x/errmsg" 17 ) 18 19 const ( 20 textGenerationTask = "TASK_TEXT_GENERATION" 21 textToImageTask = "TASK_TEXT_TO_IMAGE" 22 fillMaskTask = "TASK_FILL_MASK" 23 summarizationTask = "TASK_SUMMARIZATION" 24 textClassificationTask = "TASK_TEXT_CLASSIFICATION" 25 tokenClassificationTask = "TASK_TOKEN_CLASSIFICATION" 26 translationTask = "TASK_TRANSLATION" 27 zeroShotClassificationTask = "TASK_ZERO_SHOT_CLASSIFICATION" 28 featureExtractionTask = "TASK_FEATURE_EXTRACTION" 29 questionAnsweringTask = "TASK_QUESTION_ANSWERING" 30 tableQuestionAnsweringTask = "TASK_TABLE_QUESTION_ANSWERING" 31 sentenceSimilarityTask = "TASK_SENTENCE_SIMILARITY" 32 conversationalTask = "TASK_CONVERSATIONAL" 33 imageClassificationTask = "TASK_IMAGE_CLASSIFICATION" 34 imageSegmentationTask = "TASK_IMAGE_SEGMENTATION" 35 objectDetectionTask = "TASK_OBJECT_DETECTION" 36 imageToTextTask = "TASK_IMAGE_TO_TEXT" 37 speechRecognitionTask = "TASK_SPEECH_RECOGNITION" 38 audioClassificationTask = "TASK_AUDIO_CLASSIFICATION" 39 ) 40 41 var ( 42 //go:embed config/definition.json 43 definitionJSON []byte 44 //go:embed config/tasks.json 45 tasksJSON []byte 46 once sync.Once 47 con *connector 48 ) 49 50 type connector struct { 51 base.BaseConnector 52 } 53 54 type execution struct { 55 base.BaseConnectorExecution 56 } 57 58 func Init(l *zap.Logger, u base.UsageHandler) *connector { 59 once.Do(func() { 60 con = &connector{ 61 BaseConnector: base.BaseConnector{ 62 Logger: l, 63 UsageHandler: u, 64 }, 65 } 66 err := con.LoadConnectorDefinition(definitionJSON, tasksJSON, nil) 67 if err != nil { 68 panic(err) 69 } 70 }) 71 return con 72 } 73 74 func (c *connector) CreateExecution(sysVars map[string]any, connection *structpb.Struct, task string) (*base.ExecutionWrapper, error) { 75 return &base.ExecutionWrapper{Execution: &execution{ 76 BaseConnectorExecution: base.BaseConnectorExecution{Connector: c, SystemVariables: sysVars, Connection: connection, Task: task}, 77 }}, nil 78 } 79 80 func getAPIKey(config *structpb.Struct) string { 81 return config.GetFields()["api_key"].GetStringValue() 82 } 83 84 func getBaseURL(config *structpb.Struct) string { 85 return config.GetFields()["base_url"].GetStringValue() 86 } 87 88 func isCustomEndpoint(config *structpb.Struct) bool { 89 return config.GetFields()["is_custom_endpoint"].GetBoolValue() 90 } 91 92 func wrapSliceInStruct(data []byte, key string) (*structpb.Struct, error) { 93 var list []any 94 if err := json.Unmarshal(data, &list); err != nil { 95 return nil, err 96 } 97 98 results, err := structpb.NewList(list) 99 if err != nil { 100 return nil, err 101 } 102 103 return &structpb.Struct{ 104 Fields: map[string]*structpb.Value{ 105 key: structpb.NewListValue(results), 106 }, 107 }, nil 108 } 109 110 func (e *execution) Execute(inputs []*structpb.Struct) ([]*structpb.Struct, error) { 111 client := newClient(e.Connection, e.GetLogger()) 112 outputs := []*structpb.Struct{} 113 114 path := "/" 115 if !isCustomEndpoint(e.Connection) { 116 path = modelsPath + inputs[0].GetFields()["model"].GetStringValue() 117 } 118 119 for _, input := range inputs { 120 switch e.Task { 121 case textGenerationTask: 122 inputStruct := TextGenerationRequest{} 123 if err := base.ConvertFromStructpb(input, &inputStruct); err != nil { 124 return nil, err 125 } 126 127 resp := []TextGenerationResponse{} 128 req := client.R().SetBody(inputStruct).SetResult(&resp) 129 if _, err := post(req, path); err != nil { 130 return nil, err 131 } 132 133 if len(resp) < 1 { 134 err := fmt.Errorf("invalid response") 135 return nil, errmsg.AddMessage(err, "Hugging Face didn't return any result") 136 } 137 138 output, err := structpb.NewStruct(map[string]any{"generated_text": resp[0].GeneratedText}) 139 if err != nil { 140 return nil, err 141 } 142 143 outputs = append(outputs, output) 144 case textToImageTask: 145 inputStruct := TextToImageRequest{} 146 if err := base.ConvertFromStructpb(input, &inputStruct); err != nil { 147 return nil, err 148 } 149 150 req := client.R().SetBody(inputStruct) 151 resp, err := post(req, path) 152 if err != nil { 153 return nil, err 154 } 155 156 rawImg := base64.StdEncoding.EncodeToString(resp.Body()) 157 output, err := structpb.NewStruct(map[string]any{ 158 "image": fmt.Sprintf("data:image/jpeg;base64,%s", rawImg), 159 }) 160 if err != nil { 161 return nil, err 162 } 163 164 outputs = append(outputs, output) 165 case fillMaskTask: 166 inputStruct := FillMaskRequest{} 167 if err := base.ConvertFromStructpb(input, &inputStruct); err != nil { 168 return nil, err 169 } 170 171 req := client.R().SetBody(inputStruct) 172 resp, err := post(req, path) 173 if err != nil { 174 return nil, err 175 } 176 177 output, err := wrapSliceInStruct(resp.Body(), "results") 178 if err != nil { 179 return nil, err 180 } 181 182 outputs = append(outputs, output) 183 case summarizationTask: 184 inputStruct := SummarizationRequest{} 185 if err := base.ConvertFromStructpb(input, &inputStruct); err != nil { 186 return nil, err 187 } 188 189 resp := []SummarizationResponse{} 190 req := client.R().SetBody(inputStruct).SetResult(&resp) 191 if _, err := post(req, path); err != nil { 192 return nil, err 193 } 194 195 if len(resp) < 1 { 196 err := fmt.Errorf("invalid response") 197 return nil, errmsg.AddMessage(err, "Hugging Face didn't return any result") 198 } 199 200 output, err := structpb.NewStruct(map[string]any{"summary_text": resp[0].SummaryText}) 201 if err != nil { 202 return nil, err 203 } 204 205 outputs = append(outputs, output) 206 case textClassificationTask: 207 inputStruct := TextClassificationRequest{} 208 if err := base.ConvertFromStructpb(input, &inputStruct); err != nil { 209 return nil, err 210 } 211 212 var resp [][]any 213 req := client.R().SetBody(inputStruct).SetResult(&resp) 214 if _, err := post(req, path); err != nil { 215 return nil, err 216 } 217 218 if len(resp) < 1 { 219 err := fmt.Errorf("invalid response") 220 return nil, errmsg.AddMessage(err, "Hugging Face didn't return any result") 221 } 222 223 results, err := structpb.NewList(resp[0]) 224 if err != nil { 225 return nil, err 226 } 227 228 output := &structpb.Struct{ 229 Fields: map[string]*structpb.Value{ 230 "results": structpb.NewListValue(results), 231 }, 232 } 233 234 outputs = append(outputs, output) 235 case tokenClassificationTask: 236 inputStruct := TokenClassificationRequest{} 237 if err := base.ConvertFromStructpb(input, &inputStruct); err != nil { 238 return nil, err 239 } 240 req := client.R().SetBody(inputStruct) 241 resp, err := post(req, path) 242 if err != nil { 243 return nil, err 244 } 245 246 output, err := wrapSliceInStruct(resp.Body(), "results") 247 if err != nil { 248 return nil, err 249 } 250 251 outputs = append(outputs, output) 252 case translationTask: 253 inputStruct := TranslationRequest{} 254 if err := base.ConvertFromStructpb(input, &inputStruct); err != nil { 255 return nil, err 256 } 257 258 resp := []TranslationResponse{} 259 req := client.R().SetBody(inputStruct).SetResult(&resp) 260 if _, err := post(req, path); err != nil { 261 return nil, err 262 } 263 264 if len(resp) < 1 { 265 err := fmt.Errorf("invalid response") 266 return nil, errmsg.AddMessage(err, "Hugging Face didn't return any result") 267 } 268 269 output, err := structpb.NewStruct(map[string]any{"translation_text": resp[0].TranslationText}) 270 if err != nil { 271 return nil, err 272 } 273 274 outputs = append(outputs, output) 275 case zeroShotClassificationTask: 276 inputStruct := ZeroShotRequest{} 277 if err := base.ConvertFromStructpb(input, &inputStruct); err != nil { 278 return nil, err 279 } 280 281 req := client.R().SetBody(inputStruct) 282 resp, err := post(req, path) 283 if err != nil { 284 return nil, err 285 } 286 287 var output structpb.Struct 288 if err = protojson.Unmarshal(resp.Body(), &output); err != nil { 289 return nil, err 290 } 291 292 outputs = append(outputs, &output) 293 // case featureExtractionTask: 294 // TODO: fix this task 295 // inputStruct := FeatureExtractionRequest{} 296 // err := base.ConvertFromStructpb(input, &inputStruct) 297 // if err != nil { 298 // return nil, err 299 // } 300 // jsonBody, _ := json.Marshal(inputStruct) 301 // resp, err := doer.MakeHFAPIRequest(jsonBody, model) 302 // if err != nil { 303 // return nil, err 304 // } 305 // threeDArr := [][][]float64{} 306 // err = json.Unmarshal(resp, &threeDArr) 307 // if err != nil { 308 // return nil, err 309 // } 310 // if len(threeDArr) <= 0 { 311 // return nil, errors.New("invalid response") 312 // } 313 // nestedArr := threeDArr[0] 314 // features := structpb.ListValue{} 315 // features.Values = make([]*structpb.Value, len(nestedArr)) 316 // for i, innerArr := range nestedArr { 317 // innerValues := make([]*structpb.Value, len(innerArr)) 318 // for j := range innerArr { 319 // innerValues[j] = &structpb.Value{Kind: &structpb.Value_NumberValue{NumberValue: innerArr[j]}} 320 // } 321 // features.Values[i] = &structpb.Value{Kind: &structpb.Value_ListValue{ListValue: &structpb.ListValue{Values: innerValues}}} 322 // } 323 // output := structpb.Struct{ 324 // Fields: map[string]*structpb.Value{"feature": {Kind: &structpb.Value_ListValue{ListValue: &features}}}, 325 // } 326 // outputs = append(outputs, &output) 327 case questionAnsweringTask: 328 inputStruct := QuestionAnsweringRequest{} 329 if err := base.ConvertFromStructpb(input, &inputStruct); err != nil { 330 return nil, err 331 } 332 req := client.R().SetBody(inputStruct) 333 resp, err := post(req, path) 334 if err != nil { 335 return nil, err 336 } 337 338 var output structpb.Struct 339 if err = protojson.Unmarshal(resp.Body(), &output); err != nil { 340 return nil, err 341 } 342 343 outputs = append(outputs, &output) 344 case tableQuestionAnsweringTask: 345 inputStruct := TableQuestionAnsweringRequest{} 346 if err := base.ConvertFromStructpb(input, &inputStruct); err != nil { 347 return nil, err 348 } 349 350 req := client.R().SetBody(inputStruct) 351 resp, err := post(req, path) 352 if err != nil { 353 return nil, err 354 } 355 356 var output structpb.Struct 357 if err = protojson.Unmarshal(resp.Body(), &output); err != nil { 358 return nil, err 359 } 360 361 outputs = append(outputs, &output) 362 case sentenceSimilarityTask: 363 inputStruct := SentenceSimilarityRequest{} 364 if err := base.ConvertFromStructpb(input, &inputStruct); err != nil { 365 return nil, err 366 } 367 368 req := client.R().SetBody(inputStruct) 369 resp, err := post(req, path) 370 if err != nil { 371 return nil, err 372 } 373 374 output, err := wrapSliceInStruct(resp.Body(), "scores") 375 if err != nil { 376 return nil, err 377 } 378 379 outputs = append(outputs, output) 380 case conversationalTask: 381 inputStruct := ConversationalRequest{} 382 if err := base.ConvertFromStructpb(input, &inputStruct); err != nil { 383 return nil, err 384 } 385 386 req := client.R().SetBody(inputStruct) 387 resp, err := post(req, path) 388 if err != nil { 389 return nil, err 390 } 391 392 var output structpb.Struct 393 if err = protojson.Unmarshal(resp.Body(), &output); err != nil { 394 return nil, err 395 } 396 397 outputs = append(outputs, &output) 398 case imageClassificationTask: 399 inputStruct := ImageRequest{} 400 if err := base.ConvertFromStructpb(input, &inputStruct); err != nil { 401 return nil, err 402 } 403 404 b, err := base64.StdEncoding.DecodeString(base.TrimBase64Mime(inputStruct.Image)) 405 if err != nil { 406 return nil, err 407 } 408 409 req := client.R().SetBody(b) 410 resp, err := post(req, path) 411 if err != nil { 412 return nil, err 413 } 414 415 output, err := wrapSliceInStruct(resp.Body(), "classes") 416 if err != nil { 417 return nil, err 418 } 419 420 outputs = append(outputs, output) 421 case imageSegmentationTask: 422 inputStruct := ImageRequest{} 423 if err := base.ConvertFromStructpb(input, &inputStruct); err != nil { 424 return nil, err 425 } 426 427 b, err := base64.StdEncoding.DecodeString(base.TrimBase64Mime(inputStruct.Image)) 428 if err != nil { 429 return nil, err 430 } 431 432 resp := []ImageSegmentationResponse{} 433 req := client.R().SetBody(b).SetResult(&resp) 434 if _, err := post(req, path); err != nil { 435 return nil, err 436 } 437 438 segments := &structpb.ListValue{ 439 Values: make([]*structpb.Value, len(resp)), 440 } 441 442 for i := range resp { 443 segment, err := structpb.NewStruct(map[string]any{ 444 "score": resp[i].Score, 445 "label": resp[i].Label, 446 "mask": fmt.Sprintf("data:image/png;base64,%s", resp[i].Mask), 447 }) 448 449 if err != nil { 450 return nil, err 451 } 452 453 segments.Values[i] = structpb.NewStructValue(segment) 454 } 455 456 output := &structpb.Struct{ 457 Fields: map[string]*structpb.Value{ 458 "segments": structpb.NewListValue(segments), 459 }, 460 } 461 462 outputs = append(outputs, output) 463 case objectDetectionTask: 464 inputStruct := ImageRequest{} 465 if err := base.ConvertFromStructpb(input, &inputStruct); err != nil { 466 return nil, err 467 } 468 469 b, err := base64.StdEncoding.DecodeString(base.TrimBase64Mime(inputStruct.Image)) 470 if err != nil { 471 return nil, err 472 } 473 474 req := client.R().SetBody(b) 475 resp, err := post(req, path) 476 if err != nil { 477 return nil, err 478 } 479 480 output, err := wrapSliceInStruct(resp.Body(), "objects") 481 if err != nil { 482 return nil, err 483 } 484 485 outputs = append(outputs, output) 486 case imageToTextTask: 487 inputStruct := ImageRequest{} 488 if err := base.ConvertFromStructpb(input, &inputStruct); err != nil { 489 return nil, err 490 } 491 492 b, err := base64.StdEncoding.DecodeString(base.TrimBase64Mime(inputStruct.Image)) 493 if err != nil { 494 return nil, err 495 } 496 497 resp := []ImageToTextResponse{} 498 req := client.R().SetBody(b).SetResult(&resp) 499 if _, err := post(req, path); err != nil { 500 return nil, err 501 } 502 503 if len(resp) < 1 { 504 err := fmt.Errorf("invalid response") 505 return nil, errmsg.AddMessage(err, "Hugging Face didn't return any result") 506 } 507 508 output, err := structpb.NewStruct(map[string]any{"text": resp[0].GeneratedText}) 509 if err != nil { 510 return nil, err 511 } 512 513 outputs = append(outputs, output) 514 case speechRecognitionTask: 515 inputStruct := AudioRequest{} 516 if err := base.ConvertFromStructpb(input, &inputStruct); err != nil { 517 return nil, err 518 } 519 520 b, err := base64.StdEncoding.DecodeString(base.TrimBase64Mime(inputStruct.Audio)) 521 if err != nil { 522 return nil, err 523 } 524 525 req := client.R().SetBody(b) 526 resp, err := post(req, path) 527 if err != nil { 528 return nil, err 529 } 530 531 output := new(structpb.Struct) 532 if err := protojson.Unmarshal(resp.Body(), output); err != nil { 533 return nil, err 534 } 535 536 outputs = append(outputs, output) 537 case audioClassificationTask: 538 inputStruct := AudioRequest{} 539 if err := base.ConvertFromStructpb(input, &inputStruct); err != nil { 540 return nil, err 541 } 542 543 b, err := base64.StdEncoding.DecodeString(base.TrimBase64Mime(inputStruct.Audio)) 544 if err != nil { 545 return nil, err 546 } 547 548 req := client.R().SetBody(b) 549 resp, err := post(req, path) 550 if err != nil { 551 return nil, err 552 } 553 554 output, err := wrapSliceInStruct(resp.Body(), "classes") 555 if err != nil { 556 return nil, err 557 } 558 559 outputs = append(outputs, output) 560 default: 561 return nil, errmsg.AddMessage( 562 fmt.Errorf("not supported task: %s", e.Task), 563 fmt.Sprintf("%s task is not supported.", e.Task), 564 ) 565 } 566 } 567 568 return outputs, nil 569 } 570 571 func (c *connector) Test(sysVars map[string]any, connection *structpb.Struct) error { 572 req := newClient(connection, c.Logger).R() 573 resp, err := req.Get("") 574 if err != nil { 575 return err 576 } 577 578 if resp.IsError() { 579 return fmt.Errorf("connection error") 580 } 581 582 return nil 583 }