github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/python/aistore/pytorch/aisio.py (about) 1 """ 2 AIS IO Datapipe 3 Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. 4 """ 5 6 from typing import Iterator, Tuple, List 7 8 from torch.utils.data.dataset import T_co 9 from torchdata.datapipes import functional_datapipe 10 11 from torchdata.datapipes.iter import IterDataPipe 12 from torchdata.datapipes.utils import StreamWrapper 13 14 from aistore.sdk.ais_source import AISSource 15 16 try: 17 from aistore.sdk import Client 18 from aistore.pytorch.utils import parse_url, unparse_url 19 20 HAS_AIS = True 21 except ImportError: 22 HAS_AIS = False 23 24 25 def _assert_aistore() -> None: 26 if not HAS_AIS: 27 raise ModuleNotFoundError( 28 "Package `aistore` is required to be installed to use this datapipe." 29 "Please run `pip install aistore` or `conda install aistore` to install the package" 30 "For more info visit: https://github.com/NVIDIA/aistore/blob/main/python/aistore/" 31 ) 32 33 34 # pylint: disable=unused-variable 35 # pylint: disable=W0223 36 @functional_datapipe("ais_list_files") 37 class AISFileListerIterDataPipe(IterDataPipe[str]): 38 """ 39 Iterable Datapipe that lists files from the AIStore backends with the given URL prefixes. 40 (functional name: ``list_files_by_ais``). 41 Acceptable prefixes include but not limited to - `ais://bucket-name`, `ais://bucket-name/` 42 43 Note: 44 - This function also supports files from multiple backends (`aws://..`, `gcp://..`, etc.) 45 - Input must be a list and direct URLs are not supported. 46 - length is -1 by default, all calls to len() are invalid as 47 not all items are iterated at the start. 48 - This internally uses AIStore Python SDK. 49 50 Args: 51 source_datapipe(IterDataPipe[str]): a DataPipe that contains URLs/URL 52 prefixes to objects on AIS 53 length(int): length of the datapipe 54 url(str): AIStore endpoint 55 56 Example: 57 >>> from torchdata.datapipes.iter import IterableWrapper, AISFileLister 58 >>> ais_prefixes = IterableWrapper(['gcp://bucket-name/folder/', 'aws:bucket-name/folder/', 59 >>> 'ais://bucket-name/folder/', ...]) 60 >>> dp_ais_urls = AISFileLister(url='localhost:8080', source_datapipe=ais_prefixes) 61 >>> for url in dp_ais_urls: 62 ... pass 63 >>> # Functional API 64 >>> dp_ais_urls = ais_prefixes.list_files_by_ais(url='localhost:8080') 65 >>> for url in dp_ais_urls: 66 ... pass 67 """ 68 69 def __init__( 70 self, source_datapipe: IterDataPipe[str], url: str, length: int = -1 71 ) -> None: 72 _assert_aistore() 73 self.source_datapipe: IterDataPipe[str] = source_datapipe 74 self.length: int = length 75 self.client = Client(url) 76 77 def __iter__(self) -> Iterator[str]: 78 for prefix in self.source_datapipe: 79 provider, bck_name, prefix = parse_url(prefix) 80 obj_iter = self.client.bucket(bck_name, provider).list_objects_iter( 81 prefix=prefix 82 ) 83 for entry in obj_iter: 84 yield unparse_url( 85 provider=provider, bck_name=bck_name, obj_name=entry.name 86 ) 87 88 def __len__(self) -> int: 89 if self.length == -1: 90 raise TypeError(f"{type(self).__name__} instance doesn't have valid length") 91 return self.length 92 93 94 # pylint: disable=unused-variable 95 # pylint: disable=W0223 96 @functional_datapipe("ais_load_files") 97 class AISFileLoaderIterDataPipe(IterDataPipe[Tuple[str, StreamWrapper]]): 98 """ 99 Iterable DataPipe that loads files from AIStore with the given URLs (functional name: ``load_files_by_ais``). 100 Iterates all files in BytesIO format and returns a tuple (url, BytesIO). 101 102 Note: 103 - This function also supports files from multiple backends (`aws://..`, `gcp://..`, etc) 104 - Input must be a list and direct URLs are not supported. 105 - This internally uses AIStore Python SDK. 106 - An `etl_name` can be provided to run an existing ETL on the AIS cluster. 107 See https://github.com/NVIDIA/aistore/blob/main/docs/etl.md for more info on AIStore ETL. 108 109 Args: 110 source_datapipe(IterDataPipe[str]): a DataPipe that contains URLs/URL prefixes to objects 111 length(int): length of the datapipe 112 url(str): AIStore endpoint 113 etl_name (str, optional): Optional etl on the AIS cluster to apply to each object 114 115 Example: 116 >>> from torchdata.datapipes.iter import IterableWrapper, AISFileLister,AISFileLoader 117 >>> ais_prefixes = IterableWrapper(['gcp://bucket-name/folder/', 'aws:bucket-name/folder/', 118 >>> 'ais://bucket-name/folder/', ...]) 119 >>> dp_ais_urls = AISFileLister(url='localhost:8080', source_datapipe=ais_prefixes) 120 >>> dp_cloud_files = AISFileLoader(url='localhost:8080', source_datapipe=dp_ais_urls) 121 >>> for url, file in dp_cloud_files: 122 ... pass 123 >>> # Functional API 124 >>> dp_cloud_files = dp_ais_urls.load_files_by_ais(url='localhost:8080') 125 >>> for url, file in dp_cloud_files: 126 ... pass 127 """ 128 129 def __init__( 130 self, 131 source_datapipe: IterDataPipe[str], 132 url: str, 133 length: int = -1, 134 etl_name: str = None, 135 ) -> None: 136 _assert_aistore() 137 self.source_datapipe: IterDataPipe[str] = source_datapipe 138 self.length = length 139 self.client = Client(url) 140 self.etl_name = etl_name 141 142 def __iter__(self) -> Iterator[Tuple[str, StreamWrapper]]: 143 for url in self.source_datapipe: 144 provider, bck_name, obj_name = parse_url(url) 145 yield url, StreamWrapper( 146 self.client.bucket(bck_name=bck_name, provider=provider) 147 .object(obj_name=obj_name) 148 .get(etl_name=self.etl_name) 149 .raw() 150 ) 151 152 def __len__(self) -> int: 153 return len(self.source_datapipe) 154 155 156 @functional_datapipe("ais_list_sources") 157 class AISSourceLister(IterDataPipe[str]): 158 def __init__(self, ais_sources: List[AISSource], prefix="", etl_name=None): 159 """ 160 Iterable DataPipe over the full URLs for each of the provided AIS source object types 161 162 Args: 163 ais_sources (List[AISSource]): List of types implementing the AISSource interface: Bucket, ObjectGroup, 164 Object, etc. 165 prefix (str, optional): Filter results to only include objects with names starting with this prefix 166 etl_name (str, optional): Pre-existing ETL on AIS to apply to all selected objects on the cluster side 167 """ 168 _assert_aistore() 169 self.sources = ais_sources 170 self.prefix = prefix 171 self.etl_name = etl_name 172 173 def __getitem__(self, index) -> T_co: 174 raise NotImplementedError 175 176 def __iter__(self) -> Iterator[T_co]: 177 for source in self.sources: 178 for url in source.list_urls(prefix=self.prefix, etl_name=self.etl_name): 179 yield url