gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/images/gpu/stable-diffusion-xl/generate_image.py (about)

     1  #!/usr/bin/env python3
     2  
     3  # Copyright 2024 The gVisor Authors.
     4  #
     5  # Licensed under the Apache License, Version 2.0 (the "License");
     6  # you may not use this file except in compliance with the License.
     7  # You may obtain a copy of the License at
     8  #
     9  #     http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  
    17  """Generate image with Stable Diffusion XL.
    18  
    19  Images are written to stdout by wrapper script.
    20  """
    21  
    22  import argparse
    23  import array
    24  import base64
    25  import datetime
    26  import enum
    27  import fcntl
    28  import io
    29  import json
    30  import os
    31  import subprocess
    32  import termios
    33  
    34  import diffusers
    35  import torch
    36  
    37  
    38  # Define arguments.
    39  class Format(enum.Enum):
    40    """Output format enum."""
    41  
    42    PNG = 'PNG'
    43    JPEG = 'JPEG'
    44    ASCII = 'ASCII'
    45    BRAILLE = 'BRAILLE'
    46    PNG_BASE64 = 'PNG-BASE64'
    47    METRICS = 'METRICS'
    48  
    49    @property
    50    def is_terminal_output(self):
    51      return self in (Format.ASCII, Format.BRAILLE, Format.METRICS)
    52  
    53    def __str__(self):
    54      return self.value
    55  
    56  
    57  parser = argparse.ArgumentParser(
    58      prog='generate_image',
    59      description='Generate an image using Stable Diffusion XL',
    60  )
    61  
    62  # Arguments passed by wrapper script.
    63  parser.add_argument('--out', required=True, type=str, help=argparse.SUPPRESS)
    64  parser.add_argument('--terminal_pixel_width', type=str, help=argparse.SUPPRESS)
    65  parser.add_argument('--terminal_pixel_height', type=str, help=argparse.SUPPRESS)
    66  
    67  parser.add_argument(
    68      '--quiet_stderr',
    69      action='store_true',
    70      help=(
    71          'Suppress PyTorch messages to stderr; useful if stderr output is'
    72          ' captured.'
    73      ),
    74  )
    75  parser.add_argument(
    76      '--enable_model_cpu_offload',
    77      action='store_true',
    78      help='Offload non-main components of model to CPU if low on GPU VRAM',
    79  )
    80  parser.add_argument(
    81      '--format',
    82      type=Format,
    83      choices=list(Format),
    84      default=Format.BRAILLE,
    85      help='Output file format: ' + ', '.join(str(v) for v in Format),
    86  )
    87  parser.add_argument(
    88      '--steps', default=50, type=int, help='Number of diffusion steps'
    89  )
    90  parser.add_argument(
    91      '--noise_frac', default=0.8, type=float, help='Noise fraction'
    92  )
    93  parser.add_argument(
    94      '--enable_refiner',
    95      action='store_true',
    96      help='Use the refiner model on top of the base model for better results',
    97  )
    98  parser.add_argument(
    99      '--warm',
   100      action='store_true',
   101      help='Generate the image twice; timing metrics will measure both images',
   102  )
   103  parser.add_argument('prompt', type=str, help='Prompt to generate image')
   104  args = parser.parse_args()
   105  
   106  # Load base model.
   107  time_start = datetime.datetime.now(datetime.timezone.utc)
   108  base = diffusers.DiffusionPipeline.from_pretrained(
   109      'stabilityai/stable-diffusion-xl-base-1.0',
   110      torch_dtype=torch.float16,
   111      variant='fp16',
   112      use_safetensors=True,
   113  )
   114  if args.enable_model_cpu_offload:
   115    base.enable_model_cpu_offload()
   116  else:
   117    base.to('cuda')
   118  base.unet = torch.compile(base.unet, mode='reduce-overhead', fullgraph=True)
   119  
   120  # Load refiner model if enabled.
   121  refiner = None
   122  if args.enable_refiner:
   123    refiner = diffusers.DiffusionPipeline.from_pretrained(
   124        'stabilityai/stable-diffusion-xl-refiner-1.0',
   125        text_encoder_2=base.text_encoder_2,
   126        vae=base.vae,
   127        torch_dtype=torch.float16,
   128        use_safetensors=True,
   129        variant='fp16',
   130    )
   131    if args.enable_model_cpu_offload:
   132      refiner.enable_model_cpu_offload()
   133    else:
   134      refiner.to('cuda')
   135    refiner.unet = torch.compile(
   136        refiner.unet, mode='reduce-overhead', fullgraph=True
   137    )
   138  
   139  # Set the prompt.
   140  default_prompt = (
   141      'Photorealistic image of two androids playing chess aboard a spaceship'
   142  )
   143  if args.format.is_terminal_output:
   144    # If displaying in a terminal, cartoony pictures that have sharp edges will
   145    # look much clearer than photorealistic pictures.
   146    default_prompt = 'A boring flat corporate logo that says "gVisor"'
   147  prompt = args.prompt or default_prompt
   148  
   149  
   150  # Generate image.
   151  def generate_image():
   152    """Run the base model and maybe the refiner model to generate the image."""
   153  
   154    time_start_image = datetime.datetime.now(datetime.timezone.utc)
   155    if not args.enable_refiner:
   156      img = base(
   157          prompt=prompt,
   158          num_inference_steps=args.steps,
   159          output_type='pil',
   160      ).images[0]
   161      time_base_done = datetime.datetime.now(datetime.timezone.utc)
   162      time_refiner_done = None
   163    else:
   164      base_images = base(
   165          prompt=prompt,
   166          num_inference_steps=args.steps,
   167          denoising_end=args.noise_frac,
   168          output_type='latent',
   169      ).images
   170      time_base_done = datetime.datetime.now(datetime.timezone.utc)
   171      img = refiner(
   172          prompt=prompt,
   173          num_inference_steps=args.steps,
   174          denoising_start=args.noise_frac,
   175          image=base_images,
   176      ).images[0]
   177      time_refiner_done = datetime.datetime.now(datetime.timezone.utc)
   178    return img, time_start_image, time_base_done, time_refiner_done
   179  
   180  
   181  image, cold_start_image, cold_base_done, cold_refiner_done = generate_image()
   182  warm_start_image, warm_base_done, warm_refiner_done = None, None, None
   183  if args.warm:
   184    image, warm_start_image, warm_base_done, warm_refiner_done = generate_image()
   185  
   186  
   187  def get_optimal_terminal_width():
   188    """Returns the width of the terminal for ASCII image display."""
   189    try:
   190      terminal_width, terminal_height = os.get_terminal_size()
   191    except OSError:  # Not a TTY, return a sane default.
   192      return 80
   193    if terminal_width == 0 or terminal_height == 0:  # Incoherent terminal size.
   194      return 80
   195    if terminal_width <= 42:
   196      # Ridiculously small terminal, return default dimension anyway because
   197      # whatever we do won't look nice regardless.
   198      return 80
   199    # Try to find the aspect ratio of a single terminal character.
   200    terminal_pixel_width = 0
   201    terminal_pixel_height = 0
   202    if args.terminal_pixel_width.isdigit():
   203      terminal_pixel_width = int(args.terminal_pixel_width)
   204    if args.terminal_pixel_height.isdigit():
   205      terminal_pixel_height = int(args.terminal_pixel_height)
   206    if terminal_pixel_width == 0 or terminal_pixel_height == 0:
   207      termios_buf = array.array('H', [0, 0, 0, 0])
   208      fcntl.ioctl(1, termios.TIOCGWINSZ, termios_buf)
   209      _, _, terminal_pixel_width, terminal_pixel_height = termios_buf
   210    if terminal_pixel_width != 0 and terminal_pixel_height != 0:
   211      character_width = float(terminal_pixel_width) / float(terminal_width)
   212      character_height = float(terminal_pixel_height) / float(terminal_height)
   213      character_aspect_ratio = character_width / character_height
   214    else:
   215      character_aspect_ratio = 0.5  # Just use a sane default.
   216    adjusted_terminal_height = float(terminal_height) / float(
   217        character_aspect_ratio
   218    )
   219    image_width, image_height = image.size
   220    width_ratio = float(image_width) / float(terminal_width)
   221    height_ratio = float(image_height) / adjusted_terminal_height
   222    if width_ratio > height_ratio:
   223      # Width is determining factor.
   224      return terminal_width
   225    # Height is the determining factor.
   226    final_width = int(
   227        adjusted_terminal_height * float(image_width) / float(image_height)
   228    )
   229    # Remove one just to not make it take literally the entire console, and
   230    # in case our estimation for things like character size is wrong.
   231    final_width -= 1
   232    if final_width < 8:
   233      # Very vertical image, most likely text? So it's OK if it scrolls.
   234      # Return a sane default width.
   235      return 42
   236    return final_width
   237  
   238  
   239  # Save image in desired format.
   240  if args.format in (Format.PNG, Format.JPEG):
   241    image.save(args.out, args.format)
   242  else:
   243    buf = io.BytesIO()
   244    image.save(buf, format=Format.PNG.value)
   245    image_bytes = buf.getvalue()
   246    time_done = datetime.datetime.now(datetime.timezone.utc)
   247    with open(args.out, 'wb') as f:
   248      if args.format == Format.PNG_BASE64:
   249        f.write(base64.standard_b64encode(image_bytes))
   250      elif args.format.is_terminal_output:
   251        image_converter_args = [
   252            '/usr/bin/ascii-image-converter',
   253            '/dev/stdin',
   254            '--width=%d' % (get_optimal_terminal_width(),),
   255        ]
   256        if args.format in (Format.BRAILLE, Format.METRICS):
   257          image_converter_args.extend(('--braille', '--dither'))
   258        else:
   259          image_converter_args.append('--complex')
   260        image_ascii = subprocess.run(
   261            image_converter_args,
   262            input=image_bytes,
   263            capture_output=True,
   264            check=True,
   265            timeout=60,
   266        )
   267        if args.format == Format.METRICS:
   268          split_lines = lambda x: [
   269              x[i : i + 1024] for i in range(0, len(x), 1024)
   270          ]
   271          results = {
   272              'image_ascii_base64': split_lines(
   273                  base64.standard_b64encode(image_ascii.stdout).decode('ascii')
   274              ),
   275              'image_png_base64': split_lines(
   276                  base64.standard_b64encode(image_bytes).decode('ascii')
   277              ),
   278          }
   279          for name, timestamp in (
   280              ('start', time_start),
   281              ('cold_start_image', cold_start_image),
   282              ('cold_base_done', cold_base_done),
   283              ('cold_refiner_done', cold_refiner_done),
   284              ('warm_start_image', warm_start_image),
   285              ('warm_base_done', warm_base_done),
   286              ('warm_refiner_done', warm_refiner_done),
   287              ('done', time_done),
   288          ):
   289            results[name] = (
   290                timestamp.isoformat() if timestamp is not None else None
   291            )
   292          # Python's `json` module always outputs strings, not bytes, so
   293          # we cannot directly dump to `f`. Output to string instead, then
   294          # encode.
   295          # Also, `json.dumps` doesn't add a trailing newline, so we do.
   296          results_json = (
   297              json.dumps(results, sort_keys=True, ensure_ascii=True, indent=2)
   298              + '\n'
   299          )
   300          f.write(results_json.encode('ascii'))
   301        else:
   302          f.write(image_ascii.stdout)
   303      else:
   304        raise ValueError(f'Unknown format: {args.format}')