github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/pipeline_instrument_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Tests for apache_beam.runners.interactive.pipeline_instrument."""
    19  # pytype: skip-file
    20  
    21  import unittest
    22  
    23  import apache_beam as beam
    24  from apache_beam import coders
    25  from apache_beam.pipeline import PipelineVisitor
    26  from apache_beam.runners.interactive import cache_manager as cache
    27  from apache_beam.runners.interactive import interactive_beam as ib
    28  from apache_beam.runners.interactive import interactive_environment as ie
    29  from apache_beam.runners.interactive import pipeline_instrument as instr
    30  from apache_beam.runners.interactive import interactive_runner
    31  from apache_beam.runners.interactive import utils
    32  from apache_beam.runners.interactive.caching.cacheable import Cacheable
    33  from apache_beam.runners.interactive.caching.cacheable import CacheKey
    34  from apache_beam.runners.interactive.caching.streaming_cache import StreamingCache
    35  from apache_beam.runners.interactive.testing.pipeline_assertion import assert_pipeline_equal
    36  from apache_beam.runners.interactive.testing.pipeline_assertion import assert_pipeline_proto_contain_top_level_transform
    37  from apache_beam.runners.interactive.testing.pipeline_assertion import assert_pipeline_proto_equal
    38  from apache_beam.runners.interactive.testing.pipeline_assertion import \
    39      assert_pipeline_proto_not_contain_top_level_transform
    40  from apache_beam.runners.interactive.testing.test_cache_manager import InMemoryCache
    41  from apache_beam.testing.test_stream import TestStream
    42  
    43  
    44  class PipelineInstrumentTest(unittest.TestCase):
    45    def setUp(self):
    46      ie.new_env()
    47  
    48    def cache_key_of(self, name, pcoll):
    49      return CacheKey.from_pcoll(name, pcoll).to_str()
    50  
    51    def test_pcoll_to_pcoll_id(self):
    52      p = beam.Pipeline(interactive_runner.InteractiveRunner())
    53      ie.current_env().set_cache_manager(InMemoryCache(), p)
    54      # pylint: disable=bad-option-value
    55      init_pcoll = p | 'Init Create' >> beam.Impulse()
    56      _, ctx = p.to_runner_api(return_context=True)
    57      self.assertEqual(
    58          instr.pcoll_to_pcoll_id(p, ctx),
    59          {str(init_pcoll): 'ref_PCollection_PCollection_1'})
    60  
    61    def test_pcoll_id_with_user_pipeline(self):
    62      p_id_user = beam.Pipeline(interactive_runner.InteractiveRunner())
    63      ie.current_env().set_cache_manager(InMemoryCache(), p_id_user)
    64      init_pcoll = p_id_user | 'Init Create' >> beam.Create([1, 2, 3])
    65      instrumentation = instr.build_pipeline_instrument(p_id_user)
    66      self.assertEqual(
    67          instrumentation.pcoll_id(init_pcoll), 'ref_PCollection_PCollection_8')
    68  
    69    def test_pcoll_id_with_runner_pipeline(self):
    70      p_id_runner = beam.Pipeline(interactive_runner.InteractiveRunner())
    71      ie.current_env().set_cache_manager(InMemoryCache(), p_id_runner)
    72      # pylint: disable=possibly-unused-variable
    73      init_pcoll = p_id_runner | 'Init Create' >> beam.Create([1, 2, 3])
    74      ib.watch(locals())
    75  
    76      # It's normal that when executing, the pipeline object is a different
    77      # but equivalent instance from what user has built. The pipeline instrument
    78      # should be able to identify if the original instance has changed in an
    79      # interactive env while mutating the other instance for execution. The
    80      # version map can be used to figure out what the PCollection instances are
    81      # in the original instance and if the evaluation has changed since last
    82      # execution.
    83      p2_id_runner = beam.Pipeline(interactive_runner.InteractiveRunner())
    84      # pylint: disable=bad-option-value
    85      init_pcoll_2 = p2_id_runner | 'Init Create' >> beam.Create(range(10))
    86      ie.current_env().add_derived_pipeline(p_id_runner, p2_id_runner)
    87  
    88      instrumentation = instr.build_pipeline_instrument(p2_id_runner)
    89      # The cache_key should use id(init_pcoll) as prefix even when
    90      # init_pcoll_2 is supplied as long as the version map is given.
    91      self.assertEqual(
    92          instrumentation.pcoll_id(init_pcoll_2), 'ref_PCollection_PCollection_8')
    93  
    94    def test_cache_key(self):
    95      p = beam.Pipeline(interactive_runner.InteractiveRunner())
    96      ie.current_env().set_cache_manager(InMemoryCache(), p)
    97      # pylint: disable=bad-option-value
    98      init_pcoll = p | 'Init Create' >> beam.Create(range(10))
    99      squares = init_pcoll | 'Square' >> beam.Map(lambda x: x * x)
   100      cubes = init_pcoll | 'Cube' >> beam.Map(lambda x: x**3)
   101      # Watch the local variables, i.e., the Beam pipeline defined.
   102      ib.watch(locals())
   103  
   104      pipeline_instrument = instr.build_pipeline_instrument(p)
   105      self.assertEqual(
   106          pipeline_instrument.cache_key(init_pcoll),
   107          self.cache_key_of('init_pcoll', init_pcoll))
   108      self.assertEqual(
   109          pipeline_instrument.cache_key(squares),
   110          self.cache_key_of('squares', squares))
   111      self.assertEqual(
   112          pipeline_instrument.cache_key(cubes), self.cache_key_of('cubes', cubes))
   113  
   114    def test_cacheables(self):
   115      p_cacheables = beam.Pipeline(interactive_runner.InteractiveRunner())
   116      ie.current_env().set_cache_manager(InMemoryCache(), p_cacheables)
   117      # pylint: disable=bad-option-value
   118      init_pcoll = p_cacheables | 'Init Create' >> beam.Create(range(10))
   119      squares = init_pcoll | 'Square' >> beam.Map(lambda x: x * x)
   120      cubes = init_pcoll | 'Cube' >> beam.Map(lambda x: x**3)
   121      ib.watch(locals())
   122  
   123      pipeline_instrument = instr.build_pipeline_instrument(p_cacheables)
   124  
   125      self.assertEqual(
   126          pipeline_instrument._cacheables,
   127          {
   128              pipeline_instrument.pcoll_id(init_pcoll): Cacheable(
   129                  var='init_pcoll',
   130                  version=str(id(init_pcoll)),
   131                  producer_version=str(id(init_pcoll.producer)),
   132                  pcoll=init_pcoll),
   133              pipeline_instrument.pcoll_id(squares): Cacheable(
   134                  var='squares',
   135                  version=str(id(squares)),
   136                  producer_version=str(id(squares.producer)),
   137                  pcoll=squares),
   138              pipeline_instrument.pcoll_id(cubes): Cacheable(
   139                  var='cubes',
   140                  version=str(id(cubes)),
   141                  producer_version=str(id(cubes.producer)),
   142                  pcoll=cubes)
   143          })
   144  
   145    def test_background_caching_pipeline_proto(self):
   146      p = beam.Pipeline(interactive_runner.InteractiveRunner())
   147      ie.current_env().set_cache_manager(StreamingCache(cache_dir=None), p)
   148  
   149      # Test that the two ReadFromPubSub are correctly cut out.
   150      a = p | 'ReadUnboundedSourceA' >> beam.io.ReadFromPubSub(
   151          subscription='projects/fake-project/subscriptions/fake_sub')
   152      b = p | 'ReadUnboundedSourceB' >> beam.io.ReadFromPubSub(
   153          subscription='projects/fake-project/subscriptions/fake_sub')
   154  
   155      # Add some extra PTransform afterwards to make sure that only the unbounded
   156      # sources remain.
   157      c = (a, b) | beam.Flatten()
   158      _ = c | beam.Map(lambda x: x)
   159  
   160      ib.watch(locals())
   161      instrumenter = instr.build_pipeline_instrument(p)
   162      actual_pipeline = instrumenter.background_caching_pipeline_proto()
   163  
   164      # Now recreate the expected pipeline, which should only have the unbounded
   165      # sources.
   166      p = beam.Pipeline(interactive_runner.InteractiveRunner())
   167      ie.current_env().set_cache_manager(StreamingCache(cache_dir=None), p)
   168      a = p | 'ReadUnboundedSourceA' >> beam.io.ReadFromPubSub(
   169          subscription='projects/fake-project/subscriptions/fake_sub')
   170      _ = (
   171          a
   172          | 'reify a' >> beam.Map(lambda _: _)
   173          | 'a' >> cache.WriteCache(ie.current_env().get_cache_manager(p), ''))
   174  
   175      b = p | 'ReadUnboundedSourceB' >> beam.io.ReadFromPubSub(
   176          subscription='projects/fake-project/subscriptions/fake_sub')
   177      _ = (
   178          b
   179          | 'reify b' >> beam.Map(lambda _: _)
   180          | 'b' >> cache.WriteCache(ie.current_env().get_cache_manager(p), ''))
   181  
   182      expected_pipeline = p.to_runner_api(return_context=False)
   183      assert_pipeline_proto_equal(self, expected_pipeline, actual_pipeline)
   184  
   185    def _example_pipeline(self, watch=True, bounded=True):
   186      p_example = beam.Pipeline(interactive_runner.InteractiveRunner())
   187      ie.current_env().set_cache_manager(InMemoryCache(), p_example)
   188      # pylint: disable=bad-option-value
   189      if bounded:
   190        source = beam.Create(range(10))
   191      else:
   192        source = beam.io.ReadFromPubSub(
   193            subscription='projects/fake-project/subscriptions/fake_sub')
   194  
   195      init_pcoll = p_example | 'Init Source' >> source
   196      second_pcoll = init_pcoll | 'Second' >> beam.Map(lambda x: x * x)
   197      if watch:
   198        ib.watch(locals())
   199      return (p_example, init_pcoll, second_pcoll)
   200  
   201    def _mock_write_cache(self, pipeline, values, cache_key):
   202      """Cache the PCollection where cache.WriteCache would write to."""
   203      labels = ['full', cache_key]
   204  
   205      # Usually, the pcoder will be inferred from `pcoll.element_type`
   206      pcoder = coders.registry.get_coder(object)
   207      cache_manager = ie.current_env().get_cache_manager(pipeline)
   208      cache_manager.save_pcoder(pcoder, *labels)
   209      cache_manager.write(values, *labels)
   210  
   211    def test_instrument_example_pipeline_to_write_cache(self):
   212      # Original instance defined by user code has all variables handlers.
   213      p_origin, init_pcoll, second_pcoll = self._example_pipeline()
   214      # Copied instance when execution has no user defined variables.
   215      p_copy, _, _ = self._example_pipeline(watch=False)
   216      ie.current_env().add_derived_pipeline(p_origin, p_copy)
   217      # Instrument the copied pipeline.
   218      pipeline_instrument = instr.build_pipeline_instrument(p_copy)
   219      # Manually instrument original pipeline with expected pipeline transforms.
   220      init_pcoll_cache_key = pipeline_instrument.cache_key(init_pcoll)
   221      _ = (
   222          init_pcoll
   223          | 'reify init' >> beam.Map(lambda _: _)
   224          | '_WriteCache_' + init_pcoll_cache_key >> cache.WriteCache(
   225              ie.current_env().get_cache_manager(p_origin), init_pcoll_cache_key))
   226      second_pcoll_cache_key = pipeline_instrument.cache_key(second_pcoll)
   227      _ = (
   228          second_pcoll
   229          | 'reify second' >> beam.Map(lambda _: _)
   230          | '_WriteCache_' + second_pcoll_cache_key >> cache.WriteCache(
   231              ie.current_env().get_cache_manager(p_origin),
   232              second_pcoll_cache_key))
   233      # The 2 pipelines should be the same now.
   234      assert_pipeline_equal(self, p_copy, p_origin)
   235  
   236    def test_instrument_example_pipeline_to_read_cache(self):
   237      p_origin, init_pcoll, second_pcoll = self._example_pipeline()
   238      p_copy, _, _ = self._example_pipeline(False)
   239  
   240      # Mock as if cacheable PCollections are cached.
   241      init_pcoll_cache_key = self.cache_key_of('init_pcoll', init_pcoll)
   242      self._mock_write_cache(p_origin, [b'1', b'2', b'3'], init_pcoll_cache_key)
   243      second_pcoll_cache_key = self.cache_key_of('second_pcoll', second_pcoll)
   244      self._mock_write_cache(p_origin, [b'1', b'4', b'9'], second_pcoll_cache_key)
   245      # Mark the completeness of PCollections from the original(user) pipeline.
   246      ie.current_env().mark_pcollection_computed((init_pcoll, second_pcoll))
   247      ie.current_env().add_derived_pipeline(p_origin, p_copy)
   248      instr.build_pipeline_instrument(p_copy)
   249  
   250      cached_init_pcoll = (
   251          p_origin
   252          | '_ReadCache_' + init_pcoll_cache_key >> cache.ReadCache(
   253              ie.current_env().get_cache_manager(p_origin), init_pcoll_cache_key)
   254          | 'unreify' >> beam.Map(lambda _: _))
   255  
   256      # second_pcoll is never used as input and there is no need to read cache.
   257  
   258      class TestReadCacheWireVisitor(PipelineVisitor):
   259        """Replace init_pcoll with cached_init_pcoll for all occuring inputs."""
   260        def enter_composite_transform(self, transform_node):
   261          self.visit_transform(transform_node)
   262  
   263        def visit_transform(self, transform_node):
   264          if transform_node.inputs:
   265            main_inputs = dict(transform_node.main_inputs)
   266            for tag, main_input in main_inputs.items():
   267              if main_input == init_pcoll:
   268                main_inputs[tag] = cached_init_pcoll
   269            transform_node.main_inputs = main_inputs
   270  
   271      v = TestReadCacheWireVisitor()
   272      p_origin.visit(v)
   273      assert_pipeline_equal(self, p_origin, p_copy)
   274  
   275    def test_find_out_correct_user_pipeline(self):
   276      # This is the user pipeline instance we care in the watched scope.
   277      user_pipeline, _, _ = self._example_pipeline()
   278      # This is a new runner pipeline instance with the same pipeline graph to
   279      # what the user_pipeline represents.
   280      runner_pipeline = beam.pipeline.Pipeline.from_runner_api(
   281          user_pipeline.to_runner_api(), user_pipeline.runner, options=None)
   282      ie.current_env().add_derived_pipeline(user_pipeline, runner_pipeline)
   283      # This is a totally irrelevant user pipeline in the watched scope.
   284      irrelevant_user_pipeline = beam.Pipeline(
   285          interactive_runner.InteractiveRunner())
   286      ib.watch({'irrelevant_user_pipeline': irrelevant_user_pipeline})
   287      # Build instrument from the runner pipeline.
   288      pipeline_instrument = instr.build_pipeline_instrument(runner_pipeline)
   289      self.assertIs(pipeline_instrument.user_pipeline, user_pipeline)
   290  
   291    def test_instrument_example_unbounded_pipeline_to_read_cache(self):
   292      """Tests that the instrumenter works for a single unbounded source.
   293      """
   294      # Create the pipeline that will be instrumented.
   295      p_original = beam.Pipeline(interactive_runner.InteractiveRunner())
   296      ie.current_env().set_cache_manager(
   297          StreamingCache(cache_dir=None), p_original)
   298      source_1 = p_original | 'source1' >> beam.io.ReadFromPubSub(
   299          subscription='projects/fake-project/subscriptions/fake_sub')
   300      # pylint: disable=possibly-unused-variable
   301      pcoll_1 = source_1 | 'square1' >> beam.Map(lambda x: x * x)
   302  
   303      # Mock as if cacheable PCollections are cached.
   304      ib.watch(locals())
   305      # This should be noop.
   306      utils.watch_sources(p_original)
   307      for name, pcoll in locals().items():
   308        if not isinstance(pcoll, beam.pvalue.PCollection):
   309          continue
   310        cache_key = self.cache_key_of(name, pcoll)
   311        self._mock_write_cache(p_original, [], cache_key)
   312  
   313      # Instrument the original pipeline to create the pipeline the user will see.
   314      instrumenter = instr.build_pipeline_instrument(p_original)
   315      actual_pipeline = beam.Pipeline.from_runner_api(
   316          proto=instrumenter.instrumented_pipeline_proto(),
   317          runner=interactive_runner.InteractiveRunner(),
   318          options=None)
   319  
   320      # Now, build the expected pipeline which replaces the unbounded source with
   321      # a TestStream.
   322      source_1_cache_key = self.cache_key_of('source_1', source_1)
   323      p_expected = beam.Pipeline()
   324      test_stream = (p_expected | TestStream(output_tags=[source_1_cache_key]))
   325      # pylint: disable=expression-not-assigned
   326      test_stream[source_1_cache_key] | 'square1' >> beam.Map(lambda x: x * x)
   327  
   328      # Test that the TestStream is outputting to the correct PCollection.
   329      class TestStreamVisitor(PipelineVisitor):
   330        def __init__(self):
   331          self.output_tags = set()
   332  
   333        def enter_composite_transform(self, transform_node):
   334          self.visit_transform(transform_node)
   335  
   336        def visit_transform(self, transform_node):
   337          transform = transform_node.transform
   338          if isinstance(transform, TestStream):
   339            self.output_tags = transform.output_tags
   340  
   341      v = TestStreamVisitor()
   342      actual_pipeline.visit(v)
   343      expected_output_tags = set([source_1_cache_key])
   344      actual_output_tags = v.output_tags
   345      self.assertSetEqual(expected_output_tags, actual_output_tags)
   346  
   347      # Test that the pipeline is as expected.
   348      assert_pipeline_proto_equal(
   349          self,
   350          p_expected.to_runner_api(),
   351          instrumenter.instrumented_pipeline_proto())
   352  
   353    def test_able_to_cache_intermediate_unbounded_source_pcollection(self):
   354      """Tests being able to cache an intermediate source PCollection.
   355  
   356      In the following pipeline, the source doesn't have a reference and so is
   357      not automatically cached in the watch() command. This tests that this case
   358      is taken care of.
   359      """
   360      # Create the pipeline that will be instrumented.
   361      from apache_beam.options.pipeline_options import StandardOptions
   362      options = StandardOptions(streaming=True)
   363      streaming_cache_manager = StreamingCache(cache_dir=None)
   364      p_original_cache_source = beam.Pipeline(
   365          interactive_runner.InteractiveRunner(), options)
   366      ie.current_env().set_cache_manager(
   367          streaming_cache_manager, p_original_cache_source)
   368  
   369      # pylint: disable=possibly-unused-variable
   370      source_1 = (
   371          p_original_cache_source
   372          | 'source1' >> beam.io.ReadFromPubSub(
   373              subscription='projects/fake-project/subscriptions/fake_sub')
   374          | beam.Map(lambda e: e))
   375  
   376      # Watch but do not cache the PCollections.
   377      ib.watch(locals())
   378      # Make sure that sources without a user reference are still cached.
   379      utils.watch_sources(p_original_cache_source)
   380  
   381      intermediate_source_pcoll = None
   382      for watching in ie.current_env().watching():
   383        watching = list(watching)
   384        for var, watchable in watching:
   385          if 'synthetic' in var:
   386            intermediate_source_pcoll = watchable
   387            break
   388  
   389      # Instrument the original pipeline to create the pipeline the user will see.
   390      p_copy = beam.Pipeline.from_runner_api(
   391          p_original_cache_source.to_runner_api(),
   392          runner=interactive_runner.InteractiveRunner(),
   393          options=options)
   394      ie.current_env().add_derived_pipeline(p_original_cache_source, p_copy)
   395      instrumenter = instr.build_pipeline_instrument(p_copy)
   396      actual_pipeline = beam.Pipeline.from_runner_api(
   397          proto=instrumenter.instrumented_pipeline_proto(),
   398          runner=interactive_runner.InteractiveRunner(),
   399          options=options)
   400      ie.current_env().add_derived_pipeline(
   401          p_original_cache_source, actual_pipeline)
   402  
   403      # Now, build the expected pipeline which replaces the unbounded source with
   404      # a TestStream.
   405      intermediate_source_pcoll_cache_key = \
   406          self.cache_key_of('synthetic_var_' + str(id(intermediate_source_pcoll)),
   407                       intermediate_source_pcoll)
   408      p_expected = beam.Pipeline()
   409      ie.current_env().set_cache_manager(streaming_cache_manager, p_expected)
   410      test_stream = (
   411          p_expected
   412          | TestStream(output_tags=[intermediate_source_pcoll_cache_key]))
   413      # pylint: disable=expression-not-assigned
   414      (
   415          test_stream[intermediate_source_pcoll_cache_key]
   416          | 'square1' >> beam.Map(lambda e: e)
   417          | 'reify' >> beam.Map(lambda _: _)
   418          | cache.WriteCache(
   419              ie.current_env().get_cache_manager(p_expected), 'unused'))
   420  
   421      # Test that the TestStream is outputting to the correct PCollection.
   422      class TestStreamVisitor(PipelineVisitor):
   423        def __init__(self):
   424          self.output_tags = set()
   425  
   426        def enter_composite_transform(self, transform_node):
   427          self.visit_transform(transform_node)
   428  
   429        def visit_transform(self, transform_node):
   430          transform = transform_node.transform
   431          if isinstance(transform, TestStream):
   432            self.output_tags = transform.output_tags
   433  
   434      v = TestStreamVisitor()
   435      actual_pipeline.visit(v)
   436      expected_output_tags = set([intermediate_source_pcoll_cache_key])
   437      actual_output_tags = v.output_tags
   438      self.assertSetEqual(expected_output_tags, actual_output_tags)
   439  
   440      # Test that the pipeline is as expected.
   441      assert_pipeline_proto_equal(
   442          self,
   443          p_expected.to_runner_api(),
   444          instrumenter.instrumented_pipeline_proto())
   445  
   446    def test_instrument_mixed_streaming_batch(self):
   447      """Tests caching for both batch and streaming sources in the same pipeline.
   448  
   449      This ensures that cached bounded and unbounded sources are read from the
   450      TestStream.
   451      """
   452      # Create the pipeline that will be instrumented.
   453      from apache_beam.options.pipeline_options import StandardOptions
   454      options = StandardOptions(streaming=True)
   455      p_original = beam.Pipeline(interactive_runner.InteractiveRunner(), options)
   456      streaming_cache_manager = StreamingCache(cache_dir=None)
   457      ie.current_env().set_cache_manager(streaming_cache_manager, p_original)
   458      source_1 = p_original | 'source1' >> beam.io.ReadFromPubSub(
   459          subscription='projects/fake-project/subscriptions/fake_sub')
   460      source_2 = p_original | 'source2' >> beam.Create([1, 2, 3, 4, 5])
   461  
   462      # pylint: disable=possibly-unused-variable
   463      pcoll_1 = ((source_1, source_2)
   464                 | beam.Flatten()
   465                 | 'square1' >> beam.Map(lambda x: x * x))
   466  
   467      # Watch but do not cache the PCollections.
   468      ib.watch(locals())
   469      # This should be noop.
   470      utils.watch_sources(p_original)
   471      self._mock_write_cache(
   472          p_original, [], self.cache_key_of('source_2', source_2))
   473      ie.current_env().mark_pcollection_computed([source_2])
   474  
   475      # Instrument the original pipeline to create the pipeline the user will see.
   476      p_copy = beam.Pipeline.from_runner_api(
   477          p_original.to_runner_api(),
   478          runner=interactive_runner.InteractiveRunner(),
   479          options=options)
   480      ie.current_env().add_derived_pipeline(p_original, p_copy)
   481      instrumenter = instr.build_pipeline_instrument(p_copy)
   482      actual_pipeline = beam.Pipeline.from_runner_api(
   483          proto=instrumenter.instrumented_pipeline_proto(),
   484          runner=interactive_runner.InteractiveRunner(),
   485          options=options)
   486  
   487      # Now, build the expected pipeline which replaces the unbounded source with
   488      # a TestStream.
   489      source_1_cache_key = self.cache_key_of('source_1', source_1)
   490      source_2_cache_key = self.cache_key_of('source_2', source_2)
   491      p_expected = beam.Pipeline()
   492      ie.current_env().set_cache_manager(streaming_cache_manager, p_expected)
   493      test_stream = (
   494          p_expected
   495          | TestStream(output_tags=[source_1_cache_key, source_2_cache_key]))
   496      # pylint: disable=expression-not-assigned
   497      ((
   498          test_stream[self.cache_key_of('source_1', source_1)],
   499          test_stream[self.cache_key_of('source_2', source_2)])
   500       | beam.Flatten()
   501       | 'square1' >> beam.Map(lambda x: x * x)
   502       | 'reify' >> beam.Map(lambda _: _)
   503       | cache.WriteCache(
   504           ie.current_env().get_cache_manager(p_expected), 'unused'))
   505  
   506      # Test that the TestStream is outputting to the correct PCollection.
   507      class TestStreamVisitor(PipelineVisitor):
   508        def __init__(self):
   509          self.output_tags = set()
   510  
   511        def enter_composite_transform(self, transform_node):
   512          self.visit_transform(transform_node)
   513  
   514        def visit_transform(self, transform_node):
   515          transform = transform_node.transform
   516          if isinstance(transform, TestStream):
   517            self.output_tags = transform.output_tags
   518  
   519      v = TestStreamVisitor()
   520      actual_pipeline.visit(v)
   521      expected_output_tags = set([source_1_cache_key, source_2_cache_key])
   522      actual_output_tags = v.output_tags
   523      self.assertSetEqual(expected_output_tags, actual_output_tags)
   524  
   525      # Test that the pipeline is as expected.
   526      assert_pipeline_proto_equal(
   527          self,
   528          p_expected.to_runner_api(),
   529          instrumenter.instrumented_pipeline_proto())
   530  
   531    def test_instrument_example_unbounded_pipeline_direct_from_source(self):
   532      """Tests that the it caches PCollections from a source.
   533      """
   534      # Create the pipeline that will be instrumented.
   535      from apache_beam.options.pipeline_options import StandardOptions
   536      options = StandardOptions(streaming=True)
   537      p_original_direct_source = beam.Pipeline(
   538          interactive_runner.InteractiveRunner(), options)
   539      ie.current_env().set_cache_manager(
   540          StreamingCache(cache_dir=None), p_original_direct_source)
   541      source_1 = p_original_direct_source | 'source1' >> beam.io.ReadFromPubSub(
   542          subscription='projects/fake-project/subscriptions/fake_sub')
   543      # pylint: disable=possibly-unused-variable
   544      p_expected = beam.Pipeline()
   545      # pylint: disable=unused-variable
   546      test_stream = (
   547          p_expected
   548          | TestStream(output_tags=[self.cache_key_of('source_1', source_1)]))
   549      # Watch but do not cache the PCollections.
   550      ib.watch(locals())
   551      # This should be noop.
   552      utils.watch_sources(p_original_direct_source)
   553      # Instrument the original pipeline to create the pipeline the user will see.
   554      p_copy = beam.Pipeline.from_runner_api(
   555          p_original_direct_source.to_runner_api(),
   556          runner=interactive_runner.InteractiveRunner(),
   557          options=options)
   558      ie.current_env().add_derived_pipeline(p_original_direct_source, p_copy)
   559      instrumenter = instr.build_pipeline_instrument(p_copy)
   560      actual_pipeline = beam.Pipeline.from_runner_api(
   561          proto=instrumenter.instrumented_pipeline_proto(),
   562          runner=interactive_runner.InteractiveRunner(),
   563          options=options)
   564      ie.current_env().add_derived_pipeline(
   565          p_original_direct_source, actual_pipeline)
   566  
   567      # Now, build the expected pipeline which replaces the unbounded source with
   568      # a TestStream.
   569      source_1_cache_key = self.cache_key_of('source_1', source_1)
   570  
   571      # Test that the TestStream is outputting to the correct PCollection.
   572      class TestStreamVisitor(PipelineVisitor):
   573        def __init__(self):
   574          self.output_tags = set()
   575  
   576        def enter_composite_transform(self, transform_node):
   577          self.visit_transform(transform_node)
   578  
   579        def visit_transform(self, transform_node):
   580          transform = transform_node.transform
   581          if isinstance(transform, TestStream):
   582            self.output_tags = transform.output_tags
   583  
   584      v = TestStreamVisitor()
   585      actual_pipeline.visit(v)
   586      expected_output_tags = set([source_1_cache_key])
   587      actual_output_tags = v.output_tags
   588      self.assertSetEqual(expected_output_tags, actual_output_tags)
   589  
   590      # Test that the pipeline is as expected.
   591      assert_pipeline_proto_equal(
   592          self,
   593          p_expected.to_runner_api(),
   594          instrumenter.instrumented_pipeline_proto())
   595  
   596    def test_instrument_example_unbounded_pipeline_to_read_cache_not_cached(self):
   597      """Tests that the instrumenter works when the PCollection is not cached.
   598      """
   599      # Create the pipeline that will be instrumented.
   600      from apache_beam.options.pipeline_options import StandardOptions
   601      options = StandardOptions(streaming=True)
   602      p_original_read_cache = beam.Pipeline(
   603          interactive_runner.InteractiveRunner(), options)
   604      ie.current_env().set_cache_manager(
   605          StreamingCache(cache_dir=None), p_original_read_cache)
   606      source_1 = p_original_read_cache | 'source1' >> beam.io.ReadFromPubSub(
   607          subscription='projects/fake-project/subscriptions/fake_sub')
   608      # pylint: disable=possibly-unused-variable
   609      pcoll_1 = source_1 | 'square1' >> beam.Map(lambda x: x * x)
   610  
   611      # Watch but do not cache the PCollections.
   612      ib.watch(locals())
   613      # This should be noop.
   614      utils.watch_sources(p_original_read_cache)
   615      # Instrument the original pipeline to create the pipeline the user will see.
   616      p_copy = beam.Pipeline.from_runner_api(
   617          p_original_read_cache.to_runner_api(),
   618          runner=interactive_runner.InteractiveRunner(),
   619          options=options)
   620      ie.current_env().add_derived_pipeline(p_original_read_cache, p_copy)
   621      instrumenter = instr.build_pipeline_instrument(p_copy)
   622      actual_pipeline = beam.Pipeline.from_runner_api(
   623          proto=instrumenter.instrumented_pipeline_proto(),
   624          runner=interactive_runner.InteractiveRunner(),
   625          options=options)
   626  
   627      # Now, build the expected pipeline which replaces the unbounded source with
   628      # a TestStream.
   629      source_1_cache_key = self.cache_key_of('source_1', source_1)
   630      p_expected = beam.Pipeline()
   631      ie.current_env().set_cache_manager(
   632          StreamingCache(cache_dir=None), p_expected)
   633      test_stream = (p_expected | TestStream(output_tags=[source_1_cache_key]))
   634      # pylint: disable=expression-not-assigned
   635      (
   636          test_stream[source_1_cache_key]
   637          | 'square1' >> beam.Map(lambda x: x * x)
   638          | 'reify' >> beam.Map(lambda _: _)
   639          | cache.WriteCache(
   640              ie.current_env().get_cache_manager(p_expected), 'unused'))
   641  
   642      # Test that the TestStream is outputting to the correct PCollection.
   643      class TestStreamVisitor(PipelineVisitor):
   644        def __init__(self):
   645          self.output_tags = set()
   646  
   647        def enter_composite_transform(self, transform_node):
   648          self.visit_transform(transform_node)
   649  
   650        def visit_transform(self, transform_node):
   651          transform = transform_node.transform
   652          if isinstance(transform, TestStream):
   653            self.output_tags = transform.output_tags
   654  
   655      v = TestStreamVisitor()
   656      actual_pipeline.visit(v)
   657      expected_output_tags = set([source_1_cache_key])
   658      actual_output_tags = v.output_tags
   659      self.assertSetEqual(expected_output_tags, actual_output_tags)
   660  
   661      # Test that the pipeline is as expected.
   662      assert_pipeline_proto_equal(
   663          self,
   664          p_expected.to_runner_api(),
   665          instrumenter.instrumented_pipeline_proto())
   666  
   667    def test_instrument_example_unbounded_pipeline_to_multiple_read_cache(self):
   668      """Tests that the instrumenter works for multiple unbounded sources.
   669      """
   670      # Create the pipeline that will be instrumented.
   671      p_original = beam.Pipeline(interactive_runner.InteractiveRunner())
   672      ie.current_env().set_cache_manager(
   673          StreamingCache(cache_dir=None), p_original)
   674      source_1 = p_original | 'source1' >> beam.io.ReadFromPubSub(
   675          subscription='projects/fake-project/subscriptions/fake_sub')
   676      source_2 = p_original | 'source2' >> beam.io.ReadFromPubSub(
   677          subscription='projects/fake-project/subscriptions/fake_sub')
   678      # pylint: disable=possibly-unused-variable
   679      pcoll_1 = source_1 | 'square1' >> beam.Map(lambda x: x * x)
   680      # pylint: disable=possibly-unused-variable
   681      pcoll_2 = source_2 | 'square2' >> beam.Map(lambda x: x * x)
   682  
   683      # Mock as if cacheable PCollections are cached.
   684      ib.watch(locals())
   685      # This should be noop.
   686      utils.watch_sources(p_original)
   687      for name, pcoll in locals().items():
   688        if not isinstance(pcoll, beam.pvalue.PCollection):
   689          continue
   690        cache_key = self.cache_key_of(name, pcoll)
   691        self._mock_write_cache(p_original, [], cache_key)
   692  
   693      # Instrument the original pipeline to create the pipeline the user will see.
   694      instrumenter = instr.build_pipeline_instrument(p_original)
   695      actual_pipeline = beam.Pipeline.from_runner_api(
   696          proto=instrumenter.instrumented_pipeline_proto(),
   697          runner=interactive_runner.InteractiveRunner(),
   698          options=None)
   699  
   700      # Now, build the expected pipeline which replaces the unbounded source with
   701      # a TestStream.
   702      source_1_cache_key = self.cache_key_of('source_1', source_1)
   703      source_2_cache_key = self.cache_key_of('source_2', source_2)
   704      p_expected = beam.Pipeline()
   705      test_stream = (
   706          p_expected
   707          | TestStream(
   708              output_tags=[
   709                  self.cache_key_of('source_1', source_1),
   710                  self.cache_key_of('source_2', source_2)
   711              ]))
   712      # pylint: disable=expression-not-assigned
   713      test_stream[source_1_cache_key] | 'square1' >> beam.Map(lambda x: x * x)
   714      # pylint: disable=expression-not-assigned
   715      test_stream[source_2_cache_key] | 'square2' >> beam.Map(lambda x: x * x)
   716  
   717      # Test that the TestStream is outputting to the correct PCollection.
   718      class TestStreamVisitor(PipelineVisitor):
   719        def __init__(self):
   720          self.output_tags = set()
   721  
   722        def enter_composite_transform(self, transform_node):
   723          self.visit_transform(transform_node)
   724  
   725        def visit_transform(self, transform_node):
   726          transform = transform_node.transform
   727          if isinstance(transform, TestStream):
   728            self.output_tags = transform.output_tags
   729  
   730      v = TestStreamVisitor()
   731      actual_pipeline.visit(v)
   732      expected_output_tags = set([source_1_cache_key, source_2_cache_key])
   733      actual_output_tags = v.output_tags
   734      self.assertSetEqual(expected_output_tags, actual_output_tags)
   735  
   736      # Test that the pipeline is as expected.
   737      assert_pipeline_proto_equal(
   738          self,
   739          p_expected.to_runner_api(),
   740          instrumenter.instrumented_pipeline_proto())
   741  
   742    def test_pipeline_pruned_when_input_pcoll_is_cached(self):
   743      user_pipeline, init_pcoll, _ = self._example_pipeline()
   744      runner_pipeline = beam.Pipeline.from_runner_api(
   745          user_pipeline.to_runner_api(), user_pipeline.runner, None)
   746      ie.current_env().add_derived_pipeline(user_pipeline, runner_pipeline)
   747  
   748      # Mock as if init_pcoll is cached.
   749      init_pcoll_cache_key = self.cache_key_of('init_pcoll', init_pcoll)
   750      self._mock_write_cache(
   751          user_pipeline, [b'1', b'2', b'3'], init_pcoll_cache_key)
   752      ie.current_env().mark_pcollection_computed([init_pcoll])
   753      # Build an instrument from the runner pipeline.
   754      pipeline_instrument = instr.build_pipeline_instrument(runner_pipeline)
   755  
   756      pruned_proto = pipeline_instrument.instrumented_pipeline_proto()
   757      # Skip the prune step for comparison, it should contain the sub-graph that
   758      # produces init_pcoll but not useful anymore.
   759      full_proto = pipeline_instrument._pipeline.to_runner_api()
   760      self.assertEqual(
   761          len(
   762              pruned_proto.components.transforms[
   763                  'ref_AppliedPTransform_AppliedPTransform_1'].subtransforms),
   764          5)
   765      assert_pipeline_proto_not_contain_top_level_transform(
   766          self, pruned_proto, 'Init Source')
   767      self.assertEqual(
   768          len(
   769              full_proto.components.transforms[
   770                  'ref_AppliedPTransform_AppliedPTransform_1'].subtransforms),
   771          6)
   772      assert_pipeline_proto_contain_top_level_transform(
   773          self, full_proto, 'Init-Source')
   774  
   775    def test_side_effect_pcoll_is_included(self):
   776      pipeline_with_side_effect = beam.Pipeline(
   777          interactive_runner.InteractiveRunner())
   778      ie.current_env().set_cache_manager(
   779          InMemoryCache(), pipeline_with_side_effect)
   780      # Deliberately not assign the result to a variable to make it a
   781      # "side effect" transform. Note we never watch anything from
   782      # the pipeline defined locally either.
   783      # pylint: disable=bad-option-value,expression-not-assigned
   784      pipeline_with_side_effect | 'Init Create' >> beam.Create(range(10))
   785      pipeline_instrument = instr.build_pipeline_instrument(
   786          pipeline_with_side_effect)
   787      self.assertTrue(pipeline_instrument._extended_targets)
   788  
   789  
   790  if __name__ == '__main__':
   791    unittest.main()