github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/pipeline_fragment_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Tests for apache_beam.runners.interactive.pipeline_fragment."""
    19  import unittest
    20  from unittest.mock import patch
    21  
    22  import apache_beam as beam
    23  from apache_beam.options.pipeline_options import StandardOptions
    24  from apache_beam.runners.interactive import interactive_beam as ib
    25  from apache_beam.runners.interactive import interactive_environment as ie
    26  from apache_beam.runners.interactive import interactive_runner as ir
    27  from apache_beam.runners.interactive import pipeline_fragment as pf
    28  from apache_beam.runners.interactive.testing.mock_ipython import mock_get_ipython
    29  from apache_beam.runners.interactive.testing.pipeline_assertion import assert_pipeline_equal
    30  from apache_beam.runners.interactive.testing.pipeline_assertion import assert_pipeline_proto_equal
    31  from apache_beam.testing.test_stream import TestStream
    32  
    33  
    34  @unittest.skipIf(
    35      not ie.current_env().is_interactive_ready,
    36      '[interactive] dependency is not installed.')
    37  class PipelineFragmentTest(unittest.TestCase):
    38    def setUp(self):
    39      ie.new_env()
    40      # Assume a notebook frontend is connected to the mocked ipython kernel.
    41      ie.current_env()._is_in_ipython = True
    42      ie.current_env()._is_in_notebook = True
    43  
    44    @patch('IPython.get_ipython', new_callable=mock_get_ipython)
    45    def test_build_pipeline_fragment(self, cell):
    46      with cell:  # Cell 1
    47        p = beam.Pipeline(ir.InteractiveRunner())
    48        p_expected = beam.Pipeline(ir.InteractiveRunner())
    49        # Watch local scope now to allow interactive beam to track the pipelines.
    50        ib.watch(locals())
    51  
    52      with cell:  # Cell 2
    53        # pylint: disable=bad-option-value
    54        init = p | 'Init' >> beam.Create(range(10))
    55        init_expected = p_expected | 'Init' >> beam.Create(range(10))
    56  
    57      with cell:  # Cell 3
    58        square = init | 'Square' >> beam.Map(lambda x: x * x)
    59        _ = init | 'Cube' >> beam.Map(lambda x: x**3)
    60        _ = init_expected | 'Square' >> beam.Map(lambda x: x * x)
    61  
    62      # Watch every PCollection has been defined so far in local scope.
    63      ib.watch(locals())
    64      fragment = pf.PipelineFragment([square]).deduce_fragment()
    65      assert_pipeline_equal(self, p_expected, fragment)
    66  
    67    @patch('IPython.get_ipython', new_callable=mock_get_ipython)
    68    def test_user_pipeline_intact_after_deducing_pipeline_fragment(self, cell):
    69      with cell:  # Cell 1
    70        p = beam.Pipeline(ir.InteractiveRunner())
    71        # Watch the pipeline `p` immediately without calling locals().
    72        ib.watch({'p': p})
    73  
    74      with cell:  # Cell 2
    75        # pylint: disable=bad-option-value
    76        init = p | 'Init' >> beam.Create(range(10))
    77  
    78      with cell:  # Cell 3
    79        square = init | 'Square' >> beam.Map(lambda x: x * x)
    80  
    81      with cell:  # Cell 4
    82        cube = init | 'Cube' >> beam.Map(lambda x: x**3)
    83  
    84      # Watch every PCollection has been defined so far in local scope without
    85      # calling locals().
    86      ib.watch({'init': init, 'square': square, 'cube': cube})
    87      user_pipeline_proto_before_deducing_fragment = p.to_runner_api(
    88          return_context=False)
    89      _ = pf.PipelineFragment([square]).deduce_fragment()
    90      user_pipeline_proto_after_deducing_fragment = p.to_runner_api(
    91          return_context=False)
    92      assert_pipeline_proto_equal(
    93          self,
    94          user_pipeline_proto_before_deducing_fragment,
    95          user_pipeline_proto_after_deducing_fragment)
    96  
    97    @patch('IPython.get_ipython', new_callable=mock_get_ipython)
    98    def test_pipeline_fragment_produces_correct_data(self, cell):
    99      with cell:  # Cell 1
   100        p = beam.Pipeline(ir.InteractiveRunner())
   101        ib.watch({'p': p})
   102  
   103      with cell:  # Cell 2
   104        # pylint: disable=bad-option-value
   105        init = p | 'Init' >> beam.Create(range(5))
   106  
   107      with cell:  # Cell 3
   108        square = init | 'Square' >> beam.Map(lambda x: x * x)
   109        _ = init | 'Cube' >> beam.Map(lambda x: x**3)
   110  
   111      ib.watch(locals())
   112      result = pf.PipelineFragment([square]).run()
   113      self.assertEqual([0, 1, 4, 9, 16], list(result.get(square)))
   114  
   115    def test_fragment_does_not_prune_teststream(self):
   116      """Tests that the fragment does not prune the TestStream composite parts.
   117      """
   118      options = StandardOptions(streaming=True)
   119      p = beam.Pipeline(ir.InteractiveRunner(), options)
   120  
   121      test_stream = p | TestStream(output_tags=['a', 'b'])
   122  
   123      # pylint: disable=unused-variable
   124      a = test_stream['a'] | 'a' >> beam.Map(lambda _: _)
   125      b = test_stream['b'] | 'b' >> beam.Map(lambda _: _)
   126  
   127      fragment = pf.PipelineFragment([b]).deduce_fragment()
   128  
   129      # If the fragment does prune the TestStreawm composite parts, then the
   130      # resulting graph is invalid and the following call will raise an exception.
   131      fragment.to_runner_api()
   132  
   133    @patch('IPython.get_ipython', new_callable=mock_get_ipython)
   134    def test_pipeline_composites(self, cell):
   135      """Tests that composites are supported.
   136      """
   137      with cell:  # Cell 1
   138        p = beam.Pipeline(ir.InteractiveRunner())
   139        ib.watch({'p': p})
   140  
   141      with cell:  # Cell 2
   142        # pylint: disable=bad-option-value
   143        init = p | 'Init' >> beam.Create(range(5))
   144  
   145      with cell:  # Cell 3
   146        # Have a composite within a composite to test that all transforms under a
   147        # composite are added.
   148  
   149        @beam.ptransform_fn
   150        def Bar(pcoll):
   151          return pcoll | beam.Map(lambda n: 2 * n)
   152  
   153        @beam.ptransform_fn
   154        def Foo(pcoll):
   155          p1 = pcoll | beam.Map(lambda n: 3 * n)
   156          p2 = pcoll | beam.Map(str)
   157          bar = p1 | Bar()
   158          return {'pc1': p1, 'pc2': p2, 'bar': bar}
   159  
   160        res = init | Foo()
   161        ib.watch(res)
   162  
   163      pc = res['bar']
   164  
   165      result = pf.PipelineFragment([pc]).run()
   166      self.assertEqual([0, 6, 12, 18, 24], list(result.get(pc)))
   167  
   168    def test_ib_show_without_using_ir(self):
   169      """Tests that ib.show is called when ir is not specified.
   170      """
   171      p = beam.Pipeline()
   172      print_words = p | beam.Create(["this is a test"]) | beam.Map(print)
   173      with self.assertRaises(RuntimeError):
   174        ib.show(print_words)
   175  
   176  
   177  if __name__ == '__main__':
   178    unittest.main()