github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/python/tests/unit/sdk/test_dsort.py (about) 1 import unittest 2 from typing import Dict 3 from unittest.mock import Mock, patch, mock_open, call 4 5 from aistore.sdk.const import ( 6 URL_PATH_DSORT, 7 HTTP_METHOD_POST, 8 DSORT_ABORT, 9 HTTP_METHOD_DELETE, 10 DSORT_UUID, 11 HTTP_METHOD_GET, 12 ) 13 from aistore.sdk.dsort import Dsort 14 from aistore.sdk.dsort_types import DsortMetrics, JobInfo 15 from aistore.sdk.errors import Timeout 16 from aistore.sdk.utils import probing_frequency 17 18 19 class TestDsort(unittest.TestCase): 20 def setUp(self) -> None: 21 self.mock_client = Mock() 22 self.dsort_id = "123" 23 self.dsort = Dsort(client=self.mock_client, dsort_id=self.dsort_id) 24 25 @staticmethod 26 def _get_mock_job_info(finished, aborted=False): 27 mock_metrics = Mock(DsortMetrics) 28 mock_metrics.aborted = aborted 29 mock_metrics.shard_creation = Mock(finished=finished) 30 mock_job_info = Mock(JobInfo) 31 mock_job_info.metrics = mock_metrics 32 return mock_job_info 33 34 def test_properties(self): 35 self.assertEqual(self.dsort_id, self.dsort.dsort_id) 36 37 @patch("aistore.sdk.dsort.validate_file") 38 @patch("aistore.sdk.dsort.json") 39 # pylint: disable=unused-argument 40 def test_start(self, mock_json, mock_validate_file): 41 new_id = "456" 42 spec = {"test_spec_entry": "test_spec_value"} 43 mock_request_return_val = Mock(text=new_id) 44 mock_json.load.return_value = spec 45 self.mock_client.request.return_value = mock_request_return_val 46 47 with patch("builtins.open", mock_open()): 48 res = self.dsort.start("spec_file") 49 50 self.assertEqual(new_id, res) 51 self.assertEqual(new_id, self.dsort.dsort_id) 52 self.mock_client.request.assert_called_with( 53 HTTP_METHOD_POST, path=URL_PATH_DSORT, json=spec 54 ) 55 56 def test_abort(self): 57 self.dsort.abort() 58 self.mock_client.request.assert_called_with( 59 HTTP_METHOD_DELETE, 60 path=f"{URL_PATH_DSORT}/{DSORT_ABORT}", 61 params={DSORT_UUID: [self.dsort_id]}, 62 ) 63 64 def test_get_job_info(self): 65 mock_job_info = {"id_1": Mock(JobInfo)} 66 self.mock_client.request_deserialize.return_value = mock_job_info 67 res = self.dsort.get_job_info() 68 self.assertEqual(mock_job_info, res) 69 self.mock_client.request_deserialize.assert_called_with( 70 HTTP_METHOD_GET, 71 path=URL_PATH_DSORT, 72 res_model=Dict[str, JobInfo], 73 params={DSORT_UUID: [self.dsort_id]}, 74 ) 75 76 @patch("aistore.sdk.dsort.time.sleep") 77 @patch("aistore.sdk.dsort.Dsort.get_job_info") 78 def test_wait_default_timeout(self, mock_get_job_info, mock_sleep): 79 timeout = 300 80 frequency = probing_frequency(timeout) 81 expected_job_info_calls = [ 82 call(), 83 call(), 84 call(), 85 ] 86 expected_sleep_calls = [call(frequency), call(frequency)] 87 self._wait_test_helper( 88 self.dsort, 89 mock_get_job_info, 90 mock_sleep, 91 expected_job_info_calls, 92 expected_sleep_calls, 93 ) 94 95 @patch("aistore.sdk.dsort.time.sleep") 96 @patch("aistore.sdk.dsort.Dsort.get_job_info") 97 def test_wait(self, mock_get_job_info, mock_sleep): 98 timeout = 20 99 frequency = probing_frequency(timeout) 100 expected_job_info_calls = [call(), call(), call()] 101 expected_sleep_calls = [call(frequency), call(frequency)] 102 self._wait_test_helper( 103 self.dsort, 104 mock_get_job_info, 105 mock_sleep, 106 expected_job_info_calls, 107 expected_sleep_calls, 108 timeout=timeout, 109 ) 110 111 @patch("aistore.sdk.dsort.time.sleep") 112 @patch("aistore.sdk.dsort.Dsort.get_job_info") 113 # pylint: disable=unused-argument 114 def test_wait_timeout(self, mock_get_job_info, mock_sleep): 115 mock_get_job_info.return_value = { 116 "key": self._get_mock_job_info(finished=False, aborted=False) 117 } 118 self.assertRaises(Timeout, self.dsort.wait) 119 120 @patch("aistore.sdk.dsort.time.sleep") 121 @patch("aistore.sdk.dsort.Dsort.get_job_info") 122 def test_wait_aborted(self, mock_get_job_info, mock_sleep): 123 timeout = 300 124 frequency = probing_frequency(timeout) 125 expected_metrics_calls = [ 126 call(), 127 call(), 128 ] 129 expected_sleep_calls = [call(frequency)] 130 mock_get_job_info.side_effect = [ 131 {"key": self._get_mock_job_info(finished=False)}, 132 {"key": self._get_mock_job_info(finished=False, aborted=True)}, 133 {"key": self._get_mock_job_info(finished=False)}, 134 ] 135 136 self._wait_exec_assert( 137 self.dsort, 138 mock_get_job_info, 139 mock_sleep, 140 expected_metrics_calls, 141 expected_sleep_calls, 142 ) 143 144 # pylint: disable=too-many-arguments 145 def _wait_test_helper( 146 self, 147 dsort, 148 mock_get_job_info, 149 mock_sleep, 150 expected_job_info_calls, 151 expected_sleep_calls, 152 **kwargs, 153 ): 154 mock_get_job_info.side_effect = [ 155 {"job_id": self._get_mock_job_info(finished=False)}, 156 {"job_id": self._get_mock_job_info(finished=False)}, 157 {"job_id": self._get_mock_job_info(finished=True)}, 158 ] 159 self._wait_exec_assert( 160 dsort, 161 mock_get_job_info, 162 mock_sleep, 163 expected_job_info_calls, 164 expected_sleep_calls, 165 **kwargs, 166 ) 167 168 def _wait_exec_assert( 169 self, 170 dsort, 171 mock_get_job_info, 172 mock_sleep, 173 expected_job_info_calls, 174 expected_sleep_calls, 175 **kwargs, 176 ): 177 dsort.wait(**kwargs) 178 179 mock_get_job_info.assert_has_calls(expected_job_info_calls) 180 mock_sleep.assert_has_calls(expected_sleep_calls) 181 self.assertEqual(len(expected_job_info_calls), mock_get_job_info.call_count) 182 self.assertEqual(len(expected_sleep_calls), mock_sleep.call_count)