github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/docs/examples/aisio_webdataset/pytorch_webdataset.py (about)

     1  import os
     2  
     3  import torchvision
     4  
     5  from aistore.pytorch import AISSourceLister
     6  from aistore.sdk import Client
     7  import webdataset as wds
     8  
     9  
    10  AIS_ENDPOINT = os.getenv("AIS_ENDPOINT")
    11  client = Client(AIS_ENDPOINT)
    12  bucket_name = "images"
    13  etl_name = "wd-transform"
    14  
    15  
    16  def show_image_tensor(image_data):
    17      transform = torchvision.transforms.ToPILImage()
    18      image = transform(image_data)
    19      image.show()
    20      
    21  
    22  def create_dataset() -> wds.WebDataset:
    23      bucket = client.bucket(bucket_name)
    24      # Get a list of urls for each object in AIS, with ETL applied, converted to the format WebDataset expects
    25      sources = AISSourceLister(ais_sources=[bucket], etl_name=etl_name).map(lambda source_url: {"url": source_url})\
    26          .shuffle()
    27      # Load shuffled list of transformed shards into WebDataset pipeline
    28      dataset = wds.WebDataset(sources)
    29      # Shuffle samples and apply built-in webdataset decoder for image files
    30      dataset = dataset.shuffle(size=1000).decode("torchrgb")
    31      # Return iterator over samples as tuples in batches
    32      return dataset.to_tuple("cls", "image.jpg", "trimap.png").batched(16)
    33  
    34  
    35  def create_dataloader(dataset) -> wds.WebLoader:
    36      loader = wds.WebLoader(dataset, num_workers=4, batch_size=None)
    37      return loader.unbatched().shuffle(1000).batched(64)
    38  
    39  
    40  def view_data(dataloader):
    41      # Get the first batch
    42      batch = next(iter(dataloader))
    43      classes, images, trimaps = batch
    44      # Result is a set of tensors with the first dimension being the batch size
    45      print(classes.shape, images.shape, trimaps.shape)
    46      # View the first images in the first batch
    47      show_image_tensor(images[0])
    48      show_image_tensor(trimaps[0])
    49  
    50  
    51  if __name__ == '__main__':
    52      wd_dataset = create_dataset()
    53      wd_dataloader = create_dataloader(wd_dataset)
    54      view_data(wd_dataloader)
    55      first_batch = next(iter(wd_dataloader))
    56      classes, images, trimaps = first_batch