github.com/instill-ai/component@v0.16.0-beta/pkg/connector/instill/v0/image_to_image.go (about) 1 package instill 2 3 import ( 4 "context" 5 "fmt" 6 7 "google.golang.org/grpc/metadata" 8 "google.golang.org/protobuf/encoding/protojson" 9 "google.golang.org/protobuf/types/known/structpb" 10 11 "github.com/instill-ai/component/pkg/base" 12 modelPB "github.com/instill-ai/protogen-go/model/model/v1alpha" 13 ) 14 15 func (e *execution) executeImageToImage(grpcClient modelPB.ModelPublicServiceClient, modelName string, inputs []*structpb.Struct) ([]*structpb.Struct, error) { 16 if len(inputs) <= 0 { 17 return nil, fmt.Errorf("invalid input: %v for model: %s", inputs, modelName) 18 } 19 20 if grpcClient == nil { 21 return nil, fmt.Errorf("uninitialized client") 22 } 23 24 outputs := []*structpb.Struct{} 25 26 for _, input := range inputs { 27 28 prompt := input.GetFields()["prompt"].GetStringValue() 29 imageToImageInput := &modelPB.ImageToImageInput{ 30 Prompt: &prompt, 31 } 32 if _, ok := input.GetFields()["steps"]; ok { 33 v := int32(input.GetFields()["steps"].GetNumberValue()) 34 imageToImageInput.Steps = &v 35 } 36 if _, ok := input.GetFields()["image_base64"]; ok { 37 imageToImageInput.Type = &modelPB.ImageToImageInput_PromptImageBase64{ 38 PromptImageBase64: base.TrimBase64Mime(input.GetFields()["image_base64"].GetStringValue()), 39 } 40 } 41 if _, ok := input.GetFields()["temperature"]; ok { 42 v := int32(input.GetFields()["temperature"].GetNumberValue()) 43 imageToImageInput.Seed = &v 44 } 45 if _, ok := input.GetFields()["cfg_scale"]; ok { 46 v := float32(input.GetFields()["cfg_scale"].GetNumberValue()) 47 imageToImageInput.CfgScale = &v 48 } 49 50 if _, ok := input.GetFields()["seed"]; ok { 51 v := int32(input.GetFields()["seed"].GetNumberValue()) 52 imageToImageInput.Seed = &v 53 } 54 if _, ok := input.GetFields()["extra_params"]; ok { 55 v := input.GetFields()["extra_params"].GetStructValue() 56 imageToImageInput.ExtraParams = v 57 } 58 59 taskInput := &modelPB.TaskInput_ImageToImage{ 60 ImageToImage: imageToImageInput, 61 } 62 63 // only support batch 1 64 req := modelPB.TriggerUserModelRequest{ 65 Name: modelName, 66 TaskInputs: []*modelPB.TaskInput{{Input: taskInput}}, 67 } 68 ctx := metadata.NewOutgoingContext(context.Background(), getRequestMetadata(e.SystemVariables)) 69 res, err := grpcClient.TriggerUserModel(ctx, &req) 70 if err != nil || res == nil { 71 return nil, err 72 } 73 taskOutputs := res.GetTaskOutputs() 74 if len(taskOutputs) <= 0 { 75 return nil, fmt.Errorf("invalid output: %v for model: %s", taskOutputs, modelName) 76 } 77 78 imageToImageOutput := taskOutputs[0].GetImageToImage() 79 if imageToImageOutput == nil { 80 return nil, fmt.Errorf("invalid output: %v for model: %s", imageToImageOutput, modelName) 81 } 82 for imageIdx := range imageToImageOutput.Images { 83 imageToImageOutput.Images[imageIdx] = fmt.Sprintf("data:image/jpeg;base64,%s", imageToImageOutput.Images[imageIdx]) 84 } 85 86 outputJSON, err := protojson.MarshalOptions{ 87 UseProtoNames: true, 88 EmitUnpopulated: true, 89 }.Marshal(imageToImageOutput) 90 if err != nil { 91 return nil, err 92 } 93 output := &structpb.Struct{} 94 err = protojson.Unmarshal(outputJSON, output) 95 if err != nil { 96 return nil, err 97 } 98 outputs = append(outputs, output) 99 100 } 101 return outputs, nil 102 }