github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/python/tests/unit/sdk/test_request_client.py (about)

     1  import unittest
     2  from unittest.mock import patch, Mock
     3  
     4  from requests import Response
     5  
     6  from aistore.sdk.const import (
     7      JSON_CONTENT_TYPE,
     8      HEADER_USER_AGENT,
     9      USER_AGENT_BASE,
    10      HEADER_CONTENT_TYPE,
    11      AIS_SERVER_CRT,
    12  )
    13  from aistore.sdk.request_client import RequestClient
    14  from aistore.version import __version__ as sdk_version
    15  from tests.utils import test_cases
    16  
    17  
    18  class TestRequestClient(unittest.TestCase):  # pylint: disable=unused-variable
    19      def setUp(self) -> None:
    20          self.endpoint = "https://aistore-endpoint"
    21          self.mock_session = Mock()
    22          with patch("aistore.sdk.request_client.requests") as mock_requests_lib:
    23              mock_requests_lib.sessions.session.return_value = self.mock_session
    24              self.request_client = RequestClient(
    25                  self.endpoint, skip_verify=True, ca_cert=""
    26              )
    27  
    28          self.request_headers = {
    29              HEADER_CONTENT_TYPE: JSON_CONTENT_TYPE,
    30              HEADER_USER_AGENT: f"{USER_AGENT_BASE}/{sdk_version}",
    31          }
    32  
    33      def test_default_session(self):
    34          with patch(
    35              "aistore.sdk.request_client.os.getenv", return_value=None
    36          ) as mock_getenv:
    37              self.request_client = RequestClient(self.endpoint)
    38              mock_getenv.assert_called_with(AIS_SERVER_CRT)
    39              self.assertEqual(True, self.request_client.session.verify)
    40  
    41      @test_cases(
    42          (("env-cert", "arg-cert", False), "arg-cert"),
    43          (("env-cert", "arg-cert", True), False),
    44          (("env-cert", None, False), "env-cert"),
    45          ((True, None, False), True),
    46          ((None, None, True), False),
    47      )
    48      def test_session(self, test_case):
    49          env_cert, arg_cert, skip_verify = test_case[0]
    50          with patch(
    51              "aistore.sdk.request_client.os.getenv", return_value=env_cert
    52          ) as mock_getenv:
    53              self.request_client = RequestClient(
    54                  self.endpoint, skip_verify=skip_verify, ca_cert=arg_cert
    55              )
    56              if not skip_verify and not arg_cert:
    57                  mock_getenv.assert_called_with(AIS_SERVER_CRT)
    58              self.assertEqual(test_case[1], self.request_client.session.verify)
    59  
    60      def test_properties(self):
    61          self.assertEqual(self.endpoint + "/v1", self.request_client.base_url)
    62          self.assertEqual(self.endpoint, self.request_client.endpoint)
    63  
    64      @patch("aistore.sdk.request_client.RequestClient.request")
    65      @patch("aistore.sdk.request_client.decode_response")
    66      def test_request_deserialize(self, mock_decode, mock_request):
    67          method = "method"
    68          path = "path"
    69          decoded_value = "test value"
    70          custom_kw = "arg"
    71          mock_decode.return_value = decoded_value
    72          mock_response = Mock(Response)
    73          mock_request.return_value = mock_response
    74  
    75          res = self.request_client.request_deserialize(
    76              method, path, str, keyword=custom_kw
    77          )
    78  
    79          self.assertEqual(decoded_value, res)
    80          mock_request.assert_called_with(method, path, keyword=custom_kw)
    81          mock_decode.assert_called_with(str, mock_response)
    82  
    83      @test_cases(None, "http://custom_endpoint")
    84      def test_request(self, endpoint_arg):
    85          method = "request_method"
    86          path = "request_path"
    87          extra_kw_arg = "arg"
    88          extra_headers = {"header_1_key": "header_1_val", "header_2_key": "header_2_val"}
    89          self.request_headers.update(extra_headers)
    90          if endpoint_arg:
    91              req_url = f"{endpoint_arg}/v1/{path}"
    92          else:
    93              req_url = f"{self.request_client.base_url}/{path}"
    94  
    95          mock_response = Mock()
    96          mock_response.status_code = 200
    97          self.mock_session.request.return_value = mock_response
    98          if endpoint_arg:
    99              res = self.request_client.request(
   100                  method,
   101                  path,
   102                  endpoint=endpoint_arg,
   103                  headers=extra_headers,
   104                  keyword=extra_kw_arg,
   105              )
   106          else:
   107              res = self.request_client.request(
   108                  method, path, headers=extra_headers, keyword=extra_kw_arg
   109              )
   110          self.mock_session.request.assert_called_with(
   111              method,
   112              req_url,
   113              headers=self.request_headers,
   114              timeout=None,
   115              keyword=extra_kw_arg,
   116          )
   117          self.assertEqual(mock_response, res)
   118  
   119          for response_code in [199, 300]:
   120              with patch("aistore.sdk.request_client.handle_errors") as mock_handle_err:
   121                  mock_response.status_code = response_code
   122                  self.mock_session.request.return_value = mock_response
   123                  res = self.request_client.request(
   124                      method,
   125                      path,
   126                      endpoint=endpoint_arg,
   127                      headers=extra_headers,
   128                      keyword=extra_kw_arg,
   129                  )
   130                  self.mock_session.request.assert_called_with(
   131                      method,
   132                      req_url,
   133                      headers=self.request_headers,
   134                      timeout=None,
   135                      keyword=extra_kw_arg,
   136                  )
   137                  self.assertEqual(mock_response, res)
   138                  mock_handle_err.assert_called_once()
   139  
   140      def test_get_full_url(self):
   141          path = "/testpath/to_obj"
   142          params = {"p1key": "p1val", "p2key": "p2val"}
   143          res = self.request_client.get_full_url(path, params)
   144          self.assertEqual(
   145              "https://aistore-endpoint/v1/testpath/to_obj?p1key=p1val&p2key=p2val", res
   146          )