github.com/instill-ai/component@v0.16.0-beta/pkg/connector/instill/v0/main.go (about) 1 //go:generate compogen readme --connector ./config ./README.mdx 2 package instill 3 4 import ( 5 "context" 6 _ "embed" 7 "fmt" 8 "strings" 9 "sync" 10 "time" 11 12 "go.uber.org/zap" 13 "google.golang.org/grpc/metadata" 14 "google.golang.org/protobuf/proto" 15 "google.golang.org/protobuf/types/known/structpb" 16 17 "github.com/instill-ai/component/pkg/base" 18 19 commonPB "github.com/instill-ai/protogen-go/common/task/v1alpha" 20 mgmtPB "github.com/instill-ai/protogen-go/core/mgmt/v1beta" 21 modelPB "github.com/instill-ai/protogen-go/model/model/v1alpha" 22 pipelinePB "github.com/instill-ai/protogen-go/vdp/pipeline/v1beta" 23 ) 24 25 var ( 26 //go:embed config/definition.json 27 definitionJSON []byte 28 //go:embed config/tasks.json 29 tasksJSON []byte 30 once sync.Once 31 con *connector 32 ) 33 34 type connector struct { 35 base.BaseConnector 36 37 // Workaround solution 38 cacheDefinition *pipelinePB.ConnectorDefinition 39 } 40 41 type execution struct { 42 base.BaseConnectorExecution 43 } 44 45 func Init(l *zap.Logger, u base.UsageHandler) *connector { 46 once.Do(func() { 47 con = &connector{ 48 BaseConnector: base.BaseConnector{ 49 Logger: l, 50 UsageHandler: u, 51 }, 52 } 53 err := con.LoadConnectorDefinition(definitionJSON, tasksJSON, nil) 54 if err != nil { 55 panic(err) 56 } 57 }) 58 return con 59 } 60 61 func (c *connector) CreateExecution(sysVars map[string]any, connection *structpb.Struct, task string) (*base.ExecutionWrapper, error) { 62 return &base.ExecutionWrapper{Execution: &execution{ 63 BaseConnectorExecution: base.BaseConnectorExecution{Connector: c, SystemVariables: sysVars, Connection: connection, Task: task}, 64 }}, nil 65 } 66 67 func getHeaderAuthorization(vars map[string]any) string { 68 if v, ok := vars["__PIPELINE_HEADER_AUTHORIZATION"]; ok { 69 return v.(string) 70 } 71 return "" 72 } 73 func getInstillUserUID(vars map[string]any) string { 74 return vars["__PIPELINE_USER_UID"].(string) 75 } 76 77 func getModelServerURL(vars map[string]any) string { 78 if v, ok := vars["__MODEL_BACKEND"]; ok { 79 return v.(string) 80 } 81 return "" 82 } 83 84 func getMgmtServerURL(vars map[string]any) string { 85 if v, ok := vars["__MGMT_BACKEND"]; ok { 86 return v.(string) 87 } 88 return "" 89 } 90 91 // This is a workaround solution for caching the definition in memory if the model list is static. 92 func useStaticModelList(vars map[string]any) bool { 93 if v, ok := vars["__STATIC_MODEL_LIST"]; ok { 94 return v.(bool) 95 } 96 return false 97 } 98 99 func getRequestMetadata(vars map[string]any) metadata.MD { 100 return metadata.Pairs( 101 "Authorization", getHeaderAuthorization(vars), 102 "Instill-User-Uid", getInstillUserUID(vars), 103 "Instill-Auth-Type", "user", 104 ) 105 } 106 107 func (e *execution) Execute(inputs []*structpb.Struct) ([]*structpb.Struct, error) { 108 var err error 109 110 if len(inputs) <= 0 || inputs[0] == nil { 111 return inputs, fmt.Errorf("invalid input") 112 } 113 114 gRPCCLient, gRPCCLientConn := initModelPublicServiceClient(getModelServerURL(e.SystemVariables)) 115 if gRPCCLientConn != nil { 116 defer gRPCCLientConn.Close() 117 } 118 119 mgmtGRPCCLient, mgmtGRPCCLientConn := initMgmtPublicServiceClient(getMgmtServerURL(e.SystemVariables)) 120 if mgmtGRPCCLientConn != nil { 121 defer mgmtGRPCCLientConn.Close() 122 } 123 124 modelNameSplits := strings.Split(inputs[0].GetFields()["model_name"].GetStringValue(), "/") 125 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) 126 defer cancel() 127 128 ctx = metadata.NewOutgoingContext(ctx, getRequestMetadata(e.SystemVariables)) 129 nsResp, err := mgmtGRPCCLient.CheckNamespace(ctx, &mgmtPB.CheckNamespaceRequest{ 130 Id: modelNameSplits[0], 131 }) 132 if err != nil { 133 return nil, err 134 } 135 nsType := "" 136 if nsResp.Type == mgmtPB.CheckNamespaceResponse_NAMESPACE_ORGANIZATION { 137 nsType = "organizations" 138 } else { 139 nsType = "users" 140 } 141 142 modelName := fmt.Sprintf("%s/%s/models/%s", nsType, modelNameSplits[0], modelNameSplits[1]) 143 144 var result []*structpb.Struct 145 switch e.Task { 146 case commonPB.Task_TASK_UNSPECIFIED.String(): 147 result, err = e.executeUnspecified(gRPCCLient, modelName, inputs) 148 case commonPB.Task_TASK_CLASSIFICATION.String(): 149 result, err = e.executeImageClassification(gRPCCLient, modelName, inputs) 150 case commonPB.Task_TASK_DETECTION.String(): 151 result, err = e.executeObjectDetection(gRPCCLient, modelName, inputs) 152 case commonPB.Task_TASK_KEYPOINT.String(): 153 result, err = e.executeKeyPointDetection(gRPCCLient, modelName, inputs) 154 case commonPB.Task_TASK_OCR.String(): 155 result, err = e.executeOCR(gRPCCLient, modelName, inputs) 156 case commonPB.Task_TASK_INSTANCE_SEGMENTATION.String(): 157 result, err = e.executeInstanceSegmentation(gRPCCLient, modelName, inputs) 158 case commonPB.Task_TASK_SEMANTIC_SEGMENTATION.String(): 159 result, err = e.executeSemanticSegmentation(gRPCCLient, modelName, inputs) 160 case commonPB.Task_TASK_TEXT_TO_IMAGE.String(): 161 result, err = e.executeTextToImage(gRPCCLient, modelName, inputs) 162 case commonPB.Task_TASK_TEXT_GENERATION.String(): 163 result, err = e.executeTextGeneration(gRPCCLient, modelName, inputs) 164 case commonPB.Task_TASK_TEXT_GENERATION_CHAT.String(): 165 result, err = e.executeTextGenerationChat(gRPCCLient, modelName, inputs) 166 case commonPB.Task_TASK_VISUAL_QUESTION_ANSWERING.String(): 167 result, err = e.executeVisualQuestionAnswering(gRPCCLient, modelName, inputs) 168 case commonPB.Task_TASK_IMAGE_TO_IMAGE.String(): 169 result, err = e.executeImageToImage(gRPCCLient, modelName, inputs) 170 default: 171 return inputs, fmt.Errorf("unsupported task: %s", e.Task) 172 } 173 174 return result, err 175 } 176 177 func (c *connector) Test(sysVars map[string]any, connection *structpb.Struct) error { 178 gRPCCLient, gRPCCLientConn := initModelPublicServiceClient(getModelServerURL(sysVars)) 179 if gRPCCLientConn != nil { 180 defer gRPCCLientConn.Close() 181 } 182 ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 183 defer cancel() 184 185 ctx = metadata.NewOutgoingContext(ctx, getRequestMetadata(sysVars)) 186 _, err := gRPCCLient.ListModels(ctx, &modelPB.ListModelsRequest{}) 187 if err != nil { 188 return err 189 } 190 191 return nil 192 } 193 194 type ModelsResp struct { 195 Models []struct { 196 Name string `json:"name"` 197 Task string `json:"task"` 198 } `json:"models"` 199 } 200 201 // Generate the `model_name` enum based on the task. 202 // This implementation is a temporary solution due to the incomplete feature set of Instill Model. 203 // We'll re-implement this after Instill Model is stable. 204 func (c *connector) GetConnectorDefinition(sysVars map[string]any, component *pipelinePB.ConnectorComponent) (*pipelinePB.ConnectorDefinition, error) { 205 if useStaticModelList(sysVars) && c.cacheDefinition != nil { 206 return c.cacheDefinition, nil 207 } 208 209 oriDef, err := c.BaseConnector.GetConnectorDefinition(nil, nil) 210 if err != nil { 211 return nil, err 212 } 213 def := proto.Clone(oriDef).(*pipelinePB.ConnectorDefinition) 214 215 if getModelServerURL(sysVars) == "" { 216 return def, nil 217 } 218 219 gRPCCLient, gRPCCLientConn := initModelPublicServiceClient(getModelServerURL(sysVars)) 220 if gRPCCLientConn != nil { 221 defer gRPCCLientConn.Close() 222 } 223 224 ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) 225 defer cancel() 226 227 ctx = metadata.NewOutgoingContext(ctx, getRequestMetadata(sysVars)) 228 229 pageToken := "" 230 models := []*modelPB.Model{} 231 for { 232 resp, err := gRPCCLient.ListModels(ctx, &modelPB.ListModelsRequest{PageToken: &pageToken}) 233 if err != nil { 234 235 return def, nil 236 } 237 models = append(models, resp.Models...) 238 pageToken = resp.NextPageToken 239 if pageToken == "" { 240 break 241 } 242 } 243 244 modelNameMap := map[string]*structpb.ListValue{} 245 246 modelName := &structpb.ListValue{} 247 for _, model := range models { 248 if _, ok := modelNameMap[model.Task.String()]; !ok { 249 modelNameMap[model.Task.String()] = &structpb.ListValue{} 250 } 251 namePaths := strings.Split(model.Name, "/") 252 modelName.Values = append(modelName.Values, structpb.NewStringValue(fmt.Sprintf("%s/%s", namePaths[1], namePaths[3]))) 253 modelNameMap[model.Task.String()].Values = append(modelNameMap[model.Task.String()].Values, structpb.NewStringValue(fmt.Sprintf("%s/%s", namePaths[1], namePaths[3]))) 254 } 255 for _, sch := range def.Spec.ComponentSpecification.Fields["oneOf"].GetListValue().Values { 256 task := sch.GetStructValue().Fields["properties"].GetStructValue().Fields["task"].GetStructValue().Fields["const"].GetStringValue() 257 if _, ok := modelNameMap[task]; ok { 258 addModelEnum(sch.GetStructValue().Fields, modelNameMap[task]) 259 } 260 261 } 262 if useStaticModelList(sysVars) { 263 c.cacheDefinition = def 264 } 265 return def, nil 266 } 267 268 func addModelEnum(compSpec map[string]*structpb.Value, modelName *structpb.ListValue) { 269 if compSpec == nil { 270 return 271 } 272 for key, sch := range compSpec { 273 if key == "model_name" { 274 sch.GetStructValue().Fields["enum"] = structpb.NewListValue(modelName) 275 } 276 277 if sch.GetStructValue() != nil { 278 addModelEnum(sch.GetStructValue().Fields, modelName) 279 } 280 if sch.GetListValue() != nil { 281 for _, v := range sch.GetListValue().Values { 282 if v.GetStructValue() != nil { 283 addModelEnum(v.GetStructValue().Fields, modelName) 284 } 285 } 286 } 287 } 288 }