github.com/instill-ai/component@v0.16.0-beta/pkg/connector/huggingface/v0/main.go (about)

     1  //go:generate compogen readme --connector ./config ./README.mdx
     2  package huggingface
     3  
     4  import (
     5  	_ "embed"
     6  	"encoding/base64"
     7  	"encoding/json"
     8  	"fmt"
     9  	"sync"
    10  
    11  	"go.uber.org/zap"
    12  	"google.golang.org/protobuf/encoding/protojson"
    13  	"google.golang.org/protobuf/types/known/structpb"
    14  
    15  	"github.com/instill-ai/component/pkg/base"
    16  	"github.com/instill-ai/x/errmsg"
    17  )
    18  
    19  const (
    20  	textGenerationTask         = "TASK_TEXT_GENERATION"
    21  	textToImageTask            = "TASK_TEXT_TO_IMAGE"
    22  	fillMaskTask               = "TASK_FILL_MASK"
    23  	summarizationTask          = "TASK_SUMMARIZATION"
    24  	textClassificationTask     = "TASK_TEXT_CLASSIFICATION"
    25  	tokenClassificationTask    = "TASK_TOKEN_CLASSIFICATION"
    26  	translationTask            = "TASK_TRANSLATION"
    27  	zeroShotClassificationTask = "TASK_ZERO_SHOT_CLASSIFICATION"
    28  	featureExtractionTask      = "TASK_FEATURE_EXTRACTION"
    29  	questionAnsweringTask      = "TASK_QUESTION_ANSWERING"
    30  	tableQuestionAnsweringTask = "TASK_TABLE_QUESTION_ANSWERING"
    31  	sentenceSimilarityTask     = "TASK_SENTENCE_SIMILARITY"
    32  	conversationalTask         = "TASK_CONVERSATIONAL"
    33  	imageClassificationTask    = "TASK_IMAGE_CLASSIFICATION"
    34  	imageSegmentationTask      = "TASK_IMAGE_SEGMENTATION"
    35  	objectDetectionTask        = "TASK_OBJECT_DETECTION"
    36  	imageToTextTask            = "TASK_IMAGE_TO_TEXT"
    37  	speechRecognitionTask      = "TASK_SPEECH_RECOGNITION"
    38  	audioClassificationTask    = "TASK_AUDIO_CLASSIFICATION"
    39  )
    40  
    41  var (
    42  	//go:embed config/definition.json
    43  	definitionJSON []byte
    44  	//go:embed config/tasks.json
    45  	tasksJSON []byte
    46  	once      sync.Once
    47  	con       *connector
    48  )
    49  
    50  type connector struct {
    51  	base.BaseConnector
    52  }
    53  
    54  type execution struct {
    55  	base.BaseConnectorExecution
    56  }
    57  
    58  func Init(l *zap.Logger, u base.UsageHandler) *connector {
    59  	once.Do(func() {
    60  		con = &connector{
    61  			BaseConnector: base.BaseConnector{
    62  				Logger:       l,
    63  				UsageHandler: u,
    64  			},
    65  		}
    66  		err := con.LoadConnectorDefinition(definitionJSON, tasksJSON, nil)
    67  		if err != nil {
    68  			panic(err)
    69  		}
    70  	})
    71  	return con
    72  }
    73  
    74  func (c *connector) CreateExecution(sysVars map[string]any, connection *structpb.Struct, task string) (*base.ExecutionWrapper, error) {
    75  	return &base.ExecutionWrapper{Execution: &execution{
    76  		BaseConnectorExecution: base.BaseConnectorExecution{Connector: c, SystemVariables: sysVars, Connection: connection, Task: task},
    77  	}}, nil
    78  }
    79  
    80  func getAPIKey(config *structpb.Struct) string {
    81  	return config.GetFields()["api_key"].GetStringValue()
    82  }
    83  
    84  func getBaseURL(config *structpb.Struct) string {
    85  	return config.GetFields()["base_url"].GetStringValue()
    86  }
    87  
    88  func isCustomEndpoint(config *structpb.Struct) bool {
    89  	return config.GetFields()["is_custom_endpoint"].GetBoolValue()
    90  }
    91  
    92  func wrapSliceInStruct(data []byte, key string) (*structpb.Struct, error) {
    93  	var list []any
    94  	if err := json.Unmarshal(data, &list); err != nil {
    95  		return nil, err
    96  	}
    97  
    98  	results, err := structpb.NewList(list)
    99  	if err != nil {
   100  		return nil, err
   101  	}
   102  
   103  	return &structpb.Struct{
   104  		Fields: map[string]*structpb.Value{
   105  			key: structpb.NewListValue(results),
   106  		},
   107  	}, nil
   108  }
   109  
   110  func (e *execution) Execute(inputs []*structpb.Struct) ([]*structpb.Struct, error) {
   111  	client := newClient(e.Connection, e.GetLogger())
   112  	outputs := []*structpb.Struct{}
   113  
   114  	path := "/"
   115  	if !isCustomEndpoint(e.Connection) {
   116  		path = modelsPath + inputs[0].GetFields()["model"].GetStringValue()
   117  	}
   118  
   119  	for _, input := range inputs {
   120  		switch e.Task {
   121  		case textGenerationTask:
   122  			inputStruct := TextGenerationRequest{}
   123  			if err := base.ConvertFromStructpb(input, &inputStruct); err != nil {
   124  				return nil, err
   125  			}
   126  
   127  			resp := []TextGenerationResponse{}
   128  			req := client.R().SetBody(inputStruct).SetResult(&resp)
   129  			if _, err := post(req, path); err != nil {
   130  				return nil, err
   131  			}
   132  
   133  			if len(resp) < 1 {
   134  				err := fmt.Errorf("invalid response")
   135  				return nil, errmsg.AddMessage(err, "Hugging Face didn't return any result")
   136  			}
   137  
   138  			output, err := structpb.NewStruct(map[string]any{"generated_text": resp[0].GeneratedText})
   139  			if err != nil {
   140  				return nil, err
   141  			}
   142  
   143  			outputs = append(outputs, output)
   144  		case textToImageTask:
   145  			inputStruct := TextToImageRequest{}
   146  			if err := base.ConvertFromStructpb(input, &inputStruct); err != nil {
   147  				return nil, err
   148  			}
   149  
   150  			req := client.R().SetBody(inputStruct)
   151  			resp, err := post(req, path)
   152  			if err != nil {
   153  				return nil, err
   154  			}
   155  
   156  			rawImg := base64.StdEncoding.EncodeToString(resp.Body())
   157  			output, err := structpb.NewStruct(map[string]any{
   158  				"image": fmt.Sprintf("data:image/jpeg;base64,%s", rawImg),
   159  			})
   160  			if err != nil {
   161  				return nil, err
   162  			}
   163  
   164  			outputs = append(outputs, output)
   165  		case fillMaskTask:
   166  			inputStruct := FillMaskRequest{}
   167  			if err := base.ConvertFromStructpb(input, &inputStruct); err != nil {
   168  				return nil, err
   169  			}
   170  
   171  			req := client.R().SetBody(inputStruct)
   172  			resp, err := post(req, path)
   173  			if err != nil {
   174  				return nil, err
   175  			}
   176  
   177  			output, err := wrapSliceInStruct(resp.Body(), "results")
   178  			if err != nil {
   179  				return nil, err
   180  			}
   181  
   182  			outputs = append(outputs, output)
   183  		case summarizationTask:
   184  			inputStruct := SummarizationRequest{}
   185  			if err := base.ConvertFromStructpb(input, &inputStruct); err != nil {
   186  				return nil, err
   187  			}
   188  
   189  			resp := []SummarizationResponse{}
   190  			req := client.R().SetBody(inputStruct).SetResult(&resp)
   191  			if _, err := post(req, path); err != nil {
   192  				return nil, err
   193  			}
   194  
   195  			if len(resp) < 1 {
   196  				err := fmt.Errorf("invalid response")
   197  				return nil, errmsg.AddMessage(err, "Hugging Face didn't return any result")
   198  			}
   199  
   200  			output, err := structpb.NewStruct(map[string]any{"summary_text": resp[0].SummaryText})
   201  			if err != nil {
   202  				return nil, err
   203  			}
   204  
   205  			outputs = append(outputs, output)
   206  		case textClassificationTask:
   207  			inputStruct := TextClassificationRequest{}
   208  			if err := base.ConvertFromStructpb(input, &inputStruct); err != nil {
   209  				return nil, err
   210  			}
   211  
   212  			var resp [][]any
   213  			req := client.R().SetBody(inputStruct).SetResult(&resp)
   214  			if _, err := post(req, path); err != nil {
   215  				return nil, err
   216  			}
   217  
   218  			if len(resp) < 1 {
   219  				err := fmt.Errorf("invalid response")
   220  				return nil, errmsg.AddMessage(err, "Hugging Face didn't return any result")
   221  			}
   222  
   223  			results, err := structpb.NewList(resp[0])
   224  			if err != nil {
   225  				return nil, err
   226  			}
   227  
   228  			output := &structpb.Struct{
   229  				Fields: map[string]*structpb.Value{
   230  					"results": structpb.NewListValue(results),
   231  				},
   232  			}
   233  
   234  			outputs = append(outputs, output)
   235  		case tokenClassificationTask:
   236  			inputStruct := TokenClassificationRequest{}
   237  			if err := base.ConvertFromStructpb(input, &inputStruct); err != nil {
   238  				return nil, err
   239  			}
   240  			req := client.R().SetBody(inputStruct)
   241  			resp, err := post(req, path)
   242  			if err != nil {
   243  				return nil, err
   244  			}
   245  
   246  			output, err := wrapSliceInStruct(resp.Body(), "results")
   247  			if err != nil {
   248  				return nil, err
   249  			}
   250  
   251  			outputs = append(outputs, output)
   252  		case translationTask:
   253  			inputStruct := TranslationRequest{}
   254  			if err := base.ConvertFromStructpb(input, &inputStruct); err != nil {
   255  				return nil, err
   256  			}
   257  
   258  			resp := []TranslationResponse{}
   259  			req := client.R().SetBody(inputStruct).SetResult(&resp)
   260  			if _, err := post(req, path); err != nil {
   261  				return nil, err
   262  			}
   263  
   264  			if len(resp) < 1 {
   265  				err := fmt.Errorf("invalid response")
   266  				return nil, errmsg.AddMessage(err, "Hugging Face didn't return any result")
   267  			}
   268  
   269  			output, err := structpb.NewStruct(map[string]any{"translation_text": resp[0].TranslationText})
   270  			if err != nil {
   271  				return nil, err
   272  			}
   273  
   274  			outputs = append(outputs, output)
   275  		case zeroShotClassificationTask:
   276  			inputStruct := ZeroShotRequest{}
   277  			if err := base.ConvertFromStructpb(input, &inputStruct); err != nil {
   278  				return nil, err
   279  			}
   280  
   281  			req := client.R().SetBody(inputStruct)
   282  			resp, err := post(req, path)
   283  			if err != nil {
   284  				return nil, err
   285  			}
   286  
   287  			var output structpb.Struct
   288  			if err = protojson.Unmarshal(resp.Body(), &output); err != nil {
   289  				return nil, err
   290  			}
   291  
   292  			outputs = append(outputs, &output)
   293  		// case featureExtractionTask:
   294  		// TODO: fix this task
   295  		// 	inputStruct := FeatureExtractionRequest{}
   296  		// 	err := base.ConvertFromStructpb(input, &inputStruct)
   297  		// 	if err != nil {
   298  		// 		return nil, err
   299  		// 	}
   300  		// 	jsonBody, _ := json.Marshal(inputStruct)
   301  		// 	resp, err := doer.MakeHFAPIRequest(jsonBody, model)
   302  		// 	if err != nil {
   303  		// 		return nil, err
   304  		// 	}
   305  		// 	threeDArr := [][][]float64{}
   306  		// 	err = json.Unmarshal(resp, &threeDArr)
   307  		// 	if err != nil {
   308  		// 		return nil, err
   309  		// 	}
   310  		// 	if len(threeDArr) <= 0 {
   311  		// 		return nil, errors.New("invalid response")
   312  		// 	}
   313  		// 	nestedArr := threeDArr[0]
   314  		// 	features := structpb.ListValue{}
   315  		// 	features.Values = make([]*structpb.Value, len(nestedArr))
   316  		// 	for i, innerArr := range nestedArr {
   317  		// 		innerValues := make([]*structpb.Value, len(innerArr))
   318  		// 		for j := range innerArr {
   319  		// 			innerValues[j] = &structpb.Value{Kind: &structpb.Value_NumberValue{NumberValue: innerArr[j]}}
   320  		// 		}
   321  		// 		features.Values[i] = &structpb.Value{Kind: &structpb.Value_ListValue{ListValue: &structpb.ListValue{Values: innerValues}}}
   322  		// 	}
   323  		// 	output := structpb.Struct{
   324  		// 		Fields: map[string]*structpb.Value{"feature": {Kind: &structpb.Value_ListValue{ListValue: &features}}},
   325  		// 	}
   326  		// 	outputs = append(outputs, &output)
   327  		case questionAnsweringTask:
   328  			inputStruct := QuestionAnsweringRequest{}
   329  			if err := base.ConvertFromStructpb(input, &inputStruct); err != nil {
   330  				return nil, err
   331  			}
   332  			req := client.R().SetBody(inputStruct)
   333  			resp, err := post(req, path)
   334  			if err != nil {
   335  				return nil, err
   336  			}
   337  
   338  			var output structpb.Struct
   339  			if err = protojson.Unmarshal(resp.Body(), &output); err != nil {
   340  				return nil, err
   341  			}
   342  
   343  			outputs = append(outputs, &output)
   344  		case tableQuestionAnsweringTask:
   345  			inputStruct := TableQuestionAnsweringRequest{}
   346  			if err := base.ConvertFromStructpb(input, &inputStruct); err != nil {
   347  				return nil, err
   348  			}
   349  
   350  			req := client.R().SetBody(inputStruct)
   351  			resp, err := post(req, path)
   352  			if err != nil {
   353  				return nil, err
   354  			}
   355  
   356  			var output structpb.Struct
   357  			if err = protojson.Unmarshal(resp.Body(), &output); err != nil {
   358  				return nil, err
   359  			}
   360  
   361  			outputs = append(outputs, &output)
   362  		case sentenceSimilarityTask:
   363  			inputStruct := SentenceSimilarityRequest{}
   364  			if err := base.ConvertFromStructpb(input, &inputStruct); err != nil {
   365  				return nil, err
   366  			}
   367  
   368  			req := client.R().SetBody(inputStruct)
   369  			resp, err := post(req, path)
   370  			if err != nil {
   371  				return nil, err
   372  			}
   373  
   374  			output, err := wrapSliceInStruct(resp.Body(), "scores")
   375  			if err != nil {
   376  				return nil, err
   377  			}
   378  
   379  			outputs = append(outputs, output)
   380  		case conversationalTask:
   381  			inputStruct := ConversationalRequest{}
   382  			if err := base.ConvertFromStructpb(input, &inputStruct); err != nil {
   383  				return nil, err
   384  			}
   385  
   386  			req := client.R().SetBody(inputStruct)
   387  			resp, err := post(req, path)
   388  			if err != nil {
   389  				return nil, err
   390  			}
   391  
   392  			var output structpb.Struct
   393  			if err = protojson.Unmarshal(resp.Body(), &output); err != nil {
   394  				return nil, err
   395  			}
   396  
   397  			outputs = append(outputs, &output)
   398  		case imageClassificationTask:
   399  			inputStruct := ImageRequest{}
   400  			if err := base.ConvertFromStructpb(input, &inputStruct); err != nil {
   401  				return nil, err
   402  			}
   403  
   404  			b, err := base64.StdEncoding.DecodeString(base.TrimBase64Mime(inputStruct.Image))
   405  			if err != nil {
   406  				return nil, err
   407  			}
   408  
   409  			req := client.R().SetBody(b)
   410  			resp, err := post(req, path)
   411  			if err != nil {
   412  				return nil, err
   413  			}
   414  
   415  			output, err := wrapSliceInStruct(resp.Body(), "classes")
   416  			if err != nil {
   417  				return nil, err
   418  			}
   419  
   420  			outputs = append(outputs, output)
   421  		case imageSegmentationTask:
   422  			inputStruct := ImageRequest{}
   423  			if err := base.ConvertFromStructpb(input, &inputStruct); err != nil {
   424  				return nil, err
   425  			}
   426  
   427  			b, err := base64.StdEncoding.DecodeString(base.TrimBase64Mime(inputStruct.Image))
   428  			if err != nil {
   429  				return nil, err
   430  			}
   431  
   432  			resp := []ImageSegmentationResponse{}
   433  			req := client.R().SetBody(b).SetResult(&resp)
   434  			if _, err := post(req, path); err != nil {
   435  				return nil, err
   436  			}
   437  
   438  			segments := &structpb.ListValue{
   439  				Values: make([]*structpb.Value, len(resp)),
   440  			}
   441  
   442  			for i := range resp {
   443  				segment, err := structpb.NewStruct(map[string]any{
   444  					"score": resp[i].Score,
   445  					"label": resp[i].Label,
   446  					"mask":  fmt.Sprintf("data:image/png;base64,%s", resp[i].Mask),
   447  				})
   448  
   449  				if err != nil {
   450  					return nil, err
   451  				}
   452  
   453  				segments.Values[i] = structpb.NewStructValue(segment)
   454  			}
   455  
   456  			output := &structpb.Struct{
   457  				Fields: map[string]*structpb.Value{
   458  					"segments": structpb.NewListValue(segments),
   459  				},
   460  			}
   461  
   462  			outputs = append(outputs, output)
   463  		case objectDetectionTask:
   464  			inputStruct := ImageRequest{}
   465  			if err := base.ConvertFromStructpb(input, &inputStruct); err != nil {
   466  				return nil, err
   467  			}
   468  
   469  			b, err := base64.StdEncoding.DecodeString(base.TrimBase64Mime(inputStruct.Image))
   470  			if err != nil {
   471  				return nil, err
   472  			}
   473  
   474  			req := client.R().SetBody(b)
   475  			resp, err := post(req, path)
   476  			if err != nil {
   477  				return nil, err
   478  			}
   479  
   480  			output, err := wrapSliceInStruct(resp.Body(), "objects")
   481  			if err != nil {
   482  				return nil, err
   483  			}
   484  
   485  			outputs = append(outputs, output)
   486  		case imageToTextTask:
   487  			inputStruct := ImageRequest{}
   488  			if err := base.ConvertFromStructpb(input, &inputStruct); err != nil {
   489  				return nil, err
   490  			}
   491  
   492  			b, err := base64.StdEncoding.DecodeString(base.TrimBase64Mime(inputStruct.Image))
   493  			if err != nil {
   494  				return nil, err
   495  			}
   496  
   497  			resp := []ImageToTextResponse{}
   498  			req := client.R().SetBody(b).SetResult(&resp)
   499  			if _, err := post(req, path); err != nil {
   500  				return nil, err
   501  			}
   502  
   503  			if len(resp) < 1 {
   504  				err := fmt.Errorf("invalid response")
   505  				return nil, errmsg.AddMessage(err, "Hugging Face didn't return any result")
   506  			}
   507  
   508  			output, err := structpb.NewStruct(map[string]any{"text": resp[0].GeneratedText})
   509  			if err != nil {
   510  				return nil, err
   511  			}
   512  
   513  			outputs = append(outputs, output)
   514  		case speechRecognitionTask:
   515  			inputStruct := AudioRequest{}
   516  			if err := base.ConvertFromStructpb(input, &inputStruct); err != nil {
   517  				return nil, err
   518  			}
   519  
   520  			b, err := base64.StdEncoding.DecodeString(base.TrimBase64Mime(inputStruct.Audio))
   521  			if err != nil {
   522  				return nil, err
   523  			}
   524  
   525  			req := client.R().SetBody(b)
   526  			resp, err := post(req, path)
   527  			if err != nil {
   528  				return nil, err
   529  			}
   530  
   531  			output := new(structpb.Struct)
   532  			if err := protojson.Unmarshal(resp.Body(), output); err != nil {
   533  				return nil, err
   534  			}
   535  
   536  			outputs = append(outputs, output)
   537  		case audioClassificationTask:
   538  			inputStruct := AudioRequest{}
   539  			if err := base.ConvertFromStructpb(input, &inputStruct); err != nil {
   540  				return nil, err
   541  			}
   542  
   543  			b, err := base64.StdEncoding.DecodeString(base.TrimBase64Mime(inputStruct.Audio))
   544  			if err != nil {
   545  				return nil, err
   546  			}
   547  
   548  			req := client.R().SetBody(b)
   549  			resp, err := post(req, path)
   550  			if err != nil {
   551  				return nil, err
   552  			}
   553  
   554  			output, err := wrapSliceInStruct(resp.Body(), "classes")
   555  			if err != nil {
   556  				return nil, err
   557  			}
   558  
   559  			outputs = append(outputs, output)
   560  		default:
   561  			return nil, errmsg.AddMessage(
   562  				fmt.Errorf("not supported task: %s", e.Task),
   563  				fmt.Sprintf("%s task is not supported.", e.Task),
   564  			)
   565  		}
   566  	}
   567  
   568  	return outputs, nil
   569  }
   570  
   571  func (c *connector) Test(sysVars map[string]any, connection *structpb.Struct) error {
   572  	req := newClient(connection, c.Logger).R()
   573  	resp, err := req.Get("")
   574  	if err != nil {
   575  		return err
   576  	}
   577  
   578  	if resp.IsError() {
   579  		return fmt.Errorf("connection error")
   580  	}
   581  
   582  	return nil
   583  }