github.com/instill-ai/component@v0.16.0-beta/pkg/connector/instill/v0/text_to_image.go (about) 1 package instill 2 3 import ( 4 "context" 5 "fmt" 6 7 "google.golang.org/grpc/metadata" 8 "google.golang.org/protobuf/encoding/protojson" 9 "google.golang.org/protobuf/types/known/structpb" 10 11 "github.com/instill-ai/component/pkg/base" 12 modelPB "github.com/instill-ai/protogen-go/model/model/v1alpha" 13 ) 14 15 func (e *execution) executeTextToImage(grpcClient modelPB.ModelPublicServiceClient, modelName string, inputs []*structpb.Struct) ([]*structpb.Struct, error) { 16 if len(inputs) <= 0 { 17 return nil, fmt.Errorf("invalid input: %v for model: %s", inputs, modelName) 18 } 19 20 if grpcClient == nil { 21 return nil, fmt.Errorf("uninitialized client") 22 } 23 24 outputs := []*structpb.Struct{} 25 for _, input := range inputs { 26 textToImageInput := &modelPB.TextToImageInput{ 27 Prompt: input.GetFields()["prompt"].GetStringValue(), 28 } 29 if _, ok := input.GetFields()["steps"]; ok { 30 v := int32(input.GetFields()["steps"].GetNumberValue()) 31 textToImageInput.Steps = &v 32 } 33 if _, ok := input.GetFields()["image_base64"]; ok { 34 textToImageInput.Type = &modelPB.TextToImageInput_PromptImageBase64{ 35 PromptImageBase64: base.TrimBase64Mime(input.GetFields()["image_base64"].GetStringValue()), 36 } 37 } 38 if _, ok := input.GetFields()["cfg_scale"]; ok { 39 v := float32(input.GetFields()["cfg_scale"].GetNumberValue()) 40 textToImageInput.CfgScale = &v 41 } 42 if _, ok := input.GetFields()["samples"]; ok { 43 v := int32(input.GetFields()["samples"].GetNumberValue()) 44 textToImageInput.Samples = &v 45 } 46 if _, ok := input.GetFields()["seed"]; ok { 47 v := int32(input.GetFields()["seed"].GetNumberValue()) 48 textToImageInput.Seed = &v 49 } 50 taskInput := &modelPB.TaskInput_TextToImage{ 51 TextToImage: textToImageInput, 52 } 53 54 // only support batch 1 55 req := modelPB.TriggerUserModelRequest{ 56 Name: modelName, 57 TaskInputs: []*modelPB.TaskInput{{Input: taskInput}}, 58 } 59 ctx := metadata.NewOutgoingContext(context.Background(), getRequestMetadata(e.SystemVariables)) 60 res, err := grpcClient.TriggerUserModel(ctx, &req) 61 if err != nil || res == nil { 62 return nil, err 63 } 64 taskOutputs := res.GetTaskOutputs() 65 if len(taskOutputs) <= 0 { 66 return nil, fmt.Errorf("invalid output: %v for model: %s", taskOutputs, modelName) 67 } 68 69 textToImgOutput := taskOutputs[0].GetTextToImage() 70 71 for imageIdx := range textToImgOutput.Images { 72 textToImgOutput.Images[imageIdx] = fmt.Sprintf("data:image/jpeg;base64,%s", textToImgOutput.Images[imageIdx]) 73 } 74 75 if textToImgOutput == nil || len(textToImgOutput.Images) <= 0 { 76 return nil, fmt.Errorf("invalid output: %v for model: %s", textToImgOutput, modelName) 77 } 78 79 outputJSON, err := protojson.MarshalOptions{ 80 UseProtoNames: true, 81 EmitUnpopulated: true, 82 }.Marshal(textToImgOutput) 83 if err != nil { 84 return nil, err 85 } 86 output := &structpb.Struct{} 87 err = protojson.Unmarshal(outputJSON, output) 88 if err != nil { 89 return nil, err 90 } 91 outputs = append(outputs, output) 92 } 93 return outputs, nil 94 }