github.com/instill-ai/component@v0.16.0-beta/pkg/connector/stabilityai/v0/main.go (about)

     1  //go:generate compogen readme --connector ./config ./README.mdx
     2  package stabilityai
     3  
     4  import (
     5  	_ "embed"
     6  	"fmt"
     7  	"sync"
     8  
     9  	"go.uber.org/zap"
    10  	"google.golang.org/protobuf/types/known/structpb"
    11  
    12  	"github.com/instill-ai/component/pkg/base"
    13  	"github.com/instill-ai/x/errmsg"
    14  )
    15  
    16  const (
    17  	host             = "https://api.stability.ai"
    18  	textToImageTask  = "TASK_TEXT_TO_IMAGE"
    19  	imageToImageTask = "TASK_IMAGE_TO_IMAGE"
    20  )
    21  
    22  var (
    23  	//go:embed config/definition.json
    24  	definitionJSON []byte
    25  	//go:embed config/tasks.json
    26  	tasksJSON []byte
    27  	//go:embed config/stabilityai.json
    28  	stabilityaiJSON []byte
    29  	once            sync.Once
    30  	con             *connector
    31  )
    32  
    33  type connector struct {
    34  	base.BaseConnector
    35  }
    36  
    37  type execution struct {
    38  	base.BaseConnectorExecution
    39  }
    40  
    41  func Init(l *zap.Logger, u base.UsageHandler) *connector {
    42  	once.Do(func() {
    43  		con = &connector{
    44  			BaseConnector: base.BaseConnector{
    45  				Logger:       l,
    46  				UsageHandler: u,
    47  			},
    48  		}
    49  		err := con.LoadConnectorDefinition(definitionJSON, tasksJSON, map[string][]byte{"stabilityai.json": stabilityaiJSON})
    50  		if err != nil {
    51  			panic(err)
    52  		}
    53  	})
    54  	return con
    55  }
    56  
    57  func (c *connector) CreateExecution(sysVars map[string]any, connection *structpb.Struct, task string) (*base.ExecutionWrapper, error) {
    58  	return &base.ExecutionWrapper{Execution: &execution{
    59  		BaseConnectorExecution: base.BaseConnectorExecution{Connector: c, SystemVariables: sysVars, Connection: connection, Task: task},
    60  	}}, nil
    61  }
    62  
    63  func getAPIKey(config *structpb.Struct) string {
    64  	return config.GetFields()["api_key"].GetStringValue()
    65  }
    66  
    67  // getBasePath returns Stability AI's API URL. This configuration param allows
    68  // us to override the API the connector will point to. It isn't meant to be
    69  // exposed to users. Rather, it can serve to test the logic against a fake
    70  // server.
    71  // TODO instead of having the API value hardcoded in the codebase, it should be
    72  // read from a config file or environment variable.
    73  func getBasePath(config *structpb.Struct) string {
    74  	v, ok := config.GetFields()["base_path"]
    75  	if !ok {
    76  		return host
    77  	}
    78  	return v.GetStringValue()
    79  }
    80  
    81  func (e *execution) Execute(inputs []*structpb.Struct) ([]*structpb.Struct, error) {
    82  	client := newClient(e.Connection, e.GetLogger())
    83  	outputs := []*structpb.Struct{}
    84  
    85  	for _, input := range inputs {
    86  		switch e.Task {
    87  		case textToImageTask:
    88  			params, err := parseTextToImageReq(input)
    89  			if err != nil {
    90  				return inputs, err
    91  			}
    92  
    93  			resp := ImageTaskRes{}
    94  			req := client.R().SetResult(&resp).SetBody(params)
    95  
    96  			if _, err := req.Post(params.path); err != nil {
    97  				return inputs, err
    98  			}
    99  
   100  			output, err := textToImageOutput(resp)
   101  			if err != nil {
   102  				return nil, err
   103  			}
   104  
   105  			outputs = append(outputs, output)
   106  		case imageToImageTask:
   107  			params, err := parseImageToImageReq(input)
   108  			if err != nil {
   109  				return inputs, err
   110  			}
   111  
   112  			data, ct, err := params.getBytes()
   113  			if err != nil {
   114  				return inputs, err
   115  			}
   116  
   117  			resp := ImageTaskRes{}
   118  			req := client.R().SetBody(data).SetResult(&resp).SetHeader("Content-Type", ct)
   119  
   120  			if _, err := req.Post(params.path); err != nil {
   121  				return inputs, err
   122  			}
   123  
   124  			output, err := imageToImageOutput(resp)
   125  			if err != nil {
   126  				return nil, err
   127  			}
   128  
   129  			outputs = append(outputs, output)
   130  
   131  		default:
   132  			return nil, errmsg.AddMessage(
   133  				fmt.Errorf("not supported task: %s", e.Task),
   134  				fmt.Sprintf("%s task is not supported.", e.Task),
   135  			)
   136  		}
   137  	}
   138  	return outputs, nil
   139  }
   140  
   141  // Test checks the connector state.
   142  func (c *connector) Test(sysVars map[string]any, connection *structpb.Struct) error {
   143  	var engines []Engine
   144  	req := newClient(connection, c.Logger).R().SetResult(&engines)
   145  
   146  	if _, err := req.Get(listEnginesPath); err != nil {
   147  		return err
   148  	}
   149  
   150  	if len(engines) == 0 {
   151  		return fmt.Errorf("no engines")
   152  	}
   153  
   154  	return nil
   155  }