github.com/instill-ai/component@v0.16.0-beta/pkg/connector/stabilityai/v0/image_to_image.go (about)

     1  package stabilityai
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"mime/multipart"
     7  
     8  	"github.com/instill-ai/component/pkg/base"
     9  	"github.com/instill-ai/component/pkg/connector/util"
    10  	"google.golang.org/protobuf/types/known/structpb"
    11  )
    12  
    13  const imageToImagePathTemplate = "/v1/generation/%s/image-to-image"
    14  
    15  func imageToImagePath(engine string) string {
    16  	return fmt.Sprintf(imageToImagePathTemplate, engine)
    17  }
    18  
    19  type ImageToImageInput struct {
    20  	Task               string     `json:"task"`
    21  	Engine             string     `json:"engine"`
    22  	Prompts            []string   `json:"prompts"`
    23  	InitImage          string     `json:"init_image"`
    24  	Weights            *[]float64 `json:"weights,omitempty"`
    25  	InitImageMode      *string    `json:"init_image_mode,omitempty"`
    26  	ImageStrength      *float64   `json:"image_strength,omitempty"`
    27  	StepScheduleStart  *float64   `json:"step_schedule_start,omitempty"`
    28  	StepScheduleEnd    *float64   `json:"step_schedule_end,omitempty"`
    29  	CfgScale           *float64   `json:"cfg_scale,omitempty"`
    30  	ClipGuidancePreset *string    `json:"clip_guidance_preset,omitempty"`
    31  	Sampler            *string    `json:"sampler,omitempty"`
    32  	Samples            *uint32    `json:"samples,omitempty"`
    33  	Seed               *uint32    `json:"seed,omitempty"`
    34  	Steps              *uint32    `json:"steps,omitempty"`
    35  	StylePreset        *string    `json:"style_preset,omitempty"`
    36  }
    37  
    38  type ImageToImageOutput struct {
    39  	Images []string `json:"images"`
    40  	Seeds  []uint32 `json:"seeds"`
    41  }
    42  
    43  // ImageToImageReq represents the request body for image-to-image API
    44  type ImageToImageReq struct {
    45  	TextPrompts        []TextPrompt `json:"text_prompts" om:"texts[:]"`
    46  	InitImage          string       `json:"init_image" om:"images[0]"`
    47  	CFGScale           *float64     `json:"cfg_scale,omitempty" om:"metadata.cfg_scale"`
    48  	ClipGuidancePreset *string      `json:"clip_guidance_preset,omitempty" om:"metadata.clip_guidance_preset"`
    49  	Sampler            *string      `json:"sampler,omitempty" om:"metadata.sampler"`
    50  	Samples            *uint32      `json:"samples,omitempty" om:"metadata.samples"`
    51  	Seed               *uint32      `json:"seed,omitempty" om:"metadata.seed"`
    52  	Steps              *uint32      `json:"steps,omitempty" om:"metadata.steps"`
    53  	StylePreset        *string      `json:"style_preset,omitempty" om:"metadata.style_preset"`
    54  	InitImageMode      *string      `json:"init_image_mode,omitempty" om:"metadata.init_image_mode"`
    55  	ImageStrength      *float64     `json:"image_strength,omitempty" om:"metadata.image_strength"`
    56  	StepScheduleStart  *float64     `json:"step_schedule_start,omitempty" om:"metadata.step_schedule_start"`
    57  	StepScheduleEnd    *float64     `json:"step_schedule_end,omitempty" om:"metadata.step_schedule_end"`
    58  
    59  	path string
    60  }
    61  
    62  func parseImageToImageReq(from *structpb.Struct) (ImageToImageReq, error) {
    63  	// Parse from pb.
    64  	input := ImageToImageInput{}
    65  	if err := base.ConvertFromStructpb(from, &input); err != nil {
    66  		return ImageToImageReq{}, err
    67  	}
    68  
    69  	// Validate input.
    70  	nPrompts := len(input.Prompts)
    71  	if nPrompts <= 0 {
    72  		return ImageToImageReq{}, fmt.Errorf("no text prompts given")
    73  	}
    74  
    75  	if input.Engine == "" {
    76  		return ImageToImageReq{}, fmt.Errorf("no engine selected")
    77  	}
    78  
    79  	// Convert to req.
    80  	req := ImageToImageReq{
    81  		InitImage:          input.InitImage,
    82  		InitImageMode:      input.InitImageMode,
    83  		ImageStrength:      input.ImageStrength,
    84  		StepScheduleStart:  input.StepScheduleStart,
    85  		StepScheduleEnd:    input.StepScheduleEnd,
    86  		CFGScale:           input.CfgScale,
    87  		ClipGuidancePreset: input.ClipGuidancePreset,
    88  		Sampler:            input.Sampler,
    89  		Samples:            input.Samples,
    90  		Seed:               input.Seed,
    91  		Steps:              input.Steps,
    92  		StylePreset:        input.StylePreset,
    93  
    94  		path: imageToImagePath(input.Engine),
    95  	}
    96  
    97  	req.TextPrompts = make([]TextPrompt, 0, nPrompts)
    98  	for index, t := range input.Prompts {
    99  		var w float64
   100  		if input.Weights != nil && len(*input.Weights) > index {
   101  			w = (*input.Weights)[index]
   102  		}
   103  
   104  		req.TextPrompts = append(req.TextPrompts, TextPrompt{
   105  			Text:   t,
   106  			Weight: &w,
   107  		})
   108  	}
   109  
   110  	return req, nil
   111  }
   112  
   113  func (req ImageToImageReq) getBytes() (b *bytes.Reader, contentType string, err error) {
   114  	data := &bytes.Buffer{}
   115  	initImage, err := util.DecodeBase64(req.InitImage)
   116  	if err != nil {
   117  		return nil, "", err
   118  	}
   119  	writer := multipart.NewWriter(data)
   120  	err = util.WriteFile(writer, "init_image", initImage)
   121  	if err != nil {
   122  		return nil, "", err
   123  	}
   124  	if req.CFGScale != nil {
   125  		util.WriteField(writer, "cfg_scale", fmt.Sprintf("%f", *req.CFGScale))
   126  	}
   127  	if req.ClipGuidancePreset != nil {
   128  		util.WriteField(writer, "clip_guidance_preset", *req.ClipGuidancePreset)
   129  	}
   130  	if req.Sampler != nil {
   131  		util.WriteField(writer, "sampler", *req.Sampler)
   132  	}
   133  	if req.Seed != nil {
   134  		util.WriteField(writer, "seed", fmt.Sprintf("%d", *req.Seed))
   135  	}
   136  	if req.StylePreset != nil {
   137  		util.WriteField(writer, "style_preset", *req.StylePreset)
   138  	}
   139  	if req.InitImageMode != nil {
   140  		util.WriteField(writer, "init_image_mode", *req.InitImageMode)
   141  	}
   142  	if req.ImageStrength != nil {
   143  		util.WriteField(writer, "image_strength", fmt.Sprintf("%f", *req.ImageStrength))
   144  	}
   145  	if req.Samples != nil {
   146  		util.WriteField(writer, "samples", fmt.Sprintf("%d", *req.Samples))
   147  	}
   148  	if req.Steps != nil {
   149  		util.WriteField(writer, "steps", fmt.Sprintf("%d", *req.Steps))
   150  	}
   151  
   152  	i := 0
   153  	for _, t := range req.TextPrompts {
   154  		if t.Text == "" {
   155  			continue
   156  		}
   157  		util.WriteField(writer, fmt.Sprintf("text_prompts[%d][text]", i), t.Text)
   158  		if t.Weight != nil {
   159  			util.WriteField(writer, fmt.Sprintf("text_prompts[%d][weight]", i), fmt.Sprintf("%f", *t.Weight))
   160  		}
   161  		i++
   162  	}
   163  	writer.Close()
   164  	return bytes.NewReader(data.Bytes()), writer.FormDataContentType(), nil
   165  }
   166  
   167  func imageToImageOutput(from ImageTaskRes) (*structpb.Struct, error) {
   168  	output := ImageToImageOutput{
   169  		Images: []string{},
   170  		Seeds:  []uint32{},
   171  	}
   172  
   173  	for _, image := range from.Images {
   174  		if image.FinishReason != successFinishReason {
   175  			continue
   176  		}
   177  
   178  		output.Images = append(output.Images, fmt.Sprintf("data:image/png;base64,%s", image.Base64))
   179  		output.Seeds = append(output.Seeds, image.Seed)
   180  
   181  	}
   182  	return base.ConvertToStructpb(output)
   183  }