github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/python/tests/unit/sdk/test_etl.py (about) 1 import base64 2 import unittest 3 from unittest.mock import Mock 4 from unittest.mock import patch 5 6 import cloudpickle 7 8 import aistore 9 from aistore.sdk.const import ( 10 HTTP_METHOD_PUT, 11 HTTP_METHOD_GET, 12 HTTP_METHOD_POST, 13 HTTP_METHOD_DELETE, 14 URL_PATH_ETL, 15 UTF_ENCODING, 16 ) 17 from aistore.sdk.etl_const import ( 18 CODE_TEMPLATE, 19 ETL_COMM_HPUSH, 20 ETL_COMM_HPULL, 21 ETL_COMM_IO, 22 ) 23 24 from aistore.sdk.etl import Etl, _get_default_runtime 25 from aistore.sdk.types import ETLDetails 26 from tests.const import ETL_NAME 27 28 29 class TestEtl(unittest.TestCase): # pylint: disable=unused-variable 30 def setUp(self) -> None: 31 self.mock_client = Mock() 32 self.etl_name = ETL_NAME 33 self.etl = Etl(self.mock_client, self.etl_name) 34 35 def test_init_spec_default_params(self): 36 expected_action = { 37 "communication": "hpush://", 38 "timeout": "5m", 39 "argument": "", 40 } 41 self.init_spec_exec_assert(expected_action) 42 43 def test_init_spec_invalid_comm(self): 44 with self.assertRaises(ValueError): 45 self.etl.init_spec("template", communication_type="invalid") 46 47 def test_init_spec(self): 48 communication_type = ETL_COMM_HPUSH 49 timeout = "6m" 50 expected_action = { 51 "communication": f"{communication_type}://", 52 "timeout": timeout, 53 "argument": "", 54 } 55 self.init_spec_exec_assert( 56 expected_action, communication_type=communication_type, timeout=timeout 57 ) 58 59 def init_spec_exec_assert(self, expected_action, **kwargs): 60 template = "pod spec template" 61 expected_action["spec"] = base64.b64encode( 62 template.encode(UTF_ENCODING) 63 ).decode(UTF_ENCODING) 64 expected_action["id"] = self.etl_name 65 expected_response_text = self.etl_name 66 mock_response = Mock() 67 mock_response.text = expected_response_text 68 self.mock_client.request.return_value = mock_response 69 70 response = self.etl.init_spec(template, **kwargs) 71 72 self.assertEqual(expected_response_text, response) 73 self.mock_client.request.assert_called_with( 74 HTTP_METHOD_PUT, path=URL_PATH_ETL, json=expected_action 75 ) 76 77 def test_init_code_default_runtime(self): 78 version_to_runtime = { 79 (3, 7): "python3.8v2", 80 (3, 1234): "python3.8v2", 81 (3, 8): "python3.8v2", 82 (3, 10): "python3.10v2", 83 (3, 11): "python3.11v2", 84 } 85 for version, runtime in version_to_runtime.items(): 86 with patch.object(aistore.sdk.etl.sys, "version_info") as version_info: 87 version_info.major = version[0] 88 version_info.minor = version[1] 89 self.assertEqual(runtime, _get_default_runtime()) 90 91 def test_init_code_default_params(self): 92 communication_type = ETL_COMM_HPUSH 93 94 expected_action = { 95 "runtime": _get_default_runtime(), 96 "communication": f"{communication_type}://", 97 "timeout": "5m", 98 "funcs": {"transform": "transform"}, 99 "code": self.encode_fn([], self.transform_fn, communication_type), 100 "dependencies": base64.b64encode(b"cloudpickle==2.2.0").decode( 101 UTF_ENCODING 102 ), 103 "argument": "", 104 } 105 self.init_code_exec_assert(expected_action) 106 107 def test_init_code_invalid_comm(self): 108 with self.assertRaises(ValueError): 109 self.etl.init_code(Mock(), communication_type="invalid") 110 111 def test_init_code(self): 112 runtime = "python-non-default" 113 communication_type = ETL_COMM_HPULL 114 timeout = "6m" 115 preimported = ["pytorch"] 116 user_dependencies = ["pytorch"] 117 chunk_size = 123 118 arg_type = "url" 119 120 expected_dependencies = user_dependencies.copy() 121 expected_dependencies.append("cloudpickle==2.2.0") 122 expected_dep_str = base64.b64encode( 123 "\n".join(expected_dependencies).encode(UTF_ENCODING) 124 ).decode(UTF_ENCODING) 125 126 expected_action = { 127 "runtime": runtime, 128 "communication": f"{communication_type}://", 129 "timeout": timeout, 130 "funcs": {"transform": "transform"}, 131 "code": self.encode_fn(preimported, self.transform_fn, communication_type), 132 "dependencies": expected_dep_str, 133 "chunk_size": chunk_size, 134 "argument": arg_type, 135 } 136 self.init_code_exec_assert( 137 expected_action, 138 preimported_modules=preimported, 139 dependencies=user_dependencies, 140 runtime=runtime, 141 communication_type=communication_type, 142 timeout=timeout, 143 chunk_size=chunk_size, 144 arg_type=arg_type, 145 ) 146 147 @staticmethod 148 def transform_fn(): 149 print("example action") 150 151 @staticmethod 152 def encode_fn(preimported_modules, func, comm_type): 153 transform = base64.b64encode(cloudpickle.dumps(func)).decode(UTF_ENCODING) 154 io_comm_context = "transform()" if comm_type == ETL_COMM_IO else "" 155 template = CODE_TEMPLATE.format( 156 preimported_modules, transform, io_comm_context 157 ).encode(UTF_ENCODING) 158 return base64.b64encode(template).decode(UTF_ENCODING) 159 160 def init_code_exec_assert(self, expected_action, **kwargs): 161 expected_action["id"] = self.etl_name 162 163 expected_response_text = "response text" 164 mock_response = Mock() 165 mock_response.text = expected_response_text 166 self.mock_client.request.return_value = mock_response 167 168 response = self.etl.init_code(transform=self.transform_fn, **kwargs) 169 170 self.assertEqual(expected_response_text, response) 171 self.mock_client.request.assert_called_with( 172 HTTP_METHOD_PUT, path=URL_PATH_ETL, json=expected_action 173 ) 174 175 def test_view(self): 176 mock_response = Mock() 177 self.mock_client.request_deserialize.return_value = mock_response 178 response = self.etl.view() 179 self.assertEqual(mock_response, response) 180 self.mock_client.request_deserialize.assert_called_with( 181 HTTP_METHOD_GET, path=f"etl/{ self.etl_name }", res_model=ETLDetails 182 ) 183 184 def test_start(self): 185 self.etl.start() 186 self.mock_client.request.assert_called_with( 187 HTTP_METHOD_POST, path=f"etl/{ self.etl_name }/start" 188 ) 189 190 def test_stop(self): 191 self.etl.stop() 192 self.mock_client.request.assert_called_with( 193 HTTP_METHOD_POST, path=f"etl/{ self.etl_name }/stop" 194 ) 195 196 def test_delete(self): 197 self.etl.delete() 198 self.mock_client.request.assert_called_with( 199 HTTP_METHOD_DELETE, path=f"etl/{ self.etl_name }" 200 )