github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/gcp/bigquery_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Unit tests for BigQuery sources and sinks."""
    19  # pytype: skip-file
    20  
    21  import datetime
    22  import decimal
    23  import gc
    24  import json
    25  import logging
    26  import os
    27  import pickle
    28  import re
    29  import secrets
    30  import time
    31  import unittest
    32  import uuid
    33  
    34  import hamcrest as hc
    35  import mock
    36  import pytest
    37  import pytz
    38  import requests
    39  from parameterized import param
    40  from parameterized import parameterized
    41  
    42  import apache_beam as beam
    43  from apache_beam.internal import pickler
    44  from apache_beam.internal.gcp.json_value import to_json_value
    45  from apache_beam.io.filebasedsink_test import _TestCaseWithTempDirCleanUp
    46  from apache_beam.io.gcp import bigquery as beam_bq
    47  from apache_beam.io.gcp import bigquery_tools
    48  from apache_beam.io.gcp.bigquery import ReadFromBigQuery
    49  from apache_beam.io.gcp.bigquery import TableRowJsonCoder
    50  from apache_beam.io.gcp.bigquery import WriteToBigQuery
    51  from apache_beam.io.gcp.bigquery import _StreamToBigQuery
    52  from apache_beam.io.gcp.bigquery_file_loads_test import _ELEMENTS
    53  from apache_beam.io.gcp.bigquery_read_internal import _JsonToDictCoder
    54  from apache_beam.io.gcp.bigquery_read_internal import bigquery_export_destination_uri
    55  from apache_beam.io.gcp.bigquery_tools import JSON_COMPLIANCE_ERROR
    56  from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper
    57  from apache_beam.io.gcp.bigquery_tools import RetryStrategy
    58  from apache_beam.io.gcp.internal.clients import bigquery
    59  from apache_beam.io.gcp.internal.clients.bigquery import bigquery_v2_client
    60  from apache_beam.io.gcp.pubsub import ReadFromPubSub
    61  from apache_beam.io.gcp.tests import utils
    62  from apache_beam.io.gcp.tests.bigquery_matcher import BigqueryFullResultMatcher
    63  from apache_beam.io.gcp.tests.bigquery_matcher import BigqueryFullResultStreamingMatcher
    64  from apache_beam.io.gcp.tests.bigquery_matcher import BigQueryTableMatcher
    65  from apache_beam.options import value_provider
    66  from apache_beam.options.pipeline_options import PipelineOptions
    67  from apache_beam.options.pipeline_options import StandardOptions
    68  from apache_beam.options.value_provider import RuntimeValueProvider
    69  from apache_beam.options.value_provider import StaticValueProvider
    70  from apache_beam.runners.dataflow.test_dataflow_runner import TestDataflowRunner
    71  from apache_beam.runners.runner import PipelineState
    72  from apache_beam.testing import test_utils
    73  from apache_beam.testing.pipeline_verifiers import PipelineStateMatcher
    74  from apache_beam.testing.test_pipeline import TestPipeline
    75  from apache_beam.testing.test_stream import TestStream
    76  from apache_beam.testing.util import assert_that
    77  from apache_beam.testing.util import equal_to
    78  from apache_beam.transforms.display import DisplayData
    79  from apache_beam.transforms.display_test import DisplayDataItemMatcher
    80  from apache_beam.utils import retry
    81  
    82  # Protect against environments where bigquery library is not available.
    83  # pylint: disable=wrong-import-order, wrong-import-position
    84  
    85  try:
    86    from apitools.base.py.exceptions import HttpError
    87    from google.cloud import bigquery as gcp_bigquery
    88    from google.api_core import exceptions
    89  except ImportError:
    90    gcp_bigquery = None
    91    HttpError = None
    92    exceptions = None
    93  # pylint: enable=wrong-import-order, wrong-import-position
    94  
    95  _LOGGER = logging.getLogger(__name__)
    96  
    97  
    98  def _load_or_default(filename):
    99    try:
   100      with open(filename) as f:
   101        return json.load(f)
   102    except:  # pylint: disable=bare-except
   103      return {}
   104  
   105  
   106  @unittest.skipIf(
   107      HttpError is None or gcp_bigquery is None,
   108      'GCP dependencies are not installed')
   109  class TestTableRowJsonCoder(unittest.TestCase):
   110    def test_row_as_table_row(self):
   111      schema_definition = [('s', 'STRING'), ('i', 'INTEGER'), ('f', 'FLOAT'),
   112                           ('b', 'BOOLEAN'), ('n', 'NUMERIC'), ('r', 'RECORD'),
   113                           ('g', 'GEOGRAPHY')]
   114      data_definition = [
   115          'abc',
   116          123,
   117          123.456,
   118          True,
   119          decimal.Decimal('987654321.987654321'), {
   120              'a': 'b'
   121          },
   122          'LINESTRING(1 2, 3 4, 5 6, 7 8)'
   123      ]
   124      str_def = (
   125          '{"s": "abc", '
   126          '"i": 123, '
   127          '"f": 123.456, '
   128          '"b": true, '
   129          '"n": "987654321.987654321", '
   130          '"r": {"a": "b"}, '
   131          '"g": "LINESTRING(1 2, 3 4, 5 6, 7 8)"}')
   132      schema = bigquery.TableSchema(
   133          fields=[
   134              bigquery.TableFieldSchema(name=k, type=v) for k,
   135              v in schema_definition
   136          ])
   137      coder = TableRowJsonCoder(table_schema=schema)
   138  
   139      def value_or_decimal_to_json(val):
   140        if isinstance(val, decimal.Decimal):
   141          return to_json_value(str(val))
   142        else:
   143          return to_json_value(val)
   144  
   145      test_row = bigquery.TableRow(
   146          f=[
   147              bigquery.TableCell(v=value_or_decimal_to_json(e))
   148              for e in data_definition
   149          ])
   150  
   151      self.assertEqual(str_def, coder.encode(test_row))
   152      self.assertEqual(test_row, coder.decode(coder.encode(test_row)))
   153      # A coder without schema can still decode.
   154      self.assertEqual(
   155          test_row, TableRowJsonCoder().decode(coder.encode(test_row)))
   156  
   157    def test_row_and_no_schema(self):
   158      coder = TableRowJsonCoder()
   159      test_row = bigquery.TableRow(
   160          f=[
   161              bigquery.TableCell(v=to_json_value(e))
   162              for e in ['abc', 123, 123.456, True]
   163          ])
   164      with self.assertRaisesRegex(AttributeError,
   165                                  r'^The TableRowJsonCoder requires'):
   166        coder.encode(test_row)
   167  
   168    def json_compliance_exception(self, value):
   169      with self.assertRaisesRegex(ValueError, re.escape(JSON_COMPLIANCE_ERROR)):
   170        schema_definition = [('f', 'FLOAT')]
   171        schema = bigquery.TableSchema(
   172            fields=[
   173                bigquery.TableFieldSchema(name=k, type=v) for k,
   174                v in schema_definition
   175            ])
   176        coder = TableRowJsonCoder(table_schema=schema)
   177        test_row = bigquery.TableRow(
   178            f=[bigquery.TableCell(v=to_json_value(value))])
   179        coder.encode(test_row)
   180  
   181    def test_invalid_json_nan(self):
   182      self.json_compliance_exception(float('nan'))
   183  
   184    def test_invalid_json_inf(self):
   185      self.json_compliance_exception(float('inf'))
   186  
   187    def test_invalid_json_neg_inf(self):
   188      self.json_compliance_exception(float('-inf'))
   189  
   190  
   191  @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
   192  class TestJsonToDictCoder(unittest.TestCase):
   193    @staticmethod
   194    def _make_schema(fields):
   195      def _fill_schema(fields):
   196        for field in fields:
   197          table_field = bigquery.TableFieldSchema()
   198          table_field.name, table_field.type, table_field.mode, nested_fields, \
   199            = field
   200          if nested_fields:
   201            table_field.fields = list(_fill_schema(nested_fields))
   202          yield table_field
   203  
   204      schema = bigquery.TableSchema()
   205      schema.fields = list(_fill_schema(fields))
   206      return schema
   207  
   208    def test_coder_is_pickable(self):
   209      try:
   210        schema = self._make_schema([
   211            (
   212                'record',
   213                'RECORD',
   214                'NULLABLE', [
   215                    ('float', 'FLOAT', 'NULLABLE', []),
   216                ]),
   217            ('integer', 'INTEGER', 'NULLABLE', []),
   218        ])
   219        coder = _JsonToDictCoder(schema)
   220        pickler.loads(pickler.dumps(coder))
   221      except pickle.PicklingError:
   222        self.fail('{} is not pickable'.format(coder.__class__.__name__))
   223  
   224    def test_values_are_converted(self):
   225      input_row = b'{"float": "10.5", "string": "abc"}'
   226      expected_row = {'float': 10.5, 'string': 'abc'}
   227      schema = self._make_schema([('float', 'FLOAT', 'NULLABLE', []),
   228                                  ('string', 'STRING', 'NULLABLE', [])])
   229      coder = _JsonToDictCoder(schema)
   230  
   231      actual = coder.decode(input_row)
   232      self.assertEqual(expected_row, actual)
   233  
   234    def test_null_fields_are_preserved(self):
   235      input_row = b'{"float": "10.5"}'
   236      expected_row = {'float': 10.5, 'string': None}
   237      schema = self._make_schema([('float', 'FLOAT', 'NULLABLE', []),
   238                                  ('string', 'STRING', 'NULLABLE', [])])
   239      coder = _JsonToDictCoder(schema)
   240  
   241      actual = coder.decode(input_row)
   242      self.assertEqual(expected_row, actual)
   243  
   244    def test_record_field_is_properly_converted(self):
   245      input_row = b'{"record": {"float": "55.5"}, "integer": 10}'
   246      expected_row = {'record': {'float': 55.5}, 'integer': 10}
   247      schema = self._make_schema([
   248          (
   249              'record',
   250              'RECORD',
   251              'NULLABLE', [
   252                  ('float', 'FLOAT', 'NULLABLE', []),
   253              ]),
   254          ('integer', 'INTEGER', 'NULLABLE', []),
   255      ])
   256      coder = _JsonToDictCoder(schema)
   257  
   258      actual = coder.decode(input_row)
   259      self.assertEqual(expected_row, actual)
   260  
   261    def test_record_and_repeatable_field_is_properly_converted(self):
   262      input_row = b'{"record": [{"float": "55.5"}, {"float": "65.5"}], ' \
   263                  b'"integer": 10}'
   264      expected_row = {'record': [{'float': 55.5}, {'float': 65.5}], 'integer': 10}
   265      schema = self._make_schema([
   266          (
   267              'record',
   268              'RECORD',
   269              'REPEATED', [
   270                  ('float', 'FLOAT', 'NULLABLE', []),
   271              ]),
   272          ('integer', 'INTEGER', 'NULLABLE', []),
   273      ])
   274      coder = _JsonToDictCoder(schema)
   275  
   276      actual = coder.decode(input_row)
   277      self.assertEqual(expected_row, actual)
   278  
   279    def test_repeatable_field_is_properly_converted(self):
   280      input_row = b'{"repeated": ["55.5", "65.5"], "integer": "10"}'
   281      expected_row = {'repeated': [55.5, 65.5], 'integer': 10}
   282      schema = self._make_schema([
   283          ('repeated', 'FLOAT', 'REPEATED', []),
   284          ('integer', 'INTEGER', 'NULLABLE', []),
   285      ])
   286      coder = _JsonToDictCoder(schema)
   287  
   288      actual = coder.decode(input_row)
   289      self.assertEqual(expected_row, actual)
   290  
   291  
   292  @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
   293  class TestReadFromBigQuery(unittest.TestCase):
   294    @classmethod
   295    def setUpClass(cls):
   296      class UserDefinedOptions(PipelineOptions):
   297        @classmethod
   298        def _add_argparse_args(cls, parser):
   299          parser.add_value_provider_argument('--gcs_location')
   300  
   301      cls.UserDefinedOptions = UserDefinedOptions
   302  
   303    def tearDown(self):
   304      # Reset runtime options to avoid side-effects caused by other tests.
   305      RuntimeValueProvider.set_runtime_options(None)
   306  
   307    @classmethod
   308    def tearDownClass(cls):
   309      # Unset the option added in setupClass to avoid interfere with other tests.
   310      # Force a gc so PipelineOptions.__subclass__() no longer contains it.
   311      del cls.UserDefinedOptions
   312      gc.collect()
   313  
   314    def test_get_destination_uri_empty_runtime_vp(self):
   315      with self.assertRaisesRegex(ValueError,
   316                                  '^ReadFromBigQuery requires a GCS '
   317                                  'location to be provided'):
   318        # Don't provide any runtime values.
   319        RuntimeValueProvider.set_runtime_options({})
   320        options = self.UserDefinedOptions()
   321  
   322        bigquery_export_destination_uri(
   323            options.gcs_location, None, uuid.uuid4().hex)
   324  
   325    def test_get_destination_uri_none(self):
   326      with self.assertRaisesRegex(ValueError,
   327                                  '^ReadFromBigQuery requires a GCS '
   328                                  'location to be provided'):
   329        bigquery_export_destination_uri(None, None, uuid.uuid4().hex)
   330  
   331    def test_get_destination_uri_runtime_vp(self):
   332      # Provide values at job-execution time.
   333      RuntimeValueProvider.set_runtime_options({'gcs_location': 'gs://bucket'})
   334      options = self.UserDefinedOptions()
   335      unique_id = uuid.uuid4().hex
   336  
   337      uri = bigquery_export_destination_uri(options.gcs_location, None, unique_id)
   338      self.assertEqual(
   339          uri, 'gs://bucket/' + unique_id + '/bigquery-table-dump-*.json')
   340  
   341    def test_get_destination_uri_static_vp(self):
   342      unique_id = uuid.uuid4().hex
   343      uri = bigquery_export_destination_uri(
   344          StaticValueProvider(str, 'gs://bucket'), None, unique_id)
   345      self.assertEqual(
   346          uri, 'gs://bucket/' + unique_id + '/bigquery-table-dump-*.json')
   347  
   348    def test_get_destination_uri_fallback_temp_location(self):
   349      # Don't provide any runtime values.
   350      RuntimeValueProvider.set_runtime_options({})
   351      options = self.UserDefinedOptions()
   352  
   353      with self.assertLogs('apache_beam.io.gcp.bigquery_read_internal',
   354                           level='DEBUG') as context:
   355        bigquery_export_destination_uri(
   356            options.gcs_location, 'gs://bucket', uuid.uuid4().hex)
   357      self.assertEqual(
   358          context.output,
   359          [
   360              'DEBUG:apache_beam.io.gcp.bigquery_read_internal:gcs_location is '
   361              'empty, using temp_location instead'
   362          ])
   363  
   364    @mock.patch.object(BigQueryWrapper, '_delete_table')
   365    @mock.patch.object(BigQueryWrapper, '_delete_dataset')
   366    @mock.patch('apache_beam.io.gcp.internal.clients.bigquery.BigqueryV2')
   367    def test_temp_dataset_is_configurable(
   368        self, api, delete_dataset, delete_table):
   369      temp_dataset = bigquery.DatasetReference(
   370          projectId='temp-project', datasetId='bq_dataset')
   371      bq = BigQueryWrapper(client=api, temp_dataset_id=temp_dataset.datasetId)
   372      gcs_location = 'gs://gcs_location'
   373  
   374      c = beam.io.gcp.bigquery._CustomBigQuerySource(
   375          query='select * from test_table',
   376          gcs_location=gcs_location,
   377          method=beam.io.ReadFromBigQuery.Method.EXPORT,
   378          validate=True,
   379          pipeline_options=beam.options.pipeline_options.PipelineOptions(),
   380          job_name='job_name',
   381          step_name='step_name',
   382          project='execution_project',
   383          **{'temp_dataset': temp_dataset})
   384  
   385      c._setup_temporary_dataset(bq)
   386      api.datasets.assert_not_called()
   387  
   388      # User provided temporary dataset should not be deleted but the temporary
   389      # table created by Beam should be deleted.
   390      bq.clean_up_temporary_dataset(temp_dataset.projectId)
   391      delete_dataset.assert_not_called()
   392      delete_table.assert_called_with(
   393          temp_dataset.projectId, temp_dataset.datasetId, mock.ANY)
   394  
   395    @parameterized.expand([
   396        param(
   397            exception_type=exceptions.Forbidden if exceptions else None,
   398            error_message='accessDenied'),
   399        param(
   400            exception_type=exceptions.ServiceUnavailable if exceptions else None,
   401            error_message='backendError'),
   402    ])
   403    def test_create_temp_dataset_exception(self, exception_type, error_message):
   404  
   405      with mock.patch.object(bigquery_v2_client.BigqueryV2.JobsService,
   406                             'Insert'),\
   407        mock.patch.object(BigQueryWrapper,
   408                          'get_or_create_dataset') as mock_insert, \
   409        mock.patch('time.sleep'), \
   410        self.assertRaises(Exception) as exc,\
   411        beam.Pipeline() as p:
   412  
   413        mock_insert.side_effect = exception_type(error_message)
   414  
   415        _ = p | ReadFromBigQuery(
   416            project='apache-beam-testing',
   417            query='SELECT * FROM `project.dataset.table`',
   418            gcs_location='gs://temp_location')
   419  
   420      mock_insert.assert_called()
   421      self.assertIn(error_message, exc.exception.args[0])
   422  
   423    @parameterized.expand([
   424        param(
   425            exception_type=exceptions.BadRequest if exceptions else None,
   426            error_message='invalidQuery'),
   427        param(
   428            exception_type=exceptions.NotFound if exceptions else None,
   429            error_message='notFound'),
   430        param(
   431            exception_type=exceptions.Forbidden if exceptions else None,
   432            error_message='responseTooLarge')
   433    ])
   434    def test_query_job_exception(self, exception_type, error_message):
   435  
   436      with mock.patch.object(beam.io.gcp.bigquery._CustomBigQuerySource,
   437                             'estimate_size') as mock_estimate,\
   438        mock.patch.object(BigQueryWrapper,
   439                          'get_query_location') as mock_query_location,\
   440        mock.patch.object(bigquery_v2_client.BigqueryV2.JobsService,
   441                          'Insert') as mock_query_job,\
   442        mock.patch.object(bigquery_v2_client.BigqueryV2.DatasetsService, 'Get'), \
   443        mock.patch('time.sleep'), \
   444        self.assertRaises(Exception) as exc, \
   445        beam.Pipeline() as p:
   446  
   447        mock_estimate.return_value = None
   448        mock_query_location.return_value = None
   449        mock_query_job.side_effect = exception_type(error_message)
   450  
   451        _ = p | ReadFromBigQuery(
   452            query='SELECT * FROM `project.dataset.table`',
   453            gcs_location='gs://temp_location')
   454  
   455      mock_query_job.assert_called()
   456      self.assertIn(error_message, exc.exception.args[0])
   457  
   458    @parameterized.expand([
   459        param(
   460            exception_type=exceptions.BadRequest if exceptions else None,
   461            error_message='invalid'),
   462        param(
   463            exception_type=exceptions.Forbidden if exceptions else None,
   464            error_message='accessDenied')
   465    ])
   466    def test_read_export_exception(self, exception_type, error_message):
   467  
   468      with mock.patch.object(beam.io.gcp.bigquery._CustomBigQuerySource,
   469                             'estimate_size') as mock_estimate,\
   470        mock.patch.object(bigquery_v2_client.BigqueryV2.TablesService, 'Get'),\
   471        mock.patch.object(bigquery_v2_client.BigqueryV2.JobsService,
   472                          'Insert') as mock_query_job, \
   473        mock.patch('time.sleep'), \
   474        self.assertRaises(Exception) as exc,\
   475        beam.Pipeline() as p:
   476  
   477        mock_estimate.return_value = None
   478        mock_query_job.side_effect = exception_type(error_message)
   479  
   480        _ = p | ReadFromBigQuery(
   481            project='apache-beam-testing',
   482            method=ReadFromBigQuery.Method.EXPORT,
   483            table='project:dataset.table',
   484            gcs_location="gs://temp_location")
   485  
   486      mock_query_job.assert_called()
   487      self.assertIn(error_message, exc.exception.args[0])
   488  
   489  
   490  @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
   491  class TestBigQuerySink(unittest.TestCase):
   492    def test_table_spec_display_data(self):
   493      sink = beam.io.BigQuerySink('dataset.table')
   494      dd = DisplayData.create_from(sink)
   495      expected_items = [
   496          DisplayDataItemMatcher('table', 'dataset.table'),
   497          DisplayDataItemMatcher('validation', False)
   498      ]
   499      hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
   500  
   501    def test_parse_schema_descriptor(self):
   502      sink = beam.io.BigQuerySink('dataset.table', schema='s:STRING, n:INTEGER')
   503      self.assertEqual(sink.table_reference.datasetId, 'dataset')
   504      self.assertEqual(sink.table_reference.tableId, 'table')
   505      result_schema = {
   506          field['name']: field['type']
   507          for field in sink.schema['fields']
   508      }
   509      self.assertEqual({'n': 'INTEGER', 's': 'STRING'}, result_schema)
   510  
   511    def test_project_table_display_data(self):
   512      sinkq = beam.io.BigQuerySink('project:dataset.table')
   513      dd = DisplayData.create_from(sinkq)
   514      expected_items = [
   515          DisplayDataItemMatcher('table', 'project:dataset.table'),
   516          DisplayDataItemMatcher('validation', False)
   517      ]
   518      hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
   519  
   520  
   521  @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
   522  class TestWriteToBigQuery(unittest.TestCase):
   523    def _cleanup_files(self):
   524      if os.path.exists('insert_calls1'):
   525        os.remove('insert_calls1')
   526  
   527      if os.path.exists('insert_calls2'):
   528        os.remove('insert_calls2')
   529  
   530    def setUp(self):
   531      self._cleanup_files()
   532  
   533    def tearDown(self):
   534      self._cleanup_files()
   535  
   536    def test_noop_schema_parsing(self):
   537      expected_table_schema = None
   538      table_schema = beam.io.gcp.bigquery.BigQueryWriteFn.get_table_schema(
   539          schema=None)
   540      self.assertEqual(expected_table_schema, table_schema)
   541  
   542    def test_dict_schema_parsing(self):
   543      schema = {
   544          'fields': [{
   545              'name': 's', 'type': 'STRING', 'mode': 'NULLABLE'
   546          }, {
   547              'name': 'n', 'type': 'INTEGER', 'mode': 'NULLABLE'
   548          },
   549                     {
   550                         'name': 'r',
   551                         'type': 'RECORD',
   552                         'mode': 'NULLABLE',
   553                         'fields': [{
   554                             'name': 'x', 'type': 'INTEGER', 'mode': 'NULLABLE'
   555                         }]
   556                     }]
   557      }
   558      table_schema = beam.io.gcp.bigquery.BigQueryWriteFn.get_table_schema(schema)
   559      string_field = bigquery.TableFieldSchema(
   560          name='s', type='STRING', mode='NULLABLE')
   561      nested_field = bigquery.TableFieldSchema(
   562          name='x', type='INTEGER', mode='NULLABLE')
   563      number_field = bigquery.TableFieldSchema(
   564          name='n', type='INTEGER', mode='NULLABLE')
   565      record_field = bigquery.TableFieldSchema(
   566          name='r', type='RECORD', mode='NULLABLE', fields=[nested_field])
   567      expected_table_schema = bigquery.TableSchema(
   568          fields=[string_field, number_field, record_field])
   569      self.assertEqual(expected_table_schema, table_schema)
   570  
   571    def test_string_schema_parsing(self):
   572      schema = 's:STRING, n:INTEGER'
   573      expected_dict_schema = {
   574          'fields': [{
   575              'name': 's', 'type': 'STRING', 'mode': 'NULLABLE'
   576          }, {
   577              'name': 'n', 'type': 'INTEGER', 'mode': 'NULLABLE'
   578          }]
   579      }
   580      dict_schema = (
   581          beam.io.gcp.bigquery.WriteToBigQuery.get_dict_table_schema(schema))
   582      self.assertEqual(expected_dict_schema, dict_schema)
   583  
   584    def test_table_schema_parsing(self):
   585      string_field = bigquery.TableFieldSchema(
   586          name='s', type='STRING', mode='NULLABLE')
   587      nested_field = bigquery.TableFieldSchema(
   588          name='x', type='INTEGER', mode='NULLABLE')
   589      number_field = bigquery.TableFieldSchema(
   590          name='n', type='INTEGER', mode='NULLABLE')
   591      record_field = bigquery.TableFieldSchema(
   592          name='r', type='RECORD', mode='NULLABLE', fields=[nested_field])
   593      schema = bigquery.TableSchema(
   594          fields=[string_field, number_field, record_field])
   595      expected_dict_schema = {
   596          'fields': [{
   597              'name': 's', 'type': 'STRING', 'mode': 'NULLABLE'
   598          }, {
   599              'name': 'n', 'type': 'INTEGER', 'mode': 'NULLABLE'
   600          },
   601                     {
   602                         'name': 'r',
   603                         'type': 'RECORD',
   604                         'mode': 'NULLABLE',
   605                         'fields': [{
   606                             'name': 'x', 'type': 'INTEGER', 'mode': 'NULLABLE'
   607                         }]
   608                     }]
   609      }
   610      dict_schema = (
   611          beam.io.gcp.bigquery.WriteToBigQuery.get_dict_table_schema(schema))
   612      self.assertEqual(expected_dict_schema, dict_schema)
   613  
   614    def test_table_schema_parsing_end_to_end(self):
   615      string_field = bigquery.TableFieldSchema(
   616          name='s', type='STRING', mode='NULLABLE')
   617      nested_field = bigquery.TableFieldSchema(
   618          name='x', type='INTEGER', mode='NULLABLE')
   619      number_field = bigquery.TableFieldSchema(
   620          name='n', type='INTEGER', mode='NULLABLE')
   621      record_field = bigquery.TableFieldSchema(
   622          name='r', type='RECORD', mode='NULLABLE', fields=[nested_field])
   623      schema = bigquery.TableSchema(
   624          fields=[string_field, number_field, record_field])
   625      table_schema = beam.io.gcp.bigquery.BigQueryWriteFn.get_table_schema(
   626          beam.io.gcp.bigquery.WriteToBigQuery.get_dict_table_schema(schema))
   627      self.assertEqual(table_schema, schema)
   628  
   629    def test_none_schema_parsing(self):
   630      schema = None
   631      expected_dict_schema = None
   632      dict_schema = (
   633          beam.io.gcp.bigquery.WriteToBigQuery.get_dict_table_schema(schema))
   634      self.assertEqual(expected_dict_schema, dict_schema)
   635  
   636    def test_noop_dict_schema_parsing(self):
   637      schema = {
   638          'fields': [{
   639              'name': 's', 'type': 'STRING', 'mode': 'NULLABLE'
   640          }, {
   641              'name': 'n', 'type': 'INTEGER', 'mode': 'NULLABLE'
   642          }]
   643      }
   644      expected_dict_schema = schema
   645      dict_schema = (
   646          beam.io.gcp.bigquery.WriteToBigQuery.get_dict_table_schema(schema))
   647      self.assertEqual(expected_dict_schema, dict_schema)
   648  
   649    def test_schema_autodetect_not_allowed_with_avro_file_loads(self):
   650      with TestPipeline() as p:
   651        pc = p | beam.Impulse()
   652  
   653        with self.assertRaisesRegex(ValueError, '^A schema must be provided'):
   654          _ = (
   655              pc
   656              | 'No Schema' >> beam.io.gcp.bigquery.WriteToBigQuery(
   657                  "dataset.table",
   658                  schema=None,
   659                  temp_file_format=bigquery_tools.FileFormat.AVRO))
   660  
   661        with self.assertRaisesRegex(ValueError,
   662                                    '^Schema auto-detection is not supported'):
   663          _ = (
   664              pc
   665              | 'Schema Autodetected' >> beam.io.gcp.bigquery.WriteToBigQuery(
   666                  "dataset.table",
   667                  schema=beam.io.gcp.bigquery.SCHEMA_AUTODETECT,
   668                  temp_file_format=bigquery_tools.FileFormat.AVRO))
   669  
   670    def test_to_from_runner_api(self):
   671      """Tests that serialization of WriteToBigQuery is correct.
   672  
   673      This is not intended to be a change-detector test. As such, this only tests
   674      the more complicated serialization logic of parameters: ValueProviders,
   675      callables, and side inputs.
   676      """
   677      FULL_OUTPUT_TABLE = 'test_project:output_table'
   678  
   679      p = TestPipeline()
   680  
   681      # Used for testing side input parameters.
   682      table_record_pcv = beam.pvalue.AsDict(
   683          p | "MakeTable" >> beam.Create([('table', FULL_OUTPUT_TABLE)]))
   684  
   685      # Used for testing value provider parameters.
   686      schema = value_provider.StaticValueProvider(str, '"a:str"')
   687  
   688      original = WriteToBigQuery(
   689          table=lambda _,
   690          side_input: side_input['table'],
   691          table_side_inputs=(table_record_pcv, ),
   692          schema=schema)
   693  
   694      # pylint: disable=expression-not-assigned
   695      p | 'MyWriteToBigQuery' >> original
   696  
   697      # Run the pipeline through to generate a pipeline proto from an empty
   698      # context. This ensures that the serialization code ran.
   699      pipeline_proto, context = TestPipeline.from_runner_api(
   700          p.to_runner_api(), p.runner, p.get_pipeline_options()).to_runner_api(
   701              return_context=True)
   702  
   703      # Find the transform from the context.
   704      write_to_bq_id = [
   705          k for k,
   706          v in pipeline_proto.components.transforms.items()
   707          if v.unique_name == 'MyWriteToBigQuery'
   708      ][0]
   709      deserialized_node = context.transforms.get_by_id(write_to_bq_id)
   710      deserialized = deserialized_node.transform
   711      self.assertIsInstance(deserialized, WriteToBigQuery)
   712  
   713      # Test that the serialization of a value provider is correct.
   714      self.assertEqual(original.schema, deserialized.schema)
   715  
   716      # Test that the serialization of a callable is correct.
   717      self.assertEqual(
   718          deserialized._table(None, {'table': FULL_OUTPUT_TABLE}),
   719          FULL_OUTPUT_TABLE)
   720  
   721      # Test that the serialization of a side input is correct.
   722      self.assertEqual(
   723          len(original.table_side_inputs), len(deserialized.table_side_inputs))
   724      original_side_input_data = original.table_side_inputs[0]._side_input_data()
   725      deserialized_side_input_data = deserialized.table_side_inputs[
   726          0]._side_input_data()
   727      self.assertEqual(
   728          original_side_input_data.access_pattern,
   729          deserialized_side_input_data.access_pattern)
   730      self.assertEqual(
   731          original_side_input_data.window_mapping_fn,
   732          deserialized_side_input_data.window_mapping_fn)
   733      self.assertEqual(
   734          original_side_input_data.view_fn, deserialized_side_input_data.view_fn)
   735  
   736    def test_streaming_triggering_frequency_without_auto_sharding(self):
   737      def noop(table, **kwargs):
   738        return []
   739  
   740      client = mock.Mock()
   741      client.insert_rows_json = mock.Mock(side_effect=noop)
   742      opt = StandardOptions()
   743      opt.streaming = True
   744      with self.assertRaises(ValueError,
   745                             msg="triggering_frequency with STREAMING_INSERTS" +
   746                             "can only be used with with_auto_sharding=True"):
   747        with beam.Pipeline(runner='BundleBasedDirectRunner', options=opt) as p:
   748          _ = (
   749              p
   750              | beam.Create([{
   751                  'columnA': 'value1'
   752              }])
   753              | WriteToBigQuery(
   754                  table='project:dataset.table',
   755                  schema={
   756                      'fields': [{
   757                          'name': 'columnA', 'type': 'STRING', 'mode': 'NULLABLE'
   758                      }]
   759                  },
   760                  create_disposition='CREATE_NEVER',
   761                  triggering_frequency=1,
   762                  with_auto_sharding=False,
   763                  test_client=client))
   764  
   765    def test_streaming_triggering_frequency_with_auto_sharding(self):
   766      def noop(table, **kwargs):
   767        return []
   768  
   769      client = mock.Mock()
   770      client.insert_rows_json = mock.Mock(side_effect=noop)
   771      opt = StandardOptions()
   772      opt.streaming = True
   773      with beam.Pipeline(runner='BundleBasedDirectRunner', options=opt) as p:
   774        _ = (
   775            p
   776            | beam.Create([{
   777                'columnA': 'value1'
   778            }])
   779            | WriteToBigQuery(
   780                table='project:dataset.table',
   781                schema={
   782                    'fields': [{
   783                        'name': 'columnA', 'type': 'STRING', 'mode': 'NULLABLE'
   784                    }]
   785                },
   786                create_disposition='CREATE_NEVER',
   787                triggering_frequency=1,
   788                with_auto_sharding=True,
   789                test_client=client))
   790  
   791    @parameterized.expand([
   792        param(
   793            exception_type=exceptions.Forbidden if exceptions else None,
   794            error_message='accessDenied'),
   795        param(
   796            exception_type=exceptions.ServiceUnavailable if exceptions else None,
   797            error_message='backendError')
   798    ])
   799    def test_load_job_exception(self, exception_type, error_message):
   800  
   801      with mock.patch.object(bigquery_v2_client.BigqueryV2.JobsService,
   802                       'Insert') as mock_load_job,\
   803        mock.patch('apache_beam.io.gcp.internal.clients'
   804                   '.storage.storage_v1_client.StorageV1.ObjectsService'),\
   805        mock.patch('time.sleep'),\
   806        self.assertRaises(Exception) as exc,\
   807        beam.Pipeline() as p:
   808  
   809        mock_load_job.side_effect = exception_type(error_message)
   810  
   811        _ = (
   812            p
   813            | beam.Create([{
   814                'columnA': 'value1'
   815            }])
   816            | WriteToBigQuery(
   817                table='project:dataset.table',
   818                schema={
   819                    'fields': [{
   820                        'name': 'columnA', 'type': 'STRING', 'mode': 'NULLABLE'
   821                    }]
   822                },
   823                create_disposition='CREATE_NEVER',
   824                custom_gcs_temp_location="gs://temp_location",
   825                method='FILE_LOADS'))
   826  
   827      mock_load_job.assert_called()
   828      self.assertIn(error_message, exc.exception.args[0])
   829  
   830    @parameterized.expand([
   831        param(
   832            exception_type=exceptions.ServiceUnavailable if exceptions else None,
   833            error_message='backendError'),
   834        param(
   835            exception_type=exceptions.InternalServerError if exceptions else None,
   836            error_message='internalError'),
   837    ])
   838    def test_copy_load_job_exception(self, exception_type, error_message):
   839  
   840      from apache_beam.io.gcp import bigquery_file_loads
   841  
   842      old_max_file_size = bigquery_file_loads._DEFAULT_MAX_FILE_SIZE
   843      old_max_partition_size = bigquery_file_loads._MAXIMUM_LOAD_SIZE
   844      old_max_files_per_partition = bigquery_file_loads._MAXIMUM_SOURCE_URIS
   845      bigquery_file_loads._DEFAULT_MAX_FILE_SIZE = 15
   846      bigquery_file_loads._MAXIMUM_LOAD_SIZE = 30
   847      bigquery_file_loads._MAXIMUM_SOURCE_URIS = 1
   848  
   849      with mock.patch.object(bigquery_v2_client.BigqueryV2.JobsService,
   850                          'Insert') as mock_insert_copy_job, \
   851        mock.patch.object(BigQueryWrapper,
   852                          'perform_load_job') as mock_load_job, \
   853        mock.patch.object(BigQueryWrapper,
   854                          'wait_for_bq_job'), \
   855        mock.patch('apache_beam.io.gcp.internal.clients'
   856          '.storage.storage_v1_client.StorageV1.ObjectsService'), \
   857        mock.patch('time.sleep'), \
   858        self.assertRaises(Exception) as exc, \
   859        beam.Pipeline() as p:
   860  
   861        mock_insert_copy_job.side_effect = exception_type(error_message)
   862  
   863        dummy_job_reference = beam.io.gcp.internal.clients.bigquery.JobReference()
   864        dummy_job_reference.jobId = 'job_id'
   865        dummy_job_reference.location = 'US'
   866        dummy_job_reference.projectId = 'apache-beam-testing'
   867  
   868        mock_load_job.return_value = dummy_job_reference
   869  
   870        _ = (
   871            p
   872            | beam.Create([{
   873                'columnA': 'value1'
   874            }, {
   875                'columnA': 'value2'
   876            }, {
   877                'columnA': 'value3'
   878            }])
   879            | WriteToBigQuery(
   880                table='project:dataset.table',
   881                schema={
   882                    'fields': [{
   883                        'name': 'columnA', 'type': 'STRING', 'mode': 'NULLABLE'
   884                    }]
   885                },
   886                create_disposition='CREATE_NEVER',
   887                custom_gcs_temp_location="gs://temp_location",
   888                method='FILE_LOADS'))
   889  
   890      bigquery_file_loads._DEFAULT_MAX_FILE_SIZE = old_max_file_size
   891      bigquery_file_loads._MAXIMUM_LOAD_SIZE = old_max_partition_size
   892      bigquery_file_loads._MAXIMUM_SOURCE_URIS = old_max_files_per_partition
   893  
   894      self.assertEqual(4, mock_insert_copy_job.call_count)
   895      self.assertIn(error_message, exc.exception.args[0])
   896  
   897  
   898  @unittest.skipIf(
   899      HttpError is None or exceptions is None,
   900      'GCP dependencies are not installed')
   901  class BigQueryStreamingInsertsErrorHandling(unittest.TestCase):
   902  
   903    # Using https://cloud.google.com/bigquery/docs/error-messages and
   904    # https://googleapis.dev/python/google-api-core/latest/_modules/google
   905    #    /api_core/exceptions.html
   906    # to determine error types and messages to try for retriables.
   907    @parameterized.expand([
   908        param(
   909            exception_type=exceptions.Forbidden if exceptions else None,
   910            error_reason='rateLimitExceeded'),
   911        param(
   912            exception_type=exceptions.DeadlineExceeded if exceptions else None,
   913            error_reason='somereason'),
   914        param(
   915            exception_type=exceptions.ServiceUnavailable if exceptions else None,
   916            error_reason='backendError'),
   917        param(
   918            exception_type=exceptions.InternalServerError if exceptions else None,
   919            error_reason='internalError'),
   920        param(
   921            exception_type=exceptions.InternalServerError if exceptions else None,
   922            error_reason='backendError'),
   923    ])
   924    @mock.patch('time.sleep')
   925    @mock.patch('google.cloud.bigquery.Client.insert_rows_json')
   926    def test_insert_all_retries_if_structured_retriable(
   927        self,
   928        mock_send,
   929        unused_mock_sleep,
   930        exception_type=None,
   931        error_reason=None):
   932      # In this test, a BATCH pipeline will retry the known RETRIABLE errors.
   933      mock_send.side_effect = [
   934          exception_type(
   935              'some retriable exception', errors=[{
   936                  'reason': error_reason
   937              }]),
   938          exception_type(
   939              'some retriable exception', errors=[{
   940                  'reason': error_reason
   941              }]),
   942          exception_type(
   943              'some retriable exception', errors=[{
   944                  'reason': error_reason
   945              }]),
   946          exception_type(
   947              'some retriable exception', errors=[{
   948                  'reason': error_reason
   949              }]),
   950      ]
   951  
   952      with self.assertRaises(Exception) as exc:
   953        with beam.Pipeline() as p:
   954          _ = (
   955              p
   956              | beam.Create([{
   957                  'columnA': 'value1'
   958              }])
   959              | WriteToBigQuery(
   960                  table='project:dataset.table',
   961                  schema={
   962                      'fields': [{
   963                          'name': 'columnA', 'type': 'STRING', 'mode': 'NULLABLE'
   964                      }]
   965                  },
   966                  create_disposition='CREATE_NEVER',
   967                  method='STREAMING_INSERTS'))
   968      self.assertEqual(4, mock_send.call_count)
   969      self.assertIn('some retriable exception', exc.exception.args[0])
   970  
   971    # Using https://googleapis.dev/python/google-api-core/latest/_modules/google
   972    #   /api_core/exceptions.html
   973    # to determine error types and messages to try for retriables.
   974    @parameterized.expand([
   975        param(
   976            exception_type=requests.exceptions.ConnectionError,
   977            error_message='some connection error'),
   978        param(
   979            exception_type=requests.exceptions.Timeout,
   980            error_message='some timeout error'),
   981        param(
   982            exception_type=ConnectionError,
   983            error_message='some py connection error'),
   984        param(
   985            exception_type=exceptions.BadGateway if exceptions else None,
   986            error_message='some badgateway error'),
   987    ])
   988    @mock.patch('time.sleep')
   989    @mock.patch('google.cloud.bigquery.Client.insert_rows_json')
   990    def test_insert_all_retries_if_unstructured_retriable(
   991        self,
   992        mock_send,
   993        unused_mock_sleep,
   994        exception_type=None,
   995        error_message=None):
   996      # In this test, a BATCH pipeline will retry the unknown RETRIABLE errors.
   997      mock_send.side_effect = [
   998          exception_type(error_message),
   999          exception_type(error_message),
  1000          exception_type(error_message),
  1001          exception_type(error_message),
  1002      ]
  1003  
  1004      with self.assertRaises(Exception) as exc:
  1005        with beam.Pipeline() as p:
  1006          _ = (
  1007              p
  1008              | beam.Create([{
  1009                  'columnA': 'value1'
  1010              }])
  1011              | WriteToBigQuery(
  1012                  table='project:dataset.table',
  1013                  schema={
  1014                      'fields': [{
  1015                          'name': 'columnA', 'type': 'STRING', 'mode': 'NULLABLE'
  1016                      }]
  1017                  },
  1018                  create_disposition='CREATE_NEVER',
  1019                  method='STREAMING_INSERTS'))
  1020      self.assertEqual(4, mock_send.call_count)
  1021      self.assertIn(error_message, exc.exception.args[0])
  1022  
  1023    # Using https://googleapis.dev/python/google-api-core/latest/_modules/google
  1024    #   /api_core/exceptions.html
  1025    # to determine error types and messages to try for retriables.
  1026    @parameterized.expand([
  1027        param(
  1028            exception_type=retry.PermanentException,
  1029            error_args=('nonretriable', )),
  1030        param(
  1031            exception_type=exceptions.BadRequest if exceptions else None,
  1032            error_args=(
  1033                'forbidden morbidden', [{
  1034                    'reason': 'nonretriablereason'
  1035                }])),
  1036        param(
  1037            exception_type=exceptions.BadRequest if exceptions else None,
  1038            error_args=('BAD REQUEST!', [{
  1039                'reason': 'nonretriablereason'
  1040            }])),
  1041        param(
  1042            exception_type=exceptions.MethodNotAllowed if exceptions else None,
  1043            error_args=(
  1044                'method not allowed!', [{
  1045                    'reason': 'nonretriablereason'
  1046                }])),
  1047        param(
  1048            exception_type=exceptions.MethodNotAllowed if exceptions else None,
  1049            error_args=('method not allowed!', 'args')),
  1050        param(
  1051            exception_type=exceptions.Unknown if exceptions else None,
  1052            error_args=('unknown!', 'args')),
  1053        param(
  1054            exception_type=exceptions.Aborted if exceptions else None,
  1055            error_args=('abortet!', 'abort')),
  1056    ])
  1057    @mock.patch('time.sleep')
  1058    @mock.patch('google.cloud.bigquery.Client.insert_rows_json')
  1059    def test_insert_all_unretriable_errors(
  1060        self, mock_send, unused_mock_sleep, exception_type=None, error_args=None):
  1061      # In this test, a BATCH pipeline will retry the unknown RETRIABLE errors.
  1062      mock_send.side_effect = [
  1063          exception_type(*error_args),
  1064          exception_type(*error_args),
  1065          exception_type(*error_args),
  1066          exception_type(*error_args),
  1067      ]
  1068  
  1069      with self.assertRaises(Exception):
  1070        with beam.Pipeline() as p:
  1071          _ = (
  1072              p
  1073              | beam.Create([{
  1074                  'columnA': 'value1'
  1075              }])
  1076              | WriteToBigQuery(
  1077                  table='project:dataset.table',
  1078                  schema={
  1079                      'fields': [{
  1080                          'name': 'columnA', 'type': 'STRING', 'mode': 'NULLABLE'
  1081                      }]
  1082                  },
  1083                  create_disposition='CREATE_NEVER',
  1084                  method='STREAMING_INSERTS'))
  1085      self.assertEqual(1, mock_send.call_count)
  1086  
  1087    # Using https://googleapis.dev/python/google-api-core/latest/_modules/google
  1088    #    /api_core/exceptions.html
  1089    # to determine error types and messages to try for retriables.
  1090    @parameterized.expand([
  1091        param(
  1092            exception_type=retry.PermanentException,
  1093            error_args=('nonretriable', )),
  1094        param(
  1095            exception_type=exceptions.BadRequest if exceptions else None,
  1096            error_args=(
  1097                'forbidden morbidden', [{
  1098                    'reason': 'nonretriablereason'
  1099                }])),
  1100        param(
  1101            exception_type=exceptions.BadRequest if exceptions else None,
  1102            error_args=('BAD REQUEST!', [{
  1103                'reason': 'nonretriablereason'
  1104            }])),
  1105        param(
  1106            exception_type=exceptions.MethodNotAllowed if exceptions else None,
  1107            error_args=(
  1108                'method not allowed!', [{
  1109                    'reason': 'nonretriablereason'
  1110                }])),
  1111        param(
  1112            exception_type=exceptions.MethodNotAllowed if exceptions else None,
  1113            error_args=('method not allowed!', 'args')),
  1114        param(
  1115            exception_type=exceptions.Unknown if exceptions else None,
  1116            error_args=('unknown!', 'args')),
  1117        param(
  1118            exception_type=exceptions.Aborted if exceptions else None,
  1119            error_args=('abortet!', 'abort')),
  1120        param(
  1121            exception_type=requests.exceptions.ConnectionError,
  1122            error_args=('some connection error', )),
  1123        param(
  1124            exception_type=requests.exceptions.Timeout,
  1125            error_args=('some timeout error', )),
  1126        param(
  1127            exception_type=ConnectionError,
  1128            error_args=('some py connection error', )),
  1129        param(
  1130            exception_type=exceptions.BadGateway if exceptions else None,
  1131            error_args=('some badgateway error', )),
  1132    ])
  1133    @mock.patch('time.sleep')
  1134    @mock.patch('google.cloud.bigquery.Client.insert_rows_json')
  1135    def test_insert_all_unretriable_errors_streaming(
  1136        self, mock_send, unused_mock_sleep, exception_type=None, error_args=None):
  1137      # In this test, a STREAMING pipeline will retry ALL errors, and never throw
  1138      # an exception.
  1139      mock_send.side_effect = [
  1140          exception_type(*error_args),
  1141          exception_type(*error_args),
  1142          []  # Errors thrown twice, and then succeeded
  1143      ]
  1144  
  1145      opt = StandardOptions()
  1146      opt.streaming = True
  1147      with beam.Pipeline(runner='BundleBasedDirectRunner', options=opt) as p:
  1148        _ = (
  1149            p
  1150            | beam.Create([{
  1151                'columnA': 'value1'
  1152            }])
  1153            | WriteToBigQuery(
  1154                table='project:dataset.table',
  1155                schema={
  1156                    'fields': [{
  1157                        'name': 'columnA', 'type': 'STRING', 'mode': 'NULLABLE'
  1158                    }]
  1159                },
  1160                create_disposition='CREATE_NEVER',
  1161                method='STREAMING_INSERTS'))
  1162      self.assertEqual(3, mock_send.call_count)
  1163  
  1164  
  1165  @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
  1166  class BigQueryStreamingInsertTransformTests(unittest.TestCase):
  1167    def test_dofn_client_process_performs_batching(self):
  1168      client = mock.Mock()
  1169      client.tables.Get.return_value = bigquery.Table(
  1170          tableReference=bigquery.TableReference(
  1171              projectId='project-id', datasetId='dataset_id', tableId='table_id'))
  1172      client.insert_rows_json.return_value = []
  1173      create_disposition = beam.io.BigQueryDisposition.CREATE_NEVER
  1174      write_disposition = beam.io.BigQueryDisposition.WRITE_APPEND
  1175  
  1176      fn = beam.io.gcp.bigquery.BigQueryWriteFn(
  1177          batch_size=2,
  1178          create_disposition=create_disposition,
  1179          write_disposition=write_disposition,
  1180          kms_key=None,
  1181          test_client=client)
  1182  
  1183      fn.process(('project-id:dataset_id.table_id', {'month': 1}))
  1184  
  1185      # InsertRows not called as batch size is not hit yet
  1186      self.assertFalse(client.insert_rows_json.called)
  1187  
  1188    def test_dofn_client_process_flush_called(self):
  1189      client = mock.Mock()
  1190      client.tables.Get.return_value = bigquery.Table(
  1191          tableReference=bigquery.TableReference(
  1192              projectId='project-id', datasetId='dataset_id', tableId='table_id'))
  1193      client.insert_rows_json.return_value = []
  1194      create_disposition = beam.io.BigQueryDisposition.CREATE_NEVER
  1195      write_disposition = beam.io.BigQueryDisposition.WRITE_APPEND
  1196  
  1197      fn = beam.io.gcp.bigquery.BigQueryWriteFn(
  1198          batch_size=2,
  1199          create_disposition=create_disposition,
  1200          write_disposition=write_disposition,
  1201          kms_key=None,
  1202          test_client=client)
  1203  
  1204      fn.start_bundle()
  1205      fn.process(('project-id:dataset_id.table_id', ({'month': 1}, 'insertid1')))
  1206      fn.process(('project-id:dataset_id.table_id', ({'month': 2}, 'insertid2')))
  1207      # InsertRows called as batch size is hit
  1208      self.assertTrue(client.insert_rows_json.called)
  1209  
  1210    def test_dofn_client_finish_bundle_flush_called(self):
  1211      client = mock.Mock()
  1212      client.tables.Get.return_value = bigquery.Table(
  1213          tableReference=bigquery.TableReference(
  1214              projectId='project-id', datasetId='dataset_id', tableId='table_id'))
  1215      client.insert_rows_json.return_value = []
  1216      create_disposition = beam.io.BigQueryDisposition.CREATE_IF_NEEDED
  1217      write_disposition = beam.io.BigQueryDisposition.WRITE_APPEND
  1218  
  1219      fn = beam.io.gcp.bigquery.BigQueryWriteFn(
  1220          batch_size=2,
  1221          create_disposition=create_disposition,
  1222          write_disposition=write_disposition,
  1223          kms_key=None,
  1224          test_client=client)
  1225  
  1226      fn.start_bundle()
  1227  
  1228      # Destination is a tuple of (destination, schema) to ensure the table is
  1229      # created.
  1230      fn.process(('project-id:dataset_id.table_id', ({'month': 1}, 'insertid3')))
  1231  
  1232      self.assertTrue(client.tables.Get.called)
  1233      # InsertRows not called as batch size is not hit
  1234      self.assertFalse(client.insert_rows_json.called)
  1235  
  1236      fn.finish_bundle()
  1237      # InsertRows called in finish bundle
  1238      self.assertTrue(client.insert_rows_json.called)
  1239  
  1240    def test_dofn_client_no_records(self):
  1241      client = mock.Mock()
  1242      client.tables.Get.return_value = bigquery.Table(
  1243          tableReference=bigquery.TableReference(
  1244              projectId='project-id', datasetId='dataset_id', tableId='table_id'))
  1245      client.tabledata.InsertAll.return_value = \
  1246        bigquery.TableDataInsertAllResponse(insertErrors=[])
  1247      create_disposition = beam.io.BigQueryDisposition.CREATE_NEVER
  1248      write_disposition = beam.io.BigQueryDisposition.WRITE_APPEND
  1249  
  1250      fn = beam.io.gcp.bigquery.BigQueryWriteFn(
  1251          batch_size=2,
  1252          create_disposition=create_disposition,
  1253          write_disposition=write_disposition,
  1254          kms_key=None,
  1255          test_client=client)
  1256  
  1257      fn.start_bundle()
  1258      # InsertRows not called as batch size is not hit
  1259      self.assertFalse(client.tabledata.InsertAll.called)
  1260  
  1261      fn.finish_bundle()
  1262      # InsertRows not called in finish bundle as no records
  1263      self.assertFalse(client.tabledata.InsertAll.called)
  1264  
  1265    def test_with_batched_input(self):
  1266      client = mock.Mock()
  1267      client.tables.Get.return_value = bigquery.Table(
  1268          tableReference=bigquery.TableReference(
  1269              projectId='project-id', datasetId='dataset_id', tableId='table_id'))
  1270      client.insert_rows_json.return_value = []
  1271      create_disposition = beam.io.BigQueryDisposition.CREATE_IF_NEEDED
  1272      write_disposition = beam.io.BigQueryDisposition.WRITE_APPEND
  1273  
  1274      fn = beam.io.gcp.bigquery.BigQueryWriteFn(
  1275          batch_size=10,
  1276          create_disposition=create_disposition,
  1277          write_disposition=write_disposition,
  1278          kms_key=None,
  1279          with_batched_input=True,
  1280          test_client=client)
  1281  
  1282      fn.start_bundle()
  1283  
  1284      # Destination is a tuple of (destination, schema) to ensure the table is
  1285      # created.
  1286      fn.process((
  1287          'project-id:dataset_id.table_id',
  1288          [({
  1289              'month': 1
  1290          }, 'insertid3'), ({
  1291              'month': 2
  1292          }, 'insertid2'), ({
  1293              'month': 3
  1294          }, 'insertid1')]))
  1295  
  1296      # InsertRows called since the input is already batched.
  1297      self.assertTrue(client.insert_rows_json.called)
  1298  
  1299  
  1300  @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
  1301  class PipelineBasedStreamingInsertTest(_TestCaseWithTempDirCleanUp):
  1302    @mock.patch('time.sleep')
  1303    def test_failure_has_same_insert_ids(self, unused_mock_sleep):
  1304      tempdir = '%s%s' % (self._new_tempdir(), os.sep)
  1305      file_name_1 = os.path.join(tempdir, 'file1')
  1306      file_name_2 = os.path.join(tempdir, 'file2')
  1307  
  1308      def store_callback(table, **kwargs):
  1309        insert_ids = [r for r in kwargs['row_ids']]
  1310        colA_values = [r['columnA'] for r in kwargs['json_rows']]
  1311        json_output = {'insertIds': insert_ids, 'colA_values': colA_values}
  1312        # The first time we try to insert, we save those insertions in
  1313        # file insert_calls1.
  1314        if not os.path.exists(file_name_1):
  1315          with open(file_name_1, 'w') as f:
  1316            json.dump(json_output, f)
  1317          raise RuntimeError()
  1318        else:
  1319          with open(file_name_2, 'w') as f:
  1320            json.dump(json_output, f)
  1321  
  1322        return []
  1323  
  1324      client = mock.Mock()
  1325      client.insert_rows_json = mock.Mock(side_effect=store_callback)
  1326  
  1327      # Using the bundle based direct runner to avoid pickling problems
  1328      # with mocks.
  1329      with beam.Pipeline(runner='BundleBasedDirectRunner') as p:
  1330        _ = (
  1331            p
  1332            | beam.Create([{
  1333                'columnA': 'value1', 'columnB': 'value2'
  1334            }, {
  1335                'columnA': 'value3', 'columnB': 'value4'
  1336            }, {
  1337                'columnA': 'value5', 'columnB': 'value6'
  1338            }])
  1339            | _StreamToBigQuery(
  1340                table_reference='project:dataset.table',
  1341                table_side_inputs=[],
  1342                schema_side_inputs=[],
  1343                schema='anyschema',
  1344                batch_size=None,
  1345                triggering_frequency=None,
  1346                create_disposition='CREATE_NEVER',
  1347                write_disposition=None,
  1348                kms_key=None,
  1349                retry_strategy=None,
  1350                additional_bq_parameters=[],
  1351                ignore_insert_ids=False,
  1352                ignore_unknown_columns=False,
  1353                with_auto_sharding=False,
  1354                test_client=client,
  1355                num_streaming_keys=500))
  1356  
  1357      with open(file_name_1) as f1, open(file_name_2) as f2:
  1358        self.assertEqual(json.load(f1), json.load(f2))
  1359  
  1360    @parameterized.expand([
  1361        param(retry_strategy=RetryStrategy.RETRY_ALWAYS),
  1362        param(retry_strategy=RetryStrategy.RETRY_NEVER),
  1363        param(retry_strategy=RetryStrategy.RETRY_ON_TRANSIENT_ERROR),
  1364    ])
  1365    def test_failure_in_some_rows_does_not_duplicate(self, retry_strategy=None):
  1366      with mock.patch('time.sleep'):
  1367        # In this test we simulate a failure to write out two out of three rows.
  1368        # Row 0 and row 2 fail to be written on the first attempt, and then
  1369        # succeed on the next attempt (if there is one).
  1370        tempdir = '%s%s' % (self._new_tempdir(), os.sep)
  1371        file_name_1 = os.path.join(tempdir, 'file1_partial')
  1372        file_name_2 = os.path.join(tempdir, 'file2_partial')
  1373  
  1374        def store_callback(table, **kwargs):
  1375          insert_ids = [r for r in kwargs['row_ids']]
  1376          colA_values = [r['columnA'] for r in kwargs['json_rows']]
  1377  
  1378          # The first time this function is called, all rows are included
  1379          # so we need to filter out 'failed' rows.
  1380          json_output_1 = {
  1381              'insertIds': [insert_ids[1]], 'colA_values': [colA_values[1]]
  1382          }
  1383          # The second time this function is called, only rows 0 and 2 are incl
  1384          # so we don't need to filter any of them. We just write them all out.
  1385          json_output_2 = {'insertIds': insert_ids, 'colA_values': colA_values}
  1386  
  1387          # The first time we try to insert, we save those insertions in
  1388          # file insert_calls1.
  1389          if not os.path.exists(file_name_1):
  1390            with open(file_name_1, 'w') as f:
  1391              json.dump(json_output_1, f)
  1392            return [
  1393                {
  1394                    'index': 0,
  1395                    'errors': [{
  1396                        'reason': 'i dont like this row'
  1397                    }, {
  1398                        'reason': 'its bad'
  1399                    }]
  1400                },
  1401                {
  1402                    'index': 2,
  1403                    'errors': [{
  1404                        'reason': 'i het this row'
  1405                    }, {
  1406                        'reason': 'its no gud'
  1407                    }]
  1408                },
  1409            ]
  1410          else:
  1411            with open(file_name_2, 'w') as f:
  1412              json.dump(json_output_2, f)
  1413              return []
  1414  
  1415        client = mock.Mock()
  1416        client.insert_rows_json = mock.Mock(side_effect=store_callback)
  1417  
  1418        # The expected rows to be inserted according to the insert strategy
  1419        if retry_strategy == RetryStrategy.RETRY_NEVER:
  1420          result = ['value3']
  1421        else:  # RETRY_ALWAYS and RETRY_ON_TRANSIENT_ERRORS should insert all rows
  1422          result = ['value1', 'value3', 'value5']
  1423  
  1424        # Using the bundle based direct runner to avoid pickling problems
  1425        # with mocks.
  1426        with beam.Pipeline(runner='BundleBasedDirectRunner') as p:
  1427          bq_write_out = (
  1428              p
  1429              | beam.Create([{
  1430                  'columnA': 'value1', 'columnB': 'value2'
  1431              }, {
  1432                  'columnA': 'value3', 'columnB': 'value4'
  1433              }, {
  1434                  'columnA': 'value5', 'columnB': 'value6'
  1435              }])
  1436              | _StreamToBigQuery(
  1437                  table_reference='project:dataset.table',
  1438                  table_side_inputs=[],
  1439                  schema_side_inputs=[],
  1440                  schema='anyschema',
  1441                  batch_size=None,
  1442                  triggering_frequency=None,
  1443                  create_disposition='CREATE_NEVER',
  1444                  write_disposition=None,
  1445                  kms_key=None,
  1446                  retry_strategy=retry_strategy,
  1447                  additional_bq_parameters=[],
  1448                  ignore_insert_ids=False,
  1449                  ignore_unknown_columns=False,
  1450                  with_auto_sharding=False,
  1451                  test_client=client,
  1452                  num_streaming_keys=500))
  1453  
  1454          failed_values = (
  1455              bq_write_out[beam_bq.BigQueryWriteFn.FAILED_ROWS_WITH_ERRORS]
  1456              | beam.Map(lambda x: x[1]['columnA']))
  1457  
  1458          assert_that(
  1459              failed_values,
  1460              equal_to(list({'value1', 'value3', 'value5'}.difference(result))))
  1461  
  1462        data1 = _load_or_default(file_name_1)
  1463        data2 = _load_or_default(file_name_2)
  1464  
  1465        self.assertListEqual(
  1466            sorted(data1.get('colA_values', []) + data2.get('colA_values', [])),
  1467            result)
  1468        self.assertEqual(len(data1['colA_values']), 1)
  1469  
  1470    @parameterized.expand([
  1471        param(retry_strategy=RetryStrategy.RETRY_ALWAYS),
  1472        param(retry_strategy=RetryStrategy.RETRY_NEVER),
  1473        param(retry_strategy=RetryStrategy.RETRY_ON_TRANSIENT_ERROR),
  1474    ])
  1475    def test_permanent_failure_in_some_rows_does_not_duplicate(
  1476        self, unused_sleep_mock=None, retry_strategy=None):
  1477      with mock.patch('time.sleep'):
  1478  
  1479        def store_callback(table, **kwargs):
  1480          return [
  1481              {
  1482                  'index': 0,
  1483                  'errors': [{
  1484                      'reason': 'invalid'
  1485                  }, {
  1486                      'reason': 'its bad'
  1487                  }]
  1488              },
  1489          ]
  1490  
  1491        client = mock.Mock()
  1492        client.insert_rows_json = mock.Mock(side_effect=store_callback)
  1493  
  1494        # The expected rows to be inserted according to the insert strategy
  1495        if retry_strategy == RetryStrategy.RETRY_NEVER:
  1496          inserted_rows = ['value3', 'value5']
  1497        else:  # RETRY_ALWAYS and RETRY_ON_TRANSIENT_ERRORS should insert all rows
  1498          inserted_rows = ['value3', 'value5']
  1499  
  1500        # Using the bundle based direct runner to avoid pickling problems
  1501        # with mocks.
  1502        with beam.Pipeline(runner='BundleBasedDirectRunner') as p:
  1503          bq_write_out = (
  1504              p
  1505              | beam.Create([{
  1506                  'columnA': 'value1', 'columnB': 'value2'
  1507              }, {
  1508                  'columnA': 'value3', 'columnB': 'value4'
  1509              }, {
  1510                  'columnA': 'value5', 'columnB': 'value6'
  1511              }])
  1512              | _StreamToBigQuery(
  1513                  table_reference='project:dataset.table',
  1514                  table_side_inputs=[],
  1515                  schema_side_inputs=[],
  1516                  schema='anyschema',
  1517                  batch_size=None,
  1518                  triggering_frequency=None,
  1519                  create_disposition='CREATE_NEVER',
  1520                  write_disposition=None,
  1521                  kms_key=None,
  1522                  retry_strategy=retry_strategy,
  1523                  additional_bq_parameters=[],
  1524                  ignore_insert_ids=False,
  1525                  ignore_unknown_columns=False,
  1526                  with_auto_sharding=False,
  1527                  test_client=client,
  1528                  max_retries=10,
  1529                  num_streaming_keys=500))
  1530  
  1531          failed_values = (
  1532              bq_write_out[beam_bq.BigQueryWriteFn.FAILED_ROWS]
  1533              | beam.Map(lambda x: x[1]['columnA']))
  1534  
  1535          assert_that(
  1536              failed_values,
  1537              equal_to(
  1538                  list({'value1', 'value3', 'value5'}.difference(inserted_rows))))
  1539  
  1540    @parameterized.expand([
  1541        param(with_auto_sharding=False),
  1542        param(with_auto_sharding=True),
  1543    ])
  1544    def test_batch_size_with_auto_sharding(self, with_auto_sharding):
  1545      tempdir = '%s%s' % (self._new_tempdir(), os.sep)
  1546      file_name_1 = os.path.join(tempdir, 'file1')
  1547      file_name_2 = os.path.join(tempdir, 'file2')
  1548  
  1549      def store_callback(table, **kwargs):
  1550        insert_ids = [r for r in kwargs['row_ids']]
  1551        colA_values = [r['columnA'] for r in kwargs['json_rows']]
  1552        json_output = {'insertIds': insert_ids, 'colA_values': colA_values}
  1553        # Expect two batches of rows will be inserted. Store them separately.
  1554        if not os.path.exists(file_name_1):
  1555          with open(file_name_1, 'w') as f:
  1556            json.dump(json_output, f)
  1557        else:
  1558          with open(file_name_2, 'w') as f:
  1559            json.dump(json_output, f)
  1560  
  1561        return []
  1562  
  1563      client = mock.Mock()
  1564      client.insert_rows_json = mock.Mock(side_effect=store_callback)
  1565  
  1566      # Using the bundle based direct runner to avoid pickling problems
  1567      # with mocks.
  1568      with beam.Pipeline(runner='BundleBasedDirectRunner') as p:
  1569        _ = (
  1570            p
  1571            | beam.Create([{
  1572                'columnA': 'value1', 'columnB': 'value2'
  1573            }, {
  1574                'columnA': 'value3', 'columnB': 'value4'
  1575            }, {
  1576                'columnA': 'value5', 'columnB': 'value6'
  1577            }])
  1578            | _StreamToBigQuery(
  1579                table_reference='project:dataset.table',
  1580                table_side_inputs=[],
  1581                schema_side_inputs=[],
  1582                schema='anyschema',
  1583                # Set a batch size such that the input elements will be inserted
  1584                # in 2 batches.
  1585                batch_size=2,
  1586                triggering_frequency=None,
  1587                create_disposition='CREATE_NEVER',
  1588                write_disposition=None,
  1589                kms_key=None,
  1590                retry_strategy=None,
  1591                additional_bq_parameters=[],
  1592                ignore_insert_ids=False,
  1593                ignore_unknown_columns=False,
  1594                with_auto_sharding=with_auto_sharding,
  1595                test_client=client,
  1596                num_streaming_keys=500))
  1597  
  1598      with open(file_name_1) as f1, open(file_name_2) as f2:
  1599        out1 = json.load(f1)
  1600        self.assertEqual(out1['colA_values'], ['value1', 'value3'])
  1601        out2 = json.load(f2)
  1602        self.assertEqual(out2['colA_values'], ['value5'])
  1603  
  1604  
  1605  @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
  1606  class BigQueryStreamingInsertTransformIntegrationTests(unittest.TestCase):
  1607    BIG_QUERY_DATASET_ID = 'python_bq_streaming_inserts_'
  1608  
  1609    def setUp(self):
  1610      self.test_pipeline = TestPipeline(is_integration_test=True)
  1611      self.runner_name = type(self.test_pipeline.runner).__name__
  1612      self.project = self.test_pipeline.get_option('project')
  1613  
  1614      self.dataset_id = '%s%d%s' % (
  1615          self.BIG_QUERY_DATASET_ID, int(time.time()), secrets.token_hex(3))
  1616      self.bigquery_client = bigquery_tools.BigQueryWrapper()
  1617      self.bigquery_client.get_or_create_dataset(self.project, self.dataset_id)
  1618      self.output_table = "%s.output_table" % (self.dataset_id)
  1619      _LOGGER.info(
  1620          "Created dataset %s in project %s", self.dataset_id, self.project)
  1621  
  1622    @pytest.mark.it_postcommit
  1623    def test_value_provider_transform(self):
  1624      output_table_1 = '%s%s' % (self.output_table, 1)
  1625      output_table_2 = '%s%s' % (self.output_table, 2)
  1626      schema = {
  1627          'fields': [{
  1628              'name': 'name', 'type': 'STRING', 'mode': 'NULLABLE'
  1629          }, {
  1630              'name': 'language', 'type': 'STRING', 'mode': 'NULLABLE'
  1631          }]
  1632      }
  1633  
  1634      additional_bq_parameters = {
  1635          'timePartitioning': {
  1636              'type': 'DAY'
  1637          },
  1638          'clustering': {
  1639              'fields': ['language']
  1640          }
  1641      }
  1642  
  1643      table_ref = bigquery_tools.parse_table_reference(output_table_1)
  1644      table_ref2 = bigquery_tools.parse_table_reference(output_table_2)
  1645  
  1646      pipeline_verifiers = [
  1647          BigQueryTableMatcher(
  1648              project=self.project,
  1649              dataset=table_ref.datasetId,
  1650              table=table_ref.tableId,
  1651              expected_properties=additional_bq_parameters),
  1652          BigQueryTableMatcher(
  1653              project=self.project,
  1654              dataset=table_ref2.datasetId,
  1655              table=table_ref2.tableId,
  1656              expected_properties=additional_bq_parameters),
  1657          BigqueryFullResultMatcher(
  1658              project=self.project,
  1659              query="SELECT name, language FROM %s" % output_table_1,
  1660              data=[(d['name'], d['language']) for d in _ELEMENTS
  1661                    if 'language' in d]),
  1662          BigqueryFullResultMatcher(
  1663              project=self.project,
  1664              query="SELECT name, language FROM %s" % output_table_2,
  1665              data=[(d['name'], d['language']) for d in _ELEMENTS
  1666                    if 'language' in d])
  1667      ]
  1668  
  1669      args = self.test_pipeline.get_full_options_as_args(
  1670          on_success_matcher=hc.all_of(*pipeline_verifiers))
  1671  
  1672      with beam.Pipeline(argv=args) as p:
  1673        input = p | beam.Create([row for row in _ELEMENTS if 'language' in row])
  1674  
  1675        _ = (
  1676            input
  1677            | "WriteWithMultipleDests" >> beam.io.gcp.bigquery.WriteToBigQuery(
  1678                table=value_provider.StaticValueProvider(
  1679                    str, '%s:%s' % (self.project, output_table_1)),
  1680                schema=value_provider.StaticValueProvider(dict, schema),
  1681                additional_bq_parameters=additional_bq_parameters,
  1682                method='STREAMING_INSERTS'))
  1683        _ = (
  1684            input
  1685            | "WriteWithMultipleDests2" >> beam.io.gcp.bigquery.WriteToBigQuery(
  1686                table=value_provider.StaticValueProvider(
  1687                    str, '%s:%s' % (self.project, output_table_2)),
  1688                schema=beam.io.gcp.bigquery.SCHEMA_AUTODETECT,
  1689                additional_bq_parameters=lambda _: additional_bq_parameters,
  1690                method='FILE_LOADS'))
  1691  
  1692    @pytest.mark.it_postcommit
  1693    def test_multiple_destinations_transform(self):
  1694      streaming = self.test_pipeline.options.view_as(StandardOptions).streaming
  1695      if streaming and isinstance(self.test_pipeline.runner, TestDataflowRunner):
  1696        self.skipTest("TestStream is not supported on TestDataflowRunner")
  1697  
  1698      output_table_1 = '%s%s' % (self.output_table, 1)
  1699      output_table_2 = '%s%s' % (self.output_table, 2)
  1700  
  1701      full_output_table_1 = '%s:%s' % (self.project, output_table_1)
  1702      full_output_table_2 = '%s:%s' % (self.project, output_table_2)
  1703  
  1704      schema1 = {
  1705          'fields': [{
  1706              'name': 'name', 'type': 'STRING', 'mode': 'NULLABLE'
  1707          }, {
  1708              'name': 'language', 'type': 'STRING', 'mode': 'NULLABLE'
  1709          }]
  1710      }
  1711      schema2 = {
  1712          'fields': [{
  1713              'name': 'name', 'type': 'STRING', 'mode': 'NULLABLE'
  1714          }, {
  1715              'name': 'foundation', 'type': 'STRING', 'mode': 'NULLABLE'
  1716          }]
  1717      }
  1718  
  1719      bad_record = {'language': 1, 'manguage': 2}
  1720  
  1721      if streaming:
  1722        pipeline_verifiers = [
  1723            PipelineStateMatcher(PipelineState.RUNNING),
  1724            BigqueryFullResultStreamingMatcher(
  1725                project=self.project,
  1726                query="SELECT name, language FROM %s" % output_table_1,
  1727                data=[(d['name'], d['language']) for d in _ELEMENTS
  1728                      if 'language' in d]),
  1729            BigqueryFullResultStreamingMatcher(
  1730                project=self.project,
  1731                query="SELECT name, foundation FROM %s" % output_table_2,
  1732                data=[(d['name'], d['foundation']) for d in _ELEMENTS
  1733                      if 'foundation' in d])
  1734        ]
  1735      else:
  1736        pipeline_verifiers = [
  1737            BigqueryFullResultMatcher(
  1738                project=self.project,
  1739                query="SELECT name, language FROM %s" % output_table_1,
  1740                data=[(d['name'], d['language']) for d in _ELEMENTS
  1741                      if 'language' in d]),
  1742            BigqueryFullResultMatcher(
  1743                project=self.project,
  1744                query="SELECT name, foundation FROM %s" % output_table_2,
  1745                data=[(d['name'], d['foundation']) for d in _ELEMENTS
  1746                      if 'foundation' in d])
  1747        ]
  1748  
  1749      args = self.test_pipeline.get_full_options_as_args(
  1750          on_success_matcher=hc.all_of(*pipeline_verifiers))
  1751  
  1752      with beam.Pipeline(argv=args) as p:
  1753        if streaming:
  1754          _SIZE = len(_ELEMENTS)
  1755          test_stream = (
  1756              TestStream().advance_watermark_to(0).add_elements(
  1757                  _ELEMENTS[:_SIZE // 2]).advance_watermark_to(100).add_elements(
  1758                      _ELEMENTS[_SIZE // 2:]).advance_watermark_to_infinity())
  1759          input = p | test_stream
  1760        else:
  1761          input = p | beam.Create(_ELEMENTS)
  1762  
  1763        schema_table_pcv = beam.pvalue.AsDict(
  1764            p | "MakeSchemas" >> beam.Create([(full_output_table_1, schema1),
  1765                                              (full_output_table_2, schema2)]))
  1766  
  1767        table_record_pcv = beam.pvalue.AsDict(
  1768            p | "MakeTables" >> beam.Create([('table1', full_output_table_1),
  1769                                             ('table2', full_output_table_2)]))
  1770  
  1771        input2 = p | "Broken record" >> beam.Create([bad_record])
  1772  
  1773        input = (input, input2) | beam.Flatten()
  1774  
  1775        r = (
  1776            input
  1777            | "WriteWithMultipleDests" >> beam.io.gcp.bigquery.WriteToBigQuery(
  1778                table=lambda x,
  1779                tables:
  1780                (tables['table1'] if 'language' in x else tables['table2']),
  1781                table_side_inputs=(table_record_pcv, ),
  1782                schema=lambda dest,
  1783                table_map: table_map.get(dest, None),
  1784                schema_side_inputs=(schema_table_pcv, ),
  1785                insert_retry_strategy=RetryStrategy.RETRY_ON_TRANSIENT_ERROR,
  1786                method='STREAMING_INSERTS'))
  1787  
  1788        assert_that(
  1789            r[beam.io.gcp.bigquery.BigQueryWriteFn.FAILED_ROWS_WITH_ERRORS]
  1790            | beam.Map(lambda elm: (elm[0], elm[1])),
  1791            equal_to([(full_output_table_1, bad_record)]))
  1792  
  1793        assert_that(
  1794            r[beam.io.gcp.bigquery.BigQueryWriteFn.FAILED_ROWS],
  1795            equal_to([(full_output_table_1, bad_record)]),
  1796            label='FailedRowsMatch')
  1797  
  1798    def tearDown(self):
  1799      request = bigquery.BigqueryDatasetsDeleteRequest(
  1800          projectId=self.project, datasetId=self.dataset_id, deleteContents=True)
  1801      try:
  1802        _LOGGER.info(
  1803            "Deleting dataset %s in project %s", self.dataset_id, self.project)
  1804        self.bigquery_client.client.datasets.Delete(request)
  1805      except HttpError:
  1806        _LOGGER.debug(
  1807            'Failed to clean up dataset %s in project %s',
  1808            self.dataset_id,
  1809            self.project)
  1810  
  1811  
  1812  @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
  1813  class PubSubBigQueryIT(unittest.TestCase):
  1814  
  1815    INPUT_TOPIC = 'psit_topic_output'
  1816    INPUT_SUB = 'psit_subscription_input'
  1817  
  1818    BIG_QUERY_DATASET_ID = 'python_pubsub_bq_'
  1819    SCHEMA = {
  1820        'fields': [{
  1821            'name': 'number', 'type': 'INTEGER', 'mode': 'NULLABLE'
  1822        }]
  1823    }
  1824  
  1825    _SIZE = 4
  1826  
  1827    WAIT_UNTIL_FINISH_DURATION = 15 * 60 * 1000
  1828  
  1829    def setUp(self):
  1830      # Set up PubSub
  1831      self.test_pipeline = TestPipeline(is_integration_test=True)
  1832      self.runner_name = type(self.test_pipeline.runner).__name__
  1833      self.project = self.test_pipeline.get_option('project')
  1834      self.uuid = str(uuid.uuid4())
  1835      from google.cloud import pubsub
  1836      self.pub_client = pubsub.PublisherClient()
  1837      self.input_topic = self.pub_client.create_topic(
  1838          name=self.pub_client.topic_path(
  1839              self.project, self.INPUT_TOPIC + self.uuid))
  1840      self.sub_client = pubsub.SubscriberClient()
  1841      self.input_sub = self.sub_client.create_subscription(
  1842          name=self.sub_client.subscription_path(
  1843              self.project, self.INPUT_SUB + self.uuid),
  1844          topic=self.input_topic.name)
  1845  
  1846      # Set up BQ
  1847      self.dataset_ref = utils.create_bq_dataset(
  1848          self.project, self.BIG_QUERY_DATASET_ID)
  1849      self.output_table = "%s.output_table" % (self.dataset_ref.dataset_id)
  1850  
  1851    def tearDown(self):
  1852      # Tear down PubSub
  1853      test_utils.cleanup_topics(self.pub_client, [self.input_topic])
  1854      test_utils.cleanup_subscriptions(self.sub_client, [self.input_sub])
  1855      # Tear down BigQuery
  1856      utils.delete_bq_dataset(self.project, self.dataset_ref)
  1857  
  1858    def _run_pubsub_bq_pipeline(self, method, triggering_frequency=None):
  1859      l = [i for i in range(self._SIZE)]
  1860  
  1861      matchers = [
  1862          PipelineStateMatcher(PipelineState.RUNNING),
  1863          BigqueryFullResultStreamingMatcher(
  1864              project=self.project,
  1865              query="SELECT number FROM %s" % self.output_table,
  1866              data=[(i, ) for i in l])
  1867      ]
  1868  
  1869      args = self.test_pipeline.get_full_options_as_args(
  1870          on_success_matcher=hc.all_of(*matchers),
  1871          wait_until_finish_duration=self.WAIT_UNTIL_FINISH_DURATION,
  1872          streaming=True,
  1873          allow_unsafe_triggers=True)
  1874  
  1875      def add_schema_info(element):
  1876        yield {'number': element}
  1877  
  1878      messages = [str(i).encode('utf-8') for i in l]
  1879      for message in messages:
  1880        self.pub_client.publish(self.input_topic.name, message)
  1881  
  1882      with beam.Pipeline(argv=args) as p:
  1883        mesages = (
  1884            p
  1885            | ReadFromPubSub(subscription=self.input_sub.name)
  1886            | beam.ParDo(add_schema_info))
  1887        _ = mesages | WriteToBigQuery(
  1888            self.output_table,
  1889            schema=self.SCHEMA,
  1890            method=method,
  1891            triggering_frequency=triggering_frequency)
  1892  
  1893    @pytest.mark.it_postcommit
  1894    def test_streaming_inserts(self):
  1895      self._run_pubsub_bq_pipeline(WriteToBigQuery.Method.STREAMING_INSERTS)
  1896  
  1897    @pytest.mark.it_postcommit
  1898    def test_file_loads(self):
  1899      self._run_pubsub_bq_pipeline(
  1900          WriteToBigQuery.Method.FILE_LOADS, triggering_frequency=20)
  1901  
  1902  
  1903  @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
  1904  class BigQueryFileLoadsIntegrationTests(unittest.TestCase):
  1905    BIG_QUERY_DATASET_ID = 'python_bq_file_loads_'
  1906  
  1907    def setUp(self):
  1908      self.test_pipeline = TestPipeline(is_integration_test=True)
  1909      self.runner_name = type(self.test_pipeline.runner).__name__
  1910      self.project = self.test_pipeline.get_option('project')
  1911  
  1912      self.dataset_id = '%s%d%s' % (
  1913          self.BIG_QUERY_DATASET_ID, int(time.time()), secrets.token_hex(3))
  1914      self.bigquery_client = bigquery_tools.BigQueryWrapper()
  1915      self.bigquery_client.get_or_create_dataset(self.project, self.dataset_id)
  1916      self.output_table = '%s.output_table' % (self.dataset_id)
  1917      self.table_ref = bigquery_tools.parse_table_reference(self.output_table)
  1918      _LOGGER.info(
  1919          'Created dataset %s in project %s', self.dataset_id, self.project)
  1920  
  1921    @pytest.mark.it_postcommit
  1922    def test_avro_file_load(self):
  1923      # Construct elements such that they can be written via Avro but not via
  1924      # JSON. See BEAM-8841.
  1925      from apache_beam.io.gcp import bigquery_file_loads
  1926      old_max_files = bigquery_file_loads._MAXIMUM_SOURCE_URIS
  1927      old_max_file_size = bigquery_file_loads._DEFAULT_MAX_FILE_SIZE
  1928      bigquery_file_loads._MAXIMUM_SOURCE_URIS = 1
  1929      bigquery_file_loads._DEFAULT_MAX_FILE_SIZE = 100
  1930      elements = [
  1931          {
  1932              'name': 'Negative infinity',
  1933              'value': -float('inf'),
  1934              'timestamp': datetime.datetime(1970, 1, 1, tzinfo=pytz.utc),
  1935          },
  1936          {
  1937              'name': 'Not a number',
  1938              'value': float('nan'),
  1939              'timestamp': datetime.datetime(2930, 12, 9, tzinfo=pytz.utc),
  1940          },
  1941      ]
  1942  
  1943      schema = beam.io.gcp.bigquery.WriteToBigQuery.get_dict_table_schema(
  1944          bigquery.TableSchema(
  1945              fields=[
  1946                  bigquery.TableFieldSchema(
  1947                      name='name', type='STRING', mode='REQUIRED'),
  1948                  bigquery.TableFieldSchema(
  1949                      name='value', type='FLOAT', mode='REQUIRED'),
  1950                  bigquery.TableFieldSchema(
  1951                      name='timestamp', type='TIMESTAMP', mode='REQUIRED'),
  1952              ]))
  1953  
  1954      pipeline_verifiers = [
  1955          # Some gymnastics here to avoid comparing NaN since NaN is not equal to
  1956          # anything, including itself.
  1957          BigqueryFullResultMatcher(
  1958              project=self.project,
  1959              query="SELECT name, value, timestamp FROM {} WHERE value<0".format(
  1960                  self.output_table),
  1961              data=[(d['name'], d['value'], d['timestamp'])
  1962                    for d in elements[:1]],
  1963          ),
  1964          BigqueryFullResultMatcher(
  1965              project=self.project,
  1966              query="SELECT name, timestamp FROM {}".format(self.output_table),
  1967              data=[(d['name'], d['timestamp']) for d in elements],
  1968          ),
  1969      ]
  1970  
  1971      args = self.test_pipeline.get_full_options_as_args(
  1972          on_success_matcher=hc.all_of(*pipeline_verifiers),
  1973      )
  1974  
  1975      with beam.Pipeline(argv=args) as p:
  1976        input = p | 'CreateInput' >> beam.Create(elements)
  1977        schema_pc = p | 'CreateSchema' >> beam.Create([schema])
  1978  
  1979        _ = (
  1980            input
  1981            | 'WriteToBigQuery' >> beam.io.gcp.bigquery.WriteToBigQuery(
  1982                table='%s:%s' % (self.project, self.output_table),
  1983                schema=lambda _,
  1984                schema: schema,
  1985                schema_side_inputs=(beam.pvalue.AsSingleton(schema_pc), ),
  1986                method='FILE_LOADS',
  1987                temp_file_format=bigquery_tools.FileFormat.AVRO,
  1988            ))
  1989      bigquery_file_loads._MAXIMUM_SOURCE_URIS = old_max_files
  1990      bigquery_file_loads._DEFAULT_MAX_FILE_SIZE = old_max_file_size
  1991  
  1992    def tearDown(self):
  1993      request = bigquery.BigqueryDatasetsDeleteRequest(
  1994          projectId=self.project, datasetId=self.dataset_id, deleteContents=True)
  1995      try:
  1996        _LOGGER.info(
  1997            "Deleting dataset %s in project %s", self.dataset_id, self.project)
  1998        self.bigquery_client.client.datasets.Delete(request)
  1999      except HttpError:
  2000        _LOGGER.debug(
  2001            'Failed to clean up dataset %s in project %s',
  2002            self.dataset_id,
  2003            self.project)
  2004  
  2005  
  2006  if __name__ == '__main__':
  2007    logging.getLogger().setLevel(logging.INFO)
  2008    unittest.main()