github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/docs/examples/aisio_webdataset/load_webdataset_example.py (about) 1 import os 2 from pathlib import Path 3 4 from aistore.sdk import Client 5 6 import webdataset as wds 7 8 AIS_ENDPOINT = os.getenv("AIS_ENDPOINT") 9 bucket_name = "images" 10 11 12 def parse_annotations(annotations_file): 13 classes = {} 14 # Parse the annotations file into a dictionary from file name -> pet class 15 with open(annotations_file, "r") as annotations: 16 for line in annotations.readlines(): 17 if line[0] == "#": 18 continue 19 file_name, pet_class = line.split(" ")[:2] 20 classes[file_name] = pet_class 21 return classes 22 23 24 def create_sample_generator(image_dir, trimap_dir, annotations_file): 25 classes = parse_annotations(annotations_file) 26 # Iterate over all image files 27 for index, image_file in enumerate(Path(image_dir).glob("*.jpg")): 28 # Use the image name to look up class and trimap files and create a sample entry 29 sample = create_sample(classes, trimap_dir, index, image_file) 30 if sample is None: 31 continue 32 # Yield optimizes memory by returning a generator that only generates samples as requested 33 yield sample 34 35 36 def create_sample(classes, trimap_dir, index, image_file): 37 file_name = str(image_file).split("/")[-1].split(".")[0] 38 try: 39 with open(image_file, "rb") as f: 40 image_data = f.read() 41 pet_class = classes.get(file_name) 42 with open(trimap_dir.joinpath(file_name + ".png"), "rb") as f: 43 trimap_data = f.read() 44 if not image_data or not pet_class or not trimap_data: 45 # Ignore incomplete records 46 return None 47 return { 48 "__key__": "sample_%04d" % index, 49 "image.jpg": image_data, 50 "cls": pet_class, 51 "trimap.png": trimap_data 52 } 53 # Ignoring records with any missing files 54 except FileNotFoundError as err: 55 print(err) 56 return None 57 58 59 def load_data(bucket, sample_generator): 60 61 def upload_shard(filename): 62 bucket.object(filename).put_file(filename) 63 os.unlink(filename) 64 65 # Writes data as tar to disk, uses callback function "post" to upload to AIS and delete 66 with wds.ShardWriter("samples-%02d.tar", maxcount=400, post=upload_shard) as writer: 67 for sample in sample_generator: 68 writer.write(sample) 69 70 71 def view_shuffled_shards(): 72 objects = client.bucket("images").list_all_objects(prefix="shuffled") 73 print([entry.name for entry in objects]) 74 75 76 if __name__ == "__main__": 77 client = Client(AIS_ENDPOINT) 78 image_bucket = client.bucket(bucket_name).create(exist_ok=True) 79 base_dir = Path("/home/aaron/pets") 80 pet_image_dir = base_dir.joinpath("images") 81 pet_trimap_dir = base_dir.joinpath("annotations").joinpath("trimaps") 82 pet_annotations_file = base_dir.joinpath("annotations").joinpath("list.txt") 83 samples = create_sample_generator(pet_image_dir, pet_trimap_dir, pet_annotations_file) 84 load_data(image_bucket, samples)