github.com/instill-ai/component@v0.16.0-beta/pkg/connector/instill/v0/text_to_image.go (about)

     1  package instill
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  
     7  	"google.golang.org/grpc/metadata"
     8  	"google.golang.org/protobuf/encoding/protojson"
     9  	"google.golang.org/protobuf/types/known/structpb"
    10  
    11  	"github.com/instill-ai/component/pkg/base"
    12  	modelPB "github.com/instill-ai/protogen-go/model/model/v1alpha"
    13  )
    14  
    15  func (e *execution) executeTextToImage(grpcClient modelPB.ModelPublicServiceClient, modelName string, inputs []*structpb.Struct) ([]*structpb.Struct, error) {
    16  	if len(inputs) <= 0 {
    17  		return nil, fmt.Errorf("invalid input: %v for model: %s", inputs, modelName)
    18  	}
    19  
    20  	if grpcClient == nil {
    21  		return nil, fmt.Errorf("uninitialized client")
    22  	}
    23  
    24  	outputs := []*structpb.Struct{}
    25  	for _, input := range inputs {
    26  		textToImageInput := &modelPB.TextToImageInput{
    27  			Prompt: input.GetFields()["prompt"].GetStringValue(),
    28  		}
    29  		if _, ok := input.GetFields()["steps"]; ok {
    30  			v := int32(input.GetFields()["steps"].GetNumberValue())
    31  			textToImageInput.Steps = &v
    32  		}
    33  		if _, ok := input.GetFields()["image_base64"]; ok {
    34  			textToImageInput.Type = &modelPB.TextToImageInput_PromptImageBase64{
    35  				PromptImageBase64: base.TrimBase64Mime(input.GetFields()["image_base64"].GetStringValue()),
    36  			}
    37  		}
    38  		if _, ok := input.GetFields()["cfg_scale"]; ok {
    39  			v := float32(input.GetFields()["cfg_scale"].GetNumberValue())
    40  			textToImageInput.CfgScale = &v
    41  		}
    42  		if _, ok := input.GetFields()["samples"]; ok {
    43  			v := int32(input.GetFields()["samples"].GetNumberValue())
    44  			textToImageInput.Samples = &v
    45  		}
    46  		if _, ok := input.GetFields()["seed"]; ok {
    47  			v := int32(input.GetFields()["seed"].GetNumberValue())
    48  			textToImageInput.Seed = &v
    49  		}
    50  		taskInput := &modelPB.TaskInput_TextToImage{
    51  			TextToImage: textToImageInput,
    52  		}
    53  
    54  		// only support batch 1
    55  		req := modelPB.TriggerUserModelRequest{
    56  			Name:       modelName,
    57  			TaskInputs: []*modelPB.TaskInput{{Input: taskInput}},
    58  		}
    59  		ctx := metadata.NewOutgoingContext(context.Background(), getRequestMetadata(e.SystemVariables))
    60  		res, err := grpcClient.TriggerUserModel(ctx, &req)
    61  		if err != nil || res == nil {
    62  			return nil, err
    63  		}
    64  		taskOutputs := res.GetTaskOutputs()
    65  		if len(taskOutputs) <= 0 {
    66  			return nil, fmt.Errorf("invalid output: %v for model: %s", taskOutputs, modelName)
    67  		}
    68  
    69  		textToImgOutput := taskOutputs[0].GetTextToImage()
    70  
    71  		for imageIdx := range textToImgOutput.Images {
    72  			textToImgOutput.Images[imageIdx] = fmt.Sprintf("data:image/jpeg;base64,%s", textToImgOutput.Images[imageIdx])
    73  		}
    74  
    75  		if textToImgOutput == nil || len(textToImgOutput.Images) <= 0 {
    76  			return nil, fmt.Errorf("invalid output: %v for model: %s", textToImgOutput, modelName)
    77  		}
    78  
    79  		outputJSON, err := protojson.MarshalOptions{
    80  			UseProtoNames:   true,
    81  			EmitUnpopulated: true,
    82  		}.Marshal(textToImgOutput)
    83  		if err != nil {
    84  			return nil, err
    85  		}
    86  		output := &structpb.Struct{}
    87  		err = protojson.Unmarshal(outputJSON, output)
    88  		if err != nil {
    89  			return nil, err
    90  		}
    91  		outputs = append(outputs, output)
    92  	}
    93  	return outputs, nil
    94  }