github.com/instill-ai/component@v0.16.0-beta/pkg/connector/stabilityai/v0/connector_test.go (about) 1 package stabilityai 2 3 import ( 4 "encoding/json" 5 "fmt" 6 "net/http" 7 "net/http/httptest" 8 "testing" 9 10 qt "github.com/frankban/quicktest" 11 "go.uber.org/zap" 12 "google.golang.org/protobuf/types/known/structpb" 13 14 "github.com/instill-ai/component/pkg/base" 15 "github.com/instill-ai/component/pkg/connector/util/httpclient" 16 "github.com/instill-ai/x/errmsg" 17 ) 18 19 const ( 20 apiKey = "123" 21 errResp = ` 22 { 23 "id": "6e958442e7911ffb2e0bf89c6efe804f", 24 "message": "Incorrect API key provided", 25 "name": "unauthorized" 26 }` 27 28 okResp = ` 29 { 30 "artifacts": [ 31 { 32 "base64": "a", 33 "seed": 1234, 34 "finishReason": "SUCCESS" 35 } 36 ] 37 } 38 ` 39 ) 40 41 func TestConnector_ExecuteImageFromText(t *testing.T) { 42 c := qt.New(t) 43 44 weight := 0.5 45 text := "a cat and a dog" 46 engine := "engine" 47 48 logger := zap.NewNop() 49 connector := Init(logger, nil) 50 51 testcases := []struct { 52 name string 53 gotStatus int 54 gotResp string 55 wantResp TextToImageOutput 56 wantErr string 57 }{ 58 { 59 name: "ok - 200", 60 gotStatus: http.StatusOK, 61 gotResp: okResp, 62 wantResp: TextToImageOutput{ 63 Images: []string{""}, 64 Seeds: []uint32{1234}, 65 }, 66 }, 67 { 68 name: "nok - 401", 69 gotStatus: http.StatusUnauthorized, 70 gotResp: errResp, 71 wantErr: "Stability AI responded with a 401 status code. Incorrect API key provided", 72 }, 73 } 74 75 for _, tc := range testcases { 76 c.Run(tc.name, func(c *qt.C) { 77 h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 78 c.Check(r.Method, qt.Equals, http.MethodPost) 79 c.Check(r.URL.Path, qt.Matches, `/v1/generation/.*/text-to-image`) 80 81 c.Check(r.Header.Get("Authorization"), qt.Equals, "Bearer "+apiKey) 82 c.Check(r.Header.Get("Content-Type"), qt.Equals, httpclient.MIMETypeJSON) 83 84 w.Header().Set("Content-Type", httpclient.MIMETypeJSON) 85 w.WriteHeader(tc.gotStatus) 86 fmt.Fprintln(w, tc.gotResp) 87 }) 88 89 srv := httptest.NewServer(h) 90 c.Cleanup(srv.Close) 91 92 connection, err := structpb.NewStruct(map[string]any{ 93 "base_path": srv.URL, 94 "api_key": apiKey, 95 }) 96 c.Assert(err, qt.IsNil) 97 98 exec, err := connector.CreateExecution(nil, connection, textToImageTask) 99 c.Assert(err, qt.IsNil) 100 101 weights := []float64{weight} 102 pbIn, err := base.ConvertToStructpb(TextToImageInput{ 103 Engine: engine, 104 Prompts: []string{text}, 105 Weights: &weights, 106 }) 107 c.Assert(err, qt.IsNil) 108 109 got, err := exec.Execution.Execute([]*structpb.Struct{pbIn}) 110 if tc.wantErr != "" { 111 c.Check(errmsg.Message(err), qt.Equals, tc.wantErr) 112 return 113 } 114 115 wantJSON, err := json.Marshal(tc.wantResp) 116 c.Assert(err, qt.IsNil) 117 c.Check(wantJSON, qt.JSONEquals, got[0].AsMap()) 118 }) 119 } 120 121 c.Run("nok - unsupported task", func(c *qt.C) { 122 task := "FOOBAR" 123 exec, err := connector.CreateExecution(nil, new(structpb.Struct), task) 124 c.Assert(err, qt.IsNil) 125 126 pbIn := new(structpb.Struct) 127 _, err = exec.Execution.Execute([]*structpb.Struct{pbIn}) 128 c.Check(err, qt.IsNotNil) 129 130 want := "FOOBAR task is not supported." 131 c.Check(errmsg.Message(err), qt.Equals, want) 132 }) 133 } 134 135 func TestConnector_ExecuteImageFromImage(t *testing.T) { 136 c := qt.New(t) 137 138 weight := 0.5 139 text := "a cat and a dog" 140 engine := "engine" 141 142 logger := zap.NewNop() 143 connector := Init(logger, nil) 144 145 testcases := []struct { 146 name string 147 gotStatus int 148 gotResp string 149 wantResp ImageToImageOutput 150 wantErr string 151 }{ 152 { 153 name: "ok - 200", 154 gotStatus: http.StatusOK, 155 gotResp: okResp, 156 wantResp: ImageToImageOutput{ 157 Images: []string{""}, 158 Seeds: []uint32{1234}, 159 }, 160 }, 161 { 162 name: "nok - 401", 163 gotStatus: http.StatusUnauthorized, 164 gotResp: errResp, 165 wantErr: "Stability AI responded with a 401 status code. Incorrect API key provided", 166 }, 167 } 168 169 for _, tc := range testcases { 170 c.Run(tc.name, func(c *qt.C) { 171 h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 172 c.Check(r.Method, qt.Equals, http.MethodPost) 173 c.Check(r.URL.Path, qt.Matches, `/v1/generation/.*/image-to-image`) 174 175 c.Check(r.Header.Get("Authorization"), qt.Equals, "Bearer "+apiKey) 176 c.Check(r.Header.Get("Content-Type"), qt.Matches, "multipart/form-data; boundary=.*") 177 178 w.Header().Set("Content-Type", httpclient.MIMETypeJSON) 179 w.WriteHeader(tc.gotStatus) 180 fmt.Fprintln(w, tc.gotResp) 181 }) 182 183 srv := httptest.NewServer(h) 184 c.Cleanup(srv.Close) 185 186 connection, err := structpb.NewStruct(map[string]any{ 187 "base_path": srv.URL, 188 "api_key": apiKey, 189 }) 190 c.Assert(err, qt.IsNil) 191 192 exec, err := connector.CreateExecution(nil, connection, imageToImageTask) 193 c.Assert(err, qt.IsNil) 194 195 weights := []float64{weight} 196 pbIn, err := base.ConvertToStructpb(ImageToImageInput{ 197 Engine: engine, 198 Prompts: []string{text}, 199 Weights: &weights, 200 }) 201 c.Assert(err, qt.IsNil) 202 203 got, err := exec.Execution.Execute([]*structpb.Struct{pbIn}) 204 if tc.wantErr != "" { 205 c.Check(errmsg.Message(err), qt.Equals, tc.wantErr) 206 return 207 } 208 209 wantJSON, err := json.Marshal(tc.wantResp) 210 c.Assert(err, qt.IsNil) 211 c.Check(wantJSON, qt.JSONEquals, got[0].AsMap()) 212 }) 213 } 214 215 c.Run("nok - unsupported task", func(c *qt.C) { 216 task := "FOOBAR" 217 exec, err := connector.CreateExecution(nil, new(structpb.Struct), task) 218 c.Assert(err, qt.IsNil) 219 220 pbIn := new(structpb.Struct) 221 _, err = exec.Execution.Execute([]*structpb.Struct{pbIn}) 222 c.Check(err, qt.IsNotNil) 223 224 want := "FOOBAR task is not supported." 225 c.Check(errmsg.Message(err), qt.Equals, want) 226 }) 227 } 228 229 func TestConnector_Test(t *testing.T) { 230 c := qt.New(t) 231 232 logger := zap.NewNop() 233 connector := Init(logger, nil) 234 235 c.Run("nok - error", func(c *qt.C) { 236 h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 237 c.Check(r.Method, qt.Equals, http.MethodGet) 238 c.Check(r.URL.Path, qt.Equals, listEnginesPath) 239 240 w.Header().Set("Content-Type", httpclient.MIMETypeJSON) 241 w.WriteHeader(http.StatusUnauthorized) 242 fmt.Fprintln(w, errResp) 243 }) 244 245 srv := httptest.NewServer(h) 246 c.Cleanup(srv.Close) 247 248 connection, err := structpb.NewStruct(map[string]any{ 249 "base_path": srv.URL, 250 }) 251 c.Assert(err, qt.IsNil) 252 253 err = connector.Test(nil, connection) 254 c.Check(err, qt.IsNotNil) 255 256 wantMsg := "Stability AI responded with a 401 status code. Incorrect API key provided" 257 c.Check(errmsg.Message(err), qt.Equals, wantMsg) 258 }) 259 260 c.Run("ok - disconnected", func(c *qt.C) { 261 h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 262 c.Check(r.Method, qt.Equals, http.MethodGet) 263 c.Check(r.URL.Path, qt.Equals, listEnginesPath) 264 265 w.Header().Set("Content-Type", httpclient.MIMETypeJSON) 266 fmt.Fprintln(w, `[]`) 267 }) 268 269 srv := httptest.NewServer(h) 270 c.Cleanup(srv.Close) 271 272 connection, err := structpb.NewStruct(map[string]any{ 273 "base_path": srv.URL, 274 }) 275 c.Assert(err, qt.IsNil) 276 277 err = connector.Test(nil, connection) 278 c.Check(err, qt.IsNotNil) 279 }) 280 281 c.Run("ok - connected", func(c *qt.C) { 282 h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 283 c.Check(r.Method, qt.Equals, http.MethodGet) 284 c.Check(r.URL.Path, qt.Equals, listEnginesPath) 285 286 w.Header().Set("Content-Type", httpclient.MIMETypeJSON) 287 fmt.Fprintln(w, `[{}]`) 288 }) 289 290 srv := httptest.NewServer(h) 291 c.Cleanup(srv.Close) 292 293 connection, err := structpb.NewStruct(map[string]any{ 294 "base_path": srv.URL, 295 }) 296 c.Assert(err, qt.IsNil) 297 298 err = connector.Test(nil, connection) 299 c.Check(err, qt.IsNil) 300 }) 301 }