github.com/instill-ai/component@v0.16.0-beta/pkg/connector/instill/v0/image_to_image.go (about)

     1  package instill
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  
     7  	"google.golang.org/grpc/metadata"
     8  	"google.golang.org/protobuf/encoding/protojson"
     9  	"google.golang.org/protobuf/types/known/structpb"
    10  
    11  	"github.com/instill-ai/component/pkg/base"
    12  	modelPB "github.com/instill-ai/protogen-go/model/model/v1alpha"
    13  )
    14  
    15  func (e *execution) executeImageToImage(grpcClient modelPB.ModelPublicServiceClient, modelName string, inputs []*structpb.Struct) ([]*structpb.Struct, error) {
    16  	if len(inputs) <= 0 {
    17  		return nil, fmt.Errorf("invalid input: %v for model: %s", inputs, modelName)
    18  	}
    19  
    20  	if grpcClient == nil {
    21  		return nil, fmt.Errorf("uninitialized client")
    22  	}
    23  
    24  	outputs := []*structpb.Struct{}
    25  
    26  	for _, input := range inputs {
    27  
    28  		prompt := input.GetFields()["prompt"].GetStringValue()
    29  		imageToImageInput := &modelPB.ImageToImageInput{
    30  			Prompt: &prompt,
    31  		}
    32  		if _, ok := input.GetFields()["steps"]; ok {
    33  			v := int32(input.GetFields()["steps"].GetNumberValue())
    34  			imageToImageInput.Steps = &v
    35  		}
    36  		if _, ok := input.GetFields()["image_base64"]; ok {
    37  			imageToImageInput.Type = &modelPB.ImageToImageInput_PromptImageBase64{
    38  				PromptImageBase64: base.TrimBase64Mime(input.GetFields()["image_base64"].GetStringValue()),
    39  			}
    40  		}
    41  		if _, ok := input.GetFields()["temperature"]; ok {
    42  			v := int32(input.GetFields()["temperature"].GetNumberValue())
    43  			imageToImageInput.Seed = &v
    44  		}
    45  		if _, ok := input.GetFields()["cfg_scale"]; ok {
    46  			v := float32(input.GetFields()["cfg_scale"].GetNumberValue())
    47  			imageToImageInput.CfgScale = &v
    48  		}
    49  
    50  		if _, ok := input.GetFields()["seed"]; ok {
    51  			v := int32(input.GetFields()["seed"].GetNumberValue())
    52  			imageToImageInput.Seed = &v
    53  		}
    54  		if _, ok := input.GetFields()["extra_params"]; ok {
    55  			v := input.GetFields()["extra_params"].GetStructValue()
    56  			imageToImageInput.ExtraParams = v
    57  		}
    58  
    59  		taskInput := &modelPB.TaskInput_ImageToImage{
    60  			ImageToImage: imageToImageInput,
    61  		}
    62  
    63  		// only support batch 1
    64  		req := modelPB.TriggerUserModelRequest{
    65  			Name:       modelName,
    66  			TaskInputs: []*modelPB.TaskInput{{Input: taskInput}},
    67  		}
    68  		ctx := metadata.NewOutgoingContext(context.Background(), getRequestMetadata(e.SystemVariables))
    69  		res, err := grpcClient.TriggerUserModel(ctx, &req)
    70  		if err != nil || res == nil {
    71  			return nil, err
    72  		}
    73  		taskOutputs := res.GetTaskOutputs()
    74  		if len(taskOutputs) <= 0 {
    75  			return nil, fmt.Errorf("invalid output: %v for model: %s", taskOutputs, modelName)
    76  		}
    77  
    78  		imageToImageOutput := taskOutputs[0].GetImageToImage()
    79  		if imageToImageOutput == nil {
    80  			return nil, fmt.Errorf("invalid output: %v for model: %s", imageToImageOutput, modelName)
    81  		}
    82  		for imageIdx := range imageToImageOutput.Images {
    83  			imageToImageOutput.Images[imageIdx] = fmt.Sprintf("data:image/jpeg;base64,%s", imageToImageOutput.Images[imageIdx])
    84  		}
    85  
    86  		outputJSON, err := protojson.MarshalOptions{
    87  			UseProtoNames:   true,
    88  			EmitUnpopulated: true,
    89  		}.Marshal(imageToImageOutput)
    90  		if err != nil {
    91  			return nil, err
    92  		}
    93  		output := &structpb.Struct{}
    94  		err = protojson.Unmarshal(outputJSON, output)
    95  		if err != nil {
    96  			return nil, err
    97  		}
    98  		outputs = append(outputs, output)
    99  
   100  	}
   101  	return outputs, nil
   102  }