github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/yaml/readme_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Runs the examples from the README.md file."""
    19  
    20  import argparse
    21  import logging
    22  import os
    23  import random
    24  import re
    25  import sys
    26  import tempfile
    27  import unittest
    28  
    29  import yaml
    30  from yaml.loader import SafeLoader
    31  
    32  import apache_beam as beam
    33  from apache_beam.options.pipeline_options import PipelineOptions
    34  from apache_beam.typehints import trivial_inference
    35  from apache_beam.yaml import yaml_transform
    36  
    37  
    38  class FakeSql(beam.PTransform):
    39    def __init__(self, query):
    40      self.query = query
    41  
    42    def default_label(self):
    43      return 'Sql'
    44  
    45    def expand(self, inputs):
    46      if isinstance(inputs, beam.PCollection):
    47        inputs = {'PCOLLECTION': inputs}
    48      # This only handles the most basic of queries, trying to infer the output
    49      # schema...
    50      m = re.match('select (.*?) from', self.query, flags=re.IGNORECASE)
    51      if not m:
    52        raise ValueError(self.query)
    53  
    54      def guess_name_and_type(expr):
    55        expr = expr.strip()
    56        parts = expr.split()
    57        if len(parts) >= 2 and parts[-2].lower() == 'as':
    58          name = parts[-1]
    59        elif re.match(r'[\w.]+', parts[0]):
    60          name = parts[0].split('.')[-1]
    61        else:
    62          name = f'expr{hash(expr)}'
    63        if '(' in expr:
    64          expr = expr.lower()
    65          if expr.startswith('count'):
    66            typ = int
    67          elif expr.startswith('avg'):
    68            typ = float
    69          else:
    70            typ = str
    71        else:
    72          part = parts[0]
    73          if '.' in part:
    74            table, field = part.split('.')
    75            typ = inputs[table].element_type.get_type_for(field)
    76          else:
    77            typ = next(iter(inputs.values())).element_type.get_type_for(name)
    78          # Handle optionals more gracefully.
    79          if (str(typ).startswith('typing.Union[') or
    80              str(typ).startswith('typing.Optional[')):
    81            if len(typ.__args__) == 2 and type(None) in typ.__args__:
    82              typ, = [t for t in typ.__args__ if t is not type(None)]
    83        return name, typ
    84  
    85      output_schema = [
    86          guess_name_and_type(expr) for expr in m.group(1).split(',')
    87      ]
    88      output_element = beam.Row(**{name: typ() for name, typ in output_schema})
    89      return next(iter(inputs.values())) | beam.Map(
    90          lambda _: output_element).with_output_types(
    91              trivial_inference.instance_to_type(output_element))
    92  
    93  
    94  class FakeReadFromPubSub(beam.PTransform):
    95    def __init__(self, topic):
    96      pass
    97  
    98    def expand(self, p):
    99      data = p | beam.Create([beam.Row(col1='a', col2=1, col3=0.5)])
   100      result = data | beam.Map(
   101          lambda row: beam.transforms.window.TimestampedValue(row, 0))
   102      # TODO(robertwb): Allow this to be inferred.
   103      result.element_type = data.element_type
   104      return result
   105  
   106  
   107  class FakeWriteToPubSub(beam.PTransform):
   108    def __init__(self, topic):
   109      pass
   110  
   111    def expand(self, pcoll):
   112      return pcoll
   113  
   114  
   115  class SomeAggregation(beam.PTransform):
   116    def expand(self, pcoll):
   117      return pcoll | beam.GroupBy(lambda _: 'key').aggregate_field(
   118          lambda _: 1, sum, 'count')
   119  
   120  
   121  RENDER_DIR = None
   122  TEST_PROVIDERS = {
   123      'Sql': FakeSql,
   124      'ReadFromPubSub': FakeReadFromPubSub,
   125      'WriteToPubSub': FakeWriteToPubSub,
   126      'SomeAggregation': SomeAggregation,
   127  }
   128  
   129  
   130  class TestEnvironment:
   131    def __enter__(self):
   132      self.tempdir = tempfile.TemporaryDirectory()
   133      return self
   134  
   135    def input_file(self, name, content):
   136      path = os.path.join(self.tempdir.name, name)
   137      with open(path, 'w') as fout:
   138        fout.write(content)
   139      return path
   140  
   141    def input_csv(self):
   142      return self.input_file('input.csv', 'col1,col2,col3\nabc,1,2.5\n')
   143  
   144    def input_json(self):
   145      return self.input_file(
   146          'input.json', '{"col1": "abc", "col2": 1, "col3": 2.5"}\n')
   147  
   148    def output_file(self):
   149      return os.path.join(
   150          self.tempdir.name, str(random.randint(0, 1000)) + '.out')
   151  
   152    def __exit__(self, *args):
   153      self.tempdir.cleanup()
   154  
   155  
   156  def replace_recursive(spec, transform_type, arg_name, arg_value):
   157    if isinstance(spec, dict):
   158      spec = {
   159          key: replace_recursive(value, transform_type, arg_name, arg_value)
   160          for (key, value) in spec.items()
   161      }
   162      if spec.get('type', None) == transform_type:
   163        spec[arg_name] = arg_value
   164      return spec
   165    elif isinstance(spec, list):
   166      return [
   167          replace_recursive(value, transform_type, arg_name, arg_value)
   168          for value in spec
   169      ]
   170    else:
   171      return spec
   172  
   173  
   174  def create_test_method(test_type, test_name, test_yaml):
   175    def test(self):
   176      with TestEnvironment() as env:
   177        spec = yaml.load(test_yaml, Loader=SafeLoader)
   178        if test_type == 'PARSE':
   179          return
   180        if 'ReadFromCsv' in test_yaml:
   181          spec = replace_recursive(spec, 'ReadFromCsv', 'path', env.input_csv())
   182        if 'ReadFromJson' in test_yaml:
   183          spec = replace_recursive(spec, 'ReadFromJson', 'path', env.input_json())
   184        for write in ['WriteToText', 'WriteToCsv', 'WriteToJson']:
   185          if write in test_yaml:
   186            spec = replace_recursive(spec, write, 'path', env.output_file())
   187        modified_yaml = yaml.dump(spec)
   188        options = {'pickle_library': 'cloudpickle'}
   189        if RENDER_DIR is not None:
   190          options['runner'] = 'apache_beam.runners.render.RenderRunner'
   191          options['render_output'] = [
   192              os.path.join(RENDER_DIR, test_name + '.png')
   193          ]
   194          options['render_leaf_composite_nodes'] = ['.*']
   195        p = beam.Pipeline(options=PipelineOptions(**options))
   196        yaml_transform.expand_pipeline(p, modified_yaml, TEST_PROVIDERS)
   197        if test_type == 'BUILD':
   198          return
   199        p.run().wait_until_finish()
   200  
   201    return test
   202  
   203  
   204  def parse_test_methods(markdown_lines):
   205    code_lines = None
   206    for ix, line in enumerate(markdown_lines):
   207      line = line.rstrip()
   208      if line == '```':
   209        if code_lines is None:
   210          code_lines = []
   211          test_type = 'RUN'
   212          test_name = f'test_line_{ix + 2}'
   213        else:
   214          if code_lines and code_lines[0] == 'pipeline:':
   215            yaml_pipeline = '\n'.join(code_lines)
   216            if 'providers:' in yaml_pipeline:
   217              test_type = 'PARSE'
   218            yield test_name, create_test_method(
   219                test_type,
   220                test_name,
   221                yaml_pipeline)
   222          code_lines = None
   223      elif code_lines is not None:
   224        code_lines.append(line)
   225  
   226  
   227  def createTestSuite():
   228    with open(os.path.join(os.path.dirname(__file__), 'README.md')) as readme:
   229      return type(
   230          'ReadMeTest', (unittest.TestCase, ), dict(parse_test_methods(readme)))
   231  
   232  
   233  ReadMeTest = createTestSuite()
   234  
   235  if __name__ == '__main__':
   236    parser = argparse.ArgumentParser()
   237    parser.add_argument('--render_dir', default=None)
   238    known_args, unknown_args = parser.parse_known_args(sys.argv)
   239    if known_args.render_dir:
   240      RENDER_DIR = known_args.render_dir
   241    logging.getLogger().setLevel(logging.INFO)
   242    unittest.main(argv=unknown_args)