github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/python/tests/integration/pytorch/test_pytorch_plugin.py (about)

     1  """
     2  Test class for AIStore PyTorch Plugin
     3  Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
     4  """
     5  
     6  import unittest
     7  import torchdata.datapipes.iter as torch_pipes
     8  
     9  from aistore.sdk import Client
    10  from aistore.sdk.errors import AISError, ErrBckNotFound
    11  from aistore.pytorch import AISFileLister, AISFileLoader
    12  from tests.integration import CLUSTER_ENDPOINT
    13  from tests.utils import create_and_put_object, random_string, destroy_bucket
    14  
    15  
    16  # pylint: disable=unused-variable
    17  class TestPytorchPlugin(unittest.TestCase):
    18      """
    19      Integration tests for the Pytorch plugin
    20      """
    21  
    22      def setUp(self) -> None:
    23          self.bck_name = random_string()
    24          self.client = Client(CLUSTER_ENDPOINT)
    25          self.client.bucket(self.bck_name).create()
    26  
    27      def tearDown(self) -> None:
    28          """
    29          Cleanup after each test, destroy the bucket if it exists
    30          """
    31          destroy_bucket(self.client, self.bck_name)
    32  
    33      def test_filelister_with_prefix_variations(self):
    34          num_objs = 10
    35  
    36          # create 10 objects in the /temp dir
    37          for i in range(num_objs):
    38              create_and_put_object(
    39                  self.client, bck_name=self.bck_name, obj_name=f"temp/obj{ i }"
    40              )
    41  
    42          # create 10 objects in the / dir
    43          for i in range(num_objs):
    44              obj_name = f"obj{ i }"
    45              create_and_put_object(
    46                  self.client, bck_name=self.bck_name, obj_name=obj_name
    47              )
    48  
    49          prefixes = [
    50              ["ais://" + self.bck_name],
    51              ["ais://" + self.bck_name + "/"],
    52              ["ais://" + self.bck_name + "/temp/", "ais://" + self.bck_name + "/obj"],
    53          ]
    54          for prefix in prefixes:
    55              urls = AISFileLister(url=CLUSTER_ENDPOINT, source_datapipe=prefix)
    56              ais_loader = AISFileLoader(url=CLUSTER_ENDPOINT, source_datapipe=urls)
    57              with self.assertRaises(TypeError):
    58                  len(urls)
    59              self.assertEqual(len(list(urls)), 20)
    60              self.assertEqual(sum(1 for _ in ais_loader), 20)
    61  
    62      def test_incorrect_inputs(self):
    63          prefixes = ["ais://asdasd"]
    64  
    65          # AISFileLister: Bucket not found
    66          try:
    67              list(AISFileLister(url=CLUSTER_ENDPOINT, source_datapipe=prefixes))
    68          except ErrBckNotFound as err:
    69              self.assertEqual(err.status_code, 404)
    70  
    71          # AISFileLoader: incorrect inputs
    72          url_list = [[""], ["ais:"], ["ais://"], ["s3:///unkown-bucket"]]
    73  
    74          for url in url_list:
    75              with self.assertRaises(AISError):
    76                  s3_loader_dp = AISFileLoader(url=CLUSTER_ENDPOINT, source_datapipe=url)
    77                  for _ in s3_loader_dp:
    78                      pass
    79  
    80      def test_torch_library(self):
    81          # Tests the torch library imports of aistore
    82          torch_pipes.AISFileLister(
    83              url=CLUSTER_ENDPOINT, source_datapipe=["ais://" + self.bck_name]
    84          )
    85          torch_pipes.AISFileLoader(
    86              url=CLUSTER_ENDPOINT, source_datapipe=["ais://" + self.bck_name]
    87          )
    88  
    89  
    90  if __name__ == "__main__":
    91      unittest.main()