github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/python/tests/unit/sdk/test_request_client.py (about) 1 import unittest 2 from unittest.mock import patch, Mock 3 4 from requests import Response 5 6 from aistore.sdk.const import ( 7 JSON_CONTENT_TYPE, 8 HEADER_USER_AGENT, 9 USER_AGENT_BASE, 10 HEADER_CONTENT_TYPE, 11 AIS_SERVER_CRT, 12 ) 13 from aistore.sdk.request_client import RequestClient 14 from aistore.version import __version__ as sdk_version 15 from tests.utils import test_cases 16 17 18 class TestRequestClient(unittest.TestCase): # pylint: disable=unused-variable 19 def setUp(self) -> None: 20 self.endpoint = "https://aistore-endpoint" 21 self.mock_session = Mock() 22 with patch("aistore.sdk.request_client.requests") as mock_requests_lib: 23 mock_requests_lib.sessions.session.return_value = self.mock_session 24 self.request_client = RequestClient( 25 self.endpoint, skip_verify=True, ca_cert="" 26 ) 27 28 self.request_headers = { 29 HEADER_CONTENT_TYPE: JSON_CONTENT_TYPE, 30 HEADER_USER_AGENT: f"{USER_AGENT_BASE}/{sdk_version}", 31 } 32 33 def test_default_session(self): 34 with patch( 35 "aistore.sdk.request_client.os.getenv", return_value=None 36 ) as mock_getenv: 37 self.request_client = RequestClient(self.endpoint) 38 mock_getenv.assert_called_with(AIS_SERVER_CRT) 39 self.assertEqual(True, self.request_client.session.verify) 40 41 @test_cases( 42 (("env-cert", "arg-cert", False), "arg-cert"), 43 (("env-cert", "arg-cert", True), False), 44 (("env-cert", None, False), "env-cert"), 45 ((True, None, False), True), 46 ((None, None, True), False), 47 ) 48 def test_session(self, test_case): 49 env_cert, arg_cert, skip_verify = test_case[0] 50 with patch( 51 "aistore.sdk.request_client.os.getenv", return_value=env_cert 52 ) as mock_getenv: 53 self.request_client = RequestClient( 54 self.endpoint, skip_verify=skip_verify, ca_cert=arg_cert 55 ) 56 if not skip_verify and not arg_cert: 57 mock_getenv.assert_called_with(AIS_SERVER_CRT) 58 self.assertEqual(test_case[1], self.request_client.session.verify) 59 60 def test_properties(self): 61 self.assertEqual(self.endpoint + "/v1", self.request_client.base_url) 62 self.assertEqual(self.endpoint, self.request_client.endpoint) 63 64 @patch("aistore.sdk.request_client.RequestClient.request") 65 @patch("aistore.sdk.request_client.decode_response") 66 def test_request_deserialize(self, mock_decode, mock_request): 67 method = "method" 68 path = "path" 69 decoded_value = "test value" 70 custom_kw = "arg" 71 mock_decode.return_value = decoded_value 72 mock_response = Mock(Response) 73 mock_request.return_value = mock_response 74 75 res = self.request_client.request_deserialize( 76 method, path, str, keyword=custom_kw 77 ) 78 79 self.assertEqual(decoded_value, res) 80 mock_request.assert_called_with(method, path, keyword=custom_kw) 81 mock_decode.assert_called_with(str, mock_response) 82 83 @test_cases(None, "http://custom_endpoint") 84 def test_request(self, endpoint_arg): 85 method = "request_method" 86 path = "request_path" 87 extra_kw_arg = "arg" 88 extra_headers = {"header_1_key": "header_1_val", "header_2_key": "header_2_val"} 89 self.request_headers.update(extra_headers) 90 if endpoint_arg: 91 req_url = f"{endpoint_arg}/v1/{path}" 92 else: 93 req_url = f"{self.request_client.base_url}/{path}" 94 95 mock_response = Mock() 96 mock_response.status_code = 200 97 self.mock_session.request.return_value = mock_response 98 if endpoint_arg: 99 res = self.request_client.request( 100 method, 101 path, 102 endpoint=endpoint_arg, 103 headers=extra_headers, 104 keyword=extra_kw_arg, 105 ) 106 else: 107 res = self.request_client.request( 108 method, path, headers=extra_headers, keyword=extra_kw_arg 109 ) 110 self.mock_session.request.assert_called_with( 111 method, 112 req_url, 113 headers=self.request_headers, 114 timeout=None, 115 keyword=extra_kw_arg, 116 ) 117 self.assertEqual(mock_response, res) 118 119 for response_code in [199, 300]: 120 with patch("aistore.sdk.request_client.handle_errors") as mock_handle_err: 121 mock_response.status_code = response_code 122 self.mock_session.request.return_value = mock_response 123 res = self.request_client.request( 124 method, 125 path, 126 endpoint=endpoint_arg, 127 headers=extra_headers, 128 keyword=extra_kw_arg, 129 ) 130 self.mock_session.request.assert_called_with( 131 method, 132 req_url, 133 headers=self.request_headers, 134 timeout=None, 135 keyword=extra_kw_arg, 136 ) 137 self.assertEqual(mock_response, res) 138 mock_handle_err.assert_called_once() 139 140 def test_get_full_url(self): 141 path = "/testpath/to_obj" 142 params = {"p1key": "p1val", "p2key": "p2val"} 143 res = self.request_client.get_full_url(path, params) 144 self.assertEqual( 145 "https://aistore-endpoint/v1/testpath/to_obj?p1key=p1val&p2key=p2val", res 146 )