github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/combinefn_lifecycle_pipeline.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  # pytype: skip-file
    19  
    20  from typing import Set
    21  from typing import Tuple
    22  
    23  import apache_beam as beam
    24  from apache_beam.options.pipeline_options import TypeOptions
    25  from apache_beam.testing.util import assert_that
    26  from apache_beam.testing.util import equal_to
    27  from apache_beam.transforms import combiners
    28  from apache_beam.transforms import trigger
    29  from apache_beam.transforms import userstate
    30  from apache_beam.transforms import window
    31  from apache_beam.typehints import with_input_types
    32  from apache_beam.typehints import with_output_types
    33  
    34  
    35  @with_input_types(int)
    36  @with_output_types(int)
    37  class CallSequenceEnforcingCombineFn(beam.CombineFn):
    38    instances = set()  # type: Set[CallSequenceEnforcingCombineFn]
    39  
    40    def __init__(self):
    41      super().__init__()
    42      self._setup_called = False
    43      self._teardown_called = False
    44  
    45    def setup(self, *args, **kwargs):
    46      assert not self._setup_called, 'setup should not be called twice'
    47      assert not self._teardown_called, 'setup should be called before teardown'
    48      # Keep track of instances so that we can check if teardown is called
    49      # properly after pipeline execution.
    50      self.instances.add(self)
    51      self._setup_called = True
    52  
    53    def create_accumulator(self, *args, **kwargs):
    54      assert self._setup_called, 'setup should have been called'
    55      assert not self._teardown_called, 'teardown should not have been called'
    56      return 0
    57  
    58    def add_input(self, mutable_accumulator, element, *args, **kwargs):
    59      assert self._setup_called, 'setup should have been called'
    60      assert not self._teardown_called, 'teardown should not have been called'
    61      mutable_accumulator += element
    62      return mutable_accumulator
    63  
    64    def add_inputs(self, mutable_accumulator, elements, *args, **kwargs):
    65      return self.add_input(mutable_accumulator, sum(elements))
    66  
    67    def merge_accumulators(self, accumulators, *args, **kwargs):
    68      assert self._setup_called, 'setup should have been called'
    69      assert not self._teardown_called, 'teardown should not have been called'
    70      return sum(accumulators)
    71  
    72    def extract_output(self, accumulator, *args, **kwargs):
    73      assert self._setup_called, 'setup should have been called'
    74      assert not self._teardown_called, 'teardown should not have been called'
    75      return accumulator
    76  
    77    def teardown(self, *args, **kwargs):
    78      assert self._setup_called, 'setup should have been called'
    79      assert not self._teardown_called, 'teardown should not be called twice'
    80      self._teardown_called = True
    81  
    82  
    83  @with_input_types(Tuple[None, str])
    84  @with_output_types(Tuple[int, str])
    85  class IndexAssigningDoFn(beam.DoFn):
    86    state_param = beam.DoFn.StateParam(
    87        userstate.CombiningValueStateSpec(
    88            'index', beam.coders.VarIntCoder(), CallSequenceEnforcingCombineFn()))
    89  
    90    def process(self, element, state=state_param):
    91      _, value = element
    92      current_index = state.read()
    93      yield current_index, value
    94      state.add(1)
    95  
    96  
    97  def run_combine(pipeline, input_elements=5, lift_combiners=True):
    98    # Calculate the expected result, which is the sum of an arithmetic sequence.
    99    # By default, this is equal to: 0 + 1 + 2 + 3 + 4 = 10
   100    expected_result = input_elements * (input_elements - 1) / 2
   101  
   102    # Enable runtime type checking in order to cover TypeCheckCombineFn by
   103    # the test.
   104    pipeline.get_pipeline_options().view_as(TypeOptions).runtime_type_check = True
   105    pipeline.get_pipeline_options().view_as(
   106        TypeOptions).allow_unsafe_triggers = True
   107  
   108    with pipeline as p:
   109      pcoll = p | 'Start' >> beam.Create(range(input_elements))
   110  
   111      # Certain triggers, such as AfterCount, are incompatible with combiner
   112      # lifting. We can use that fact to prevent combiners from being lifted.
   113      if not lift_combiners:
   114        pcoll |= beam.WindowInto(
   115            window.GlobalWindows(),
   116            trigger=trigger.AfterCount(input_elements),
   117            accumulation_mode=trigger.AccumulationMode.DISCARDING)
   118  
   119      # Pass an additional 'None' in order to cover _CurriedFn by the test.
   120      pcoll |= 'Do' >> beam.CombineGlobally(
   121          combiners.SingleInputTupleCombineFn(
   122              CallSequenceEnforcingCombineFn(), CallSequenceEnforcingCombineFn()),
   123          None).with_fanout(fanout=1)
   124      assert_that(pcoll, equal_to([(expected_result, expected_result)]))
   125  
   126  
   127  def run_pardo(pipeline, input_elements=10):
   128    with pipeline as p:
   129      _ = (
   130          p
   131          | 'Start' >> beam.Create(('Hello' for _ in range(input_elements)))
   132          | 'KeyWithNone' >> beam.Map(lambda elem: (None, elem))
   133          | 'Do' >> beam.ParDo(IndexAssigningDoFn()))