github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/gcp/bigquery_tools_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 # pytype: skip-file 19 20 import datetime 21 import decimal 22 import io 23 import json 24 import logging 25 import math 26 import re 27 import unittest 28 from typing import Optional 29 from typing import Sequence 30 31 import fastavro 32 import mock 33 import numpy as np 34 import pytz 35 from parameterized import parameterized 36 37 import apache_beam as beam 38 from apache_beam.io.gcp import resource_identifiers 39 from apache_beam.io.gcp.bigquery_tools import JSON_COMPLIANCE_ERROR 40 from apache_beam.io.gcp.bigquery_tools import AvroRowWriter 41 from apache_beam.io.gcp.bigquery_tools import BigQueryJobTypes 42 from apache_beam.io.gcp.bigquery_tools import JsonRowWriter 43 from apache_beam.io.gcp.bigquery_tools import RowAsDictJsonCoder 44 from apache_beam.io.gcp.bigquery_tools import beam_row_from_dict 45 from apache_beam.io.gcp.bigquery_tools import check_schema_equal 46 from apache_beam.io.gcp.bigquery_tools import generate_bq_job_name 47 from apache_beam.io.gcp.bigquery_tools import get_beam_typehints_from_tableschema 48 from apache_beam.io.gcp.bigquery_tools import parse_table_reference 49 from apache_beam.io.gcp.bigquery_tools import parse_table_schema_from_json 50 from apache_beam.io.gcp.internal.clients import bigquery 51 from apache_beam.metrics import monitoring_infos 52 from apache_beam.metrics.execution import MetricsEnvironment 53 from apache_beam.options.value_provider import StaticValueProvider 54 from apache_beam.typehints.row_type import RowTypeConstraint 55 from apache_beam.utils.timestamp import Timestamp 56 57 # Protect against environments where bigquery library is not available. 58 # pylint: disable=wrong-import-order, wrong-import-position 59 try: 60 from apitools.base.py.exceptions import HttpError, HttpForbiddenError 61 from google.api_core.exceptions import ClientError, DeadlineExceeded 62 from google.api_core.exceptions import InternalServerError 63 import google.cloud 64 except ImportError: 65 ClientError = None 66 DeadlineExceeded = None 67 HttpError = None 68 HttpForbiddenError = None 69 InternalServerError = None 70 google = None 71 # pylint: enable=wrong-import-order, wrong-import-position 72 73 74 @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') 75 class TestTableSchemaParser(unittest.TestCase): 76 def test_parse_table_schema_from_json(self): 77 string_field = bigquery.TableFieldSchema( 78 name='s', type='STRING', mode='NULLABLE', description='s description') 79 number_field = bigquery.TableFieldSchema( 80 name='n', type='INTEGER', mode='REQUIRED', description='n description') 81 record_field = bigquery.TableFieldSchema( 82 name='r', 83 type='RECORD', 84 mode='REQUIRED', 85 description='r description', 86 fields=[string_field, number_field]) 87 expected_schema = bigquery.TableSchema(fields=[record_field]) 88 json_str = json.dumps({ 89 'fields': [{ 90 'name': 'r', 91 'type': 'RECORD', 92 'mode': 'REQUIRED', 93 'description': 'r description', 94 'fields': [{ 95 'name': 's', 96 'type': 'STRING', 97 'mode': 'NULLABLE', 98 'description': 's description' 99 }, 100 { 101 'name': 'n', 102 'type': 'INTEGER', 103 'mode': 'REQUIRED', 104 'description': 'n description' 105 }] 106 }] 107 }) 108 self.assertEqual(parse_table_schema_from_json(json_str), expected_schema) 109 110 111 @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') 112 class TestTableReferenceParser(unittest.TestCase): 113 def test_calling_with_table_reference(self): 114 table_ref = bigquery.TableReference() 115 table_ref.projectId = 'test_project' 116 table_ref.datasetId = 'test_dataset' 117 table_ref.tableId = 'test_table' 118 parsed_ref = parse_table_reference(table_ref) 119 self.assertEqual(table_ref, parsed_ref) 120 self.assertIsNot(table_ref, parsed_ref) 121 122 def test_calling_with_callable(self): 123 callable_ref = lambda: 'foo' 124 parsed_ref = parse_table_reference(callable_ref) 125 self.assertIs(callable_ref, parsed_ref) 126 127 def test_calling_with_value_provider(self): 128 value_provider_ref = StaticValueProvider(str, 'test_dataset.test_table') 129 parsed_ref = parse_table_reference(value_provider_ref) 130 self.assertIs(value_provider_ref, parsed_ref) 131 132 @parameterized.expand([ 133 ('project:dataset.test_table', 'project', 'dataset', 'test_table'), 134 ('project:dataset.test-table', 'project', 'dataset', 'test-table'), 135 ('project:dataset.test- table', 'project', 'dataset', 'test- table'), 136 ('project.dataset. test_table', 'project', 'dataset', ' test_table'), 137 ('project.dataset.test$table', 'project', 'dataset', 'test$table'), 138 ]) 139 def test_calling_with_fully_qualified_table_ref( 140 self, 141 fully_qualified_table: str, 142 project_id: str, 143 dataset_id: str, 144 table_id: str, 145 ): 146 parsed_ref = parse_table_reference(fully_qualified_table) 147 self.assertIsInstance(parsed_ref, bigquery.TableReference) 148 self.assertEqual(parsed_ref.projectId, project_id) 149 self.assertEqual(parsed_ref.datasetId, dataset_id) 150 self.assertEqual(parsed_ref.tableId, table_id) 151 152 def test_calling_with_partially_qualified_table_ref(self): 153 datasetId = 'test_dataset' 154 tableId = 'test_table' 155 partially_qualified_table = '{}.{}'.format(datasetId, tableId) 156 parsed_ref = parse_table_reference(partially_qualified_table) 157 self.assertIsInstance(parsed_ref, bigquery.TableReference) 158 self.assertEqual(parsed_ref.datasetId, datasetId) 159 self.assertEqual(parsed_ref.tableId, tableId) 160 161 def test_calling_with_insufficient_table_ref(self): 162 table = 'test_table' 163 self.assertRaises(ValueError, parse_table_reference, table) 164 165 def test_calling_with_all_arguments(self): 166 projectId = 'test_project' 167 datasetId = 'test_dataset' 168 tableId = 'test_table' 169 parsed_ref = parse_table_reference( 170 tableId, dataset=datasetId, project=projectId) 171 self.assertIsInstance(parsed_ref, bigquery.TableReference) 172 self.assertEqual(parsed_ref.projectId, projectId) 173 self.assertEqual(parsed_ref.datasetId, datasetId) 174 self.assertEqual(parsed_ref.tableId, tableId) 175 176 177 @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') 178 class TestBigQueryWrapper(unittest.TestCase): 179 def test_delete_non_existing_dataset(self): 180 client = mock.Mock() 181 client.datasets.Delete.side_effect = HttpError( 182 response={'status': '404'}, url='', content='') 183 wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client) 184 wrapper._delete_dataset('', '') 185 self.assertTrue(client.datasets.Delete.called) 186 187 @mock.patch('time.sleep', return_value=None) 188 def test_delete_dataset_retries_fail(self, patched_time_sleep): 189 client = mock.Mock() 190 client.datasets.Delete.side_effect = ValueError("Cannot delete") 191 wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client) 192 with self.assertRaises(ValueError): 193 wrapper._delete_dataset('', '') 194 self.assertEqual( 195 beam.io.gcp.bigquery_tools.MAX_RETRIES + 1, 196 client.datasets.Delete.call_count) 197 self.assertTrue(client.datasets.Delete.called) 198 199 def test_delete_non_existing_table(self): 200 client = mock.Mock() 201 client.tables.Delete.side_effect = HttpError( 202 response={'status': '404'}, url='', content='') 203 wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client) 204 wrapper._delete_table('', '', '') 205 self.assertTrue(client.tables.Delete.called) 206 207 @mock.patch('time.sleep', return_value=None) 208 def test_delete_table_retries_fail(self, patched_time_sleep): 209 client = mock.Mock() 210 client.tables.Delete.side_effect = ValueError("Cannot delete") 211 wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client) 212 with self.assertRaises(ValueError): 213 wrapper._delete_table('', '', '') 214 self.assertTrue(client.tables.Delete.called) 215 216 @mock.patch('time.sleep', return_value=None) 217 def test_delete_dataset_retries_for_timeouts(self, patched_time_sleep): 218 client = mock.Mock() 219 client.datasets.Delete.side_effect = [ 220 HttpError(response={'status': '408'}, url='', content=''), 221 bigquery.BigqueryDatasetsDeleteResponse() 222 ] 223 wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client) 224 wrapper._delete_dataset('', '') 225 self.assertTrue(client.datasets.Delete.called) 226 227 @unittest.skipIf( 228 google and not hasattr(google.cloud, '_http'), # pylint: disable=c-extension-no-member 229 'Dependencies not installed') 230 @mock.patch('time.sleep', return_value=None) 231 @mock.patch('google.cloud._http.JSONConnection.http') 232 def test_user_agent_insert_all(self, http_mock, patched_sleep): 233 wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper() 234 try: 235 wrapper._insert_all_rows('p', 'd', 't', [{'name': 'any'}], None) 236 except: # pylint: disable=bare-except 237 # Ignore errors. The errors come from the fact that we did not mock 238 # the response from the API, so the overall insert_all_rows call fails 239 # soon after the BQ API is called. 240 pass 241 call = http_mock.request.mock_calls[-2] 242 self.assertIn('apache-beam-', call[2]['headers']['User-Agent']) 243 244 @mock.patch('time.sleep', return_value=None) 245 def test_delete_table_retries_for_timeouts(self, patched_time_sleep): 246 client = mock.Mock() 247 client.tables.Delete.side_effect = [ 248 HttpError(response={'status': '408'}, url='', content=''), 249 bigquery.BigqueryTablesDeleteResponse() 250 ] 251 wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client) 252 wrapper._delete_table('', '', '') 253 self.assertTrue(client.tables.Delete.called) 254 255 @mock.patch('time.sleep', return_value=None) 256 def test_temporary_dataset_is_unique(self, patched_time_sleep): 257 client = mock.Mock() 258 client.datasets.Get.return_value = bigquery.Dataset( 259 datasetReference=bigquery.DatasetReference( 260 projectId='project-id', datasetId='dataset_id')) 261 wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client) 262 with self.assertRaises(RuntimeError): 263 wrapper.create_temporary_dataset('project-id', 'location') 264 self.assertTrue(client.datasets.Get.called) 265 266 @mock.patch('time.sleep', return_value=None) 267 def test_user_agent_passed(self, sleep_mock): 268 try: 269 wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper() 270 except: # pylint: disable=bare-except 271 self.skipTest('Unable to create a BQ Wrapper') 272 request_mock = mock.Mock() 273 wrapper.client._http.request = request_mock 274 try: 275 wrapper.create_temporary_dataset('project-id', 'location') 276 except: # pylint: disable=bare-except 277 # Ignore errors. The errors come from the fact that we did not mock 278 # the response from the API, so the overall create_dataset call fails 279 # soon after the BQ API is called. 280 pass 281 call = request_mock.mock_calls[-1] 282 self.assertIn('apache-beam-', call[2]['headers']['user-agent']) 283 284 def test_get_or_create_dataset_created(self): 285 client = mock.Mock() 286 client.datasets.Get.side_effect = HttpError( 287 response={'status': '404'}, url='', content='') 288 client.datasets.Insert.return_value = bigquery.Dataset( 289 datasetReference=bigquery.DatasetReference( 290 projectId='project-id', datasetId='dataset_id')) 291 wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client) 292 new_dataset = wrapper.get_or_create_dataset('project-id', 'dataset_id') 293 self.assertEqual(new_dataset.datasetReference.datasetId, 'dataset_id') 294 295 def test_get_or_create_dataset_fetched(self): 296 client = mock.Mock() 297 client.datasets.Get.return_value = bigquery.Dataset( 298 datasetReference=bigquery.DatasetReference( 299 projectId='project-id', datasetId='dataset_id')) 300 wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client) 301 new_dataset = wrapper.get_or_create_dataset('project-id', 'dataset_id') 302 self.assertEqual(new_dataset.datasetReference.datasetId, 'dataset_id') 303 304 def test_get_or_create_table(self): 305 client = mock.Mock() 306 client.tables.Insert.return_value = 'table_id' 307 client.tables.Get.side_effect = [None, 'table_id'] 308 wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client) 309 new_table = wrapper.get_or_create_table( 310 'project-id', 311 'dataset_id', 312 'table_id', 313 bigquery.TableSchema( 314 fields=[ 315 bigquery.TableFieldSchema( 316 name='b', type='BOOLEAN', mode='REQUIRED') 317 ]), 318 False, 319 False) 320 self.assertEqual(new_table, 'table_id') 321 322 def test_get_or_create_table_race_condition(self): 323 client = mock.Mock() 324 client.tables.Insert.side_effect = HttpError( 325 response={'status': '409'}, url='', content='') 326 client.tables.Get.side_effect = [None, 'table_id'] 327 wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client) 328 new_table = wrapper.get_or_create_table( 329 'project-id', 330 'dataset_id', 331 'table_id', 332 bigquery.TableSchema( 333 fields=[ 334 bigquery.TableFieldSchema( 335 name='b', type='BOOLEAN', mode='REQUIRED') 336 ]), 337 False, 338 False) 339 self.assertEqual(new_table, 'table_id') 340 341 def test_get_or_create_table_intermittent_exception(self): 342 client = mock.Mock() 343 client.tables.Insert.side_effect = [ 344 HttpError(response={'status': '408'}, url='', content=''), 'table_id' 345 ] 346 client.tables.Get.side_effect = [None, 'table_id'] 347 wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client) 348 new_table = wrapper.get_or_create_table( 349 'project-id', 350 'dataset_id', 351 'table_id', 352 bigquery.TableSchema( 353 fields=[ 354 bigquery.TableFieldSchema( 355 name='b', type='BOOLEAN', mode='REQUIRED') 356 ]), 357 False, 358 False) 359 self.assertEqual(new_table, 'table_id') 360 361 @parameterized.expand(['', 'a' * 1025]) 362 def test_get_or_create_table_invalid_tablename(self, table_id): 363 client = mock.Mock() 364 client.tables.Get.side_effect = [None] 365 wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client) 366 367 self.assertRaises( 368 ValueError, 369 wrapper.get_or_create_table, 370 'project-id', 371 'dataset_id', 372 table_id, 373 bigquery.TableSchema( 374 fields=[ 375 bigquery.TableFieldSchema( 376 name='b', type='BOOLEAN', mode='REQUIRED') 377 ]), 378 False, 379 False) 380 381 def test_wait_for_job_returns_true_when_job_is_done(self): 382 def make_response(state): 383 m = mock.Mock() 384 m.status.errorResult = None 385 m.status.state = state 386 return m 387 388 client, job_ref = mock.Mock(), mock.Mock() 389 wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client) 390 # Return 'DONE' the second time get_job is called. 391 wrapper.get_job = mock.Mock( 392 side_effect=[make_response('RUNNING'), make_response('DONE')]) 393 394 result = wrapper.wait_for_bq_job( 395 job_ref, sleep_duration_sec=0, max_retries=5) 396 self.assertTrue(result) 397 398 def test_wait_for_job_retries_fail(self): 399 client, response, job_ref = mock.Mock(), mock.Mock(), mock.Mock() 400 response.status.state = 'RUNNING' 401 wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client) 402 # Return 'RUNNING' response forever. 403 wrapper.get_job = lambda *args: response 404 405 with self.assertRaises(RuntimeError) as context: 406 wrapper.wait_for_bq_job(job_ref, sleep_duration_sec=0, max_retries=5) 407 self.assertEqual( 408 'The maximum number of retries has been reached', 409 str(context.exception)) 410 411 def test_get_query_location(self): 412 client = mock.Mock() 413 query = """ 414 SELECT 415 av.column1, table.column1 416 FROM `dataset.authorized_view` as av 417 JOIN `dataset.table` as table ON av.column2 = table.column2 418 """ 419 job = mock.MagicMock(spec=bigquery.Job) 420 job.statistics.query.referencedTables = [ 421 bigquery.TableReference( 422 projectId="first_project_id", 423 datasetId="first_dataset", 424 tableId="table_used_by_authorized_view"), 425 bigquery.TableReference( 426 projectId="second_project_id", 427 datasetId="second_dataset", 428 tableId="table"), 429 ] 430 client.jobs.Insert.return_value = job 431 432 wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client) 433 wrapper.get_table_location = mock.Mock( 434 side_effect=[ 435 HttpForbiddenError(response={'status': '404'}, url='', content=''), 436 "US" 437 ]) 438 location = wrapper.get_query_location( 439 project_id="second_project_id", query=query, use_legacy_sql=False) 440 self.assertEqual("US", location) 441 442 def test_perform_load_job_source_mutual_exclusivity(self): 443 client = mock.Mock() 444 wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client) 445 446 # Both source_uri and source_stream specified. 447 with self.assertRaises(ValueError): 448 wrapper.perform_load_job( 449 destination=parse_table_reference('project:dataset.table'), 450 job_id='job_id', 451 source_uris=['gs://example.com/*'], 452 source_stream=io.BytesIO()) 453 454 # Neither source_uri nor source_stream specified. 455 wrapper.perform_load_job( 456 destination=parse_table_reference('project:dataset.table'), job_id='J') 457 458 def test_perform_load_job_with_source_stream(self): 459 client = mock.Mock() 460 wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client) 461 462 wrapper.perform_load_job( 463 destination=parse_table_reference('project:dataset.table'), 464 job_id='job_id', 465 source_stream=io.BytesIO(b'some,data')) 466 467 client.jobs.Insert.assert_called_once() 468 upload = client.jobs.Insert.call_args[1]["upload"] 469 self.assertEqual(b'some,data', upload.stream.read()) 470 471 def test_perform_load_job_with_load_job_id(self): 472 client = mock.Mock() 473 wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client) 474 475 wrapper.perform_load_job( 476 destination=parse_table_reference('project:dataset.table'), 477 job_id='job_id', 478 source_uris=['gs://example.com/*'], 479 load_job_project_id='loadId') 480 call_args = client.jobs.Insert.call_args 481 self.assertEqual('loadId', call_args[0][0].projectId) 482 483 def verify_write_call_metric( 484 self, project_id, dataset_id, table_id, status, count): 485 """Check if an metric was recorded for the BQ IO write API call.""" 486 process_wide_monitoring_infos = list( 487 MetricsEnvironment.process_wide_container(). 488 to_runner_api_monitoring_infos(None).values()) 489 resource = resource_identifiers.BigQueryTable( 490 project_id, dataset_id, table_id) 491 labels = { 492 # TODO(ajamato): Add Ptransform label. 493 monitoring_infos.SERVICE_LABEL: 'BigQuery', 494 # Refer to any method which writes elements to BigQuery in batches 495 # as "BigQueryBatchWrite". I.e. storage API's insertAll, or future 496 # APIs introduced. 497 monitoring_infos.METHOD_LABEL: 'BigQueryBatchWrite', 498 monitoring_infos.RESOURCE_LABEL: resource, 499 monitoring_infos.BIGQUERY_PROJECT_ID_LABEL: project_id, 500 monitoring_infos.BIGQUERY_DATASET_LABEL: dataset_id, 501 monitoring_infos.BIGQUERY_TABLE_LABEL: table_id, 502 monitoring_infos.STATUS_LABEL: status, 503 } 504 expected_mi = monitoring_infos.int64_counter( 505 monitoring_infos.API_REQUEST_COUNT_URN, count, labels=labels) 506 expected_mi.ClearField("start_time") 507 508 found = False 509 for actual_mi in process_wide_monitoring_infos: 510 actual_mi.ClearField("start_time") 511 if expected_mi == actual_mi: 512 found = True 513 break 514 self.assertTrue( 515 found, "Did not find write call metric with status: %s" % status) 516 517 @unittest.skipIf(ClientError is None, 'GCP dependencies are not installed') 518 def test_insert_rows_sets_metric_on_failure(self): 519 MetricsEnvironment.process_wide_container().reset() 520 client = mock.Mock() 521 client.insert_rows_json = mock.Mock( 522 # Fail a few times, then succeed. 523 side_effect=[ 524 DeadlineExceeded("Deadline Exceeded"), 525 InternalServerError("Internal Error"), 526 [], 527 ]) 528 wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client) 529 wrapper.insert_rows("my_project", "my_dataset", "my_table", []) 530 531 # Expect two failing calls, then a success (i.e. two retries). 532 self.verify_write_call_metric( 533 "my_project", "my_dataset", "my_table", "deadline_exceeded", 1) 534 self.verify_write_call_metric( 535 "my_project", "my_dataset", "my_table", "internal", 1) 536 self.verify_write_call_metric( 537 "my_project", "my_dataset", "my_table", "ok", 1) 538 539 @unittest.skipIf(ClientError is None, 'GCP dependencies are not installed') 540 def test_start_query_job_priority_configuration(self): 541 client = mock.Mock() 542 wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client) 543 544 query_result = mock.Mock() 545 query_result.pageToken = None 546 wrapper._get_query_results = mock.Mock(return_value=query_result) 547 548 wrapper._start_query_job( 549 "my_project", 550 "my_query", 551 use_legacy_sql=False, 552 flatten_results=False, 553 job_id="my_job_id", 554 priority=beam.io.BigQueryQueryPriority.BATCH) 555 556 self.assertEqual( 557 client.jobs.Insert.call_args[0][0].job.configuration.query.priority, 558 'BATCH') 559 560 wrapper._start_query_job( 561 "my_project", 562 "my_query", 563 use_legacy_sql=False, 564 flatten_results=False, 565 job_id="my_job_id", 566 priority=beam.io.BigQueryQueryPriority.INTERACTIVE) 567 568 self.assertEqual( 569 client.jobs.Insert.call_args[0][0].job.configuration.query.priority, 570 'INTERACTIVE') 571 572 573 @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') 574 class TestRowAsDictJsonCoder(unittest.TestCase): 575 def test_row_as_dict(self): 576 coder = RowAsDictJsonCoder() 577 test_value = {'s': 'abc', 'i': 123, 'f': 123.456, 'b': True} 578 self.assertEqual(test_value, coder.decode(coder.encode(test_value))) 579 580 def test_decimal_in_row_as_dict(self): 581 decimal_value = decimal.Decimal('123456789.987654321') 582 coder = RowAsDictJsonCoder() 583 # Bigquery IO uses decimals to represent NUMERIC types. 584 # To export to BQ, it's necessary to convert to strings, due to the 585 # lower precision of JSON numbers. This means that we can't recognize 586 # a NUMERIC when we decode from JSON, thus we match the string here. 587 test_value = {'f': 123.456, 'b': True, 'numerico': decimal_value} 588 output_value = {'f': 123.456, 'b': True, 'numerico': str(decimal_value)} 589 self.assertEqual(output_value, coder.decode(coder.encode(test_value))) 590 591 def json_compliance_exception(self, value): 592 with self.assertRaisesRegex(ValueError, re.escape(JSON_COMPLIANCE_ERROR)): 593 coder = RowAsDictJsonCoder() 594 test_value = {'s': value} 595 coder.decode(coder.encode(test_value)) 596 597 def test_invalid_json_nan(self): 598 self.json_compliance_exception(float('nan')) 599 600 def test_invalid_json_inf(self): 601 self.json_compliance_exception(float('inf')) 602 603 def test_invalid_json_neg_inf(self): 604 self.json_compliance_exception(float('-inf')) 605 606 def test_ensure_ascii(self): 607 coder = RowAsDictJsonCoder() 608 test_value = {'s': '🎉'} 609 output_value = b'{"s": "\xf0\x9f\x8e\x89"}' 610 611 self.assertEqual(output_value, coder.encode(test_value)) 612 613 614 @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') 615 class TestJsonRowWriter(unittest.TestCase): 616 def test_write_row(self): 617 rows = [ 618 { 619 'name': 'beam', 'game': 'dream' 620 }, 621 { 622 'name': 'team', 'game': 'cream' 623 }, 624 ] 625 626 with io.BytesIO() as buf: 627 # Mock close() so we can access the buffer contents 628 # after JsonRowWriter is closed. 629 with mock.patch.object(buf, 'close') as mock_close: 630 writer = JsonRowWriter(buf) 631 for row in rows: 632 writer.write(row) 633 writer.close() 634 635 mock_close.assert_called_once() 636 637 buf.seek(0) 638 read_rows = [ 639 json.loads(row) 640 for row in buf.getvalue().strip().decode('utf-8').split('\n') 641 ] 642 643 self.assertEqual(read_rows, rows) 644 645 646 @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') 647 class TestAvroRowWriter(unittest.TestCase): 648 def test_write_row(self): 649 schema = bigquery.TableSchema( 650 fields=[ 651 bigquery.TableFieldSchema(name='stamp', type='TIMESTAMP'), 652 bigquery.TableFieldSchema( 653 name='number', type='FLOAT', mode='REQUIRED'), 654 ]) 655 stamp = datetime.datetime(2020, 2, 25, 12, 0, 0, tzinfo=pytz.utc) 656 657 with io.BytesIO() as buf: 658 # Mock close() so we can access the buffer contents 659 # after AvroRowWriter is closed. 660 with mock.patch.object(buf, 'close') as mock_close: 661 writer = AvroRowWriter(buf, schema) 662 writer.write({'stamp': stamp, 'number': float('NaN')}) 663 writer.close() 664 665 mock_close.assert_called_once() 666 667 buf.seek(0) 668 records = [r for r in fastavro.reader(buf)] 669 670 self.assertEqual(len(records), 1) 671 self.assertTrue(math.isnan(records[0]['number'])) 672 self.assertEqual(records[0]['stamp'], stamp) 673 674 675 class TestBQJobNames(unittest.TestCase): 676 def test_simple_names(self): 677 self.assertEqual( 678 "beam_bq_job_EXPORT_beamappjobtest_abcd", 679 generate_bq_job_name( 680 "beamapp-job-test", "abcd", BigQueryJobTypes.EXPORT)) 681 682 self.assertEqual( 683 "beam_bq_job_LOAD_beamappjobtest_abcd", 684 generate_bq_job_name("beamapp-job-test", "abcd", BigQueryJobTypes.LOAD)) 685 686 self.assertEqual( 687 "beam_bq_job_QUERY_beamappjobtest_abcd", 688 generate_bq_job_name( 689 "beamapp-job-test", "abcd", BigQueryJobTypes.QUERY)) 690 691 self.assertEqual( 692 "beam_bq_job_COPY_beamappjobtest_abcd", 693 generate_bq_job_name("beamapp-job-test", "abcd", BigQueryJobTypes.COPY)) 694 695 def test_random_in_name(self): 696 self.assertEqual( 697 "beam_bq_job_COPY_beamappjobtest_abcd_randome", 698 generate_bq_job_name( 699 "beamapp-job-test", "abcd", BigQueryJobTypes.COPY, "randome")) 700 701 def test_matches_template(self): 702 base_pattern = "beam_bq_job_[A-Z]+_[a-z0-9-]+_[a-z0-9-]+(_[a-z0-9-]+)?" 703 job_name = generate_bq_job_name( 704 "beamapp-job-test", "abcd", BigQueryJobTypes.COPY, "randome") 705 self.assertRegex(job_name, base_pattern) 706 707 job_name = generate_bq_job_name( 708 "beamapp-job-test", "abcd", BigQueryJobTypes.COPY) 709 self.assertRegex(job_name, base_pattern) 710 711 712 @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') 713 class TestCheckSchemaEqual(unittest.TestCase): 714 def test_simple_schemas(self): 715 schema1 = bigquery.TableSchema(fields=[]) 716 self.assertTrue(check_schema_equal(schema1, schema1)) 717 718 schema2 = bigquery.TableSchema( 719 fields=[ 720 bigquery.TableFieldSchema(name="a", mode="NULLABLE", type="INT64") 721 ]) 722 self.assertTrue(check_schema_equal(schema2, schema2)) 723 self.assertFalse(check_schema_equal(schema1, schema2)) 724 725 schema3 = bigquery.TableSchema( 726 fields=[ 727 bigquery.TableFieldSchema( 728 name="b", 729 mode="REPEATED", 730 type="RECORD", 731 fields=[ 732 bigquery.TableFieldSchema( 733 name="c", mode="REQUIRED", type="BOOL") 734 ]) 735 ]) 736 self.assertTrue(check_schema_equal(schema3, schema3)) 737 self.assertFalse(check_schema_equal(schema2, schema3)) 738 739 def test_field_order(self): 740 """Test that field order is ignored when ignore_field_order=True.""" 741 schema1 = bigquery.TableSchema( 742 fields=[ 743 bigquery.TableFieldSchema( 744 name="a", mode="REQUIRED", type="FLOAT64"), 745 bigquery.TableFieldSchema(name="b", mode="REQUIRED", type="INT64"), 746 ]) 747 748 schema2 = bigquery.TableSchema(fields=list(reversed(schema1.fields))) 749 750 self.assertFalse(check_schema_equal(schema1, schema2)) 751 self.assertTrue( 752 check_schema_equal(schema1, schema2, ignore_field_order=True)) 753 754 def test_descriptions(self): 755 """ 756 Test that differences in description are ignored 757 when ignore_descriptions=True. 758 """ 759 schema1 = bigquery.TableSchema( 760 fields=[ 761 bigquery.TableFieldSchema( 762 name="a", 763 mode="REQUIRED", 764 type="FLOAT64", 765 description="Field A", 766 ), 767 bigquery.TableFieldSchema( 768 name="b", 769 mode="REQUIRED", 770 type="INT64", 771 ), 772 ]) 773 774 schema2 = bigquery.TableSchema( 775 fields=[ 776 bigquery.TableFieldSchema( 777 name="a", 778 mode="REQUIRED", 779 type="FLOAT64", 780 description="Field A is for Apple"), 781 bigquery.TableFieldSchema( 782 name="b", 783 mode="REQUIRED", 784 type="INT64", 785 description="Field B", 786 ), 787 ]) 788 789 self.assertFalse(check_schema_equal(schema1, schema2)) 790 self.assertTrue( 791 check_schema_equal(schema1, schema2, ignore_descriptions=True)) 792 793 794 @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') 795 class TestBeamRowFromDict(unittest.TestCase): 796 DICT_ROW = { 797 "str": "a", 798 "bool": True, 799 "bytes": b'a', 800 "int": 1, 801 "float": 0.1, 802 "numeric": decimal.Decimal("1.11"), 803 "timestamp": Timestamp(1000, 100) 804 } 805 806 def get_schema_fields_with_mode(self, mode): 807 return [{ 808 "name": "str", "type": "STRING", "mode": mode 809 }, { 810 "name": "bool", "type": "boolean", "mode": mode 811 }, { 812 "name": "bytes", "type": "BYTES", "mode": mode 813 }, { 814 "name": "int", "type": "INTEGER", "mode": mode 815 }, { 816 "name": "float", "type": "Float", "mode": mode 817 }, { 818 "name": "numeric", "type": "NUMERIC", "mode": mode 819 }, { 820 "name": "timestamp", "type": "TIMESTAMP", "mode": mode 821 }] 822 823 def test_dict_to_beam_row_all_types_required(self): 824 schema = {"fields": self.get_schema_fields_with_mode("REQUIRED")} 825 expected_beam_row = beam.Row( 826 str="a", 827 bool=True, 828 bytes=b'a', 829 int=1, 830 float=0.1, 831 numeric=decimal.Decimal("1.11"), 832 timestamp=Timestamp(1000, 100)) 833 834 self.assertEqual( 835 expected_beam_row, beam_row_from_dict(self.DICT_ROW, schema)) 836 837 def test_dict_to_beam_row_all_types_repeated(self): 838 schema = {"fields": self.get_schema_fields_with_mode("REPEATED")} 839 dict_row = { 840 "str": ["a", "b"], 841 "bool": [True, False], 842 "bytes": [b'a', b'b'], 843 "int": [1, 2], 844 "float": [0.1, 0.2], 845 "numeric": [decimal.Decimal("1.11"), decimal.Decimal("2.22")], 846 "timestamp": [Timestamp(1000, 100), Timestamp(2000, 200)] 847 } 848 849 expected_beam_row = beam.Row( 850 str=["a", "b"], 851 bool=[True, False], 852 bytes=[b'a', b'b'], 853 int=[1, 2], 854 float=[0.1, 0.2], 855 numeric=[decimal.Decimal("1.11"), decimal.Decimal("2.22")], 856 timestamp=[Timestamp(1000, 100), Timestamp(2000, 200)]) 857 858 self.assertEqual(expected_beam_row, beam_row_from_dict(dict_row, schema)) 859 860 def test_dict_to_beam_row_all_types_nullable(self): 861 schema = {"fields": self.get_schema_fields_with_mode("nullable")} 862 dict_row = {k: None for k in self.DICT_ROW} 863 864 expected_beam_row = beam.Row( 865 str=None, 866 bool=None, 867 bytes=None, 868 int=None, 869 float=None, 870 numeric=None, 871 timestamp=None) 872 873 self.assertEqual(expected_beam_row, beam_row_from_dict(dict_row, schema)) 874 875 def test_dict_to_beam_row_nested_record(self): 876 schema_fields_with_nested = [{ 877 "name": "nested_record", 878 "type": "record", 879 "fields": self.get_schema_fields_with_mode("required") 880 }] 881 schema_fields_with_nested.extend( 882 self.get_schema_fields_with_mode("required")) 883 schema = {"fields": schema_fields_with_nested} 884 885 dict_row = { 886 "nested_record": self.DICT_ROW, 887 "str": "a", 888 "bool": True, 889 "bytes": b'a', 890 "int": 1, 891 "float": 0.1, 892 "numeric": decimal.Decimal("1.11"), 893 "timestamp": Timestamp(1000, 100) 894 } 895 expected_beam_row = beam.Row( 896 nested_record=beam.Row( 897 str="a", 898 bool=True, 899 bytes=b'a', 900 int=1, 901 float=0.1, 902 numeric=decimal.Decimal("1.11"), 903 timestamp=Timestamp(1000, 100)), 904 str="a", 905 bool=True, 906 bytes=b'a', 907 int=1, 908 float=0.1, 909 numeric=decimal.Decimal("1.11"), 910 timestamp=Timestamp(1000, 100)) 911 912 self.assertEqual(expected_beam_row, beam_row_from_dict(dict_row, schema)) 913 914 def test_dict_to_beam_row_repeated_nested_record(self): 915 schema_fields_with_repeated_nested_record = [{ 916 "name": "nested_repeated_record", 917 "type": "record", 918 "mode": "repeated", 919 "fields": self.get_schema_fields_with_mode("required") 920 }] 921 schema = {"fields": schema_fields_with_repeated_nested_record} 922 923 dict_row = { 924 "nested_repeated_record": [self.DICT_ROW, self.DICT_ROW, self.DICT_ROW], 925 } 926 927 beam_row = beam.Row( 928 str="a", 929 bool=True, 930 bytes=b'a', 931 int=1, 932 float=0.1, 933 numeric=decimal.Decimal("1.11"), 934 timestamp=Timestamp(1000, 100)) 935 expected_beam_row = beam.Row( 936 nested_repeated_record=[beam_row, beam_row, beam_row]) 937 938 self.assertEqual(expected_beam_row, beam_row_from_dict(dict_row, schema)) 939 940 941 @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') 942 class TestBeamTypehintFromSchema(unittest.TestCase): 943 EXPECTED_TYPEHINTS = [("str", str), ("bool", bool), ("bytes", bytes), 944 ("int", np.int64), ("float", np.float64), 945 ("numeric", decimal.Decimal), ("timestamp", Timestamp)] 946 947 def get_schema_fields_with_mode(self, mode): 948 return [{ 949 "name": "str", "type": "STRING", "mode": mode 950 }, { 951 "name": "bool", "type": "boolean", "mode": mode 952 }, { 953 "name": "bytes", "type": "BYTES", "mode": mode 954 }, { 955 "name": "int", "type": "INTEGER", "mode": mode 956 }, { 957 "name": "float", "type": "Float", "mode": mode 958 }, { 959 "name": "numeric", "type": "NUMERIC", "mode": mode 960 }, { 961 "name": "timestamp", "type": "TIMESTAMP", "mode": mode 962 }] 963 964 def test_typehints_from_required_schema(self): 965 schema = {"fields": self.get_schema_fields_with_mode("required")} 966 typehints = get_beam_typehints_from_tableschema(schema) 967 968 self.assertEqual(typehints, self.EXPECTED_TYPEHINTS) 969 970 def test_typehints_from_repeated_schema(self): 971 schema = {"fields": self.get_schema_fields_with_mode("repeated")} 972 typehints = get_beam_typehints_from_tableschema(schema) 973 974 expected_repeated_typehints = [ 975 (name, Sequence[type]) for name, type in self.EXPECTED_TYPEHINTS 976 ] 977 978 self.assertEqual(typehints, expected_repeated_typehints) 979 980 def test_typehints_from_nullable_schema(self): 981 schema = {"fields": self.get_schema_fields_with_mode("nullable")} 982 typehints = get_beam_typehints_from_tableschema(schema) 983 984 expected_nullable_typehints = [ 985 (name, Optional[type]) for name, type in self.EXPECTED_TYPEHINTS 986 ] 987 988 self.assertEqual(typehints, expected_nullable_typehints) 989 990 def test_typehints_from_schema_with_struct(self): 991 schema = { 992 "fields": [{ 993 "name": "record", 994 "type": "record", 995 "mode": "required", 996 "fields": self.get_schema_fields_with_mode("required") 997 }] 998 } 999 typehints = get_beam_typehints_from_tableschema(schema) 1000 1001 expected_typehints = [ 1002 ("record", RowTypeConstraint.from_fields(self.EXPECTED_TYPEHINTS)) 1003 ] 1004 1005 self.assertEqual(typehints, expected_typehints) 1006 1007 def test_typehints_from_schema_with_repeated_struct(self): 1008 schema = { 1009 "fields": [{ 1010 "name": "record", 1011 "type": "record", 1012 "mode": "repeated", 1013 "fields": self.get_schema_fields_with_mode("required") 1014 }] 1015 } 1016 typehints = get_beam_typehints_from_tableschema(schema) 1017 1018 expected_typehints = [( 1019 "record", 1020 Sequence[RowTypeConstraint.from_fields(self.EXPECTED_TYPEHINTS)])] 1021 1022 self.assertEqual(typehints, expected_typehints) 1023 1024 1025 if __name__ == '__main__': 1026 logging.getLogger().setLevel(logging.INFO) 1027 unittest.main()