github.com/instill-ai/component@v0.16.0-beta/pkg/base/connector.go (about)

     1  package base
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  
     7  	"github.com/gofrs/uuid"
     8  	"go.uber.org/zap"
     9  	"google.golang.org/protobuf/encoding/protojson"
    10  	"google.golang.org/protobuf/proto"
    11  	"google.golang.org/protobuf/types/known/structpb"
    12  
    13  	pipelinePB "github.com/instill-ai/protogen-go/vdp/pipeline/v1beta"
    14  )
    15  
    16  // IConnector is the interface that all connectors need to implement
    17  type IConnector interface {
    18  	IComponent
    19  
    20  	LoadConnectorDefinition(definitionJSON []byte, tasksJSON []byte, additionalJSONBytes map[string][]byte) error
    21  
    22  	// Note: Some content in the definition JSON schema needs to be generated by sysVars or component setting.
    23  	GetConnectorDefinition(sysVars map[string]any, component *pipelinePB.ConnectorComponent) (*pipelinePB.ConnectorDefinition, error)
    24  
    25  	CreateExecution(sysVars map[string]any, connection *structpb.Struct, task string) (*ExecutionWrapper, error)
    26  	Test(sysVars map[string]any, connection *structpb.Struct) error
    27  
    28  	IsCredentialField(target string) bool
    29  }
    30  
    31  // Connector is the base struct for all connectors
    32  type BaseConnector struct {
    33  	Logger       *zap.Logger
    34  	UsageHandler UsageHandler
    35  
    36  	taskInputSchemas  map[string]string
    37  	taskOutputSchemas map[string]string
    38  
    39  	definition       *pipelinePB.ConnectorDefinition
    40  	credentialFields []string
    41  }
    42  
    43  type IConnectorExecution interface {
    44  	IExecution
    45  
    46  	GetConnector() IConnector
    47  	GetConnection() *structpb.Struct
    48  }
    49  
    50  type BaseConnectorExecution struct {
    51  	Connector       IConnector
    52  	SystemVariables map[string]any
    53  	Connection      *structpb.Struct
    54  	Task            string
    55  }
    56  
    57  func (c *BaseConnector) GetID() string {
    58  	return c.definition.Id
    59  }
    60  
    61  func (c *BaseConnector) GetUID() uuid.UUID {
    62  	return uuid.FromStringOrNil(c.definition.Uid)
    63  }
    64  
    65  func (c *BaseConnector) GetLogger() *zap.Logger {
    66  	return c.Logger
    67  }
    68  func (c *BaseConnector) GetUsageHandler() UsageHandler {
    69  	return c.UsageHandler
    70  }
    71  func (c *BaseConnector) GetConnectorDefinition(sysVars map[string]any, component *pipelinePB.ConnectorComponent) (*pipelinePB.ConnectorDefinition, error) {
    72  	return c.definition, nil
    73  }
    74  
    75  func (c *BaseConnector) GetTaskInputSchemas() map[string]string {
    76  	return c.taskInputSchemas
    77  }
    78  func (c *BaseConnector) GetTaskOutputSchemas() map[string]string {
    79  	return c.taskOutputSchemas
    80  }
    81  
    82  // LoadConnectorDefinition loads the connector definitions from json files
    83  func (c *BaseConnector) LoadConnectorDefinition(definitionJSONBytes []byte, tasksJSONBytes []byte, additionalJSONBytes map[string][]byte) error {
    84  	var err error
    85  	var definitionJSON any
    86  
    87  	c.credentialFields = []string{}
    88  
    89  	err = json.Unmarshal(definitionJSONBytes, &definitionJSON)
    90  	if err != nil {
    91  		return err
    92  	}
    93  	renderedTasksJSON, err := RenderJSON(tasksJSONBytes, additionalJSONBytes)
    94  	if err != nil {
    95  		return nil
    96  	}
    97  
    98  	availableTasks := []string{}
    99  	for _, availableTask := range definitionJSON.(map[string]interface{})["available_tasks"].([]interface{}) {
   100  		availableTasks = append(availableTasks, availableTask.(string))
   101  	}
   102  
   103  	tasks, taskStructs, err := loadTasks(availableTasks, renderedTasksJSON)
   104  	if err != nil {
   105  		return err
   106  	}
   107  
   108  	c.taskInputSchemas = map[string]string{}
   109  	c.taskOutputSchemas = map[string]string{}
   110  	for k := range taskStructs {
   111  		var s []byte
   112  		s, err = protojson.Marshal(taskStructs[k].Fields["input"].GetStructValue())
   113  		if err != nil {
   114  			return err
   115  		}
   116  		c.taskInputSchemas[k] = string(s)
   117  
   118  		s, err = protojson.Marshal(taskStructs[k].Fields["output"].GetStructValue())
   119  		if err != nil {
   120  			return err
   121  		}
   122  		c.taskOutputSchemas[k] = string(s)
   123  	}
   124  
   125  	c.definition = &pipelinePB.ConnectorDefinition{}
   126  	err = protojson.UnmarshalOptions{DiscardUnknown: true}.Unmarshal(definitionJSONBytes, c.definition)
   127  	if err != nil {
   128  		return err
   129  	}
   130  
   131  	c.definition.Name = fmt.Sprintf("connector-definitions/%s", c.definition.Id)
   132  	c.definition.Tasks = tasks
   133  	if c.definition.Spec == nil {
   134  		c.definition.Spec = &pipelinePB.ConnectorSpec{}
   135  	}
   136  	c.definition.Spec.ComponentSpecification, err = generateComponentSpec(c.definition.Title, tasks, taskStructs)
   137  	if err != nil {
   138  		return err
   139  	}
   140  
   141  	raw := &structpb.Struct{}
   142  	err = protojson.Unmarshal(definitionJSONBytes, raw)
   143  	if err != nil {
   144  		return err
   145  	}
   146  	// TODO: Avoid using structpb traversal here.
   147  	if _, ok := raw.Fields["spec"]; ok {
   148  		if v, ok := raw.Fields["spec"].GetStructValue().Fields["connection_specification"]; ok {
   149  			connection, err := c.refineResourceSpec(v.GetStructValue())
   150  			if err != nil {
   151  				return err
   152  			}
   153  			connectionPropStruct := &structpb.Struct{Fields: map[string]*structpb.Value{}}
   154  			connectionPropStruct.Fields["connection"] = structpb.NewStructValue(connection)
   155  			c.definition.Spec.ComponentSpecification.Fields["properties"] = structpb.NewStructValue(connectionPropStruct)
   156  		}
   157  	}
   158  
   159  	c.definition.Spec.DataSpecifications, err = generateDataSpecs(taskStructs)
   160  	if err != nil {
   161  		return err
   162  	}
   163  
   164  	c.initCredentialField(c.definition)
   165  
   166  	return nil
   167  
   168  }
   169  
   170  func (c *BaseConnector) refineResourceSpec(resourceSpec *structpb.Struct) (*structpb.Struct, error) {
   171  
   172  	spec := proto.Clone(resourceSpec).(*structpb.Struct)
   173  	if _, ok := spec.Fields["instillShortDescription"]; !ok {
   174  		spec.Fields["instillShortDescription"] = structpb.NewStringValue(spec.Fields["description"].GetStringValue())
   175  	}
   176  
   177  	if _, ok := spec.Fields["properties"]; ok {
   178  		for k, v := range spec.Fields["properties"].GetStructValue().AsMap() {
   179  			s, err := structpb.NewStruct(v.(map[string]interface{}))
   180  			if err != nil {
   181  				return nil, err
   182  			}
   183  			converted, err := c.refineResourceSpec(s)
   184  			if err != nil {
   185  				return nil, err
   186  			}
   187  			spec.Fields["properties"].GetStructValue().Fields[k] = structpb.NewStructValue(converted)
   188  
   189  		}
   190  	}
   191  	if _, ok := spec.Fields["patternProperties"]; ok {
   192  		for k, v := range spec.Fields["patternProperties"].GetStructValue().AsMap() {
   193  			s, err := structpb.NewStruct(v.(map[string]interface{}))
   194  			if err != nil {
   195  				return nil, err
   196  			}
   197  			converted, err := c.refineResourceSpec(s)
   198  			if err != nil {
   199  				return nil, err
   200  			}
   201  			spec.Fields["patternProperties"].GetStructValue().Fields[k] = structpb.NewStructValue(converted)
   202  
   203  		}
   204  	}
   205  	for _, target := range []string{"allOf", "anyOf", "oneOf"} {
   206  		if _, ok := spec.Fields[target]; ok {
   207  			for idx, item := range spec.Fields[target].GetListValue().AsSlice() {
   208  				s, err := structpb.NewStruct(item.(map[string]interface{}))
   209  				if err != nil {
   210  					return nil, err
   211  				}
   212  				converted, err := c.refineResourceSpec(s)
   213  				if err != nil {
   214  					return nil, err
   215  				}
   216  				spec.Fields[target].GetListValue().AsSlice()[idx] = structpb.NewStructValue(converted)
   217  			}
   218  		}
   219  	}
   220  
   221  	return spec, nil
   222  }
   223  
   224  // IsCredentialField checks if the target field is credential field
   225  func (c *BaseConnector) IsCredentialField(target string) bool {
   226  	for _, field := range c.credentialFields {
   227  		if target == field {
   228  			return true
   229  		}
   230  	}
   231  	return false
   232  }
   233  
   234  // ListCredentialField lists the credential fields by definition id
   235  func (c *BaseConnector) ListCredentialField() ([]string, error) {
   236  	return c.credentialFields, nil
   237  }
   238  
   239  func (c *BaseConnector) initCredentialField(def *pipelinePB.ConnectorDefinition) {
   240  	if c.credentialFields == nil {
   241  		c.credentialFields = []string{}
   242  	}
   243  	credentialFields := []string{}
   244  	connection := def.Spec.GetComponentSpecification().GetFields()["properties"].GetStructValue().GetFields()["connection"].GetStructValue()
   245  	credentialFields = c.traverseCredentialField(connection.GetFields()["properties"], "", credentialFields)
   246  	if l, ok := connection.GetFields()["oneOf"]; ok {
   247  		for _, v := range l.GetListValue().Values {
   248  			credentialFields = c.traverseCredentialField(v.GetStructValue().GetFields()["properties"], "", credentialFields)
   249  		}
   250  	}
   251  	c.credentialFields = credentialFields
   252  }
   253  
   254  func (c *BaseConnector) traverseCredentialField(input *structpb.Value, prefix string, credentialFields []string) []string {
   255  	for key, v := range input.GetStructValue().GetFields() {
   256  		if isCredential, ok := v.GetStructValue().GetFields()["instillCredentialField"]; ok {
   257  			if isCredential.GetBoolValue() || isCredential.GetStringValue() == "true" {
   258  				credentialFields = append(credentialFields, fmt.Sprintf("%s%s", prefix, key))
   259  			}
   260  		}
   261  		if tp, ok := v.GetStructValue().GetFields()["type"]; ok {
   262  			if tp.GetStringValue() == "object" {
   263  				if l, ok := v.GetStructValue().GetFields()["oneOf"]; ok {
   264  					for _, v := range l.GetListValue().Values {
   265  						credentialFields = c.traverseCredentialField(v.GetStructValue().GetFields()["properties"], fmt.Sprintf("%s%s.", prefix, key), credentialFields)
   266  					}
   267  				}
   268  				credentialFields = c.traverseCredentialField(v.GetStructValue().GetFields()["properties"], fmt.Sprintf("%s%s.", prefix, key), credentialFields)
   269  			}
   270  
   271  		}
   272  	}
   273  
   274  	return credentialFields
   275  }
   276  
   277  func (e *BaseConnectorExecution) GetTask() string {
   278  	return e.Task
   279  }
   280  func (e *BaseConnectorExecution) GetConnector() IConnector {
   281  	return e.Connector
   282  }
   283  func (e *BaseConnectorExecution) GetConnection() *structpb.Struct {
   284  	return e.Connection
   285  }
   286  func (e *BaseConnectorExecution) GetSystemVariables() map[string]any {
   287  	return e.SystemVariables
   288  }
   289  func (e *BaseConnectorExecution) GetLogger() *zap.Logger {
   290  	return e.Connector.GetLogger()
   291  }
   292  func (e *BaseConnectorExecution) GetUsageHandler() UsageHandler {
   293  	return e.Connector.GetUsageHandler()
   294  }
   295  func (e *BaseConnectorExecution) GetTaskInputSchema() string {
   296  	return e.Connector.GetTaskInputSchemas()[e.Task]
   297  }
   298  func (e *BaseConnectorExecution) GetTaskOutputSchema() string {
   299  	return e.Connector.GetTaskOutputSchemas()[e.Task]
   300  }