github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/python/tests/integration/sdk/remote_enabled_test.py (about) 1 # 2 # Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved. 3 # 4 5 from typing import List 6 import unittest 7 import boto3 8 9 from aistore.sdk.const import PROVIDER_AIS 10 from aistore import Client 11 from tests.integration import ( 12 REMOTE_SET, 13 REMOTE_BUCKET, 14 CLUSTER_ENDPOINT, 15 ) 16 from tests.utils import ( 17 random_string, 18 destroy_bucket, 19 create_and_put_objects, 20 create_and_put_object, 21 ) 22 from tests import AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY 23 from tests.integration.boto3 import AWS_REGION 24 from tests.const import TEST_TIMEOUT_LONG, OBJECT_COUNT, SUFFIX_NAME 25 26 27 class RemoteEnabledTest(unittest.TestCase): 28 """ 29 This class is intended to be used with all tests that work with remote buckets. 30 It provides helper methods for dealing with remote buckets and objects and tracking them for proper cleanup. 31 This includes prefixing all objects with a unique value and deleting all objects after tests finish 32 to avoid collisions with multiple instances using the same bucket. 33 To use this class with another test class, simply inherit from this rather than TestCase. 34 To extend setUp behavior in a child class, define them as normal for a TestCase then call 35 super().setUp() before adding additional setup steps (same process for tearDown) 36 """ 37 38 def setUp(self) -> None: 39 self.bck_name = random_string() 40 self.client = Client(CLUSTER_ENDPOINT) 41 self.buckets = [] 42 self.obj_prefix = f"{self._testMethodName}-{random_string(6)}-" 43 44 if REMOTE_SET: 45 self.cloud_objects = [] 46 provider, bck_name = REMOTE_BUCKET.split("://") 47 self.bucket = self.client.bucket(bck_name, provider=provider) 48 self.provider = provider 49 self.bck_name = bck_name 50 else: 51 self.provider = PROVIDER_AIS 52 self.bucket = self._create_bucket(self.bck_name) 53 54 def tearDown(self) -> None: 55 """ 56 Cleanup after each test, destroy the bucket if it exists 57 """ 58 if REMOTE_SET: 59 entries = self.bucket.list_all_objects(prefix=self.obj_prefix) 60 obj_names = [entry.name for entry in entries] 61 obj_names.extend(self.cloud_objects) 62 if len(obj_names) > 0: 63 job_id = self.bucket.objects(obj_names=obj_names).delete() 64 self.client.job(job_id).wait(timeout=TEST_TIMEOUT_LONG) 65 for bck in self.buckets: 66 destroy_bucket(self.client, bck) 67 68 def _create_bucket(self, bck_name, provider=PROVIDER_AIS): 69 """ 70 Create a bucket and store its name for later cleanup 71 Args: 72 bck_name: Name of new bucket 73 provider: Provider for new bucket 74 """ 75 bck = self.client.bucket(bck_name, provider=provider) 76 bck.create() 77 self._register_for_post_test_cleanup(names=[bck_name], is_bucket=True) 78 return bck 79 80 def _register_for_post_test_cleanup(self, names: List[str], is_bucket: bool): 81 """ 82 Register objects or buckets for post-test cleanup 83 84 Args: 85 names (List[str]): Names of buckets or objects 86 is_bucket (bool): True if we are storing a bucket; False for an object 87 """ 88 if is_bucket: 89 self.buckets.extend(names) 90 elif REMOTE_SET: 91 self.cloud_objects.extend(names) 92 93 def _create_object(self, obj_name=""): 94 """ 95 Create an object with the given object name and track them for later cleanup 96 97 Args: 98 obj_name: Name of the object to create 99 100 Returns: 101 The object created 102 """ 103 obj = self.bucket.object(obj_name=obj_name) 104 self._register_for_post_test_cleanup(names=[obj_name], is_bucket=False) 105 return obj 106 107 def _create_object_with_content(self, obj_name="", obj_size=None): 108 """ 109 Create an object with the given object name and some content and track them for later cleanup 110 111 Args: 112 obj_name: Name of the object to create 113 114 Returns: 115 The content of the object created 116 """ 117 118 content = create_and_put_object( 119 client=self.client, 120 bck_name=self.bck_name, 121 obj_name=obj_name, 122 provider=self.provider, 123 obj_size=obj_size, 124 ) 125 self._register_for_post_test_cleanup(names=[obj_name], is_bucket=False) 126 return content 127 128 def _create_objects( 129 self, num_obj=OBJECT_COUNT, suffix="", obj_names=None, obj_size=None 130 ): 131 """ 132 Create a list of objects using a unique test prefix and track them for later cleanup 133 Args: 134 num_obj: Number of objects to create 135 suffix: Optional suffix for each object name 136 """ 137 obj_names = create_and_put_objects( 138 self.client, 139 self.bucket, 140 self.obj_prefix, 141 suffix, 142 num_obj, 143 obj_names, 144 obj_size, 145 ) 146 self._register_for_post_test_cleanup(names=obj_names, is_bucket=False) 147 return obj_names 148 149 def _check_all_objects_cached(self, num_obj, expected_cached): 150 """ 151 List all objects with this test prefix and validate the cache status 152 Args: 153 num_obj: Number of objects we expect to find 154 expected_cached: Whether we expect them to be cached 155 """ 156 objects = self.bucket.list_objects( 157 props="name,cached", prefix=self.obj_prefix 158 ).entries 159 self.assertEqual(num_obj, len(objects)) 160 self._validate_objects_cached(objects, expected_cached) 161 162 def _validate_objects_cached(self, objects, expected_cached): 163 """ 164 Validate that all objects provided are either cached or not 165 Args: 166 objects: List of objects to check 167 expected_cached: Whether we expect them to be cached 168 """ 169 self.assertTrue(len(objects) > 0) 170 for obj in objects: 171 self.assertTrue(obj.is_ok()) 172 if expected_cached: 173 self.assertTrue(obj.is_cached()) 174 else: 175 self.assertFalse(obj.is_cached()) 176 177 def _verify_cached_objects(self, expected_object_count, cached_range): 178 """ 179 List each of the objects and verify the correct count and that all objects matching 180 the cached range are cached and all others are not 181 182 Args: 183 expected_object_count: expected number of objects to list 184 cached_range: object indices that should be cached, all others should not 185 """ 186 objects = self.bucket.list_objects( 187 props="name,cached", prefix=self.obj_prefix 188 ).entries 189 self.assertEqual(expected_object_count, len(objects)) 190 cached_names = {self.obj_prefix + str(x) + SUFFIX_NAME for x in cached_range} 191 cached_objs = [] 192 evicted_objs = [] 193 for obj in objects: 194 if obj.name in cached_names: 195 cached_objs.append(obj) 196 else: 197 evicted_objs.append(obj) 198 if len(cached_objs) > 0: 199 self._validate_objects_cached(cached_objs, True) 200 if len(evicted_objs) > 0: 201 self._validate_objects_cached(evicted_objs, False) 202 203 def _get_boto3_client(self): 204 return boto3.client( 205 "s3", 206 region_name=AWS_REGION, 207 aws_access_key_id=AWS_ACCESS_KEY_ID, 208 aws_secret_access_key=AWS_SECRET_ACCESS_KEY, 209 )