github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/external/xlang_snowflakeio_it_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """
    19  Integration test for cross-language snowflake io operations.
    20  
    21  Example run:
    22  
    23  python setup.py nosetests --tests=apache_beam.io.external.snowflake_test \
    24  --test-pipeline-options="
    25    --server_name=<SNOWFLAKE_SERVER_NAME>
    26    --username=<SNOWFLAKE_USERNAME>
    27    --password=<SNOWFLAKE_PASSWORD>
    28    --private_key_path=<PATH_TO_PRIVATE_KEY>
    29    --raw_private_key=<RAW_PRIVATE_KEY>
    30    --private_key_passphrase=<PASSWORD_TO_PRIVATE_KEY>
    31    --o_auth_token=<TOKEN>
    32    --staging_bucket_name=<GCP_BUCKET_PATH>
    33    --storage_integration_name=<SNOWFLAKE_STORAGE_INTEGRATION_NAME>
    34    --database=<DATABASE>
    35    --schema=<SCHEMA>
    36    --role=<ROLE>
    37    --warehouse=<WAREHOUSE>
    38    --table=<TABLE_NAME>
    39    --runner=FlinkRunner"
    40  """
    41  
    42  # pytype: skip-file
    43  
    44  import argparse
    45  import binascii
    46  import logging
    47  import unittest
    48  from typing import ByteString
    49  from typing import NamedTuple
    50  
    51  import apache_beam as beam
    52  from apache_beam import coders
    53  from apache_beam.io.snowflake import CreateDisposition
    54  from apache_beam.io.snowflake import ReadFromSnowflake
    55  from apache_beam.io.snowflake import WriteDisposition
    56  from apache_beam.io.snowflake import WriteToSnowflake
    57  from apache_beam.options.pipeline_options import PipelineOptions
    58  from apache_beam.testing.test_pipeline import TestPipeline
    59  from apache_beam.testing.util import assert_that
    60  from apache_beam.testing.util import equal_to
    61  
    62  # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
    63  try:
    64    from apache_beam.io.gcp.gcsfilesystem import GCSFileSystem
    65  except ImportError:
    66    GCSFileSystem = None
    67  # pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports
    68  
    69  SCHEMA_STRING = """
    70  {"schema":[
    71      {"dataType":{"type":"integer","precision":38,"scale":0},"name":"number_column","nullable":false},
    72      {"dataType":{"type":"boolean"},"name":"boolean_column","nullable":false},
    73      {"dataType":{"type":"binary","size":100},"name":"bytes_column","nullable":true}
    74  ]}
    75  """
    76  
    77  TestRow = NamedTuple(
    78      'TestRow',
    79      [
    80          ('number_column', int),
    81          ('boolean_column', bool),
    82          ('bytes_column', ByteString),
    83      ])
    84  
    85  coders.registry.register_coder(TestRow, coders.RowCoder)
    86  
    87  NUM_RECORDS = 100
    88  
    89  
    90  @unittest.skipIf(GCSFileSystem is None, 'GCP dependencies are not installed')
    91  @unittest.skipIf(
    92      TestPipeline().get_option('server_name') is None,
    93      'Snowflake IT test requires external configuration to be run.')
    94  class SnowflakeTest(unittest.TestCase):
    95    def test_snowflake_write_read(self):
    96      self.run_write()
    97      self.run_read()
    98  
    99    def run_write(self):
   100      def user_data_mapper(test_row):
   101        return [
   102            str(test_row.number_column).encode('utf-8'),
   103            str(test_row.boolean_column).encode('utf-8'),
   104            binascii.hexlify(test_row.bytes_column),
   105        ]
   106  
   107      with TestPipeline(options=PipelineOptions(self.pipeline_args)) as p:
   108        p.not_use_test_runner_api = True
   109        _ = (
   110            p
   111            | 'Impulse' >> beam.Impulse()
   112            | 'Generate' >> beam.FlatMap(lambda x: range(NUM_RECORDS))  # pylint: disable=bad-option-value
   113            | 'Map to TestRow' >> beam.Map(
   114                lambda num: TestRow(
   115                    num, num % 2 == 0, b"test" + str(num).encode()))
   116            | WriteToSnowflake(
   117                server_name=self.server_name,
   118                username=self.username,
   119                password=self.password,
   120                o_auth_token=self.o_auth_token,
   121                private_key_path=self.private_key_path,
   122                raw_private_key=self.raw_private_key,
   123                private_key_passphrase=self.private_key_passphrase,
   124                schema=self.schema,
   125                database=self.database,
   126                role=self.role,
   127                warehouse=self.warehouse,
   128                staging_bucket_name=self.staging_bucket_name,
   129                storage_integration_name=self.storage_integration_name,
   130                create_disposition=CreateDisposition.CREATE_IF_NEEDED,
   131                write_disposition=WriteDisposition.TRUNCATE,
   132                table_schema=SCHEMA_STRING,
   133                user_data_mapper=user_data_mapper,
   134                table=self.table,
   135                query=None,
   136                expansion_service=self.expansion_service,
   137            ))
   138  
   139    def run_read(self):
   140      def csv_mapper(bytes_array):
   141        return TestRow(
   142            int(bytes_array[0]),
   143            bytes_array[1] == b'true',
   144            binascii.unhexlify(bytes_array[2]))
   145  
   146      with TestPipeline(options=PipelineOptions(self.pipeline_args)) as p:
   147        result = (
   148            p
   149            | ReadFromSnowflake(
   150                server_name=self.server_name,
   151                username=self.username,
   152                password=self.password,
   153                o_auth_token=self.o_auth_token,
   154                private_key_path=self.private_key_path,
   155                raw_private_key=self.raw_private_key,
   156                private_key_passphrase=self.private_key_passphrase,
   157                schema=self.schema,
   158                database=self.database,
   159                role=self.role,
   160                warehouse=self.warehouse,
   161                staging_bucket_name=self.staging_bucket_name,
   162                storage_integration_name=self.storage_integration_name,
   163                csv_mapper=csv_mapper,
   164                table=self.table,
   165                query=None,
   166                expansion_service=self.expansion_service,
   167            ).with_output_types(TestRow))
   168  
   169        assert_that(
   170            result,
   171            equal_to([
   172                TestRow(i, i % 2 == 0, b'test' + str(i).encode())
   173                for i in range(NUM_RECORDS)
   174            ]))
   175  
   176    @classmethod
   177    def tearDownClass(cls):
   178      GCSFileSystem(pipeline_options=PipelineOptions()) \
   179          .delete([cls.staging_bucket_name])
   180  
   181    @classmethod
   182    def setUpClass(cls):
   183      parser = argparse.ArgumentParser()
   184      parser.add_argument(
   185          '--server_name',
   186          required=True,
   187          help=(
   188              'Snowflake server name of the form '
   189              'https://<SNOWFLAKE_ACCOUNT_NAME>.snowflakecomputing.com'),
   190      )
   191      parser.add_argument(
   192          '--username',
   193          help='Snowflake username',
   194      )
   195      parser.add_argument(
   196          '--password',
   197          help='Snowflake password',
   198      )
   199      parser.add_argument(
   200          '--private_key_path',
   201          help='Path to private key',
   202      )
   203      parser.add_argument(
   204          '--raw_private_key',
   205          help='Raw private key',
   206      )
   207      parser.add_argument(
   208          '--private_key_passphrase',
   209          help='Password to private key',
   210      )
   211      parser.add_argument(
   212          '--o_auth_token',
   213          help='OAuth token',
   214      )
   215      parser.add_argument(
   216          '--staging_bucket_name',
   217          required=True,
   218          help='GCP staging bucket name (must end with backslash)',
   219      )
   220      parser.add_argument(
   221          '--storage_integration_name',
   222          required=True,
   223          help='Snowflake integration name',
   224      )
   225      parser.add_argument(
   226          '--database',
   227          required=True,
   228          help='Snowflake database name',
   229      )
   230      parser.add_argument(
   231          '--schema',
   232          required=True,
   233          help='Snowflake schema name',
   234      )
   235      parser.add_argument(
   236          '--table',
   237          required=True,
   238          help='Snowflake table name',
   239      )
   240      parser.add_argument(
   241          '--role',
   242          help='Snowflake role',
   243      )
   244      parser.add_argument(
   245          '--warehouse',
   246          help='Snowflake warehouse name',
   247      )
   248      parser.add_argument(
   249          '--expansion_service',
   250          help='Url to externally launched expansion service.',
   251      )
   252  
   253      pipeline = TestPipeline()
   254      argv = pipeline.get_full_options_as_args()
   255  
   256      known_args, cls.pipeline_args = parser.parse_known_args(argv)
   257  
   258      cls.server_name = known_args.server_name
   259      cls.database = known_args.database
   260      cls.schema = known_args.schema
   261      cls.table = known_args.table
   262      cls.username = known_args.username
   263      cls.password = known_args.password
   264      cls.private_key_path = known_args.private_key_path
   265      cls.raw_private_key = known_args.raw_private_key
   266      cls.private_key_passphrase = known_args.private_key_passphrase
   267      cls.o_auth_token = known_args.o_auth_token
   268      cls.staging_bucket_name = known_args.staging_bucket_name
   269      cls.storage_integration_name = known_args.storage_integration_name
   270      cls.role = known_args.role
   271      cls.warehouse = known_args.warehouse
   272      cls.expansion_service = known_args.expansion_service
   273  
   274  
   275  if __name__ == '__main__':
   276    logging.getLogger().setLevel(logging.INFO)
   277    unittest.main()