github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/yaml/readme_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Runs the examples from the README.md file.""" 19 20 import argparse 21 import logging 22 import os 23 import random 24 import re 25 import sys 26 import tempfile 27 import unittest 28 29 import yaml 30 from yaml.loader import SafeLoader 31 32 import apache_beam as beam 33 from apache_beam.options.pipeline_options import PipelineOptions 34 from apache_beam.typehints import trivial_inference 35 from apache_beam.yaml import yaml_transform 36 37 38 class FakeSql(beam.PTransform): 39 def __init__(self, query): 40 self.query = query 41 42 def default_label(self): 43 return 'Sql' 44 45 def expand(self, inputs): 46 if isinstance(inputs, beam.PCollection): 47 inputs = {'PCOLLECTION': inputs} 48 # This only handles the most basic of queries, trying to infer the output 49 # schema... 50 m = re.match('select (.*?) from', self.query, flags=re.IGNORECASE) 51 if not m: 52 raise ValueError(self.query) 53 54 def guess_name_and_type(expr): 55 expr = expr.strip() 56 parts = expr.split() 57 if len(parts) >= 2 and parts[-2].lower() == 'as': 58 name = parts[-1] 59 elif re.match(r'[\w.]+', parts[0]): 60 name = parts[0].split('.')[-1] 61 else: 62 name = f'expr{hash(expr)}' 63 if '(' in expr: 64 expr = expr.lower() 65 if expr.startswith('count'): 66 typ = int 67 elif expr.startswith('avg'): 68 typ = float 69 else: 70 typ = str 71 else: 72 part = parts[0] 73 if '.' in part: 74 table, field = part.split('.') 75 typ = inputs[table].element_type.get_type_for(field) 76 else: 77 typ = next(iter(inputs.values())).element_type.get_type_for(name) 78 # Handle optionals more gracefully. 79 if (str(typ).startswith('typing.Union[') or 80 str(typ).startswith('typing.Optional[')): 81 if len(typ.__args__) == 2 and type(None) in typ.__args__: 82 typ, = [t for t in typ.__args__ if t is not type(None)] 83 return name, typ 84 85 output_schema = [ 86 guess_name_and_type(expr) for expr in m.group(1).split(',') 87 ] 88 output_element = beam.Row(**{name: typ() for name, typ in output_schema}) 89 return next(iter(inputs.values())) | beam.Map( 90 lambda _: output_element).with_output_types( 91 trivial_inference.instance_to_type(output_element)) 92 93 94 class FakeReadFromPubSub(beam.PTransform): 95 def __init__(self, topic): 96 pass 97 98 def expand(self, p): 99 data = p | beam.Create([beam.Row(col1='a', col2=1, col3=0.5)]) 100 result = data | beam.Map( 101 lambda row: beam.transforms.window.TimestampedValue(row, 0)) 102 # TODO(robertwb): Allow this to be inferred. 103 result.element_type = data.element_type 104 return result 105 106 107 class FakeWriteToPubSub(beam.PTransform): 108 def __init__(self, topic): 109 pass 110 111 def expand(self, pcoll): 112 return pcoll 113 114 115 class SomeAggregation(beam.PTransform): 116 def expand(self, pcoll): 117 return pcoll | beam.GroupBy(lambda _: 'key').aggregate_field( 118 lambda _: 1, sum, 'count') 119 120 121 RENDER_DIR = None 122 TEST_PROVIDERS = { 123 'Sql': FakeSql, 124 'ReadFromPubSub': FakeReadFromPubSub, 125 'WriteToPubSub': FakeWriteToPubSub, 126 'SomeAggregation': SomeAggregation, 127 } 128 129 130 class TestEnvironment: 131 def __enter__(self): 132 self.tempdir = tempfile.TemporaryDirectory() 133 return self 134 135 def input_file(self, name, content): 136 path = os.path.join(self.tempdir.name, name) 137 with open(path, 'w') as fout: 138 fout.write(content) 139 return path 140 141 def input_csv(self): 142 return self.input_file('input.csv', 'col1,col2,col3\nabc,1,2.5\n') 143 144 def input_json(self): 145 return self.input_file( 146 'input.json', '{"col1": "abc", "col2": 1, "col3": 2.5"}\n') 147 148 def output_file(self): 149 return os.path.join( 150 self.tempdir.name, str(random.randint(0, 1000)) + '.out') 151 152 def __exit__(self, *args): 153 self.tempdir.cleanup() 154 155 156 def replace_recursive(spec, transform_type, arg_name, arg_value): 157 if isinstance(spec, dict): 158 spec = { 159 key: replace_recursive(value, transform_type, arg_name, arg_value) 160 for (key, value) in spec.items() 161 } 162 if spec.get('type', None) == transform_type: 163 spec[arg_name] = arg_value 164 return spec 165 elif isinstance(spec, list): 166 return [ 167 replace_recursive(value, transform_type, arg_name, arg_value) 168 for value in spec 169 ] 170 else: 171 return spec 172 173 174 def create_test_method(test_type, test_name, test_yaml): 175 def test(self): 176 with TestEnvironment() as env: 177 spec = yaml.load(test_yaml, Loader=SafeLoader) 178 if test_type == 'PARSE': 179 return 180 if 'ReadFromCsv' in test_yaml: 181 spec = replace_recursive(spec, 'ReadFromCsv', 'path', env.input_csv()) 182 if 'ReadFromJson' in test_yaml: 183 spec = replace_recursive(spec, 'ReadFromJson', 'path', env.input_json()) 184 for write in ['WriteToText', 'WriteToCsv', 'WriteToJson']: 185 if write in test_yaml: 186 spec = replace_recursive(spec, write, 'path', env.output_file()) 187 modified_yaml = yaml.dump(spec) 188 options = {'pickle_library': 'cloudpickle'} 189 if RENDER_DIR is not None: 190 options['runner'] = 'apache_beam.runners.render.RenderRunner' 191 options['render_output'] = [ 192 os.path.join(RENDER_DIR, test_name + '.png') 193 ] 194 options['render_leaf_composite_nodes'] = ['.*'] 195 p = beam.Pipeline(options=PipelineOptions(**options)) 196 yaml_transform.expand_pipeline(p, modified_yaml, TEST_PROVIDERS) 197 if test_type == 'BUILD': 198 return 199 p.run().wait_until_finish() 200 201 return test 202 203 204 def parse_test_methods(markdown_lines): 205 code_lines = None 206 for ix, line in enumerate(markdown_lines): 207 line = line.rstrip() 208 if line == '```': 209 if code_lines is None: 210 code_lines = [] 211 test_type = 'RUN' 212 test_name = f'test_line_{ix + 2}' 213 else: 214 if code_lines and code_lines[0] == 'pipeline:': 215 yaml_pipeline = '\n'.join(code_lines) 216 if 'providers:' in yaml_pipeline: 217 test_type = 'PARSE' 218 yield test_name, create_test_method( 219 test_type, 220 test_name, 221 yaml_pipeline) 222 code_lines = None 223 elif code_lines is not None: 224 code_lines.append(line) 225 226 227 def createTestSuite(): 228 with open(os.path.join(os.path.dirname(__file__), 'README.md')) as readme: 229 return type( 230 'ReadMeTest', (unittest.TestCase, ), dict(parse_test_methods(readme))) 231 232 233 ReadMeTest = createTestSuite() 234 235 if __name__ == '__main__': 236 parser = argparse.ArgumentParser() 237 parser.add_argument('--render_dir', default=None) 238 known_args, unknown_args = parser.parse_known_args(sys.argv) 239 if known_args.render_dir: 240 RENDER_DIR = known_args.render_dir 241 logging.getLogger().setLevel(logging.INFO) 242 unittest.main(argv=unknown_args)