github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/docs/examples/transform-images-sdk/transform_sdk.py (about) 1 import os 2 import io 3 import sys 4 from PIL import Image 5 from torchvision import transforms 6 import torch 7 8 from aistore.pytorch import AISDataset 9 from aistore.sdk import Client 10 from aistore.sdk.multiobj import ObjectRange 11 12 AISTORE_ENDPOINT = os.getenv("AIS_ENDPOINT", "http://192.168.49.2:8080") 13 client = Client(AISTORE_ENDPOINT) 14 bucket_name = "images" 15 16 17 def etl(): 18 def img_to_bytes(img): 19 buf = io.BytesIO() 20 img = img.convert('RGB') 21 img.save(buf, format='JPEG') 22 return buf.getvalue() 23 24 input_bytes = sys.stdin.buffer.read() 25 image = Image.open(io.BytesIO(input_bytes)).convert('RGB') 26 preprocessing = transforms.Compose([ 27 transforms.RandomResizedCrop(224), 28 transforms.RandomHorizontalFlip(), 29 transforms.ToTensor(), 30 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 31 transforms.ToPILImage(), 32 transforms.Lambda(img_to_bytes), 33 ]) 34 processed_bytes = preprocessing(image) 35 sys.stdout.buffer.write(processed_bytes) 36 37 38 def show_image(image_data): 39 with Image.open(io.BytesIO(image_data)) as image: 40 image.show() 41 42 43 def load_data(): 44 # First, let's create a bucket and put the data into AIS 45 bucket = client.bucket(bucket_name).create() 46 bucket.put_files("images/", pattern="*.jpg") 47 # Show a random (non-transformed) image from the dataset 48 image_data = bucket.object("Bengal_171.jpg").get().read_all() 49 show_image(image_data) 50 51 52 def create_etl(etl_name): 53 image_etl = client.etl(etl_name) 54 image_etl.init_code( 55 transform=etl, 56 dependencies=["torchvision"], 57 communication_type="io") 58 return image_etl 59 60 61 def show_etl(etl): 62 print(client.cluster().list_running_etls()) 63 print(etl.view()) 64 65 66 def get_with_etl(etl): 67 transformed_data = client.bucket(bucket_name).object("Bengal_171.jpg").get(etl_name=etl.name).read_all() 68 show_image(transformed_data) 69 70 71 def etl_bucket(etl): 72 dest_bucket = client.bucket("transformed-images").create() 73 transform_job = client.bucket(bucket_name).transform(etl_name=etl.name, to_bck=dest_bucket) 74 client.job(transform_job).wait() 75 print(entry.name for entry in dest_bucket.list_all_objects()) 76 77 78 def etl_group(etl): 79 dest_bucket = client.bucket("transformed-selected-images").create() 80 # Select a range of objects from the source bucket 81 object_range = ObjectRange(min_index=0, max_index=100, prefix="Bengal_", suffix=".jpg") 82 object_group = client.bucket(bucket_name).objects(obj_range=object_range) 83 transform_job = object_group.transform(etl_name=etl.name, to_bck=dest_bucket) 84 client.job(transform_job).wait_for_idle(timeout=300) 85 print([entry.name for entry in dest_bucket.list_all_objects()]) 86 87 88 def create_dataloader(): 89 # Construct a dataset and dataloader to read data from the transformed bucket 90 dataset = AISDataset(AISTORE_ENDPOINT, "ais://transformed-images") 91 train_loader = torch.utils.data.DataLoader(dataset, shuffle=True) 92 return train_loader 93 94 95 if __name__ == "__main__": 96 load_data() 97 image_etl = create_etl("transform-images") 98 show_etl(image_etl) 99 get_with_etl(image_etl) 100 etl_bucket(image_etl) 101 etl_group(image_etl) 102 data_loader = create_dataloader()