github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/combinefn_lifecycle_pipeline.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 # pytype: skip-file 19 20 from typing import Set 21 from typing import Tuple 22 23 import apache_beam as beam 24 from apache_beam.options.pipeline_options import TypeOptions 25 from apache_beam.testing.util import assert_that 26 from apache_beam.testing.util import equal_to 27 from apache_beam.transforms import combiners 28 from apache_beam.transforms import trigger 29 from apache_beam.transforms import userstate 30 from apache_beam.transforms import window 31 from apache_beam.typehints import with_input_types 32 from apache_beam.typehints import with_output_types 33 34 35 @with_input_types(int) 36 @with_output_types(int) 37 class CallSequenceEnforcingCombineFn(beam.CombineFn): 38 instances = set() # type: Set[CallSequenceEnforcingCombineFn] 39 40 def __init__(self): 41 super().__init__() 42 self._setup_called = False 43 self._teardown_called = False 44 45 def setup(self, *args, **kwargs): 46 assert not self._setup_called, 'setup should not be called twice' 47 assert not self._teardown_called, 'setup should be called before teardown' 48 # Keep track of instances so that we can check if teardown is called 49 # properly after pipeline execution. 50 self.instances.add(self) 51 self._setup_called = True 52 53 def create_accumulator(self, *args, **kwargs): 54 assert self._setup_called, 'setup should have been called' 55 assert not self._teardown_called, 'teardown should not have been called' 56 return 0 57 58 def add_input(self, mutable_accumulator, element, *args, **kwargs): 59 assert self._setup_called, 'setup should have been called' 60 assert not self._teardown_called, 'teardown should not have been called' 61 mutable_accumulator += element 62 return mutable_accumulator 63 64 def add_inputs(self, mutable_accumulator, elements, *args, **kwargs): 65 return self.add_input(mutable_accumulator, sum(elements)) 66 67 def merge_accumulators(self, accumulators, *args, **kwargs): 68 assert self._setup_called, 'setup should have been called' 69 assert not self._teardown_called, 'teardown should not have been called' 70 return sum(accumulators) 71 72 def extract_output(self, accumulator, *args, **kwargs): 73 assert self._setup_called, 'setup should have been called' 74 assert not self._teardown_called, 'teardown should not have been called' 75 return accumulator 76 77 def teardown(self, *args, **kwargs): 78 assert self._setup_called, 'setup should have been called' 79 assert not self._teardown_called, 'teardown should not be called twice' 80 self._teardown_called = True 81 82 83 @with_input_types(Tuple[None, str]) 84 @with_output_types(Tuple[int, str]) 85 class IndexAssigningDoFn(beam.DoFn): 86 state_param = beam.DoFn.StateParam( 87 userstate.CombiningValueStateSpec( 88 'index', beam.coders.VarIntCoder(), CallSequenceEnforcingCombineFn())) 89 90 def process(self, element, state=state_param): 91 _, value = element 92 current_index = state.read() 93 yield current_index, value 94 state.add(1) 95 96 97 def run_combine(pipeline, input_elements=5, lift_combiners=True): 98 # Calculate the expected result, which is the sum of an arithmetic sequence. 99 # By default, this is equal to: 0 + 1 + 2 + 3 + 4 = 10 100 expected_result = input_elements * (input_elements - 1) / 2 101 102 # Enable runtime type checking in order to cover TypeCheckCombineFn by 103 # the test. 104 pipeline.get_pipeline_options().view_as(TypeOptions).runtime_type_check = True 105 pipeline.get_pipeline_options().view_as( 106 TypeOptions).allow_unsafe_triggers = True 107 108 with pipeline as p: 109 pcoll = p | 'Start' >> beam.Create(range(input_elements)) 110 111 # Certain triggers, such as AfterCount, are incompatible with combiner 112 # lifting. We can use that fact to prevent combiners from being lifted. 113 if not lift_combiners: 114 pcoll |= beam.WindowInto( 115 window.GlobalWindows(), 116 trigger=trigger.AfterCount(input_elements), 117 accumulation_mode=trigger.AccumulationMode.DISCARDING) 118 119 # Pass an additional 'None' in order to cover _CurriedFn by the test. 120 pcoll |= 'Do' >> beam.CombineGlobally( 121 combiners.SingleInputTupleCombineFn( 122 CallSequenceEnforcingCombineFn(), CallSequenceEnforcingCombineFn()), 123 None).with_fanout(fanout=1) 124 assert_that(pcoll, equal_to([(expected_result, expected_result)])) 125 126 127 def run_pardo(pipeline, input_elements=10): 128 with pipeline as p: 129 _ = ( 130 p 131 | 'Start' >> beam.Create(('Hello' for _ in range(input_elements))) 132 | 'KeyWithNone' >> beam.Map(lambda elem: (None, elem)) 133 | 'Do' >> beam.ParDo(IndexAssigningDoFn()))