github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/python/tests/unit/sdk/test_etl.py (about)

     1  import base64
     2  import unittest
     3  from unittest.mock import Mock
     4  from unittest.mock import patch
     5  
     6  import cloudpickle
     7  
     8  import aistore
     9  from aistore.sdk.const import (
    10      HTTP_METHOD_PUT,
    11      HTTP_METHOD_GET,
    12      HTTP_METHOD_POST,
    13      HTTP_METHOD_DELETE,
    14      URL_PATH_ETL,
    15      UTF_ENCODING,
    16  )
    17  from aistore.sdk.etl_const import (
    18      CODE_TEMPLATE,
    19      ETL_COMM_HPUSH,
    20      ETL_COMM_HPULL,
    21      ETL_COMM_IO,
    22  )
    23  
    24  from aistore.sdk.etl import Etl, _get_default_runtime
    25  from aistore.sdk.types import ETLDetails
    26  from tests.const import ETL_NAME
    27  
    28  
    29  class TestEtl(unittest.TestCase):  # pylint: disable=unused-variable
    30      def setUp(self) -> None:
    31          self.mock_client = Mock()
    32          self.etl_name = ETL_NAME
    33          self.etl = Etl(self.mock_client, self.etl_name)
    34  
    35      def test_init_spec_default_params(self):
    36          expected_action = {
    37              "communication": "hpush://",
    38              "timeout": "5m",
    39              "argument": "",
    40          }
    41          self.init_spec_exec_assert(expected_action)
    42  
    43      def test_init_spec_invalid_comm(self):
    44          with self.assertRaises(ValueError):
    45              self.etl.init_spec("template", communication_type="invalid")
    46  
    47      def test_init_spec(self):
    48          communication_type = ETL_COMM_HPUSH
    49          timeout = "6m"
    50          expected_action = {
    51              "communication": f"{communication_type}://",
    52              "timeout": timeout,
    53              "argument": "",
    54          }
    55          self.init_spec_exec_assert(
    56              expected_action, communication_type=communication_type, timeout=timeout
    57          )
    58  
    59      def init_spec_exec_assert(self, expected_action, **kwargs):
    60          template = "pod spec template"
    61          expected_action["spec"] = base64.b64encode(
    62              template.encode(UTF_ENCODING)
    63          ).decode(UTF_ENCODING)
    64          expected_action["id"] = self.etl_name
    65          expected_response_text = self.etl_name
    66          mock_response = Mock()
    67          mock_response.text = expected_response_text
    68          self.mock_client.request.return_value = mock_response
    69  
    70          response = self.etl.init_spec(template, **kwargs)
    71  
    72          self.assertEqual(expected_response_text, response)
    73          self.mock_client.request.assert_called_with(
    74              HTTP_METHOD_PUT, path=URL_PATH_ETL, json=expected_action
    75          )
    76  
    77      def test_init_code_default_runtime(self):
    78          version_to_runtime = {
    79              (3, 7): "python3.8v2",
    80              (3, 1234): "python3.8v2",
    81              (3, 8): "python3.8v2",
    82              (3, 10): "python3.10v2",
    83              (3, 11): "python3.11v2",
    84          }
    85          for version, runtime in version_to_runtime.items():
    86              with patch.object(aistore.sdk.etl.sys, "version_info") as version_info:
    87                  version_info.major = version[0]
    88                  version_info.minor = version[1]
    89                  self.assertEqual(runtime, _get_default_runtime())
    90  
    91      def test_init_code_default_params(self):
    92          communication_type = ETL_COMM_HPUSH
    93  
    94          expected_action = {
    95              "runtime": _get_default_runtime(),
    96              "communication": f"{communication_type}://",
    97              "timeout": "5m",
    98              "funcs": {"transform": "transform"},
    99              "code": self.encode_fn([], self.transform_fn, communication_type),
   100              "dependencies": base64.b64encode(b"cloudpickle==2.2.0").decode(
   101                  UTF_ENCODING
   102              ),
   103              "argument": "",
   104          }
   105          self.init_code_exec_assert(expected_action)
   106  
   107      def test_init_code_invalid_comm(self):
   108          with self.assertRaises(ValueError):
   109              self.etl.init_code(Mock(), communication_type="invalid")
   110  
   111      def test_init_code(self):
   112          runtime = "python-non-default"
   113          communication_type = ETL_COMM_HPULL
   114          timeout = "6m"
   115          preimported = ["pytorch"]
   116          user_dependencies = ["pytorch"]
   117          chunk_size = 123
   118          arg_type = "url"
   119  
   120          expected_dependencies = user_dependencies.copy()
   121          expected_dependencies.append("cloudpickle==2.2.0")
   122          expected_dep_str = base64.b64encode(
   123              "\n".join(expected_dependencies).encode(UTF_ENCODING)
   124          ).decode(UTF_ENCODING)
   125  
   126          expected_action = {
   127              "runtime": runtime,
   128              "communication": f"{communication_type}://",
   129              "timeout": timeout,
   130              "funcs": {"transform": "transform"},
   131              "code": self.encode_fn(preimported, self.transform_fn, communication_type),
   132              "dependencies": expected_dep_str,
   133              "chunk_size": chunk_size,
   134              "argument": arg_type,
   135          }
   136          self.init_code_exec_assert(
   137              expected_action,
   138              preimported_modules=preimported,
   139              dependencies=user_dependencies,
   140              runtime=runtime,
   141              communication_type=communication_type,
   142              timeout=timeout,
   143              chunk_size=chunk_size,
   144              arg_type=arg_type,
   145          )
   146  
   147      @staticmethod
   148      def transform_fn():
   149          print("example action")
   150  
   151      @staticmethod
   152      def encode_fn(preimported_modules, func, comm_type):
   153          transform = base64.b64encode(cloudpickle.dumps(func)).decode(UTF_ENCODING)
   154          io_comm_context = "transform()" if comm_type == ETL_COMM_IO else ""
   155          template = CODE_TEMPLATE.format(
   156              preimported_modules, transform, io_comm_context
   157          ).encode(UTF_ENCODING)
   158          return base64.b64encode(template).decode(UTF_ENCODING)
   159  
   160      def init_code_exec_assert(self, expected_action, **kwargs):
   161          expected_action["id"] = self.etl_name
   162  
   163          expected_response_text = "response text"
   164          mock_response = Mock()
   165          mock_response.text = expected_response_text
   166          self.mock_client.request.return_value = mock_response
   167  
   168          response = self.etl.init_code(transform=self.transform_fn, **kwargs)
   169  
   170          self.assertEqual(expected_response_text, response)
   171          self.mock_client.request.assert_called_with(
   172              HTTP_METHOD_PUT, path=URL_PATH_ETL, json=expected_action
   173          )
   174  
   175      def test_view(self):
   176          mock_response = Mock()
   177          self.mock_client.request_deserialize.return_value = mock_response
   178          response = self.etl.view()
   179          self.assertEqual(mock_response, response)
   180          self.mock_client.request_deserialize.assert_called_with(
   181              HTTP_METHOD_GET, path=f"etl/{ self.etl_name }", res_model=ETLDetails
   182          )
   183  
   184      def test_start(self):
   185          self.etl.start()
   186          self.mock_client.request.assert_called_with(
   187              HTTP_METHOD_POST, path=f"etl/{ self.etl_name }/start"
   188          )
   189  
   190      def test_stop(self):
   191          self.etl.stop()
   192          self.mock_client.request.assert_called_with(
   193              HTTP_METHOD_POST, path=f"etl/{ self.etl_name }/stop"
   194          )
   195  
   196      def test_delete(self):
   197          self.etl.delete()
   198          self.mock_client.request.assert_called_with(
   199              HTTP_METHOD_DELETE, path=f"etl/{ self.etl_name }"
   200          )