github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/snippets/transforms/aggregation/combineperkey.py (about) 1 # coding=utf-8 2 # 3 # Licensed to the Apache Software Foundation (ASF) under one or more 4 # contributor license agreements. See the NOTICE file distributed with 5 # this work for additional information regarding copyright ownership. 6 # The ASF licenses this file to You under the Apache License, Version 2.0 7 # (the "License"); you may not use this file except in compliance with 8 # the License. You may obtain a copy of the License at 9 # 10 # http://www.apache.org/licenses/LICENSE-2.0 11 # 12 # Unless required by applicable law or agreed to in writing, software 13 # distributed under the License is distributed on an "AS IS" BASIS, 14 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 # See the License for the specific language governing permissions and 16 # limitations under the License. 17 # 18 19 # pytype: skip-file 20 21 22 def combineperkey_simple(test=None): 23 # [START combineperkey_simple] 24 import apache_beam as beam 25 26 with beam.Pipeline() as pipeline: 27 total = ( 28 pipeline 29 | 'Create plant counts' >> beam.Create([ 30 ('🥕', 3), 31 ('🥕', 2), 32 ('🍆', 1), 33 ('🍅', 4), 34 ('🍅', 5), 35 ('🍅', 3), 36 ]) 37 | 'Sum' >> beam.CombinePerKey(sum) 38 | beam.Map(print)) 39 # [END combineperkey_simple] 40 if test: 41 test(total) 42 43 44 def combineperkey_function(test=None): 45 # [START combineperkey_function] 46 import apache_beam as beam 47 48 def saturated_sum(values): 49 max_value = 8 50 return min(sum(values), max_value) 51 52 with beam.Pipeline() as pipeline: 53 saturated_total = ( 54 pipeline 55 | 'Create plant counts' >> beam.Create([ 56 ('🥕', 3), 57 ('🥕', 2), 58 ('🍆', 1), 59 ('🍅', 4), 60 ('🍅', 5), 61 ('🍅', 3), 62 ]) 63 | 'Saturated sum' >> beam.CombinePerKey(saturated_sum) 64 | beam.Map(print)) 65 # [END combineperkey_function] 66 if test: 67 test(saturated_total) 68 69 70 def combineperkey_lambda(test=None): 71 # [START combineperkey_lambda] 72 import apache_beam as beam 73 74 with beam.Pipeline() as pipeline: 75 saturated_total = ( 76 pipeline 77 | 'Create plant counts' >> beam.Create([ 78 ('🥕', 3), 79 ('🥕', 2), 80 ('🍆', 1), 81 ('🍅', 4), 82 ('🍅', 5), 83 ('🍅', 3), 84 ]) 85 | 'Saturated sum' >> 86 beam.CombinePerKey(lambda values: min(sum(values), 8)) 87 | beam.Map(print)) 88 # [END combineperkey_lambda] 89 if test: 90 test(saturated_total) 91 92 93 def combineperkey_multiple_arguments(test=None): 94 # [START combineperkey_multiple_arguments] 95 import apache_beam as beam 96 97 with beam.Pipeline() as pipeline: 98 saturated_total = ( 99 pipeline 100 | 'Create plant counts' >> beam.Create([ 101 ('🥕', 3), 102 ('🥕', 2), 103 ('🍆', 1), 104 ('🍅', 4), 105 ('🍅', 5), 106 ('🍅', 3), 107 ]) 108 | 'Saturated sum' >> beam.CombinePerKey( 109 lambda values, max_value: min(sum(values), max_value), max_value=8) 110 | beam.Map(print)) 111 # [END combineperkey_multiple_arguments] 112 if test: 113 test(saturated_total) 114 115 116 def combineperkey_side_inputs_singleton(test=None): 117 # [START combineperkey_side_inputs_singleton] 118 import apache_beam as beam 119 120 with beam.Pipeline() as pipeline: 121 max_value = pipeline | 'Create max_value' >> beam.Create([8]) 122 123 saturated_total = ( 124 pipeline 125 | 'Create plant counts' >> beam.Create([ 126 ('🥕', 3), 127 ('🥕', 2), 128 ('🍆', 1), 129 ('🍅', 4), 130 ('🍅', 5), 131 ('🍅', 3), 132 ]) 133 | 'Saturated sum' >> beam.CombinePerKey( 134 lambda values, 135 max_value: min(sum(values), max_value), 136 max_value=beam.pvalue.AsSingleton(max_value)) 137 | beam.Map(print)) 138 # [END combineperkey_side_inputs_singleton] 139 if test: 140 test(saturated_total) 141 142 143 def combineperkey_side_inputs_iter(test=None): 144 # [START combineperkey_side_inputs_iter] 145 import apache_beam as beam 146 147 def bounded_sum(values, data_range): 148 min_value = min(data_range) 149 result = sum(values) 150 if result < min_value: 151 return min_value 152 max_value = max(data_range) 153 if result > max_value: 154 return max_value 155 return result 156 157 with beam.Pipeline() as pipeline: 158 data_range = pipeline | 'Create data_range' >> beam.Create([2, 4, 8]) 159 160 bounded_total = ( 161 pipeline 162 | 'Create plant counts' >> beam.Create([ 163 ('🥕', 3), 164 ('🥕', 2), 165 ('🍆', 1), 166 ('🍅', 4), 167 ('🍅', 5), 168 ('🍅', 3), 169 ]) 170 | 'Bounded sum' >> beam.CombinePerKey( 171 bounded_sum, data_range=beam.pvalue.AsIter(data_range)) 172 | beam.Map(print)) 173 # [END combineperkey_side_inputs_iter] 174 if test: 175 test(bounded_total) 176 177 178 def combineperkey_side_inputs_dict(test=None): 179 # [START combineperkey_side_inputs_dict] 180 import apache_beam as beam 181 182 def bounded_sum(values, data_range): 183 min_value = data_range['min'] 184 result = sum(values) 185 if result < min_value: 186 return min_value 187 max_value = data_range['max'] 188 if result > max_value: 189 return max_value 190 return result 191 192 with beam.Pipeline() as pipeline: 193 data_range = pipeline | 'Create data_range' >> beam.Create([ 194 ('min', 2), 195 ('max', 8), 196 ]) 197 198 bounded_total = ( 199 pipeline 200 | 'Create plant counts' >> beam.Create([ 201 ('🥕', 3), 202 ('🥕', 2), 203 ('🍆', 1), 204 ('🍅', 4), 205 ('🍅', 5), 206 ('🍅', 3), 207 ]) 208 | 'Bounded sum' >> beam.CombinePerKey( 209 bounded_sum, data_range=beam.pvalue.AsDict(data_range)) 210 | beam.Map(print)) 211 # [END combineperkey_side_inputs_dict] 212 if test: 213 test(bounded_total) 214 215 216 def combineperkey_combinefn(test=None): 217 # [START combineperkey_combinefn] 218 import apache_beam as beam 219 220 class AverageFn(beam.CombineFn): 221 def create_accumulator(self): 222 sum = 0.0 223 count = 0 224 accumulator = sum, count 225 return accumulator 226 227 def add_input(self, accumulator, input): 228 sum, count = accumulator 229 return sum + input, count + 1 230 231 def merge_accumulators(self, accumulators): 232 # accumulators = [(sum1, count1), (sum2, count2), (sum3, count3), ...] 233 sums, counts = zip(*accumulators) 234 # sums = [sum1, sum2, sum3, ...] 235 # counts = [count1, count2, count3, ...] 236 return sum(sums), sum(counts) 237 238 def extract_output(self, accumulator): 239 sum, count = accumulator 240 if count == 0: 241 return float('NaN') 242 return sum / count 243 244 with beam.Pipeline() as pipeline: 245 average = ( 246 pipeline 247 | 'Create plant counts' >> beam.Create([ 248 ('🥕', 3), 249 ('🥕', 2), 250 ('🍆', 1), 251 ('🍅', 4), 252 ('🍅', 5), 253 ('🍅', 3), 254 ]) 255 | 'Average' >> beam.CombinePerKey(AverageFn()) 256 | beam.Map(print)) 257 # [END combineperkey_combinefn] 258 if test: 259 test(average)