github.com/instill-ai/component@v0.16.0-beta/pkg/connector/stabilityai/v0/connector_test.go (about)

     1  package stabilityai
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"testing"
     9  
    10  	qt "github.com/frankban/quicktest"
    11  	"go.uber.org/zap"
    12  	"google.golang.org/protobuf/types/known/structpb"
    13  
    14  	"github.com/instill-ai/component/pkg/base"
    15  	"github.com/instill-ai/component/pkg/connector/util/httpclient"
    16  	"github.com/instill-ai/x/errmsg"
    17  )
    18  
    19  const (
    20  	apiKey  = "123"
    21  	errResp = `
    22  {
    23    "id": "6e958442e7911ffb2e0bf89c6efe804f",
    24    "message": "Incorrect API key provided",
    25    "name": "unauthorized"
    26  }`
    27  
    28  	okResp = `
    29  {
    30    "artifacts": [
    31      {
    32        "base64": "a",
    33        "seed": 1234,
    34        "finishReason": "SUCCESS"
    35      }
    36    ]
    37  }
    38  `
    39  )
    40  
    41  func TestConnector_ExecuteImageFromText(t *testing.T) {
    42  	c := qt.New(t)
    43  
    44  	weight := 0.5
    45  	text := "a cat and a dog"
    46  	engine := "engine"
    47  
    48  	logger := zap.NewNop()
    49  	connector := Init(logger, nil)
    50  
    51  	testcases := []struct {
    52  		name      string
    53  		gotStatus int
    54  		gotResp   string
    55  		wantResp  TextToImageOutput
    56  		wantErr   string
    57  	}{
    58  		{
    59  			name:      "ok - 200",
    60  			gotStatus: http.StatusOK,
    61  			gotResp:   okResp,
    62  			wantResp: TextToImageOutput{
    63  				Images: []string{"data:image/png;base64,a"},
    64  				Seeds:  []uint32{1234},
    65  			},
    66  		},
    67  		{
    68  			name:      "nok - 401",
    69  			gotStatus: http.StatusUnauthorized,
    70  			gotResp:   errResp,
    71  			wantErr:   "Stability AI responded with a 401 status code. Incorrect API key provided",
    72  		},
    73  	}
    74  
    75  	for _, tc := range testcases {
    76  		c.Run(tc.name, func(c *qt.C) {
    77  			h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    78  				c.Check(r.Method, qt.Equals, http.MethodPost)
    79  				c.Check(r.URL.Path, qt.Matches, `/v1/generation/.*/text-to-image`)
    80  
    81  				c.Check(r.Header.Get("Authorization"), qt.Equals, "Bearer "+apiKey)
    82  				c.Check(r.Header.Get("Content-Type"), qt.Equals, httpclient.MIMETypeJSON)
    83  
    84  				w.Header().Set("Content-Type", httpclient.MIMETypeJSON)
    85  				w.WriteHeader(tc.gotStatus)
    86  				fmt.Fprintln(w, tc.gotResp)
    87  			})
    88  
    89  			srv := httptest.NewServer(h)
    90  			c.Cleanup(srv.Close)
    91  
    92  			connection, err := structpb.NewStruct(map[string]any{
    93  				"base_path": srv.URL,
    94  				"api_key":   apiKey,
    95  			})
    96  			c.Assert(err, qt.IsNil)
    97  
    98  			exec, err := connector.CreateExecution(nil, connection, textToImageTask)
    99  			c.Assert(err, qt.IsNil)
   100  
   101  			weights := []float64{weight}
   102  			pbIn, err := base.ConvertToStructpb(TextToImageInput{
   103  				Engine:  engine,
   104  				Prompts: []string{text},
   105  				Weights: &weights,
   106  			})
   107  			c.Assert(err, qt.IsNil)
   108  
   109  			got, err := exec.Execution.Execute([]*structpb.Struct{pbIn})
   110  			if tc.wantErr != "" {
   111  				c.Check(errmsg.Message(err), qt.Equals, tc.wantErr)
   112  				return
   113  			}
   114  
   115  			wantJSON, err := json.Marshal(tc.wantResp)
   116  			c.Assert(err, qt.IsNil)
   117  			c.Check(wantJSON, qt.JSONEquals, got[0].AsMap())
   118  		})
   119  	}
   120  
   121  	c.Run("nok - unsupported task", func(c *qt.C) {
   122  		task := "FOOBAR"
   123  		exec, err := connector.CreateExecution(nil, new(structpb.Struct), task)
   124  		c.Assert(err, qt.IsNil)
   125  
   126  		pbIn := new(structpb.Struct)
   127  		_, err = exec.Execution.Execute([]*structpb.Struct{pbIn})
   128  		c.Check(err, qt.IsNotNil)
   129  
   130  		want := "FOOBAR task is not supported."
   131  		c.Check(errmsg.Message(err), qt.Equals, want)
   132  	})
   133  }
   134  
   135  func TestConnector_ExecuteImageFromImage(t *testing.T) {
   136  	c := qt.New(t)
   137  
   138  	weight := 0.5
   139  	text := "a cat and a dog"
   140  	engine := "engine"
   141  
   142  	logger := zap.NewNop()
   143  	connector := Init(logger, nil)
   144  
   145  	testcases := []struct {
   146  		name      string
   147  		gotStatus int
   148  		gotResp   string
   149  		wantResp  ImageToImageOutput
   150  		wantErr   string
   151  	}{
   152  		{
   153  			name:      "ok - 200",
   154  			gotStatus: http.StatusOK,
   155  			gotResp:   okResp,
   156  			wantResp: ImageToImageOutput{
   157  				Images: []string{"data:image/png;base64,a"},
   158  				Seeds:  []uint32{1234},
   159  			},
   160  		},
   161  		{
   162  			name:      "nok - 401",
   163  			gotStatus: http.StatusUnauthorized,
   164  			gotResp:   errResp,
   165  			wantErr:   "Stability AI responded with a 401 status code. Incorrect API key provided",
   166  		},
   167  	}
   168  
   169  	for _, tc := range testcases {
   170  		c.Run(tc.name, func(c *qt.C) {
   171  			h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   172  				c.Check(r.Method, qt.Equals, http.MethodPost)
   173  				c.Check(r.URL.Path, qt.Matches, `/v1/generation/.*/image-to-image`)
   174  
   175  				c.Check(r.Header.Get("Authorization"), qt.Equals, "Bearer "+apiKey)
   176  				c.Check(r.Header.Get("Content-Type"), qt.Matches, "multipart/form-data; boundary=.*")
   177  
   178  				w.Header().Set("Content-Type", httpclient.MIMETypeJSON)
   179  				w.WriteHeader(tc.gotStatus)
   180  				fmt.Fprintln(w, tc.gotResp)
   181  			})
   182  
   183  			srv := httptest.NewServer(h)
   184  			c.Cleanup(srv.Close)
   185  
   186  			connection, err := structpb.NewStruct(map[string]any{
   187  				"base_path": srv.URL,
   188  				"api_key":   apiKey,
   189  			})
   190  			c.Assert(err, qt.IsNil)
   191  
   192  			exec, err := connector.CreateExecution(nil, connection, imageToImageTask)
   193  			c.Assert(err, qt.IsNil)
   194  
   195  			weights := []float64{weight}
   196  			pbIn, err := base.ConvertToStructpb(ImageToImageInput{
   197  				Engine:  engine,
   198  				Prompts: []string{text},
   199  				Weights: &weights,
   200  			})
   201  			c.Assert(err, qt.IsNil)
   202  
   203  			got, err := exec.Execution.Execute([]*structpb.Struct{pbIn})
   204  			if tc.wantErr != "" {
   205  				c.Check(errmsg.Message(err), qt.Equals, tc.wantErr)
   206  				return
   207  			}
   208  
   209  			wantJSON, err := json.Marshal(tc.wantResp)
   210  			c.Assert(err, qt.IsNil)
   211  			c.Check(wantJSON, qt.JSONEquals, got[0].AsMap())
   212  		})
   213  	}
   214  
   215  	c.Run("nok - unsupported task", func(c *qt.C) {
   216  		task := "FOOBAR"
   217  		exec, err := connector.CreateExecution(nil, new(structpb.Struct), task)
   218  		c.Assert(err, qt.IsNil)
   219  
   220  		pbIn := new(structpb.Struct)
   221  		_, err = exec.Execution.Execute([]*structpb.Struct{pbIn})
   222  		c.Check(err, qt.IsNotNil)
   223  
   224  		want := "FOOBAR task is not supported."
   225  		c.Check(errmsg.Message(err), qt.Equals, want)
   226  	})
   227  }
   228  
   229  func TestConnector_Test(t *testing.T) {
   230  	c := qt.New(t)
   231  
   232  	logger := zap.NewNop()
   233  	connector := Init(logger, nil)
   234  
   235  	c.Run("nok - error", func(c *qt.C) {
   236  		h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   237  			c.Check(r.Method, qt.Equals, http.MethodGet)
   238  			c.Check(r.URL.Path, qt.Equals, listEnginesPath)
   239  
   240  			w.Header().Set("Content-Type", httpclient.MIMETypeJSON)
   241  			w.WriteHeader(http.StatusUnauthorized)
   242  			fmt.Fprintln(w, errResp)
   243  		})
   244  
   245  		srv := httptest.NewServer(h)
   246  		c.Cleanup(srv.Close)
   247  
   248  		connection, err := structpb.NewStruct(map[string]any{
   249  			"base_path": srv.URL,
   250  		})
   251  		c.Assert(err, qt.IsNil)
   252  
   253  		err = connector.Test(nil, connection)
   254  		c.Check(err, qt.IsNotNil)
   255  
   256  		wantMsg := "Stability AI responded with a 401 status code. Incorrect API key provided"
   257  		c.Check(errmsg.Message(err), qt.Equals, wantMsg)
   258  	})
   259  
   260  	c.Run("ok - disconnected", func(c *qt.C) {
   261  		h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   262  			c.Check(r.Method, qt.Equals, http.MethodGet)
   263  			c.Check(r.URL.Path, qt.Equals, listEnginesPath)
   264  
   265  			w.Header().Set("Content-Type", httpclient.MIMETypeJSON)
   266  			fmt.Fprintln(w, `[]`)
   267  		})
   268  
   269  		srv := httptest.NewServer(h)
   270  		c.Cleanup(srv.Close)
   271  
   272  		connection, err := structpb.NewStruct(map[string]any{
   273  			"base_path": srv.URL,
   274  		})
   275  		c.Assert(err, qt.IsNil)
   276  
   277  		err = connector.Test(nil, connection)
   278  		c.Check(err, qt.IsNotNil)
   279  	})
   280  
   281  	c.Run("ok - connected", func(c *qt.C) {
   282  		h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   283  			c.Check(r.Method, qt.Equals, http.MethodGet)
   284  			c.Check(r.URL.Path, qt.Equals, listEnginesPath)
   285  
   286  			w.Header().Set("Content-Type", httpclient.MIMETypeJSON)
   287  			fmt.Fprintln(w, `[{}]`)
   288  		})
   289  
   290  		srv := httptest.NewServer(h)
   291  		c.Cleanup(srv.Close)
   292  
   293  		connection, err := structpb.NewStruct(map[string]any{
   294  			"base_path": srv.URL,
   295  		})
   296  		c.Assert(err, qt.IsNil)
   297  
   298  		err = connector.Test(nil, connection)
   299  		c.Check(err, qt.IsNil)
   300  	})
   301  }