github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/docs/examples/aisio_webdataset/load_webdataset_example.py (about)

     1  import os
     2  from pathlib import Path
     3  
     4  from aistore.sdk import Client
     5  
     6  import webdataset as wds
     7  
     8  AIS_ENDPOINT = os.getenv("AIS_ENDPOINT")
     9  bucket_name = "images"
    10  
    11  
    12  def parse_annotations(annotations_file):
    13      classes = {}
    14      # Parse the annotations file into a dictionary from file name -> pet class
    15      with open(annotations_file, "r") as annotations:
    16          for line in annotations.readlines():
    17              if line[0] == "#":
    18                  continue
    19              file_name, pet_class = line.split(" ")[:2]
    20              classes[file_name] = pet_class
    21      return classes
    22  
    23  
    24  def create_sample_generator(image_dir, trimap_dir, annotations_file):
    25      classes = parse_annotations(annotations_file)
    26      # Iterate over all image files
    27      for index, image_file in enumerate(Path(image_dir).glob("*.jpg")):
    28          # Use the image name to look up class and trimap files and create a sample entry
    29          sample = create_sample(classes, trimap_dir, index, image_file)
    30          if sample is None:
    31              continue
    32          # Yield optimizes memory by returning a generator that only generates samples as requested
    33          yield sample
    34  
    35  
    36  def create_sample(classes, trimap_dir, index, image_file):
    37      file_name = str(image_file).split("/")[-1].split(".")[0]
    38      try:
    39          with open(image_file, "rb") as f:
    40              image_data = f.read()
    41          pet_class = classes.get(file_name)
    42          with open(trimap_dir.joinpath(file_name + ".png"), "rb") as f:
    43              trimap_data = f.read()
    44          if not image_data or not pet_class or not trimap_data:
    45              # Ignore incomplete records
    46              return None
    47          return {
    48              "__key__": "sample_%04d" % index,
    49              "image.jpg": image_data,
    50              "cls": pet_class,
    51              "trimap.png": trimap_data
    52              }    
    53      # Ignoring records with any missing files
    54      except FileNotFoundError as err:
    55          print(err)
    56          return None
    57  
    58  
    59  def load_data(bucket, sample_generator):
    60  
    61      def upload_shard(filename):
    62          bucket.object(filename).put_file(filename)
    63          os.unlink(filename)
    64  
    65      # Writes data as tar to disk, uses callback function "post" to upload to AIS and delete
    66      with wds.ShardWriter("samples-%02d.tar", maxcount=400, post=upload_shard) as writer:
    67          for sample in sample_generator:
    68              writer.write(sample)
    69  
    70  
    71  def view_shuffled_shards():
    72      objects = client.bucket("images").list_all_objects(prefix="shuffled")
    73      print([entry.name for entry in objects])
    74  
    75  
    76  if __name__ == "__main__":
    77      client = Client(AIS_ENDPOINT)
    78      image_bucket = client.bucket(bucket_name).create(exist_ok=True)
    79      base_dir = Path("/home/aaron/pets")
    80      pet_image_dir = base_dir.joinpath("images")
    81      pet_trimap_dir = base_dir.joinpath("annotations").joinpath("trimaps")
    82      pet_annotations_file = base_dir.joinpath("annotations").joinpath("list.txt")
    83      samples = create_sample_generator(pet_image_dir, pet_trimap_dir, pet_annotations_file)
    84      load_data(image_bucket, samples)