github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/external/xlang_jdbcio_it_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  # pytype: skip-file
    19  
    20  import logging
    21  import time
    22  import typing
    23  import unittest
    24  from decimal import Decimal
    25  from typing import Callable
    26  from typing import Union
    27  
    28  from parameterized import parameterized
    29  
    30  import apache_beam as beam
    31  from apache_beam import coders
    32  from apache_beam.io.jdbc import ReadFromJdbc
    33  from apache_beam.io.jdbc import WriteToJdbc
    34  from apache_beam.options.pipeline_options import StandardOptions
    35  from apache_beam.testing.test_pipeline import TestPipeline
    36  from apache_beam.testing.util import assert_that
    37  from apache_beam.testing.util import equal_to
    38  from apache_beam.typehints.schemas import LogicalType
    39  from apache_beam.typehints.schemas import MillisInstant
    40  from apache_beam.utils.timestamp import Timestamp
    41  
    42  # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
    43  try:
    44    import sqlalchemy
    45  except ImportError:
    46    sqlalchemy = None
    47  # pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports
    48  
    49  # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
    50  try:
    51    from testcontainers.postgres import PostgresContainer
    52    from testcontainers.mysql import MySqlContainer
    53  except ImportError:
    54    PostgresContainer = None
    55  # pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports
    56  
    57  ROW_COUNT = 10
    58  
    59  JdbcTestRow = typing.NamedTuple(
    60      "JdbcTestRow",
    61      [("f_id", int), ("f_float", float), ("f_char", str), ("f_varchar", str),
    62       ("f_bytes", bytes), ("f_varbytes", bytes), ("f_timestamp", Timestamp),
    63       ("f_decimal", Decimal)],
    64  )
    65  coders.registry.register_coder(JdbcTestRow, coders.RowCoder)
    66  
    67  
    68  @unittest.skipIf(sqlalchemy is None, 'sql alchemy package is not installed.')
    69  @unittest.skipIf(
    70      PostgresContainer is None, 'testcontainers package is not installed')
    71  @unittest.skipIf(
    72      TestPipeline().get_pipeline_options().view_as(StandardOptions).runner is
    73      None,
    74      'Do not run this test on precommit suites.')
    75  class CrossLanguageJdbcIOTest(unittest.TestCase):
    76    DbData = typing.NamedTuple(
    77        'DbData',
    78        [('container_fn', typing.Any), ('classpath', typing.List[str]),
    79         ('db_string', str), ('connector', str)])
    80    DB_CONTAINER_CLASSPATH_STRING = {
    81        'postgres': DbData(
    82            lambda: PostgresContainer('postgres:12.3'),
    83            None,
    84            'postgresql',
    85            'org.postgresql.Driver'),
    86        'mysql': DbData(
    87            lambda: MySqlContainer(), ['mysql:mysql-connector-java:8.0.28'],
    88            'mysql',
    89            'com.mysql.cj.jdbc.Driver')
    90    }
    91  
    92    def _setUpTestCase(
    93        self,
    94        container_init: Callable[[], Union[PostgresContainer, MySqlContainer]],
    95        db_string: str,
    96        driver: str):
    97      # This method is not the normal setUp from unittest, because the test has
    98      # beem parameterized. The setup then needs extra parameters to initialize.
    99      self.start_db_container(retries=3, container_init=container_init)
   100      self.engine = sqlalchemy.create_engine(self.db.get_connection_url())
   101      self.username = 'test'
   102      self.password = 'test'
   103      self.host = self.db.get_container_host_ip()
   104      self.port = self.db.get_exposed_port(self.db.port_to_expose)
   105      self.database_name = 'test'
   106      self.driver_class_name = driver
   107      self.jdbc_url = 'jdbc:{}://{}:{}/{}'.format(
   108          db_string, self.host, self.port, self.database_name)
   109  
   110    def tearDown(self):
   111      # Sometimes stopping the container raises ReadTimeout. We can ignore it
   112      # here to avoid the test failure.
   113      try:
   114        self.db.stop()
   115      except:  # pylint: disable=bare-except
   116        logging.error('Could not stop the postgreSQL container.')
   117  
   118    @parameterized.expand(['postgres', 'mysql'])
   119    def test_xlang_jdbc_write_read(self, database):
   120      container_init, classpath, db_string, driver = (
   121          CrossLanguageJdbcIOTest.DB_CONTAINER_CLASSPATH_STRING[database])
   122      self._setUpTestCase(container_init, db_string, driver)
   123      table_name = 'jdbc_external_test'
   124      if database == 'postgres':
   125        # postgres does not have BINARY and VARBINARY type, use equvalent.
   126        binary_type = ('BYTEA', 'BYTEA')
   127      else:
   128        binary_type = ('BINARY(10)', 'VARBINARY(10)')
   129  
   130      self.engine.execute(
   131          "CREATE TABLE IF NOT EXISTS {}".format(table_name) + "(f_id INTEGER, " +
   132          "f_float DOUBLE PRECISION, " + "f_char CHAR(10), " +
   133          "f_varchar VARCHAR(10), " + f"f_bytes {binary_type[0]}, " +
   134          f"f_varbytes {binary_type[1]}, " + "f_timestamp TIMESTAMP(3), " +
   135          "f_decimal DECIMAL(10, 2))")
   136      inserted_rows = [
   137          JdbcTestRow(
   138              i,
   139              i + 0.1,
   140              f'Test{i}',
   141              f'Test{i}',
   142              f'Test{i}'.encode(),
   143              f'Test{i}'.encode(),
   144              # In alignment with Java Instant which supports milli precision.
   145              Timestamp.of(seconds=round(time.time(), 3)),
   146              # Test both positive and negative numbers.
   147              Decimal(f'{i-1}.23')) for i in range(ROW_COUNT)
   148      ]
   149      expected_row = []
   150      for row in inserted_rows:
   151        f_char = row.f_char + ' ' * (10 - len(row.f_char))
   152        if database != 'postgres':
   153          # padding expected results
   154          f_bytes = row.f_bytes + b'\0' * (10 - len(row.f_bytes))
   155        else:
   156          f_bytes = row.f_bytes
   157        expected_row.append(
   158            JdbcTestRow(
   159                row.f_id,
   160                row.f_float,
   161                f_char,
   162                row.f_varchar,
   163                f_bytes,
   164                row.f_bytes,
   165                row.f_timestamp,
   166                row.f_decimal))
   167  
   168      with TestPipeline() as p:
   169        p.not_use_test_runner_api = True
   170        _ = (
   171            p
   172            | beam.Create(inserted_rows).with_output_types(JdbcTestRow)
   173            # TODO(https://github.com/apache/beam/issues/20446) Add test with
   174            # overridden write_statement
   175            | 'Write to jdbc' >> WriteToJdbc(
   176                table_name=table_name,
   177                driver_class_name=self.driver_class_name,
   178                jdbc_url=self.jdbc_url,
   179                username=self.username,
   180                password=self.password,
   181                classpath=classpath,
   182            ))
   183  
   184      # Register MillisInstant logical type to override the mapping from Timestamp
   185      # originally handled by MicrosInstant.
   186      LogicalType.register_logical_type(MillisInstant)
   187  
   188      with TestPipeline() as p:
   189        p.not_use_test_runner_api = True
   190        result = (
   191            p
   192            # TODO(https://github.com/apache/beam/issues/20446) Add test with
   193            # overridden read_query
   194            | 'Read from jdbc' >> ReadFromJdbc(
   195                table_name=table_name,
   196                driver_class_name=self.driver_class_name,
   197                jdbc_url=self.jdbc_url,
   198                username=self.username,
   199                password=self.password,
   200                classpath=classpath))
   201  
   202        assert_that(result, equal_to(expected_row))
   203  
   204      # Try the same read using the partitioned reader code path.
   205      # Outputs should be the same.
   206      with TestPipeline() as p:
   207        p.not_use_test_runner_api = True
   208        result = (
   209            p
   210            | 'Partitioned read from jdbc' >> ReadFromJdbc(
   211                table_name=table_name,
   212                partition_column='f_id',
   213                partitions=3,
   214                driver_class_name=self.driver_class_name,
   215                jdbc_url=self.jdbc_url,
   216                username=self.username,
   217                password=self.password,
   218                classpath=classpath))
   219  
   220        assert_that(result, equal_to(expected_row))
   221  
   222    # Creating a container with testcontainers sometimes raises ReadTimeout
   223    # error. In java there are 2 retries set by default.
   224    def start_db_container(self, retries, container_init):
   225      for i in range(retries):
   226        try:
   227          self.db = container_init()
   228          self.db.start()
   229          break
   230        except Exception as e:  # pylint: disable=bare-except
   231          if i == retries - 1:
   232            logging.error('Unable to initialize database container.')
   233            raise e
   234  
   235  
   236  if __name__ == '__main__':
   237    logging.getLogger().setLevel(logging.INFO)
   238    unittest.main()