github.com/instill-ai/component@v0.16.0-beta/pkg/connector/stabilityai/v0/main.go (about) 1 //go:generate compogen readme --connector ./config ./README.mdx 2 package stabilityai 3 4 import ( 5 _ "embed" 6 "fmt" 7 "sync" 8 9 "go.uber.org/zap" 10 "google.golang.org/protobuf/types/known/structpb" 11 12 "github.com/instill-ai/component/pkg/base" 13 "github.com/instill-ai/x/errmsg" 14 ) 15 16 const ( 17 host = "https://api.stability.ai" 18 textToImageTask = "TASK_TEXT_TO_IMAGE" 19 imageToImageTask = "TASK_IMAGE_TO_IMAGE" 20 ) 21 22 var ( 23 //go:embed config/definition.json 24 definitionJSON []byte 25 //go:embed config/tasks.json 26 tasksJSON []byte 27 //go:embed config/stabilityai.json 28 stabilityaiJSON []byte 29 once sync.Once 30 con *connector 31 ) 32 33 type connector struct { 34 base.BaseConnector 35 } 36 37 type execution struct { 38 base.BaseConnectorExecution 39 } 40 41 func Init(l *zap.Logger, u base.UsageHandler) *connector { 42 once.Do(func() { 43 con = &connector{ 44 BaseConnector: base.BaseConnector{ 45 Logger: l, 46 UsageHandler: u, 47 }, 48 } 49 err := con.LoadConnectorDefinition(definitionJSON, tasksJSON, map[string][]byte{"stabilityai.json": stabilityaiJSON}) 50 if err != nil { 51 panic(err) 52 } 53 }) 54 return con 55 } 56 57 func (c *connector) CreateExecution(sysVars map[string]any, connection *structpb.Struct, task string) (*base.ExecutionWrapper, error) { 58 return &base.ExecutionWrapper{Execution: &execution{ 59 BaseConnectorExecution: base.BaseConnectorExecution{Connector: c, SystemVariables: sysVars, Connection: connection, Task: task}, 60 }}, nil 61 } 62 63 func getAPIKey(config *structpb.Struct) string { 64 return config.GetFields()["api_key"].GetStringValue() 65 } 66 67 // getBasePath returns Stability AI's API URL. This configuration param allows 68 // us to override the API the connector will point to. It isn't meant to be 69 // exposed to users. Rather, it can serve to test the logic against a fake 70 // server. 71 // TODO instead of having the API value hardcoded in the codebase, it should be 72 // read from a config file or environment variable. 73 func getBasePath(config *structpb.Struct) string { 74 v, ok := config.GetFields()["base_path"] 75 if !ok { 76 return host 77 } 78 return v.GetStringValue() 79 } 80 81 func (e *execution) Execute(inputs []*structpb.Struct) ([]*structpb.Struct, error) { 82 client := newClient(e.Connection, e.GetLogger()) 83 outputs := []*structpb.Struct{} 84 85 for _, input := range inputs { 86 switch e.Task { 87 case textToImageTask: 88 params, err := parseTextToImageReq(input) 89 if err != nil { 90 return inputs, err 91 } 92 93 resp := ImageTaskRes{} 94 req := client.R().SetResult(&resp).SetBody(params) 95 96 if _, err := req.Post(params.path); err != nil { 97 return inputs, err 98 } 99 100 output, err := textToImageOutput(resp) 101 if err != nil { 102 return nil, err 103 } 104 105 outputs = append(outputs, output) 106 case imageToImageTask: 107 params, err := parseImageToImageReq(input) 108 if err != nil { 109 return inputs, err 110 } 111 112 data, ct, err := params.getBytes() 113 if err != nil { 114 return inputs, err 115 } 116 117 resp := ImageTaskRes{} 118 req := client.R().SetBody(data).SetResult(&resp).SetHeader("Content-Type", ct) 119 120 if _, err := req.Post(params.path); err != nil { 121 return inputs, err 122 } 123 124 output, err := imageToImageOutput(resp) 125 if err != nil { 126 return nil, err 127 } 128 129 outputs = append(outputs, output) 130 131 default: 132 return nil, errmsg.AddMessage( 133 fmt.Errorf("not supported task: %s", e.Task), 134 fmt.Sprintf("%s task is not supported.", e.Task), 135 ) 136 } 137 } 138 return outputs, nil 139 } 140 141 // Test checks the connector state. 142 func (c *connector) Test(sysVars map[string]any, connection *structpb.Struct) error { 143 var engines []Engine 144 req := newClient(connection, c.Logger).R().SetResult(&engines) 145 146 if _, err := req.Get(listEnginesPath); err != nil { 147 return err 148 } 149 150 if len(engines) == 0 { 151 return fmt.Errorf("no engines") 152 } 153 154 return nil 155 }