github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/gcp/tests/utils_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Unittest for GCP testing utils.""" 19 20 # pytype: skip-file 21 22 import logging 23 import unittest 24 25 import mock 26 27 from apache_beam.io.gcp.pubsub import PubsubMessage 28 from apache_beam.io.gcp.tests import utils 29 from apache_beam.testing import test_utils 30 31 # Protect against environments where bigquery library is not available. 32 try: 33 from google.api_core import exceptions as gexc 34 from google.cloud import bigquery 35 from google.cloud import pubsub 36 except ImportError: 37 gexc = None 38 bigquery = None 39 pubsub = None 40 41 42 @unittest.skipIf(bigquery is None, 'Bigquery dependencies are not installed.') 43 @mock.patch.object(bigquery, 'Client') 44 class UtilsTest(unittest.TestCase): 45 def setUp(self): 46 test_utils.patch_retry(self, utils) 47 48 @mock.patch.object(bigquery, 'Dataset') 49 def test_create_bq_dataset(self, mock_dataset, mock_client): 50 mock_client.dataset.return_value = 'dataset_ref' 51 mock_dataset.return_value = 'dataset_obj' 52 53 utils.create_bq_dataset('project', 'dataset_base_name') 54 mock_client.return_value.create_dataset.assert_called_with('dataset_obj') 55 56 def test_delete_bq_dataset(self, mock_client): 57 utils.delete_bq_dataset('project', 'dataset_ref') 58 mock_client.return_value.delete_dataset.assert_called_with( 59 'dataset_ref', delete_contents=mock.ANY) 60 61 def test_delete_table_succeeds(self, mock_client): 62 mock_client.return_value.dataset.return_value.table.return_value = ( 63 'table_ref') 64 65 utils.delete_bq_table('unused_project', 'unused_dataset', 'unused_table') 66 mock_client.return_value.delete_table.assert_called_with('table_ref') 67 68 def test_delete_table_fails_not_found(self, mock_client): 69 mock_client.return_value.dataset.return_value.table.return_value = ( 70 'table_ref') 71 mock_client.return_value.delete_table.side_effect = gexc.NotFound('test') 72 73 with self.assertRaisesRegex(Exception, r'does not exist:.*table_ref'): 74 utils.delete_bq_table('unused_project', 'unused_dataset', 'unused_table') 75 76 77 @unittest.skipIf(pubsub is None, 'GCP dependencies are not installed') 78 class PubSubUtilTest(unittest.TestCase): 79 def test_write_to_pubsub(self): 80 mock_pubsub = mock.Mock() 81 topic_path = "project/fakeproj/topics/faketopic" 82 data = b'data' 83 utils.write_to_pubsub(mock_pubsub, topic_path, [data]) 84 mock_pubsub.publish.assert_has_calls( 85 [mock.call(topic_path, data), mock.call().result()]) 86 87 def test_write_to_pubsub_with_attributes(self): 88 mock_pubsub = mock.Mock() 89 topic_path = "project/fakeproj/topics/faketopic" 90 data = b'data' 91 attributes = {'key': 'value'} 92 message = PubsubMessage(data, attributes) 93 utils.write_to_pubsub( 94 mock_pubsub, topic_path, [message], with_attributes=True) 95 mock_pubsub.publish.assert_has_calls( 96 [mock.call(topic_path, data, **attributes), mock.call().result()]) 97 98 def test_write_to_pubsub_delay(self): 99 number_of_elements = 2 100 chunk_size = 1 101 mock_pubsub = mock.Mock() 102 topic_path = "project/fakeproj/topics/faketopic" 103 data = b'data' 104 with mock.patch('apache_beam.io.gcp.tests.utils.time') as mock_time: 105 utils.write_to_pubsub( 106 mock_pubsub, 107 topic_path, [data] * number_of_elements, 108 chunk_size=chunk_size, 109 delay_between_chunks=123) 110 mock_time.sleep.assert_called_with(123) 111 mock_pubsub.publish.assert_has_calls( 112 [mock.call(topic_path, data), mock.call().result()] * 113 number_of_elements) 114 115 def test_write_to_pubsub_many_chunks(self): 116 number_of_elements = 83 117 chunk_size = 11 118 mock_pubsub = mock.Mock() 119 topic_path = "project/fakeproj/topics/faketopic" 120 data_list = [ 121 'data {}'.format(i).encode("utf-8") for i in range(number_of_elements) 122 ] 123 utils.write_to_pubsub( 124 mock_pubsub, topic_path, data_list, chunk_size=chunk_size) 125 call_list = [] 126 for start in range(0, number_of_elements, chunk_size): 127 # Publish a batch of messages 128 call_list += [ 129 mock.call(topic_path, data) 130 for data in data_list[start:start + chunk_size] 131 ] 132 # Wait for those messages to be received 133 call_list += [ 134 mock.call().result() for _ in data_list[start:start + chunk_size] 135 ] 136 mock_pubsub.publish.assert_has_calls(call_list) 137 138 def test_read_from_pubsub(self): 139 mock_pubsub = mock.Mock() 140 subscription_path = "project/fakeproj/subscriptions/fakesub" 141 data = b'data' 142 ack_id = 'ack_id' 143 pull_response = test_utils.create_pull_response( 144 [test_utils.PullResponseMessage(data, ack_id=ack_id)]) 145 mock_pubsub.pull.return_value = pull_response 146 output = utils.read_from_pubsub( 147 mock_pubsub, subscription_path, number_of_elements=1) 148 self.assertEqual([data], output) 149 mock_pubsub.acknowledge.assert_called_once_with( 150 subscription=subscription_path, ack_ids=[ack_id]) 151 152 def test_read_from_pubsub_with_attributes(self): 153 mock_pubsub = mock.Mock() 154 subscription_path = "project/fakeproj/subscriptions/fakesub" 155 data = b'data' 156 ack_id = 'ack_id' 157 attributes = {'key': 'value'} 158 message = PubsubMessage(data, attributes) 159 pull_response = test_utils.create_pull_response( 160 [test_utils.PullResponseMessage(data, attributes, ack_id=ack_id)]) 161 mock_pubsub.pull.return_value = pull_response 162 output = utils.read_from_pubsub( 163 mock_pubsub, 164 subscription_path, 165 with_attributes=True, 166 number_of_elements=1) 167 self.assertEqual([message], output) 168 mock_pubsub.acknowledge.assert_called_once_with( 169 subscription=subscription_path, ack_ids=[ack_id]) 170 171 def test_read_from_pubsub_flaky(self): 172 number_of_elements = 10 173 mock_pubsub = mock.Mock() 174 subscription_path = "project/fakeproj/subscriptions/fakesub" 175 data = b'data' 176 ack_id = 'ack_id' 177 pull_response = test_utils.create_pull_response( 178 [test_utils.PullResponseMessage(data, ack_id=ack_id)]) 179 180 class FlakyPullResponse(object): 181 def __init__(self, pull_response): 182 self.pull_response = pull_response 183 self._state = -1 184 185 def __call__(self, *args, **kwargs): 186 self._state += 1 187 if self._state % 3 == 0: 188 raise gexc.RetryError("", "") 189 if self._state % 3 == 1: 190 raise gexc.DeadlineExceeded("") 191 if self._state % 3 == 2: 192 return self.pull_response 193 194 mock_pubsub.pull.side_effect = FlakyPullResponse(pull_response) 195 output = utils.read_from_pubsub( 196 mock_pubsub, subscription_path, number_of_elements=number_of_elements) 197 self.assertEqual([data] * number_of_elements, output) 198 self._assert_ack_ids_equal(mock_pubsub, [ack_id] * number_of_elements) 199 200 def test_read_from_pubsub_many(self): 201 response_size = 33 202 number_of_elements = 100 203 mock_pubsub = mock.Mock() 204 subscription_path = "project/fakeproj/subscriptions/fakesub" 205 data_list = [ 206 'data {}'.format(i).encode("utf-8") for i in range(number_of_elements) 207 ] 208 attributes_list = [{ 209 'key': 'value {}'.format(i) 210 } for i in range(number_of_elements)] 211 ack_ids = ['ack_id_{}'.format(i) for i in range(number_of_elements)] 212 messages = [ 213 PubsubMessage(data, attributes) for data, 214 attributes in zip(data_list, attributes_list) 215 ] 216 response_messages = [ 217 test_utils.PullResponseMessage(data, attributes, ack_id=ack_id) 218 for data, 219 attributes, 220 ack_id in zip(data_list, attributes_list, ack_ids) 221 ] 222 223 class SequentialPullResponse(object): 224 def __init__(self, response_messages, response_size): 225 self.response_messages = response_messages 226 self.response_size = response_size 227 self._index = 0 228 229 def __call__(self, *args, **kwargs): 230 start = self._index 231 self._index += self.response_size 232 response = test_utils.create_pull_response( 233 self.response_messages[start:start + self.response_size]) 234 return response 235 236 mock_pubsub.pull.side_effect = SequentialPullResponse( 237 response_messages, response_size) 238 output = utils.read_from_pubsub( 239 mock_pubsub, 240 subscription_path, 241 with_attributes=True, 242 number_of_elements=number_of_elements) 243 self.assertEqual(messages, output) 244 self._assert_ack_ids_equal(mock_pubsub, ack_ids) 245 246 def test_read_from_pubsub_invalid_arg(self): 247 sub_client = mock.Mock() 248 subscription_path = "project/fakeproj/subscriptions/fakesub" 249 with self.assertRaisesRegex(ValueError, "number_of_elements"): 250 utils.read_from_pubsub(sub_client, subscription_path) 251 with self.assertRaisesRegex(ValueError, "number_of_elements"): 252 utils.read_from_pubsub( 253 sub_client, subscription_path, with_attributes=True) 254 255 def _assert_ack_ids_equal(self, mock_pubsub, ack_ids): 256 actual_ack_ids = [ 257 ack_id for args_list in mock_pubsub.acknowledge.call_args_list 258 for ack_id in args_list[1]["ack_ids"] 259 ] 260 self.assertEqual(actual_ack_ids, ack_ids) 261 262 263 if __name__ == '__main__': 264 logging.getLogger().setLevel(logging.INFO) 265 unittest.main()