github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/python/tests/integration/pytorch/test_pytorch_plugin.py (about) 1 """ 2 Test class for AIStore PyTorch Plugin 3 Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. 4 """ 5 6 import unittest 7 import torchdata.datapipes.iter as torch_pipes 8 9 from aistore.sdk import Client 10 from aistore.sdk.errors import AISError, ErrBckNotFound 11 from aistore.pytorch import AISFileLister, AISFileLoader 12 from tests.integration import CLUSTER_ENDPOINT 13 from tests.utils import create_and_put_object, random_string, destroy_bucket 14 15 16 # pylint: disable=unused-variable 17 class TestPytorchPlugin(unittest.TestCase): 18 """ 19 Integration tests for the Pytorch plugin 20 """ 21 22 def setUp(self) -> None: 23 self.bck_name = random_string() 24 self.client = Client(CLUSTER_ENDPOINT) 25 self.client.bucket(self.bck_name).create() 26 27 def tearDown(self) -> None: 28 """ 29 Cleanup after each test, destroy the bucket if it exists 30 """ 31 destroy_bucket(self.client, self.bck_name) 32 33 def test_filelister_with_prefix_variations(self): 34 num_objs = 10 35 36 # create 10 objects in the /temp dir 37 for i in range(num_objs): 38 create_and_put_object( 39 self.client, bck_name=self.bck_name, obj_name=f"temp/obj{ i }" 40 ) 41 42 # create 10 objects in the / dir 43 for i in range(num_objs): 44 obj_name = f"obj{ i }" 45 create_and_put_object( 46 self.client, bck_name=self.bck_name, obj_name=obj_name 47 ) 48 49 prefixes = [ 50 ["ais://" + self.bck_name], 51 ["ais://" + self.bck_name + "/"], 52 ["ais://" + self.bck_name + "/temp/", "ais://" + self.bck_name + "/obj"], 53 ] 54 for prefix in prefixes: 55 urls = AISFileLister(url=CLUSTER_ENDPOINT, source_datapipe=prefix) 56 ais_loader = AISFileLoader(url=CLUSTER_ENDPOINT, source_datapipe=urls) 57 with self.assertRaises(TypeError): 58 len(urls) 59 self.assertEqual(len(list(urls)), 20) 60 self.assertEqual(sum(1 for _ in ais_loader), 20) 61 62 def test_incorrect_inputs(self): 63 prefixes = ["ais://asdasd"] 64 65 # AISFileLister: Bucket not found 66 try: 67 list(AISFileLister(url=CLUSTER_ENDPOINT, source_datapipe=prefixes)) 68 except ErrBckNotFound as err: 69 self.assertEqual(err.status_code, 404) 70 71 # AISFileLoader: incorrect inputs 72 url_list = [[""], ["ais:"], ["ais://"], ["s3:///unkown-bucket"]] 73 74 for url in url_list: 75 with self.assertRaises(AISError): 76 s3_loader_dp = AISFileLoader(url=CLUSTER_ENDPOINT, source_datapipe=url) 77 for _ in s3_loader_dp: 78 pass 79 80 def test_torch_library(self): 81 # Tests the torch library imports of aistore 82 torch_pipes.AISFileLister( 83 url=CLUSTER_ENDPOINT, source_datapipe=["ais://" + self.bck_name] 84 ) 85 torch_pipes.AISFileLoader( 86 url=CLUSTER_ENDPOINT, source_datapipe=["ais://" + self.bck_name] 87 ) 88 89 90 if __name__ == "__main__": 91 unittest.main()