github.com/instill-ai/component@v0.16.0-beta/pkg/connector/stabilityai/v0/text_to_image.go (about) 1 package stabilityai 2 3 import ( 4 "fmt" 5 6 "github.com/instill-ai/component/pkg/base" 7 "google.golang.org/protobuf/types/known/structpb" 8 ) 9 10 const ( 11 successFinishReason = "SUCCESS" 12 textToImagePathTemplate = "/v1/generation/%s/text-to-image" 13 ) 14 15 func textToImagePath(engine string) string { 16 return fmt.Sprintf(textToImagePathTemplate, engine) 17 } 18 19 type TextToImageInput struct { 20 Task string `json:"task"` 21 Prompts []string `json:"prompts"` 22 Engine string `json:"engine"` 23 Weights *[]float64 `json:"weights,omitempty"` 24 Height *uint32 `json:"height,omitempty"` 25 Width *uint32 `json:"width,omitempty"` 26 CfgScale *float64 `json:"cfg_scale,omitempty"` 27 ClipGuidancePreset *string `json:"clip_guidance_preset,omitempty"` 28 Sampler *string `json:"sampler,omitempty"` 29 Samples *uint32 `json:"samples,omitempty"` 30 Seed *uint32 `json:"seed,omitempty"` 31 Steps *uint32 `json:"steps,omitempty"` 32 StylePreset *string `json:"style_preset,omitempty"` 33 } 34 35 type TextToImageOutput struct { 36 Images []string `json:"images"` 37 Seeds []uint32 `json:"seeds"` 38 } 39 40 // TextToImageReq represents the request body for text-to-image API 41 type TextToImageReq struct { 42 TextPrompts []TextPrompt `json:"text_prompts" om:"texts[:]"` 43 CFGScale *float64 `json:"cfg_scale,omitempty" om:"metadata.cfg_scale"` 44 ClipGuidancePreset *string `json:"clip_guidance_preset,omitempty" om:"metadata.clip_guidance_preset"` 45 Sampler *string `json:"sampler,omitempty" om:"metadata.sampler"` 46 Samples *uint32 `json:"samples,omitempty" om:"metadata.samples"` 47 Seed *uint32 `json:"seed,omitempty" om:"metadata.seed"` 48 Steps *uint32 `json:"steps,omitempty" om:"metadata.steps"` 49 StylePreset *string `json:"style_preset,omitempty" om:"metadata.style_preset"` 50 Height *uint32 `json:"height,omitempty" om:"metadata.height"` 51 Width *uint32 `json:"width,omitempty" om:"metadata.width"` 52 53 path string 54 } 55 56 // TextPrompt holds a prompt's text and its weight. 57 type TextPrompt struct { 58 Text string `json:"text" om:"."` 59 Weight *float64 `json:"weight"` 60 } 61 62 // Image represents a single image. 63 type Image struct { 64 Base64 string `json:"base64"` 65 Seed uint32 `json:"seed"` 66 FinishReason string `json:"finishReason"` 67 } 68 69 // ImageTaskRes represents the response body for text-to-image API. 70 type ImageTaskRes struct { 71 Images []Image `json:"artifacts"` 72 } 73 74 func parseTextToImageReq(from *structpb.Struct) (TextToImageReq, error) { 75 // Parse from pb. 76 input := TextToImageInput{} 77 if err := base.ConvertFromStructpb(from, &input); err != nil { 78 return TextToImageReq{}, err 79 } 80 81 // Validate input. 82 nPrompts := len(input.Prompts) 83 if nPrompts <= 0 { 84 return TextToImageReq{}, fmt.Errorf("no text prompts given") 85 } 86 87 if input.Engine == "" { 88 return TextToImageReq{}, fmt.Errorf("no engine selected") 89 } 90 91 // Convert to req. 92 req := TextToImageReq{ 93 CFGScale: input.CfgScale, 94 ClipGuidancePreset: input.ClipGuidancePreset, 95 Sampler: input.Sampler, 96 Samples: input.Samples, 97 Seed: input.Seed, 98 Steps: input.Steps, 99 StylePreset: input.StylePreset, 100 Height: input.Height, 101 Width: input.Width, 102 103 path: textToImagePath(input.Engine), 104 } 105 106 req.TextPrompts = make([]TextPrompt, 0, nPrompts) 107 for index, t := range input.Prompts { 108 // If weight isn't provided, set to 1. 109 w := 1.0 110 if input.Weights != nil && len(*input.Weights) > index { 111 w = (*input.Weights)[index] 112 } 113 114 req.TextPrompts = append(req.TextPrompts, TextPrompt{ 115 Text: t, 116 Weight: &w, 117 }) 118 } 119 120 return req, nil 121 } 122 123 func textToImageOutput(from ImageTaskRes) (*structpb.Struct, error) { 124 output := TextToImageOutput{ 125 Images: []string{}, 126 Seeds: []uint32{}, 127 } 128 129 for _, image := range from.Images { 130 if image.FinishReason != successFinishReason { 131 continue 132 } 133 134 output.Images = append(output.Images, fmt.Sprintf("data:image/png;base64,%s", image.Base64)) 135 output.Seeds = append(output.Seeds, image.Seed) 136 } 137 138 return base.ConvertToStructpb(output) 139 }