github.com/instill-ai/component@v0.16.0-beta/pkg/connector/huggingface/v0/connector_test.go (about)

     1  package huggingface
     2  
     3  import (
     4  	"encoding/base64"
     5  	"fmt"
     6  	"io"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"testing"
    10  
    11  	qt "github.com/frankban/quicktest"
    12  	"go.uber.org/zap"
    13  	"google.golang.org/protobuf/types/known/structpb"
    14  
    15  	"github.com/instill-ai/component/pkg/base"
    16  	"github.com/instill-ai/component/pkg/connector/util/httpclient"
    17  	"github.com/instill-ai/x/errmsg"
    18  )
    19  
    20  const (
    21  	apiKey = "123"
    22  	model  = "openai/whisper-tiny"
    23  
    24  	testInput = "testing generation"
    25  
    26  	fillMaskResp            = `[{"score": 0.234, "token": 3, "sequence": "one", "token_str": "three"}]`
    27  	classificationResp      = `[{"score": 0.123, "label": "backpack hip-hop"}, {"score": 0.894, "label": "lo-fi jazz"}]`
    28  	tokenClassificationResp = `[{"entity_group":"foo", "score": 0.234, "start": 0, "end": 5, "word": "bar"}]`
    29  	objDetectionResp        = `
    30  [
    31    {
    32  	"score": 0.123,
    33  	"label": "backpack hip-hop",
    34  	"box": {
    35  	  "xmin": 0,
    36  	  "xmax": 1,
    37  	  "ymin": 0,
    38  	  "ymax": 1
    39  	}
    40    }
    41  ]`
    42  
    43  	errorResp  = ` { "error": "Invalid request" }`
    44  	errorsResp = ` { "error": ["Temporarily unavailable", "Too many requests"] }`
    45  )
    46  
    47  var (
    48  	bRaw     = []byte("aaa")
    49  	bEncoded = base64.StdEncoding.EncodeToString(bRaw)
    50  
    51  	inputsBody = []byte(`{"inputs": "testing generation"}`)
    52  )
    53  
    54  type taskParams struct {
    55  	task        string
    56  	input       any
    57  	contentType string // content type received in Hugging Face
    58  	wantBody    []byte // expected request body in Hugging Face
    59  	okResp      string // successful response from Hugging Face
    60  	wantResp    string // successful response from connector
    61  }
    62  
    63  func wrapArrayInObject(array, key string) string {
    64  	return fmt.Sprintf(`{"%s": %s}`, key, array)
    65  }
    66  
    67  var coveredTasks = []taskParams{
    68  	{
    69  		task:        textGenerationTask,
    70  		input:       TextGenerationRequest{Inputs: testInput},
    71  		contentType: httpclient.MIMETypeJSON,
    72  		wantBody:    inputsBody,
    73  		okResp:      `[{"generated_text": "text response"}]`,
    74  		wantResp:    `{"generated_text": "text response"}`,
    75  	},
    76  	{
    77  		task:        textToImageTask,
    78  		input:       TextToImageRequest{Inputs: testInput},
    79  		contentType: httpclient.MIMETypeJSON,
    80  		wantBody:    inputsBody,
    81  		okResp:      string(bRaw),
    82  		wantResp:    fmt.Sprintf(`{"image": "data:image/jpeg;base64,%s"}`, bEncoded),
    83  	},
    84  	{
    85  		task:        fillMaskTask,
    86  		input:       FillMaskRequest{Inputs: testInput},
    87  		contentType: httpclient.MIMETypeJSON,
    88  		wantBody:    inputsBody,
    89  		okResp:      fillMaskResp,
    90  		wantResp:    wrapArrayInObject(fillMaskResp, "results"),
    91  	},
    92  	{
    93  		task:        summarizationTask,
    94  		input:       SummarizationRequest{Inputs: testInput},
    95  		contentType: httpclient.MIMETypeJSON,
    96  		wantBody:    inputsBody,
    97  		okResp:      `[{"summary_text": "summary"}]`,
    98  		wantResp:    `{"summary_text": "summary"}`,
    99  	},
   100  	{
   101  		task:        textClassificationTask,
   102  		input:       TextClassificationRequest{Inputs: testInput},
   103  		contentType: httpclient.MIMETypeJSON,
   104  		wantBody:    inputsBody,
   105  		okResp:      "[" + classificationResp + "]",
   106  		wantResp:    wrapArrayInObject(classificationResp, "results"),
   107  	},
   108  	{
   109  		task:        tokenClassificationTask,
   110  		input:       TokenClassificationRequest{Inputs: testInput},
   111  		contentType: httpclient.MIMETypeJSON,
   112  		wantBody:    inputsBody,
   113  		okResp:      tokenClassificationResp,
   114  		wantResp:    wrapArrayInObject(tokenClassificationResp, "results"),
   115  	},
   116  	{
   117  		task:        translationTask,
   118  		input:       TranslationRequest{Inputs: testInput},
   119  		contentType: httpclient.MIMETypeJSON,
   120  		wantBody:    inputsBody,
   121  		okResp:      `[{"translation_text": "translated"}]`,
   122  		wantResp:    `{"translation_text": "translated"}`,
   123  	},
   124  	{
   125  		task:        zeroShotClassificationTask,
   126  		input:       ZeroShotRequest{Inputs: testInput},
   127  		contentType: httpclient.MIMETypeJSON,
   128  		wantBody:    inputsBody,
   129  		okResp:      `{"sequence": "seq"}`,
   130  		wantResp:    `{"sequence": "seq"}`,
   131  	},
   132  	{
   133  		task:        questionAnsweringTask,
   134  		input:       QuestionAnsweringRequest{Inputs: QuestionAnsweringInputs{Question: "isn't it?"}},
   135  		contentType: httpclient.MIMETypeJSON,
   136  		wantBody:    []byte(`{"inputs": {"question": "is it?"}}`),
   137  		okResp:      `{"answer": "it is"}`,
   138  		wantResp:    `{"answer": "it is"}`,
   139  	},
   140  	{
   141  		task:        tableQuestionAnsweringTask,
   142  		input:       TableQuestionAnsweringRequest{Inputs: TableQuestionAnsweringInputs{Query: "yes?"}},
   143  		contentType: httpclient.MIMETypeJSON,
   144  		wantBody:    []byte(`{"inputs": {"query": "yes?"}}`),
   145  		okResp:      `{"answer": "yes"}`,
   146  		wantResp:    `{"answer": "yes"}`,
   147  	},
   148  	{
   149  		task:        sentenceSimilarityTask,
   150  		input:       SentenceSimilarityRequest{Inputs: SentenceSimilarityInputs{}},
   151  		contentType: httpclient.MIMETypeJSON,
   152  		wantBody:    []byte(`{"inputs": {}}`),
   153  		okResp:      `[0.23]`,
   154  		wantResp:    wrapArrayInObject(`[0.23]`, "scores"),
   155  	},
   156  	{
   157  		task:        conversationalTask,
   158  		input:       ConversationalRequest{Inputs: ConversationalInputs{}},
   159  		contentType: httpclient.MIMETypeJSON,
   160  		wantBody:    []byte(`{"inputs": {}}`),
   161  		okResp:      `{"generated_text": "gen"}`,
   162  		wantResp:    `{"generated_text": "gen"}`,
   163  	},
   164  	{
   165  		task:        imageClassificationTask,
   166  		input:       ImageRequest{Image: bEncoded},
   167  		contentType: "text/plain.*",
   168  		wantBody:    bRaw,
   169  		okResp:      classificationResp,
   170  		wantResp:    wrapArrayInObject(classificationResp, "classes"),
   171  	},
   172  	{
   173  		task:        imageSegmentationTask,
   174  		input:       ImageRequest{Image: bEncoded},
   175  		contentType: "text/plain.*",
   176  		wantBody:    bRaw,
   177  		okResp:      `[{"score": 0.123, "label": "backpack hip-hop", "mask": "YBcsSdfg"}]`,
   178  		wantResp:    `{"segments": [{"score": 0.123, "label": "backpack hip-hop", "mask": "data:image/png;base64,YBcsSdfg"}]}`,
   179  	},
   180  	{
   181  		task:        objectDetectionTask,
   182  		input:       ImageRequest{Image: bEncoded},
   183  		contentType: "text/plain.*",
   184  		wantBody:    bRaw,
   185  		okResp:      objDetectionResp,
   186  		wantResp:    wrapArrayInObject(objDetectionResp, "objects"),
   187  	},
   188  	{
   189  		task:        imageToTextTask,
   190  		input:       ImageRequest{Image: bEncoded},
   191  		contentType: "text/plain.*",
   192  		wantBody:    bRaw,
   193  		okResp:      `[{"generated_text": "Me robaron mi runa mula"}]`,
   194  		wantResp:    `{"text": "Me robaron mi runa mula"}`,
   195  	},
   196  	{
   197  		task:        speechRecognitionTask,
   198  		input:       AudioRequest{Audio: bEncoded},
   199  		contentType: "text/plain.*",
   200  		wantBody:    bRaw,
   201  		okResp:      `{"text": "Me robaron mi runa mula"}`,
   202  		wantResp:    `{"text": "Me robaron mi runa mula"}`,
   203  	},
   204  	{
   205  		task:        audioClassificationTask,
   206  		input:       AudioRequest{Audio: bEncoded},
   207  		contentType: "text/plain.*",
   208  		wantBody:    bRaw,
   209  		okResp:      classificationResp,
   210  		wantResp:    wrapArrayInObject(classificationResp, "classes"),
   211  	},
   212  }
   213  
   214  func TestConnector_ExecuteSpeechRecognition(t *testing.T) {
   215  	c := qt.New(t)
   216  
   217  	for _, params := range coveredTasks {
   218  		testTask(c, params)
   219  	}
   220  }
   221  
   222  func testTask(c *qt.C, p taskParams) {
   223  	logger := zap.NewNop()
   224  	connector := Init(logger, nil)
   225  
   226  	c.Run("nok - HTTP client error - "+p.task, func(c *qt.C) {
   227  		c.Parallel()
   228  
   229  		connection, err := structpb.NewStruct(map[string]any{
   230  			"base_url": "http://no-such.host",
   231  		})
   232  		c.Assert(err, qt.IsNil)
   233  
   234  		exec, err := connector.CreateExecution(nil, connection, p.task)
   235  		c.Assert(err, qt.IsNil)
   236  
   237  		pbIn, err := base.ConvertToStructpb(p.input)
   238  		c.Assert(err, qt.IsNil)
   239  		pbIn.Fields["model"] = structpb.NewStringValue(model)
   240  
   241  		_, err = exec.Execution.Execute([]*structpb.Struct{pbIn})
   242  		c.Check(err, qt.IsNotNil)
   243  		c.Check(err, qt.ErrorMatches, ".*no such host")
   244  		c.Check(errmsg.Message(err), qt.Matches, "Failed to call .*check that the connector configuration is correct.")
   245  	})
   246  
   247  	testcases := []struct {
   248  		name           string
   249  		customEndpoint bool
   250  		httpStatus     int
   251  		httpBody       string
   252  		wantErr        string
   253  	}{
   254  		{
   255  			name:       "ok",
   256  			httpStatus: http.StatusOK,
   257  			httpBody:   p.okResp,
   258  		},
   259  		{
   260  			name:       "nok - API error",
   261  			httpStatus: http.StatusBadRequest,
   262  			httpBody:   errorResp,
   263  			wantErr:    "Hugging Face responded with a 400 status code. Invalid request",
   264  		},
   265  		{
   266  			name:           "nok - API errors",
   267  			customEndpoint: true,
   268  			httpStatus:     http.StatusTooManyRequests,
   269  			httpBody:       errorsResp,
   270  			wantErr:        "Hugging Face responded with a 429 status code. [Temporarily unavailable, Too many requests]",
   271  		},
   272  	}
   273  
   274  	for _, tc := range testcases {
   275  		tc := tc
   276  		c.Run(tc.name+" _ "+p.task, func(c *qt.C) {
   277  			c.Parallel()
   278  
   279  			wantPath := modelsPath + model
   280  			if tc.customEndpoint {
   281  				wantPath = "/"
   282  			}
   283  
   284  			h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   285  				c.Check(r.Method, qt.Equals, http.MethodPost)
   286  				c.Check(r.URL.Path, qt.Matches, wantPath)
   287  
   288  				c.Check(r.Header.Get("Authorization"), qt.Equals, "Bearer "+apiKey)
   289  
   290  				ct := r.Header.Get("Content-Type")
   291  				c.Check(ct, qt.Matches, p.contentType)
   292  
   293  				c.Assert(r.Body, qt.IsNotNil)
   294  				defer r.Body.Close()
   295  
   296  				body, err := io.ReadAll(r.Body)
   297  				c.Assert(err, qt.IsNil)
   298  				if ct == httpclient.MIMETypeJSON {
   299  					// If we have a case where we don't pass the input request,
   300  					// we can check if p.wantBody is not empty and then do
   301  					// c.Check(body, qt.JSONEquals, json.RawMessage(p.wantBody)
   302  
   303  					c.Check(body, qt.JSONEquals, p.input)
   304  				} else {
   305  					c.Check(body, qt.ContentEquals, p.wantBody)
   306  				}
   307  
   308  				w.Header().Set("Content-Type", httpclient.MIMETypeJSON)
   309  				w.WriteHeader(tc.httpStatus)
   310  				fmt.Fprint(w, tc.httpBody)
   311  			})
   312  
   313  			srv := httptest.NewServer(h)
   314  			c.Cleanup(srv.Close)
   315  
   316  			connection, _ := structpb.NewStruct(map[string]any{
   317  				"api_key":            apiKey,
   318  				"base_url":           srv.URL,
   319  				"is_custom_endpoint": tc.customEndpoint,
   320  			})
   321  
   322  			exec, err := connector.CreateExecution(nil, connection, p.task)
   323  			c.Assert(err, qt.IsNil)
   324  
   325  			pbIn, err := base.ConvertToStructpb(p.input)
   326  			c.Assert(err, qt.IsNil)
   327  			pbIn.Fields["model"] = structpb.NewStringValue(model)
   328  
   329  			got, err := exec.Execution.Execute([]*structpb.Struct{pbIn})
   330  			if tc.wantErr != "" {
   331  				c.Check(err, qt.IsNotNil)
   332  				c.Check(errmsg.Message(err), qt.Equals, tc.wantErr)
   333  				return
   334  			}
   335  
   336  			c.Check(err, qt.IsNil)
   337  
   338  			c.Assert(got, qt.HasLen, 1)
   339  			c.Check(p.wantResp, qt.JSONEquals, got[0].AsMap())
   340  		})
   341  	}
   342  }