github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/pipeline_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Unit tests for the Pipeline class."""
    19  
    20  # pytype: skip-file
    21  
    22  import copy
    23  import platform
    24  import unittest
    25  
    26  import mock
    27  import pytest
    28  
    29  import apache_beam as beam
    30  from apache_beam import typehints
    31  from apache_beam.coders import BytesCoder
    32  from apache_beam.io import Read
    33  from apache_beam.metrics import Metrics
    34  from apache_beam.options.pipeline_options import PortableOptions
    35  from apache_beam.pipeline import Pipeline
    36  from apache_beam.pipeline import PipelineOptions
    37  from apache_beam.pipeline import PipelineVisitor
    38  from apache_beam.pipeline import PTransformOverride
    39  from apache_beam.portability import common_urns
    40  from apache_beam.portability.api import beam_runner_api_pb2
    41  from apache_beam.pvalue import AsSingleton
    42  from apache_beam.pvalue import TaggedOutput
    43  from apache_beam.runners.dataflow.native_io.iobase import NativeSource
    44  from apache_beam.testing.test_pipeline import TestPipeline
    45  from apache_beam.testing.util import assert_that
    46  from apache_beam.testing.util import equal_to
    47  from apache_beam.transforms import CombineGlobally
    48  from apache_beam.transforms import Create
    49  from apache_beam.transforms import DoFn
    50  from apache_beam.transforms import FlatMap
    51  from apache_beam.transforms import Map
    52  from apache_beam.transforms import ParDo
    53  from apache_beam.transforms import PTransform
    54  from apache_beam.transforms import WindowInto
    55  from apache_beam.transforms.display import DisplayDataItem
    56  from apache_beam.transforms.environments import ProcessEnvironment
    57  from apache_beam.transforms.resources import ResourceHint
    58  from apache_beam.transforms.userstate import BagStateSpec
    59  from apache_beam.transforms.window import SlidingWindows
    60  from apache_beam.transforms.window import TimestampedValue
    61  from apache_beam.utils import windowed_value
    62  from apache_beam.utils.timestamp import MIN_TIMESTAMP
    63  
    64  # TODO(BEAM-1555): Test is failing on the service, with FakeSource.
    65  
    66  
    67  class FakeSource(NativeSource):
    68    """Fake source returning a fixed list of values."""
    69    class _Reader(object):
    70      def __init__(self, vals):
    71        self._vals = vals
    72        self._output_counter = Metrics.counter('main', 'outputs')
    73  
    74      def __enter__(self):
    75        return self
    76  
    77      def __exit__(self, exception_type, exception_value, traceback):
    78        pass
    79  
    80      def __iter__(self):
    81        for v in self._vals:
    82          self._output_counter.inc()
    83          yield v
    84  
    85    def __init__(self, vals):
    86      self._vals = vals
    87  
    88    def reader(self):
    89      return FakeSource._Reader(self._vals)
    90  
    91  
    92  class FakeUnboundedSource(NativeSource):
    93    """Fake unbounded source. Does not work at runtime"""
    94    def reader(self):
    95      return None
    96  
    97    def is_bounded(self):
    98      return False
    99  
   100  
   101  class DoubleParDo(beam.PTransform):
   102    def expand(self, input):
   103      return input | 'Inner' >> beam.Map(lambda a: a * 2)
   104  
   105    def to_runner_api_parameter(self, context):
   106      return self.to_runner_api_pickled(context)
   107  
   108  
   109  class TripleParDo(beam.PTransform):
   110    def expand(self, input):
   111      # Keeping labels the same intentionally to make sure that there is no label
   112      # conflict due to replacement.
   113      return input | 'Inner' >> beam.Map(lambda a: a * 3)
   114  
   115  
   116  class ToStringParDo(beam.PTransform):
   117    def expand(self, input):
   118      # We use copy.copy() here to make sure the typehint mechanism doesn't
   119      # automatically infer that the output type is str.
   120      return input | 'Inner' >> beam.Map(lambda a: copy.copy(str(a)))
   121  
   122  
   123  class FlattenAndDouble(beam.PTransform):
   124    def expand(self, pcolls):
   125      return pcolls | beam.Flatten() | 'Double' >> DoubleParDo()
   126  
   127  
   128  class FlattenAndTriple(beam.PTransform):
   129    def expand(self, pcolls):
   130      return pcolls | beam.Flatten() | 'Triple' >> TripleParDo()
   131  
   132  
   133  class AddWithProductDoFn(beam.DoFn):
   134    def process(self, input, a, b):
   135      yield input + a * b
   136  
   137  
   138  class AddThenMultiplyDoFn(beam.DoFn):
   139    def process(self, input, a, b):
   140      yield (input + a) * b
   141  
   142  
   143  class AddThenMultiply(beam.PTransform):
   144    def expand(self, pvalues):
   145      return pvalues[0] | beam.ParDo(
   146          AddThenMultiplyDoFn(), AsSingleton(pvalues[1]), AsSingleton(pvalues[2]))
   147  
   148  
   149  class PipelineTest(unittest.TestCase):
   150    @staticmethod
   151    def custom_callable(pcoll):
   152      return pcoll | '+1' >> FlatMap(lambda x: [x + 1])
   153  
   154    # Some of these tests designate a runner by name, others supply a runner.
   155    # This variation is just to verify that both means of runner specification
   156    # work and is not related to other aspects of the tests.
   157  
   158    class CustomTransform(PTransform):
   159      def expand(self, pcoll):
   160        return pcoll | '+1' >> FlatMap(lambda x: [x + 1])
   161  
   162    class Visitor(PipelineVisitor):
   163      def __init__(self, visited):
   164        self.visited = visited
   165        self.enter_composite = []
   166        self.leave_composite = []
   167  
   168      def visit_value(self, value, _):
   169        self.visited.append(value)
   170  
   171      def enter_composite_transform(self, transform_node):
   172        self.enter_composite.append(transform_node)
   173  
   174      def leave_composite_transform(self, transform_node):
   175        self.leave_composite.append(transform_node)
   176  
   177    def test_create(self):
   178      with TestPipeline() as pipeline:
   179        pcoll = pipeline | 'label1' >> Create([1, 2, 3])
   180        assert_that(pcoll, equal_to([1, 2, 3]))
   181  
   182        # Test if initial value is an iterator object.
   183        pcoll2 = pipeline | 'label2' >> Create(iter((4, 5, 6)))
   184        pcoll3 = pcoll2 | 'do' >> FlatMap(lambda x: [x + 10])
   185        assert_that(pcoll3, equal_to([14, 15, 16]), label='pcoll3')
   186  
   187    def test_flatmap_builtin(self):
   188      with TestPipeline() as pipeline:
   189        pcoll = pipeline | 'label1' >> Create([1, 2, 3])
   190        assert_that(pcoll, equal_to([1, 2, 3]))
   191  
   192        pcoll2 = pcoll | 'do' >> FlatMap(lambda x: [x + 10])
   193        assert_that(pcoll2, equal_to([11, 12, 13]), label='pcoll2')
   194  
   195        pcoll3 = pcoll2 | 'm1' >> Map(lambda x: [x, 12])
   196        assert_that(
   197            pcoll3, equal_to([[11, 12], [12, 12], [13, 12]]), label='pcoll3')
   198  
   199        pcoll4 = pcoll3 | 'do2' >> FlatMap(set)
   200        assert_that(pcoll4, equal_to([11, 12, 12, 12, 13]), label='pcoll4')
   201  
   202    def test_maptuple_builtin(self):
   203      with TestPipeline() as pipeline:
   204        pcoll = pipeline | Create([('e1', 'e2')])
   205        side1 = beam.pvalue.AsSingleton(pipeline | 'side1' >> Create(['s1']))
   206        side2 = beam.pvalue.AsSingleton(pipeline | 'side2' >> Create(['s2']))
   207  
   208        # A test function with a tuple input, an auxiliary parameter,
   209        # and some side inputs.
   210        fn = lambda e1, e2, t=DoFn.TimestampParam, s1=None, s2=None: (
   211            e1, e2, t, s1, s2)
   212        assert_that(
   213            pcoll | 'NoSides' >> beam.core.MapTuple(fn),
   214            equal_to([('e1', 'e2', MIN_TIMESTAMP, None, None)]),
   215            label='NoSidesCheck')
   216        assert_that(
   217            pcoll | 'StaticSides' >> beam.core.MapTuple(fn, 's1', 's2'),
   218            equal_to([('e1', 'e2', MIN_TIMESTAMP, 's1', 's2')]),
   219            label='StaticSidesCheck')
   220        assert_that(
   221            pcoll | 'DynamicSides' >> beam.core.MapTuple(fn, side1, side2),
   222            equal_to([('e1', 'e2', MIN_TIMESTAMP, 's1', 's2')]),
   223            label='DynamicSidesCheck')
   224        assert_that(
   225            pcoll | 'MixedSides' >> beam.core.MapTuple(fn, s2=side2),
   226            equal_to([('e1', 'e2', MIN_TIMESTAMP, None, 's2')]),
   227            label='MixedSidesCheck')
   228  
   229    def test_flatmaptuple_builtin(self):
   230      with TestPipeline() as pipeline:
   231        pcoll = pipeline | Create([('e1', 'e2')])
   232        side1 = beam.pvalue.AsSingleton(pipeline | 'side1' >> Create(['s1']))
   233        side2 = beam.pvalue.AsSingleton(pipeline | 'side2' >> Create(['s2']))
   234  
   235        # A test function with a tuple input, an auxiliary parameter,
   236        # and some side inputs.
   237        fn = lambda e1, e2, t=DoFn.TimestampParam, s1=None, s2=None: (
   238            e1, e2, t, s1, s2)
   239        assert_that(
   240            pcoll | 'NoSides' >> beam.core.FlatMapTuple(fn),
   241            equal_to(['e1', 'e2', MIN_TIMESTAMP, None, None]),
   242            label='NoSidesCheck')
   243        assert_that(
   244            pcoll | 'StaticSides' >> beam.core.FlatMapTuple(fn, 's1', 's2'),
   245            equal_to(['e1', 'e2', MIN_TIMESTAMP, 's1', 's2']),
   246            label='StaticSidesCheck')
   247        assert_that(
   248            pcoll
   249            | 'DynamicSides' >> beam.core.FlatMapTuple(fn, side1, side2),
   250            equal_to(['e1', 'e2', MIN_TIMESTAMP, 's1', 's2']),
   251            label='DynamicSidesCheck')
   252        assert_that(
   253            pcoll | 'MixedSides' >> beam.core.FlatMapTuple(fn, s2=side2),
   254            equal_to(['e1', 'e2', MIN_TIMESTAMP, None, 's2']),
   255            label='MixedSidesCheck')
   256  
   257    def test_create_singleton_pcollection(self):
   258      with TestPipeline() as pipeline:
   259        pcoll = pipeline | 'label' >> Create([[1, 2, 3]])
   260        assert_that(pcoll, equal_to([[1, 2, 3]]))
   261  
   262    # TODO(BEAM-1555): Test is failing on the service, with FakeSource.
   263    # @pytest.mark.it_validatesrunner
   264    def test_metrics_in_fake_source(self):
   265      pipeline = TestPipeline()
   266      pcoll = pipeline | Read(FakeSource([1, 2, 3, 4, 5, 6]))
   267      assert_that(pcoll, equal_to([1, 2, 3, 4, 5, 6]))
   268      res = pipeline.run()
   269      metric_results = res.metrics().query()
   270      outputs_counter = metric_results['counters'][0]
   271      self.assertEqual(outputs_counter.key.step, 'Read')
   272      self.assertEqual(outputs_counter.key.metric.name, 'outputs')
   273      self.assertEqual(outputs_counter.committed, 6)
   274  
   275    def test_fake_read(self):
   276      with TestPipeline() as pipeline:
   277        pcoll = pipeline | 'read' >> Read(FakeSource([1, 2, 3]))
   278        assert_that(pcoll, equal_to([1, 2, 3]))
   279  
   280    def test_visit_entire_graph(self):
   281      pipeline = Pipeline()
   282      pcoll1 = pipeline | 'pcoll' >> beam.Impulse()
   283      pcoll2 = pcoll1 | 'do1' >> FlatMap(lambda x: [x + 1])
   284      pcoll3 = pcoll2 | 'do2' >> FlatMap(lambda x: [x + 1])
   285      pcoll4 = pcoll2 | 'do3' >> FlatMap(lambda x: [x + 1])
   286      transform = PipelineTest.CustomTransform()
   287      pcoll5 = pcoll4 | transform
   288  
   289      visitor = PipelineTest.Visitor(visited=[])
   290      pipeline.visit(visitor)
   291      self.assertEqual({pcoll1, pcoll2, pcoll3, pcoll4, pcoll5},
   292                       set(visitor.visited))
   293      self.assertEqual(set(visitor.enter_composite), set(visitor.leave_composite))
   294      self.assertEqual(2, len(visitor.enter_composite))
   295      self.assertEqual(visitor.enter_composite[1].transform, transform)
   296      self.assertEqual(visitor.leave_composite[0].transform, transform)
   297  
   298    def test_apply_custom_transform(self):
   299      with TestPipeline() as pipeline:
   300        pcoll = pipeline | 'pcoll' >> Create([1, 2, 3])
   301        result = pcoll | PipelineTest.CustomTransform()
   302        assert_that(result, equal_to([2, 3, 4]))
   303  
   304    def test_reuse_custom_transform_instance(self):
   305      pipeline = Pipeline()
   306      pcoll1 = pipeline | 'pcoll1' >> Create([1, 2, 3])
   307      pcoll2 = pipeline | 'pcoll2' >> Create([4, 5, 6])
   308      transform = PipelineTest.CustomTransform()
   309      pcoll1 | transform
   310      with self.assertRaises(RuntimeError) as cm:
   311        pipeline.apply(transform, pcoll2)
   312      self.assertEqual(
   313          cm.exception.args[0],
   314          'A transform with label "CustomTransform" already exists in the '
   315          'pipeline. To apply a transform with a specified label write '
   316          'pvalue | "label" >> transform')
   317  
   318    def test_reuse_cloned_custom_transform_instance(self):
   319      with TestPipeline() as pipeline:
   320        pcoll1 = pipeline | 'pc1' >> Create([1, 2, 3])
   321        pcoll2 = pipeline | 'pc2' >> Create([4, 5, 6])
   322        transform = PipelineTest.CustomTransform()
   323        result1 = pcoll1 | transform
   324        result2 = pcoll2 | 'new_label' >> transform
   325        assert_that(result1, equal_to([2, 3, 4]), label='r1')
   326        assert_that(result2, equal_to([5, 6, 7]), label='r2')
   327  
   328    def test_transform_no_super_init(self):
   329      class AddSuffix(PTransform):
   330        def __init__(self, suffix):
   331          # No call to super(...).__init__
   332          self.suffix = suffix
   333  
   334        def expand(self, pcoll):
   335          return pcoll | Map(lambda x: x + self.suffix)
   336  
   337      self.assertEqual(['a-x', 'b-x', 'c-x'],
   338                       sorted(['a', 'b', 'c'] | 'AddSuffix' >> AddSuffix('-x')))
   339  
   340    @unittest.skip("Fails on some platforms with new urllib3.")
   341    def test_memory_usage(self):
   342      try:
   343        import resource
   344      except ImportError:
   345        # Skip the test if resource module is not available (e.g. non-Unix os).
   346        self.skipTest('resource module not available.')
   347      if platform.mac_ver()[0]:
   348        # Skip the test on macos, depending on version it returns ru_maxrss in
   349        # different units.
   350        self.skipTest('ru_maxrss is not in standard units.')
   351  
   352      def get_memory_usage_in_bytes():
   353        return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss * (2**10)
   354  
   355      def check_memory(value, memory_threshold):
   356        memory_usage = get_memory_usage_in_bytes()
   357        if memory_usage > memory_threshold:
   358          raise RuntimeError(
   359              'High memory usage: %d > %d' % (memory_usage, memory_threshold))
   360        return value
   361  
   362      len_elements = 1000000
   363      num_elements = 10
   364      num_maps = 100
   365  
   366      # TODO(robertwb): reduce memory usage of FnApiRunner so that this test
   367      # passes.
   368      with TestPipeline(runner='BundleBasedDirectRunner') as pipeline:
   369  
   370        # Consumed memory should not be proportional to the number of maps.
   371        memory_threshold = (
   372            get_memory_usage_in_bytes() + (5 * len_elements * num_elements))
   373  
   374        # Plus small additional slack for memory fluctuations during the test.
   375        memory_threshold += 10 * (2**20)
   376  
   377        biglist = pipeline | 'oom:create' >> Create(
   378            ['x' * len_elements] * num_elements)
   379        for i in range(num_maps):
   380          biglist = biglist | ('oom:addone-%d' % i) >> Map(lambda x: x + 'y')
   381        result = biglist | 'oom:check' >> Map(check_memory, memory_threshold)
   382        assert_that(
   383            result,
   384            equal_to(['x' * len_elements + 'y' * num_maps] * num_elements))
   385  
   386    def test_aggregator_empty_input(self):
   387      actual = [] | CombineGlobally(max).without_defaults()
   388      self.assertEqual(actual, [])
   389  
   390    def test_pipeline_as_context(self):
   391      def raise_exception(exn):
   392        raise exn
   393  
   394      with self.assertRaises(ValueError):
   395        with Pipeline() as p:
   396          # pylint: disable=expression-not-assigned
   397          p | Create([ValueError('msg')]) | Map(raise_exception)
   398  
   399    def test_ptransform_overrides(self):
   400      class MyParDoOverride(PTransformOverride):
   401        def matches(self, applied_ptransform):
   402          return isinstance(applied_ptransform.transform, DoubleParDo)
   403  
   404        def get_replacement_transform_for_applied_ptransform(
   405            self, applied_ptransform):
   406          ptransform = applied_ptransform.transform
   407          if isinstance(ptransform, DoubleParDo):
   408            return TripleParDo()
   409          raise ValueError('Unsupported type of transform: %r' % ptransform)
   410  
   411      p = Pipeline()
   412      pcoll = p | beam.Create([1, 2, 3]) | 'Multiply' >> DoubleParDo()
   413      assert_that(pcoll, equal_to([3, 6, 9]))
   414  
   415      p.replace_all([MyParDoOverride()])
   416      p.run()
   417  
   418    def test_ptransform_override_type_hints(self):
   419      class NoTypeHintOverride(PTransformOverride):
   420        def matches(self, applied_ptransform):
   421          return isinstance(applied_ptransform.transform, DoubleParDo)
   422  
   423        def get_replacement_transform_for_applied_ptransform(
   424            self, applied_ptransform):
   425          return ToStringParDo()
   426  
   427      class WithTypeHintOverride(PTransformOverride):
   428        def matches(self, applied_ptransform):
   429          return isinstance(applied_ptransform.transform, DoubleParDo)
   430  
   431        def get_replacement_transform_for_applied_ptransform(
   432            self, applied_ptransform):
   433          return ToStringParDo().with_input_types(int).with_output_types(str)
   434  
   435      for override, expected_type in [(NoTypeHintOverride(), int),
   436                                      (WithTypeHintOverride(), str)]:
   437        p = TestPipeline()
   438        pcoll = (
   439            p
   440            | beam.Create([1, 2, 3])
   441            | 'Operate' >> DoubleParDo()
   442            | 'NoOp' >> beam.Map(lambda x: x))
   443  
   444        p.replace_all([override])
   445        self.assertEqual(pcoll.producer.inputs[0].element_type, expected_type)
   446  
   447    def test_ptransform_override_multiple_inputs(self):
   448      class MyParDoOverride(PTransformOverride):
   449        def matches(self, applied_ptransform):
   450          return isinstance(applied_ptransform.transform, FlattenAndDouble)
   451  
   452        def get_replacement_transform(self, applied_ptransform):
   453          return FlattenAndTriple()
   454  
   455      p = Pipeline()
   456      pcoll1 = p | 'pc1' >> beam.Create([1, 2, 3])
   457      pcoll2 = p | 'pc2' >> beam.Create([4, 5, 6])
   458      pcoll3 = (pcoll1, pcoll2) | 'FlattenAndMultiply' >> FlattenAndDouble()
   459      assert_that(pcoll3, equal_to([3, 6, 9, 12, 15, 18]))
   460  
   461      p.replace_all([MyParDoOverride()])
   462      p.run()
   463  
   464    def test_ptransform_override_side_inputs(self):
   465      class MyParDoOverride(PTransformOverride):
   466        def matches(self, applied_ptransform):
   467          return (
   468              isinstance(applied_ptransform.transform, ParDo) and
   469              isinstance(applied_ptransform.transform.fn, AddWithProductDoFn))
   470  
   471        def get_replacement_transform(self, transform):
   472          return AddThenMultiply()
   473  
   474      p = Pipeline()
   475      pcoll1 = p | 'pc1' >> beam.Create([2])
   476      pcoll2 = p | 'pc2' >> beam.Create([3])
   477      pcoll3 = p | 'pc3' >> beam.Create([4, 5, 6])
   478      result = pcoll3 | 'Operate' >> beam.ParDo(
   479          AddWithProductDoFn(), AsSingleton(pcoll1), AsSingleton(pcoll2))
   480      assert_that(result, equal_to([18, 21, 24]))
   481  
   482      p.replace_all([MyParDoOverride()])
   483      p.run()
   484  
   485    def test_ptransform_override_replacement_inputs(self):
   486      class MyParDoOverride(PTransformOverride):
   487        def matches(self, applied_ptransform):
   488          return (
   489              isinstance(applied_ptransform.transform, ParDo) and
   490              isinstance(applied_ptransform.transform.fn, AddWithProductDoFn))
   491  
   492        def get_replacement_transform(self, transform):
   493          return AddThenMultiply()
   494  
   495        def get_replacement_inputs(self, applied_ptransform):
   496          assert len(applied_ptransform.inputs) == 1
   497          assert len(applied_ptransform.side_inputs) == 2
   498          # Swap the order of the two side inputs
   499          return (
   500              applied_ptransform.inputs[0],
   501              applied_ptransform.side_inputs[1].pvalue,
   502              applied_ptransform.side_inputs[0].pvalue)
   503  
   504      p = Pipeline()
   505      pcoll1 = p | 'pc1' >> beam.Create([2])
   506      pcoll2 = p | 'pc2' >> beam.Create([3])
   507      pcoll3 = p | 'pc3' >> beam.Create([4, 5, 6])
   508      result = pcoll3 | 'Operate' >> beam.ParDo(
   509          AddWithProductDoFn(), AsSingleton(pcoll1), AsSingleton(pcoll2))
   510      assert_that(result, equal_to([14, 16, 18]))
   511  
   512      p.replace_all([MyParDoOverride()])
   513      p.run()
   514  
   515    def test_ptransform_override_multiple_outputs(self):
   516      class MultiOutputComposite(PTransform):
   517        def __init__(self):
   518          self.output_tags = set()
   519  
   520        def expand(self, pcoll):
   521          def mux_input(x):
   522            x = x * 2
   523            if isinstance(x, int):
   524              yield TaggedOutput('numbers', x)
   525            else:
   526              yield TaggedOutput('letters', x)
   527  
   528          multi = pcoll | 'MyReplacement' >> beam.ParDo(mux_input).with_outputs()
   529          letters = multi.letters | 'LettersComposite' >> beam.Map(
   530              lambda x: x * 3)
   531          numbers = multi.numbers | 'NumbersComposite' >> beam.Map(
   532              lambda x: x * 5)
   533  
   534          return {
   535              'letters': letters,
   536              'numbers': numbers,
   537          }
   538  
   539      class MultiOutputOverride(PTransformOverride):
   540        def matches(self, applied_ptransform):
   541          return applied_ptransform.full_label == 'MyMultiOutput'
   542  
   543        def get_replacement_transform_for_applied_ptransform(
   544            self, applied_ptransform):
   545          return MultiOutputComposite()
   546  
   547      def mux_input(x):
   548        if isinstance(x, int):
   549          yield TaggedOutput('numbers', x)
   550        else:
   551          yield TaggedOutput('letters', x)
   552  
   553      with TestPipeline() as p:
   554        multi = (
   555            p
   556            | beam.Create([1, 2, 3, 'a', 'b', 'c'])
   557            | 'MyMultiOutput' >> beam.ParDo(mux_input).with_outputs())
   558        letters = multi.letters | 'MyLetters' >> beam.Map(lambda x: x)
   559        numbers = multi.numbers | 'MyNumbers' >> beam.Map(lambda x: x)
   560  
   561        # Assert that the PCollection replacement worked correctly and that
   562        # elements are flowing through. The replacement transform first
   563        # multiples by 2 then the leaf nodes inside the composite multiply by
   564        # an additional 3 and 5. Use prime numbers to ensure that each
   565        # transform is getting executed once.
   566        assert_that(
   567            letters,
   568            equal_to(['a' * 2 * 3, 'b' * 2 * 3, 'c' * 2 * 3]),
   569            label='assert letters')
   570        assert_that(
   571            numbers,
   572            equal_to([1 * 2 * 5, 2 * 2 * 5, 3 * 2 * 5]),
   573            label='assert numbers')
   574  
   575        # Do the replacement and run the element assertions.
   576        p.replace_all([MultiOutputOverride()])
   577  
   578      # The following checks the graph to make sure the replacement occurred.
   579      visitor = PipelineTest.Visitor(visited=[])
   580      p.visit(visitor)
   581      pcollections = visitor.visited
   582      composites = visitor.enter_composite
   583  
   584      # Assert the replacement is in the composite list and retrieve the
   585      # AppliedPTransform.
   586      self.assertIn(
   587          MultiOutputComposite, [t.transform.__class__ for t in composites])
   588      multi_output_composite = list(
   589          filter(
   590              lambda t: t.transform.__class__ == MultiOutputComposite,
   591              composites))[0]
   592  
   593      # Assert that all of the replacement PCollections are in the graph.
   594      for output in multi_output_composite.outputs.values():
   595        self.assertIn(output, pcollections)
   596  
   597      # Assert that all of the "old"/replaced PCollections are not in the graph.
   598      self.assertNotIn(multi[None], visitor.visited)
   599      self.assertNotIn(multi.letters, visitor.visited)
   600      self.assertNotIn(multi.numbers, visitor.visited)
   601  
   602    def test_kv_ptransform_honor_type_hints(self):
   603  
   604      # The return type of this DoFn cannot be inferred by the default
   605      # Beam type inference
   606      class StatefulDoFn(DoFn):
   607        BYTES_STATE = BagStateSpec('bytes', BytesCoder())
   608  
   609        def return_recursive(self, count):
   610          if count == 0:
   611            return ["some string"]
   612          else:
   613            self.return_recursive(count - 1)
   614  
   615        def process(self, element, counter=DoFn.StateParam(BYTES_STATE)):
   616          return self.return_recursive(1)
   617  
   618      with TestPipeline() as p:
   619        pcoll = (
   620            p
   621            | beam.Create([(1, 1), (2, 2), (3, 3)])
   622            | beam.GroupByKey()
   623            | beam.ParDo(StatefulDoFn()))
   624      self.assertEqual(pcoll.element_type, typehints.Any)
   625  
   626      with TestPipeline() as p:
   627        pcoll = (
   628            p
   629            | beam.Create([(1, 1), (2, 2), (3, 3)])
   630            | beam.GroupByKey()
   631            | beam.ParDo(StatefulDoFn()).with_output_types(str))
   632      self.assertEqual(pcoll.element_type, str)
   633  
   634    def test_track_pcoll_unbounded(self):
   635      pipeline = TestPipeline()
   636      pcoll1 = pipeline | 'read' >> Read(FakeUnboundedSource())
   637      pcoll2 = pcoll1 | 'do1' >> FlatMap(lambda x: [x + 1])
   638      pcoll3 = pcoll2 | 'do2' >> FlatMap(lambda x: [x + 1])
   639      self.assertIs(pcoll1.is_bounded, False)
   640      self.assertIs(pcoll2.is_bounded, False)
   641      self.assertIs(pcoll3.is_bounded, False)
   642  
   643    def test_track_pcoll_bounded(self):
   644      pipeline = TestPipeline()
   645      pcoll1 = pipeline | 'label1' >> Create([1, 2, 3])
   646      pcoll2 = pcoll1 | 'do1' >> FlatMap(lambda x: [x + 1])
   647      pcoll3 = pcoll2 | 'do2' >> FlatMap(lambda x: [x + 1])
   648      self.assertIs(pcoll1.is_bounded, True)
   649      self.assertIs(pcoll2.is_bounded, True)
   650      self.assertIs(pcoll3.is_bounded, True)
   651  
   652    def test_track_pcoll_bounded_flatten(self):
   653      pipeline = TestPipeline()
   654      pcoll1_a = pipeline | 'label_a' >> Create([1, 2, 3])
   655      pcoll2_a = pcoll1_a | 'do_a' >> FlatMap(lambda x: [x + 1])
   656  
   657      pcoll1_b = pipeline | 'label_b' >> Create([1, 2, 3])
   658      pcoll2_b = pcoll1_b | 'do_b' >> FlatMap(lambda x: [x + 1])
   659  
   660      merged = (pcoll2_a, pcoll2_b) | beam.Flatten()
   661  
   662      self.assertIs(pcoll1_a.is_bounded, True)
   663      self.assertIs(pcoll2_a.is_bounded, True)
   664      self.assertIs(pcoll1_b.is_bounded, True)
   665      self.assertIs(pcoll2_b.is_bounded, True)
   666      self.assertIs(merged.is_bounded, True)
   667  
   668    def test_track_pcoll_unbounded_flatten(self):
   669      pipeline = TestPipeline()
   670      pcoll1_bounded = pipeline | 'label1' >> Create([1, 2, 3])
   671      pcoll2_bounded = pcoll1_bounded | 'do1' >> FlatMap(lambda x: [x + 1])
   672  
   673      pcoll1_unbounded = pipeline | 'read' >> Read(FakeUnboundedSource())
   674      pcoll2_unbounded = pcoll1_unbounded | 'do2' >> FlatMap(lambda x: [x + 1])
   675  
   676      merged = (pcoll2_bounded, pcoll2_unbounded) | beam.Flatten()
   677  
   678      self.assertIs(pcoll1_bounded.is_bounded, True)
   679      self.assertIs(pcoll2_bounded.is_bounded, True)
   680      self.assertIs(pcoll1_unbounded.is_bounded, False)
   681      self.assertIs(pcoll2_unbounded.is_bounded, False)
   682      self.assertIs(merged.is_bounded, False)
   683  
   684    def test_incompatible_submission_and_runtime_envs_fail_pipeline(self):
   685      with mock.patch(
   686          'apache_beam.transforms.environments.sdk_base_version_capability'
   687      ) as base_version:
   688        base_version.side_effect = [
   689            f"beam:version:sdk_base:apache/beam_python3.5_sdk:2.{i}.0"
   690            for i in range(100)
   691        ]
   692        with self.assertRaisesRegex(
   693            RuntimeError,
   694            'Pipeline construction environment and pipeline runtime '
   695            'environment are not compatible.'):
   696          with TestPipeline() as p:
   697            _ = p | Create([None])
   698  
   699  
   700  class DoFnTest(unittest.TestCase):
   701    def test_element(self):
   702      class TestDoFn(DoFn):
   703        def process(self, element):
   704          yield element + 10
   705  
   706      with TestPipeline() as pipeline:
   707        pcoll = pipeline | 'Create' >> Create([1, 2]) | 'Do' >> ParDo(TestDoFn())
   708        assert_that(pcoll, equal_to([11, 12]))
   709  
   710    def test_side_input_no_tag(self):
   711      class TestDoFn(DoFn):
   712        def process(self, element, prefix, suffix):
   713          return ['%s-%s-%s' % (prefix, element, suffix)]
   714  
   715      with TestPipeline() as pipeline:
   716        words_list = ['aa', 'bb', 'cc']
   717        words = pipeline | 'SomeWords' >> Create(words_list)
   718        prefix = 'zyx'
   719        suffix = pipeline | 'SomeString' >> Create(['xyz'])  # side in
   720        result = words | 'DecorateWordsDoFnNoTag' >> ParDo(
   721            TestDoFn(), prefix, suffix=AsSingleton(suffix))
   722        assert_that(result, equal_to(['zyx-%s-xyz' % x for x in words_list]))
   723  
   724    def test_side_input_tagged(self):
   725      class TestDoFn(DoFn):
   726        def process(self, element, prefix, suffix=DoFn.SideInputParam):
   727          return ['%s-%s-%s' % (prefix, element, suffix)]
   728  
   729      with TestPipeline() as pipeline:
   730        words_list = ['aa', 'bb', 'cc']
   731        words = pipeline | 'SomeWords' >> Create(words_list)
   732        prefix = 'zyx'
   733        suffix = pipeline | 'SomeString' >> Create(['xyz'])  # side in
   734        result = words | 'DecorateWordsDoFnNoTag' >> ParDo(
   735            TestDoFn(), prefix, suffix=AsSingleton(suffix))
   736        assert_that(result, equal_to(['zyx-%s-xyz' % x for x in words_list]))
   737  
   738    @pytest.mark.it_validatesrunner
   739    def test_element_param(self):
   740      pipeline = TestPipeline()
   741      input = [1, 2]
   742      pcoll = (
   743          pipeline
   744          | 'Create' >> Create(input)
   745          | 'Ele param' >> Map(lambda element=DoFn.ElementParam: element))
   746      assert_that(pcoll, equal_to(input))
   747      pipeline.run()
   748  
   749    @pytest.mark.it_validatesrunner
   750    def test_key_param(self):
   751      pipeline = TestPipeline()
   752      pcoll = (
   753          pipeline
   754          | 'Create' >> Create([('a', 1), ('b', 2)])
   755          | 'Key param' >> Map(lambda _, key=DoFn.KeyParam: key))
   756      assert_that(pcoll, equal_to(['a', 'b']))
   757      pipeline.run()
   758  
   759    def test_window_param(self):
   760      class TestDoFn(DoFn):
   761        def process(self, element, window=DoFn.WindowParam):
   762          yield (element, (float(window.start), float(window.end)))
   763  
   764      with TestPipeline() as pipeline:
   765        pcoll = (
   766            pipeline
   767            | Create([1, 7])
   768            | Map(lambda x: TimestampedValue(x, x))
   769            | WindowInto(windowfn=SlidingWindows(10, 5))
   770            | ParDo(TestDoFn()))
   771        assert_that(
   772            pcoll,
   773            equal_to([(1, (-5, 5)), (1, (0, 10)), (7, (0, 10)), (7, (5, 15))]))
   774        pcoll2 = pcoll | 'Again' >> ParDo(TestDoFn())
   775        assert_that(
   776            pcoll2,
   777            equal_to([((1, (-5, 5)), (-5, 5)), ((1, (0, 10)), (0, 10)),
   778                      ((7, (0, 10)), (0, 10)), ((7, (5, 15)), (5, 15))]),
   779            label='doubled windows')
   780  
   781    def test_timestamp_param(self):
   782      class TestDoFn(DoFn):
   783        def process(self, element, timestamp=DoFn.TimestampParam):
   784          yield timestamp
   785  
   786      with TestPipeline() as pipeline:
   787        pcoll = pipeline | 'Create' >> Create([1, 2]) | 'Do' >> ParDo(TestDoFn())
   788        assert_that(pcoll, equal_to([MIN_TIMESTAMP, MIN_TIMESTAMP]))
   789  
   790    def test_timestamp_param_map(self):
   791      with TestPipeline() as p:
   792        assert_that(
   793            p | Create([1, 2]) | beam.Map(lambda _, t=DoFn.TimestampParam: t),
   794            equal_to([MIN_TIMESTAMP, MIN_TIMESTAMP]))
   795  
   796    def test_pane_info_param(self):
   797      with TestPipeline() as p:
   798        pc = p | Create([(None, None)])
   799        assert_that(
   800            pc | beam.Map(lambda _, p=DoFn.PaneInfoParam: p),
   801            equal_to([windowed_value.PANE_INFO_UNKNOWN]),
   802            label='CheckUngrouped')
   803        assert_that(
   804            pc | beam.GroupByKey() | beam.Map(lambda _, p=DoFn.PaneInfoParam: p),
   805            equal_to([
   806                windowed_value.PaneInfo(
   807                    is_first=True,
   808                    is_last=True,
   809                    timing=windowed_value.PaneInfoTiming.ON_TIME,
   810                    index=0,
   811                    nonspeculative_index=0)
   812            ]),
   813            label='CheckGrouped')
   814  
   815    def test_incomparable_default(self):
   816      class IncomparableType(object):
   817        def __eq__(self, other):
   818          raise RuntimeError()
   819  
   820        def __ne__(self, other):
   821          raise RuntimeError()
   822  
   823        def __hash__(self):
   824          raise RuntimeError()
   825  
   826      # Ensure that we don't use default values in a context where they must be
   827      # comparable (see BEAM-8301).
   828      with TestPipeline() as pipeline:
   829        pcoll = (
   830            pipeline
   831            | beam.Create([None])
   832            | Map(lambda e, x=IncomparableType(): (e, type(x).__name__)))
   833        assert_that(pcoll, equal_to([(None, 'IncomparableType')]))
   834  
   835  
   836  class Bacon(PipelineOptions):
   837    @classmethod
   838    def _add_argparse_args(cls, parser):
   839      parser.add_argument('--slices', type=int)
   840  
   841  
   842  class Eggs(PipelineOptions):
   843    @classmethod
   844    def _add_argparse_args(cls, parser):
   845      parser.add_argument('--style', default='scrambled')
   846  
   847  
   848  class Breakfast(Bacon, Eggs):
   849    pass
   850  
   851  
   852  class PipelineOptionsTest(unittest.TestCase):
   853    def test_flag_parsing(self):
   854      options = Breakfast(['--slices=3', '--style=sunny side up', '--ignored'])
   855      self.assertEqual(3, options.slices)
   856      self.assertEqual('sunny side up', options.style)
   857  
   858    def test_keyword_parsing(self):
   859      options = Breakfast(['--slices=3', '--style=sunny side up', '--ignored'],
   860                          slices=10)
   861      self.assertEqual(10, options.slices)
   862      self.assertEqual('sunny side up', options.style)
   863  
   864    def test_attribute_setting(self):
   865      options = Breakfast(slices=10)
   866      self.assertEqual(10, options.slices)
   867      options.slices = 20
   868      self.assertEqual(20, options.slices)
   869  
   870    def test_view_as(self):
   871      generic_options = PipelineOptions(['--slices=3'])
   872      self.assertEqual(3, generic_options.view_as(Bacon).slices)
   873      self.assertEqual(3, generic_options.view_as(Breakfast).slices)
   874  
   875      generic_options.view_as(Breakfast).slices = 10
   876      self.assertEqual(10, generic_options.view_as(Bacon).slices)
   877  
   878      with self.assertRaises(AttributeError):
   879        generic_options.slices  # pylint: disable=pointless-statement
   880  
   881      with self.assertRaises(AttributeError):
   882        generic_options.view_as(Eggs).slices  # pylint: disable=expression-not-assigned
   883  
   884    def test_defaults(self):
   885      options = Breakfast(['--slices=3'])
   886      self.assertEqual(3, options.slices)
   887      self.assertEqual('scrambled', options.style)
   888  
   889    def test_dir(self):
   890      options = Breakfast()
   891      self.assertEqual({
   892          'from_dictionary',
   893          'get_all_options',
   894          'slices',
   895          'style',
   896          'view_as',
   897          'display_data'
   898      },
   899                       {
   900                           attr
   901                           for attr in dir(options)
   902                           if not attr.startswith('_') and attr != 'next'
   903                       })
   904      self.assertEqual({
   905          'from_dictionary',
   906          'get_all_options',
   907          'style',
   908          'view_as',
   909          'display_data'
   910      },
   911                       {
   912                           attr
   913                           for attr in dir(options.view_as(Eggs))
   914                           if not attr.startswith('_') and attr != 'next'
   915                       })
   916  
   917  
   918  class RunnerApiTest(unittest.TestCase):
   919    def test_parent_pointer(self):
   920      class MyPTransform(beam.PTransform):
   921        def expand(self, p):
   922          self.p = p
   923          return p | beam.Create([None])
   924  
   925      p = beam.Pipeline()
   926      p | MyPTransform()  # pylint: disable=expression-not-assigned
   927      p = Pipeline.from_runner_api(
   928          Pipeline.to_runner_api(p, use_fake_coders=True), None, None)
   929      self.assertIsNotNone(p.transforms_stack[0].parts[0].parent)
   930      self.assertEqual(
   931          p.transforms_stack[0].parts[0].parent, p.transforms_stack[0])
   932  
   933    def test_requirements(self):
   934      p = beam.Pipeline()
   935      _ = (
   936          p | beam.Create([])
   937          | beam.ParDo(lambda x, finalize=beam.DoFn.BundleFinalizerParam: None))
   938      proto = p.to_runner_api()
   939      self.assertTrue(
   940          common_urns.requirements.REQUIRES_BUNDLE_FINALIZATION.urn,
   941          proto.requirements)
   942  
   943    def test_annotations(self):
   944      some_proto = BytesCoder().to_runner_api(None)
   945  
   946      class EmptyTransform(beam.PTransform):
   947        def expand(self, pcoll):
   948          return pcoll
   949  
   950        def annotations(self):
   951          return {'foo': 'some_string'}
   952  
   953      class NonEmptyTransform(beam.PTransform):
   954        def expand(self, pcoll):
   955          return pcoll | beam.Map(lambda x: x)
   956  
   957        def annotations(self):
   958          return {
   959              'foo': b'some_bytes',
   960              'proto': some_proto,
   961          }
   962  
   963      p = beam.Pipeline()
   964      _ = p | beam.Create([]) | EmptyTransform() | NonEmptyTransform()
   965      proto = p.to_runner_api()
   966  
   967      seen = 0
   968      for transform in proto.components.transforms.values():
   969        if transform.unique_name == 'EmptyTransform':
   970          seen += 1
   971          self.assertEqual(transform.annotations['foo'], b'some_string')
   972        elif transform.unique_name == 'NonEmptyTransform':
   973          seen += 1
   974          self.assertEqual(transform.annotations['foo'], b'some_bytes')
   975          self.assertEqual(
   976              transform.annotations['proto'], some_proto.SerializeToString())
   977      self.assertEqual(seen, 2)
   978  
   979    def test_transform_ids(self):
   980      class MyPTransform(beam.PTransform):
   981        def expand(self, p):
   982          self.p = p
   983          return p | beam.Create([None])
   984  
   985      p = beam.Pipeline()
   986      p | MyPTransform()  # pylint: disable=expression-not-assigned
   987      runner_api_proto = Pipeline.to_runner_api(p)
   988  
   989      for transform_id in runner_api_proto.components.transforms:
   990        self.assertRegex(transform_id, r'[a-zA-Z0-9-_]+')
   991  
   992    def test_input_names(self):
   993      class MyPTransform(beam.PTransform):
   994        def expand(self, pcolls):
   995          return pcolls.values() | beam.Flatten()
   996  
   997      p = beam.Pipeline()
   998      input_names = set('ABC')
   999      inputs = {x: p | x >> beam.Create([x]) for x in input_names}
  1000      inputs | MyPTransform()  # pylint: disable=expression-not-assigned
  1001      runner_api_proto = Pipeline.to_runner_api(p)
  1002  
  1003      for transform_proto in runner_api_proto.components.transforms.values():
  1004        if transform_proto.unique_name == 'MyPTransform':
  1005          self.assertEqual(set(transform_proto.inputs.keys()), input_names)
  1006          break
  1007      else:
  1008        self.fail('Unable to find transform.')
  1009  
  1010    def test_display_data(self):
  1011      class MyParentTransform(beam.PTransform):
  1012        def expand(self, p):
  1013          self.p = p
  1014          return p | beam.Create([None])
  1015  
  1016        def display_data(self):  # type: () -> dict
  1017          parent_dd = super().display_data()
  1018          parent_dd['p_dd_string'] = DisplayDataItem(
  1019              'p_dd_string_value', label='p_dd_string_label')
  1020          parent_dd['p_dd_string_2'] = DisplayDataItem('p_dd_string_value_2')
  1021          parent_dd['p_dd_bool'] = DisplayDataItem(True, label='p_dd_bool_label')
  1022          parent_dd['p_dd_int'] = DisplayDataItem(1, label='p_dd_int_label')
  1023          return parent_dd
  1024  
  1025      class MyPTransform(MyParentTransform):
  1026        def expand(self, p):
  1027          self.p = p
  1028          return p | beam.Create([None])
  1029  
  1030        def display_data(self):  # type: () -> dict
  1031          parent_dd = super().display_data()
  1032          parent_dd['dd_string'] = DisplayDataItem(
  1033              'dd_string_value', label='dd_string_label')
  1034          parent_dd['dd_string_2'] = DisplayDataItem('dd_string_value_2')
  1035          parent_dd['dd_bool'] = DisplayDataItem(False, label='dd_bool_label')
  1036          parent_dd['dd_double'] = DisplayDataItem(1.1, label='dd_double_label')
  1037          return parent_dd
  1038  
  1039      p = beam.Pipeline()
  1040      p | MyPTransform()  # pylint: disable=expression-not-assigned
  1041  
  1042      proto_pipeline = Pipeline.to_runner_api(p, use_fake_coders=True)
  1043      my_transform, = [
  1044          transform
  1045          for transform in proto_pipeline.components.transforms.values()
  1046          if transform.unique_name == 'MyPTransform'
  1047      ]
  1048      self.assertIsNotNone(my_transform)
  1049      self.assertListEqual(
  1050          list(my_transform.display_data),
  1051          [
  1052              beam_runner_api_pb2.DisplayData(
  1053                  urn=common_urns.StandardDisplayData.DisplayData.LABELLED.urn,
  1054                  payload=beam_runner_api_pb2.LabelledPayload(
  1055                      label='p_dd_string_label',
  1056                      key='p_dd_string',
  1057                      namespace='apache_beam.pipeline_test.MyPTransform',
  1058                      string_value='p_dd_string_value').SerializeToString()),
  1059              beam_runner_api_pb2.DisplayData(
  1060                  urn=common_urns.StandardDisplayData.DisplayData.LABELLED.urn,
  1061                  payload=beam_runner_api_pb2.LabelledPayload(
  1062                      label='p_dd_string_2',
  1063                      key='p_dd_string_2',
  1064                      namespace='apache_beam.pipeline_test.MyPTransform',
  1065                      string_value='p_dd_string_value_2').SerializeToString()),
  1066              beam_runner_api_pb2.DisplayData(
  1067                  urn=common_urns.StandardDisplayData.DisplayData.LABELLED.urn,
  1068                  payload=beam_runner_api_pb2.LabelledPayload(
  1069                      label='p_dd_bool_label',
  1070                      key='p_dd_bool',
  1071                      namespace='apache_beam.pipeline_test.MyPTransform',
  1072                      bool_value=True).SerializeToString()),
  1073              beam_runner_api_pb2.DisplayData(
  1074                  urn=common_urns.StandardDisplayData.DisplayData.LABELLED.urn,
  1075                  payload=beam_runner_api_pb2.LabelledPayload(
  1076                      label='p_dd_int_label',
  1077                      key='p_dd_int',
  1078                      namespace='apache_beam.pipeline_test.MyPTransform',
  1079                      int_value=1).SerializeToString()),
  1080              beam_runner_api_pb2.DisplayData(
  1081                  urn=common_urns.StandardDisplayData.DisplayData.LABELLED.urn,
  1082                  payload=beam_runner_api_pb2.LabelledPayload(
  1083                      label='dd_string_label',
  1084                      key='dd_string',
  1085                      namespace='apache_beam.pipeline_test.MyPTransform',
  1086                      string_value='dd_string_value').SerializeToString()),
  1087              beam_runner_api_pb2.DisplayData(
  1088                  urn=common_urns.StandardDisplayData.DisplayData.LABELLED.urn,
  1089                  payload=beam_runner_api_pb2.LabelledPayload(
  1090                      label='dd_string_2',
  1091                      key='dd_string_2',
  1092                      namespace='apache_beam.pipeline_test.MyPTransform',
  1093                      string_value='dd_string_value_2').SerializeToString()),
  1094              beam_runner_api_pb2.DisplayData(
  1095                  urn=common_urns.StandardDisplayData.DisplayData.LABELLED.urn,
  1096                  payload=beam_runner_api_pb2.LabelledPayload(
  1097                      label='dd_bool_label',
  1098                      key='dd_bool',
  1099                      namespace='apache_beam.pipeline_test.MyPTransform',
  1100                      bool_value=False).SerializeToString()),
  1101              beam_runner_api_pb2.DisplayData(
  1102                  urn=common_urns.StandardDisplayData.DisplayData.LABELLED.urn,
  1103                  payload=beam_runner_api_pb2.LabelledPayload(
  1104                      label='dd_double_label',
  1105                      key='dd_double',
  1106                      namespace='apache_beam.pipeline_test.MyPTransform',
  1107                      double_value=1.1).SerializeToString()),
  1108          ])
  1109  
  1110    def test_runner_api_roundtrip_preserves_resource_hints(self):
  1111      p = beam.Pipeline()
  1112      _ = (
  1113          p | beam.Create([1, 2])
  1114          | beam.Map(lambda x: x + 1).with_resource_hints(accelerator='gpu'))
  1115  
  1116      self.assertEqual(
  1117          p.transforms_stack[0].parts[1].transform.get_resource_hints(),
  1118          {common_urns.resource_hints.ACCELERATOR.urn: b'gpu'})
  1119  
  1120      for _ in range(3):
  1121        # Verify that DEFAULT environments are recreated during multiple RunnerAPI
  1122        # translation and hints don't get lost.
  1123        p = Pipeline.from_runner_api(Pipeline.to_runner_api(p), None, None)
  1124        self.assertEqual(
  1125            p.transforms_stack[0].parts[1].transform.get_resource_hints(),
  1126            {common_urns.resource_hints.ACCELERATOR.urn: b'gpu'})
  1127  
  1128    def test_hints_on_composite_transforms_are_propagated_to_subtransforms(self):
  1129      class FooHint(ResourceHint):
  1130        urn = 'foo_urn'
  1131  
  1132      class BarHint(ResourceHint):
  1133        urn = 'bar_urn'
  1134  
  1135      class BazHint(ResourceHint):
  1136        urn = 'baz_urn'
  1137  
  1138      class QuxHint(ResourceHint):
  1139        urn = 'qux_urn'
  1140  
  1141      class UseMaxValueHint(ResourceHint):
  1142        urn = 'use_max_value_urn'
  1143  
  1144        @classmethod
  1145        def get_merged_value(
  1146            cls, outer_value, inner_value):  # type: (bytes, bytes) -> bytes
  1147          return ResourceHint._use_max(outer_value, inner_value)
  1148  
  1149      ResourceHint.register_resource_hint('foo_hint', FooHint)
  1150      ResourceHint.register_resource_hint('bar_hint', BarHint)
  1151      ResourceHint.register_resource_hint('baz_hint', BazHint)
  1152      ResourceHint.register_resource_hint('qux_hint', QuxHint)
  1153      ResourceHint.register_resource_hint('use_max_value_hint', UseMaxValueHint)
  1154  
  1155      @beam.ptransform_fn
  1156      def SubTransform(pcoll):
  1157        return pcoll | beam.Map(lambda x: x + 1).with_resource_hints(
  1158            foo_hint='set_on_subtransform', use_max_value_hint='10')
  1159  
  1160      @beam.ptransform_fn
  1161      def CompositeTransform(pcoll):
  1162        return pcoll | beam.Map(lambda x: x * 2) | SubTransform()
  1163  
  1164      p = beam.Pipeline()
  1165      _ = (
  1166          p | beam.Create([1, 2])
  1167          | CompositeTransform().with_resource_hints(
  1168              foo_hint='should_be_overriden_by_subtransform',
  1169              bar_hint='set_on_composite',
  1170              baz_hint='set_on_composite',
  1171              use_max_value_hint='100'))
  1172      options = PortableOptions([
  1173          '--resource_hint=baz_hint=should_be_overriden_by_composite',
  1174          '--resource_hint=qux_hint=set_via_options',
  1175          '--environment_type=PROCESS',
  1176          '--environment_option=process_command=foo',
  1177          '--sdk_location=container',
  1178      ])
  1179      environment = ProcessEnvironment.from_options(options)
  1180      proto = Pipeline.to_runner_api(p, default_environment=environment)
  1181  
  1182      for t in proto.components.transforms.values():
  1183        if "CompositeTransform/SubTransform/Map" in t.unique_name:
  1184          environment = proto.components.environments.get(t.environment_id)
  1185          self.assertEqual(
  1186              environment.resource_hints.get('foo_urn'), b'set_on_subtransform')
  1187          self.assertEqual(
  1188              environment.resource_hints.get('bar_urn'), b'set_on_composite')
  1189          self.assertEqual(
  1190              environment.resource_hints.get('baz_urn'), b'set_on_composite')
  1191          self.assertEqual(
  1192              environment.resource_hints.get('qux_urn'), b'set_via_options')
  1193          self.assertEqual(
  1194              environment.resource_hints.get('use_max_value_urn'), b'100')
  1195          found = True
  1196      assert found
  1197  
  1198    def test_environments_with_same_resource_hints_are_reused(self):
  1199      class HintX(ResourceHint):
  1200        urn = 'X_urn'
  1201  
  1202      class HintY(ResourceHint):
  1203        urn = 'Y_urn'
  1204  
  1205      class HintIsOdd(ResourceHint):
  1206        urn = 'IsOdd_urn'
  1207  
  1208      ResourceHint.register_resource_hint('X', HintX)
  1209      ResourceHint.register_resource_hint('Y', HintY)
  1210      ResourceHint.register_resource_hint('IsOdd', HintIsOdd)
  1211  
  1212      p = beam.Pipeline()
  1213      num_iter = 4
  1214      for i in range(num_iter):
  1215        _ = (
  1216            p
  1217            | f'NoHintCreate_{i}' >> beam.Create([1, 2])
  1218            | f'NoHint_{i}' >> beam.Map(lambda x: x + 1))
  1219        _ = (
  1220            p
  1221            | f'XCreate_{i}' >> beam.Create([1, 2])
  1222            |
  1223            f'HintX_{i}' >> beam.Map(lambda x: x + 1).with_resource_hints(X='X'))
  1224        _ = (
  1225            p
  1226            | f'XYCreate_{i}' >> beam.Create([1, 2])
  1227            | f'HintXY_{i}' >> beam.Map(lambda x: x + 1).with_resource_hints(
  1228                X='X', Y='Y'))
  1229        _ = (
  1230            p
  1231            | f'IsOddCreate_{i}' >> beam.Create([1, 2])
  1232            | f'IsOdd_{i}' >>
  1233            beam.Map(lambda x: x + 1).with_resource_hints(IsOdd=str(i % 2 != 0)))
  1234  
  1235      proto = Pipeline.to_runner_api(p)
  1236      count_x = count_xy = count_is_odd = count_no_hints = 0
  1237      env_ids = set()
  1238      for _, t in proto.components.transforms.items():
  1239        env = proto.components.environments[t.environment_id]
  1240        if t.unique_name.startswith('HintX_'):
  1241          count_x += 1
  1242          env_ids.add(t.environment_id)
  1243          self.assertEqual(env.resource_hints, {'X_urn': b'X'})
  1244  
  1245        if t.unique_name.startswith('HintXY_'):
  1246          count_xy += 1
  1247          env_ids.add(t.environment_id)
  1248          self.assertEqual(env.resource_hints, {'X_urn': b'X', 'Y_urn': b'Y'})
  1249  
  1250        if t.unique_name.startswith('NoHint_'):
  1251          count_no_hints += 1
  1252          env_ids.add(t.environment_id)
  1253          self.assertEqual(env.resource_hints, {})
  1254  
  1255        if t.unique_name.startswith('IsOdd_'):
  1256          count_is_odd += 1
  1257          env_ids.add(t.environment_id)
  1258          self.assertTrue(
  1259              env.resource_hints == {'IsOdd_urn': b'True'} or
  1260              env.resource_hints == {'IsOdd_urn': b'False'})
  1261      assert count_x == count_is_odd == count_xy == count_no_hints == num_iter
  1262      assert num_iter > 1
  1263  
  1264      self.assertEqual(len(env_ids), 5)
  1265  
  1266    def test_multiple_application_of_the_same_transform_set_different_hints(self):
  1267      class FooHint(ResourceHint):
  1268        urn = 'foo_urn'
  1269  
  1270      class UseMaxValueHint(ResourceHint):
  1271        urn = 'use_max_value_urn'
  1272  
  1273        @classmethod
  1274        def get_merged_value(
  1275            cls, outer_value, inner_value):  # type: (bytes, bytes) -> bytes
  1276          return ResourceHint._use_max(outer_value, inner_value)
  1277  
  1278      ResourceHint.register_resource_hint('foo_hint', FooHint)
  1279      ResourceHint.register_resource_hint('use_max_value_hint', UseMaxValueHint)
  1280  
  1281      @beam.ptransform_fn
  1282      def SubTransform(pcoll):
  1283        return pcoll | beam.Map(lambda x: x + 1)
  1284  
  1285      @beam.ptransform_fn
  1286      def CompositeTransform(pcoll):
  1287        sub = SubTransform()
  1288        return (
  1289            pcoll
  1290            | 'first' >> sub.with_resource_hints(foo_hint='first_application')
  1291            | 'second' >> sub.with_resource_hints(foo_hint='second_application'))
  1292  
  1293      p = beam.Pipeline()
  1294      _ = (p | beam.Create([1, 2]) | CompositeTransform())
  1295      proto = Pipeline.to_runner_api(p)
  1296      count = 0
  1297      for t in proto.components.transforms.values():
  1298        if "CompositeTransform/first/Map" in t.unique_name:
  1299          environment = proto.components.environments.get(t.environment_id)
  1300          self.assertEqual(
  1301              b'first_application', environment.resource_hints.get('foo_urn'))
  1302          count += 1
  1303        if "CompositeTransform/second/Map" in t.unique_name:
  1304          environment = proto.components.environments.get(t.environment_id)
  1305          self.assertEqual(
  1306              b'second_application', environment.resource_hints.get('foo_urn'))
  1307          count += 1
  1308      assert count == 2
  1309  
  1310    def test_environments_are_deduplicated(self):
  1311      def file_artifact(path, hash, staged_name):
  1312        return beam_runner_api_pb2.ArtifactInformation(
  1313            type_urn=common_urns.artifact_types.FILE.urn,
  1314            type_payload=beam_runner_api_pb2.ArtifactFilePayload(
  1315                path=path, sha256=hash).SerializeToString(),
  1316            role_urn=common_urns.artifact_roles.STAGING_TO.urn,
  1317            role_payload=beam_runner_api_pb2.ArtifactStagingToRolePayload(
  1318                staged_name=staged_name).SerializeToString(),
  1319        )
  1320  
  1321      proto = beam_runner_api_pb2.Pipeline(
  1322          components=beam_runner_api_pb2.Components(
  1323              transforms={
  1324                  f'transform{ix}': beam_runner_api_pb2.PTransform(
  1325                      environment_id=f'e{ix}')
  1326                  for ix in range(8)
  1327              },
  1328              environments={
  1329                  # Same hash and destination.
  1330                  'e1': beam_runner_api_pb2.Environment(
  1331                      dependencies=[file_artifact('a1', 'x', 'dest')]),
  1332                  'e2': beam_runner_api_pb2.Environment(
  1333                      dependencies=[file_artifact('a2', 'x', 'dest')]),
  1334                  # Different hash.
  1335                  'e3': beam_runner_api_pb2.Environment(
  1336                      dependencies=[file_artifact('a3', 'y', 'dest')]),
  1337                  # Different destination.
  1338                  'e4': beam_runner_api_pb2.Environment(
  1339                      dependencies=[file_artifact('a4', 'y', 'dest2')]),
  1340                  # Multiple files with same hash and destinations.
  1341                  'e5': beam_runner_api_pb2.Environment(
  1342                      dependencies=[
  1343                          file_artifact('a1', 'x', 'dest'),
  1344                          file_artifact('b1', 'xb', 'destB')
  1345                      ]),
  1346                  'e6': beam_runner_api_pb2.Environment(
  1347                      dependencies=[
  1348                          file_artifact('a2', 'x', 'dest'),
  1349                          file_artifact('b2', 'xb', 'destB')
  1350                      ]),
  1351                  # Overlapping, but not identical, files.
  1352                  'e7': beam_runner_api_pb2.Environment(
  1353                      dependencies=[
  1354                          file_artifact('a1', 'x', 'dest'),
  1355                          file_artifact('b2', 'y', 'destB')
  1356                      ]),
  1357                  # Same files as first, but differing other properties.
  1358                  'e0': beam_runner_api_pb2.Environment(
  1359                      resource_hints={'hint': b'value'},
  1360                      dependencies=[file_artifact('a1', 'x', 'dest')]),
  1361              }))
  1362      Pipeline.merge_compatible_environments(proto)
  1363  
  1364      # These environments are equivalent.
  1365      self.assertEqual(
  1366          proto.components.transforms['transform1'].environment_id,
  1367          proto.components.transforms['transform2'].environment_id)
  1368  
  1369      self.assertEqual(
  1370          proto.components.transforms['transform5'].environment_id,
  1371          proto.components.transforms['transform6'].environment_id)
  1372  
  1373      # These are not.
  1374      self.assertNotEqual(
  1375          proto.components.transforms['transform1'].environment_id,
  1376          proto.components.transforms['transform3'].environment_id)
  1377      self.assertNotEqual(
  1378          proto.components.transforms['transform4'].environment_id,
  1379          proto.components.transforms['transform3'].environment_id)
  1380      self.assertNotEqual(
  1381          proto.components.transforms['transform6'].environment_id,
  1382          proto.components.transforms['transform7'].environment_id)
  1383      self.assertNotEqual(
  1384          proto.components.transforms['transform1'].environment_id,
  1385          proto.components.transforms['transform0'].environment_id)
  1386  
  1387      self.assertEqual(len(proto.components.environments), 6)
  1388  
  1389  
  1390  if __name__ == '__main__':
  1391    unittest.main()