github.com/instill-ai/component@v0.16.0-beta/pkg/connector/instill/v0/semantic_segmentation.go (about)

     1  package instill
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  
     7  	"google.golang.org/grpc/metadata"
     8  	"google.golang.org/protobuf/encoding/protojson"
     9  	"google.golang.org/protobuf/types/known/structpb"
    10  
    11  	"github.com/instill-ai/component/pkg/base"
    12  	modelPB "github.com/instill-ai/protogen-go/model/model/v1alpha"
    13  )
    14  
    15  func (e *execution) executeSemanticSegmentation(grpcClient modelPB.ModelPublicServiceClient, modelName string, inputs []*structpb.Struct) ([]*structpb.Struct, error) {
    16  	if len(inputs) <= 0 {
    17  		return nil, fmt.Errorf("invalid input: %v for model: %s", inputs, modelName)
    18  	}
    19  	if grpcClient == nil {
    20  		return nil, fmt.Errorf("uninitialized client")
    21  	}
    22  	taskInputs := []*modelPB.TaskInput{}
    23  	for _, input := range inputs {
    24  		inputJSON, err := protojson.Marshal(input)
    25  		if err != nil {
    26  			return nil, err
    27  		}
    28  		semanticSegmentationInput := &modelPB.SemanticSegmentationInput{}
    29  		err = protojson.UnmarshalOptions{DiscardUnknown: true}.Unmarshal(inputJSON, semanticSegmentationInput)
    30  		if err != nil {
    31  			return nil, err
    32  		}
    33  		semanticSegmentationInput.Type = &modelPB.SemanticSegmentationInput_ImageBase64{
    34  			ImageBase64: base.TrimBase64Mime(semanticSegmentationInput.GetImageBase64()),
    35  		}
    36  
    37  		taskInput := &modelPB.TaskInput_SemanticSegmentation{
    38  			SemanticSegmentation: semanticSegmentationInput,
    39  		}
    40  		taskInputs = append(taskInputs, &modelPB.TaskInput{Input: taskInput})
    41  
    42  	}
    43  
    44  	req := modelPB.TriggerUserModelRequest{
    45  		Name:       modelName,
    46  		TaskInputs: taskInputs,
    47  	}
    48  	ctx := metadata.NewOutgoingContext(context.Background(), getRequestMetadata(e.SystemVariables))
    49  	res, err := grpcClient.TriggerUserModel(ctx, &req)
    50  	if err != nil || res == nil {
    51  		return nil, err
    52  	}
    53  	taskOutputs := res.GetTaskOutputs()
    54  	if len(taskOutputs) <= 0 {
    55  		return nil, fmt.Errorf("invalid output: %v for model: %s", taskOutputs, modelName)
    56  	}
    57  
    58  	outputs := []*structpb.Struct{}
    59  	for idx := range inputs {
    60  		semanticSegmentationOp := taskOutputs[idx].GetSemanticSegmentation()
    61  		if semanticSegmentationOp == nil {
    62  			return nil, fmt.Errorf("invalid output: %v for model: %s", semanticSegmentationOp, modelName)
    63  		}
    64  		outputJSON, err := protojson.MarshalOptions{
    65  			UseProtoNames:   true,
    66  			EmitUnpopulated: true,
    67  		}.Marshal(semanticSegmentationOp)
    68  		if err != nil {
    69  			return nil, err
    70  		}
    71  		output := &structpb.Struct{}
    72  		err = protojson.Unmarshal(outputJSON, output)
    73  		if err != nil {
    74  			return nil, err
    75  		}
    76  		outputs = append(outputs, output)
    77  	}
    78  	return outputs, nil
    79  }