github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/python/aistore/pytorch/dataset.py (about)

     1  """
     2  AIS Plugin for PyTorch
     3  
     4  PyTorch Dataset and DataLoader for AIS.
     5  
     6  Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
     7  """
     8  
     9  from typing import List, Union
    10  from torch.utils.data import Dataset, IterableDataset
    11  
    12  from aistore.sdk import Client
    13  from aistore.sdk.ais_source import AISSource
    14  from aistore.pytorch.utils import list_objects, list_objects_iterator
    15  
    16  
    17  class AISBaseClass:
    18      """
    19      A base class for creating AIS Datasets for PyTorch
    20  
    21      Args:
    22          client_url(str): AIS endpoint URL
    23          urls_list (str, List[str]): single or list of url prefixes objects to load data
    24          ais_source_list (AISSource, List[AISSource]): single or list of AISSource objects to load data
    25      """
    26  
    27      def __init__(
    28          self,
    29          client_url: str,
    30          urls_list: Union[str, List],
    31          ais_source_list: Union[AISSource, List[AISSource]],
    32      ) -> None:
    33          self.client = Client(client_url)
    34          if isinstance(urls_list, str):
    35              urls_list = [urls_list]
    36          if isinstance(ais_source_list, AISSource):
    37              ais_source_list = [ais_source_list]
    38  
    39          self._objects = list_objects(self.client, urls_list, ais_source_list)
    40  
    41  
    42  class AISDataset(AISBaseClass, Dataset):
    43      """
    44      A map-style dataset for objects in AIS
    45      If `etl_name` is provided, that ETL must already exist on the AIStore cluster
    46  
    47      Args:
    48          client_url (str): AIS endpoint URL
    49          urls_list (str, List[str]): single or list of url prefixes objects to load data
    50          ais_source_list (AISSource, List[AISSource]): single or list of AISSource objects to load data
    51          etl_name (str, optional): Optional etl on the AIS cluster to apply to each object
    52  
    53      Note:
    54          Each object is represented as a tuple of object_name(str) and object_content(bytes)
    55      """
    56  
    57      def __init__(
    58          self,
    59          client_url: str,
    60          urls_list: Union[str, List[str]] = [],
    61          ais_source_list: Union[AISSource, List[AISSource]] = [],
    62          etl_name=None,
    63      ):
    64          if not urls_list and not ais_source_list:
    65              raise ValueError(
    66                  "At least one of urls_list or ais_source_list must be provided"
    67              )
    68          AISBaseClass.__init__(self, client_url, urls_list, ais_source_list)
    69          self.etl_name = etl_name
    70  
    71      def __len__(self):
    72          return len(self._objects)
    73  
    74      def __getitem__(self, index: int):
    75          obj = self._objects[index]
    76          obj_name = self._objects[index].name
    77          content = obj.get(etl_name=self.etl_name).read_all()
    78          return obj_name, content
    79  
    80  
    81  class AISBaseClassIter:
    82      """
    83      A base class for creating AIS Iterable Datasets for PyTorch
    84  
    85      Args:
    86          client_url (str): AIS endpoint URL
    87          urls_list (str, List[str]): single or list of url prefixes objects to load data
    88          ais_source_list (AISSource, List[AISSource]): single or list of AISSource objects to load data
    89      """
    90  
    91      def __init__(
    92          self,
    93          client_url: str,
    94          urls_list: Union[str, List[str]],
    95          ais_source_list: Union[AISSource, List[AISSource]],
    96      ) -> None:
    97          self.client = Client(client_url)
    98          if isinstance(urls_list, str):
    99              urls_list = [urls_list]
   100          if isinstance(ais_source_list, AISSource):
   101              ais_source_list = [ais_source_list]
   102          self.urls_list = urls_list
   103          self.ais_source_list = ais_source_list
   104          self._reset_iterator()
   105  
   106      def _reset_iterator(self):
   107          """
   108          Reset the object iterator to start from the beginning
   109          """
   110          self._object_iter = list_objects_iterator(
   111              self.client, self.urls_list, self.ais_source_list
   112          )
   113  
   114  
   115  class AISIterDataset(AISBaseClassIter, IterableDataset):
   116      """
   117      A iterable style dataset which iterates over objects in AIS
   118      If `etl_name` is provided, that ETL must already exist on the AIStore cluster
   119  
   120      Args:
   121          client_url (str): AIS endpoint URL
   122          urls_list (str, List[str]): single or list of url prefixes objects to load data
   123          ais_source_list (AISSource, List[AISSource]): single or list of AISSource objects to load data
   124          etl_name (str, optional): Optional etl on the AIS cluster to apply to each object
   125  
   126      Note:
   127          Each object is represented as a tuple of object_name(str) and object_content(bytes)
   128      """
   129  
   130      def __init__(
   131          self,
   132          client_url: str,
   133          urls_list: Union[str, List[str]] = [],
   134          ais_source_list: Union[AISSource, List[AISSource]] = [],
   135          etl_name=None,
   136      ):
   137          if not urls_list and not ais_source_list:
   138              raise ValueError(
   139                  "At least one of urls_list or ais_source_list must be provided"
   140              )
   141          AISBaseClassIter.__init__(self, client_url, urls_list, ais_source_list)
   142          self.etl_name = etl_name
   143          self.length = None
   144  
   145      def __iter__(self):
   146          self._reset_iterator()
   147          for obj in self._object_iter:
   148              obj_name = obj.name
   149              content = obj.get(etl_name=self.etl_name).read_all()
   150              yield obj_name, content
   151  
   152      def __len__(self):
   153          if not self.length:
   154              self._reset_iterator()
   155              self.length = self._calculate_len()
   156          return self.length
   157  
   158      def _calculate_len(self):
   159          return sum(1 for _ in self._object_iter)