github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/testing/pipeline_assertion.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Module to verify implicit cache transforms applied by Interactive Beam.
    19  
    20  For internal use only; no backwards-compatibility guarantees.
    21  This utility should only be used by Interactive Beam tests. For example, it can
    22  be used to verify if the implicit cache transforms are applied as expected when
    23  running a pipeline with the InteractiveRunner. It can also be used to verify if
    24  a pipeline fragment has pruned unnecessary transforms. It shouldn't be used to
    25  verify equivalence between pipelines if the code to be tested depends on or
    26  mutates user code within transforms in pipelines.
    27  """
    28  
    29  
    30  def assert_pipeline_equal(test_case, expected_pipeline, actual_pipeline):
    31    """Asserts the equivalence between two given apache_beam.Pipeline instances.
    32  
    33      Args:
    34        test_case: (unittest.TestCase) the unittest testcase where it asserts.
    35        expected_pipeline: (Pipeline) the pipeline instance expected.
    36        actual_pipeline: (Pipeline) the actual pipeline instance to be asserted.
    37      """
    38    expected_pipeline_proto = expected_pipeline.to_runner_api(
    39        use_fake_coders=True)
    40    actual_pipeline_proto = actual_pipeline.to_runner_api(use_fake_coders=True)
    41    assert_pipeline_proto_equal(
    42        test_case, expected_pipeline_proto, actual_pipeline_proto)
    43  
    44  
    45  def assert_pipeline_proto_equal(
    46      test_case, expected_pipeline_proto, actual_pipeline_proto):
    47    """Asserts the equivalence between two pipeline proto representations."""
    48    components1 = expected_pipeline_proto.components
    49    components2 = actual_pipeline_proto.components
    50    test_case.assertEqual(
    51        len(components1.transforms), len(components2.transforms))
    52    test_case.assertEqual(
    53        len(components1.pcollections), len(components2.pcollections))
    54  
    55    # TODO(BEAM-7926): Update tests and make below 2 assertions assertEqual.
    56    test_case.assertLessEqual(
    57        len(components1.windowing_strategies),
    58        len(components2.windowing_strategies))
    59    test_case.assertLessEqual(len(components1.coders), len(components2.coders))
    60  
    61    _assert_transform_equal(
    62        test_case,
    63        actual_pipeline_proto,
    64        actual_pipeline_proto.root_transform_ids[0],
    65        expected_pipeline_proto,
    66        expected_pipeline_proto.root_transform_ids[0])
    67  
    68  
    69  def assert_pipeline_proto_contain_top_level_transform(
    70      test_case, pipeline_proto, transform_label):
    71    """Asserts the top level transforms contain a transform with the given
    72     transform label."""
    73    _assert_pipeline_proto_contains_top_level_transform(
    74        test_case, pipeline_proto, transform_label, True)
    75  
    76  
    77  def assert_pipeline_proto_not_contain_top_level_transform(
    78      test_case, pipeline_proto, transform_label):
    79    """Asserts the top level transforms do not contain a transform with the given
    80     transform label."""
    81    _assert_pipeline_proto_contains_top_level_transform(
    82        test_case, pipeline_proto, transform_label, False)
    83  
    84  
    85  def _assert_pipeline_proto_contains_top_level_transform(
    86      test_case, pipeline_proto, transform_label, contain):
    87    top_level_transform_labels = pipeline_proto.components.transforms[
    88        pipeline_proto.root_transform_ids[0]].subtransforms
    89    test_case.assertEqual(
    90        contain,
    91        any(
    92            transform_label in top_level_transform_label
    93            for top_level_transform_label in top_level_transform_labels))
    94  
    95  
    96  def _assert_transform_equal(
    97      test_case,
    98      expected_pipeline_proto,
    99      expected_transform_id,
   100      actual_pipeline_proto,
   101      actual_transform_id):
   102    """Asserts the equivalence between transforms from two given pipelines. """
   103    transform_proto1 = expected_pipeline_proto.components.transforms[
   104        expected_transform_id]
   105    transform_proto2 = actual_pipeline_proto.components.transforms[
   106        actual_transform_id]
   107    test_case.assertEqual(transform_proto1.spec.urn, transform_proto2.spec.urn)
   108    # Skipping payload checking because PTransforms of the same functionality
   109    # could generate different payloads.
   110    test_case.assertEqual(
   111        len(transform_proto1.subtransforms), len(transform_proto2.subtransforms))
   112    test_case.assertSetEqual(
   113        set(transform_proto1.inputs), set(transform_proto2.inputs))
   114    test_case.assertSetEqual(
   115        set(transform_proto1.outputs), set(transform_proto2.outputs))