github.com/instill-ai/component@v0.16.0-beta/pkg/connector/instill/v0/main.go (about)

     1  //go:generate compogen readme --connector ./config ./README.mdx
     2  package instill
     3  
     4  import (
     5  	"context"
     6  	_ "embed"
     7  	"fmt"
     8  	"strings"
     9  	"sync"
    10  	"time"
    11  
    12  	"go.uber.org/zap"
    13  	"google.golang.org/grpc/metadata"
    14  	"google.golang.org/protobuf/proto"
    15  	"google.golang.org/protobuf/types/known/structpb"
    16  
    17  	"github.com/instill-ai/component/pkg/base"
    18  
    19  	commonPB "github.com/instill-ai/protogen-go/common/task/v1alpha"
    20  	mgmtPB "github.com/instill-ai/protogen-go/core/mgmt/v1beta"
    21  	modelPB "github.com/instill-ai/protogen-go/model/model/v1alpha"
    22  	pipelinePB "github.com/instill-ai/protogen-go/vdp/pipeline/v1beta"
    23  )
    24  
    25  var (
    26  	//go:embed config/definition.json
    27  	definitionJSON []byte
    28  	//go:embed config/tasks.json
    29  	tasksJSON []byte
    30  	once      sync.Once
    31  	con       *connector
    32  )
    33  
    34  type connector struct {
    35  	base.BaseConnector
    36  
    37  	// Workaround solution
    38  	cacheDefinition *pipelinePB.ConnectorDefinition
    39  }
    40  
    41  type execution struct {
    42  	base.BaseConnectorExecution
    43  }
    44  
    45  func Init(l *zap.Logger, u base.UsageHandler) *connector {
    46  	once.Do(func() {
    47  		con = &connector{
    48  			BaseConnector: base.BaseConnector{
    49  				Logger:       l,
    50  				UsageHandler: u,
    51  			},
    52  		}
    53  		err := con.LoadConnectorDefinition(definitionJSON, tasksJSON, nil)
    54  		if err != nil {
    55  			panic(err)
    56  		}
    57  	})
    58  	return con
    59  }
    60  
    61  func (c *connector) CreateExecution(sysVars map[string]any, connection *structpb.Struct, task string) (*base.ExecutionWrapper, error) {
    62  	return &base.ExecutionWrapper{Execution: &execution{
    63  		BaseConnectorExecution: base.BaseConnectorExecution{Connector: c, SystemVariables: sysVars, Connection: connection, Task: task},
    64  	}}, nil
    65  }
    66  
    67  func getHeaderAuthorization(vars map[string]any) string {
    68  	if v, ok := vars["__PIPELINE_HEADER_AUTHORIZATION"]; ok {
    69  		return v.(string)
    70  	}
    71  	return ""
    72  }
    73  func getInstillUserUID(vars map[string]any) string {
    74  	return vars["__PIPELINE_USER_UID"].(string)
    75  }
    76  
    77  func getModelServerURL(vars map[string]any) string {
    78  	if v, ok := vars["__MODEL_BACKEND"]; ok {
    79  		return v.(string)
    80  	}
    81  	return ""
    82  }
    83  
    84  func getMgmtServerURL(vars map[string]any) string {
    85  	if v, ok := vars["__MGMT_BACKEND"]; ok {
    86  		return v.(string)
    87  	}
    88  	return ""
    89  }
    90  
    91  // This is a workaround solution for caching the definition in memory if the model list is static.
    92  func useStaticModelList(vars map[string]any) bool {
    93  	if v, ok := vars["__STATIC_MODEL_LIST"]; ok {
    94  		return v.(bool)
    95  	}
    96  	return false
    97  }
    98  
    99  func getRequestMetadata(vars map[string]any) metadata.MD {
   100  	return metadata.Pairs(
   101  		"Authorization", getHeaderAuthorization(vars),
   102  		"Instill-User-Uid", getInstillUserUID(vars),
   103  		"Instill-Auth-Type", "user",
   104  	)
   105  }
   106  
   107  func (e *execution) Execute(inputs []*structpb.Struct) ([]*structpb.Struct, error) {
   108  	var err error
   109  
   110  	if len(inputs) <= 0 || inputs[0] == nil {
   111  		return inputs, fmt.Errorf("invalid input")
   112  	}
   113  
   114  	gRPCCLient, gRPCCLientConn := initModelPublicServiceClient(getModelServerURL(e.SystemVariables))
   115  	if gRPCCLientConn != nil {
   116  		defer gRPCCLientConn.Close()
   117  	}
   118  
   119  	mgmtGRPCCLient, mgmtGRPCCLientConn := initMgmtPublicServiceClient(getMgmtServerURL(e.SystemVariables))
   120  	if mgmtGRPCCLientConn != nil {
   121  		defer mgmtGRPCCLientConn.Close()
   122  	}
   123  
   124  	modelNameSplits := strings.Split(inputs[0].GetFields()["model_name"].GetStringValue(), "/")
   125  	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
   126  	defer cancel()
   127  
   128  	ctx = metadata.NewOutgoingContext(ctx, getRequestMetadata(e.SystemVariables))
   129  	nsResp, err := mgmtGRPCCLient.CheckNamespace(ctx, &mgmtPB.CheckNamespaceRequest{
   130  		Id: modelNameSplits[0],
   131  	})
   132  	if err != nil {
   133  		return nil, err
   134  	}
   135  	nsType := ""
   136  	if nsResp.Type == mgmtPB.CheckNamespaceResponse_NAMESPACE_ORGANIZATION {
   137  		nsType = "organizations"
   138  	} else {
   139  		nsType = "users"
   140  	}
   141  
   142  	modelName := fmt.Sprintf("%s/%s/models/%s", nsType, modelNameSplits[0], modelNameSplits[1])
   143  
   144  	var result []*structpb.Struct
   145  	switch e.Task {
   146  	case commonPB.Task_TASK_UNSPECIFIED.String():
   147  		result, err = e.executeUnspecified(gRPCCLient, modelName, inputs)
   148  	case commonPB.Task_TASK_CLASSIFICATION.String():
   149  		result, err = e.executeImageClassification(gRPCCLient, modelName, inputs)
   150  	case commonPB.Task_TASK_DETECTION.String():
   151  		result, err = e.executeObjectDetection(gRPCCLient, modelName, inputs)
   152  	case commonPB.Task_TASK_KEYPOINT.String():
   153  		result, err = e.executeKeyPointDetection(gRPCCLient, modelName, inputs)
   154  	case commonPB.Task_TASK_OCR.String():
   155  		result, err = e.executeOCR(gRPCCLient, modelName, inputs)
   156  	case commonPB.Task_TASK_INSTANCE_SEGMENTATION.String():
   157  		result, err = e.executeInstanceSegmentation(gRPCCLient, modelName, inputs)
   158  	case commonPB.Task_TASK_SEMANTIC_SEGMENTATION.String():
   159  		result, err = e.executeSemanticSegmentation(gRPCCLient, modelName, inputs)
   160  	case commonPB.Task_TASK_TEXT_TO_IMAGE.String():
   161  		result, err = e.executeTextToImage(gRPCCLient, modelName, inputs)
   162  	case commonPB.Task_TASK_TEXT_GENERATION.String():
   163  		result, err = e.executeTextGeneration(gRPCCLient, modelName, inputs)
   164  	case commonPB.Task_TASK_TEXT_GENERATION_CHAT.String():
   165  		result, err = e.executeTextGenerationChat(gRPCCLient, modelName, inputs)
   166  	case commonPB.Task_TASK_VISUAL_QUESTION_ANSWERING.String():
   167  		result, err = e.executeVisualQuestionAnswering(gRPCCLient, modelName, inputs)
   168  	case commonPB.Task_TASK_IMAGE_TO_IMAGE.String():
   169  		result, err = e.executeImageToImage(gRPCCLient, modelName, inputs)
   170  	default:
   171  		return inputs, fmt.Errorf("unsupported task: %s", e.Task)
   172  	}
   173  
   174  	return result, err
   175  }
   176  
   177  func (c *connector) Test(sysVars map[string]any, connection *structpb.Struct) error {
   178  	gRPCCLient, gRPCCLientConn := initModelPublicServiceClient(getModelServerURL(sysVars))
   179  	if gRPCCLientConn != nil {
   180  		defer gRPCCLientConn.Close()
   181  	}
   182  	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
   183  	defer cancel()
   184  
   185  	ctx = metadata.NewOutgoingContext(ctx, getRequestMetadata(sysVars))
   186  	_, err := gRPCCLient.ListModels(ctx, &modelPB.ListModelsRequest{})
   187  	if err != nil {
   188  		return err
   189  	}
   190  
   191  	return nil
   192  }
   193  
   194  type ModelsResp struct {
   195  	Models []struct {
   196  		Name string `json:"name"`
   197  		Task string `json:"task"`
   198  	} `json:"models"`
   199  }
   200  
   201  // Generate the `model_name` enum based on the task.
   202  // This implementation is a temporary solution due to the incomplete feature set of Instill Model.
   203  // We'll re-implement this after Instill Model is stable.
   204  func (c *connector) GetConnectorDefinition(sysVars map[string]any, component *pipelinePB.ConnectorComponent) (*pipelinePB.ConnectorDefinition, error) {
   205  	if useStaticModelList(sysVars) && c.cacheDefinition != nil {
   206  		return c.cacheDefinition, nil
   207  	}
   208  
   209  	oriDef, err := c.BaseConnector.GetConnectorDefinition(nil, nil)
   210  	if err != nil {
   211  		return nil, err
   212  	}
   213  	def := proto.Clone(oriDef).(*pipelinePB.ConnectorDefinition)
   214  
   215  	if getModelServerURL(sysVars) == "" {
   216  		return def, nil
   217  	}
   218  
   219  	gRPCCLient, gRPCCLientConn := initModelPublicServiceClient(getModelServerURL(sysVars))
   220  	if gRPCCLientConn != nil {
   221  		defer gRPCCLientConn.Close()
   222  	}
   223  
   224  	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
   225  	defer cancel()
   226  
   227  	ctx = metadata.NewOutgoingContext(ctx, getRequestMetadata(sysVars))
   228  
   229  	pageToken := ""
   230  	models := []*modelPB.Model{}
   231  	for {
   232  		resp, err := gRPCCLient.ListModels(ctx, &modelPB.ListModelsRequest{PageToken: &pageToken})
   233  		if err != nil {
   234  
   235  			return def, nil
   236  		}
   237  		models = append(models, resp.Models...)
   238  		pageToken = resp.NextPageToken
   239  		if pageToken == "" {
   240  			break
   241  		}
   242  	}
   243  
   244  	modelNameMap := map[string]*structpb.ListValue{}
   245  
   246  	modelName := &structpb.ListValue{}
   247  	for _, model := range models {
   248  		if _, ok := modelNameMap[model.Task.String()]; !ok {
   249  			modelNameMap[model.Task.String()] = &structpb.ListValue{}
   250  		}
   251  		namePaths := strings.Split(model.Name, "/")
   252  		modelName.Values = append(modelName.Values, structpb.NewStringValue(fmt.Sprintf("%s/%s", namePaths[1], namePaths[3])))
   253  		modelNameMap[model.Task.String()].Values = append(modelNameMap[model.Task.String()].Values, structpb.NewStringValue(fmt.Sprintf("%s/%s", namePaths[1], namePaths[3])))
   254  	}
   255  	for _, sch := range def.Spec.ComponentSpecification.Fields["oneOf"].GetListValue().Values {
   256  		task := sch.GetStructValue().Fields["properties"].GetStructValue().Fields["task"].GetStructValue().Fields["const"].GetStringValue()
   257  		if _, ok := modelNameMap[task]; ok {
   258  			addModelEnum(sch.GetStructValue().Fields, modelNameMap[task])
   259  		}
   260  
   261  	}
   262  	if useStaticModelList(sysVars) {
   263  		c.cacheDefinition = def
   264  	}
   265  	return def, nil
   266  }
   267  
   268  func addModelEnum(compSpec map[string]*structpb.Value, modelName *structpb.ListValue) {
   269  	if compSpec == nil {
   270  		return
   271  	}
   272  	for key, sch := range compSpec {
   273  		if key == "model_name" {
   274  			sch.GetStructValue().Fields["enum"] = structpb.NewListValue(modelName)
   275  		}
   276  
   277  		if sch.GetStructValue() != nil {
   278  			addModelEnum(sch.GetStructValue().Fields, modelName)
   279  		}
   280  		if sch.GetListValue() != nil {
   281  			for _, v := range sch.GetListValue().Values {
   282  				if v.GetStructValue() != nil {
   283  					addModelEnum(v.GetStructValue().Fields, modelName)
   284  				}
   285  			}
   286  		}
   287  	}
   288  }