github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/external/xlang_jdbcio_it_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 # pytype: skip-file 19 20 import logging 21 import time 22 import typing 23 import unittest 24 from decimal import Decimal 25 from typing import Callable 26 from typing import Union 27 28 from parameterized import parameterized 29 30 import apache_beam as beam 31 from apache_beam import coders 32 from apache_beam.io.jdbc import ReadFromJdbc 33 from apache_beam.io.jdbc import WriteToJdbc 34 from apache_beam.options.pipeline_options import StandardOptions 35 from apache_beam.testing.test_pipeline import TestPipeline 36 from apache_beam.testing.util import assert_that 37 from apache_beam.testing.util import equal_to 38 from apache_beam.typehints.schemas import LogicalType 39 from apache_beam.typehints.schemas import MillisInstant 40 from apache_beam.utils.timestamp import Timestamp 41 42 # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports 43 try: 44 import sqlalchemy 45 except ImportError: 46 sqlalchemy = None 47 # pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports 48 49 # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports 50 try: 51 from testcontainers.postgres import PostgresContainer 52 from testcontainers.mysql import MySqlContainer 53 except ImportError: 54 PostgresContainer = None 55 # pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports 56 57 ROW_COUNT = 10 58 59 JdbcTestRow = typing.NamedTuple( 60 "JdbcTestRow", 61 [("f_id", int), ("f_float", float), ("f_char", str), ("f_varchar", str), 62 ("f_bytes", bytes), ("f_varbytes", bytes), ("f_timestamp", Timestamp), 63 ("f_decimal", Decimal)], 64 ) 65 coders.registry.register_coder(JdbcTestRow, coders.RowCoder) 66 67 68 @unittest.skipIf(sqlalchemy is None, 'sql alchemy package is not installed.') 69 @unittest.skipIf( 70 PostgresContainer is None, 'testcontainers package is not installed') 71 @unittest.skipIf( 72 TestPipeline().get_pipeline_options().view_as(StandardOptions).runner is 73 None, 74 'Do not run this test on precommit suites.') 75 class CrossLanguageJdbcIOTest(unittest.TestCase): 76 DbData = typing.NamedTuple( 77 'DbData', 78 [('container_fn', typing.Any), ('classpath', typing.List[str]), 79 ('db_string', str), ('connector', str)]) 80 DB_CONTAINER_CLASSPATH_STRING = { 81 'postgres': DbData( 82 lambda: PostgresContainer('postgres:12.3'), 83 None, 84 'postgresql', 85 'org.postgresql.Driver'), 86 'mysql': DbData( 87 lambda: MySqlContainer(), ['mysql:mysql-connector-java:8.0.28'], 88 'mysql', 89 'com.mysql.cj.jdbc.Driver') 90 } 91 92 def _setUpTestCase( 93 self, 94 container_init: Callable[[], Union[PostgresContainer, MySqlContainer]], 95 db_string: str, 96 driver: str): 97 # This method is not the normal setUp from unittest, because the test has 98 # beem parameterized. The setup then needs extra parameters to initialize. 99 self.start_db_container(retries=3, container_init=container_init) 100 self.engine = sqlalchemy.create_engine(self.db.get_connection_url()) 101 self.username = 'test' 102 self.password = 'test' 103 self.host = self.db.get_container_host_ip() 104 self.port = self.db.get_exposed_port(self.db.port_to_expose) 105 self.database_name = 'test' 106 self.driver_class_name = driver 107 self.jdbc_url = 'jdbc:{}://{}:{}/{}'.format( 108 db_string, self.host, self.port, self.database_name) 109 110 def tearDown(self): 111 # Sometimes stopping the container raises ReadTimeout. We can ignore it 112 # here to avoid the test failure. 113 try: 114 self.db.stop() 115 except: # pylint: disable=bare-except 116 logging.error('Could not stop the postgreSQL container.') 117 118 @parameterized.expand(['postgres', 'mysql']) 119 def test_xlang_jdbc_write_read(self, database): 120 container_init, classpath, db_string, driver = ( 121 CrossLanguageJdbcIOTest.DB_CONTAINER_CLASSPATH_STRING[database]) 122 self._setUpTestCase(container_init, db_string, driver) 123 table_name = 'jdbc_external_test' 124 if database == 'postgres': 125 # postgres does not have BINARY and VARBINARY type, use equvalent. 126 binary_type = ('BYTEA', 'BYTEA') 127 else: 128 binary_type = ('BINARY(10)', 'VARBINARY(10)') 129 130 self.engine.execute( 131 "CREATE TABLE IF NOT EXISTS {}".format(table_name) + "(f_id INTEGER, " + 132 "f_float DOUBLE PRECISION, " + "f_char CHAR(10), " + 133 "f_varchar VARCHAR(10), " + f"f_bytes {binary_type[0]}, " + 134 f"f_varbytes {binary_type[1]}, " + "f_timestamp TIMESTAMP(3), " + 135 "f_decimal DECIMAL(10, 2))") 136 inserted_rows = [ 137 JdbcTestRow( 138 i, 139 i + 0.1, 140 f'Test{i}', 141 f'Test{i}', 142 f'Test{i}'.encode(), 143 f'Test{i}'.encode(), 144 # In alignment with Java Instant which supports milli precision. 145 Timestamp.of(seconds=round(time.time(), 3)), 146 # Test both positive and negative numbers. 147 Decimal(f'{i-1}.23')) for i in range(ROW_COUNT) 148 ] 149 expected_row = [] 150 for row in inserted_rows: 151 f_char = row.f_char + ' ' * (10 - len(row.f_char)) 152 if database != 'postgres': 153 # padding expected results 154 f_bytes = row.f_bytes + b'\0' * (10 - len(row.f_bytes)) 155 else: 156 f_bytes = row.f_bytes 157 expected_row.append( 158 JdbcTestRow( 159 row.f_id, 160 row.f_float, 161 f_char, 162 row.f_varchar, 163 f_bytes, 164 row.f_bytes, 165 row.f_timestamp, 166 row.f_decimal)) 167 168 with TestPipeline() as p: 169 p.not_use_test_runner_api = True 170 _ = ( 171 p 172 | beam.Create(inserted_rows).with_output_types(JdbcTestRow) 173 # TODO(https://github.com/apache/beam/issues/20446) Add test with 174 # overridden write_statement 175 | 'Write to jdbc' >> WriteToJdbc( 176 table_name=table_name, 177 driver_class_name=self.driver_class_name, 178 jdbc_url=self.jdbc_url, 179 username=self.username, 180 password=self.password, 181 classpath=classpath, 182 )) 183 184 # Register MillisInstant logical type to override the mapping from Timestamp 185 # originally handled by MicrosInstant. 186 LogicalType.register_logical_type(MillisInstant) 187 188 with TestPipeline() as p: 189 p.not_use_test_runner_api = True 190 result = ( 191 p 192 # TODO(https://github.com/apache/beam/issues/20446) Add test with 193 # overridden read_query 194 | 'Read from jdbc' >> ReadFromJdbc( 195 table_name=table_name, 196 driver_class_name=self.driver_class_name, 197 jdbc_url=self.jdbc_url, 198 username=self.username, 199 password=self.password, 200 classpath=classpath)) 201 202 assert_that(result, equal_to(expected_row)) 203 204 # Try the same read using the partitioned reader code path. 205 # Outputs should be the same. 206 with TestPipeline() as p: 207 p.not_use_test_runner_api = True 208 result = ( 209 p 210 | 'Partitioned read from jdbc' >> ReadFromJdbc( 211 table_name=table_name, 212 partition_column='f_id', 213 partitions=3, 214 driver_class_name=self.driver_class_name, 215 jdbc_url=self.jdbc_url, 216 username=self.username, 217 password=self.password, 218 classpath=classpath)) 219 220 assert_that(result, equal_to(expected_row)) 221 222 # Creating a container with testcontainers sometimes raises ReadTimeout 223 # error. In java there are 2 retries set by default. 224 def start_db_container(self, retries, container_init): 225 for i in range(retries): 226 try: 227 self.db = container_init() 228 self.db.start() 229 break 230 except Exception as e: # pylint: disable=bare-except 231 if i == retries - 1: 232 logging.error('Unable to initialize database container.') 233 raise e 234 235 236 if __name__ == '__main__': 237 logging.getLogger().setLevel(logging.INFO) 238 unittest.main()