github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/sql_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Tests for transforms that use the SQL Expansion service."""
    19  
    20  # pytype: skip-file
    21  
    22  import logging
    23  import typing
    24  import unittest
    25  
    26  import pytest
    27  
    28  import apache_beam as beam
    29  from apache_beam import coders
    30  from apache_beam.options.pipeline_options import StandardOptions
    31  from apache_beam.testing.test_pipeline import TestPipeline
    32  from apache_beam.testing.util import assert_that
    33  from apache_beam.testing.util import equal_to
    34  from apache_beam.transforms.sql import SqlTransform
    35  
    36  SimpleRow = typing.NamedTuple(
    37      "SimpleRow", [("id", int), ("str", str), ("flt", float)])
    38  coders.registry.register_coder(SimpleRow, coders.RowCoder)
    39  
    40  Enrich = typing.NamedTuple("Enrich", [("id", int), ("metadata", str)])
    41  coders.registry.register_coder(Enrich, coders.RowCoder)
    42  
    43  Shopper = typing.NamedTuple(
    44      "Shopper", [("shopper", str), ("cart", typing.Mapping[str, int])])
    45  coders.registry.register_coder(Shopper, coders.RowCoder)
    46  
    47  
    48  @pytest.mark.xlang_sql_expansion_service
    49  @unittest.skipIf(
    50      TestPipeline().get_pipeline_options().view_as(StandardOptions).runner is
    51      None,
    52      "Must be run with a runner that supports staging java artifacts.")
    53  class SqlTransformTest(unittest.TestCase):
    54    """Tests that exercise the cross-language SqlTransform (implemented in java).
    55  
    56    Note this test must be executed with pipeline options that run jobs on a local
    57    job server. The easiest way to accomplish this is to run the
    58    `validatesCrossLanguageRunnerPythonUsingSql` gradle target for a particular
    59    job server, which will start the runner and job server for you. For example,
    60    `:runners:flink:1.13:job-server:validatesCrossLanguageRunnerPythonUsingSql` to
    61    test on Flink 1.13.
    62  
    63    Alternatively, you may be able to iterate faster if you run the tests directly
    64    using a runner like `FlinkRunner`, which can start a local Flink cluster and
    65    job server for you:
    66      $ pip install -e './sdks/python[gcp,test]'
    67      $ pytest apache_beam/transforms/sql_test.py \\
    68          --test-pipeline-options="--runner=FlinkRunner"
    69    """
    70    _multiprocess_can_split_ = True
    71  
    72    def test_generate_data(self):
    73      with TestPipeline() as p:
    74        out = p | SqlTransform(
    75            """SELECT
    76              CAST(1 AS INT) AS `id`,
    77              CAST('foo' AS VARCHAR) AS `str`,
    78              CAST(3.14  AS DOUBLE) AS `flt`""")
    79        assert_that(out, equal_to([(1, "foo", 3.14)]))
    80  
    81    def test_project(self):
    82      with TestPipeline() as p:
    83        out = (
    84            p | beam.Create([SimpleRow(1, "foo", 3.14)])
    85            | SqlTransform("SELECT `id`, `flt` FROM PCOLLECTION"))
    86        assert_that(out, equal_to([(1, 3.14)]))
    87  
    88    def test_filter(self):
    89      with TestPipeline() as p:
    90        out = (
    91            p
    92            | beam.Create([SimpleRow(1, "foo", 3.14), SimpleRow(2, "bar", 1.414)])
    93            | SqlTransform("SELECT * FROM PCOLLECTION WHERE `str` = 'bar'"))
    94        assert_that(out, equal_to([(2, "bar", 1.414)]))
    95  
    96    def test_agg(self):
    97      with TestPipeline() as p:
    98        out = (
    99            p
   100            | beam.Create([
   101                SimpleRow(1, "foo", 1.),
   102                SimpleRow(1, "foo", 2.),
   103                SimpleRow(1, "foo", 3.),
   104                SimpleRow(2, "bar", 1.414),
   105                SimpleRow(2, "bar", 1.414),
   106                SimpleRow(2, "bar", 1.414),
   107                SimpleRow(2, "bar", 1.414),
   108            ])
   109            | SqlTransform(
   110                """
   111                SELECT
   112                  `str`,
   113                  COUNT(*) AS `count`,
   114                  SUM(`id`) AS `sum`,
   115                  AVG(`flt`) AS `avg`
   116                FROM PCOLLECTION GROUP BY `str`"""))
   117        assert_that(out, equal_to([("foo", 3, 3, 2), ("bar", 4, 8, 1.414)]))
   118  
   119    def test_tagged_join(self):
   120      with TestPipeline() as p:
   121        enrich = (
   122            p | "Create enrich" >> beam.Create(
   123                [Enrich(1, "a"), Enrich(2, "b"), Enrich(26, "z")]))
   124        simple = (
   125            p | "Create simple" >> beam.Create([
   126                SimpleRow(1, "foo", 3.14),
   127                SimpleRow(26, "bar", 1.11),
   128                SimpleRow(1, "baz", 2.34)
   129            ]))
   130        out = ({
   131            'simple': simple, 'enrich': enrich
   132        }
   133               | SqlTransform(
   134                   """
   135                SELECT
   136                  simple.`id` AS `id`,
   137                  enrich.metadata AS metadata
   138                FROM simple
   139                JOIN enrich
   140                ON simple.`id` = enrich.`id`"""))
   141        assert_that(out, equal_to([(1, "a"), (26, "z"), (1, "a")]))
   142  
   143    def test_row(self):
   144      with TestPipeline() as p:
   145        out = (
   146            p
   147            | beam.Create([1, 2, 10])
   148            | beam.Map(lambda x: beam.Row(a=x, b=str(x)))
   149            | SqlTransform("SELECT a*a as s, LENGTH(b) AS c FROM PCOLLECTION"))
   150        assert_that(out, equal_to([(1, 1), (4, 1), (100, 2)]))
   151  
   152    def test_zetasql_generate_data(self):
   153      with TestPipeline() as p:
   154        out = p | SqlTransform(
   155            """SELECT
   156              CAST(1 AS INT64) AS `int`,
   157              CAST('foo' AS STRING) AS `str`,
   158              CAST(3.14  AS FLOAT64) AS `flt`""",
   159            dialect="zetasql")
   160        assert_that(out, equal_to([(1, "foo", 3.14)]))
   161  
   162    def test_windowing_before_sql(self):
   163      with TestPipeline() as p:
   164        out = (
   165            p | beam.Create([
   166                SimpleRow(5, "foo", 1.),
   167                SimpleRow(15, "bar", 2.),
   168                SimpleRow(25, "baz", 3.)
   169            ])
   170            | beam.Map(lambda v: beam.window.TimestampedValue(v, v.id)).
   171            with_output_types(SimpleRow)
   172            | beam.WindowInto(
   173                beam.window.FixedWindows(10)).with_output_types(SimpleRow)
   174            | SqlTransform("SELECT COUNT(*) as `count` FROM PCOLLECTION"))
   175        assert_that(out, equal_to([(1, ), (1, ), (1, )]))
   176  
   177    def test_map(self):
   178      with TestPipeline() as p:
   179        out = (
   180            p
   181            | beam.Create([
   182                Shopper('bob', {
   183                    'bananas': 6, 'cherries': 3
   184                }),
   185                Shopper('alice', {
   186                    'apples': 2, 'bananas': 3
   187                })
   188            ]).with_output_types(Shopper)
   189            | SqlTransform("SELECT * FROM PCOLLECTION WHERE shopper = 'alice'"))
   190        assert_that(out, equal_to([('alice', {'apples': 2, 'bananas': 3})]))
   191  
   192  
   193  if __name__ == "__main__":
   194    logging.getLogger().setLevel(logging.INFO)
   195    unittest.main()