github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/python/tests/integration/sdk/test_etl_ops.py (about)

     1  #
     2  # Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
     3  #
     4  
     5  from itertools import cycle
     6  import unittest
     7  import hashlib
     8  import sys
     9  import time
    10  
    11  import pytest
    12  
    13  from aistore.sdk import Client, Bucket
    14  from aistore.sdk.etl_const import ETL_COMM_HPUSH, ETL_COMM_IO
    15  from aistore.sdk.errors import AISError
    16  from aistore.sdk.etl_templates import MD5, ECHO
    17  from tests.integration import CLUSTER_ENDPOINT
    18  from tests.utils import create_and_put_object, random_string
    19  
    20  ETL_NAME_CODE = "etl-" + random_string(5)
    21  ETL_NAME_CODE_IO = "etl-" + random_string(5)
    22  ETL_NAME_CODE_STREAM = "etl-" + random_string(5)
    23  ETL_NAME_SPEC = "etl-" + random_string(5)
    24  ETL_NAME_SPEC_COMP = "etl-" + random_string(5)
    25  
    26  
    27  # pylint: disable=unused-variable
    28  class TestETLOps(unittest.TestCase):
    29      def setUp(self) -> None:
    30          self.bck_name = random_string()
    31          print("URL END PT ", CLUSTER_ENDPOINT)
    32          self.client = Client(CLUSTER_ENDPOINT)
    33  
    34          self.bucket = self.client.bucket(bck_name=self.bck_name).create()
    35          self.obj_name = "temp-obj1.jpg"
    36          self.obj_size = 128
    37          self.content = create_and_put_object(
    38              client=self.client,
    39              bck_name=self.bck_name,
    40              obj_name=self.obj_name,
    41              obj_size=self.obj_size,
    42          )
    43          create_and_put_object(
    44              client=self.client, bck_name=self.bck_name, obj_name="obj2.jpg"
    45          )
    46  
    47          self.current_etl_count = len(self.client.cluster().list_running_etls())
    48  
    49      def tearDown(self) -> None:
    50          # Try to destroy all temporary buckets if there are left.
    51          for bucket in self.client.cluster().list_buckets():
    52              self.client.bucket(bucket.name).delete(missing_ok=True)
    53  
    54          # delete all the etls
    55          for etl in self.client.cluster().list_running_etls():
    56              self.client.etl(etl.id).stop()
    57              self.client.etl(etl.id).delete()
    58  
    59      # pylint: disable=too-many-statements,too-many-locals
    60      @pytest.mark.etl
    61      def test_etl_apis(self):
    62          # code
    63          def transform(input_bytes):
    64              md5 = hashlib.md5()
    65              md5.update(input_bytes)
    66              return md5.hexdigest().encode()
    67  
    68          code_etl = self.client.etl(ETL_NAME_CODE)
    69          code_etl.init_code(transform=transform)
    70  
    71          obj = self.bucket.object(self.obj_name).get(etl_name=code_etl.name).read_all()
    72          self.assertEqual(obj, transform(bytes(self.content)))
    73          self.assertEqual(
    74              self.current_etl_count + 1, len(self.client.cluster().list_running_etls())
    75          )
    76  
    77          # code (io comm)
    78          def main():
    79              md5 = hashlib.md5()
    80              chunk = sys.stdin.buffer.read()
    81              md5.update(chunk)
    82              sys.stdout.buffer.write(md5.hexdigest().encode())
    83  
    84          code_io_etl = self.client.etl(ETL_NAME_CODE_IO)
    85          code_io_etl.init_code(transform=main, communication_type=ETL_COMM_IO)
    86  
    87          obj_io = (
    88              self.bucket.object(self.obj_name).get(etl_name=code_io_etl.name).read_all()
    89          )
    90          self.assertEqual(obj_io, transform(bytes(self.content)))
    91  
    92          code_io_etl.stop()
    93          code_io_etl.delete()
    94  
    95          # spec
    96          template = MD5.format(communication_type=ETL_COMM_HPUSH)
    97          spec_etl = self.client.etl(ETL_NAME_SPEC)
    98          spec_etl.init_spec(template=template)
    99  
   100          obj = self.bucket.object(self.obj_name).get(etl_name=spec_etl.name).read_all()
   101          self.assertEqual(obj, transform(bytes(self.content)))
   102  
   103          self.assertEqual(
   104              self.current_etl_count + 2, len(self.client.cluster().list_running_etls())
   105          )
   106  
   107          self.assertIsNotNone(code_etl.view())
   108          self.assertIsNotNone(spec_etl.view())
   109  
   110          temp_bck1 = self.client.bucket(random_string()).create()
   111  
   112          # Transform Bucket with MD5 Template
   113          job_id = self.bucket.transform(
   114              etl_name=spec_etl.name, to_bck=temp_bck1, prefix_filter="temp-"
   115          )
   116          self.client.job(job_id).wait()
   117  
   118          starting_obj = self.bucket.list_objects().entries
   119          transformed_obj = temp_bck1.list_objects().entries
   120          # Should transform only the object defined by the prefix filter
   121          self.assertEqual(len(starting_obj) - 1, len(transformed_obj))
   122  
   123          md5_obj = temp_bck1.object(self.obj_name).get().read_all()
   124  
   125          # Verify bucket-level transformation and object-level transformation are the same
   126          self.assertEqual(obj, md5_obj)
   127  
   128          # Start ETL with ECHO template
   129          template = ECHO.format(communication_type=ETL_COMM_HPUSH)
   130          echo_spec_etl = self.client.etl(ETL_NAME_SPEC_COMP)
   131          echo_spec_etl.init_spec(template=template)
   132  
   133          temp_bck2 = self.client.bucket(random_string()).create()
   134  
   135          # Transform bucket with ECHO template
   136          job_id = self.bucket.transform(
   137              etl_name=echo_spec_etl.name,
   138              to_bck=temp_bck2,
   139              ext={"jpg": "txt"},
   140          )
   141          self.client.job(job_id).wait()
   142  
   143          # Verify extension rename
   144          for obj_iter in temp_bck2.list_objects().entries:
   145              self.assertEqual(obj_iter.name.split(".")[1], "txt")
   146  
   147          echo_obj = temp_bck2.object("temp-obj1.txt").get().read_all()
   148  
   149          # Verify different bucket-level transformations are not the same (compare ECHO transformation and MD5
   150          # transformation)
   151          self.assertNotEqual(md5_obj, echo_obj)
   152  
   153          echo_spec_etl.stop()
   154          echo_spec_etl.delete()
   155  
   156          # Transform w/ non-existent ETL name raises exception
   157          with self.assertRaises(AISError):
   158              self.bucket.transform(
   159                  etl_name="faulty-name", to_bck=Bucket(random_string())
   160              )
   161  
   162          # Stop ETLs
   163          code_etl.stop()
   164          spec_etl.stop()
   165          self.assertEqual(
   166              len(self.client.cluster().list_running_etls()), self.current_etl_count
   167          )
   168  
   169          # Start stopped ETLs
   170          code_etl.start()
   171          spec_etl.start()
   172          self.assertEqual(
   173              len(self.client.cluster().list_running_etls()), self.current_etl_count + 2
   174          )
   175  
   176          # Delete stopped ETLs
   177          code_etl.stop()
   178          spec_etl.stop()
   179          code_etl.delete()
   180          spec_etl.delete()
   181  
   182          # Starting deleted ETLs raises error
   183          with self.assertRaises(AISError):
   184              code_etl.start()
   185          with self.assertRaises(AISError):
   186              spec_etl.start()
   187  
   188      @pytest.mark.etl
   189      def test_etl_apis_stress(self):
   190          num_objs = 200
   191          content = {}
   192          for i in range(num_objs):
   193              obj_name = f"obj{ i }"
   194              content[obj_name] = create_and_put_object(
   195                  client=self.client, bck_name=self.bck_name, obj_name=obj_name
   196              )
   197  
   198          # code (hpush)
   199          def transform(input_bytes):
   200              md5 = hashlib.md5()
   201              md5.update(input_bytes)
   202              return md5.hexdigest().encode()
   203  
   204          md5_hpush_etl = self.client.etl(ETL_NAME_CODE)
   205          md5_hpush_etl.init_code(transform=transform)
   206  
   207          # code (io comm)
   208          def main():
   209              md5 = hashlib.md5()
   210              chunk = sys.stdin.buffer.read()
   211              md5.update(chunk)
   212              sys.stdout.buffer.write(md5.hexdigest().encode())
   213  
   214          md5_io_etl = self.client.etl(ETL_NAME_CODE_IO)
   215          md5_io_etl.init_code(transform=main, communication_type=ETL_COMM_IO)
   216  
   217          start_time = time.time()
   218          job_id = self.bucket.transform(
   219              etl_name=md5_hpush_etl.name, to_bck=Bucket("transformed-etl-hpush")
   220          )
   221          self.client.job(job_id).wait()
   222          print("Transform bucket using HPUSH took ", time.time() - start_time)
   223  
   224          start_time = time.time()
   225          job_id = self.bucket.transform(
   226              etl_name=md5_io_etl.name, to_bck=Bucket("transformed-etl-io")
   227          )
   228          self.client.job(job_id).wait()
   229          print("Transform bucket using IO took ", time.time() - start_time)
   230  
   231          for key, value in content.items():
   232              transformed_obj_hpush = (
   233                  self.bucket.object(key).get(etl_name=md5_hpush_etl.name).read_all()
   234              )
   235              transformed_obj_io = (
   236                  self.bucket.object(key).get(etl_name=md5_io_etl.name).read_all()
   237              )
   238  
   239              self.assertEqual(transform(bytes(value)), transformed_obj_hpush)
   240              self.assertEqual(transform(bytes(value)), transformed_obj_io)
   241  
   242      @pytest.mark.etl
   243      def test_etl_apis_stream(self):
   244          def transform(reader, writer):
   245              checksum = hashlib.md5()
   246              for byte in reader:
   247                  checksum.update(byte)
   248              writer.write(checksum.hexdigest().encode())
   249  
   250          code_stream_etl = self.client.etl(ETL_NAME_CODE_STREAM)
   251          code_stream_etl.init_code(transform=transform, chunk_size=32768)
   252  
   253          obj = (
   254              self.bucket.object(self.obj_name)
   255              .get(etl_name=code_stream_etl.name)
   256              .read_all()
   257          )
   258          md5 = hashlib.md5()
   259          md5.update(self.content)
   260          self.assertEqual(obj, md5.hexdigest().encode())
   261  
   262      @pytest.mark.etl
   263      def test_etl_api_xor(self):
   264          def transform(reader, writer):
   265              checksum = hashlib.md5()
   266              key = b"AISTORE"
   267              for byte in reader:
   268                  out = bytes([_a ^ _b for _a, _b in zip(byte, cycle(key))])
   269                  writer.write(out)
   270                  checksum.update(out)
   271              writer.write(checksum.hexdigest().encode())
   272  
   273          xor_etl = self.client.etl("etl-xor1")
   274          xor_etl.init_code(transform=transform, chunk_size=32)
   275          transformed_obj = (
   276              self.bucket.object(self.obj_name).get(etl_name=xor_etl.name).read_all()
   277          )
   278          data, checksum = transformed_obj[:-32], transformed_obj[-32:]
   279          computed_checksum = hashlib.md5(data).hexdigest().encode()
   280          self.assertEqual(checksum, computed_checksum)
   281  
   282      @pytest.mark.etl
   283      def test_etl_transform_url(self):
   284          def url_transform(url):
   285              return url.encode("utf-8")
   286  
   287          url_etl = self.client.etl("etl-hpull-url")
   288          url_etl.init_code(
   289              transform=url_transform, arg_type="url", communication_type="hpull"
   290          )
   291          res = self.bucket.object(self.obj_name).get(etl_name=url_etl.name).read_all()
   292          result_url = res.decode("utf-8")
   293  
   294          self.assertTrue(self.bucket.name in result_url)
   295          self.assertTrue(self.obj_name in result_url)
   296  
   297  
   298  if __name__ == "__main__":
   299      unittest.main()