github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/python/aistore/sdk/dataset/dataset_config.py (about)

     1  #
     2  # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
     3  #
     4  
     5  import logging
     6  from typing import List, Dict, Any, Generator, Tuple
     7  from pathlib import Path
     8  from webdataset import ShardWriter
     9  
    10  from aistore.sdk.dataset.config_attribute import ConfigAttribute
    11  from aistore.sdk.const import DEFAULT_DATASET_MAX_COUNT
    12  
    13  
    14  # pylint: disable=too-few-public-methods
    15  class DatasetConfig:
    16      """
    17      Represents the configuration for managing datasets, particularly focusing on how data attributes are structured
    18  
    19      Args:
    20          primary_attribute (ConfigAttribute): The primary key used for looking up any secondary_attributes will
    21              be determined by the filename of each sample defined by primary_attribute
    22          secondary_attributes (List[ConfigAttribute], optional): A list of configurations for
    23              each attribute or feature in the dataset
    24      """
    25  
    26      def __init__(
    27          self,
    28          primary_attribute: ConfigAttribute,
    29          secondary_attributes: List[ConfigAttribute] = None,
    30      ):
    31          self.primary_attribute = primary_attribute
    32          self.secondary_attributes = secondary_attributes if secondary_attributes else []
    33  
    34      def write_shards(self, skip_missing: bool, **kwargs):
    35          """
    36          Write the dataset to a bucket in webdataset format and log the missing attributes
    37  
    38          Args:
    39              skip_missing (bool, optional): Skip samples that are missing one or more attributes, defaults to True
    40              **kwargs: Additional arguments to pass to the webdataset writer
    41          """
    42          logger = logging.getLogger(f"{__name__}.write_shards")
    43          max_shard_items = kwargs.get("maxcount", DEFAULT_DATASET_MAX_COUNT)
    44          num_digits = len(str(max_shard_items))
    45          kwargs["pattern"] = kwargs.get("pattern", "dataset") + f"-%0{num_digits}d.tar"
    46          shard_writer = ShardWriter(**kwargs)
    47  
    48          dataset = self.generate_dataset(max_shard_items)
    49          for sample, missing_attributes in dataset:
    50              if missing_attributes:
    51                  missing_attributes_str = ", ".join(missing_attributes)
    52                  if skip_missing:
    53                      logger.warning(
    54                          "Missing attributes: %s - Skipping sample.",
    55                          missing_attributes_str,
    56                      )
    57                  else:
    58                      logger.warning(
    59                          "Missing attributes: %s - Including sample without missing attributes.",
    60                          missing_attributes_str,
    61                      )
    62                      shard_writer.write(sample)
    63              else:
    64                  shard_writer.write(sample)
    65  
    66          shard_writer.close()
    67  
    68      def generate_dataset(
    69          self,
    70          max_shard_items: int,
    71      ) -> Generator[Tuple[Dict[str, Any], List[str]], None, None]:
    72          """
    73          Generate a dataset in webdataset format
    74  
    75          Args:
    76              max_shard_items (int): The maximum number of items to include in a shard
    77  
    78          Returns:
    79              Generator (Tuple[Dict[str, Any], List[str]]): A generator that yields samples in webdataset format
    80                  and a list of missing attributes
    81          """
    82          all_attributes = [self.primary_attribute] + self.secondary_attributes
    83          # Generate the dataset
    84          for index, file in enumerate(
    85              Path(self.primary_attribute.path).glob(
    86                  "*." + self.primary_attribute.file_type
    87              )
    88          ):
    89              filename = file.stem
    90              item = {}
    91              missing_attributes = []
    92              for cfg in all_attributes:
    93                  key, data = cfg.get_data_for_entry(filename)
    94                  if not data:
    95                      missing_attributes.append(f"{filename} - {key}")
    96                  else:
    97                      item[key] = data
    98              item["__key__"] = self._get_format_string(max_shard_items) % index
    99              yield item, missing_attributes
   100  
   101      @staticmethod
   102      def _get_format_string(val) -> str:
   103          """
   104          Get a __key__ string for an item in webdataset format
   105          """
   106          num_digits = len(str(val))
   107          format_str = "sample_%0" + str(num_digits) + "d"
   108          return format_str