github.com/instill-ai/component@v0.16.0-beta/pkg/connector/stabilityai/v0/image_to_image.go (about) 1 package stabilityai 2 3 import ( 4 "bytes" 5 "fmt" 6 "mime/multipart" 7 8 "github.com/instill-ai/component/pkg/base" 9 "github.com/instill-ai/component/pkg/connector/util" 10 "google.golang.org/protobuf/types/known/structpb" 11 ) 12 13 const imageToImagePathTemplate = "/v1/generation/%s/image-to-image" 14 15 func imageToImagePath(engine string) string { 16 return fmt.Sprintf(imageToImagePathTemplate, engine) 17 } 18 19 type ImageToImageInput struct { 20 Task string `json:"task"` 21 Engine string `json:"engine"` 22 Prompts []string `json:"prompts"` 23 InitImage string `json:"init_image"` 24 Weights *[]float64 `json:"weights,omitempty"` 25 InitImageMode *string `json:"init_image_mode,omitempty"` 26 ImageStrength *float64 `json:"image_strength,omitempty"` 27 StepScheduleStart *float64 `json:"step_schedule_start,omitempty"` 28 StepScheduleEnd *float64 `json:"step_schedule_end,omitempty"` 29 CfgScale *float64 `json:"cfg_scale,omitempty"` 30 ClipGuidancePreset *string `json:"clip_guidance_preset,omitempty"` 31 Sampler *string `json:"sampler,omitempty"` 32 Samples *uint32 `json:"samples,omitempty"` 33 Seed *uint32 `json:"seed,omitempty"` 34 Steps *uint32 `json:"steps,omitempty"` 35 StylePreset *string `json:"style_preset,omitempty"` 36 } 37 38 type ImageToImageOutput struct { 39 Images []string `json:"images"` 40 Seeds []uint32 `json:"seeds"` 41 } 42 43 // ImageToImageReq represents the request body for image-to-image API 44 type ImageToImageReq struct { 45 TextPrompts []TextPrompt `json:"text_prompts" om:"texts[:]"` 46 InitImage string `json:"init_image" om:"images[0]"` 47 CFGScale *float64 `json:"cfg_scale,omitempty" om:"metadata.cfg_scale"` 48 ClipGuidancePreset *string `json:"clip_guidance_preset,omitempty" om:"metadata.clip_guidance_preset"` 49 Sampler *string `json:"sampler,omitempty" om:"metadata.sampler"` 50 Samples *uint32 `json:"samples,omitempty" om:"metadata.samples"` 51 Seed *uint32 `json:"seed,omitempty" om:"metadata.seed"` 52 Steps *uint32 `json:"steps,omitempty" om:"metadata.steps"` 53 StylePreset *string `json:"style_preset,omitempty" om:"metadata.style_preset"` 54 InitImageMode *string `json:"init_image_mode,omitempty" om:"metadata.init_image_mode"` 55 ImageStrength *float64 `json:"image_strength,omitempty" om:"metadata.image_strength"` 56 StepScheduleStart *float64 `json:"step_schedule_start,omitempty" om:"metadata.step_schedule_start"` 57 StepScheduleEnd *float64 `json:"step_schedule_end,omitempty" om:"metadata.step_schedule_end"` 58 59 path string 60 } 61 62 func parseImageToImageReq(from *structpb.Struct) (ImageToImageReq, error) { 63 // Parse from pb. 64 input := ImageToImageInput{} 65 if err := base.ConvertFromStructpb(from, &input); err != nil { 66 return ImageToImageReq{}, err 67 } 68 69 // Validate input. 70 nPrompts := len(input.Prompts) 71 if nPrompts <= 0 { 72 return ImageToImageReq{}, fmt.Errorf("no text prompts given") 73 } 74 75 if input.Engine == "" { 76 return ImageToImageReq{}, fmt.Errorf("no engine selected") 77 } 78 79 // Convert to req. 80 req := ImageToImageReq{ 81 InitImage: input.InitImage, 82 InitImageMode: input.InitImageMode, 83 ImageStrength: input.ImageStrength, 84 StepScheduleStart: input.StepScheduleStart, 85 StepScheduleEnd: input.StepScheduleEnd, 86 CFGScale: input.CfgScale, 87 ClipGuidancePreset: input.ClipGuidancePreset, 88 Sampler: input.Sampler, 89 Samples: input.Samples, 90 Seed: input.Seed, 91 Steps: input.Steps, 92 StylePreset: input.StylePreset, 93 94 path: imageToImagePath(input.Engine), 95 } 96 97 req.TextPrompts = make([]TextPrompt, 0, nPrompts) 98 for index, t := range input.Prompts { 99 var w float64 100 if input.Weights != nil && len(*input.Weights) > index { 101 w = (*input.Weights)[index] 102 } 103 104 req.TextPrompts = append(req.TextPrompts, TextPrompt{ 105 Text: t, 106 Weight: &w, 107 }) 108 } 109 110 return req, nil 111 } 112 113 func (req ImageToImageReq) getBytes() (b *bytes.Reader, contentType string, err error) { 114 data := &bytes.Buffer{} 115 initImage, err := util.DecodeBase64(req.InitImage) 116 if err != nil { 117 return nil, "", err 118 } 119 writer := multipart.NewWriter(data) 120 err = util.WriteFile(writer, "init_image", initImage) 121 if err != nil { 122 return nil, "", err 123 } 124 if req.CFGScale != nil { 125 util.WriteField(writer, "cfg_scale", fmt.Sprintf("%f", *req.CFGScale)) 126 } 127 if req.ClipGuidancePreset != nil { 128 util.WriteField(writer, "clip_guidance_preset", *req.ClipGuidancePreset) 129 } 130 if req.Sampler != nil { 131 util.WriteField(writer, "sampler", *req.Sampler) 132 } 133 if req.Seed != nil { 134 util.WriteField(writer, "seed", fmt.Sprintf("%d", *req.Seed)) 135 } 136 if req.StylePreset != nil { 137 util.WriteField(writer, "style_preset", *req.StylePreset) 138 } 139 if req.InitImageMode != nil { 140 util.WriteField(writer, "init_image_mode", *req.InitImageMode) 141 } 142 if req.ImageStrength != nil { 143 util.WriteField(writer, "image_strength", fmt.Sprintf("%f", *req.ImageStrength)) 144 } 145 if req.Samples != nil { 146 util.WriteField(writer, "samples", fmt.Sprintf("%d", *req.Samples)) 147 } 148 if req.Steps != nil { 149 util.WriteField(writer, "steps", fmt.Sprintf("%d", *req.Steps)) 150 } 151 152 i := 0 153 for _, t := range req.TextPrompts { 154 if t.Text == "" { 155 continue 156 } 157 util.WriteField(writer, fmt.Sprintf("text_prompts[%d][text]", i), t.Text) 158 if t.Weight != nil { 159 util.WriteField(writer, fmt.Sprintf("text_prompts[%d][weight]", i), fmt.Sprintf("%f", *t.Weight)) 160 } 161 i++ 162 } 163 writer.Close() 164 return bytes.NewReader(data.Bytes()), writer.FormDataContentType(), nil 165 } 166 167 func imageToImageOutput(from ImageTaskRes) (*structpb.Struct, error) { 168 output := ImageToImageOutput{ 169 Images: []string{}, 170 Seeds: []uint32{}, 171 } 172 173 for _, image := range from.Images { 174 if image.FinishReason != successFinishReason { 175 continue 176 } 177 178 output.Images = append(output.Images, fmt.Sprintf("data:image/png;base64,%s", image.Base64)) 179 output.Seeds = append(output.Seeds, image.Seed) 180 181 } 182 return base.ConvertToStructpb(output) 183 }