github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/gcp/bigquery_tools_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  # pytype: skip-file
    19  
    20  import datetime
    21  import decimal
    22  import io
    23  import json
    24  import logging
    25  import math
    26  import re
    27  import unittest
    28  from typing import Optional
    29  from typing import Sequence
    30  
    31  import fastavro
    32  import mock
    33  import numpy as np
    34  import pytz
    35  from parameterized import parameterized
    36  
    37  import apache_beam as beam
    38  from apache_beam.io.gcp import resource_identifiers
    39  from apache_beam.io.gcp.bigquery_tools import JSON_COMPLIANCE_ERROR
    40  from apache_beam.io.gcp.bigquery_tools import AvroRowWriter
    41  from apache_beam.io.gcp.bigquery_tools import BigQueryJobTypes
    42  from apache_beam.io.gcp.bigquery_tools import JsonRowWriter
    43  from apache_beam.io.gcp.bigquery_tools import RowAsDictJsonCoder
    44  from apache_beam.io.gcp.bigquery_tools import beam_row_from_dict
    45  from apache_beam.io.gcp.bigquery_tools import check_schema_equal
    46  from apache_beam.io.gcp.bigquery_tools import generate_bq_job_name
    47  from apache_beam.io.gcp.bigquery_tools import get_beam_typehints_from_tableschema
    48  from apache_beam.io.gcp.bigquery_tools import parse_table_reference
    49  from apache_beam.io.gcp.bigquery_tools import parse_table_schema_from_json
    50  from apache_beam.io.gcp.internal.clients import bigquery
    51  from apache_beam.metrics import monitoring_infos
    52  from apache_beam.metrics.execution import MetricsEnvironment
    53  from apache_beam.options.value_provider import StaticValueProvider
    54  from apache_beam.typehints.row_type import RowTypeConstraint
    55  from apache_beam.utils.timestamp import Timestamp
    56  
    57  # Protect against environments where bigquery library is not available.
    58  # pylint: disable=wrong-import-order, wrong-import-position
    59  try:
    60    from apitools.base.py.exceptions import HttpError, HttpForbiddenError
    61    from google.api_core.exceptions import ClientError, DeadlineExceeded
    62    from google.api_core.exceptions import InternalServerError
    63    import google.cloud
    64  except ImportError:
    65    ClientError = None
    66    DeadlineExceeded = None
    67    HttpError = None
    68    HttpForbiddenError = None
    69    InternalServerError = None
    70    google = None
    71  # pylint: enable=wrong-import-order, wrong-import-position
    72  
    73  
    74  @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
    75  class TestTableSchemaParser(unittest.TestCase):
    76    def test_parse_table_schema_from_json(self):
    77      string_field = bigquery.TableFieldSchema(
    78          name='s', type='STRING', mode='NULLABLE', description='s description')
    79      number_field = bigquery.TableFieldSchema(
    80          name='n', type='INTEGER', mode='REQUIRED', description='n description')
    81      record_field = bigquery.TableFieldSchema(
    82          name='r',
    83          type='RECORD',
    84          mode='REQUIRED',
    85          description='r description',
    86          fields=[string_field, number_field])
    87      expected_schema = bigquery.TableSchema(fields=[record_field])
    88      json_str = json.dumps({
    89          'fields': [{
    90              'name': 'r',
    91              'type': 'RECORD',
    92              'mode': 'REQUIRED',
    93              'description': 'r description',
    94              'fields': [{
    95                  'name': 's',
    96                  'type': 'STRING',
    97                  'mode': 'NULLABLE',
    98                  'description': 's description'
    99              },
   100                         {
   101                             'name': 'n',
   102                             'type': 'INTEGER',
   103                             'mode': 'REQUIRED',
   104                             'description': 'n description'
   105                         }]
   106          }]
   107      })
   108      self.assertEqual(parse_table_schema_from_json(json_str), expected_schema)
   109  
   110  
   111  @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
   112  class TestTableReferenceParser(unittest.TestCase):
   113    def test_calling_with_table_reference(self):
   114      table_ref = bigquery.TableReference()
   115      table_ref.projectId = 'test_project'
   116      table_ref.datasetId = 'test_dataset'
   117      table_ref.tableId = 'test_table'
   118      parsed_ref = parse_table_reference(table_ref)
   119      self.assertEqual(table_ref, parsed_ref)
   120      self.assertIsNot(table_ref, parsed_ref)
   121  
   122    def test_calling_with_callable(self):
   123      callable_ref = lambda: 'foo'
   124      parsed_ref = parse_table_reference(callable_ref)
   125      self.assertIs(callable_ref, parsed_ref)
   126  
   127    def test_calling_with_value_provider(self):
   128      value_provider_ref = StaticValueProvider(str, 'test_dataset.test_table')
   129      parsed_ref = parse_table_reference(value_provider_ref)
   130      self.assertIs(value_provider_ref, parsed_ref)
   131  
   132    @parameterized.expand([
   133        ('project:dataset.test_table', 'project', 'dataset', 'test_table'),
   134        ('project:dataset.test-table', 'project', 'dataset', 'test-table'),
   135        ('project:dataset.test- table', 'project', 'dataset', 'test- table'),
   136        ('project.dataset. test_table', 'project', 'dataset', ' test_table'),
   137        ('project.dataset.test$table', 'project', 'dataset', 'test$table'),
   138    ])
   139    def test_calling_with_fully_qualified_table_ref(
   140        self,
   141        fully_qualified_table: str,
   142        project_id: str,
   143        dataset_id: str,
   144        table_id: str,
   145    ):
   146      parsed_ref = parse_table_reference(fully_qualified_table)
   147      self.assertIsInstance(parsed_ref, bigquery.TableReference)
   148      self.assertEqual(parsed_ref.projectId, project_id)
   149      self.assertEqual(parsed_ref.datasetId, dataset_id)
   150      self.assertEqual(parsed_ref.tableId, table_id)
   151  
   152    def test_calling_with_partially_qualified_table_ref(self):
   153      datasetId = 'test_dataset'
   154      tableId = 'test_table'
   155      partially_qualified_table = '{}.{}'.format(datasetId, tableId)
   156      parsed_ref = parse_table_reference(partially_qualified_table)
   157      self.assertIsInstance(parsed_ref, bigquery.TableReference)
   158      self.assertEqual(parsed_ref.datasetId, datasetId)
   159      self.assertEqual(parsed_ref.tableId, tableId)
   160  
   161    def test_calling_with_insufficient_table_ref(self):
   162      table = 'test_table'
   163      self.assertRaises(ValueError, parse_table_reference, table)
   164  
   165    def test_calling_with_all_arguments(self):
   166      projectId = 'test_project'
   167      datasetId = 'test_dataset'
   168      tableId = 'test_table'
   169      parsed_ref = parse_table_reference(
   170          tableId, dataset=datasetId, project=projectId)
   171      self.assertIsInstance(parsed_ref, bigquery.TableReference)
   172      self.assertEqual(parsed_ref.projectId, projectId)
   173      self.assertEqual(parsed_ref.datasetId, datasetId)
   174      self.assertEqual(parsed_ref.tableId, tableId)
   175  
   176  
   177  @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
   178  class TestBigQueryWrapper(unittest.TestCase):
   179    def test_delete_non_existing_dataset(self):
   180      client = mock.Mock()
   181      client.datasets.Delete.side_effect = HttpError(
   182          response={'status': '404'}, url='', content='')
   183      wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
   184      wrapper._delete_dataset('', '')
   185      self.assertTrue(client.datasets.Delete.called)
   186  
   187    @mock.patch('time.sleep', return_value=None)
   188    def test_delete_dataset_retries_fail(self, patched_time_sleep):
   189      client = mock.Mock()
   190      client.datasets.Delete.side_effect = ValueError("Cannot delete")
   191      wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
   192      with self.assertRaises(ValueError):
   193        wrapper._delete_dataset('', '')
   194      self.assertEqual(
   195          beam.io.gcp.bigquery_tools.MAX_RETRIES + 1,
   196          client.datasets.Delete.call_count)
   197      self.assertTrue(client.datasets.Delete.called)
   198  
   199    def test_delete_non_existing_table(self):
   200      client = mock.Mock()
   201      client.tables.Delete.side_effect = HttpError(
   202          response={'status': '404'}, url='', content='')
   203      wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
   204      wrapper._delete_table('', '', '')
   205      self.assertTrue(client.tables.Delete.called)
   206  
   207    @mock.patch('time.sleep', return_value=None)
   208    def test_delete_table_retries_fail(self, patched_time_sleep):
   209      client = mock.Mock()
   210      client.tables.Delete.side_effect = ValueError("Cannot delete")
   211      wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
   212      with self.assertRaises(ValueError):
   213        wrapper._delete_table('', '', '')
   214      self.assertTrue(client.tables.Delete.called)
   215  
   216    @mock.patch('time.sleep', return_value=None)
   217    def test_delete_dataset_retries_for_timeouts(self, patched_time_sleep):
   218      client = mock.Mock()
   219      client.datasets.Delete.side_effect = [
   220          HttpError(response={'status': '408'}, url='', content=''),
   221          bigquery.BigqueryDatasetsDeleteResponse()
   222      ]
   223      wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
   224      wrapper._delete_dataset('', '')
   225      self.assertTrue(client.datasets.Delete.called)
   226  
   227    @unittest.skipIf(
   228        google and not hasattr(google.cloud, '_http'),  # pylint: disable=c-extension-no-member
   229        'Dependencies not installed')
   230    @mock.patch('time.sleep', return_value=None)
   231    @mock.patch('google.cloud._http.JSONConnection.http')
   232    def test_user_agent_insert_all(self, http_mock, patched_sleep):
   233      wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper()
   234      try:
   235        wrapper._insert_all_rows('p', 'd', 't', [{'name': 'any'}], None)
   236      except:  # pylint: disable=bare-except
   237        # Ignore errors. The errors come from the fact that we did not mock
   238        # the response from the API, so the overall insert_all_rows call fails
   239        # soon after the BQ API is called.
   240        pass
   241      call = http_mock.request.mock_calls[-2]
   242      self.assertIn('apache-beam-', call[2]['headers']['User-Agent'])
   243  
   244    @mock.patch('time.sleep', return_value=None)
   245    def test_delete_table_retries_for_timeouts(self, patched_time_sleep):
   246      client = mock.Mock()
   247      client.tables.Delete.side_effect = [
   248          HttpError(response={'status': '408'}, url='', content=''),
   249          bigquery.BigqueryTablesDeleteResponse()
   250      ]
   251      wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
   252      wrapper._delete_table('', '', '')
   253      self.assertTrue(client.tables.Delete.called)
   254  
   255    @mock.patch('time.sleep', return_value=None)
   256    def test_temporary_dataset_is_unique(self, patched_time_sleep):
   257      client = mock.Mock()
   258      client.datasets.Get.return_value = bigquery.Dataset(
   259          datasetReference=bigquery.DatasetReference(
   260              projectId='project-id', datasetId='dataset_id'))
   261      wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
   262      with self.assertRaises(RuntimeError):
   263        wrapper.create_temporary_dataset('project-id', 'location')
   264      self.assertTrue(client.datasets.Get.called)
   265  
   266    @mock.patch('time.sleep', return_value=None)
   267    def test_user_agent_passed(self, sleep_mock):
   268      try:
   269        wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper()
   270      except:  # pylint: disable=bare-except
   271        self.skipTest('Unable to create a BQ Wrapper')
   272      request_mock = mock.Mock()
   273      wrapper.client._http.request = request_mock
   274      try:
   275        wrapper.create_temporary_dataset('project-id', 'location')
   276      except:  # pylint: disable=bare-except
   277        # Ignore errors. The errors come from the fact that we did not mock
   278        # the response from the API, so the overall create_dataset call fails
   279        # soon after the BQ API is called.
   280        pass
   281      call = request_mock.mock_calls[-1]
   282      self.assertIn('apache-beam-', call[2]['headers']['user-agent'])
   283  
   284    def test_get_or_create_dataset_created(self):
   285      client = mock.Mock()
   286      client.datasets.Get.side_effect = HttpError(
   287          response={'status': '404'}, url='', content='')
   288      client.datasets.Insert.return_value = bigquery.Dataset(
   289          datasetReference=bigquery.DatasetReference(
   290              projectId='project-id', datasetId='dataset_id'))
   291      wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
   292      new_dataset = wrapper.get_or_create_dataset('project-id', 'dataset_id')
   293      self.assertEqual(new_dataset.datasetReference.datasetId, 'dataset_id')
   294  
   295    def test_get_or_create_dataset_fetched(self):
   296      client = mock.Mock()
   297      client.datasets.Get.return_value = bigquery.Dataset(
   298          datasetReference=bigquery.DatasetReference(
   299              projectId='project-id', datasetId='dataset_id'))
   300      wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
   301      new_dataset = wrapper.get_or_create_dataset('project-id', 'dataset_id')
   302      self.assertEqual(new_dataset.datasetReference.datasetId, 'dataset_id')
   303  
   304    def test_get_or_create_table(self):
   305      client = mock.Mock()
   306      client.tables.Insert.return_value = 'table_id'
   307      client.tables.Get.side_effect = [None, 'table_id']
   308      wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
   309      new_table = wrapper.get_or_create_table(
   310          'project-id',
   311          'dataset_id',
   312          'table_id',
   313          bigquery.TableSchema(
   314              fields=[
   315                  bigquery.TableFieldSchema(
   316                      name='b', type='BOOLEAN', mode='REQUIRED')
   317              ]),
   318          False,
   319          False)
   320      self.assertEqual(new_table, 'table_id')
   321  
   322    def test_get_or_create_table_race_condition(self):
   323      client = mock.Mock()
   324      client.tables.Insert.side_effect = HttpError(
   325          response={'status': '409'}, url='', content='')
   326      client.tables.Get.side_effect = [None, 'table_id']
   327      wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
   328      new_table = wrapper.get_or_create_table(
   329          'project-id',
   330          'dataset_id',
   331          'table_id',
   332          bigquery.TableSchema(
   333              fields=[
   334                  bigquery.TableFieldSchema(
   335                      name='b', type='BOOLEAN', mode='REQUIRED')
   336              ]),
   337          False,
   338          False)
   339      self.assertEqual(new_table, 'table_id')
   340  
   341    def test_get_or_create_table_intermittent_exception(self):
   342      client = mock.Mock()
   343      client.tables.Insert.side_effect = [
   344          HttpError(response={'status': '408'}, url='', content=''), 'table_id'
   345      ]
   346      client.tables.Get.side_effect = [None, 'table_id']
   347      wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
   348      new_table = wrapper.get_or_create_table(
   349          'project-id',
   350          'dataset_id',
   351          'table_id',
   352          bigquery.TableSchema(
   353              fields=[
   354                  bigquery.TableFieldSchema(
   355                      name='b', type='BOOLEAN', mode='REQUIRED')
   356              ]),
   357          False,
   358          False)
   359      self.assertEqual(new_table, 'table_id')
   360  
   361    @parameterized.expand(['', 'a' * 1025])
   362    def test_get_or_create_table_invalid_tablename(self, table_id):
   363      client = mock.Mock()
   364      client.tables.Get.side_effect = [None]
   365      wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
   366  
   367      self.assertRaises(
   368          ValueError,
   369          wrapper.get_or_create_table,
   370          'project-id',
   371          'dataset_id',
   372          table_id,
   373          bigquery.TableSchema(
   374              fields=[
   375                  bigquery.TableFieldSchema(
   376                      name='b', type='BOOLEAN', mode='REQUIRED')
   377              ]),
   378          False,
   379          False)
   380  
   381    def test_wait_for_job_returns_true_when_job_is_done(self):
   382      def make_response(state):
   383        m = mock.Mock()
   384        m.status.errorResult = None
   385        m.status.state = state
   386        return m
   387  
   388      client, job_ref = mock.Mock(), mock.Mock()
   389      wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
   390      # Return 'DONE' the second time get_job is called.
   391      wrapper.get_job = mock.Mock(
   392          side_effect=[make_response('RUNNING'), make_response('DONE')])
   393  
   394      result = wrapper.wait_for_bq_job(
   395          job_ref, sleep_duration_sec=0, max_retries=5)
   396      self.assertTrue(result)
   397  
   398    def test_wait_for_job_retries_fail(self):
   399      client, response, job_ref = mock.Mock(), mock.Mock(), mock.Mock()
   400      response.status.state = 'RUNNING'
   401      wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
   402      # Return 'RUNNING' response forever.
   403      wrapper.get_job = lambda *args: response
   404  
   405      with self.assertRaises(RuntimeError) as context:
   406        wrapper.wait_for_bq_job(job_ref, sleep_duration_sec=0, max_retries=5)
   407      self.assertEqual(
   408          'The maximum number of retries has been reached',
   409          str(context.exception))
   410  
   411    def test_get_query_location(self):
   412      client = mock.Mock()
   413      query = """
   414          SELECT
   415              av.column1, table.column1
   416          FROM `dataset.authorized_view` as av
   417          JOIN `dataset.table` as table ON av.column2 = table.column2
   418      """
   419      job = mock.MagicMock(spec=bigquery.Job)
   420      job.statistics.query.referencedTables = [
   421          bigquery.TableReference(
   422              projectId="first_project_id",
   423              datasetId="first_dataset",
   424              tableId="table_used_by_authorized_view"),
   425          bigquery.TableReference(
   426              projectId="second_project_id",
   427              datasetId="second_dataset",
   428              tableId="table"),
   429      ]
   430      client.jobs.Insert.return_value = job
   431  
   432      wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
   433      wrapper.get_table_location = mock.Mock(
   434          side_effect=[
   435              HttpForbiddenError(response={'status': '404'}, url='', content=''),
   436              "US"
   437          ])
   438      location = wrapper.get_query_location(
   439          project_id="second_project_id", query=query, use_legacy_sql=False)
   440      self.assertEqual("US", location)
   441  
   442    def test_perform_load_job_source_mutual_exclusivity(self):
   443      client = mock.Mock()
   444      wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
   445  
   446      # Both source_uri and source_stream specified.
   447      with self.assertRaises(ValueError):
   448        wrapper.perform_load_job(
   449            destination=parse_table_reference('project:dataset.table'),
   450            job_id='job_id',
   451            source_uris=['gs://example.com/*'],
   452            source_stream=io.BytesIO())
   453  
   454      # Neither source_uri nor source_stream specified.
   455      wrapper.perform_load_job(
   456          destination=parse_table_reference('project:dataset.table'), job_id='J')
   457  
   458    def test_perform_load_job_with_source_stream(self):
   459      client = mock.Mock()
   460      wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
   461  
   462      wrapper.perform_load_job(
   463          destination=parse_table_reference('project:dataset.table'),
   464          job_id='job_id',
   465          source_stream=io.BytesIO(b'some,data'))
   466  
   467      client.jobs.Insert.assert_called_once()
   468      upload = client.jobs.Insert.call_args[1]["upload"]
   469      self.assertEqual(b'some,data', upload.stream.read())
   470  
   471    def test_perform_load_job_with_load_job_id(self):
   472      client = mock.Mock()
   473      wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
   474  
   475      wrapper.perform_load_job(
   476          destination=parse_table_reference('project:dataset.table'),
   477          job_id='job_id',
   478          source_uris=['gs://example.com/*'],
   479          load_job_project_id='loadId')
   480      call_args = client.jobs.Insert.call_args
   481      self.assertEqual('loadId', call_args[0][0].projectId)
   482  
   483    def verify_write_call_metric(
   484        self, project_id, dataset_id, table_id, status, count):
   485      """Check if an metric was recorded for the BQ IO write API call."""
   486      process_wide_monitoring_infos = list(
   487          MetricsEnvironment.process_wide_container().
   488          to_runner_api_monitoring_infos(None).values())
   489      resource = resource_identifiers.BigQueryTable(
   490          project_id, dataset_id, table_id)
   491      labels = {
   492          # TODO(ajamato): Add Ptransform label.
   493          monitoring_infos.SERVICE_LABEL: 'BigQuery',
   494          # Refer to any method which writes elements to BigQuery in batches
   495          # as "BigQueryBatchWrite". I.e. storage API's insertAll, or future
   496          # APIs introduced.
   497          monitoring_infos.METHOD_LABEL: 'BigQueryBatchWrite',
   498          monitoring_infos.RESOURCE_LABEL: resource,
   499          monitoring_infos.BIGQUERY_PROJECT_ID_LABEL: project_id,
   500          monitoring_infos.BIGQUERY_DATASET_LABEL: dataset_id,
   501          monitoring_infos.BIGQUERY_TABLE_LABEL: table_id,
   502          monitoring_infos.STATUS_LABEL: status,
   503      }
   504      expected_mi = monitoring_infos.int64_counter(
   505          monitoring_infos.API_REQUEST_COUNT_URN, count, labels=labels)
   506      expected_mi.ClearField("start_time")
   507  
   508      found = False
   509      for actual_mi in process_wide_monitoring_infos:
   510        actual_mi.ClearField("start_time")
   511        if expected_mi == actual_mi:
   512          found = True
   513          break
   514      self.assertTrue(
   515          found, "Did not find write call metric with status: %s" % status)
   516  
   517    @unittest.skipIf(ClientError is None, 'GCP dependencies are not installed')
   518    def test_insert_rows_sets_metric_on_failure(self):
   519      MetricsEnvironment.process_wide_container().reset()
   520      client = mock.Mock()
   521      client.insert_rows_json = mock.Mock(
   522          # Fail a few times, then succeed.
   523          side_effect=[
   524              DeadlineExceeded("Deadline Exceeded"),
   525              InternalServerError("Internal Error"),
   526              [],
   527          ])
   528      wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
   529      wrapper.insert_rows("my_project", "my_dataset", "my_table", [])
   530  
   531      # Expect two failing calls, then a success (i.e. two retries).
   532      self.verify_write_call_metric(
   533          "my_project", "my_dataset", "my_table", "deadline_exceeded", 1)
   534      self.verify_write_call_metric(
   535          "my_project", "my_dataset", "my_table", "internal", 1)
   536      self.verify_write_call_metric(
   537          "my_project", "my_dataset", "my_table", "ok", 1)
   538  
   539    @unittest.skipIf(ClientError is None, 'GCP dependencies are not installed')
   540    def test_start_query_job_priority_configuration(self):
   541      client = mock.Mock()
   542      wrapper = beam.io.gcp.bigquery_tools.BigQueryWrapper(client)
   543  
   544      query_result = mock.Mock()
   545      query_result.pageToken = None
   546      wrapper._get_query_results = mock.Mock(return_value=query_result)
   547  
   548      wrapper._start_query_job(
   549          "my_project",
   550          "my_query",
   551          use_legacy_sql=False,
   552          flatten_results=False,
   553          job_id="my_job_id",
   554          priority=beam.io.BigQueryQueryPriority.BATCH)
   555  
   556      self.assertEqual(
   557          client.jobs.Insert.call_args[0][0].job.configuration.query.priority,
   558          'BATCH')
   559  
   560      wrapper._start_query_job(
   561          "my_project",
   562          "my_query",
   563          use_legacy_sql=False,
   564          flatten_results=False,
   565          job_id="my_job_id",
   566          priority=beam.io.BigQueryQueryPriority.INTERACTIVE)
   567  
   568      self.assertEqual(
   569          client.jobs.Insert.call_args[0][0].job.configuration.query.priority,
   570          'INTERACTIVE')
   571  
   572  
   573  @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
   574  class TestRowAsDictJsonCoder(unittest.TestCase):
   575    def test_row_as_dict(self):
   576      coder = RowAsDictJsonCoder()
   577      test_value = {'s': 'abc', 'i': 123, 'f': 123.456, 'b': True}
   578      self.assertEqual(test_value, coder.decode(coder.encode(test_value)))
   579  
   580    def test_decimal_in_row_as_dict(self):
   581      decimal_value = decimal.Decimal('123456789.987654321')
   582      coder = RowAsDictJsonCoder()
   583      # Bigquery IO uses decimals to represent NUMERIC types.
   584      # To export to BQ, it's necessary to convert to strings, due to the
   585      # lower precision of JSON numbers. This means that we can't recognize
   586      # a NUMERIC when we decode from JSON, thus we match the string here.
   587      test_value = {'f': 123.456, 'b': True, 'numerico': decimal_value}
   588      output_value = {'f': 123.456, 'b': True, 'numerico': str(decimal_value)}
   589      self.assertEqual(output_value, coder.decode(coder.encode(test_value)))
   590  
   591    def json_compliance_exception(self, value):
   592      with self.assertRaisesRegex(ValueError, re.escape(JSON_COMPLIANCE_ERROR)):
   593        coder = RowAsDictJsonCoder()
   594        test_value = {'s': value}
   595        coder.decode(coder.encode(test_value))
   596  
   597    def test_invalid_json_nan(self):
   598      self.json_compliance_exception(float('nan'))
   599  
   600    def test_invalid_json_inf(self):
   601      self.json_compliance_exception(float('inf'))
   602  
   603    def test_invalid_json_neg_inf(self):
   604      self.json_compliance_exception(float('-inf'))
   605  
   606    def test_ensure_ascii(self):
   607      coder = RowAsDictJsonCoder()
   608      test_value = {'s': '🎉'}
   609      output_value = b'{"s": "\xf0\x9f\x8e\x89"}'
   610  
   611      self.assertEqual(output_value, coder.encode(test_value))
   612  
   613  
   614  @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
   615  class TestJsonRowWriter(unittest.TestCase):
   616    def test_write_row(self):
   617      rows = [
   618          {
   619              'name': 'beam', 'game': 'dream'
   620          },
   621          {
   622              'name': 'team', 'game': 'cream'
   623          },
   624      ]
   625  
   626      with io.BytesIO() as buf:
   627        # Mock close() so we can access the buffer contents
   628        # after JsonRowWriter is closed.
   629        with mock.patch.object(buf, 'close') as mock_close:
   630          writer = JsonRowWriter(buf)
   631          for row in rows:
   632            writer.write(row)
   633          writer.close()
   634  
   635          mock_close.assert_called_once()
   636  
   637        buf.seek(0)
   638        read_rows = [
   639            json.loads(row)
   640            for row in buf.getvalue().strip().decode('utf-8').split('\n')
   641        ]
   642  
   643      self.assertEqual(read_rows, rows)
   644  
   645  
   646  @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
   647  class TestAvroRowWriter(unittest.TestCase):
   648    def test_write_row(self):
   649      schema = bigquery.TableSchema(
   650          fields=[
   651              bigquery.TableFieldSchema(name='stamp', type='TIMESTAMP'),
   652              bigquery.TableFieldSchema(
   653                  name='number', type='FLOAT', mode='REQUIRED'),
   654          ])
   655      stamp = datetime.datetime(2020, 2, 25, 12, 0, 0, tzinfo=pytz.utc)
   656  
   657      with io.BytesIO() as buf:
   658        # Mock close() so we can access the buffer contents
   659        # after AvroRowWriter is closed.
   660        with mock.patch.object(buf, 'close') as mock_close:
   661          writer = AvroRowWriter(buf, schema)
   662          writer.write({'stamp': stamp, 'number': float('NaN')})
   663          writer.close()
   664  
   665          mock_close.assert_called_once()
   666  
   667        buf.seek(0)
   668        records = [r for r in fastavro.reader(buf)]
   669  
   670      self.assertEqual(len(records), 1)
   671      self.assertTrue(math.isnan(records[0]['number']))
   672      self.assertEqual(records[0]['stamp'], stamp)
   673  
   674  
   675  class TestBQJobNames(unittest.TestCase):
   676    def test_simple_names(self):
   677      self.assertEqual(
   678          "beam_bq_job_EXPORT_beamappjobtest_abcd",
   679          generate_bq_job_name(
   680              "beamapp-job-test", "abcd", BigQueryJobTypes.EXPORT))
   681  
   682      self.assertEqual(
   683          "beam_bq_job_LOAD_beamappjobtest_abcd",
   684          generate_bq_job_name("beamapp-job-test", "abcd", BigQueryJobTypes.LOAD))
   685  
   686      self.assertEqual(
   687          "beam_bq_job_QUERY_beamappjobtest_abcd",
   688          generate_bq_job_name(
   689              "beamapp-job-test", "abcd", BigQueryJobTypes.QUERY))
   690  
   691      self.assertEqual(
   692          "beam_bq_job_COPY_beamappjobtest_abcd",
   693          generate_bq_job_name("beamapp-job-test", "abcd", BigQueryJobTypes.COPY))
   694  
   695    def test_random_in_name(self):
   696      self.assertEqual(
   697          "beam_bq_job_COPY_beamappjobtest_abcd_randome",
   698          generate_bq_job_name(
   699              "beamapp-job-test", "abcd", BigQueryJobTypes.COPY, "randome"))
   700  
   701    def test_matches_template(self):
   702      base_pattern = "beam_bq_job_[A-Z]+_[a-z0-9-]+_[a-z0-9-]+(_[a-z0-9-]+)?"
   703      job_name = generate_bq_job_name(
   704          "beamapp-job-test", "abcd", BigQueryJobTypes.COPY, "randome")
   705      self.assertRegex(job_name, base_pattern)
   706  
   707      job_name = generate_bq_job_name(
   708          "beamapp-job-test", "abcd", BigQueryJobTypes.COPY)
   709      self.assertRegex(job_name, base_pattern)
   710  
   711  
   712  @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
   713  class TestCheckSchemaEqual(unittest.TestCase):
   714    def test_simple_schemas(self):
   715      schema1 = bigquery.TableSchema(fields=[])
   716      self.assertTrue(check_schema_equal(schema1, schema1))
   717  
   718      schema2 = bigquery.TableSchema(
   719          fields=[
   720              bigquery.TableFieldSchema(name="a", mode="NULLABLE", type="INT64")
   721          ])
   722      self.assertTrue(check_schema_equal(schema2, schema2))
   723      self.assertFalse(check_schema_equal(schema1, schema2))
   724  
   725      schema3 = bigquery.TableSchema(
   726          fields=[
   727              bigquery.TableFieldSchema(
   728                  name="b",
   729                  mode="REPEATED",
   730                  type="RECORD",
   731                  fields=[
   732                      bigquery.TableFieldSchema(
   733                          name="c", mode="REQUIRED", type="BOOL")
   734                  ])
   735          ])
   736      self.assertTrue(check_schema_equal(schema3, schema3))
   737      self.assertFalse(check_schema_equal(schema2, schema3))
   738  
   739    def test_field_order(self):
   740      """Test that field order is ignored when ignore_field_order=True."""
   741      schema1 = bigquery.TableSchema(
   742          fields=[
   743              bigquery.TableFieldSchema(
   744                  name="a", mode="REQUIRED", type="FLOAT64"),
   745              bigquery.TableFieldSchema(name="b", mode="REQUIRED", type="INT64"),
   746          ])
   747  
   748      schema2 = bigquery.TableSchema(fields=list(reversed(schema1.fields)))
   749  
   750      self.assertFalse(check_schema_equal(schema1, schema2))
   751      self.assertTrue(
   752          check_schema_equal(schema1, schema2, ignore_field_order=True))
   753  
   754    def test_descriptions(self):
   755      """
   756          Test that differences in description are ignored
   757          when ignore_descriptions=True.
   758          """
   759      schema1 = bigquery.TableSchema(
   760          fields=[
   761              bigquery.TableFieldSchema(
   762                  name="a",
   763                  mode="REQUIRED",
   764                  type="FLOAT64",
   765                  description="Field A",
   766              ),
   767              bigquery.TableFieldSchema(
   768                  name="b",
   769                  mode="REQUIRED",
   770                  type="INT64",
   771              ),
   772          ])
   773  
   774      schema2 = bigquery.TableSchema(
   775          fields=[
   776              bigquery.TableFieldSchema(
   777                  name="a",
   778                  mode="REQUIRED",
   779                  type="FLOAT64",
   780                  description="Field A is for Apple"),
   781              bigquery.TableFieldSchema(
   782                  name="b",
   783                  mode="REQUIRED",
   784                  type="INT64",
   785                  description="Field B",
   786              ),
   787          ])
   788  
   789      self.assertFalse(check_schema_equal(schema1, schema2))
   790      self.assertTrue(
   791          check_schema_equal(schema1, schema2, ignore_descriptions=True))
   792  
   793  
   794  @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
   795  class TestBeamRowFromDict(unittest.TestCase):
   796    DICT_ROW = {
   797        "str": "a",
   798        "bool": True,
   799        "bytes": b'a',
   800        "int": 1,
   801        "float": 0.1,
   802        "numeric": decimal.Decimal("1.11"),
   803        "timestamp": Timestamp(1000, 100)
   804    }
   805  
   806    def get_schema_fields_with_mode(self, mode):
   807      return [{
   808          "name": "str", "type": "STRING", "mode": mode
   809      }, {
   810          "name": "bool", "type": "boolean", "mode": mode
   811      }, {
   812          "name": "bytes", "type": "BYTES", "mode": mode
   813      }, {
   814          "name": "int", "type": "INTEGER", "mode": mode
   815      }, {
   816          "name": "float", "type": "Float", "mode": mode
   817      }, {
   818          "name": "numeric", "type": "NUMERIC", "mode": mode
   819      }, {
   820          "name": "timestamp", "type": "TIMESTAMP", "mode": mode
   821      }]
   822  
   823    def test_dict_to_beam_row_all_types_required(self):
   824      schema = {"fields": self.get_schema_fields_with_mode("REQUIRED")}
   825      expected_beam_row = beam.Row(
   826          str="a",
   827          bool=True,
   828          bytes=b'a',
   829          int=1,
   830          float=0.1,
   831          numeric=decimal.Decimal("1.11"),
   832          timestamp=Timestamp(1000, 100))
   833  
   834      self.assertEqual(
   835          expected_beam_row, beam_row_from_dict(self.DICT_ROW, schema))
   836  
   837    def test_dict_to_beam_row_all_types_repeated(self):
   838      schema = {"fields": self.get_schema_fields_with_mode("REPEATED")}
   839      dict_row = {
   840          "str": ["a", "b"],
   841          "bool": [True, False],
   842          "bytes": [b'a', b'b'],
   843          "int": [1, 2],
   844          "float": [0.1, 0.2],
   845          "numeric": [decimal.Decimal("1.11"), decimal.Decimal("2.22")],
   846          "timestamp": [Timestamp(1000, 100), Timestamp(2000, 200)]
   847      }
   848  
   849      expected_beam_row = beam.Row(
   850          str=["a", "b"],
   851          bool=[True, False],
   852          bytes=[b'a', b'b'],
   853          int=[1, 2],
   854          float=[0.1, 0.2],
   855          numeric=[decimal.Decimal("1.11"), decimal.Decimal("2.22")],
   856          timestamp=[Timestamp(1000, 100), Timestamp(2000, 200)])
   857  
   858      self.assertEqual(expected_beam_row, beam_row_from_dict(dict_row, schema))
   859  
   860    def test_dict_to_beam_row_all_types_nullable(self):
   861      schema = {"fields": self.get_schema_fields_with_mode("nullable")}
   862      dict_row = {k: None for k in self.DICT_ROW}
   863  
   864      expected_beam_row = beam.Row(
   865          str=None,
   866          bool=None,
   867          bytes=None,
   868          int=None,
   869          float=None,
   870          numeric=None,
   871          timestamp=None)
   872  
   873      self.assertEqual(expected_beam_row, beam_row_from_dict(dict_row, schema))
   874  
   875    def test_dict_to_beam_row_nested_record(self):
   876      schema_fields_with_nested = [{
   877          "name": "nested_record",
   878          "type": "record",
   879          "fields": self.get_schema_fields_with_mode("required")
   880      }]
   881      schema_fields_with_nested.extend(
   882          self.get_schema_fields_with_mode("required"))
   883      schema = {"fields": schema_fields_with_nested}
   884  
   885      dict_row = {
   886          "nested_record": self.DICT_ROW,
   887          "str": "a",
   888          "bool": True,
   889          "bytes": b'a',
   890          "int": 1,
   891          "float": 0.1,
   892          "numeric": decimal.Decimal("1.11"),
   893          "timestamp": Timestamp(1000, 100)
   894      }
   895      expected_beam_row = beam.Row(
   896          nested_record=beam.Row(
   897              str="a",
   898              bool=True,
   899              bytes=b'a',
   900              int=1,
   901              float=0.1,
   902              numeric=decimal.Decimal("1.11"),
   903              timestamp=Timestamp(1000, 100)),
   904          str="a",
   905          bool=True,
   906          bytes=b'a',
   907          int=1,
   908          float=0.1,
   909          numeric=decimal.Decimal("1.11"),
   910          timestamp=Timestamp(1000, 100))
   911  
   912      self.assertEqual(expected_beam_row, beam_row_from_dict(dict_row, schema))
   913  
   914    def test_dict_to_beam_row_repeated_nested_record(self):
   915      schema_fields_with_repeated_nested_record = [{
   916          "name": "nested_repeated_record",
   917          "type": "record",
   918          "mode": "repeated",
   919          "fields": self.get_schema_fields_with_mode("required")
   920      }]
   921      schema = {"fields": schema_fields_with_repeated_nested_record}
   922  
   923      dict_row = {
   924          "nested_repeated_record": [self.DICT_ROW, self.DICT_ROW, self.DICT_ROW],
   925      }
   926  
   927      beam_row = beam.Row(
   928          str="a",
   929          bool=True,
   930          bytes=b'a',
   931          int=1,
   932          float=0.1,
   933          numeric=decimal.Decimal("1.11"),
   934          timestamp=Timestamp(1000, 100))
   935      expected_beam_row = beam.Row(
   936          nested_repeated_record=[beam_row, beam_row, beam_row])
   937  
   938      self.assertEqual(expected_beam_row, beam_row_from_dict(dict_row, schema))
   939  
   940  
   941  @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
   942  class TestBeamTypehintFromSchema(unittest.TestCase):
   943    EXPECTED_TYPEHINTS = [("str", str), ("bool", bool), ("bytes", bytes),
   944                          ("int", np.int64), ("float", np.float64),
   945                          ("numeric", decimal.Decimal), ("timestamp", Timestamp)]
   946  
   947    def get_schema_fields_with_mode(self, mode):
   948      return [{
   949          "name": "str", "type": "STRING", "mode": mode
   950      }, {
   951          "name": "bool", "type": "boolean", "mode": mode
   952      }, {
   953          "name": "bytes", "type": "BYTES", "mode": mode
   954      }, {
   955          "name": "int", "type": "INTEGER", "mode": mode
   956      }, {
   957          "name": "float", "type": "Float", "mode": mode
   958      }, {
   959          "name": "numeric", "type": "NUMERIC", "mode": mode
   960      }, {
   961          "name": "timestamp", "type": "TIMESTAMP", "mode": mode
   962      }]
   963  
   964    def test_typehints_from_required_schema(self):
   965      schema = {"fields": self.get_schema_fields_with_mode("required")}
   966      typehints = get_beam_typehints_from_tableschema(schema)
   967  
   968      self.assertEqual(typehints, self.EXPECTED_TYPEHINTS)
   969  
   970    def test_typehints_from_repeated_schema(self):
   971      schema = {"fields": self.get_schema_fields_with_mode("repeated")}
   972      typehints = get_beam_typehints_from_tableschema(schema)
   973  
   974      expected_repeated_typehints = [
   975          (name, Sequence[type]) for name, type in self.EXPECTED_TYPEHINTS
   976      ]
   977  
   978      self.assertEqual(typehints, expected_repeated_typehints)
   979  
   980    def test_typehints_from_nullable_schema(self):
   981      schema = {"fields": self.get_schema_fields_with_mode("nullable")}
   982      typehints = get_beam_typehints_from_tableschema(schema)
   983  
   984      expected_nullable_typehints = [
   985          (name, Optional[type]) for name, type in self.EXPECTED_TYPEHINTS
   986      ]
   987  
   988      self.assertEqual(typehints, expected_nullable_typehints)
   989  
   990    def test_typehints_from_schema_with_struct(self):
   991      schema = {
   992          "fields": [{
   993              "name": "record",
   994              "type": "record",
   995              "mode": "required",
   996              "fields": self.get_schema_fields_with_mode("required")
   997          }]
   998      }
   999      typehints = get_beam_typehints_from_tableschema(schema)
  1000  
  1001      expected_typehints = [
  1002          ("record", RowTypeConstraint.from_fields(self.EXPECTED_TYPEHINTS))
  1003      ]
  1004  
  1005      self.assertEqual(typehints, expected_typehints)
  1006  
  1007    def test_typehints_from_schema_with_repeated_struct(self):
  1008      schema = {
  1009          "fields": [{
  1010              "name": "record",
  1011              "type": "record",
  1012              "mode": "repeated",
  1013              "fields": self.get_schema_fields_with_mode("required")
  1014          }]
  1015      }
  1016      typehints = get_beam_typehints_from_tableschema(schema)
  1017  
  1018      expected_typehints = [(
  1019          "record",
  1020          Sequence[RowTypeConstraint.from_fields(self.EXPECTED_TYPEHINTS)])]
  1021  
  1022      self.assertEqual(typehints, expected_typehints)
  1023  
  1024  
  1025  if __name__ == '__main__':
  1026    logging.getLogger().setLevel(logging.INFO)
  1027    unittest.main()