github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/external/xlang_bigqueryio_it_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Unit tests for cross-language BigQuery sources and sinks."""
    19  # pytype: skip-file
    20  
    21  import datetime
    22  import logging
    23  import os
    24  import secrets
    25  import time
    26  import unittest
    27  from decimal import Decimal
    28  
    29  import pytest
    30  from hamcrest.core import assert_that as hamcrest_assert
    31  
    32  import apache_beam as beam
    33  from apache_beam.io.external.generate_sequence import GenerateSequence
    34  from apache_beam.io.gcp.bigquery import StorageWriteToBigQuery
    35  from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper
    36  from apache_beam.io.gcp.tests.bigquery_matcher import BigqueryFullResultMatcher
    37  from apache_beam.io.gcp.tests.bigquery_matcher import BigqueryFullResultStreamingMatcher
    38  from apache_beam.testing.test_pipeline import TestPipeline
    39  from apache_beam.utils.timestamp import Timestamp
    40  
    41  # Protect against environments where bigquery library is not available.
    42  # pylint: disable=wrong-import-order, wrong-import-position
    43  
    44  try:
    45    from apitools.base.py.exceptions import HttpError
    46  except ImportError:
    47    HttpError = None
    48  # pylint: enable=wrong-import-order, wrong-import-position
    49  
    50  _LOGGER = logging.getLogger(__name__)
    51  
    52  
    53  @pytest.mark.uses_gcp_java_expansion_service
    54  @unittest.skipUnless(
    55      os.environ.get('EXPANSION_PORT'),
    56      "EXPANSION_PORT environment var is not provided.")
    57  class BigQueryXlangStorageWriteIT(unittest.TestCase):
    58    BIGQUERY_DATASET = 'python_xlang_storage_write'
    59  
    60    ELEMENTS = [
    61        # (int, float, numeric, string, bool, bytes, timestamp)
    62        {
    63            "int": 1,
    64            "float": 0.1,
    65            "numeric": Decimal("1.11"),
    66            "str": "a",
    67            "bool": True,
    68            "bytes": b'a',
    69            "timestamp": Timestamp(1000, 100)
    70        },
    71        {
    72            "int": 2,
    73            "float": 0.2,
    74            "numeric": Decimal("2.22"),
    75            "str": "b",
    76            "bool": False,
    77            "bytes": b'b',
    78            "timestamp": Timestamp(2000, 200)
    79        },
    80        {
    81            "int": 3,
    82            "float": 0.3,
    83            "numeric": Decimal("3.33"),
    84            "str": "c",
    85            "bool": True,
    86            "bytes": b'd',
    87            "timestamp": Timestamp(3000, 300)
    88        },
    89        {
    90            "int": 4,
    91            "float": 0.4,
    92            "numeric": Decimal("4.44"),
    93            "str": "d",
    94            "bool": False,
    95            "bytes": b'd',
    96            "timestamp": Timestamp(4000, 400)
    97        }
    98    ]
    99    ALL_TYPES_SCHEMA = (
   100        "int:INTEGER,float:FLOAT,numeric:NUMERIC,str:STRING,"
   101        "bool:BOOLEAN,bytes:BYTES,timestamp:TIMESTAMP")
   102  
   103    def setUp(self):
   104      self.test_pipeline = TestPipeline(is_integration_test=True)
   105      self.args = self.test_pipeline.get_full_options_as_args()
   106      self.project = self.test_pipeline.get_option('project')
   107  
   108      self.bigquery_client = BigQueryWrapper()
   109      self.dataset_id = '%s_%s_%s' % (
   110          self.BIGQUERY_DATASET, str(int(time.time())), secrets.token_hex(3))
   111      self.bigquery_client.get_or_create_dataset(self.project, self.dataset_id)
   112      _LOGGER.info(
   113          "Created dataset %s in project %s", self.dataset_id, self.project)
   114  
   115      _LOGGER.info("expansion port: %s", os.environ.get('EXPANSION_PORT'))
   116      self.expansion_service = ('localhost:%s' % os.environ.get('EXPANSION_PORT'))
   117  
   118    def tearDown(self):
   119      try:
   120        _LOGGER.info(
   121            "Deleting dataset %s in project %s", self.dataset_id, self.project)
   122        self.bigquery_client._delete_dataset(
   123            project_id=self.project,
   124            dataset_id=self.dataset_id,
   125            delete_contents=True)
   126      except HttpError:
   127        _LOGGER.debug(
   128            'Failed to clean up dataset %s in project %s',
   129            self.dataset_id,
   130            self.project)
   131  
   132    def parse_expected_data(self, expected_elements):
   133      data = []
   134      for row in expected_elements:
   135        values = list(row.values())
   136        for i, val in enumerate(values):
   137          if isinstance(val, Timestamp):
   138            # BigQuery matcher query returns a datetime.datetime object
   139            values[i] = val.to_utc_datetime().replace(
   140                tzinfo=datetime.timezone.utc)
   141        data.append(tuple(values))
   142  
   143      return data
   144  
   145    def run_storage_write_test(
   146        self, table_name, items, schema, use_at_least_once=False):
   147      table_id = '{}:{}.{}'.format(self.project, self.dataset_id, table_name)
   148  
   149      bq_matcher = BigqueryFullResultMatcher(
   150          project=self.project,
   151          query="SELECT * FROM %s" % '{}.{}'.format(self.dataset_id, table_name),
   152          data=self.parse_expected_data(items))
   153  
   154      with beam.Pipeline(argv=self.args) as p:
   155        _ = (
   156            p
   157            | beam.Create(items)
   158            | beam.io.WriteToBigQuery(
   159                table=table_id,
   160                method=beam.io.WriteToBigQuery.Method.STORAGE_WRITE_API,
   161                schema=schema,
   162                use_at_least_once=use_at_least_once,
   163                expansion_service=self.expansion_service))
   164      hamcrest_assert(p, bq_matcher)
   165  
   166    def test_all_types(self):
   167      table_name = "all_types"
   168      schema = self.ALL_TYPES_SCHEMA
   169      self.run_storage_write_test(table_name, self.ELEMENTS, schema)
   170  
   171    def test_with_at_least_once_semantics(self):
   172      table_name = "with_at_least_once_semantics"
   173      schema = self.ALL_TYPES_SCHEMA
   174      self.run_storage_write_test(
   175          table_name, self.ELEMENTS, schema, use_at_least_once=True)
   176  
   177    def test_nested_records_and_lists(self):
   178      table_name = "nested_records_and_lists"
   179      schema = {
   180          "fields": [{
   181              "name": "repeated_int", "type": "INTEGER", "mode": "REPEATED"
   182          },
   183                     {
   184                         "name": "struct",
   185                         "type": "STRUCT",
   186                         "fields": [{
   187                             "name": "nested_int", "type": "INTEGER"
   188                         }, {
   189                             "name": "nested_str", "type": "STRING"
   190                         }]
   191                     },
   192                     {
   193                         "name": "repeated_struct",
   194                         "type": "STRUCT",
   195                         "mode": "REPEATED",
   196                         "fields": [{
   197                             "name": "nested_numeric", "type": "NUMERIC"
   198                         }, {
   199                             "name": "nested_bytes", "type": "BYTES"
   200                         }]
   201                     }]
   202      }
   203      items = [{
   204          "repeated_int": [1, 2, 3],
   205          "struct": {
   206              "nested_int": 1, "nested_str": "a"
   207          },
   208          "repeated_struct": [{
   209              "nested_numeric": Decimal("1.23"), "nested_bytes": b'a'
   210          },
   211                              {
   212                                  "nested_numeric": Decimal("3.21"),
   213                                  "nested_bytes": b'aa'
   214                              }]
   215      }]
   216  
   217      self.run_storage_write_test(table_name, items, schema)
   218  
   219    def test_write_with_beam_rows(self):
   220      table = 'write_with_beam_rows'
   221      table_id = '{}:{}.{}'.format(self.project, self.dataset_id, table)
   222  
   223      row_elements = [
   224          beam.Row(
   225              my_int=e['int'],
   226              my_float=e['float'],
   227              my_numeric=e['numeric'],
   228              my_string=e['str'],
   229              my_bool=e['bool'],
   230              my_bytes=e['bytes'],
   231              my_timestamp=e['timestamp']) for e in self.ELEMENTS
   232      ]
   233  
   234      bq_matcher = BigqueryFullResultMatcher(
   235          project=self.project,
   236          query="SELECT * FROM {}.{}".format(self.dataset_id, table),
   237          data=self.parse_expected_data(self.ELEMENTS))
   238  
   239      with beam.Pipeline(argv=self.args) as p:
   240        _ = (
   241            p
   242            | beam.Create(row_elements)
   243            | StorageWriteToBigQuery(
   244                table=table_id, expansion_service=self.expansion_service))
   245      hamcrest_assert(p, bq_matcher)
   246  
   247    def run_streaming(
   248        self, table_name, auto_sharding=False, use_at_least_once=False):
   249      elements = self.ELEMENTS.copy()
   250      schema = self.ALL_TYPES_SCHEMA
   251      table_id = '{}:{}.{}'.format(self.project, self.dataset_id, table_name)
   252  
   253      bq_matcher = BigqueryFullResultStreamingMatcher(
   254          project=self.project,
   255          query="SELECT * FROM {}.{}".format(self.dataset_id, table_name),
   256          data=self.parse_expected_data(self.ELEMENTS))
   257  
   258      args = self.test_pipeline.get_full_options_as_args(
   259          on_success_matcher=bq_matcher,
   260          streaming=True,
   261          allow_unsafe_triggers=True)
   262  
   263      with beam.Pipeline(argv=args) as p:
   264        _ = (
   265            p
   266            | GenerateSequence(
   267                start=0, stop=4, expansion_service=self.expansion_service)
   268            | beam.Map(lambda x: elements[x])
   269            | beam.io.WriteToBigQuery(
   270                table=table_id,
   271                method=beam.io.WriteToBigQuery.Method.STORAGE_WRITE_API,
   272                schema=schema,
   273                with_auto_sharding=auto_sharding,
   274                use_at_least_once=use_at_least_once,
   275                expansion_service=self.expansion_service))
   276      hamcrest_assert(p, bq_matcher)
   277  
   278    def test_streaming(self):
   279      table = 'streaming'
   280      self.run_streaming(table_name=table)
   281  
   282    def test_streaming_with_at_least_once(self):
   283      table = 'streaming'
   284      self.run_streaming(table_name=table, use_at_least_once=True)
   285  
   286    def test_streaming_with_auto_sharding(self):
   287      table = 'streaming_with_auto_sharding'
   288      self.run_streaming(table_name=table, auto_sharding=True)
   289  
   290  
   291  if __name__ == '__main__':
   292    logging.getLogger().setLevel(logging.INFO)
   293    unittest.main()