gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/images/gpu/stable-diffusion-xl/generate_image.py (about) 1 #!/usr/bin/env python3 2 3 # Copyright 2024 The gVisor Authors. 4 # 5 # Licensed under the Apache License, Version 2.0 (the "License"); 6 # you may not use this file except in compliance with the License. 7 # You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 17 """Generate image with Stable Diffusion XL. 18 19 Images are written to stdout by wrapper script. 20 """ 21 22 import argparse 23 import array 24 import base64 25 import datetime 26 import enum 27 import fcntl 28 import io 29 import json 30 import os 31 import subprocess 32 import termios 33 34 import diffusers 35 import torch 36 37 38 # Define arguments. 39 class Format(enum.Enum): 40 """Output format enum.""" 41 42 PNG = 'PNG' 43 JPEG = 'JPEG' 44 ASCII = 'ASCII' 45 BRAILLE = 'BRAILLE' 46 PNG_BASE64 = 'PNG-BASE64' 47 METRICS = 'METRICS' 48 49 @property 50 def is_terminal_output(self): 51 return self in (Format.ASCII, Format.BRAILLE, Format.METRICS) 52 53 def __str__(self): 54 return self.value 55 56 57 parser = argparse.ArgumentParser( 58 prog='generate_image', 59 description='Generate an image using Stable Diffusion XL', 60 ) 61 62 # Arguments passed by wrapper script. 63 parser.add_argument('--out', required=True, type=str, help=argparse.SUPPRESS) 64 parser.add_argument('--terminal_pixel_width', type=str, help=argparse.SUPPRESS) 65 parser.add_argument('--terminal_pixel_height', type=str, help=argparse.SUPPRESS) 66 67 parser.add_argument( 68 '--quiet_stderr', 69 action='store_true', 70 help=( 71 'Suppress PyTorch messages to stderr; useful if stderr output is' 72 ' captured.' 73 ), 74 ) 75 parser.add_argument( 76 '--enable_model_cpu_offload', 77 action='store_true', 78 help='Offload non-main components of model to CPU if low on GPU VRAM', 79 ) 80 parser.add_argument( 81 '--format', 82 type=Format, 83 choices=list(Format), 84 default=Format.BRAILLE, 85 help='Output file format: ' + ', '.join(str(v) for v in Format), 86 ) 87 parser.add_argument( 88 '--steps', default=50, type=int, help='Number of diffusion steps' 89 ) 90 parser.add_argument( 91 '--noise_frac', default=0.8, type=float, help='Noise fraction' 92 ) 93 parser.add_argument( 94 '--enable_refiner', 95 action='store_true', 96 help='Use the refiner model on top of the base model for better results', 97 ) 98 parser.add_argument( 99 '--warm', 100 action='store_true', 101 help='Generate the image twice; timing metrics will measure both images', 102 ) 103 parser.add_argument('prompt', type=str, help='Prompt to generate image') 104 args = parser.parse_args() 105 106 # Load base model. 107 time_start = datetime.datetime.now(datetime.timezone.utc) 108 base = diffusers.DiffusionPipeline.from_pretrained( 109 'stabilityai/stable-diffusion-xl-base-1.0', 110 torch_dtype=torch.float16, 111 variant='fp16', 112 use_safetensors=True, 113 ) 114 if args.enable_model_cpu_offload: 115 base.enable_model_cpu_offload() 116 else: 117 base.to('cuda') 118 base.unet = torch.compile(base.unet, mode='reduce-overhead', fullgraph=True) 119 120 # Load refiner model if enabled. 121 refiner = None 122 if args.enable_refiner: 123 refiner = diffusers.DiffusionPipeline.from_pretrained( 124 'stabilityai/stable-diffusion-xl-refiner-1.0', 125 text_encoder_2=base.text_encoder_2, 126 vae=base.vae, 127 torch_dtype=torch.float16, 128 use_safetensors=True, 129 variant='fp16', 130 ) 131 if args.enable_model_cpu_offload: 132 refiner.enable_model_cpu_offload() 133 else: 134 refiner.to('cuda') 135 refiner.unet = torch.compile( 136 refiner.unet, mode='reduce-overhead', fullgraph=True 137 ) 138 139 # Set the prompt. 140 default_prompt = ( 141 'Photorealistic image of two androids playing chess aboard a spaceship' 142 ) 143 if args.format.is_terminal_output: 144 # If displaying in a terminal, cartoony pictures that have sharp edges will 145 # look much clearer than photorealistic pictures. 146 default_prompt = 'A boring flat corporate logo that says "gVisor"' 147 prompt = args.prompt or default_prompt 148 149 150 # Generate image. 151 def generate_image(): 152 """Run the base model and maybe the refiner model to generate the image.""" 153 154 time_start_image = datetime.datetime.now(datetime.timezone.utc) 155 if not args.enable_refiner: 156 img = base( 157 prompt=prompt, 158 num_inference_steps=args.steps, 159 output_type='pil', 160 ).images[0] 161 time_base_done = datetime.datetime.now(datetime.timezone.utc) 162 time_refiner_done = None 163 else: 164 base_images = base( 165 prompt=prompt, 166 num_inference_steps=args.steps, 167 denoising_end=args.noise_frac, 168 output_type='latent', 169 ).images 170 time_base_done = datetime.datetime.now(datetime.timezone.utc) 171 img = refiner( 172 prompt=prompt, 173 num_inference_steps=args.steps, 174 denoising_start=args.noise_frac, 175 image=base_images, 176 ).images[0] 177 time_refiner_done = datetime.datetime.now(datetime.timezone.utc) 178 return img, time_start_image, time_base_done, time_refiner_done 179 180 181 image, cold_start_image, cold_base_done, cold_refiner_done = generate_image() 182 warm_start_image, warm_base_done, warm_refiner_done = None, None, None 183 if args.warm: 184 image, warm_start_image, warm_base_done, warm_refiner_done = generate_image() 185 186 187 def get_optimal_terminal_width(): 188 """Returns the width of the terminal for ASCII image display.""" 189 try: 190 terminal_width, terminal_height = os.get_terminal_size() 191 except OSError: # Not a TTY, return a sane default. 192 return 80 193 if terminal_width == 0 or terminal_height == 0: # Incoherent terminal size. 194 return 80 195 if terminal_width <= 42: 196 # Ridiculously small terminal, return default dimension anyway because 197 # whatever we do won't look nice regardless. 198 return 80 199 # Try to find the aspect ratio of a single terminal character. 200 terminal_pixel_width = 0 201 terminal_pixel_height = 0 202 if args.terminal_pixel_width.isdigit(): 203 terminal_pixel_width = int(args.terminal_pixel_width) 204 if args.terminal_pixel_height.isdigit(): 205 terminal_pixel_height = int(args.terminal_pixel_height) 206 if terminal_pixel_width == 0 or terminal_pixel_height == 0: 207 termios_buf = array.array('H', [0, 0, 0, 0]) 208 fcntl.ioctl(1, termios.TIOCGWINSZ, termios_buf) 209 _, _, terminal_pixel_width, terminal_pixel_height = termios_buf 210 if terminal_pixel_width != 0 and terminal_pixel_height != 0: 211 character_width = float(terminal_pixel_width) / float(terminal_width) 212 character_height = float(terminal_pixel_height) / float(terminal_height) 213 character_aspect_ratio = character_width / character_height 214 else: 215 character_aspect_ratio = 0.5 # Just use a sane default. 216 adjusted_terminal_height = float(terminal_height) / float( 217 character_aspect_ratio 218 ) 219 image_width, image_height = image.size 220 width_ratio = float(image_width) / float(terminal_width) 221 height_ratio = float(image_height) / adjusted_terminal_height 222 if width_ratio > height_ratio: 223 # Width is determining factor. 224 return terminal_width 225 # Height is the determining factor. 226 final_width = int( 227 adjusted_terminal_height * float(image_width) / float(image_height) 228 ) 229 # Remove one just to not make it take literally the entire console, and 230 # in case our estimation for things like character size is wrong. 231 final_width -= 1 232 if final_width < 8: 233 # Very vertical image, most likely text? So it's OK if it scrolls. 234 # Return a sane default width. 235 return 42 236 return final_width 237 238 239 # Save image in desired format. 240 if args.format in (Format.PNG, Format.JPEG): 241 image.save(args.out, args.format) 242 else: 243 buf = io.BytesIO() 244 image.save(buf, format=Format.PNG.value) 245 image_bytes = buf.getvalue() 246 time_done = datetime.datetime.now(datetime.timezone.utc) 247 with open(args.out, 'wb') as f: 248 if args.format == Format.PNG_BASE64: 249 f.write(base64.standard_b64encode(image_bytes)) 250 elif args.format.is_terminal_output: 251 image_converter_args = [ 252 '/usr/bin/ascii-image-converter', 253 '/dev/stdin', 254 '--width=%d' % (get_optimal_terminal_width(),), 255 ] 256 if args.format in (Format.BRAILLE, Format.METRICS): 257 image_converter_args.extend(('--braille', '--dither')) 258 else: 259 image_converter_args.append('--complex') 260 image_ascii = subprocess.run( 261 image_converter_args, 262 input=image_bytes, 263 capture_output=True, 264 check=True, 265 timeout=60, 266 ) 267 if args.format == Format.METRICS: 268 split_lines = lambda x: [ 269 x[i : i + 1024] for i in range(0, len(x), 1024) 270 ] 271 results = { 272 'image_ascii_base64': split_lines( 273 base64.standard_b64encode(image_ascii.stdout).decode('ascii') 274 ), 275 'image_png_base64': split_lines( 276 base64.standard_b64encode(image_bytes).decode('ascii') 277 ), 278 } 279 for name, timestamp in ( 280 ('start', time_start), 281 ('cold_start_image', cold_start_image), 282 ('cold_base_done', cold_base_done), 283 ('cold_refiner_done', cold_refiner_done), 284 ('warm_start_image', warm_start_image), 285 ('warm_base_done', warm_base_done), 286 ('warm_refiner_done', warm_refiner_done), 287 ('done', time_done), 288 ): 289 results[name] = ( 290 timestamp.isoformat() if timestamp is not None else None 291 ) 292 # Python's `json` module always outputs strings, not bytes, so 293 # we cannot directly dump to `f`. Output to string instead, then 294 # encode. 295 # Also, `json.dumps` doesn't add a trailing newline, so we do. 296 results_json = ( 297 json.dumps(results, sort_keys=True, ensure_ascii=True, indent=2) 298 + '\n' 299 ) 300 f.write(results_json.encode('ascii')) 301 else: 302 f.write(image_ascii.stdout) 303 else: 304 raise ValueError(f'Unknown format: {args.format}')