github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/python/aistore/pytorch/aisio.py (about)

     1  """
     2  AIS IO Datapipe
     3  Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
     4  """
     5  
     6  from typing import Iterator, Tuple, List
     7  
     8  from torch.utils.data.dataset import T_co
     9  from torchdata.datapipes import functional_datapipe
    10  
    11  from torchdata.datapipes.iter import IterDataPipe
    12  from torchdata.datapipes.utils import StreamWrapper
    13  
    14  from aistore.sdk.ais_source import AISSource
    15  
    16  try:
    17      from aistore.sdk import Client
    18      from aistore.pytorch.utils import parse_url, unparse_url
    19  
    20      HAS_AIS = True
    21  except ImportError:
    22      HAS_AIS = False
    23  
    24  
    25  def _assert_aistore() -> None:
    26      if not HAS_AIS:
    27          raise ModuleNotFoundError(
    28              "Package `aistore` is required to be installed to use this datapipe."
    29              "Please run `pip install aistore` or `conda install aistore` to install the package"
    30              "For more info visit: https://github.com/NVIDIA/aistore/blob/main/python/aistore/"
    31          )
    32  
    33  
    34  # pylint: disable=unused-variable
    35  # pylint: disable=W0223
    36  @functional_datapipe("ais_list_files")
    37  class AISFileListerIterDataPipe(IterDataPipe[str]):
    38      """
    39      Iterable Datapipe that lists files from the AIStore backends with the given URL prefixes.
    40          (functional name: ``list_files_by_ais``).
    41      Acceptable prefixes include but not limited to - `ais://bucket-name`, `ais://bucket-name/`
    42  
    43      Note:
    44      -   This function also supports files from multiple backends (`aws://..`, `gcp://..`, etc.)
    45      -   Input must be a list and direct URLs are not supported.
    46      -   length is -1 by default, all calls to len() are invalid as
    47          not all items are iterated at the start.
    48      -   This internally uses AIStore Python SDK.
    49  
    50      Args:
    51          source_datapipe(IterDataPipe[str]): a DataPipe that contains URLs/URL
    52                                              prefixes to objects on AIS
    53          length(int): length of the datapipe
    54          url(str): AIStore endpoint
    55  
    56      Example:
    57          >>> from torchdata.datapipes.iter import IterableWrapper, AISFileLister
    58          >>> ais_prefixes = IterableWrapper(['gcp://bucket-name/folder/', 'aws:bucket-name/folder/',
    59          >>>        'ais://bucket-name/folder/', ...])
    60          >>> dp_ais_urls = AISFileLister(url='localhost:8080', source_datapipe=ais_prefixes)
    61          >>> for url in dp_ais_urls:
    62          ...     pass
    63          >>> # Functional API
    64          >>> dp_ais_urls = ais_prefixes.list_files_by_ais(url='localhost:8080')
    65          >>> for url in dp_ais_urls:
    66          ...     pass
    67      """
    68  
    69      def __init__(
    70          self, source_datapipe: IterDataPipe[str], url: str, length: int = -1
    71      ) -> None:
    72          _assert_aistore()
    73          self.source_datapipe: IterDataPipe[str] = source_datapipe
    74          self.length: int = length
    75          self.client = Client(url)
    76  
    77      def __iter__(self) -> Iterator[str]:
    78          for prefix in self.source_datapipe:
    79              provider, bck_name, prefix = parse_url(prefix)
    80              obj_iter = self.client.bucket(bck_name, provider).list_objects_iter(
    81                  prefix=prefix
    82              )
    83              for entry in obj_iter:
    84                  yield unparse_url(
    85                      provider=provider, bck_name=bck_name, obj_name=entry.name
    86                  )
    87  
    88      def __len__(self) -> int:
    89          if self.length == -1:
    90              raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
    91          return self.length
    92  
    93  
    94  # pylint: disable=unused-variable
    95  # pylint: disable=W0223
    96  @functional_datapipe("ais_load_files")
    97  class AISFileLoaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]):
    98      """
    99      Iterable DataPipe that loads files from AIStore with the given URLs (functional name: ``load_files_by_ais``).
   100      Iterates all files in BytesIO format and returns a tuple (url, BytesIO).
   101  
   102      Note:
   103      -   This function also supports files from multiple backends (`aws://..`, `gcp://..`, etc)
   104      -   Input must be a list and direct URLs are not supported.
   105      -   This internally uses AIStore Python SDK.
   106      -   An `etl_name` can be provided to run an existing ETL on the AIS cluster.
   107          See https://github.com/NVIDIA/aistore/blob/main/docs/etl.md for more info on AIStore ETL.
   108  
   109      Args:
   110          source_datapipe(IterDataPipe[str]): a DataPipe that contains URLs/URL prefixes to objects
   111          length(int): length of the datapipe
   112          url(str): AIStore endpoint
   113          etl_name (str, optional): Optional etl on the AIS cluster to apply to each object
   114  
   115      Example:
   116          >>> from torchdata.datapipes.iter import IterableWrapper, AISFileLister,AISFileLoader
   117          >>> ais_prefixes = IterableWrapper(['gcp://bucket-name/folder/', 'aws:bucket-name/folder/',
   118          >>>     'ais://bucket-name/folder/', ...])
   119          >>> dp_ais_urls = AISFileLister(url='localhost:8080', source_datapipe=ais_prefixes)
   120          >>> dp_cloud_files = AISFileLoader(url='localhost:8080', source_datapipe=dp_ais_urls)
   121          >>> for url, file in dp_cloud_files:
   122          ...     pass
   123          >>> # Functional API
   124          >>> dp_cloud_files = dp_ais_urls.load_files_by_ais(url='localhost:8080')
   125          >>> for url, file in dp_cloud_files:
   126          ...     pass
   127      """
   128  
   129      def __init__(
   130          self,
   131          source_datapipe: IterDataPipe[str],
   132          url: str,
   133          length: int = -1,
   134          etl_name: str = None,
   135      ) -> None:
   136          _assert_aistore()
   137          self.source_datapipe: IterDataPipe[str] = source_datapipe
   138          self.length = length
   139          self.client = Client(url)
   140          self.etl_name = etl_name
   141  
   142      def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]:
   143          for url in self.source_datapipe:
   144              provider, bck_name, obj_name = parse_url(url)
   145              yield url, StreamWrapper(
   146                  self.client.bucket(bck_name=bck_name, provider=provider)
   147                  .object(obj_name=obj_name)
   148                  .get(etl_name=self.etl_name)
   149                  .raw()
   150              )
   151  
   152      def __len__(self) -> int:
   153          return len(self.source_datapipe)
   154  
   155  
   156  @functional_datapipe("ais_list_sources")
   157  class AISSourceLister(IterDataPipe[str]):
   158      def __init__(self, ais_sources: List[AISSource], prefix="", etl_name=None):
   159          """
   160          Iterable DataPipe over the full URLs for each of the provided AIS source object types
   161  
   162          Args:
   163              ais_sources (List[AISSource]): List of types implementing the AISSource interface: Bucket, ObjectGroup,
   164               Object, etc.
   165              prefix (str, optional): Filter results to only include objects with names starting with this prefix
   166              etl_name (str, optional): Pre-existing ETL on AIS to apply to all selected objects on the cluster side
   167          """
   168          _assert_aistore()
   169          self.sources = ais_sources
   170          self.prefix = prefix
   171          self.etl_name = etl_name
   172  
   173      def __getitem__(self, index) -> T_co:
   174          raise NotImplementedError
   175  
   176      def __iter__(self) -> Iterator[T_co]:
   177          for source in self.sources:
   178              for url in source.list_urls(prefix=self.prefix, etl_name=self.etl_name):
   179                  yield url