github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/interactive_runner_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Tests for google3.pipeline.dataflow.python.interactive.interactive_runner.
    19  
    20  This module is experimental. No backwards-compatibility guarantees.
    21  """
    22  
    23  # pytype: skip-file
    24  
    25  import sys
    26  import unittest
    27  from typing import NamedTuple
    28  
    29  import pandas as pd
    30  
    31  import apache_beam as beam
    32  from apache_beam.dataframe.convert import to_dataframe
    33  from apache_beam.options.pipeline_options import FlinkRunnerOptions
    34  from apache_beam.options.pipeline_options import PipelineOptions
    35  from apache_beam.options.pipeline_options import StandardOptions
    36  from apache_beam.options.pipeline_options import WorkerOptions
    37  from apache_beam.runners.direct import direct_runner
    38  from apache_beam.runners.interactive import interactive_beam as ib
    39  from apache_beam.runners.interactive import interactive_environment as ie
    40  from apache_beam.runners.interactive import interactive_runner
    41  from apache_beam.runners.interactive.dataproc.dataproc_cluster_manager import DataprocClusterManager
    42  from apache_beam.runners.interactive.dataproc.types import ClusterMetadata
    43  from apache_beam.runners.interactive.testing.mock_env import isolated_env
    44  from apache_beam.runners.portability.flink_runner import FlinkRunner
    45  from apache_beam.testing.test_stream import TestStream
    46  from apache_beam.transforms.window import GlobalWindow
    47  from apache_beam.transforms.window import IntervalWindow
    48  from apache_beam.utils.timestamp import Timestamp
    49  from apache_beam.utils.windowed_value import PaneInfo
    50  from apache_beam.utils.windowed_value import PaneInfoTiming
    51  from apache_beam.utils.windowed_value import WindowedValue
    52  
    53  
    54  def print_with_message(msg):
    55    def printer(elem):
    56      print(msg, elem)
    57      return elem
    58  
    59    return printer
    60  
    61  
    62  class Record(NamedTuple):
    63    name: str
    64    age: int
    65    height: int
    66  
    67  
    68  @isolated_env
    69  class InteractiveRunnerTest(unittest.TestCase):
    70    @unittest.skipIf(sys.platform == "win32", "[BEAM-10627]")
    71    def test_basic(self):
    72      p = beam.Pipeline(
    73          runner=interactive_runner.InteractiveRunner(
    74              direct_runner.DirectRunner()))
    75      ib.watch({'p': p})
    76      p.run().wait_until_finish()
    77      pc0 = (
    78          p | 'read' >> beam.Create([1, 2, 3])
    79          | 'Print1.1' >> beam.Map(print_with_message('Run1.1')))
    80      pc = pc0 | 'Print1.2' >> beam.Map(print_with_message('Run1.2'))
    81      ib.watch(locals())
    82      p.run().wait_until_finish()
    83      _ = pc | 'Print2' >> beam.Map(print_with_message('Run2'))
    84      p.run().wait_until_finish()
    85      _ = pc0 | 'Print3' >> beam.Map(print_with_message('Run3'))
    86      p.run().wait_until_finish()
    87  
    88    @unittest.skipIf(sys.platform == "win32", "[BEAM-10627]")
    89    def test_wordcount(self):
    90      class WordExtractingDoFn(beam.DoFn):
    91        def process(self, element):
    92          text_line = element.strip()
    93          words = text_line.split()
    94          return words
    95  
    96      p = beam.Pipeline(
    97          runner=interactive_runner.InteractiveRunner(
    98              direct_runner.DirectRunner()))
    99  
   100      # Count the occurrences of each word.
   101      counts = (
   102          p
   103          | beam.Create(['to be or not to be that is the question'])
   104          | 'split' >> beam.ParDo(WordExtractingDoFn())
   105          | 'pair_with_one' >> beam.Map(lambda x: (x, 1))
   106          | 'group' >> beam.GroupByKey()
   107          | 'count' >> beam.Map(lambda wordones: (wordones[0], sum(wordones[1]))))
   108  
   109      # Watch the local scope for Interactive Beam so that counts will be cached.
   110      ib.watch(locals())
   111  
   112      # This is normally done in the interactive_utils when a transform is
   113      # applied but needs an IPython environment. So we manually run this here.
   114      ie.current_env().track_user_pipelines()
   115  
   116      result = p.run()
   117      result.wait_until_finish()
   118  
   119      actual = list(result.get(counts))
   120      self.assertSetEqual(
   121          set(actual),
   122          set([
   123              ('or', 1),
   124              ('that', 1),
   125              ('be', 2),
   126              ('is', 1),
   127              ('question', 1),
   128              ('to', 2),
   129              ('the', 1),
   130              ('not', 1),
   131          ]))
   132  
   133      # Truncate the precision to millis because the window coder uses millis
   134      # as units then gets upcast to micros.
   135      end_of_window = (GlobalWindow().max_timestamp().micros // 1000) * 1000
   136      df_counts = ib.collect(counts, include_window_info=True, n=10)
   137      df_expected = pd.DataFrame({
   138          0: [e[0] for e in actual],
   139          1: [e[1] for e in actual],
   140          'event_time': [end_of_window for _ in actual],
   141          'windows': [[GlobalWindow()] for _ in actual],
   142          'pane_info': [
   143              PaneInfo(True, True, PaneInfoTiming.ON_TIME, 0, 0) for _ in actual
   144          ]
   145      },
   146                                 columns=[
   147                                     0, 1, 'event_time', 'windows', 'pane_info'
   148                                 ])
   149  
   150      pd.testing.assert_frame_equal(df_expected, df_counts)
   151  
   152      actual_reified = result.get(counts, include_window_info=True)
   153      expected_reified = [
   154          WindowedValue(
   155              e,
   156              Timestamp(micros=end_of_window), [GlobalWindow()],
   157              PaneInfo(True, True, PaneInfoTiming.ON_TIME, 0, 0)) for e in actual
   158      ]
   159      self.assertEqual(actual_reified, expected_reified)
   160  
   161    def test_streaming_wordcount(self):
   162      class WordExtractingDoFn(beam.DoFn):
   163        def process(self, element):
   164          text_line = element.strip()
   165          words = text_line.split()
   166          return words
   167  
   168      # Add the TestStream so that it can be cached.
   169      ib.options.recordable_sources.add(TestStream)
   170  
   171      p = beam.Pipeline(
   172          runner=interactive_runner.InteractiveRunner(),
   173          options=StandardOptions(streaming=True))
   174  
   175      data = (
   176          p
   177          | TestStream()
   178              .advance_watermark_to(0)
   179              .advance_processing_time(1)
   180              .add_elements(['to', 'be', 'or', 'not', 'to', 'be'])
   181              .advance_watermark_to(20)
   182              .advance_processing_time(1)
   183              .add_elements(['that', 'is', 'the', 'question'])
   184              .advance_watermark_to(30)
   185              .advance_processing_time(1)
   186              .advance_watermark_to(40)
   187              .advance_processing_time(1)
   188              .advance_watermark_to(50)
   189              .advance_processing_time(1)
   190          | beam.WindowInto(beam.window.FixedWindows(10))) # yapf: disable
   191  
   192      counts = (
   193          data
   194          | 'split' >> beam.ParDo(WordExtractingDoFn())
   195          | 'pair_with_one' >> beam.Map(lambda x: (x, 1))
   196          | 'group' >> beam.GroupByKey()
   197          | 'count' >> beam.Map(lambda wordones: (wordones[0], sum(wordones[1]))))
   198  
   199      # Watch the local scope for Interactive Beam so that referenced PCollections
   200      # will be cached.
   201      ib.watch(locals())
   202  
   203      # This is normally done in the interactive_utils when a transform is
   204      # applied but needs an IPython environment. So we manually run this here.
   205      ie.current_env().track_user_pipelines()
   206  
   207      # This tests that the data was correctly cached.
   208      pane_info = PaneInfo(True, True, PaneInfoTiming.UNKNOWN, 0, 0)
   209      expected_data_df = pd.DataFrame([
   210          ('to', 0, [IntervalWindow(0, 10)], pane_info),
   211          ('be', 0, [IntervalWindow(0, 10)], pane_info),
   212          ('or', 0, [IntervalWindow(0, 10)], pane_info),
   213          ('not', 0, [IntervalWindow(0, 10)], pane_info),
   214          ('to', 0, [IntervalWindow(0, 10)], pane_info),
   215          ('be', 0, [IntervalWindow(0, 10)], pane_info),
   216          ('that', 20000000, [IntervalWindow(20, 30)], pane_info),
   217          ('is', 20000000, [IntervalWindow(20, 30)], pane_info),
   218          ('the', 20000000, [IntervalWindow(20, 30)], pane_info),
   219          ('question', 20000000, [IntervalWindow(20, 30)], pane_info)
   220      ], columns=[0, 'event_time', 'windows', 'pane_info']) # yapf: disable
   221  
   222      data_df = ib.collect(data, n=10, include_window_info=True)
   223      pd.testing.assert_frame_equal(expected_data_df, data_df)
   224  
   225      # This tests that the windowing was passed correctly so that all the data
   226      # is aggregated also correctly.
   227      pane_info = PaneInfo(True, False, PaneInfoTiming.ON_TIME, 0, 0)
   228      expected_counts_df = pd.DataFrame([
   229          ('be', 2, 9999999, [IntervalWindow(0, 10)], pane_info),
   230          ('not', 1, 9999999, [IntervalWindow(0, 10)], pane_info),
   231          ('or', 1, 9999999, [IntervalWindow(0, 10)], pane_info),
   232          ('to', 2, 9999999, [IntervalWindow(0, 10)], pane_info),
   233          ('is', 1, 29999999, [IntervalWindow(20, 30)], pane_info),
   234          ('question', 1, 29999999, [IntervalWindow(20, 30)], pane_info),
   235          ('that', 1, 29999999, [IntervalWindow(20, 30)], pane_info),
   236          ('the', 1, 29999999, [IntervalWindow(20, 30)], pane_info),
   237      ], columns=[0, 1, 'event_time', 'windows', 'pane_info']) # yapf: disable
   238  
   239      counts_df = ib.collect(counts, n=8, include_window_info=True)
   240  
   241      # The group by key has no guarantee of order. So we post-process the DF by
   242      # sorting so we can test equality.
   243      sorted_counts_df = (counts_df
   244                          .sort_values(['event_time', 0], ascending=True)
   245                          .reset_index(drop=True)) # yapf: disable
   246      pd.testing.assert_frame_equal(expected_counts_df, sorted_counts_df)
   247  
   248    def test_session(self):
   249      class MockPipelineRunner(object):
   250        def __init__(self):
   251          self._in_session = False
   252  
   253        def __enter__(self):
   254          self._in_session = True
   255  
   256        def __exit__(self, exc_type, exc_val, exc_tb):
   257          self._in_session = False
   258  
   259      underlying_runner = MockPipelineRunner()
   260      runner = interactive_runner.InteractiveRunner(underlying_runner)
   261      runner.start_session()
   262      self.assertTrue(underlying_runner._in_session)
   263      runner.end_session()
   264      self.assertFalse(underlying_runner._in_session)
   265  
   266    @unittest.skipIf(
   267        not ie.current_env().is_interactive_ready,
   268        '[interactive] dependency is not installed.')
   269    def test_mark_pcollection_completed_after_successful_run(self):
   270      with self.cell:  # Cell 1
   271        p = beam.Pipeline(interactive_runner.InteractiveRunner())
   272        ib.watch({'p': p})
   273  
   274      with self.cell:  # Cell 2
   275        # pylint: disable=bad-option-value
   276        init = p | 'Init' >> beam.Create(range(5))
   277  
   278      with self.cell:  # Cell 3
   279        square = init | 'Square' >> beam.Map(lambda x: x * x)
   280        cube = init | 'Cube' >> beam.Map(lambda x: x**3)
   281  
   282      ib.watch(locals())
   283      result = p.run()
   284      self.assertTrue(init in ie.current_env().computed_pcollections)
   285      self.assertEqual({0, 1, 2, 3, 4}, set(result.get(init)))
   286      self.assertTrue(square in ie.current_env().computed_pcollections)
   287      self.assertEqual({0, 1, 4, 9, 16}, set(result.get(square)))
   288      self.assertTrue(cube in ie.current_env().computed_pcollections)
   289      self.assertEqual({0, 1, 8, 27, 64}, set(result.get(cube)))
   290  
   291    @unittest.skipIf(sys.platform == "win32", "[BEAM-10627]")
   292    def test_dataframes(self):
   293      p = beam.Pipeline(
   294          runner=interactive_runner.InteractiveRunner(
   295              direct_runner.DirectRunner()))
   296      data = p | beam.Create(
   297          [1, 2, 3]) | beam.Map(lambda x: beam.Row(square=x * x, cube=x * x * x))
   298      df = to_dataframe(data)
   299  
   300      # Watch the local scope for Interactive Beam so that values will be cached.
   301      ib.watch(locals())
   302  
   303      # This is normally done in the interactive_utils when a transform is
   304      # applied but needs an IPython environment. So we manually run this here.
   305      ie.current_env().track_user_pipelines()
   306  
   307      df_expected = pd.DataFrame({'square': [1, 4, 9], 'cube': [1, 8, 27]})
   308      pd.testing.assert_frame_equal(
   309          df_expected, ib.collect(df, n=10).reset_index(drop=True))
   310  
   311    @unittest.skipIf(sys.platform == "win32", "[BEAM-10627]")
   312    def test_dataframes_with_grouped_index(self):
   313      p = beam.Pipeline(
   314          runner=interactive_runner.InteractiveRunner(
   315              direct_runner.DirectRunner()))
   316  
   317      data = [
   318          Record('a', 20, 170),
   319          Record('a', 30, 170),
   320          Record('b', 22, 180),
   321          Record('c', 18, 150)
   322      ]
   323  
   324      aggregate = lambda df: df.groupby('height').mean()
   325  
   326      deferred_df = aggregate(to_dataframe(p | beam.Create(data)))
   327      df_expected = aggregate(pd.DataFrame(data))
   328  
   329      # Watch the local scope for Interactive Beam so that values will be cached.
   330      ib.watch(locals())
   331  
   332      # This is normally done in the interactive_utils when a transform is
   333      # applied but needs an IPython environment. So we manually run this here.
   334      ie.current_env().track_user_pipelines()
   335  
   336      pd.testing.assert_frame_equal(df_expected, ib.collect(deferred_df, n=10))
   337  
   338    @unittest.skipIf(sys.platform == "win32", "[BEAM-10627]")
   339    def test_dataframes_with_multi_index(self):
   340      p = beam.Pipeline(
   341          runner=interactive_runner.InteractiveRunner(
   342              direct_runner.DirectRunner()))
   343  
   344      data = [
   345          Record('a', 20, 170),
   346          Record('a', 30, 170),
   347          Record('b', 22, 180),
   348          Record('c', 18, 150)
   349      ]
   350  
   351      aggregate = lambda df: df.groupby(['name', 'height']).mean()
   352  
   353      deferred_df = aggregate(to_dataframe(p | beam.Create(data)))
   354      df_input = pd.DataFrame(data)
   355      df_input.name = df_input.name.astype(pd.StringDtype())
   356      df_expected = aggregate(df_input)
   357  
   358      # Watch the local scope for Interactive Beam so that values will be cached.
   359      ib.watch(locals())
   360  
   361      # This is normally done in the interactive_utils when a transform is
   362      # applied but needs an IPython environment. So we manually run this here.
   363      ie.current_env().track_user_pipelines()
   364  
   365      pd.testing.assert_frame_equal(df_expected, ib.collect(deferred_df, n=10))
   366  
   367    @unittest.skipIf(sys.platform == "win32", "[BEAM-10627]")
   368    def test_dataframes_with_multi_index_get_result(self):
   369      p = beam.Pipeline(
   370          runner=interactive_runner.InteractiveRunner(
   371              direct_runner.DirectRunner()))
   372  
   373      data = [
   374          Record('a', 20, 170),
   375          Record('a', 30, 170),
   376          Record('b', 22, 180),
   377          Record('c', 18, 150)
   378      ]
   379  
   380      aggregate = lambda df: df.groupby(['name', 'height']).mean()['age']
   381  
   382      deferred_df = aggregate(to_dataframe(p | beam.Create(data)))
   383      df_input = pd.DataFrame(data)
   384      df_input.name = df_input.name.astype(pd.StringDtype())
   385      df_expected = aggregate(df_input)
   386  
   387      # Watch the local scope for Interactive Beam so that values will be cached.
   388      ib.watch(locals())
   389  
   390      # This is normally done in the interactive_utils when a transform is
   391      # applied but needs an IPython environment. So we manually run this here.
   392      ie.current_env().track_user_pipelines()
   393  
   394      pd.testing.assert_series_equal(df_expected, ib.collect(deferred_df, n=10))
   395  
   396    @unittest.skipIf(sys.platform == "win32", "[BEAM-10627]")
   397    def test_dataframes_same_cell_twice(self):
   398      p = beam.Pipeline(
   399          runner=interactive_runner.InteractiveRunner(
   400              direct_runner.DirectRunner()))
   401      data = p | beam.Create(
   402          [1, 2, 3]) | beam.Map(lambda x: beam.Row(square=x * x, cube=x * x * x))
   403      df = to_dataframe(data)
   404  
   405      # Watch the local scope for Interactive Beam so that values will be cached.
   406      ib.watch(locals())
   407  
   408      # This is normally done in the interactive_utils when a transform is
   409      # applied but needs an IPython environment. So we manually run this here.
   410      ie.current_env().track_user_pipelines()
   411  
   412      df_expected = pd.DataFrame({'square': [1, 4, 9], 'cube': [1, 8, 27]})
   413      pd.testing.assert_series_equal(
   414          df_expected['square'],
   415          ib.collect(df['square'], n=10).reset_index(drop=True))
   416      pd.testing.assert_series_equal(
   417          df_expected['cube'],
   418          ib.collect(df['cube'], n=10).reset_index(drop=True))
   419  
   420    @unittest.skipIf(
   421        not ie.current_env().is_interactive_ready,
   422        '[interactive] dependency is not installed.')
   423    @unittest.skipIf(sys.platform == "win32", "[BEAM-10627]")
   424    def test_dataframe_caching(self):
   425      # Create a pipeline that exercises the DataFrame API. This will also use
   426      # caching in the background.
   427      with self.cell:  # Cell 1
   428        p = beam.Pipeline(interactive_runner.InteractiveRunner())
   429        ib.watch({'p': p})
   430  
   431      with self.cell:  # Cell 2
   432        data = p | beam.Create([
   433            1, 2, 3
   434        ]) | beam.Map(lambda x: beam.Row(square=x * x, cube=x * x * x))
   435  
   436        with beam.dataframe.allow_non_parallel_operations():
   437          df = to_dataframe(data).reset_index(drop=True)
   438  
   439        ib.collect(df)
   440  
   441      with self.cell:  # Cell 3
   442        df['output'] = df['square'] * df['cube']
   443        ib.collect(df)
   444  
   445      with self.cell:  # Cell 4
   446        df['output'] = 0
   447        ib.collect(df)
   448  
   449      # We use a trace through the graph to perform an isomorphism test. The end
   450      # output should look like a linear graph. This indicates that the dataframe
   451      # transform was correctly broken into separate pieces to cache. If caching
   452      # isn't enabled, all the dataframe computation nodes are connected to a
   453      # single shared node.
   454      trace = []
   455  
   456      # Only look at the top-level transforms for the isomorphism. The test
   457      # doesn't care about the transform implementations, just the overall shape.
   458      class TopLevelTracer(beam.pipeline.PipelineVisitor):
   459        def _find_root_producer(self, node: beam.pipeline.AppliedPTransform):
   460          if node is None or not node.full_label:
   461            return None
   462  
   463          parent = self._find_root_producer(node.parent)
   464          if parent is None:
   465            return node
   466  
   467          return parent
   468  
   469        def _add_to_trace(self, node, trace):
   470          if '/' not in str(node):
   471            if node.inputs:
   472              producer = self._find_root_producer(node.inputs[0].producer)
   473              producer_name = producer.full_label if producer else ''
   474              trace.append((producer_name, node.full_label))
   475  
   476        def visit_transform(self, node: beam.pipeline.AppliedPTransform):
   477          self._add_to_trace(node, trace)
   478  
   479        def enter_composite_transform(
   480            self, node: beam.pipeline.AppliedPTransform):
   481          self._add_to_trace(node, trace)
   482  
   483      p.visit(TopLevelTracer())
   484  
   485      # Do the isomorphism test which states that the topological sort of the
   486      # graph yields a linear graph.
   487      trace_string = '\n'.join(str(t) for t in trace)
   488      prev_producer = ''
   489      for producer, consumer in trace:
   490        self.assertEqual(producer, prev_producer, trace_string)
   491        prev_producer = consumer
   492  
   493  
   494  @unittest.skipIf(
   495      not ie.current_env().is_interactive_ready,
   496      '[interactive] dependency is not installed.')
   497  @isolated_env
   498  class ConfigForFlinkTest(unittest.TestCase):
   499    def setUp(self):
   500      self.current_env.options.cache_root = 'gs://fake'
   501  
   502    def tearDown(self):
   503      self.current_env.options.cache_root = None
   504  
   505    def test_create_a_new_cluster_for_a_new_pipeline(self):
   506      clusters = self.current_env.clusters
   507      runner = interactive_runner.InteractiveRunner(
   508          underlying_runner=FlinkRunner())
   509      options = PipelineOptions(project='test-project', region='test-region')
   510      p = beam.Pipeline(runner=runner, options=options)
   511      runner.configure_for_flink(p, options)
   512  
   513      # Fetch the metadata and assert all side effects.
   514      meta = clusters.cluster_metadata(p)
   515      # The metadata should have all fields populated.
   516      self.assertEqual(meta.project_id, 'test-project')
   517      self.assertEqual(meta.region, 'test-region')
   518      self.assertTrue(meta.cluster_name.startswith('interactive-beam-'))
   519      self.assertTrue(meta.master_url.startswith('test-url'))
   520      self.assertEqual(meta.dashboard, 'test-dashboard')
   521      # The cluster is known now.
   522      self.assertIn(meta, clusters.dataproc_cluster_managers)
   523      self.assertIn(meta.master_url, clusters.master_urls)
   524      self.assertIn(p, clusters.pipelines)
   525      # The default cluster is updated to the created cluster.
   526      self.assertIs(meta, clusters.default_cluster_metadata)
   527      # The pipeline options is tuned for execution on the cluster.
   528      flink_options = options.view_as(FlinkRunnerOptions)
   529      self.assertEqual(flink_options.flink_master, meta.master_url)
   530      self.assertEqual(
   531          flink_options.flink_version, clusters.DATAPROC_FLINK_VERSION)
   532  
   533    def test_reuse_a_cluster_for_a_known_pipeline(self):
   534      clusters = self.current_env.clusters
   535      runner = interactive_runner.InteractiveRunner(
   536          underlying_runner=FlinkRunner())
   537      options = PipelineOptions(project='test-project', region='test-region')
   538      p = beam.Pipeline(runner=runner, options=options)
   539      meta = ClusterMetadata(project_id='test-project', region='test-region')
   540      dcm = DataprocClusterManager(meta)
   541      # Configure the clusters so that the pipeline is known.
   542      clusters.pipelines[p] = dcm
   543      runner.configure_for_flink(p, options)
   544  
   545      # A known cluster is reused.
   546      tuned_meta = clusters.cluster_metadata(p)
   547      self.assertIs(tuned_meta, meta)
   548  
   549    def test_reuse_a_known_cluster_for_unknown_pipeline(self):
   550      clusters = self.current_env.clusters
   551      runner = interactive_runner.InteractiveRunner(
   552          underlying_runner=FlinkRunner())
   553      options = PipelineOptions(project='test-project', region='test-region')
   554      p = beam.Pipeline(runner=runner, options=options)
   555      meta = ClusterMetadata(project_id='test-project', region='test-region')
   556      dcm = DataprocClusterManager(meta)
   557      # Configure the clusters so that the cluster is known.
   558      clusters.dataproc_cluster_managers[meta] = dcm
   559      clusters.set_default_cluster(meta)
   560      runner.configure_for_flink(p, options)
   561  
   562      # A known cluster is reused.
   563      tuned_meta = clusters.cluster_metadata(p)
   564      self.assertIs(tuned_meta, meta)
   565      # The pipeline is known.
   566      self.assertIn(p, clusters.pipelines)
   567      registered_dcm = clusters.pipelines[p]
   568      self.assertIn(p, registered_dcm.pipelines)
   569  
   570    def test_reuse_default_cluster_if_not_configured(self):
   571      clusters = self.current_env.clusters
   572      runner = interactive_runner.InteractiveRunner(
   573          underlying_runner=FlinkRunner())
   574      options = PipelineOptions()
   575      # Pipeline is not configured to run on Cloud.
   576      p = beam.Pipeline(runner=runner, options=options)
   577      meta = ClusterMetadata(project_id='test-project', region='test-region')
   578      meta.master_url = 'test-url'
   579      meta.dashboard = 'test-dashboard'
   580      dcm = DataprocClusterManager(meta)
   581      # Configure the clusters so that a default cluster is known.
   582      clusters.dataproc_cluster_managers[meta] = dcm
   583      clusters.set_default_cluster(meta)
   584      runner.configure_for_flink(p, options)
   585  
   586      # The default cluster is used.
   587      tuned_meta = clusters.cluster_metadata(p)
   588      self.assertIs(tuned_meta, clusters.default_cluster_metadata)
   589      # The pipeline is known.
   590      self.assertIn(p, clusters.pipelines)
   591      registered_dcm = clusters.pipelines[p]
   592      self.assertIn(p, registered_dcm.pipelines)
   593      # The pipeline options is tuned for execution on the cluster.
   594      flink_options = options.view_as(FlinkRunnerOptions)
   595      self.assertEqual(flink_options.flink_master, tuned_meta.master_url)
   596      self.assertEqual(
   597          flink_options.flink_version, clusters.DATAPROC_FLINK_VERSION)
   598  
   599    def test_worker_options_to_cluster_metadata(self):
   600      clusters = self.current_env.clusters
   601      runner = interactive_runner.InteractiveRunner(
   602          underlying_runner=FlinkRunner())
   603      options = PipelineOptions(project='test-project', region='test-region')
   604      worker_options = options.view_as(WorkerOptions)
   605      worker_options.num_workers = 2
   606      worker_options.subnetwork = 'test-network'
   607      worker_options.machine_type = 'test-machine-type'
   608      p = beam.Pipeline(runner=runner, options=options)
   609      runner.configure_for_flink(p, options)
   610  
   611      configured_meta = clusters.cluster_metadata(p)
   612      self.assertEqual(configured_meta.num_workers, worker_options.num_workers)
   613      self.assertEqual(configured_meta.subnetwork, worker_options.subnetwork)
   614      self.assertEqual(configured_meta.machine_type, worker_options.machine_type)
   615  
   616    def test_configure_flink_options(self):
   617      clusters = self.current_env.clusters
   618      runner = interactive_runner.InteractiveRunner(
   619          underlying_runner=FlinkRunner())
   620      options = PipelineOptions(project='test-project', region='test-region')
   621      p = beam.Pipeline(runner=runner, options=options)
   622      runner.configure_for_flink(p, options)
   623  
   624      flink_options = options.view_as(FlinkRunnerOptions)
   625      self.assertEqual(
   626          flink_options.flink_version, clusters.DATAPROC_FLINK_VERSION)
   627      self.assertTrue(flink_options.flink_master.startswith('test-url-'))
   628  
   629    def test_configure_flink_options_with_flink_version_overridden(self):
   630      clusters = self.current_env.clusters
   631      runner = interactive_runner.InteractiveRunner(
   632          underlying_runner=FlinkRunner())
   633      options = PipelineOptions(project='test-project', region='test-region')
   634      flink_options = options.view_as(FlinkRunnerOptions)
   635      flink_options.flink_version = 'test-version'
   636      p = beam.Pipeline(runner=runner, options=options)
   637      runner.configure_for_flink(p, options)
   638  
   639      # The version is overridden to the flink version used by the EMR solution,
   640      # currently only 1: Cloud Dataproc.
   641      self.assertEqual(
   642          flink_options.flink_version, clusters.DATAPROC_FLINK_VERSION)
   643  
   644    def test_strip_http_protocol_from_flink_master(self):
   645      runner = interactive_runner.InteractiveRunner(
   646          underlying_runner=FlinkRunner())
   647      stripped = runner._strip_protocol_if_any('https://flink-master')
   648  
   649      self.assertEqual('flink-master', stripped)
   650  
   651    def test_no_strip_from_flink_master(self):
   652      runner = interactive_runner.InteractiveRunner(
   653          underlying_runner=FlinkRunner())
   654      stripped = runner._strip_protocol_if_any('flink-master')
   655  
   656      self.assertEqual('flink-master', stripped)
   657  
   658    def test_no_strip_from_non_flink_master(self):
   659      runner = interactive_runner.InteractiveRunner(
   660          underlying_runner=FlinkRunner())
   661      stripped = runner._strip_protocol_if_any(None)
   662  
   663      self.assertIsNone(stripped)
   664  
   665  
   666  if __name__ == '__main__':
   667    unittest.main()