github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/python/aistore/pytorch/dataset.py (about) 1 """ 2 AIS Plugin for PyTorch 3 4 PyTorch Dataset and DataLoader for AIS. 5 6 Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. 7 """ 8 9 from typing import List, Union 10 from torch.utils.data import Dataset, IterableDataset 11 12 from aistore.sdk import Client 13 from aistore.sdk.ais_source import AISSource 14 from aistore.pytorch.utils import list_objects, list_objects_iterator 15 16 17 class AISBaseClass: 18 """ 19 A base class for creating AIS Datasets for PyTorch 20 21 Args: 22 client_url(str): AIS endpoint URL 23 urls_list (str, List[str]): single or list of url prefixes objects to load data 24 ais_source_list (AISSource, List[AISSource]): single or list of AISSource objects to load data 25 """ 26 27 def __init__( 28 self, 29 client_url: str, 30 urls_list: Union[str, List], 31 ais_source_list: Union[AISSource, List[AISSource]], 32 ) -> None: 33 self.client = Client(client_url) 34 if isinstance(urls_list, str): 35 urls_list = [urls_list] 36 if isinstance(ais_source_list, AISSource): 37 ais_source_list = [ais_source_list] 38 39 self._objects = list_objects(self.client, urls_list, ais_source_list) 40 41 42 class AISDataset(AISBaseClass, Dataset): 43 """ 44 A map-style dataset for objects in AIS 45 If `etl_name` is provided, that ETL must already exist on the AIStore cluster 46 47 Args: 48 client_url (str): AIS endpoint URL 49 urls_list (str, List[str]): single or list of url prefixes objects to load data 50 ais_source_list (AISSource, List[AISSource]): single or list of AISSource objects to load data 51 etl_name (str, optional): Optional etl on the AIS cluster to apply to each object 52 53 Note: 54 Each object is represented as a tuple of object_name(str) and object_content(bytes) 55 """ 56 57 def __init__( 58 self, 59 client_url: str, 60 urls_list: Union[str, List[str]] = [], 61 ais_source_list: Union[AISSource, List[AISSource]] = [], 62 etl_name=None, 63 ): 64 if not urls_list and not ais_source_list: 65 raise ValueError( 66 "At least one of urls_list or ais_source_list must be provided" 67 ) 68 AISBaseClass.__init__(self, client_url, urls_list, ais_source_list) 69 self.etl_name = etl_name 70 71 def __len__(self): 72 return len(self._objects) 73 74 def __getitem__(self, index: int): 75 obj = self._objects[index] 76 obj_name = self._objects[index].name 77 content = obj.get(etl_name=self.etl_name).read_all() 78 return obj_name, content 79 80 81 class AISBaseClassIter: 82 """ 83 A base class for creating AIS Iterable Datasets for PyTorch 84 85 Args: 86 client_url (str): AIS endpoint URL 87 urls_list (str, List[str]): single or list of url prefixes objects to load data 88 ais_source_list (AISSource, List[AISSource]): single or list of AISSource objects to load data 89 """ 90 91 def __init__( 92 self, 93 client_url: str, 94 urls_list: Union[str, List[str]], 95 ais_source_list: Union[AISSource, List[AISSource]], 96 ) -> None: 97 self.client = Client(client_url) 98 if isinstance(urls_list, str): 99 urls_list = [urls_list] 100 if isinstance(ais_source_list, AISSource): 101 ais_source_list = [ais_source_list] 102 self.urls_list = urls_list 103 self.ais_source_list = ais_source_list 104 self._reset_iterator() 105 106 def _reset_iterator(self): 107 """ 108 Reset the object iterator to start from the beginning 109 """ 110 self._object_iter = list_objects_iterator( 111 self.client, self.urls_list, self.ais_source_list 112 ) 113 114 115 class AISIterDataset(AISBaseClassIter, IterableDataset): 116 """ 117 A iterable style dataset which iterates over objects in AIS 118 If `etl_name` is provided, that ETL must already exist on the AIStore cluster 119 120 Args: 121 client_url (str): AIS endpoint URL 122 urls_list (str, List[str]): single or list of url prefixes objects to load data 123 ais_source_list (AISSource, List[AISSource]): single or list of AISSource objects to load data 124 etl_name (str, optional): Optional etl on the AIS cluster to apply to each object 125 126 Note: 127 Each object is represented as a tuple of object_name(str) and object_content(bytes) 128 """ 129 130 def __init__( 131 self, 132 client_url: str, 133 urls_list: Union[str, List[str]] = [], 134 ais_source_list: Union[AISSource, List[AISSource]] = [], 135 etl_name=None, 136 ): 137 if not urls_list and not ais_source_list: 138 raise ValueError( 139 "At least one of urls_list or ais_source_list must be provided" 140 ) 141 AISBaseClassIter.__init__(self, client_url, urls_list, ais_source_list) 142 self.etl_name = etl_name 143 self.length = None 144 145 def __iter__(self): 146 self._reset_iterator() 147 for obj in self._object_iter: 148 obj_name = obj.name 149 content = obj.get(etl_name=self.etl_name).read_all() 150 yield obj_name, content 151 152 def __len__(self): 153 if not self.length: 154 self._reset_iterator() 155 self.length = self._calculate_len() 156 return self.length 157 158 def _calculate_len(self): 159 return sum(1 for _ in self._object_iter)