github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/python/tests/unit/sdk/test_object.py (about) 1 import unittest 2 from unittest.mock import Mock, patch, mock_open 3 4 from requests import Response 5 from requests.structures import CaseInsensitiveDict 6 7 from aistore.sdk.const import ( 8 HTTP_METHOD_HEAD, 9 DEFAULT_CHUNK_SIZE, 10 HTTP_METHOD_GET, 11 QPARAM_ARCHPATH, 12 QPARAM_ETL_NAME, 13 HTTP_METHOD_PUT, 14 HTTP_METHOD_DELETE, 15 HEADER_CONTENT_LENGTH, 16 AIS_CHECKSUM_VALUE, 17 AIS_CHECKSUM_TYPE, 18 AIS_ACCESS_TIME, 19 AIS_VERSION, 20 AIS_CUSTOM_MD, 21 HTTP_METHOD_POST, 22 ACT_PROMOTE, 23 ACT_BLOB_DOWNLOAD, 24 URL_PATH_OBJECTS, 25 HEADER_OBJECT_BLOB_DOWNLOAD, 26 HEADER_OBJECT_BLOB_CHUNK_SIZE, 27 HEADER_OBJECT_BLOB_WORKERS, 28 ) 29 from aistore.sdk.object import Object 30 from aistore.sdk.object_reader import ObjectReader 31 from aistore.sdk.types import ActionMsg, BlobMsg, PromoteAPIArgs 32 from tests.const import SMALL_FILE_SIZE, ETL_NAME 33 34 BCK_NAME = "bucket_name" 35 OBJ_NAME = "object_name" 36 REQUEST_PATH = f"{URL_PATH_OBJECTS}/{BCK_NAME}/{OBJ_NAME}" 37 38 39 # pylint: disable=unused-variable, too-many-locals 40 class TestObject(unittest.TestCase): 41 def setUp(self) -> None: 42 self.mock_client = Mock() 43 self.mock_bucket = Mock() 44 self.mock_bucket.client = self.mock_client 45 self.mock_bucket.name = BCK_NAME 46 self.mock_writer = Mock() 47 self.mock_bucket.qparam = {} 48 self.expected_params = {} 49 self.object = Object(self.mock_bucket, OBJ_NAME) 50 51 def test_properties(self): 52 self.assertEqual(self.mock_bucket, self.object.bucket) 53 self.assertEqual(OBJ_NAME, self.object.name) 54 55 def test_head(self): 56 self.object.head() 57 58 self.mock_client.request.assert_called_with( 59 HTTP_METHOD_HEAD, 60 path=REQUEST_PATH, 61 params=self.expected_params, 62 ) 63 64 def test_get_default_params(self): 65 self.expected_params[QPARAM_ARCHPATH] = "" 66 self.get_exec_assert() 67 68 def test_get(self): 69 archpath_param = "archpath" 70 blob_chunk_size = "4mb" 71 blob_num_workers = 10 72 self.expected_params[QPARAM_ARCHPATH] = archpath_param 73 self.expected_params[QPARAM_ETL_NAME] = ETL_NAME 74 self.get_exec_assert( 75 archpath=archpath_param, 76 chunk_size=3, 77 etl_name=ETL_NAME, 78 writer=self.mock_writer, 79 blob_chunk_size=blob_chunk_size, 80 blob_num_workers=blob_num_workers, 81 ) 82 83 def get_exec_assert(self, **kwargs): 84 content = b"123456789" 85 content_length = 9 86 ais_check_val = "xyz" 87 ais_check_type = "md5" 88 ais_atime = "time string" 89 ais_version = "3" 90 custom_metadata_dict = {"key1": "val1", "key2": "val2"} 91 custom_metadata = ", ".join( 92 ["=".join(kv) for kv in custom_metadata_dict.items()] 93 ) 94 resp_headers = CaseInsensitiveDict( 95 { 96 HEADER_CONTENT_LENGTH: content_length, 97 AIS_CHECKSUM_VALUE: ais_check_val, 98 AIS_CHECKSUM_TYPE: ais_check_type, 99 AIS_ACCESS_TIME: ais_atime, 100 AIS_VERSION: ais_version, 101 AIS_CUSTOM_MD: custom_metadata, 102 } 103 ) 104 mock_response = Mock(Response) 105 mock_response.headers = resp_headers 106 mock_response.iter_content.return_value = content 107 mock_response.raw = content 108 expected_obj = ObjectReader( 109 response_headers=resp_headers, 110 stream=mock_response, 111 ) 112 self.mock_client.request.return_value = mock_response 113 114 res = self.object.get(**kwargs) 115 blob_chunk_size = kwargs.get("blob_chunk_size") 116 blob_num_workers = kwargs.get("blob_num_workers") 117 headers = {} 118 if blob_chunk_size or blob_num_workers: 119 headers[HEADER_OBJECT_BLOB_DOWNLOAD] = "true" 120 if blob_chunk_size: 121 headers[HEADER_OBJECT_BLOB_CHUNK_SIZE] = blob_chunk_size 122 if blob_num_workers: 123 headers[HEADER_OBJECT_BLOB_WORKERS] = blob_num_workers 124 125 self.assertEqual(expected_obj.raw(), res.raw()) 126 self.assertEqual(content_length, res.attributes.size) 127 self.assertEqual(ais_check_type, res.attributes.checksum_type) 128 self.assertEqual(ais_check_val, res.attributes.checksum_value) 129 self.assertEqual(ais_atime, res.attributes.access_time) 130 self.assertEqual(ais_version, res.attributes.obj_version) 131 self.assertEqual(custom_metadata_dict, res.attributes.custom_metadata) 132 self.mock_client.request.assert_called_with( 133 HTTP_METHOD_GET, 134 path=REQUEST_PATH, 135 params=self.expected_params, 136 stream=True, 137 headers=headers, 138 ) 139 140 # Use the object reader iterator to call the stream with the chunk size 141 for _ in res: 142 continue 143 mock_response.iter_content.assert_called_with( 144 chunk_size=kwargs.get("chunk_size", DEFAULT_CHUNK_SIZE) 145 ) 146 147 if "writer" in kwargs: 148 self.mock_writer.writelines.assert_called_with(res) 149 150 def test_get_url(self): 151 expected_res = "full url" 152 archpath = "arch" 153 self.mock_client.get_full_url.return_value = expected_res 154 res = self.object.get_url(archpath=archpath, etl_name=ETL_NAME) 155 self.assertEqual(expected_res, res) 156 self.mock_client.get_full_url.assert_called_with( 157 REQUEST_PATH, {QPARAM_ARCHPATH: archpath, QPARAM_ETL_NAME: ETL_NAME} 158 ) 159 160 @patch("pathlib.Path.is_file") 161 @patch("pathlib.Path.exists") 162 def test_put_file(self, mock_exists, mock_is_file): 163 mock_exists.return_value = True 164 mock_is_file.return_value = True 165 path = "any/filepath" 166 data = b"bytes in the file" 167 168 with patch("builtins.open", mock_open(read_data=data)): 169 self.object.put_file(path) 170 171 self.mock_client.request.assert_called_with( 172 HTTP_METHOD_PUT, 173 path=REQUEST_PATH, 174 params=self.expected_params, 175 data=data, 176 ) 177 178 def test_put_content(self): 179 content = b"user-supplied-bytes" 180 self.object.put_content(content) 181 self.mock_client.request.assert_called_with( 182 HTTP_METHOD_PUT, 183 path=REQUEST_PATH, 184 params=self.expected_params, 185 data=content, 186 ) 187 188 def test_promote_default_args(self): 189 filename = "promoted file" 190 expected_value = PromoteAPIArgs(source_path=filename, object_name=OBJ_NAME) 191 self.promote_exec_assert(filename, expected_value) 192 193 def test_promote(self): 194 filename = "promoted file" 195 target_id = "target node" 196 recursive = True 197 overwrite_dest = True 198 delete_source = True 199 src_not_file_share = True 200 expected_value = PromoteAPIArgs( 201 source_path=filename, 202 object_name=OBJ_NAME, 203 target_id=target_id, 204 recursive=recursive, 205 overwrite_dest=overwrite_dest, 206 delete_source=delete_source, 207 src_not_file_share=src_not_file_share, 208 ) 209 self.promote_exec_assert( 210 filename, 211 expected_value, 212 target_id=target_id, 213 recursive=recursive, 214 overwrite_dest=overwrite_dest, 215 delete_source=delete_source, 216 src_not_file_share=src_not_file_share, 217 ) 218 219 def promote_exec_assert(self, filename, expected_value, **kwargs): 220 request_path = f"{URL_PATH_OBJECTS}/{BCK_NAME}" 221 expected_json = ActionMsg( 222 action=ACT_PROMOTE, name=filename, value=expected_value.as_dict() 223 ).dict() 224 self.object.promote(filename, **kwargs) 225 self.mock_client.request.assert_called_with( 226 HTTP_METHOD_POST, 227 path=request_path, 228 params=self.expected_params, 229 json=expected_json, 230 ) 231 232 def test_delete(self): 233 self.object.delete() 234 self.mock_client.request.assert_called_with( 235 HTTP_METHOD_DELETE, path=REQUEST_PATH, params=self.expected_params 236 ) 237 238 def test_blob_download_default_args(self): 239 request_path = f"{URL_PATH_OBJECTS}/{BCK_NAME}" 240 expected_blob_msg = BlobMsg( 241 chunk_size=None, 242 num_workers=None, 243 latest=False, 244 ).as_dict() 245 expected_json = ActionMsg( 246 action=ACT_BLOB_DOWNLOAD, name=OBJ_NAME, value=expected_blob_msg 247 ).dict() 248 self.object.blob_download() 249 self.mock_client.request.assert_called_with( 250 HTTP_METHOD_POST, 251 path=request_path, 252 params=self.expected_params, 253 json=expected_json, 254 ) 255 256 def test_blob_download(self): 257 request_path = f"{URL_PATH_OBJECTS}/{BCK_NAME}" 258 chunk_size = SMALL_FILE_SIZE 259 num_workers = 10 260 latest = True 261 expected_blob_msg = BlobMsg( 262 chunk_size=chunk_size, 263 num_workers=num_workers, 264 latest=latest, 265 ).as_dict() 266 expected_json = ActionMsg( 267 action=ACT_BLOB_DOWNLOAD, name=OBJ_NAME, value=expected_blob_msg 268 ).dict() 269 self.object.blob_download( 270 num_workers=num_workers, chunk_size=chunk_size, latest=latest 271 ) 272 self.mock_client.request.assert_called_with( 273 HTTP_METHOD_POST, 274 path=request_path, 275 params=self.expected_params, 276 json=expected_json, 277 )