github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/render_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  # pytype: skip-file
    18  
    19  import argparse
    20  import logging
    21  import subprocess
    22  import unittest
    23  
    24  import apache_beam as beam
    25  from apache_beam.runners import render
    26  
    27  default_options = render.RenderOptions._add_argparse_args(
    28      argparse.ArgumentParser()).parse_args([])
    29  
    30  
    31  class RenderRunnerTest(unittest.TestCase):
    32    def test_basic_graph(self):
    33      p = beam.Pipeline()
    34      _ = (
    35          p | beam.Impulse() | beam.Map(lambda _: 2)
    36          | 'CustomName' >> beam.Map(lambda x: x * x))
    37      dot = render.PipelineRenderer(p.to_runner_api(), default_options).to_dot()
    38      self.assertIn('digraph', dot)
    39      self.assertIn('CustomName', dot)
    40      self.assertEqual(dot.count('->'), 2)
    41  
    42    def test_side_input(self):
    43      p = beam.Pipeline()
    44      pcoll = p | beam.Impulse() | beam.FlatMap(lambda x: [1, 2, 3])
    45      dot = render.PipelineRenderer(p.to_runner_api(), default_options).to_dot()
    46      self.assertEqual(dot.count('->'), 1)
    47      self.assertNotIn('dashed', dot)
    48  
    49      _ = pcoll | beam.Map(
    50          lambda x, side: x * side, side=beam.pvalue.AsList(pcoll))
    51      dot = render.PipelineRenderer(p.to_runner_api(), default_options).to_dot()
    52      self.assertEqual(dot.count('->'), 3)
    53      self.assertIn('dashed', dot)
    54  
    55    def test_composite_collapse(self):
    56      p = beam.Pipeline()
    57      _ = p | beam.Create([1, 2, 3]) | beam.Map(lambda x: x * x)
    58      pipeline_proto = p.to_runner_api()
    59      renderer = render.PipelineRenderer(pipeline_proto, default_options)
    60      self.assertEqual(renderer.to_dot().count('->'), 8)
    61      create_transform_id, = [
    62          id
    63          for (id, transform) in pipeline_proto.components.transforms.items()
    64          if transform.unique_name == 'Create']
    65      renderer.update(toggle=[create_transform_id])
    66      self.assertEqual(renderer.to_dot().count('->'), 1)
    67  
    68    def test_dot_well_formed(self):
    69      try:
    70        subprocess.run(['dot', '-V'], capture_output=True, check=True)
    71      except FileNotFoundError:
    72        self.skipTest('dot executable not installed')
    73      p = beam.Pipeline()
    74      _ = p | beam.Create([1, 2, 3]) | beam.Map(lambda x: x * x)
    75      pipeline_proto = p.to_runner_api()
    76      renderer = render.PipelineRenderer(pipeline_proto, default_options)
    77      # Doesn't actually look at the output, but ensures dot executes correctly.
    78      renderer.render_data()
    79      create_transform_id, = [
    80          id
    81          for (id, transform) in pipeline_proto.components.transforms.items()
    82          if transform.unique_name == 'Create']
    83      renderer.update(toggle=[create_transform_id])
    84      renderer.render_data()
    85  
    86  
    87  if __name__ == '__main__':
    88    logging.getLogger().setLevel(logging.INFO)
    89    unittest.main()