github.com/instill-ai/component@v0.16.0-beta/pkg/connector/instill/v0/image_classification.go (about) 1 package instill 2 3 import ( 4 "context" 5 "fmt" 6 7 "google.golang.org/grpc/metadata" 8 "google.golang.org/protobuf/encoding/protojson" 9 "google.golang.org/protobuf/types/known/structpb" 10 11 "github.com/instill-ai/component/pkg/base" 12 modelPB "github.com/instill-ai/protogen-go/model/model/v1alpha" 13 ) 14 15 func (e *execution) executeImageClassification(grpcClient modelPB.ModelPublicServiceClient, modelName string, inputs []*structpb.Struct) ([]*structpb.Struct, error) { 16 if len(inputs) <= 0 { 17 return nil, fmt.Errorf("invalid input: %v for model: %s", inputs, modelName) 18 } 19 20 if grpcClient == nil { 21 return nil, fmt.Errorf("uninitialized client") 22 } 23 24 taskInputs := []*modelPB.TaskInput{} 25 for _, input := range inputs { 26 27 inputJSON, err := protojson.Marshal(input) 28 if err != nil { 29 return nil, err 30 } 31 classificationInput := &modelPB.ClassificationInput{} 32 err = protojson.UnmarshalOptions{DiscardUnknown: true}.Unmarshal(inputJSON, classificationInput) 33 if err != nil { 34 return nil, err 35 } 36 classificationInput.Type = &modelPB.ClassificationInput_ImageBase64{ 37 ImageBase64: base.TrimBase64Mime(classificationInput.GetImageBase64()), 38 } 39 40 taskInput := &modelPB.TaskInput_Classification{ 41 Classification: classificationInput, 42 } 43 taskInputs = append(taskInputs, &modelPB.TaskInput{Input: taskInput}) 44 } 45 46 req := modelPB.TriggerUserModelRequest{ 47 Name: modelName, 48 TaskInputs: taskInputs, 49 } 50 ctx := metadata.NewOutgoingContext(context.Background(), getRequestMetadata(e.SystemVariables)) 51 res, err := grpcClient.TriggerUserModel(ctx, &req) 52 if err != nil || res == nil { 53 return nil, err 54 } 55 taskOutputs := res.GetTaskOutputs() 56 if len(taskOutputs) <= 0 { 57 return nil, fmt.Errorf("invalid output: %v for model: %s", taskOutputs, modelName) 58 } 59 outputs := []*structpb.Struct{} 60 for idx := range inputs { 61 imgClassificationOp := taskOutputs[idx].GetClassification() 62 if imgClassificationOp == nil { 63 return nil, fmt.Errorf("invalid output: %v for model: %s", imgClassificationOp, modelName) 64 } 65 outputJSON, err := protojson.MarshalOptions{ 66 UseProtoNames: true, 67 EmitUnpopulated: true, 68 }.Marshal(imgClassificationOp) 69 if err != nil { 70 return nil, err 71 } 72 output := &structpb.Struct{} 73 err = protojson.Unmarshal(outputJSON, output) 74 if err != nil { 75 return nil, err 76 } 77 outputs = append(outputs, output) 78 } 79 return outputs, nil 80 }