github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/snippets/transforms/aggregation/combineperkey.py (about)

     1  # coding=utf-8
     2  #
     3  # Licensed to the Apache Software Foundation (ASF) under one or more
     4  # contributor license agreements.  See the NOTICE file distributed with
     5  # this work for additional information regarding copyright ownership.
     6  # The ASF licenses this file to You under the Apache License, Version 2.0
     7  # (the "License"); you may not use this file except in compliance with
     8  # the License.  You may obtain a copy of the License at
     9  #
    10  #    http://www.apache.org/licenses/LICENSE-2.0
    11  #
    12  # Unless required by applicable law or agreed to in writing, software
    13  # distributed under the License is distributed on an "AS IS" BASIS,
    14  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    15  # See the License for the specific language governing permissions and
    16  # limitations under the License.
    17  #
    18  
    19  # pytype: skip-file
    20  
    21  
    22  def combineperkey_simple(test=None):
    23    # [START combineperkey_simple]
    24    import apache_beam as beam
    25  
    26    with beam.Pipeline() as pipeline:
    27      total = (
    28          pipeline
    29          | 'Create plant counts' >> beam.Create([
    30              ('🥕', 3),
    31              ('🥕', 2),
    32              ('🍆', 1),
    33              ('🍅', 4),
    34              ('🍅', 5),
    35              ('🍅', 3),
    36          ])
    37          | 'Sum' >> beam.CombinePerKey(sum)
    38          | beam.Map(print))
    39      # [END combineperkey_simple]
    40      if test:
    41        test(total)
    42  
    43  
    44  def combineperkey_function(test=None):
    45    # [START combineperkey_function]
    46    import apache_beam as beam
    47  
    48    def saturated_sum(values):
    49      max_value = 8
    50      return min(sum(values), max_value)
    51  
    52    with beam.Pipeline() as pipeline:
    53      saturated_total = (
    54          pipeline
    55          | 'Create plant counts' >> beam.Create([
    56              ('🥕', 3),
    57              ('🥕', 2),
    58              ('🍆', 1),
    59              ('🍅', 4),
    60              ('🍅', 5),
    61              ('🍅', 3),
    62          ])
    63          | 'Saturated sum' >> beam.CombinePerKey(saturated_sum)
    64          | beam.Map(print))
    65      # [END combineperkey_function]
    66      if test:
    67        test(saturated_total)
    68  
    69  
    70  def combineperkey_lambda(test=None):
    71    # [START combineperkey_lambda]
    72    import apache_beam as beam
    73  
    74    with beam.Pipeline() as pipeline:
    75      saturated_total = (
    76          pipeline
    77          | 'Create plant counts' >> beam.Create([
    78              ('🥕', 3),
    79              ('🥕', 2),
    80              ('🍆', 1),
    81              ('🍅', 4),
    82              ('🍅', 5),
    83              ('🍅', 3),
    84          ])
    85          | 'Saturated sum' >>
    86          beam.CombinePerKey(lambda values: min(sum(values), 8))
    87          | beam.Map(print))
    88      # [END combineperkey_lambda]
    89      if test:
    90        test(saturated_total)
    91  
    92  
    93  def combineperkey_multiple_arguments(test=None):
    94    # [START combineperkey_multiple_arguments]
    95    import apache_beam as beam
    96  
    97    with beam.Pipeline() as pipeline:
    98      saturated_total = (
    99          pipeline
   100          | 'Create plant counts' >> beam.Create([
   101              ('🥕', 3),
   102              ('🥕', 2),
   103              ('🍆', 1),
   104              ('🍅', 4),
   105              ('🍅', 5),
   106              ('🍅', 3),
   107          ])
   108          | 'Saturated sum' >> beam.CombinePerKey(
   109              lambda values, max_value: min(sum(values), max_value), max_value=8)
   110          | beam.Map(print))
   111      # [END combineperkey_multiple_arguments]
   112      if test:
   113        test(saturated_total)
   114  
   115  
   116  def combineperkey_side_inputs_singleton(test=None):
   117    # [START combineperkey_side_inputs_singleton]
   118    import apache_beam as beam
   119  
   120    with beam.Pipeline() as pipeline:
   121      max_value = pipeline | 'Create max_value' >> beam.Create([8])
   122  
   123      saturated_total = (
   124          pipeline
   125          | 'Create plant counts' >> beam.Create([
   126              ('🥕', 3),
   127              ('🥕', 2),
   128              ('🍆', 1),
   129              ('🍅', 4),
   130              ('🍅', 5),
   131              ('🍅', 3),
   132          ])
   133          | 'Saturated sum' >> beam.CombinePerKey(
   134              lambda values,
   135              max_value: min(sum(values), max_value),
   136              max_value=beam.pvalue.AsSingleton(max_value))
   137          | beam.Map(print))
   138      # [END combineperkey_side_inputs_singleton]
   139      if test:
   140        test(saturated_total)
   141  
   142  
   143  def combineperkey_side_inputs_iter(test=None):
   144    # [START combineperkey_side_inputs_iter]
   145    import apache_beam as beam
   146  
   147    def bounded_sum(values, data_range):
   148      min_value = min(data_range)
   149      result = sum(values)
   150      if result < min_value:
   151        return min_value
   152      max_value = max(data_range)
   153      if result > max_value:
   154        return max_value
   155      return result
   156  
   157    with beam.Pipeline() as pipeline:
   158      data_range = pipeline | 'Create data_range' >> beam.Create([2, 4, 8])
   159  
   160      bounded_total = (
   161          pipeline
   162          | 'Create plant counts' >> beam.Create([
   163              ('🥕', 3),
   164              ('🥕', 2),
   165              ('🍆', 1),
   166              ('🍅', 4),
   167              ('🍅', 5),
   168              ('🍅', 3),
   169          ])
   170          | 'Bounded sum' >> beam.CombinePerKey(
   171              bounded_sum, data_range=beam.pvalue.AsIter(data_range))
   172          | beam.Map(print))
   173      # [END combineperkey_side_inputs_iter]
   174      if test:
   175        test(bounded_total)
   176  
   177  
   178  def combineperkey_side_inputs_dict(test=None):
   179    # [START combineperkey_side_inputs_dict]
   180    import apache_beam as beam
   181  
   182    def bounded_sum(values, data_range):
   183      min_value = data_range['min']
   184      result = sum(values)
   185      if result < min_value:
   186        return min_value
   187      max_value = data_range['max']
   188      if result > max_value:
   189        return max_value
   190      return result
   191  
   192    with beam.Pipeline() as pipeline:
   193      data_range = pipeline | 'Create data_range' >> beam.Create([
   194          ('min', 2),
   195          ('max', 8),
   196      ])
   197  
   198      bounded_total = (
   199          pipeline
   200          | 'Create plant counts' >> beam.Create([
   201              ('🥕', 3),
   202              ('🥕', 2),
   203              ('🍆', 1),
   204              ('🍅', 4),
   205              ('🍅', 5),
   206              ('🍅', 3),
   207          ])
   208          | 'Bounded sum' >> beam.CombinePerKey(
   209              bounded_sum, data_range=beam.pvalue.AsDict(data_range))
   210          | beam.Map(print))
   211      # [END combineperkey_side_inputs_dict]
   212      if test:
   213        test(bounded_total)
   214  
   215  
   216  def combineperkey_combinefn(test=None):
   217    # [START combineperkey_combinefn]
   218    import apache_beam as beam
   219  
   220    class AverageFn(beam.CombineFn):
   221      def create_accumulator(self):
   222        sum = 0.0
   223        count = 0
   224        accumulator = sum, count
   225        return accumulator
   226  
   227      def add_input(self, accumulator, input):
   228        sum, count = accumulator
   229        return sum + input, count + 1
   230  
   231      def merge_accumulators(self, accumulators):
   232        # accumulators = [(sum1, count1), (sum2, count2), (sum3, count3), ...]
   233        sums, counts = zip(*accumulators)
   234        # sums = [sum1, sum2, sum3, ...]
   235        # counts = [count1, count2, count3, ...]
   236        return sum(sums), sum(counts)
   237  
   238      def extract_output(self, accumulator):
   239        sum, count = accumulator
   240        if count == 0:
   241          return float('NaN')
   242        return sum / count
   243  
   244    with beam.Pipeline() as pipeline:
   245      average = (
   246          pipeline
   247          | 'Create plant counts' >> beam.Create([
   248              ('🥕', 3),
   249              ('🥕', 2),
   250              ('🍆', 1),
   251              ('🍅', 4),
   252              ('🍅', 5),
   253              ('🍅', 3),
   254          ])
   255          | 'Average' >> beam.CombinePerKey(AverageFn())
   256          | beam.Map(print))
   257      # [END combineperkey_combinefn]
   258      if test:
   259        test(average)