github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/external/xlang_snowflakeio_it_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """ 19 Integration test for cross-language snowflake io operations. 20 21 Example run: 22 23 python setup.py nosetests --tests=apache_beam.io.external.snowflake_test \ 24 --test-pipeline-options=" 25 --server_name=<SNOWFLAKE_SERVER_NAME> 26 --username=<SNOWFLAKE_USERNAME> 27 --password=<SNOWFLAKE_PASSWORD> 28 --private_key_path=<PATH_TO_PRIVATE_KEY> 29 --raw_private_key=<RAW_PRIVATE_KEY> 30 --private_key_passphrase=<PASSWORD_TO_PRIVATE_KEY> 31 --o_auth_token=<TOKEN> 32 --staging_bucket_name=<GCP_BUCKET_PATH> 33 --storage_integration_name=<SNOWFLAKE_STORAGE_INTEGRATION_NAME> 34 --database=<DATABASE> 35 --schema=<SCHEMA> 36 --role=<ROLE> 37 --warehouse=<WAREHOUSE> 38 --table=<TABLE_NAME> 39 --runner=FlinkRunner" 40 """ 41 42 # pytype: skip-file 43 44 import argparse 45 import binascii 46 import logging 47 import unittest 48 from typing import ByteString 49 from typing import NamedTuple 50 51 import apache_beam as beam 52 from apache_beam import coders 53 from apache_beam.io.snowflake import CreateDisposition 54 from apache_beam.io.snowflake import ReadFromSnowflake 55 from apache_beam.io.snowflake import WriteDisposition 56 from apache_beam.io.snowflake import WriteToSnowflake 57 from apache_beam.options.pipeline_options import PipelineOptions 58 from apache_beam.testing.test_pipeline import TestPipeline 59 from apache_beam.testing.util import assert_that 60 from apache_beam.testing.util import equal_to 61 62 # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports 63 try: 64 from apache_beam.io.gcp.gcsfilesystem import GCSFileSystem 65 except ImportError: 66 GCSFileSystem = None 67 # pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports 68 69 SCHEMA_STRING = """ 70 {"schema":[ 71 {"dataType":{"type":"integer","precision":38,"scale":0},"name":"number_column","nullable":false}, 72 {"dataType":{"type":"boolean"},"name":"boolean_column","nullable":false}, 73 {"dataType":{"type":"binary","size":100},"name":"bytes_column","nullable":true} 74 ]} 75 """ 76 77 TestRow = NamedTuple( 78 'TestRow', 79 [ 80 ('number_column', int), 81 ('boolean_column', bool), 82 ('bytes_column', ByteString), 83 ]) 84 85 coders.registry.register_coder(TestRow, coders.RowCoder) 86 87 NUM_RECORDS = 100 88 89 90 @unittest.skipIf(GCSFileSystem is None, 'GCP dependencies are not installed') 91 @unittest.skipIf( 92 TestPipeline().get_option('server_name') is None, 93 'Snowflake IT test requires external configuration to be run.') 94 class SnowflakeTest(unittest.TestCase): 95 def test_snowflake_write_read(self): 96 self.run_write() 97 self.run_read() 98 99 def run_write(self): 100 def user_data_mapper(test_row): 101 return [ 102 str(test_row.number_column).encode('utf-8'), 103 str(test_row.boolean_column).encode('utf-8'), 104 binascii.hexlify(test_row.bytes_column), 105 ] 106 107 with TestPipeline(options=PipelineOptions(self.pipeline_args)) as p: 108 p.not_use_test_runner_api = True 109 _ = ( 110 p 111 | 'Impulse' >> beam.Impulse() 112 | 'Generate' >> beam.FlatMap(lambda x: range(NUM_RECORDS)) # pylint: disable=bad-option-value 113 | 'Map to TestRow' >> beam.Map( 114 lambda num: TestRow( 115 num, num % 2 == 0, b"test" + str(num).encode())) 116 | WriteToSnowflake( 117 server_name=self.server_name, 118 username=self.username, 119 password=self.password, 120 o_auth_token=self.o_auth_token, 121 private_key_path=self.private_key_path, 122 raw_private_key=self.raw_private_key, 123 private_key_passphrase=self.private_key_passphrase, 124 schema=self.schema, 125 database=self.database, 126 role=self.role, 127 warehouse=self.warehouse, 128 staging_bucket_name=self.staging_bucket_name, 129 storage_integration_name=self.storage_integration_name, 130 create_disposition=CreateDisposition.CREATE_IF_NEEDED, 131 write_disposition=WriteDisposition.TRUNCATE, 132 table_schema=SCHEMA_STRING, 133 user_data_mapper=user_data_mapper, 134 table=self.table, 135 query=None, 136 expansion_service=self.expansion_service, 137 )) 138 139 def run_read(self): 140 def csv_mapper(bytes_array): 141 return TestRow( 142 int(bytes_array[0]), 143 bytes_array[1] == b'true', 144 binascii.unhexlify(bytes_array[2])) 145 146 with TestPipeline(options=PipelineOptions(self.pipeline_args)) as p: 147 result = ( 148 p 149 | ReadFromSnowflake( 150 server_name=self.server_name, 151 username=self.username, 152 password=self.password, 153 o_auth_token=self.o_auth_token, 154 private_key_path=self.private_key_path, 155 raw_private_key=self.raw_private_key, 156 private_key_passphrase=self.private_key_passphrase, 157 schema=self.schema, 158 database=self.database, 159 role=self.role, 160 warehouse=self.warehouse, 161 staging_bucket_name=self.staging_bucket_name, 162 storage_integration_name=self.storage_integration_name, 163 csv_mapper=csv_mapper, 164 table=self.table, 165 query=None, 166 expansion_service=self.expansion_service, 167 ).with_output_types(TestRow)) 168 169 assert_that( 170 result, 171 equal_to([ 172 TestRow(i, i % 2 == 0, b'test' + str(i).encode()) 173 for i in range(NUM_RECORDS) 174 ])) 175 176 @classmethod 177 def tearDownClass(cls): 178 GCSFileSystem(pipeline_options=PipelineOptions()) \ 179 .delete([cls.staging_bucket_name]) 180 181 @classmethod 182 def setUpClass(cls): 183 parser = argparse.ArgumentParser() 184 parser.add_argument( 185 '--server_name', 186 required=True, 187 help=( 188 'Snowflake server name of the form ' 189 'https://<SNOWFLAKE_ACCOUNT_NAME>.snowflakecomputing.com'), 190 ) 191 parser.add_argument( 192 '--username', 193 help='Snowflake username', 194 ) 195 parser.add_argument( 196 '--password', 197 help='Snowflake password', 198 ) 199 parser.add_argument( 200 '--private_key_path', 201 help='Path to private key', 202 ) 203 parser.add_argument( 204 '--raw_private_key', 205 help='Raw private key', 206 ) 207 parser.add_argument( 208 '--private_key_passphrase', 209 help='Password to private key', 210 ) 211 parser.add_argument( 212 '--o_auth_token', 213 help='OAuth token', 214 ) 215 parser.add_argument( 216 '--staging_bucket_name', 217 required=True, 218 help='GCP staging bucket name (must end with backslash)', 219 ) 220 parser.add_argument( 221 '--storage_integration_name', 222 required=True, 223 help='Snowflake integration name', 224 ) 225 parser.add_argument( 226 '--database', 227 required=True, 228 help='Snowflake database name', 229 ) 230 parser.add_argument( 231 '--schema', 232 required=True, 233 help='Snowflake schema name', 234 ) 235 parser.add_argument( 236 '--table', 237 required=True, 238 help='Snowflake table name', 239 ) 240 parser.add_argument( 241 '--role', 242 help='Snowflake role', 243 ) 244 parser.add_argument( 245 '--warehouse', 246 help='Snowflake warehouse name', 247 ) 248 parser.add_argument( 249 '--expansion_service', 250 help='Url to externally launched expansion service.', 251 ) 252 253 pipeline = TestPipeline() 254 argv = pipeline.get_full_options_as_args() 255 256 known_args, cls.pipeline_args = parser.parse_known_args(argv) 257 258 cls.server_name = known_args.server_name 259 cls.database = known_args.database 260 cls.schema = known_args.schema 261 cls.table = known_args.table 262 cls.username = known_args.username 263 cls.password = known_args.password 264 cls.private_key_path = known_args.private_key_path 265 cls.raw_private_key = known_args.raw_private_key 266 cls.private_key_passphrase = known_args.private_key_passphrase 267 cls.o_auth_token = known_args.o_auth_token 268 cls.staging_bucket_name = known_args.staging_bucket_name 269 cls.storage_integration_name = known_args.storage_integration_name 270 cls.role = known_args.role 271 cls.warehouse = known_args.warehouse 272 cls.expansion_service = known_args.expansion_service 273 274 275 if __name__ == '__main__': 276 logging.getLogger().setLevel(logging.INFO) 277 unittest.main()