github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/docs/examples/aisio_webdataset/etl_webdataset.py (about)

     1  import io
     2  import os
     3  
     4  import torchvision
     5  import webdataset as wds
     6  from PIL import Image
     7  from aistore.sdk import Client
     8  from torch.utils.data import IterableDataset
     9  from torch.utils.data.dataset import T_co
    10  
    11  AIS_ENDPOINT = os.getenv("AIS_ENDPOINT")
    12  bucket_name = "images"
    13  etl_name = "wd-transform"
    14  
    15  
    16  def show_image(image_data):
    17      with Image.open(io.BytesIO(image_data)) as image:
    18          image.show()
    19  
    20  
    21  def wd_etl(object_url):
    22      def img_to_bytes(img):
    23          buf = io.BytesIO()
    24          img = img.convert("RGB")
    25          img.save(buf, format="JPEG")
    26          return buf.getvalue()
    27  
    28      def process_trimap(trimap_bytes):
    29          image = Image.open(io.BytesIO(trimap_bytes))
    30          preprocessing = torchvision.transforms.Compose(
    31              [
    32                  torchvision.transforms.CenterCrop(350),
    33                  torchvision.transforms.Lambda(img_to_bytes)
    34              ]
    35          )
    36          return preprocessing(image)
    37  
    38      def process_image(image_bytes):
    39          image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
    40          preprocessing = torchvision.transforms.Compose(
    41              [
    42                  torchvision.transforms.CenterCrop(350),
    43                  torchvision.transforms.ToTensor(),
    44                  # Means and stds from ImageNet
    45                  torchvision.transforms.Normalize(
    46                      mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    47                  ),
    48                  torchvision.transforms.ToPILImage(),
    49                  torchvision.transforms.Lambda(img_to_bytes),
    50              ]
    51          )
    52          return preprocessing(image)
    53  
    54      # Initialize a WD object from the AIS URL
    55      dataset = wds.WebDataset(object_url)
    56      # Map the files for each individual sample to the appropriate processing function
    57      processed_shard = dataset.map_dict(**{"image.jpg": process_image, "trimap.png": process_trimap})
    58  
    59      # Write the output to a memory buffer and return the value
    60      buffer = io.BytesIO()
    61      with wds.TarWriter(fileobj=buffer) as dst:
    62          for sample in processed_shard:
    63              dst.write(sample)
    64      return buffer.getvalue()
    65  
    66  
    67  def create_wd_etl(client):
    68      client.etl(etl_name).init_code(
    69          transform=wd_etl,
    70          preimported_modules=["torch"],
    71          dependencies=["webdataset", "pillow", "torch", "torchvision"],
    72          communication_type="hpull",
    73          transform_url=True
    74      )
    75  
    76  
    77  class LocalTarDataset(IterableDataset):
    78      """
    79      Builds a PyTorch IterableDataset from bytes in memory as if was read from a URL by WebDataset. This lets us
    80      initialize a WebDataset Pipeline without writing to local disk and iterate over each record from a shard.
    81      """
    82      def __getitem__(self, index) -> T_co:
    83          raise NotImplemented
    84  
    85      def __init__(self, input_bytes):
    86          self.data = [{"url": "input_data", "stream": io.BytesIO(input_bytes)}]
    87  
    88      def __iter__(self):
    89          files = wds.tariterators.tar_file_expander(self.data)
    90          samples = wds.tariterators.group_by_keys(files)
    91          return samples
    92  
    93  
    94  def read_object_tar(shard_data):
    95      local_dataset = LocalTarDataset(shard_data)
    96      sample = next(iter(local_dataset))
    97      show_image(sample.get('image.jpg'))
    98  
    99  
   100  def transform_object_inline():
   101      single_object = client.bucket(bucket_name).object("samples-00.tar")
   102      # Get object contents with ETL applied
   103      processed_shard = single_object.get(etl_name=etl_name).read_all()
   104      read_object_tar(processed_shard)
   105  
   106  
   107  def transform_bucket_offline():
   108      dest_bucket = client.bucket("processed-samples").create(exist_ok=True)
   109      # Transform the entire bucket, placing the output in the destination bucket
   110      transform_job = client.bucket(bucket_name).transform(to_bck=dest_bucket, etl_name=etl_name)
   111      client.job(transform_job).wait(verbose=True)
   112      processed_shard = dest_bucket.object("samples-00.tar").get().read_all()
   113      read_object_tar(processed_shard)
   114  
   115  
   116  if __name__ == "__main__":
   117      client = Client(AIS_ENDPOINT)
   118      image_bucket = client.bucket(bucket_name)
   119      create_wd_etl(client)
   120      transform_object_inline()
   121      transform_bucket_offline()