github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/python/tests/integration/sdk/test_etl_ops.py (about) 1 # 2 # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved. 3 # 4 5 from itertools import cycle 6 import unittest 7 import hashlib 8 import sys 9 import time 10 11 import pytest 12 13 from aistore.sdk import Client, Bucket 14 from aistore.sdk.etl_const import ETL_COMM_HPUSH, ETL_COMM_IO 15 from aistore.sdk.errors import AISError 16 from aistore.sdk.etl_templates import MD5, ECHO 17 from tests.integration import CLUSTER_ENDPOINT 18 from tests.utils import create_and_put_object, random_string 19 20 ETL_NAME_CODE = "etl-" + random_string(5) 21 ETL_NAME_CODE_IO = "etl-" + random_string(5) 22 ETL_NAME_CODE_STREAM = "etl-" + random_string(5) 23 ETL_NAME_SPEC = "etl-" + random_string(5) 24 ETL_NAME_SPEC_COMP = "etl-" + random_string(5) 25 26 27 # pylint: disable=unused-variable 28 class TestETLOps(unittest.TestCase): 29 def setUp(self) -> None: 30 self.bck_name = random_string() 31 print("URL END PT ", CLUSTER_ENDPOINT) 32 self.client = Client(CLUSTER_ENDPOINT) 33 34 self.bucket = self.client.bucket(bck_name=self.bck_name).create() 35 self.obj_name = "temp-obj1.jpg" 36 self.obj_size = 128 37 self.content = create_and_put_object( 38 client=self.client, 39 bck_name=self.bck_name, 40 obj_name=self.obj_name, 41 obj_size=self.obj_size, 42 ) 43 create_and_put_object( 44 client=self.client, bck_name=self.bck_name, obj_name="obj2.jpg" 45 ) 46 47 self.current_etl_count = len(self.client.cluster().list_running_etls()) 48 49 def tearDown(self) -> None: 50 # Try to destroy all temporary buckets if there are left. 51 for bucket in self.client.cluster().list_buckets(): 52 self.client.bucket(bucket.name).delete(missing_ok=True) 53 54 # delete all the etls 55 for etl in self.client.cluster().list_running_etls(): 56 self.client.etl(etl.id).stop() 57 self.client.etl(etl.id).delete() 58 59 # pylint: disable=too-many-statements,too-many-locals 60 @pytest.mark.etl 61 def test_etl_apis(self): 62 # code 63 def transform(input_bytes): 64 md5 = hashlib.md5() 65 md5.update(input_bytes) 66 return md5.hexdigest().encode() 67 68 code_etl = self.client.etl(ETL_NAME_CODE) 69 code_etl.init_code(transform=transform) 70 71 obj = self.bucket.object(self.obj_name).get(etl_name=code_etl.name).read_all() 72 self.assertEqual(obj, transform(bytes(self.content))) 73 self.assertEqual( 74 self.current_etl_count + 1, len(self.client.cluster().list_running_etls()) 75 ) 76 77 # code (io comm) 78 def main(): 79 md5 = hashlib.md5() 80 chunk = sys.stdin.buffer.read() 81 md5.update(chunk) 82 sys.stdout.buffer.write(md5.hexdigest().encode()) 83 84 code_io_etl = self.client.etl(ETL_NAME_CODE_IO) 85 code_io_etl.init_code(transform=main, communication_type=ETL_COMM_IO) 86 87 obj_io = ( 88 self.bucket.object(self.obj_name).get(etl_name=code_io_etl.name).read_all() 89 ) 90 self.assertEqual(obj_io, transform(bytes(self.content))) 91 92 code_io_etl.stop() 93 code_io_etl.delete() 94 95 # spec 96 template = MD5.format(communication_type=ETL_COMM_HPUSH) 97 spec_etl = self.client.etl(ETL_NAME_SPEC) 98 spec_etl.init_spec(template=template) 99 100 obj = self.bucket.object(self.obj_name).get(etl_name=spec_etl.name).read_all() 101 self.assertEqual(obj, transform(bytes(self.content))) 102 103 self.assertEqual( 104 self.current_etl_count + 2, len(self.client.cluster().list_running_etls()) 105 ) 106 107 self.assertIsNotNone(code_etl.view()) 108 self.assertIsNotNone(spec_etl.view()) 109 110 temp_bck1 = self.client.bucket(random_string()).create() 111 112 # Transform Bucket with MD5 Template 113 job_id = self.bucket.transform( 114 etl_name=spec_etl.name, to_bck=temp_bck1, prefix_filter="temp-" 115 ) 116 self.client.job(job_id).wait() 117 118 starting_obj = self.bucket.list_objects().entries 119 transformed_obj = temp_bck1.list_objects().entries 120 # Should transform only the object defined by the prefix filter 121 self.assertEqual(len(starting_obj) - 1, len(transformed_obj)) 122 123 md5_obj = temp_bck1.object(self.obj_name).get().read_all() 124 125 # Verify bucket-level transformation and object-level transformation are the same 126 self.assertEqual(obj, md5_obj) 127 128 # Start ETL with ECHO template 129 template = ECHO.format(communication_type=ETL_COMM_HPUSH) 130 echo_spec_etl = self.client.etl(ETL_NAME_SPEC_COMP) 131 echo_spec_etl.init_spec(template=template) 132 133 temp_bck2 = self.client.bucket(random_string()).create() 134 135 # Transform bucket with ECHO template 136 job_id = self.bucket.transform( 137 etl_name=echo_spec_etl.name, 138 to_bck=temp_bck2, 139 ext={"jpg": "txt"}, 140 ) 141 self.client.job(job_id).wait() 142 143 # Verify extension rename 144 for obj_iter in temp_bck2.list_objects().entries: 145 self.assertEqual(obj_iter.name.split(".")[1], "txt") 146 147 echo_obj = temp_bck2.object("temp-obj1.txt").get().read_all() 148 149 # Verify different bucket-level transformations are not the same (compare ECHO transformation and MD5 150 # transformation) 151 self.assertNotEqual(md5_obj, echo_obj) 152 153 echo_spec_etl.stop() 154 echo_spec_etl.delete() 155 156 # Transform w/ non-existent ETL name raises exception 157 with self.assertRaises(AISError): 158 self.bucket.transform( 159 etl_name="faulty-name", to_bck=Bucket(random_string()) 160 ) 161 162 # Stop ETLs 163 code_etl.stop() 164 spec_etl.stop() 165 self.assertEqual( 166 len(self.client.cluster().list_running_etls()), self.current_etl_count 167 ) 168 169 # Start stopped ETLs 170 code_etl.start() 171 spec_etl.start() 172 self.assertEqual( 173 len(self.client.cluster().list_running_etls()), self.current_etl_count + 2 174 ) 175 176 # Delete stopped ETLs 177 code_etl.stop() 178 spec_etl.stop() 179 code_etl.delete() 180 spec_etl.delete() 181 182 # Starting deleted ETLs raises error 183 with self.assertRaises(AISError): 184 code_etl.start() 185 with self.assertRaises(AISError): 186 spec_etl.start() 187 188 @pytest.mark.etl 189 def test_etl_apis_stress(self): 190 num_objs = 200 191 content = {} 192 for i in range(num_objs): 193 obj_name = f"obj{ i }" 194 content[obj_name] = create_and_put_object( 195 client=self.client, bck_name=self.bck_name, obj_name=obj_name 196 ) 197 198 # code (hpush) 199 def transform(input_bytes): 200 md5 = hashlib.md5() 201 md5.update(input_bytes) 202 return md5.hexdigest().encode() 203 204 md5_hpush_etl = self.client.etl(ETL_NAME_CODE) 205 md5_hpush_etl.init_code(transform=transform) 206 207 # code (io comm) 208 def main(): 209 md5 = hashlib.md5() 210 chunk = sys.stdin.buffer.read() 211 md5.update(chunk) 212 sys.stdout.buffer.write(md5.hexdigest().encode()) 213 214 md5_io_etl = self.client.etl(ETL_NAME_CODE_IO) 215 md5_io_etl.init_code(transform=main, communication_type=ETL_COMM_IO) 216 217 start_time = time.time() 218 job_id = self.bucket.transform( 219 etl_name=md5_hpush_etl.name, to_bck=Bucket("transformed-etl-hpush") 220 ) 221 self.client.job(job_id).wait() 222 print("Transform bucket using HPUSH took ", time.time() - start_time) 223 224 start_time = time.time() 225 job_id = self.bucket.transform( 226 etl_name=md5_io_etl.name, to_bck=Bucket("transformed-etl-io") 227 ) 228 self.client.job(job_id).wait() 229 print("Transform bucket using IO took ", time.time() - start_time) 230 231 for key, value in content.items(): 232 transformed_obj_hpush = ( 233 self.bucket.object(key).get(etl_name=md5_hpush_etl.name).read_all() 234 ) 235 transformed_obj_io = ( 236 self.bucket.object(key).get(etl_name=md5_io_etl.name).read_all() 237 ) 238 239 self.assertEqual(transform(bytes(value)), transformed_obj_hpush) 240 self.assertEqual(transform(bytes(value)), transformed_obj_io) 241 242 @pytest.mark.etl 243 def test_etl_apis_stream(self): 244 def transform(reader, writer): 245 checksum = hashlib.md5() 246 for byte in reader: 247 checksum.update(byte) 248 writer.write(checksum.hexdigest().encode()) 249 250 code_stream_etl = self.client.etl(ETL_NAME_CODE_STREAM) 251 code_stream_etl.init_code(transform=transform, chunk_size=32768) 252 253 obj = ( 254 self.bucket.object(self.obj_name) 255 .get(etl_name=code_stream_etl.name) 256 .read_all() 257 ) 258 md5 = hashlib.md5() 259 md5.update(self.content) 260 self.assertEqual(obj, md5.hexdigest().encode()) 261 262 @pytest.mark.etl 263 def test_etl_api_xor(self): 264 def transform(reader, writer): 265 checksum = hashlib.md5() 266 key = b"AISTORE" 267 for byte in reader: 268 out = bytes([_a ^ _b for _a, _b in zip(byte, cycle(key))]) 269 writer.write(out) 270 checksum.update(out) 271 writer.write(checksum.hexdigest().encode()) 272 273 xor_etl = self.client.etl("etl-xor1") 274 xor_etl.init_code(transform=transform, chunk_size=32) 275 transformed_obj = ( 276 self.bucket.object(self.obj_name).get(etl_name=xor_etl.name).read_all() 277 ) 278 data, checksum = transformed_obj[:-32], transformed_obj[-32:] 279 computed_checksum = hashlib.md5(data).hexdigest().encode() 280 self.assertEqual(checksum, computed_checksum) 281 282 @pytest.mark.etl 283 def test_etl_transform_url(self): 284 def url_transform(url): 285 return url.encode("utf-8") 286 287 url_etl = self.client.etl("etl-hpull-url") 288 url_etl.init_code( 289 transform=url_transform, arg_type="url", communication_type="hpull" 290 ) 291 res = self.bucket.object(self.obj_name).get(etl_name=url_etl.name).read_all() 292 result_url = res.decode("utf-8") 293 294 self.assertTrue(self.bucket.name in result_url) 295 self.assertTrue(self.obj_name in result_url) 296 297 298 if __name__ == "__main__": 299 unittest.main()