github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/python/tests/unit/sdk/test_dsort.py (about)

     1  import unittest
     2  from typing import Dict
     3  from unittest.mock import Mock, patch, mock_open, call
     4  
     5  from aistore.sdk.const import (
     6      URL_PATH_DSORT,
     7      HTTP_METHOD_POST,
     8      DSORT_ABORT,
     9      HTTP_METHOD_DELETE,
    10      DSORT_UUID,
    11      HTTP_METHOD_GET,
    12  )
    13  from aistore.sdk.dsort import Dsort
    14  from aistore.sdk.dsort_types import DsortMetrics, JobInfo
    15  from aistore.sdk.errors import Timeout
    16  from aistore.sdk.utils import probing_frequency
    17  
    18  
    19  class TestDsort(unittest.TestCase):
    20      def setUp(self) -> None:
    21          self.mock_client = Mock()
    22          self.dsort_id = "123"
    23          self.dsort = Dsort(client=self.mock_client, dsort_id=self.dsort_id)
    24  
    25      @staticmethod
    26      def _get_mock_job_info(finished, aborted=False):
    27          mock_metrics = Mock(DsortMetrics)
    28          mock_metrics.aborted = aborted
    29          mock_metrics.shard_creation = Mock(finished=finished)
    30          mock_job_info = Mock(JobInfo)
    31          mock_job_info.metrics = mock_metrics
    32          return mock_job_info
    33  
    34      def test_properties(self):
    35          self.assertEqual(self.dsort_id, self.dsort.dsort_id)
    36  
    37      @patch("aistore.sdk.dsort.validate_file")
    38      @patch("aistore.sdk.dsort.json")
    39      # pylint: disable=unused-argument
    40      def test_start(self, mock_json, mock_validate_file):
    41          new_id = "456"
    42          spec = {"test_spec_entry": "test_spec_value"}
    43          mock_request_return_val = Mock(text=new_id)
    44          mock_json.load.return_value = spec
    45          self.mock_client.request.return_value = mock_request_return_val
    46  
    47          with patch("builtins.open", mock_open()):
    48              res = self.dsort.start("spec_file")
    49  
    50          self.assertEqual(new_id, res)
    51          self.assertEqual(new_id, self.dsort.dsort_id)
    52          self.mock_client.request.assert_called_with(
    53              HTTP_METHOD_POST, path=URL_PATH_DSORT, json=spec
    54          )
    55  
    56      def test_abort(self):
    57          self.dsort.abort()
    58          self.mock_client.request.assert_called_with(
    59              HTTP_METHOD_DELETE,
    60              path=f"{URL_PATH_DSORT}/{DSORT_ABORT}",
    61              params={DSORT_UUID: [self.dsort_id]},
    62          )
    63  
    64      def test_get_job_info(self):
    65          mock_job_info = {"id_1": Mock(JobInfo)}
    66          self.mock_client.request_deserialize.return_value = mock_job_info
    67          res = self.dsort.get_job_info()
    68          self.assertEqual(mock_job_info, res)
    69          self.mock_client.request_deserialize.assert_called_with(
    70              HTTP_METHOD_GET,
    71              path=URL_PATH_DSORT,
    72              res_model=Dict[str, JobInfo],
    73              params={DSORT_UUID: [self.dsort_id]},
    74          )
    75  
    76      @patch("aistore.sdk.dsort.time.sleep")
    77      @patch("aistore.sdk.dsort.Dsort.get_job_info")
    78      def test_wait_default_timeout(self, mock_get_job_info, mock_sleep):
    79          timeout = 300
    80          frequency = probing_frequency(timeout)
    81          expected_job_info_calls = [
    82              call(),
    83              call(),
    84              call(),
    85          ]
    86          expected_sleep_calls = [call(frequency), call(frequency)]
    87          self._wait_test_helper(
    88              self.dsort,
    89              mock_get_job_info,
    90              mock_sleep,
    91              expected_job_info_calls,
    92              expected_sleep_calls,
    93          )
    94  
    95      @patch("aistore.sdk.dsort.time.sleep")
    96      @patch("aistore.sdk.dsort.Dsort.get_job_info")
    97      def test_wait(self, mock_get_job_info, mock_sleep):
    98          timeout = 20
    99          frequency = probing_frequency(timeout)
   100          expected_job_info_calls = [call(), call(), call()]
   101          expected_sleep_calls = [call(frequency), call(frequency)]
   102          self._wait_test_helper(
   103              self.dsort,
   104              mock_get_job_info,
   105              mock_sleep,
   106              expected_job_info_calls,
   107              expected_sleep_calls,
   108              timeout=timeout,
   109          )
   110  
   111      @patch("aistore.sdk.dsort.time.sleep")
   112      @patch("aistore.sdk.dsort.Dsort.get_job_info")
   113      # pylint: disable=unused-argument
   114      def test_wait_timeout(self, mock_get_job_info, mock_sleep):
   115          mock_get_job_info.return_value = {
   116              "key": self._get_mock_job_info(finished=False, aborted=False)
   117          }
   118          self.assertRaises(Timeout, self.dsort.wait)
   119  
   120      @patch("aistore.sdk.dsort.time.sleep")
   121      @patch("aistore.sdk.dsort.Dsort.get_job_info")
   122      def test_wait_aborted(self, mock_get_job_info, mock_sleep):
   123          timeout = 300
   124          frequency = probing_frequency(timeout)
   125          expected_metrics_calls = [
   126              call(),
   127              call(),
   128          ]
   129          expected_sleep_calls = [call(frequency)]
   130          mock_get_job_info.side_effect = [
   131              {"key": self._get_mock_job_info(finished=False)},
   132              {"key": self._get_mock_job_info(finished=False, aborted=True)},
   133              {"key": self._get_mock_job_info(finished=False)},
   134          ]
   135  
   136          self._wait_exec_assert(
   137              self.dsort,
   138              mock_get_job_info,
   139              mock_sleep,
   140              expected_metrics_calls,
   141              expected_sleep_calls,
   142          )
   143  
   144      # pylint: disable=too-many-arguments
   145      def _wait_test_helper(
   146          self,
   147          dsort,
   148          mock_get_job_info,
   149          mock_sleep,
   150          expected_job_info_calls,
   151          expected_sleep_calls,
   152          **kwargs,
   153      ):
   154          mock_get_job_info.side_effect = [
   155              {"job_id": self._get_mock_job_info(finished=False)},
   156              {"job_id": self._get_mock_job_info(finished=False)},
   157              {"job_id": self._get_mock_job_info(finished=True)},
   158          ]
   159          self._wait_exec_assert(
   160              dsort,
   161              mock_get_job_info,
   162              mock_sleep,
   163              expected_job_info_calls,
   164              expected_sleep_calls,
   165              **kwargs,
   166          )
   167  
   168      def _wait_exec_assert(
   169          self,
   170          dsort,
   171          mock_get_job_info,
   172          mock_sleep,
   173          expected_job_info_calls,
   174          expected_sleep_calls,
   175          **kwargs,
   176      ):
   177          dsort.wait(**kwargs)
   178  
   179          mock_get_job_info.assert_has_calls(expected_job_info_calls)
   180          mock_sleep.assert_has_calls(expected_sleep_calls)
   181          self.assertEqual(len(expected_job_info_calls), mock_get_job_info.call_count)
   182          self.assertEqual(len(expected_sleep_calls), mock_sleep.call_count)