github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/docs/examples/aisio_webdataset/etl_webdataset.py (about) 1 import io 2 import os 3 4 import torchvision 5 import webdataset as wds 6 from PIL import Image 7 from aistore.sdk import Client 8 from torch.utils.data import IterableDataset 9 from torch.utils.data.dataset import T_co 10 11 AIS_ENDPOINT = os.getenv("AIS_ENDPOINT") 12 bucket_name = "images" 13 etl_name = "wd-transform" 14 15 16 def show_image(image_data): 17 with Image.open(io.BytesIO(image_data)) as image: 18 image.show() 19 20 21 def wd_etl(object_url): 22 def img_to_bytes(img): 23 buf = io.BytesIO() 24 img = img.convert("RGB") 25 img.save(buf, format="JPEG") 26 return buf.getvalue() 27 28 def process_trimap(trimap_bytes): 29 image = Image.open(io.BytesIO(trimap_bytes)) 30 preprocessing = torchvision.transforms.Compose( 31 [ 32 torchvision.transforms.CenterCrop(350), 33 torchvision.transforms.Lambda(img_to_bytes) 34 ] 35 ) 36 return preprocessing(image) 37 38 def process_image(image_bytes): 39 image = Image.open(io.BytesIO(image_bytes)).convert("RGB") 40 preprocessing = torchvision.transforms.Compose( 41 [ 42 torchvision.transforms.CenterCrop(350), 43 torchvision.transforms.ToTensor(), 44 # Means and stds from ImageNet 45 torchvision.transforms.Normalize( 46 mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 47 ), 48 torchvision.transforms.ToPILImage(), 49 torchvision.transforms.Lambda(img_to_bytes), 50 ] 51 ) 52 return preprocessing(image) 53 54 # Initialize a WD object from the AIS URL 55 dataset = wds.WebDataset(object_url) 56 # Map the files for each individual sample to the appropriate processing function 57 processed_shard = dataset.map_dict(**{"image.jpg": process_image, "trimap.png": process_trimap}) 58 59 # Write the output to a memory buffer and return the value 60 buffer = io.BytesIO() 61 with wds.TarWriter(fileobj=buffer) as dst: 62 for sample in processed_shard: 63 dst.write(sample) 64 return buffer.getvalue() 65 66 67 def create_wd_etl(client): 68 client.etl(etl_name).init_code( 69 transform=wd_etl, 70 preimported_modules=["torch"], 71 dependencies=["webdataset", "pillow", "torch", "torchvision"], 72 communication_type="hpull", 73 transform_url=True 74 ) 75 76 77 class LocalTarDataset(IterableDataset): 78 """ 79 Builds a PyTorch IterableDataset from bytes in memory as if was read from a URL by WebDataset. This lets us 80 initialize a WebDataset Pipeline without writing to local disk and iterate over each record from a shard. 81 """ 82 def __getitem__(self, index) -> T_co: 83 raise NotImplemented 84 85 def __init__(self, input_bytes): 86 self.data = [{"url": "input_data", "stream": io.BytesIO(input_bytes)}] 87 88 def __iter__(self): 89 files = wds.tariterators.tar_file_expander(self.data) 90 samples = wds.tariterators.group_by_keys(files) 91 return samples 92 93 94 def read_object_tar(shard_data): 95 local_dataset = LocalTarDataset(shard_data) 96 sample = next(iter(local_dataset)) 97 show_image(sample.get('image.jpg')) 98 99 100 def transform_object_inline(): 101 single_object = client.bucket(bucket_name).object("samples-00.tar") 102 # Get object contents with ETL applied 103 processed_shard = single_object.get(etl_name=etl_name).read_all() 104 read_object_tar(processed_shard) 105 106 107 def transform_bucket_offline(): 108 dest_bucket = client.bucket("processed-samples").create(exist_ok=True) 109 # Transform the entire bucket, placing the output in the destination bucket 110 transform_job = client.bucket(bucket_name).transform(to_bck=dest_bucket, etl_name=etl_name) 111 client.job(transform_job).wait(verbose=True) 112 processed_shard = dest_bucket.object("samples-00.tar").get().read_all() 113 read_object_tar(processed_shard) 114 115 116 if __name__ == "__main__": 117 client = Client(AIS_ENDPOINT) 118 image_bucket = client.bucket(bucket_name) 119 create_wd_etl(client) 120 transform_object_inline() 121 transform_bucket_offline()