github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/gcp/tests/utils_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Unittest for GCP testing utils."""
    19  
    20  # pytype: skip-file
    21  
    22  import logging
    23  import unittest
    24  
    25  import mock
    26  
    27  from apache_beam.io.gcp.pubsub import PubsubMessage
    28  from apache_beam.io.gcp.tests import utils
    29  from apache_beam.testing import test_utils
    30  
    31  # Protect against environments where bigquery library is not available.
    32  try:
    33    from google.api_core import exceptions as gexc
    34    from google.cloud import bigquery
    35    from google.cloud import pubsub
    36  except ImportError:
    37    gexc = None
    38    bigquery = None
    39    pubsub = None
    40  
    41  
    42  @unittest.skipIf(bigquery is None, 'Bigquery dependencies are not installed.')
    43  @mock.patch.object(bigquery, 'Client')
    44  class UtilsTest(unittest.TestCase):
    45    def setUp(self):
    46      test_utils.patch_retry(self, utils)
    47  
    48    @mock.patch.object(bigquery, 'Dataset')
    49    def test_create_bq_dataset(self, mock_dataset, mock_client):
    50      mock_client.dataset.return_value = 'dataset_ref'
    51      mock_dataset.return_value = 'dataset_obj'
    52  
    53      utils.create_bq_dataset('project', 'dataset_base_name')
    54      mock_client.return_value.create_dataset.assert_called_with('dataset_obj')
    55  
    56    def test_delete_bq_dataset(self, mock_client):
    57      utils.delete_bq_dataset('project', 'dataset_ref')
    58      mock_client.return_value.delete_dataset.assert_called_with(
    59          'dataset_ref', delete_contents=mock.ANY)
    60  
    61    def test_delete_table_succeeds(self, mock_client):
    62      mock_client.return_value.dataset.return_value.table.return_value = (
    63          'table_ref')
    64  
    65      utils.delete_bq_table('unused_project', 'unused_dataset', 'unused_table')
    66      mock_client.return_value.delete_table.assert_called_with('table_ref')
    67  
    68    def test_delete_table_fails_not_found(self, mock_client):
    69      mock_client.return_value.dataset.return_value.table.return_value = (
    70          'table_ref')
    71      mock_client.return_value.delete_table.side_effect = gexc.NotFound('test')
    72  
    73      with self.assertRaisesRegex(Exception, r'does not exist:.*table_ref'):
    74        utils.delete_bq_table('unused_project', 'unused_dataset', 'unused_table')
    75  
    76  
    77  @unittest.skipIf(pubsub is None, 'GCP dependencies are not installed')
    78  class PubSubUtilTest(unittest.TestCase):
    79    def test_write_to_pubsub(self):
    80      mock_pubsub = mock.Mock()
    81      topic_path = "project/fakeproj/topics/faketopic"
    82      data = b'data'
    83      utils.write_to_pubsub(mock_pubsub, topic_path, [data])
    84      mock_pubsub.publish.assert_has_calls(
    85          [mock.call(topic_path, data), mock.call().result()])
    86  
    87    def test_write_to_pubsub_with_attributes(self):
    88      mock_pubsub = mock.Mock()
    89      topic_path = "project/fakeproj/topics/faketopic"
    90      data = b'data'
    91      attributes = {'key': 'value'}
    92      message = PubsubMessage(data, attributes)
    93      utils.write_to_pubsub(
    94          mock_pubsub, topic_path, [message], with_attributes=True)
    95      mock_pubsub.publish.assert_has_calls(
    96          [mock.call(topic_path, data, **attributes), mock.call().result()])
    97  
    98    def test_write_to_pubsub_delay(self):
    99      number_of_elements = 2
   100      chunk_size = 1
   101      mock_pubsub = mock.Mock()
   102      topic_path = "project/fakeproj/topics/faketopic"
   103      data = b'data'
   104      with mock.patch('apache_beam.io.gcp.tests.utils.time') as mock_time:
   105        utils.write_to_pubsub(
   106            mock_pubsub,
   107            topic_path, [data] * number_of_elements,
   108            chunk_size=chunk_size,
   109            delay_between_chunks=123)
   110      mock_time.sleep.assert_called_with(123)
   111      mock_pubsub.publish.assert_has_calls(
   112          [mock.call(topic_path, data), mock.call().result()] *
   113          number_of_elements)
   114  
   115    def test_write_to_pubsub_many_chunks(self):
   116      number_of_elements = 83
   117      chunk_size = 11
   118      mock_pubsub = mock.Mock()
   119      topic_path = "project/fakeproj/topics/faketopic"
   120      data_list = [
   121          'data {}'.format(i).encode("utf-8") for i in range(number_of_elements)
   122      ]
   123      utils.write_to_pubsub(
   124          mock_pubsub, topic_path, data_list, chunk_size=chunk_size)
   125      call_list = []
   126      for start in range(0, number_of_elements, chunk_size):
   127        # Publish a batch of messages
   128        call_list += [
   129            mock.call(topic_path, data)
   130            for data in data_list[start:start + chunk_size]
   131        ]
   132        # Wait for those messages to be received
   133        call_list += [
   134            mock.call().result() for _ in data_list[start:start + chunk_size]
   135        ]
   136      mock_pubsub.publish.assert_has_calls(call_list)
   137  
   138    def test_read_from_pubsub(self):
   139      mock_pubsub = mock.Mock()
   140      subscription_path = "project/fakeproj/subscriptions/fakesub"
   141      data = b'data'
   142      ack_id = 'ack_id'
   143      pull_response = test_utils.create_pull_response(
   144          [test_utils.PullResponseMessage(data, ack_id=ack_id)])
   145      mock_pubsub.pull.return_value = pull_response
   146      output = utils.read_from_pubsub(
   147          mock_pubsub, subscription_path, number_of_elements=1)
   148      self.assertEqual([data], output)
   149      mock_pubsub.acknowledge.assert_called_once_with(
   150          subscription=subscription_path, ack_ids=[ack_id])
   151  
   152    def test_read_from_pubsub_with_attributes(self):
   153      mock_pubsub = mock.Mock()
   154      subscription_path = "project/fakeproj/subscriptions/fakesub"
   155      data = b'data'
   156      ack_id = 'ack_id'
   157      attributes = {'key': 'value'}
   158      message = PubsubMessage(data, attributes)
   159      pull_response = test_utils.create_pull_response(
   160          [test_utils.PullResponseMessage(data, attributes, ack_id=ack_id)])
   161      mock_pubsub.pull.return_value = pull_response
   162      output = utils.read_from_pubsub(
   163          mock_pubsub,
   164          subscription_path,
   165          with_attributes=True,
   166          number_of_elements=1)
   167      self.assertEqual([message], output)
   168      mock_pubsub.acknowledge.assert_called_once_with(
   169          subscription=subscription_path, ack_ids=[ack_id])
   170  
   171    def test_read_from_pubsub_flaky(self):
   172      number_of_elements = 10
   173      mock_pubsub = mock.Mock()
   174      subscription_path = "project/fakeproj/subscriptions/fakesub"
   175      data = b'data'
   176      ack_id = 'ack_id'
   177      pull_response = test_utils.create_pull_response(
   178          [test_utils.PullResponseMessage(data, ack_id=ack_id)])
   179  
   180      class FlakyPullResponse(object):
   181        def __init__(self, pull_response):
   182          self.pull_response = pull_response
   183          self._state = -1
   184  
   185        def __call__(self, *args, **kwargs):
   186          self._state += 1
   187          if self._state % 3 == 0:
   188            raise gexc.RetryError("", "")
   189          if self._state % 3 == 1:
   190            raise gexc.DeadlineExceeded("")
   191          if self._state % 3 == 2:
   192            return self.pull_response
   193  
   194      mock_pubsub.pull.side_effect = FlakyPullResponse(pull_response)
   195      output = utils.read_from_pubsub(
   196          mock_pubsub, subscription_path, number_of_elements=number_of_elements)
   197      self.assertEqual([data] * number_of_elements, output)
   198      self._assert_ack_ids_equal(mock_pubsub, [ack_id] * number_of_elements)
   199  
   200    def test_read_from_pubsub_many(self):
   201      response_size = 33
   202      number_of_elements = 100
   203      mock_pubsub = mock.Mock()
   204      subscription_path = "project/fakeproj/subscriptions/fakesub"
   205      data_list = [
   206          'data {}'.format(i).encode("utf-8") for i in range(number_of_elements)
   207      ]
   208      attributes_list = [{
   209          'key': 'value {}'.format(i)
   210      } for i in range(number_of_elements)]
   211      ack_ids = ['ack_id_{}'.format(i) for i in range(number_of_elements)]
   212      messages = [
   213          PubsubMessage(data, attributes) for data,
   214          attributes in zip(data_list, attributes_list)
   215      ]
   216      response_messages = [
   217          test_utils.PullResponseMessage(data, attributes, ack_id=ack_id)
   218          for data,
   219          attributes,
   220          ack_id in zip(data_list, attributes_list, ack_ids)
   221      ]
   222  
   223      class SequentialPullResponse(object):
   224        def __init__(self, response_messages, response_size):
   225          self.response_messages = response_messages
   226          self.response_size = response_size
   227          self._index = 0
   228  
   229        def __call__(self, *args, **kwargs):
   230          start = self._index
   231          self._index += self.response_size
   232          response = test_utils.create_pull_response(
   233              self.response_messages[start:start + self.response_size])
   234          return response
   235  
   236      mock_pubsub.pull.side_effect = SequentialPullResponse(
   237          response_messages, response_size)
   238      output = utils.read_from_pubsub(
   239          mock_pubsub,
   240          subscription_path,
   241          with_attributes=True,
   242          number_of_elements=number_of_elements)
   243      self.assertEqual(messages, output)
   244      self._assert_ack_ids_equal(mock_pubsub, ack_ids)
   245  
   246    def test_read_from_pubsub_invalid_arg(self):
   247      sub_client = mock.Mock()
   248      subscription_path = "project/fakeproj/subscriptions/fakesub"
   249      with self.assertRaisesRegex(ValueError, "number_of_elements"):
   250        utils.read_from_pubsub(sub_client, subscription_path)
   251      with self.assertRaisesRegex(ValueError, "number_of_elements"):
   252        utils.read_from_pubsub(
   253            sub_client, subscription_path, with_attributes=True)
   254  
   255    def _assert_ack_ids_equal(self, mock_pubsub, ack_ids):
   256      actual_ack_ids = [
   257          ack_id for args_list in mock_pubsub.acknowledge.call_args_list
   258          for ack_id in args_list[1]["ack_ids"]
   259      ]
   260      self.assertEqual(actual_ack_ids, ack_ids)
   261  
   262  
   263  if __name__ == '__main__':
   264    logging.getLogger().setLevel(logging.INFO)
   265    unittest.main()