github.com/instill-ai/component@v0.16.0-beta/pkg/connector/instill/v0/instance_segmentation.go (about) 1 package instill 2 3 import ( 4 "context" 5 "fmt" 6 7 "google.golang.org/grpc/metadata" 8 "google.golang.org/protobuf/encoding/protojson" 9 "google.golang.org/protobuf/types/known/structpb" 10 11 "github.com/instill-ai/component/pkg/base" 12 modelPB "github.com/instill-ai/protogen-go/model/model/v1alpha" 13 ) 14 15 func (e *execution) executeInstanceSegmentation(grpcClient modelPB.ModelPublicServiceClient, modelName string, inputs []*structpb.Struct) ([]*structpb.Struct, error) { 16 if len(inputs) <= 0 { 17 return nil, fmt.Errorf("invalid input: %v for model: %s", inputs, modelName) 18 } 19 20 if grpcClient == nil { 21 return nil, fmt.Errorf("uninitialized client") 22 } 23 24 taskInputs := []*modelPB.TaskInput{} 25 for _, input := range inputs { 26 inputJSON, err := protojson.Marshal(input) 27 if err != nil { 28 return nil, err 29 } 30 segmentationInput := &modelPB.InstanceSegmentationInput{} 31 err = protojson.UnmarshalOptions{DiscardUnknown: true}.Unmarshal(inputJSON, segmentationInput) 32 if err != nil { 33 return nil, err 34 } 35 segmentationInput.Type = &modelPB.InstanceSegmentationInput_ImageBase64{ 36 ImageBase64: base.TrimBase64Mime(segmentationInput.GetImageBase64()), 37 } 38 39 taskInput := &modelPB.TaskInput_InstanceSegmentation{ 40 InstanceSegmentation: segmentationInput, 41 } 42 taskInputs = append(taskInputs, &modelPB.TaskInput{Input: taskInput}) 43 } 44 req := modelPB.TriggerUserModelRequest{ 45 Name: modelName, 46 TaskInputs: taskInputs, 47 } 48 ctx := metadata.NewOutgoingContext(context.Background(), getRequestMetadata(e.SystemVariables)) 49 res, err := grpcClient.TriggerUserModel(ctx, &req) 50 if err != nil || res == nil { 51 return nil, err 52 } 53 taskOutputs := res.GetTaskOutputs() 54 if len(taskOutputs) <= 0 { 55 return nil, fmt.Errorf("invalid output: %v for model: %s", taskOutputs, modelName) 56 } 57 58 outputs := []*structpb.Struct{} 59 for idx := range inputs { 60 instanceSegmentationOp := taskOutputs[idx].GetInstanceSegmentation() 61 if instanceSegmentationOp == nil { 62 return nil, fmt.Errorf("invalid output: %v for model: %s", instanceSegmentationOp, modelName) 63 } 64 outputJSON, err := protojson.MarshalOptions{ 65 UseProtoNames: true, 66 EmitUnpopulated: true, 67 }.Marshal(instanceSegmentationOp) 68 if err != nil { 69 return nil, err 70 } 71 output := &structpb.Struct{} 72 err = protojson.Unmarshal(outputJSON, output) 73 if err != nil { 74 return nil, err 75 } 76 outputs = append(outputs, output) 77 78 } 79 80 return outputs, nil 81 }