github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/docs/examples/transform-images-sdk/transform_sdk.py (about)

     1  import os
     2  import io
     3  import sys
     4  from PIL import Image
     5  from torchvision import transforms
     6  import torch
     7  
     8  from aistore.pytorch import AISDataset
     9  from aistore.sdk import Client
    10  from aistore.sdk.multiobj import ObjectRange
    11  
    12  AISTORE_ENDPOINT = os.getenv("AIS_ENDPOINT", "http://192.168.49.2:8080")
    13  client = Client(AISTORE_ENDPOINT)
    14  bucket_name = "images"
    15  
    16  
    17  def etl():
    18      def img_to_bytes(img):
    19          buf = io.BytesIO()
    20          img = img.convert('RGB')
    21          img.save(buf, format='JPEG')
    22          return buf.getvalue()
    23  
    24      input_bytes = sys.stdin.buffer.read()
    25      image = Image.open(io.BytesIO(input_bytes)).convert('RGB')
    26      preprocessing = transforms.Compose([
    27          transforms.RandomResizedCrop(224),
    28          transforms.RandomHorizontalFlip(),
    29          transforms.ToTensor(),
    30          transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    31          transforms.ToPILImage(),
    32          transforms.Lambda(img_to_bytes),
    33      ])
    34      processed_bytes = preprocessing(image)
    35      sys.stdout.buffer.write(processed_bytes)
    36  
    37  
    38  def show_image(image_data):
    39      with Image.open(io.BytesIO(image_data)) as image:
    40          image.show()
    41  
    42  
    43  def load_data():
    44      # First, let's create a bucket and put the data into AIS
    45      bucket = client.bucket(bucket_name).create()
    46      bucket.put_files("images/", pattern="*.jpg")
    47      # Show a random (non-transformed) image from the dataset
    48      image_data = bucket.object("Bengal_171.jpg").get().read_all()
    49      show_image(image_data)
    50  
    51  
    52  def create_etl(etl_name):
    53      image_etl = client.etl(etl_name)
    54      image_etl.init_code(
    55                             transform=etl,
    56                             dependencies=["torchvision"],
    57                             communication_type="io")
    58      return image_etl
    59  
    60  
    61  def show_etl(etl):
    62      print(client.cluster().list_running_etls())
    63      print(etl.view())
    64  
    65  
    66  def get_with_etl(etl):
    67      transformed_data = client.bucket(bucket_name).object("Bengal_171.jpg").get(etl_name=etl.name).read_all()
    68      show_image(transformed_data)
    69  
    70  
    71  def etl_bucket(etl):
    72      dest_bucket = client.bucket("transformed-images").create()
    73      transform_job = client.bucket(bucket_name).transform(etl_name=etl.name, to_bck=dest_bucket)
    74      client.job(transform_job).wait()
    75      print(entry.name for entry in dest_bucket.list_all_objects())
    76  
    77  
    78  def etl_group(etl):
    79      dest_bucket = client.bucket("transformed-selected-images").create()
    80      # Select a range of objects from the source bucket
    81      object_range = ObjectRange(min_index=0, max_index=100, prefix="Bengal_", suffix=".jpg")
    82      object_group = client.bucket(bucket_name).objects(obj_range=object_range)
    83      transform_job = object_group.transform(etl_name=etl.name, to_bck=dest_bucket)
    84      client.job(transform_job).wait_for_idle(timeout=300)
    85      print([entry.name for entry in dest_bucket.list_all_objects()])
    86  
    87  
    88  def create_dataloader():
    89      # Construct a dataset and dataloader to read data from the transformed bucket
    90      dataset = AISDataset(AISTORE_ENDPOINT, "ais://transformed-images")
    91      train_loader = torch.utils.data.DataLoader(dataset, shuffle=True)
    92      return train_loader
    93  
    94  
    95  if __name__ == "__main__":
    96      load_data()
    97      image_etl = create_etl("transform-images")
    98      show_etl(image_etl)
    99      get_with_etl(image_etl)
   100      etl_bucket(image_etl)
   101      etl_group(image_etl)
   102      data_loader = create_dataloader()