github.com/instill-ai/component@v0.16.0-beta/pkg/base/operator.go (about)

     1  package base
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  
     7  	"go.uber.org/zap"
     8  	"google.golang.org/protobuf/encoding/protojson"
     9  
    10  	"github.com/gofrs/uuid"
    11  	pipelinePB "github.com/instill-ai/protogen-go/vdp/pipeline/v1beta"
    12  )
    13  
    14  type IOperator interface {
    15  	IComponent
    16  
    17  	LoadOperatorDefinition(definitionJSON []byte, tasksJSON []byte, additionalJSONBytes map[string][]byte) error
    18  
    19  	// Note: Some content in the definition JSON schema needs to be generated by sysVars or component setting.
    20  	GetOperatorDefinition(sysVars map[string]any, component *pipelinePB.OperatorComponent) (*pipelinePB.OperatorDefinition, error)
    21  
    22  	CreateExecution(sysVars map[string]any, task string) (*ExecutionWrapper, error)
    23  }
    24  
    25  type BaseOperator struct {
    26  	Logger       *zap.Logger
    27  	UsageHandler UsageHandler
    28  
    29  	taskInputSchemas  map[string]string
    30  	taskOutputSchemas map[string]string
    31  
    32  	definition *pipelinePB.OperatorDefinition
    33  }
    34  
    35  type IOperatorExecution interface {
    36  	IExecution
    37  
    38  	GetOperator() IOperator
    39  }
    40  
    41  type BaseOperatorExecution struct {
    42  	Operator        IOperator
    43  	SystemVariables map[string]any
    44  	Task            string
    45  }
    46  
    47  func (o *BaseOperator) GetID() string {
    48  	return o.definition.Id
    49  }
    50  func (o *BaseOperator) GetUID() uuid.UUID {
    51  	return uuid.FromStringOrNil(o.definition.Uid)
    52  }
    53  func (o *BaseOperator) GetLogger() *zap.Logger {
    54  	return o.Logger
    55  }
    56  func (o *BaseOperator) GetUsageHandler() UsageHandler {
    57  	return o.UsageHandler
    58  }
    59  func (o *BaseOperator) GetTaskInputSchemas() map[string]string {
    60  	return o.taskInputSchemas
    61  }
    62  func (o *BaseOperator) GetTaskOutputSchemas() map[string]string {
    63  	return o.taskOutputSchemas
    64  }
    65  
    66  func (o *BaseOperator) GetOperatorDefinition(sysVars map[string]any, component *pipelinePB.OperatorComponent) (*pipelinePB.OperatorDefinition, error) {
    67  	return o.definition, nil
    68  }
    69  
    70  // LoadOperatorDefinition loads the operator definitions from json files
    71  func (o *BaseOperator) LoadOperatorDefinition(definitionJSONBytes []byte, tasksJSONBytes []byte, additionalJSONBytes map[string][]byte) error {
    72  	var err error
    73  	var definitionJSON any
    74  
    75  	err = json.Unmarshal(definitionJSONBytes, &definitionJSON)
    76  	if err != nil {
    77  		return err
    78  	}
    79  	renderedTasksJSON, err := RenderJSON(tasksJSONBytes, additionalJSONBytes)
    80  	if err != nil {
    81  		return nil
    82  	}
    83  
    84  	availableTasks := []string{}
    85  	for _, availableTask := range definitionJSON.(map[string]interface{})["available_tasks"].([]interface{}) {
    86  		availableTasks = append(availableTasks, availableTask.(string))
    87  	}
    88  
    89  	tasks, taskStructs, err := loadTasks(availableTasks, renderedTasksJSON)
    90  	if err != nil {
    91  		return err
    92  	}
    93  
    94  	o.taskInputSchemas = map[string]string{}
    95  	o.taskOutputSchemas = map[string]string{}
    96  	for k := range taskStructs {
    97  		var s []byte
    98  		s, err = protojson.Marshal(taskStructs[k].Fields["input"].GetStructValue())
    99  		if err != nil {
   100  			return err
   101  		}
   102  		o.taskInputSchemas[k] = string(s)
   103  
   104  		s, err = protojson.Marshal(taskStructs[k].Fields["output"].GetStructValue())
   105  		if err != nil {
   106  			return err
   107  		}
   108  		o.taskOutputSchemas[k] = string(s)
   109  	}
   110  
   111  	o.definition = &pipelinePB.OperatorDefinition{}
   112  	err = protojson.UnmarshalOptions{DiscardUnknown: true}.Unmarshal(definitionJSONBytes, o.definition)
   113  	if err != nil {
   114  		return err
   115  	}
   116  
   117  	o.definition.Name = fmt.Sprintf("operator-definitions/%s", o.definition.Id)
   118  	o.definition.Tasks = tasks
   119  	o.definition.Spec.ComponentSpecification, err = generateComponentSpec(o.definition.Title, tasks, taskStructs)
   120  	if err != nil {
   121  		return err
   122  	}
   123  	o.definition.Spec.DataSpecifications, err = generateDataSpecs(taskStructs)
   124  	if err != nil {
   125  		return err
   126  	}
   127  
   128  	return nil
   129  }
   130  
   131  func (e *BaseOperatorExecution) GetTask() string {
   132  	return e.Task
   133  }
   134  func (e *BaseOperatorExecution) GetOperator() IOperator {
   135  	return e.Operator
   136  }
   137  func (e *BaseOperatorExecution) GetSystemVariables() map[string]any {
   138  	return e.SystemVariables
   139  }
   140  func (e *BaseOperatorExecution) GetLogger() *zap.Logger {
   141  	return e.Operator.GetLogger()
   142  }
   143  func (e *BaseOperatorExecution) GetUsageHandler() UsageHandler {
   144  	return e.Operator.GetUsageHandler()
   145  }
   146  func (e *BaseOperatorExecution) GetTaskInputSchema() string {
   147  	return e.Operator.GetTaskInputSchemas()[e.Task]
   148  }
   149  func (e *BaseOperatorExecution) GetTaskOutputSchema() string {
   150  	return e.Operator.GetTaskOutputSchemas()[e.Task]
   151  }