github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/docs/examples/aisio_webdataset/pytorch_webdataset.py (about) 1 import os 2 3 import torchvision 4 5 from aistore.pytorch import AISSourceLister 6 from aistore.sdk import Client 7 import webdataset as wds 8 9 10 AIS_ENDPOINT = os.getenv("AIS_ENDPOINT") 11 client = Client(AIS_ENDPOINT) 12 bucket_name = "images" 13 etl_name = "wd-transform" 14 15 16 def show_image_tensor(image_data): 17 transform = torchvision.transforms.ToPILImage() 18 image = transform(image_data) 19 image.show() 20 21 22 def create_dataset() -> wds.WebDataset: 23 bucket = client.bucket(bucket_name) 24 # Get a list of urls for each object in AIS, with ETL applied, converted to the format WebDataset expects 25 sources = AISSourceLister(ais_sources=[bucket], etl_name=etl_name).map(lambda source_url: {"url": source_url})\ 26 .shuffle() 27 # Load shuffled list of transformed shards into WebDataset pipeline 28 dataset = wds.WebDataset(sources) 29 # Shuffle samples and apply built-in webdataset decoder for image files 30 dataset = dataset.shuffle(size=1000).decode("torchrgb") 31 # Return iterator over samples as tuples in batches 32 return dataset.to_tuple("cls", "image.jpg", "trimap.png").batched(16) 33 34 35 def create_dataloader(dataset) -> wds.WebLoader: 36 loader = wds.WebLoader(dataset, num_workers=4, batch_size=None) 37 return loader.unbatched().shuffle(1000).batched(64) 38 39 40 def view_data(dataloader): 41 # Get the first batch 42 batch = next(iter(dataloader)) 43 classes, images, trimaps = batch 44 # Result is a set of tensors with the first dimension being the batch size 45 print(classes.shape, images.shape, trimaps.shape) 46 # View the first images in the first batch 47 show_image_tensor(images[0]) 48 show_image_tensor(trimaps[0]) 49 50 51 if __name__ == '__main__': 52 wd_dataset = create_dataset() 53 wd_dataloader = create_dataloader(wd_dataset) 54 view_data(wd_dataloader) 55 first_batch = next(iter(wd_dataloader)) 56 classes, images, trimaps = first_batch