github.com/NVIDIA/aistore@v1.3.23-0.20240517131212-7df6609be51d/python/tests/unit/sdk/test_object.py (about)

     1  import unittest
     2  from unittest.mock import Mock, patch, mock_open
     3  
     4  from requests import Response
     5  from requests.structures import CaseInsensitiveDict
     6  
     7  from aistore.sdk.const import (
     8      HTTP_METHOD_HEAD,
     9      DEFAULT_CHUNK_SIZE,
    10      HTTP_METHOD_GET,
    11      QPARAM_ARCHPATH,
    12      QPARAM_ETL_NAME,
    13      HTTP_METHOD_PUT,
    14      HTTP_METHOD_DELETE,
    15      HEADER_CONTENT_LENGTH,
    16      AIS_CHECKSUM_VALUE,
    17      AIS_CHECKSUM_TYPE,
    18      AIS_ACCESS_TIME,
    19      AIS_VERSION,
    20      AIS_CUSTOM_MD,
    21      HTTP_METHOD_POST,
    22      ACT_PROMOTE,
    23      ACT_BLOB_DOWNLOAD,
    24      URL_PATH_OBJECTS,
    25      HEADER_OBJECT_BLOB_DOWNLOAD,
    26      HEADER_OBJECT_BLOB_CHUNK_SIZE,
    27      HEADER_OBJECT_BLOB_WORKERS,
    28  )
    29  from aistore.sdk.object import Object
    30  from aistore.sdk.object_reader import ObjectReader
    31  from aistore.sdk.types import ActionMsg, BlobMsg, PromoteAPIArgs
    32  from tests.const import SMALL_FILE_SIZE, ETL_NAME
    33  
    34  BCK_NAME = "bucket_name"
    35  OBJ_NAME = "object_name"
    36  REQUEST_PATH = f"{URL_PATH_OBJECTS}/{BCK_NAME}/{OBJ_NAME}"
    37  
    38  
    39  # pylint: disable=unused-variable, too-many-locals
    40  class TestObject(unittest.TestCase):
    41      def setUp(self) -> None:
    42          self.mock_client = Mock()
    43          self.mock_bucket = Mock()
    44          self.mock_bucket.client = self.mock_client
    45          self.mock_bucket.name = BCK_NAME
    46          self.mock_writer = Mock()
    47          self.mock_bucket.qparam = {}
    48          self.expected_params = {}
    49          self.object = Object(self.mock_bucket, OBJ_NAME)
    50  
    51      def test_properties(self):
    52          self.assertEqual(self.mock_bucket, self.object.bucket)
    53          self.assertEqual(OBJ_NAME, self.object.name)
    54  
    55      def test_head(self):
    56          self.object.head()
    57  
    58          self.mock_client.request.assert_called_with(
    59              HTTP_METHOD_HEAD,
    60              path=REQUEST_PATH,
    61              params=self.expected_params,
    62          )
    63  
    64      def test_get_default_params(self):
    65          self.expected_params[QPARAM_ARCHPATH] = ""
    66          self.get_exec_assert()
    67  
    68      def test_get(self):
    69          archpath_param = "archpath"
    70          blob_chunk_size = "4mb"
    71          blob_num_workers = 10
    72          self.expected_params[QPARAM_ARCHPATH] = archpath_param
    73          self.expected_params[QPARAM_ETL_NAME] = ETL_NAME
    74          self.get_exec_assert(
    75              archpath=archpath_param,
    76              chunk_size=3,
    77              etl_name=ETL_NAME,
    78              writer=self.mock_writer,
    79              blob_chunk_size=blob_chunk_size,
    80              blob_num_workers=blob_num_workers,
    81          )
    82  
    83      def get_exec_assert(self, **kwargs):
    84          content = b"123456789"
    85          content_length = 9
    86          ais_check_val = "xyz"
    87          ais_check_type = "md5"
    88          ais_atime = "time string"
    89          ais_version = "3"
    90          custom_metadata_dict = {"key1": "val1", "key2": "val2"}
    91          custom_metadata = ", ".join(
    92              ["=".join(kv) for kv in custom_metadata_dict.items()]
    93          )
    94          resp_headers = CaseInsensitiveDict(
    95              {
    96                  HEADER_CONTENT_LENGTH: content_length,
    97                  AIS_CHECKSUM_VALUE: ais_check_val,
    98                  AIS_CHECKSUM_TYPE: ais_check_type,
    99                  AIS_ACCESS_TIME: ais_atime,
   100                  AIS_VERSION: ais_version,
   101                  AIS_CUSTOM_MD: custom_metadata,
   102              }
   103          )
   104          mock_response = Mock(Response)
   105          mock_response.headers = resp_headers
   106          mock_response.iter_content.return_value = content
   107          mock_response.raw = content
   108          expected_obj = ObjectReader(
   109              response_headers=resp_headers,
   110              stream=mock_response,
   111          )
   112          self.mock_client.request.return_value = mock_response
   113  
   114          res = self.object.get(**kwargs)
   115          blob_chunk_size = kwargs.get("blob_chunk_size")
   116          blob_num_workers = kwargs.get("blob_num_workers")
   117          headers = {}
   118          if blob_chunk_size or blob_num_workers:
   119              headers[HEADER_OBJECT_BLOB_DOWNLOAD] = "true"
   120          if blob_chunk_size:
   121              headers[HEADER_OBJECT_BLOB_CHUNK_SIZE] = blob_chunk_size
   122          if blob_num_workers:
   123              headers[HEADER_OBJECT_BLOB_WORKERS] = blob_num_workers
   124  
   125          self.assertEqual(expected_obj.raw(), res.raw())
   126          self.assertEqual(content_length, res.attributes.size)
   127          self.assertEqual(ais_check_type, res.attributes.checksum_type)
   128          self.assertEqual(ais_check_val, res.attributes.checksum_value)
   129          self.assertEqual(ais_atime, res.attributes.access_time)
   130          self.assertEqual(ais_version, res.attributes.obj_version)
   131          self.assertEqual(custom_metadata_dict, res.attributes.custom_metadata)
   132          self.mock_client.request.assert_called_with(
   133              HTTP_METHOD_GET,
   134              path=REQUEST_PATH,
   135              params=self.expected_params,
   136              stream=True,
   137              headers=headers,
   138          )
   139  
   140          # Use the object reader iterator to call the stream with the chunk size
   141          for _ in res:
   142              continue
   143          mock_response.iter_content.assert_called_with(
   144              chunk_size=kwargs.get("chunk_size", DEFAULT_CHUNK_SIZE)
   145          )
   146  
   147          if "writer" in kwargs:
   148              self.mock_writer.writelines.assert_called_with(res)
   149  
   150      def test_get_url(self):
   151          expected_res = "full url"
   152          archpath = "arch"
   153          self.mock_client.get_full_url.return_value = expected_res
   154          res = self.object.get_url(archpath=archpath, etl_name=ETL_NAME)
   155          self.assertEqual(expected_res, res)
   156          self.mock_client.get_full_url.assert_called_with(
   157              REQUEST_PATH, {QPARAM_ARCHPATH: archpath, QPARAM_ETL_NAME: ETL_NAME}
   158          )
   159  
   160      @patch("pathlib.Path.is_file")
   161      @patch("pathlib.Path.exists")
   162      def test_put_file(self, mock_exists, mock_is_file):
   163          mock_exists.return_value = True
   164          mock_is_file.return_value = True
   165          path = "any/filepath"
   166          data = b"bytes in the file"
   167  
   168          with patch("builtins.open", mock_open(read_data=data)):
   169              self.object.put_file(path)
   170  
   171          self.mock_client.request.assert_called_with(
   172              HTTP_METHOD_PUT,
   173              path=REQUEST_PATH,
   174              params=self.expected_params,
   175              data=data,
   176          )
   177  
   178      def test_put_content(self):
   179          content = b"user-supplied-bytes"
   180          self.object.put_content(content)
   181          self.mock_client.request.assert_called_with(
   182              HTTP_METHOD_PUT,
   183              path=REQUEST_PATH,
   184              params=self.expected_params,
   185              data=content,
   186          )
   187  
   188      def test_promote_default_args(self):
   189          filename = "promoted file"
   190          expected_value = PromoteAPIArgs(source_path=filename, object_name=OBJ_NAME)
   191          self.promote_exec_assert(filename, expected_value)
   192  
   193      def test_promote(self):
   194          filename = "promoted file"
   195          target_id = "target node"
   196          recursive = True
   197          overwrite_dest = True
   198          delete_source = True
   199          src_not_file_share = True
   200          expected_value = PromoteAPIArgs(
   201              source_path=filename,
   202              object_name=OBJ_NAME,
   203              target_id=target_id,
   204              recursive=recursive,
   205              overwrite_dest=overwrite_dest,
   206              delete_source=delete_source,
   207              src_not_file_share=src_not_file_share,
   208          )
   209          self.promote_exec_assert(
   210              filename,
   211              expected_value,
   212              target_id=target_id,
   213              recursive=recursive,
   214              overwrite_dest=overwrite_dest,
   215              delete_source=delete_source,
   216              src_not_file_share=src_not_file_share,
   217          )
   218  
   219      def promote_exec_assert(self, filename, expected_value, **kwargs):
   220          request_path = f"{URL_PATH_OBJECTS}/{BCK_NAME}"
   221          expected_json = ActionMsg(
   222              action=ACT_PROMOTE, name=filename, value=expected_value.as_dict()
   223          ).dict()
   224          self.object.promote(filename, **kwargs)
   225          self.mock_client.request.assert_called_with(
   226              HTTP_METHOD_POST,
   227              path=request_path,
   228              params=self.expected_params,
   229              json=expected_json,
   230          )
   231  
   232      def test_delete(self):
   233          self.object.delete()
   234          self.mock_client.request.assert_called_with(
   235              HTTP_METHOD_DELETE, path=REQUEST_PATH, params=self.expected_params
   236          )
   237  
   238      def test_blob_download_default_args(self):
   239          request_path = f"{URL_PATH_OBJECTS}/{BCK_NAME}"
   240          expected_blob_msg = BlobMsg(
   241              chunk_size=None,
   242              num_workers=None,
   243              latest=False,
   244          ).as_dict()
   245          expected_json = ActionMsg(
   246              action=ACT_BLOB_DOWNLOAD, name=OBJ_NAME, value=expected_blob_msg
   247          ).dict()
   248          self.object.blob_download()
   249          self.mock_client.request.assert_called_with(
   250              HTTP_METHOD_POST,
   251              path=request_path,
   252              params=self.expected_params,
   253              json=expected_json,
   254          )
   255  
   256      def test_blob_download(self):
   257          request_path = f"{URL_PATH_OBJECTS}/{BCK_NAME}"
   258          chunk_size = SMALL_FILE_SIZE
   259          num_workers = 10
   260          latest = True
   261          expected_blob_msg = BlobMsg(
   262              chunk_size=chunk_size,
   263              num_workers=num_workers,
   264              latest=latest,
   265          ).as_dict()
   266          expected_json = ActionMsg(
   267              action=ACT_BLOB_DOWNLOAD, name=OBJ_NAME, value=expected_blob_msg
   268          ).dict()
   269          self.object.blob_download(
   270              num_workers=num_workers, chunk_size=chunk_size, latest=latest
   271          )
   272          self.mock_client.request.assert_called_with(
   273              HTTP_METHOD_POST,
   274              path=request_path,
   275              params=self.expected_params,
   276              json=expected_json,
   277          )