github.com/instill-ai/component@v0.16.0-beta/pkg/connector/openai/v0/connector_test.go (about) 1 package openai 2 3 import ( 4 "fmt" 5 "net/http" 6 "net/http/httptest" 7 "testing" 8 9 qt "github.com/frankban/quicktest" 10 "go.uber.org/zap" 11 "google.golang.org/protobuf/types/known/structpb" 12 13 "github.com/instill-ai/component/pkg/connector/util/httpclient" 14 "github.com/instill-ai/x/errmsg" 15 ) 16 17 const ( 18 apiKey = "123" 19 org = "org1" 20 errResp = ` 21 { 22 "error": { 23 "message": "Incorrect API key provided." 24 } 25 }` 26 ) 27 28 func TestConnector_Execute(t *testing.T) { 29 c := qt.New(t) 30 31 logger := zap.NewNop() 32 connector := Init(logger, nil) 33 34 testcases := []struct { 35 name string 36 task string 37 path string 38 contentType string 39 }{ 40 { 41 name: "text generation", 42 task: textGenerationTask, 43 path: completionsPath, 44 contentType: httpclient.MIMETypeJSON, 45 }, 46 { 47 name: "text embeddings", 48 task: textEmbeddingsTask, 49 path: embeddingsPath, 50 contentType: httpclient.MIMETypeJSON, 51 }, 52 { 53 name: "speech recognition", 54 task: speechRecognitionTask, 55 path: transcriptionsPath, 56 contentType: "multipart/form-data; boundary=.*", 57 }, 58 { 59 name: "text to speech", 60 task: textToSpeechTask, 61 path: createSpeechPath, 62 contentType: httpclient.MIMETypeJSON, 63 }, 64 { 65 name: "text to image", 66 task: textToImageTask, 67 path: imgGenerationPath, 68 contentType: httpclient.MIMETypeJSON, 69 }, 70 } 71 72 // TODO we'll likely want to have a test function per task and test at 73 // least OK, NOK. For now, only errors are tested in order to verify 74 // end-user messages. 75 for _, tc := range testcases { 76 c.Run("nok - "+tc.name+" 401", func(c *qt.C) { 77 h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 78 c.Check(r.Method, qt.Equals, http.MethodPost) 79 c.Check(r.URL.Path, qt.Equals, tc.path) 80 81 c.Check(r.Header.Get("OpenAI-Organization"), qt.Equals, org) 82 c.Check(r.Header.Get("Authorization"), qt.Equals, "Bearer "+apiKey) 83 84 c.Check(r.Header.Get("Content-Type"), qt.Matches, tc.contentType) 85 86 w.Header().Set("Content-Type", httpclient.MIMETypeJSON) 87 w.WriteHeader(http.StatusUnauthorized) 88 fmt.Fprintln(w, errResp) 89 }) 90 91 openAIServer := httptest.NewServer(h) 92 c.Cleanup(openAIServer.Close) 93 94 connection, err := structpb.NewStruct(map[string]any{ 95 "base_path": openAIServer.URL, 96 "api_key": apiKey, 97 "organization": org, 98 }) 99 c.Assert(err, qt.IsNil) 100 101 exec, err := connector.CreateExecution(nil, connection, tc.task) 102 c.Assert(err, qt.IsNil) 103 104 pbIn := new(structpb.Struct) 105 _, err = exec.Execution.Execute([]*structpb.Struct{pbIn}) 106 c.Check(err, qt.IsNotNil) 107 108 want := "OpenAI responded with a 401 status code. Incorrect API key provided." 109 c.Check(errmsg.Message(err), qt.Equals, want) 110 }) 111 } 112 113 c.Run("nok - unsupported task", func(c *qt.C) { 114 task := "FOOBAR" 115 exec, err := connector.CreateExecution(nil, new(structpb.Struct), task) 116 c.Assert(err, qt.IsNil) 117 118 pbIn := new(structpb.Struct) 119 _, err = exec.Execution.Execute([]*structpb.Struct{pbIn}) 120 c.Check(err, qt.IsNotNil) 121 122 want := "FOOBAR task is not supported." 123 c.Check(errmsg.Message(err), qt.Equals, want) 124 }) 125 } 126 127 func TestConnector_Test(t *testing.T) { 128 c := qt.New(t) 129 130 logger := zap.NewNop() 131 connector := Init(logger, nil) 132 133 c.Run("nok - error", func(c *qt.C) { 134 h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 135 c.Check(r.Method, qt.Equals, http.MethodGet) 136 c.Check(r.URL.Path, qt.Equals, listModelsPath) 137 138 w.Header().Set("Content-Type", httpclient.MIMETypeJSON) 139 w.WriteHeader(http.StatusUnauthorized) 140 fmt.Fprintln(w, errResp) 141 }) 142 143 openAIServer := httptest.NewServer(h) 144 c.Cleanup(openAIServer.Close) 145 146 connection, err := structpb.NewStruct(map[string]any{ 147 "base_path": openAIServer.URL, 148 }) 149 c.Assert(err, qt.IsNil) 150 151 err = connector.Test(nil, connection) 152 c.Check(err, qt.IsNotNil) 153 154 wantMsg := "OpenAI responded with a 401 status code. Incorrect API key provided." 155 c.Check(errmsg.Message(err), qt.Equals, wantMsg) 156 }) 157 158 c.Run("ok - disconnected", func(c *qt.C) { 159 h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 160 c.Check(r.Method, qt.Equals, http.MethodGet) 161 c.Check(r.URL.Path, qt.Equals, listModelsPath) 162 163 w.Header().Set("Content-Type", httpclient.MIMETypeJSON) 164 fmt.Fprintln(w, `{}`) 165 }) 166 167 openAIServer := httptest.NewServer(h) 168 c.Cleanup(openAIServer.Close) 169 170 connection, err := structpb.NewStruct(map[string]any{ 171 "base_path": openAIServer.URL, 172 }) 173 c.Assert(err, qt.IsNil) 174 175 err = connector.Test(nil, connection) 176 c.Check(err, qt.IsNotNil) 177 }) 178 179 c.Run("ok - connected", func(c *qt.C) { 180 h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 181 c.Check(r.Method, qt.Equals, http.MethodGet) 182 c.Check(r.URL.Path, qt.Equals, listModelsPath) 183 184 w.Header().Set("Content-Type", httpclient.MIMETypeJSON) 185 fmt.Fprintln(w, `{"data": [{}]}`) 186 }) 187 188 openAIServer := httptest.NewServer(h) 189 c.Cleanup(openAIServer.Close) 190 191 connection, err := structpb.NewStruct(map[string]any{ 192 "base_path": openAIServer.URL, 193 }) 194 c.Assert(err, qt.IsNil) 195 196 err = connector.Test(nil, connection) 197 c.Check(err, qt.IsNil) 198 }) 199 }