github.com/instill-ai/component@v0.16.0-beta/pkg/connector/openai/v0/connector_test.go (about)

     1  package openai
     2  
     3  import (
     4  	"fmt"
     5  	"net/http"
     6  	"net/http/httptest"
     7  	"testing"
     8  
     9  	qt "github.com/frankban/quicktest"
    10  	"go.uber.org/zap"
    11  	"google.golang.org/protobuf/types/known/structpb"
    12  
    13  	"github.com/instill-ai/component/pkg/connector/util/httpclient"
    14  	"github.com/instill-ai/x/errmsg"
    15  )
    16  
    17  const (
    18  	apiKey  = "123"
    19  	org     = "org1"
    20  	errResp = `
    21  {
    22    "error": {
    23      "message": "Incorrect API key provided."
    24    }
    25  }`
    26  )
    27  
    28  func TestConnector_Execute(t *testing.T) {
    29  	c := qt.New(t)
    30  
    31  	logger := zap.NewNop()
    32  	connector := Init(logger, nil)
    33  
    34  	testcases := []struct {
    35  		name        string
    36  		task        string
    37  		path        string
    38  		contentType string
    39  	}{
    40  		{
    41  			name:        "text generation",
    42  			task:        textGenerationTask,
    43  			path:        completionsPath,
    44  			contentType: httpclient.MIMETypeJSON,
    45  		},
    46  		{
    47  			name:        "text embeddings",
    48  			task:        textEmbeddingsTask,
    49  			path:        embeddingsPath,
    50  			contentType: httpclient.MIMETypeJSON,
    51  		},
    52  		{
    53  			name:        "speech recognition",
    54  			task:        speechRecognitionTask,
    55  			path:        transcriptionsPath,
    56  			contentType: "multipart/form-data; boundary=.*",
    57  		},
    58  		{
    59  			name:        "text to speech",
    60  			task:        textToSpeechTask,
    61  			path:        createSpeechPath,
    62  			contentType: httpclient.MIMETypeJSON,
    63  		},
    64  		{
    65  			name:        "text to image",
    66  			task:        textToImageTask,
    67  			path:        imgGenerationPath,
    68  			contentType: httpclient.MIMETypeJSON,
    69  		},
    70  	}
    71  
    72  	// TODO we'll likely want to have a test function per task and test at
    73  	// least OK, NOK. For now, only errors are tested in order to verify
    74  	// end-user messages.
    75  	for _, tc := range testcases {
    76  		c.Run("nok - "+tc.name+" 401", func(c *qt.C) {
    77  			h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    78  				c.Check(r.Method, qt.Equals, http.MethodPost)
    79  				c.Check(r.URL.Path, qt.Equals, tc.path)
    80  
    81  				c.Check(r.Header.Get("OpenAI-Organization"), qt.Equals, org)
    82  				c.Check(r.Header.Get("Authorization"), qt.Equals, "Bearer "+apiKey)
    83  
    84  				c.Check(r.Header.Get("Content-Type"), qt.Matches, tc.contentType)
    85  
    86  				w.Header().Set("Content-Type", httpclient.MIMETypeJSON)
    87  				w.WriteHeader(http.StatusUnauthorized)
    88  				fmt.Fprintln(w, errResp)
    89  			})
    90  
    91  			openAIServer := httptest.NewServer(h)
    92  			c.Cleanup(openAIServer.Close)
    93  
    94  			connection, err := structpb.NewStruct(map[string]any{
    95  				"base_path":    openAIServer.URL,
    96  				"api_key":      apiKey,
    97  				"organization": org,
    98  			})
    99  			c.Assert(err, qt.IsNil)
   100  
   101  			exec, err := connector.CreateExecution(nil, connection, tc.task)
   102  			c.Assert(err, qt.IsNil)
   103  
   104  			pbIn := new(structpb.Struct)
   105  			_, err = exec.Execution.Execute([]*structpb.Struct{pbIn})
   106  			c.Check(err, qt.IsNotNil)
   107  
   108  			want := "OpenAI responded with a 401 status code. Incorrect API key provided."
   109  			c.Check(errmsg.Message(err), qt.Equals, want)
   110  		})
   111  	}
   112  
   113  	c.Run("nok - unsupported task", func(c *qt.C) {
   114  		task := "FOOBAR"
   115  		exec, err := connector.CreateExecution(nil, new(structpb.Struct), task)
   116  		c.Assert(err, qt.IsNil)
   117  
   118  		pbIn := new(structpb.Struct)
   119  		_, err = exec.Execution.Execute([]*structpb.Struct{pbIn})
   120  		c.Check(err, qt.IsNotNil)
   121  
   122  		want := "FOOBAR task is not supported."
   123  		c.Check(errmsg.Message(err), qt.Equals, want)
   124  	})
   125  }
   126  
   127  func TestConnector_Test(t *testing.T) {
   128  	c := qt.New(t)
   129  
   130  	logger := zap.NewNop()
   131  	connector := Init(logger, nil)
   132  
   133  	c.Run("nok - error", func(c *qt.C) {
   134  		h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   135  			c.Check(r.Method, qt.Equals, http.MethodGet)
   136  			c.Check(r.URL.Path, qt.Equals, listModelsPath)
   137  
   138  			w.Header().Set("Content-Type", httpclient.MIMETypeJSON)
   139  			w.WriteHeader(http.StatusUnauthorized)
   140  			fmt.Fprintln(w, errResp)
   141  		})
   142  
   143  		openAIServer := httptest.NewServer(h)
   144  		c.Cleanup(openAIServer.Close)
   145  
   146  		connection, err := structpb.NewStruct(map[string]any{
   147  			"base_path": openAIServer.URL,
   148  		})
   149  		c.Assert(err, qt.IsNil)
   150  
   151  		err = connector.Test(nil, connection)
   152  		c.Check(err, qt.IsNotNil)
   153  
   154  		wantMsg := "OpenAI responded with a 401 status code. Incorrect API key provided."
   155  		c.Check(errmsg.Message(err), qt.Equals, wantMsg)
   156  	})
   157  
   158  	c.Run("ok - disconnected", func(c *qt.C) {
   159  		h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   160  			c.Check(r.Method, qt.Equals, http.MethodGet)
   161  			c.Check(r.URL.Path, qt.Equals, listModelsPath)
   162  
   163  			w.Header().Set("Content-Type", httpclient.MIMETypeJSON)
   164  			fmt.Fprintln(w, `{}`)
   165  		})
   166  
   167  		openAIServer := httptest.NewServer(h)
   168  		c.Cleanup(openAIServer.Close)
   169  
   170  		connection, err := structpb.NewStruct(map[string]any{
   171  			"base_path": openAIServer.URL,
   172  		})
   173  		c.Assert(err, qt.IsNil)
   174  
   175  		err = connector.Test(nil, connection)
   176  		c.Check(err, qt.IsNotNil)
   177  	})
   178  
   179  	c.Run("ok - connected", func(c *qt.C) {
   180  		h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   181  			c.Check(r.Method, qt.Equals, http.MethodGet)
   182  			c.Check(r.URL.Path, qt.Equals, listModelsPath)
   183  
   184  			w.Header().Set("Content-Type", httpclient.MIMETypeJSON)
   185  			fmt.Fprintln(w, `{"data": [{}]}`)
   186  		})
   187  
   188  		openAIServer := httptest.NewServer(h)
   189  		c.Cleanup(openAIServer.Close)
   190  
   191  		connection, err := structpb.NewStruct(map[string]any{
   192  			"base_path": openAIServer.URL,
   193  		})
   194  		c.Assert(err, qt.IsNil)
   195  
   196  		err = connector.Test(nil, connection)
   197  		c.Check(err, qt.IsNil)
   198  	})
   199  }