github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/testing/pipeline_assertion.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Module to verify implicit cache transforms applied by Interactive Beam. 19 20 For internal use only; no backwards-compatibility guarantees. 21 This utility should only be used by Interactive Beam tests. For example, it can 22 be used to verify if the implicit cache transforms are applied as expected when 23 running a pipeline with the InteractiveRunner. It can also be used to verify if 24 a pipeline fragment has pruned unnecessary transforms. It shouldn't be used to 25 verify equivalence between pipelines if the code to be tested depends on or 26 mutates user code within transforms in pipelines. 27 """ 28 29 30 def assert_pipeline_equal(test_case, expected_pipeline, actual_pipeline): 31 """Asserts the equivalence between two given apache_beam.Pipeline instances. 32 33 Args: 34 test_case: (unittest.TestCase) the unittest testcase where it asserts. 35 expected_pipeline: (Pipeline) the pipeline instance expected. 36 actual_pipeline: (Pipeline) the actual pipeline instance to be asserted. 37 """ 38 expected_pipeline_proto = expected_pipeline.to_runner_api( 39 use_fake_coders=True) 40 actual_pipeline_proto = actual_pipeline.to_runner_api(use_fake_coders=True) 41 assert_pipeline_proto_equal( 42 test_case, expected_pipeline_proto, actual_pipeline_proto) 43 44 45 def assert_pipeline_proto_equal( 46 test_case, expected_pipeline_proto, actual_pipeline_proto): 47 """Asserts the equivalence between two pipeline proto representations.""" 48 components1 = expected_pipeline_proto.components 49 components2 = actual_pipeline_proto.components 50 test_case.assertEqual( 51 len(components1.transforms), len(components2.transforms)) 52 test_case.assertEqual( 53 len(components1.pcollections), len(components2.pcollections)) 54 55 # TODO(BEAM-7926): Update tests and make below 2 assertions assertEqual. 56 test_case.assertLessEqual( 57 len(components1.windowing_strategies), 58 len(components2.windowing_strategies)) 59 test_case.assertLessEqual(len(components1.coders), len(components2.coders)) 60 61 _assert_transform_equal( 62 test_case, 63 actual_pipeline_proto, 64 actual_pipeline_proto.root_transform_ids[0], 65 expected_pipeline_proto, 66 expected_pipeline_proto.root_transform_ids[0]) 67 68 69 def assert_pipeline_proto_contain_top_level_transform( 70 test_case, pipeline_proto, transform_label): 71 """Asserts the top level transforms contain a transform with the given 72 transform label.""" 73 _assert_pipeline_proto_contains_top_level_transform( 74 test_case, pipeline_proto, transform_label, True) 75 76 77 def assert_pipeline_proto_not_contain_top_level_transform( 78 test_case, pipeline_proto, transform_label): 79 """Asserts the top level transforms do not contain a transform with the given 80 transform label.""" 81 _assert_pipeline_proto_contains_top_level_transform( 82 test_case, pipeline_proto, transform_label, False) 83 84 85 def _assert_pipeline_proto_contains_top_level_transform( 86 test_case, pipeline_proto, transform_label, contain): 87 top_level_transform_labels = pipeline_proto.components.transforms[ 88 pipeline_proto.root_transform_ids[0]].subtransforms 89 test_case.assertEqual( 90 contain, 91 any( 92 transform_label in top_level_transform_label 93 for top_level_transform_label in top_level_transform_labels)) 94 95 96 def _assert_transform_equal( 97 test_case, 98 expected_pipeline_proto, 99 expected_transform_id, 100 actual_pipeline_proto, 101 actual_transform_id): 102 """Asserts the equivalence between transforms from two given pipelines. """ 103 transform_proto1 = expected_pipeline_proto.components.transforms[ 104 expected_transform_id] 105 transform_proto2 = actual_pipeline_proto.components.transforms[ 106 actual_transform_id] 107 test_case.assertEqual(transform_proto1.spec.urn, transform_proto2.spec.urn) 108 # Skipping payload checking because PTransforms of the same functionality 109 # could generate different payloads. 110 test_case.assertEqual( 111 len(transform_proto1.subtransforms), len(transform_proto2.subtransforms)) 112 test_case.assertSetEqual( 113 set(transform_proto1.inputs), set(transform_proto2.inputs)) 114 test_case.assertSetEqual( 115 set(transform_proto1.outputs), set(transform_proto2.outputs))