github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/sql_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Tests for transforms that use the SQL Expansion service.""" 19 20 # pytype: skip-file 21 22 import logging 23 import typing 24 import unittest 25 26 import pytest 27 28 import apache_beam as beam 29 from apache_beam import coders 30 from apache_beam.options.pipeline_options import StandardOptions 31 from apache_beam.testing.test_pipeline import TestPipeline 32 from apache_beam.testing.util import assert_that 33 from apache_beam.testing.util import equal_to 34 from apache_beam.transforms.sql import SqlTransform 35 36 SimpleRow = typing.NamedTuple( 37 "SimpleRow", [("id", int), ("str", str), ("flt", float)]) 38 coders.registry.register_coder(SimpleRow, coders.RowCoder) 39 40 Enrich = typing.NamedTuple("Enrich", [("id", int), ("metadata", str)]) 41 coders.registry.register_coder(Enrich, coders.RowCoder) 42 43 Shopper = typing.NamedTuple( 44 "Shopper", [("shopper", str), ("cart", typing.Mapping[str, int])]) 45 coders.registry.register_coder(Shopper, coders.RowCoder) 46 47 48 @pytest.mark.xlang_sql_expansion_service 49 @unittest.skipIf( 50 TestPipeline().get_pipeline_options().view_as(StandardOptions).runner is 51 None, 52 "Must be run with a runner that supports staging java artifacts.") 53 class SqlTransformTest(unittest.TestCase): 54 """Tests that exercise the cross-language SqlTransform (implemented in java). 55 56 Note this test must be executed with pipeline options that run jobs on a local 57 job server. The easiest way to accomplish this is to run the 58 `validatesCrossLanguageRunnerPythonUsingSql` gradle target for a particular 59 job server, which will start the runner and job server for you. For example, 60 `:runners:flink:1.13:job-server:validatesCrossLanguageRunnerPythonUsingSql` to 61 test on Flink 1.13. 62 63 Alternatively, you may be able to iterate faster if you run the tests directly 64 using a runner like `FlinkRunner`, which can start a local Flink cluster and 65 job server for you: 66 $ pip install -e './sdks/python[gcp,test]' 67 $ pytest apache_beam/transforms/sql_test.py \\ 68 --test-pipeline-options="--runner=FlinkRunner" 69 """ 70 _multiprocess_can_split_ = True 71 72 def test_generate_data(self): 73 with TestPipeline() as p: 74 out = p | SqlTransform( 75 """SELECT 76 CAST(1 AS INT) AS `id`, 77 CAST('foo' AS VARCHAR) AS `str`, 78 CAST(3.14 AS DOUBLE) AS `flt`""") 79 assert_that(out, equal_to([(1, "foo", 3.14)])) 80 81 def test_project(self): 82 with TestPipeline() as p: 83 out = ( 84 p | beam.Create([SimpleRow(1, "foo", 3.14)]) 85 | SqlTransform("SELECT `id`, `flt` FROM PCOLLECTION")) 86 assert_that(out, equal_to([(1, 3.14)])) 87 88 def test_filter(self): 89 with TestPipeline() as p: 90 out = ( 91 p 92 | beam.Create([SimpleRow(1, "foo", 3.14), SimpleRow(2, "bar", 1.414)]) 93 | SqlTransform("SELECT * FROM PCOLLECTION WHERE `str` = 'bar'")) 94 assert_that(out, equal_to([(2, "bar", 1.414)])) 95 96 def test_agg(self): 97 with TestPipeline() as p: 98 out = ( 99 p 100 | beam.Create([ 101 SimpleRow(1, "foo", 1.), 102 SimpleRow(1, "foo", 2.), 103 SimpleRow(1, "foo", 3.), 104 SimpleRow(2, "bar", 1.414), 105 SimpleRow(2, "bar", 1.414), 106 SimpleRow(2, "bar", 1.414), 107 SimpleRow(2, "bar", 1.414), 108 ]) 109 | SqlTransform( 110 """ 111 SELECT 112 `str`, 113 COUNT(*) AS `count`, 114 SUM(`id`) AS `sum`, 115 AVG(`flt`) AS `avg` 116 FROM PCOLLECTION GROUP BY `str`""")) 117 assert_that(out, equal_to([("foo", 3, 3, 2), ("bar", 4, 8, 1.414)])) 118 119 def test_tagged_join(self): 120 with TestPipeline() as p: 121 enrich = ( 122 p | "Create enrich" >> beam.Create( 123 [Enrich(1, "a"), Enrich(2, "b"), Enrich(26, "z")])) 124 simple = ( 125 p | "Create simple" >> beam.Create([ 126 SimpleRow(1, "foo", 3.14), 127 SimpleRow(26, "bar", 1.11), 128 SimpleRow(1, "baz", 2.34) 129 ])) 130 out = ({ 131 'simple': simple, 'enrich': enrich 132 } 133 | SqlTransform( 134 """ 135 SELECT 136 simple.`id` AS `id`, 137 enrich.metadata AS metadata 138 FROM simple 139 JOIN enrich 140 ON simple.`id` = enrich.`id`""")) 141 assert_that(out, equal_to([(1, "a"), (26, "z"), (1, "a")])) 142 143 def test_row(self): 144 with TestPipeline() as p: 145 out = ( 146 p 147 | beam.Create([1, 2, 10]) 148 | beam.Map(lambda x: beam.Row(a=x, b=str(x))) 149 | SqlTransform("SELECT a*a as s, LENGTH(b) AS c FROM PCOLLECTION")) 150 assert_that(out, equal_to([(1, 1), (4, 1), (100, 2)])) 151 152 def test_zetasql_generate_data(self): 153 with TestPipeline() as p: 154 out = p | SqlTransform( 155 """SELECT 156 CAST(1 AS INT64) AS `int`, 157 CAST('foo' AS STRING) AS `str`, 158 CAST(3.14 AS FLOAT64) AS `flt`""", 159 dialect="zetasql") 160 assert_that(out, equal_to([(1, "foo", 3.14)])) 161 162 def test_windowing_before_sql(self): 163 with TestPipeline() as p: 164 out = ( 165 p | beam.Create([ 166 SimpleRow(5, "foo", 1.), 167 SimpleRow(15, "bar", 2.), 168 SimpleRow(25, "baz", 3.) 169 ]) 170 | beam.Map(lambda v: beam.window.TimestampedValue(v, v.id)). 171 with_output_types(SimpleRow) 172 | beam.WindowInto( 173 beam.window.FixedWindows(10)).with_output_types(SimpleRow) 174 | SqlTransform("SELECT COUNT(*) as `count` FROM PCOLLECTION")) 175 assert_that(out, equal_to([(1, ), (1, ), (1, )])) 176 177 def test_map(self): 178 with TestPipeline() as p: 179 out = ( 180 p 181 | beam.Create([ 182 Shopper('bob', { 183 'bananas': 6, 'cherries': 3 184 }), 185 Shopper('alice', { 186 'apples': 2, 'bananas': 3 187 }) 188 ]).with_output_types(Shopper) 189 | SqlTransform("SELECT * FROM PCOLLECTION WHERE shopper = 'alice'")) 190 assert_that(out, equal_to([('alice', {'apples': 2, 'bananas': 3})])) 191 192 193 if __name__ == "__main__": 194 logging.getLogger().setLevel(logging.INFO) 195 unittest.main()