gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/test/gpu/stablediffusion/stablediffusion.go (about) 1 // Copyright 2024 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package stablediffusion provides utilities to generate images with 16 // Stable Diffusion. 17 package stablediffusion 18 19 import ( 20 "bytes" 21 "context" 22 "encoding/base64" 23 "encoding/json" 24 "fmt" 25 "image" 26 "image/png" 27 "strings" 28 "time" 29 30 "gvisor.dev/gvisor/pkg/test/dockerutil" 31 "gvisor.dev/gvisor/pkg/test/testutil" 32 ) 33 34 // ContainerRunner is an interface to run containers. 35 type ContainerRunner interface { 36 // Run runs a container with the given image and arguments to completion, 37 // and returns its combined output as a byte string. 38 Run(ctx context.Context, image string, argv []string) ([]byte, error) 39 } 40 41 // dockerRunner runs Docker containers on the local machine. 42 type dockerRunner struct { 43 logger testutil.Logger 44 } 45 46 // Run implements `ContainerRunner.Run`. 47 func (dr *dockerRunner) Run(ctx context.Context, image string, argv []string) ([]byte, error) { 48 cont := dockerutil.MakeContainer(ctx, dr.logger) 49 defer cont.CleanUp(ctx) 50 opts := dockerutil.GPURunOpts() 51 opts.Image = image 52 if err := cont.Spawn(ctx, opts, argv...); err != nil { 53 return nil, fmt.Errorf("could not start Stable Diffusion container: %v", err) 54 } 55 waitErr := cont.Wait(ctx) 56 logs, logsErr := cont.Logs(ctx) 57 if waitErr != nil { 58 if logsErr == nil { 59 return nil, fmt.Errorf("container exited with error: %v; logs: %v", waitErr, logs) 60 } 61 return nil, fmt.Errorf("container exited with error: %v (cannot get logs: %v)", waitErr, logsErr) 62 } 63 if logsErr != nil { 64 return nil, fmt.Errorf("could not get container logs: %v", logsErr) 65 } 66 return []byte(logs), nil 67 } 68 69 // XL generates images using Stable Diffusion XL. 70 type XL struct { 71 image string 72 runner ContainerRunner 73 } 74 75 // NewXL returns a new Stable Diffusion XL generator. 76 func NewXL(sdxlImage string, runner ContainerRunner) *XL { 77 return &XL{ 78 image: sdxlImage, 79 runner: runner, 80 } 81 } 82 83 // NewDockerXL returns a new Stable Diffusion XL generator using Docker 84 // containers on the local machine. 85 func NewDockerXL(logger testutil.Logger) *XL { 86 return NewXL("gpu/stable-diffusion-xl", &dockerRunner{logger: logger}) 87 } 88 89 // XLPrompt is the input to Stable Diffusion XL to generate an image. 90 type XLPrompt struct { 91 // Query is the text query to generate the image with. 92 Query string 93 94 // AllowCPUOffload is whether to allow offloading parts of the model to CPU. 95 AllowCPUOffload bool 96 97 // UseRefiner is whether to use the refiner model after the base model. 98 // This takes more VRAM and more time but produces a better image. 99 UseRefiner bool 100 101 // NoiseFraction is the fraction of noise to seed the image with. 102 // Must be between 0.0 and 1.0 inclusively. 103 NoiseFraction float64 104 105 // Steps is the number of diffusion steps to run for the base and refiner 106 // models. More steps generally means sharper results but more time to 107 // generate the image. A reasonable value is between 30 and 50. 108 Steps int 109 110 // Warm controls whether the image will be generated while the model is 111 // warm. This will double the running time, as the image will still be 112 // generated with a cold model first. 113 Warm bool 114 } 115 116 // xlImageJSON is the JSON response from the Stable Diffusion XL 117 // container's generate_image.py. 118 // Warm* fields are only present when `XLPrompt.Warm` is set. 119 type xlImageJSON struct { 120 ImageASCIIBase64 []string `json:"image_ascii_base64"` 121 ImagePNGBase64 []string `json:"image_png_base64"` 122 Start time.Time `json:"start"` 123 ColdStartImage time.Time `json:"cold_start_image"` 124 ColdBaseDone time.Time `json:"cold_base_done"` 125 ColdRefinerDone time.Time `json:"cold_refiner_done"` 126 WarmStartImage time.Time `json:"warm_start_image"` 127 WarmBaseDone time.Time `json:"warm_base_done"` 128 WarmRefinerDone time.Time `json:"warm_refiner_done"` 129 Done time.Time `json:"done"` 130 } 131 132 // XLImage is an image generated by Stable Diffusion XL. 133 type XLImage struct { 134 Prompt *XLPrompt 135 data xlImageJSON 136 } 137 138 // ASCII returns an ASCII version of the generated image. 139 func (i *XLImage) ASCII() (string, error) { 140 ascii, err := base64.StdEncoding.DecodeString(strings.Join(i.data.ImageASCIIBase64, "")) 141 if err != nil { 142 return "", fmt.Errorf("invalid base64: %w", err) 143 } 144 return string(ascii), nil 145 } 146 147 // Image returns the generated image. 148 func (i *XLImage) Image() (image.Image, error) { 149 return png.Decode(base64.NewDecoder(base64.StdEncoding, bytes.NewBufferString(strings.Join(i.data.ImagePNGBase64, "")))) 150 } 151 152 // TotalDuration returns the total time taken to generate the image. 153 func (i *XLImage) TotalDuration() time.Duration { 154 return i.data.Done.Sub(i.data.Start) 155 } 156 157 // ColdBaseDuration returns time taken to run the base image generation model 158 // the first time the image was generated (i.e. the model was cold). 159 func (i *XLImage) ColdBaseDuration() time.Duration { 160 return i.data.ColdBaseDone.Sub(i.data.ColdStartImage) 161 } 162 163 // ColdRefinerDuration returns time taken to run the refiner model 164 // the first time the image was generated (i.e. the model was cold). 165 // Returns -1 if the refiner was not run. 166 func (i *XLImage) ColdRefinerDuration() time.Duration { 167 if !i.Prompt.UseRefiner { 168 return -1 169 } 170 return i.data.ColdRefinerDone.Sub(i.data.ColdBaseDone) 171 } 172 173 // WarmBaseDuration returns time taken to run the base image generation model 174 // the second time the image was generated (i.e. the model was warm). 175 func (i *XLImage) WarmBaseDuration() time.Duration { 176 return i.data.WarmBaseDone.Sub(i.data.WarmStartImage) 177 } 178 179 // WarmRefinerDuration returns time taken to run the refiner model 180 // the second time the image was generated (i.e. the model was warm). 181 // Returns -1 if the refiner was not run. 182 func (i *XLImage) WarmRefinerDuration() time.Duration { 183 if !i.Prompt.UseRefiner { 184 return -1 185 } 186 return i.data.WarmRefinerDone.Sub(i.data.WarmBaseDone) 187 } 188 189 // Generate generates an image with Stable Diffusion XL. 190 func (xl *XL) Generate(ctx context.Context, prompt *XLPrompt) (*XLImage, error) { 191 argv := []string{ 192 "--format=METRICS", 193 fmt.Sprintf("--steps=%d", prompt.Steps), 194 fmt.Sprintf("--noise_frac=%f", prompt.NoiseFraction), 195 "--quiet_stderr", 196 } 197 if prompt.AllowCPUOffload { 198 argv = append(argv, "--enable_model_cpu_offload") 199 } 200 if prompt.UseRefiner { 201 argv = append(argv, "--enable_refiner") 202 } 203 if prompt.Warm { 204 argv = append(argv, "--warm") 205 } 206 argv = append(argv, prompt.Query) 207 output, err := xl.runner.Run(ctx, xl.image, argv) 208 if err != nil { 209 return nil, err 210 } 211 xlImage := &XLImage{Prompt: prompt} 212 if err := json.Unmarshal(output, &xlImage.data); err != nil { 213 return nil, fmt.Errorf("malformed JSON output %q: %w", string(output), err) 214 } 215 return xlImage, nil 216 }