github.com/instill-ai/component@v0.16.0-beta/pkg/connector/stabilityai/v0/text_to_image.go (about)

     1  package stabilityai
     2  
     3  import (
     4  	"fmt"
     5  
     6  	"github.com/instill-ai/component/pkg/base"
     7  	"google.golang.org/protobuf/types/known/structpb"
     8  )
     9  
    10  const (
    11  	successFinishReason     = "SUCCESS"
    12  	textToImagePathTemplate = "/v1/generation/%s/text-to-image"
    13  )
    14  
    15  func textToImagePath(engine string) string {
    16  	return fmt.Sprintf(textToImagePathTemplate, engine)
    17  }
    18  
    19  type TextToImageInput struct {
    20  	Task               string     `json:"task"`
    21  	Prompts            []string   `json:"prompts"`
    22  	Engine             string     `json:"engine"`
    23  	Weights            *[]float64 `json:"weights,omitempty"`
    24  	Height             *uint32    `json:"height,omitempty"`
    25  	Width              *uint32    `json:"width,omitempty"`
    26  	CfgScale           *float64   `json:"cfg_scale,omitempty"`
    27  	ClipGuidancePreset *string    `json:"clip_guidance_preset,omitempty"`
    28  	Sampler            *string    `json:"sampler,omitempty"`
    29  	Samples            *uint32    `json:"samples,omitempty"`
    30  	Seed               *uint32    `json:"seed,omitempty"`
    31  	Steps              *uint32    `json:"steps,omitempty"`
    32  	StylePreset        *string    `json:"style_preset,omitempty"`
    33  }
    34  
    35  type TextToImageOutput struct {
    36  	Images []string `json:"images"`
    37  	Seeds  []uint32 `json:"seeds"`
    38  }
    39  
    40  // TextToImageReq represents the request body for text-to-image API
    41  type TextToImageReq struct {
    42  	TextPrompts        []TextPrompt `json:"text_prompts" om:"texts[:]"`
    43  	CFGScale           *float64     `json:"cfg_scale,omitempty" om:"metadata.cfg_scale"`
    44  	ClipGuidancePreset *string      `json:"clip_guidance_preset,omitempty" om:"metadata.clip_guidance_preset"`
    45  	Sampler            *string      `json:"sampler,omitempty" om:"metadata.sampler"`
    46  	Samples            *uint32      `json:"samples,omitempty" om:"metadata.samples"`
    47  	Seed               *uint32      `json:"seed,omitempty" om:"metadata.seed"`
    48  	Steps              *uint32      `json:"steps,omitempty" om:"metadata.steps"`
    49  	StylePreset        *string      `json:"style_preset,omitempty" om:"metadata.style_preset"`
    50  	Height             *uint32      `json:"height,omitempty" om:"metadata.height"`
    51  	Width              *uint32      `json:"width,omitempty" om:"metadata.width"`
    52  
    53  	path string
    54  }
    55  
    56  // TextPrompt holds a prompt's text and its weight.
    57  type TextPrompt struct {
    58  	Text   string   `json:"text" om:"."`
    59  	Weight *float64 `json:"weight"`
    60  }
    61  
    62  // Image represents a single image.
    63  type Image struct {
    64  	Base64       string `json:"base64"`
    65  	Seed         uint32 `json:"seed"`
    66  	FinishReason string `json:"finishReason"`
    67  }
    68  
    69  // ImageTaskRes represents the response body for text-to-image API.
    70  type ImageTaskRes struct {
    71  	Images []Image `json:"artifacts"`
    72  }
    73  
    74  func parseTextToImageReq(from *structpb.Struct) (TextToImageReq, error) {
    75  	// Parse from pb.
    76  	input := TextToImageInput{}
    77  	if err := base.ConvertFromStructpb(from, &input); err != nil {
    78  		return TextToImageReq{}, err
    79  	}
    80  
    81  	// Validate input.
    82  	nPrompts := len(input.Prompts)
    83  	if nPrompts <= 0 {
    84  		return TextToImageReq{}, fmt.Errorf("no text prompts given")
    85  	}
    86  
    87  	if input.Engine == "" {
    88  		return TextToImageReq{}, fmt.Errorf("no engine selected")
    89  	}
    90  
    91  	// Convert to req.
    92  	req := TextToImageReq{
    93  		CFGScale:           input.CfgScale,
    94  		ClipGuidancePreset: input.ClipGuidancePreset,
    95  		Sampler:            input.Sampler,
    96  		Samples:            input.Samples,
    97  		Seed:               input.Seed,
    98  		Steps:              input.Steps,
    99  		StylePreset:        input.StylePreset,
   100  		Height:             input.Height,
   101  		Width:              input.Width,
   102  
   103  		path: textToImagePath(input.Engine),
   104  	}
   105  
   106  	req.TextPrompts = make([]TextPrompt, 0, nPrompts)
   107  	for index, t := range input.Prompts {
   108  		// If weight isn't provided, set to 1.
   109  		w := 1.0
   110  		if input.Weights != nil && len(*input.Weights) > index {
   111  			w = (*input.Weights)[index]
   112  		}
   113  
   114  		req.TextPrompts = append(req.TextPrompts, TextPrompt{
   115  			Text:   t,
   116  			Weight: &w,
   117  		})
   118  	}
   119  
   120  	return req, nil
   121  }
   122  
   123  func textToImageOutput(from ImageTaskRes) (*structpb.Struct, error) {
   124  	output := TextToImageOutput{
   125  		Images: []string{},
   126  		Seeds:  []uint32{},
   127  	}
   128  
   129  	for _, image := range from.Images {
   130  		if image.FinishReason != successFinishReason {
   131  			continue
   132  		}
   133  
   134  		output.Images = append(output.Images, fmt.Sprintf("data:image/png;base64,%s", image.Base64))
   135  		output.Seeds = append(output.Seeds, image.Seed)
   136  	}
   137  
   138  	return base.ConvertToStructpb(output)
   139  }