github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/python/aistore/sdk/dataset/dataset_config.py (about) 1 # 2 # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. 3 # 4 5 import logging 6 from typing import List, Dict, Any, Generator, Tuple 7 from pathlib import Path 8 from webdataset import ShardWriter 9 10 from aistore.sdk.dataset.config_attribute import ConfigAttribute 11 from aistore.sdk.const import DEFAULT_DATASET_MAX_COUNT 12 13 14 # pylint: disable=too-few-public-methods 15 class DatasetConfig: 16 """ 17 Represents the configuration for managing datasets, particularly focusing on how data attributes are structured 18 19 Args: 20 primary_attribute (ConfigAttribute): The primary key used for looking up any secondary_attributes will 21 be determined by the filename of each sample defined by primary_attribute 22 secondary_attributes (List[ConfigAttribute], optional): A list of configurations for 23 each attribute or feature in the dataset 24 """ 25 26 def __init__( 27 self, 28 primary_attribute: ConfigAttribute, 29 secondary_attributes: List[ConfigAttribute] = None, 30 ): 31 self.primary_attribute = primary_attribute 32 self.secondary_attributes = secondary_attributes if secondary_attributes else [] 33 34 def write_shards(self, skip_missing: bool, **kwargs): 35 """ 36 Write the dataset to a bucket in webdataset format and log the missing attributes 37 38 Args: 39 skip_missing (bool, optional): Skip samples that are missing one or more attributes, defaults to True 40 **kwargs: Additional arguments to pass to the webdataset writer 41 """ 42 logger = logging.getLogger(f"{__name__}.write_shards") 43 max_shard_items = kwargs.get("maxcount", DEFAULT_DATASET_MAX_COUNT) 44 num_digits = len(str(max_shard_items)) 45 kwargs["pattern"] = kwargs.get("pattern", "dataset") + f"-%0{num_digits}d.tar" 46 shard_writer = ShardWriter(**kwargs) 47 48 dataset = self.generate_dataset(max_shard_items) 49 for sample, missing_attributes in dataset: 50 if missing_attributes: 51 missing_attributes_str = ", ".join(missing_attributes) 52 if skip_missing: 53 logger.warning( 54 "Missing attributes: %s - Skipping sample.", 55 missing_attributes_str, 56 ) 57 else: 58 logger.warning( 59 "Missing attributes: %s - Including sample without missing attributes.", 60 missing_attributes_str, 61 ) 62 shard_writer.write(sample) 63 else: 64 shard_writer.write(sample) 65 66 shard_writer.close() 67 68 def generate_dataset( 69 self, 70 max_shard_items: int, 71 ) -> Generator[Tuple[Dict[str, Any], List[str]], None, None]: 72 """ 73 Generate a dataset in webdataset format 74 75 Args: 76 max_shard_items (int): The maximum number of items to include in a shard 77 78 Returns: 79 Generator (Tuple[Dict[str, Any], List[str]]): A generator that yields samples in webdataset format 80 and a list of missing attributes 81 """ 82 all_attributes = [self.primary_attribute] + self.secondary_attributes 83 # Generate the dataset 84 for index, file in enumerate( 85 Path(self.primary_attribute.path).glob( 86 "*." + self.primary_attribute.file_type 87 ) 88 ): 89 filename = file.stem 90 item = {} 91 missing_attributes = [] 92 for cfg in all_attributes: 93 key, data = cfg.get_data_for_entry(filename) 94 if not data: 95 missing_attributes.append(f"{filename} - {key}") 96 else: 97 item[key] = data 98 item["__key__"] = self._get_format_string(max_shard_items) % index 99 yield item, missing_attributes 100 101 @staticmethod 102 def _get_format_string(val) -> str: 103 """ 104 Get a __key__ string for an item in webdataset format 105 """ 106 num_digits = len(str(val)) 107 format_str = "sample_%0" + str(num_digits) + "d" 108 return format_str