github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/gcp/bigquery_json_it_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """
    19  Integration tests for BigQuery's JSON data type
    20  """
    21  
    22  import argparse
    23  import json
    24  import logging
    25  import time
    26  import unittest
    27  from random import randint
    28  
    29  import pytest
    30  
    31  import apache_beam as beam
    32  from apache_beam.io.gcp.bigquery import ReadFromBigQuery
    33  from apache_beam.testing.test_pipeline import TestPipeline
    34  from apache_beam.testing.util import assert_that
    35  from apache_beam.testing.util import equal_to
    36  
    37  _LOGGER = logging.getLogger(__name__)
    38  
    39  PROJECT = 'apache-beam-testing'
    40  DATASET_ID = 'bq_jsontype_test_nodelete'
    41  JSON_TABLE_NAME = 'json_data'
    42  
    43  JSON_TABLE_DESTINATION = f"{PROJECT}:{DATASET_ID}.{JSON_TABLE_NAME}"
    44  
    45  STREAMING_TEST_TABLE = "py_streaming_test" \
    46                         f"{time.time_ns() // 1000}_{randint(0,32)}"
    47  
    48  
    49  class BigQueryJsonIT(unittest.TestCase):
    50    @classmethod
    51    def setUpClass(cls):
    52      cls.test_pipeline = TestPipeline(is_integration_test=True)
    53  
    54    def run_test_write(self, options):
    55      json_table_schema = self.generate_schema()
    56      rows_to_write = []
    57      json_data = self.generate_data()
    58      for country_code, country in json_data.items():
    59        cities_to_write = []
    60        for city_name, city in country["cities"].items():
    61          cities_to_write.append({'city_name': city_name, 'city': city})
    62  
    63        rows_to_write.append({
    64            'country_code': country_code,
    65            'country': country["country"],
    66            'stats': country["stats"],
    67            'cities': cities_to_write,
    68            'landmarks': country["landmarks"]
    69        })
    70  
    71      parser = argparse.ArgumentParser()
    72      parser.add_argument('--write_method')
    73      parser.add_argument('--output')
    74  
    75      known_args, pipeline_args = parser.parse_known_args(options)
    76  
    77      with beam.Pipeline(argv=pipeline_args) as p:
    78        _ = (
    79            p
    80            | "Create rows with JSON data" >> beam.Create(rows_to_write)
    81            | "Write to BigQuery" >> beam.io.WriteToBigQuery(
    82                method=known_args.write_method,
    83                table=known_args.output,
    84                schema=json_table_schema,
    85                create_disposition=beam.io.BigQueryDisposition.CREATE_IF_NEEDED,
    86            ))
    87  
    88      extra_opts = {'read_method': "EXPORT", 'input': known_args.output}
    89      read_options = self.test_pipeline.get_full_options_as_args(**extra_opts)
    90      self.read_and_validate_rows(read_options)
    91  
    92    def read_and_validate_rows(self, options):
    93      json_data = self.generate_data()
    94  
    95      class CompareJson(beam.DoFn, unittest.TestCase):
    96        def process(self, row):
    97          country_code = row["country_code"]
    98          expected = json_data[country_code]
    99  
   100          # Test country (JSON String)
   101          country_actual = json.loads(row["country"])
   102          country_expected = json.loads(expected["country"])
   103          self.assertTrue(country_expected == country_actual)
   104  
   105          # Test stats (JSON String in BigQuery struct)
   106          for stat, value in row["stats"].items():
   107            stats_actual = json.loads(value)
   108            stats_expected = json.loads(expected["stats"][stat])
   109            self.assertTrue(stats_expected == stats_actual)
   110  
   111          # Test cities (JSON String in BigQuery array of structs)
   112          for city_row in row["cities"]:
   113            city = city_row["city"]
   114            city_name = city_row["city_name"]
   115  
   116            city_actual = json.loads(city)
   117            city_expected = json.loads(expected["cities"][city_name])
   118            self.assertTrue(city_expected == city_actual)
   119  
   120          # Test landmarks (JSON String in BigQuery array)
   121          landmarks_actual = row["landmarks"]
   122          landmarks_expected = expected["landmarks"]
   123          for i in range(len(landmarks_actual)):
   124            l_actual = json.loads(landmarks_actual[i])
   125            l_expected = json.loads(landmarks_expected[i])
   126            self.assertTrue(l_expected == l_actual)
   127  
   128      parser = argparse.ArgumentParser()
   129      parser.add_argument('--read_method')
   130      parser.add_argument('--query')
   131      parser.add_argument('--input')
   132  
   133      known_args, pipeline_args = parser.parse_known_args(options)
   134  
   135      method = ReadFromBigQuery.Method.DIRECT_READ if \
   136        known_args.read_method == "DIRECT_READ" else \
   137        ReadFromBigQuery.Method.EXPORT
   138  
   139      if known_args.query:
   140        json_query_data = self.generate_query_data()
   141        with beam.Pipeline(argv=pipeline_args) as p:
   142          data = p | 'Read rows' >> ReadFromBigQuery(
   143              query=known_args.query, method=method, use_standard_sql=True)
   144          assert_that(data, equal_to(json_query_data))
   145      else:
   146        with beam.Pipeline(argv=pipeline_args) as p:
   147          _ = p | 'Read rows' >> ReadFromBigQuery(
   148              table=known_args.input,
   149              method=method,
   150          ) | 'Validate rows' >> beam.ParDo(CompareJson())
   151  
   152    @pytest.mark.it_postcommit
   153    def test_direct_read(self):
   154      extra_opts = {
   155          'read_method': "DIRECT_READ",
   156          'input': JSON_TABLE_DESTINATION,
   157      }
   158      options = self.test_pipeline.get_full_options_as_args(**extra_opts)
   159  
   160      self.read_and_validate_rows(options)
   161  
   162    @pytest.mark.it_postcommit
   163    def test_export_read(self):
   164      extra_opts = {
   165          'read_method': "EXPORT",
   166          'input': JSON_TABLE_DESTINATION,
   167      }
   168      options = self.test_pipeline.get_full_options_as_args(**extra_opts)
   169  
   170      self.read_and_validate_rows(options)
   171  
   172    @pytest.mark.it_postcommit
   173    def test_query_read(self):
   174      extra_opts = {
   175          'query': "SELECT "
   176          "country_code, "
   177          "country.past_leaders[2] AS past_leader, "
   178          "stats.gdp_per_capita[\"gdp_per_capita\"] AS gdp, "
   179          "cities[OFFSET(1)].city.name AS city_name, "
   180          "landmarks[OFFSET(1)][\"name\"] AS landmark_name "
   181          f"FROM `{PROJECT}.{DATASET_ID}.{JSON_TABLE_NAME}`",
   182      }
   183      options = self.test_pipeline.get_full_options_as_args(**extra_opts)
   184  
   185      self.read_and_validate_rows(options)
   186  
   187    @pytest.mark.it_postcommit
   188    def test_streaming_inserts(self):
   189      extra_opts = {
   190          'output': f"{PROJECT}:{DATASET_ID}.{STREAMING_TEST_TABLE}",
   191          'write_method': "STREAMING_INSERTS"
   192      }
   193      options = self.test_pipeline.get_full_options_as_args(**extra_opts)
   194  
   195      self.run_test_write(options)
   196  
   197    @pytest.mark.it_postcommit
   198    def test_file_loads_write(self):
   199      extra_opts = {
   200          'output': f"{PROJECT}:{DATASET_ID}.{STREAMING_TEST_TABLE}",
   201          'write_method': "FILE_LOADS"
   202      }
   203      options = self.test_pipeline.get_full_options_as_args(**extra_opts)
   204      with self.assertRaises(ValueError):
   205        self.run_test_write(options)
   206  
   207    # Schema for writing to BigQuery
   208    def generate_schema(self):
   209      from apache_beam.io.gcp.internal.clients.bigquery import TableFieldSchema
   210      from apache_beam.io.gcp.internal.clients.bigquery import TableSchema
   211      json_fields = [
   212          TableFieldSchema(name='country_code', type='STRING', mode='NULLABLE'),
   213          TableFieldSchema(name='country', type='JSON', mode='NULLABLE'),
   214          TableFieldSchema(
   215              name='stats',
   216              type='STRUCT',
   217              mode='NULLABLE',
   218              fields=[
   219                  TableFieldSchema(
   220                      name="gdp_per_capita", type='JSON', mode='NULLABLE'),
   221                  TableFieldSchema(
   222                      name="co2_emissions", type='JSON', mode='NULLABLE'),
   223              ]),
   224          TableFieldSchema(
   225              name='cities',
   226              type='STRUCT',
   227              mode='REPEATED',
   228              fields=[
   229                  TableFieldSchema(
   230                      name="city_name", type='STRING', mode='NULLABLE'),
   231                  TableFieldSchema(name="city", type='JSON', mode='NULLABLE'),
   232              ]),
   233          TableFieldSchema(name='landmarks', type='JSON', mode='REPEATED'),
   234      ]
   235  
   236      schema = TableSchema(fields=json_fields)
   237  
   238      return schema
   239  
   240    # Expected data for query test
   241    def generate_query_data(self):
   242      query_data = [{
   243          'country_code': 'usa',
   244          'past_leader': '\"George W. Bush\"',
   245          'gdp': '58559.675',
   246          'city_name': '\"Los Angeles\"',
   247          'landmark_name': '\"Golden Gate Bridge\"'
   248      },
   249                    {
   250                        'country_code': 'aus',
   251                        'past_leader': '\"Kevin Rudd\"',
   252                        'gdp': '58043.581',
   253                        'city_name': '\"Melbourne\"',
   254                        'landmark_name': '\"Great Barrier Reef\"'
   255                    },
   256                    {
   257                        'country_code': 'special',
   258                        'past_leader': '\"!@#$%^&*()_+\"',
   259                        'gdp': '421.7',
   260                        'city_name': '\"Bikini Bottom\"',
   261                        'landmark_name': "\"Willy Wonka's Factory\""
   262                    }]
   263      return query_data
   264  
   265    def generate_data(self):
   266      # Raw country data
   267      usa = {
   268          "name": "United States of America",
   269          "population": 329484123,
   270          "cities": {
   271              "nyc": {
   272                  "name": "New York City", "state": "NY", "population": 8622357
   273              },
   274              "la": {
   275                  "name": "Los Angeles", "state": "CA", "population": 4085014
   276              },
   277              "chicago": {
   278                  "name": "Chicago", "state": "IL", "population": 2670406
   279              },
   280          },
   281          "past_leaders": [
   282              "Donald Trump", "Barack Obama", "George W. Bush", "Bill Clinton"
   283          ],
   284          "in_northern_hemisphere": True
   285      }
   286  
   287      aus = {
   288          "name": "Australia",
   289          "population": 25687041,
   290          "cities": {
   291              "sydney": {
   292                  "name": "Sydney",
   293                  "state": "New South Wales",
   294                  "population": 5367206
   295              },
   296              "melbourne": {
   297                  "name": "Melbourne", "state": "Victoria", "population": 5159211
   298              },
   299              "brisbane": {
   300                  "name": "Brisbane",
   301                  "state": "Queensland",
   302                  "population": 2560720
   303              }
   304          },
   305          "past_leaders": [
   306              "Malcolm Turnbull",
   307              "Tony Abbot",
   308              "Kevin Rudd",
   309          ],
   310          "in_northern_hemisphere": False
   311      }
   312  
   313      special = {
   314          "name": "newline\n, form\f, tab\t, \"quotes\", "
   315          "\\backslash\\, backspace\b, \u0000_hex_\u0f0f",
   316          "population": -123456789,
   317          "cities": {
   318              "basingse": {
   319                  "name": "Ba Sing Se",
   320                  "state": "The Earth Kingdom",
   321                  "population": 200000
   322              },
   323              "bikinibottom": {
   324                  "name": "Bikini Bottom",
   325                  "state": "The Pacific Ocean",
   326                  "population": 50000
   327              }
   328          },
   329          "past_leaders": [
   330              "1",
   331              "2",
   332              "!@#$%^&*()_+",
   333          ],
   334          "in_northern_hemisphere": True
   335      }
   336  
   337      landmarks = {
   338          "usa_0": {
   339              "name": "Statue of Liberty", "cool rating": None
   340          },
   341          "usa_1": {
   342              "name": "Golden Gate Bridge", "cool rating": "very cool"
   343          },
   344          "usa_2": {
   345              "name": "Grand Canyon", "cool rating": "very very cool"
   346          },
   347          "aus_0": {
   348              "name": "Sydney Opera House", "cool rating": "amazing"
   349          },
   350          "aus_1": {
   351              "name": "Great Barrier Reef", "cool rating": None
   352          },
   353          "special_0": {
   354              "name": "Hogwarts School of WitchCraft and Wizardry",
   355              "cool rating": "magical"
   356          },
   357          "special_1": {
   358              "name": "Willy Wonka's Factory", "cool rating": None
   359          },
   360          "special_2": {
   361              "name": "Rivendell", "cool rating": "precious"
   362          },
   363      }
   364      stats = {
   365          "usa_gdp_per_capita": {
   366              "gdp_per_capita": 58559.675, "currency": "constant 2015 US$"
   367          },
   368          "usa_co2_emissions": {
   369              "co2 emissions": 15.241,
   370              "measurement": "metric tons per capita",
   371              "year": 2018
   372          },
   373          "aus_gdp_per_capita": {
   374              "gdp_per_capita": 58043.581, "currency": "constant 2015 US$"
   375          },
   376          "aus_co2_emissions": {
   377              "co2 emissions": 15.476,
   378              "measurement": "metric tons per capita",
   379              "year": 2018
   380          },
   381          "special_gdp_per_capita": {
   382              "gdp_per_capita": 421.70, "currency": "constant 200 BC gold"
   383          },
   384          "special_co2_emissions": {
   385              "co2 emissions": -10.79,
   386              "measurement": "metric tons per capita",
   387              "year": 2018
   388          }
   389      }
   390  
   391      data = {
   392          "usa": {
   393              "country": json.dumps(usa),
   394              "cities": {
   395                  "nyc": json.dumps(usa["cities"]["nyc"]),
   396                  "la": json.dumps(usa["cities"]["la"]),
   397                  "chicago": json.dumps(usa["cities"]["chicago"])
   398              },
   399              "landmarks": [
   400                  json.dumps(landmarks["usa_0"]),
   401                  json.dumps(landmarks["usa_1"]),
   402                  json.dumps(landmarks["usa_2"])
   403              ],
   404              "stats": {
   405                  "gdp_per_capita": json.dumps(stats["usa_gdp_per_capita"]),
   406                  "co2_emissions": json.dumps(stats["usa_co2_emissions"])
   407              }
   408          },
   409          "aus": {
   410              "country": json.dumps(aus),
   411              "cities": {
   412                  "sydney": json.dumps(aus["cities"]["sydney"]),
   413                  "melbourne": json.dumps(aus["cities"]["melbourne"]),
   414                  "brisbane": json.dumps(aus["cities"]["brisbane"])
   415              },
   416              "landmarks": [
   417                  json.dumps(landmarks["aus_0"]), json.dumps(landmarks["aus_1"])
   418              ],
   419              "stats": {
   420                  "gdp_per_capita": json.dumps(stats["aus_gdp_per_capita"]),
   421                  "co2_emissions": json.dumps(stats["aus_co2_emissions"])
   422              }
   423          },
   424          "special": {
   425              "country": json.dumps(special),
   426              "cities": {
   427                  "basingse": json.dumps(special["cities"]["basingse"]),
   428                  "bikinibottom": json.dumps(special["cities"]["bikinibottom"])
   429              },
   430              "landmarks": [
   431                  json.dumps(landmarks["special_0"]),
   432                  json.dumps(landmarks["special_1"]),
   433                  json.dumps(landmarks["special_2"])
   434              ],
   435              "stats": {
   436                  "gdp_per_capita": json.dumps(stats["special_gdp_per_capita"]),
   437                  "co2_emissions": json.dumps(stats["special_co2_emissions"])
   438              }
   439          }
   440      }
   441      return data
   442  
   443  
   444  if __name__ == '__main__':
   445    logging.getLogger().setLevel(logging.INFO)
   446    unittest.main()