github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/combiners.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """A library of basic combiner PTransform subclasses.""" 19 20 # pytype: skip-file 21 22 import copy 23 import heapq 24 import itertools 25 import operator 26 import random 27 from typing import Any 28 from typing import Dict 29 from typing import Iterable 30 from typing import List 31 from typing import Set 32 from typing import Tuple 33 from typing import TypeVar 34 from typing import Union 35 36 import numpy as np 37 38 from apache_beam import typehints 39 from apache_beam.transforms import core 40 from apache_beam.transforms import cy_combiners 41 from apache_beam.transforms import ptransform 42 from apache_beam.transforms import window 43 from apache_beam.transforms.display import DisplayDataItem 44 from apache_beam.typehints import with_input_types 45 from apache_beam.typehints import with_output_types 46 from apache_beam.utils.timestamp import Duration 47 from apache_beam.utils.timestamp import Timestamp 48 49 __all__ = [ 50 'Count', 51 'Mean', 52 'Sample', 53 'Top', 54 'ToDict', 55 'ToList', 56 'ToSet', 57 'Latest', 58 'CountCombineFn', 59 'MeanCombineFn', 60 'SampleCombineFn', 61 'TopCombineFn', 62 'ToDictCombineFn', 63 'ToListCombineFn', 64 'ToSetCombineFn', 65 'LatestCombineFn', 66 ] 67 68 # Type variables 69 T = TypeVar('T') 70 K = TypeVar('K') 71 V = TypeVar('V') 72 TimestampType = Union[int, float, Timestamp, Duration] 73 74 75 class CombinerWithoutDefaults(ptransform.PTransform): 76 """Super class to inherit without_defaults to built-in Combiners.""" 77 def __init__(self, has_defaults=True): 78 super().__init__() 79 self.has_defaults = has_defaults 80 81 def with_defaults(self, has_defaults=True): 82 new = copy.copy(self) 83 new.has_defaults = has_defaults 84 return new 85 86 def without_defaults(self): 87 return self.with_defaults(False) 88 89 90 class Mean(object): 91 """Combiners for computing arithmetic means of elements.""" 92 class Globally(CombinerWithoutDefaults): 93 """combiners.Mean.Globally computes the arithmetic mean of the elements.""" 94 def expand(self, pcoll): 95 if self.has_defaults: 96 return pcoll | core.CombineGlobally(MeanCombineFn()) 97 else: 98 return pcoll | core.CombineGlobally(MeanCombineFn()).without_defaults() 99 100 class PerKey(ptransform.PTransform): 101 """combiners.Mean.PerKey finds the means of the values for each key.""" 102 def expand(self, pcoll): 103 return pcoll | core.CombinePerKey(MeanCombineFn()) 104 105 106 # TODO(laolu): This type signature is overly restrictive. This should be 107 # more general. 108 @with_input_types(Union[float, int, np.int64, np.float64]) 109 @with_output_types(float) 110 class MeanCombineFn(core.CombineFn): 111 """CombineFn for computing an arithmetic mean.""" 112 def create_accumulator(self): 113 return (0, 0) 114 115 def add_input(self, sum_count, element): 116 (sum_, count) = sum_count 117 return sum_ + element, count + 1 118 119 def merge_accumulators(self, accumulators): 120 sums, counts = zip(*accumulators) 121 return sum(sums), sum(counts) 122 123 def extract_output(self, sum_count): 124 (sum_, count) = sum_count 125 if count == 0: 126 return float('NaN') 127 return sum_ / float(count) 128 129 def for_input_type(self, input_type): 130 if input_type is int: 131 return cy_combiners.MeanInt64Fn() 132 elif input_type is float: 133 return cy_combiners.MeanFloatFn() 134 return self 135 136 137 class Count(object): 138 """Combiners for counting elements.""" 139 @with_input_types(T) 140 @with_output_types(int) 141 class Globally(CombinerWithoutDefaults): 142 """combiners.Count.Globally counts the total number of elements.""" 143 def expand(self, pcoll): 144 if self.has_defaults: 145 return pcoll | core.CombineGlobally(CountCombineFn()) 146 else: 147 return pcoll | core.CombineGlobally(CountCombineFn()).without_defaults() 148 149 @with_input_types(Tuple[K, V]) 150 @with_output_types(Tuple[K, int]) 151 class PerKey(ptransform.PTransform): 152 """combiners.Count.PerKey counts how many elements each unique key has.""" 153 def expand(self, pcoll): 154 return pcoll | core.CombinePerKey(CountCombineFn()) 155 156 @with_input_types(T) 157 @with_output_types(Tuple[T, int]) 158 class PerElement(ptransform.PTransform): 159 """combiners.Count.PerElement counts how many times each element occurs.""" 160 def expand(self, pcoll): 161 paired_with_void_type = typehints.Tuple[pcoll.element_type, Any] 162 output_type = typehints.KV[pcoll.element_type, int] 163 return ( 164 pcoll 165 | ( 166 '%s:PairWithVoid' % self.label >> core.Map( 167 lambda x: (x, None)).with_output_types(paired_with_void_type)) 168 | core.CombinePerKey(CountCombineFn()).with_output_types(output_type)) 169 170 171 @with_input_types(Any) 172 @with_output_types(int) 173 class CountCombineFn(core.CombineFn): 174 """CombineFn for computing PCollection size.""" 175 def create_accumulator(self): 176 return 0 177 178 def add_input(self, accumulator, element): 179 return accumulator + 1 180 181 def add_inputs(self, accumulator, elements): 182 return accumulator + len(list(elements)) 183 184 def merge_accumulators(self, accumulators): 185 return sum(accumulators) 186 187 def extract_output(self, accumulator): 188 return accumulator 189 190 191 class Top(object): 192 """Combiners for obtaining extremal elements.""" 193 194 # pylint: disable=no-self-argument 195 @with_input_types(T) 196 @with_output_types(List[T]) 197 class Of(CombinerWithoutDefaults): 198 """Returns the n greatest elements in the PCollection. 199 200 This transform will retrieve the n greatest elements in the PCollection 201 to which it is applied, where "greatest" is determined by a 202 function supplied as the `key` or `reverse` arguments. 203 """ 204 def __init__(self, n, key=None, reverse=False): 205 """Creates a global Top operation. 206 207 The arguments 'key' and 'reverse' may be passed as keyword arguments, 208 and have the same meaning as for Python's sort functions. 209 210 Args: 211 n: number of elements to extract from pcoll. 212 key: (optional) a mapping of elements to a comparable key, similar to 213 the key argument of Python's sorting methods. 214 reverse: (optional) whether to order things smallest to largest, rather 215 than largest to smallest 216 """ 217 super().__init__() 218 self._n = n 219 self._key = key 220 self._reverse = reverse 221 222 def default_label(self): 223 return 'Top(%d)' % self._n 224 225 def expand(self, pcoll): 226 if pcoll.windowing.is_default(): 227 # This is a more efficient global algorithm. 228 top_per_bundle = pcoll | core.ParDo( 229 _TopPerBundle(self._n, self._key, self._reverse)) 230 # If pcoll is empty, we can't guarantee that top_per_bundle 231 # won't be empty, so inject at least one empty accumulator 232 # so that downstream is guaranteed to produce non-empty output. 233 empty_bundle = ( 234 pcoll.pipeline | core.Create([(None, [])]).with_output_types( 235 top_per_bundle.element_type)) 236 return ((top_per_bundle, empty_bundle) | core.Flatten() 237 | core.GroupByKey() 238 | core.ParDo( 239 _MergeTopPerBundle(self._n, self._key, self._reverse))) 240 else: 241 if self.has_defaults: 242 return pcoll | core.CombineGlobally( 243 TopCombineFn(self._n, self._key, self._reverse)) 244 else: 245 return pcoll | core.CombineGlobally( 246 TopCombineFn(self._n, self._key, 247 self._reverse)).without_defaults() 248 249 @with_input_types(Tuple[K, V]) 250 @with_output_types(Tuple[K, List[V]]) 251 class PerKey(ptransform.PTransform): 252 """Identifies the N greatest elements associated with each key. 253 254 This transform will produce a PCollection mapping unique keys in the input 255 PCollection to the n greatest elements with which they are associated, where 256 "greatest" is determined by a function supplied as the `key` or 257 `reverse` arguments. 258 """ 259 def __init__(self, n, key=None, reverse=False): 260 """Creates a per-key Top operation. 261 262 The arguments 'key' and 'reverse' may be passed as keyword arguments, 263 and have the same meaning as for Python's sort functions. 264 265 Args: 266 n: number of elements to extract from pcoll. 267 key: (optional) a mapping of elements to a comparable key, similar to 268 the key argument of Python's sorting methods. 269 reverse: (optional) whether to order things smallest to largest, rather 270 than largest to smallest 271 """ 272 self._n = n 273 self._key = key 274 self._reverse = reverse 275 276 def default_label(self): 277 return 'TopPerKey(%d)' % self._n 278 279 def expand(self, pcoll): 280 """Expands the transform. 281 282 Raises TypeCheckError: If the output type of the input PCollection is not 283 compatible with Tuple[A, B]. 284 285 Args: 286 pcoll: PCollection to process 287 288 Returns: 289 the PCollection containing the result. 290 """ 291 return pcoll | core.CombinePerKey( 292 TopCombineFn(self._n, self._key, self._reverse)) 293 294 @staticmethod 295 @ptransform.ptransform_fn 296 def Largest(pcoll, n, has_defaults=True, key=None): 297 """Obtain a list of the greatest N elements in a PCollection.""" 298 if has_defaults: 299 return pcoll | Top.Of(n, key) 300 else: 301 return pcoll | Top.Of(n, key).without_defaults() 302 303 @staticmethod 304 @ptransform.ptransform_fn 305 def Smallest(pcoll, n, has_defaults=True, key=None): 306 """Obtain a list of the least N elements in a PCollection.""" 307 if has_defaults: 308 return pcoll | Top.Of(n, key, reverse=True) 309 else: 310 return pcoll | Top.Of(n, key, reverse=True).without_defaults() 311 312 @staticmethod 313 @ptransform.ptransform_fn 314 def LargestPerKey(pcoll, n, key=None): 315 """Identifies the N greatest elements associated with each key.""" 316 return pcoll | Top.PerKey(n, key) 317 318 @staticmethod 319 @ptransform.ptransform_fn 320 def SmallestPerKey(pcoll, n, *, key=None, reverse=None): 321 """Identifies the N least elements associated with each key.""" 322 return pcoll | Top.PerKey(n, key, reverse=True) 323 324 325 @with_input_types(T) 326 @with_output_types(Tuple[None, List[T]]) 327 class _TopPerBundle(core.DoFn): 328 def __init__(self, n, key, reverse): 329 self._n = n 330 self._compare = operator.gt if reverse else None 331 self._key = key 332 333 def start_bundle(self): 334 self._heap = [] 335 336 def process(self, element): 337 if self._compare or self._key: 338 element = cy_combiners.ComparableValue(element, self._compare, self._key) 339 if len(self._heap) < self._n: 340 heapq.heappush(self._heap, element) 341 else: 342 heapq.heappushpop(self._heap, element) 343 344 def finish_bundle(self): 345 # Though sorting here results in more total work, this allows us to 346 # skip most elements in the reducer. 347 # Essentially, given s map bundles, we are trading about O(sn) compares in 348 # the (single) reducer for O(sn log n) compares across all mappers. 349 self._heap.sort() 350 351 # Unwrap to avoid serialization via pickle. 352 if self._compare or self._key: 353 yield window.GlobalWindows.windowed_value( 354 (None, [wrapper.value for wrapper in self._heap])) 355 else: 356 yield window.GlobalWindows.windowed_value((None, self._heap)) 357 358 359 @with_input_types(Tuple[None, Iterable[List[T]]]) 360 @with_output_types(List[T]) 361 class _MergeTopPerBundle(core.DoFn): 362 def __init__(self, n, key, reverse): 363 self._n = n 364 self._compare = operator.gt if reverse else None 365 self._key = key 366 367 def process(self, key_and_bundles): 368 _, bundles = key_and_bundles 369 370 def push(hp, e): 371 if len(hp) < self._n: 372 heapq.heappush(hp, e) 373 return False 374 elif e < hp[0]: 375 # Because _TopPerBundle returns sorted lists, all other elements 376 # will also be smaller. 377 return True 378 else: 379 heapq.heappushpop(hp, e) 380 return False 381 382 if self._compare or self._key: 383 heapc = [] # type: List[cy_combiners.ComparableValue] 384 for bundle in bundles: 385 if not heapc: 386 heapc = [ 387 cy_combiners.ComparableValue(element, self._compare, self._key) 388 for element in bundle 389 ] 390 continue 391 # TODO(https://github.com/apache/beam/issues/21205): Remove this 392 # workaround once legacy dataflow correctly handles coders with 393 # combiner packing and/or is deprecated. 394 if not isinstance(bundle, list): 395 bundle = list(bundle) 396 for element in reversed(bundle): 397 if push(heapc, 398 cy_combiners.ComparableValue(element, 399 self._compare, 400 self._key)): 401 break 402 heapc.sort() 403 yield [wrapper.value for wrapper in reversed(heapc)] 404 405 else: 406 heap = [] 407 for bundle in bundles: 408 # TODO(https://github.com/apache/beam/issues/21205): Remove this 409 # workaround once legacy dataflow correctly handles coders with 410 # combiner packing and/or is deprecated. 411 if not isinstance(bundle, list): 412 bundle = list(bundle) 413 if not heap: 414 heap = bundle 415 continue 416 for element in reversed(bundle): 417 if push(heap, element): 418 break 419 heap.sort() 420 yield heap[::-1] 421 422 423 @with_input_types(T) 424 @with_output_types(List[T]) 425 class TopCombineFn(core.CombineFn): 426 """CombineFn doing the combining for all of the Top transforms. 427 428 This CombineFn uses a `key` or `reverse` operator to rank the elements. 429 430 Args: 431 key: (optional) a mapping of elements to a comparable key, similar to 432 the key argument of Python's sorting methods. 433 reverse: (optional) whether to order things smallest to largest, rather 434 than largest to smallest 435 """ 436 def __init__(self, n, key=None, reverse=False): 437 self._n = n 438 self._compare = operator.gt if reverse else operator.lt 439 self._key = key 440 441 def _hydrated_heap(self, heap): 442 if heap: 443 first = heap[0] 444 if isinstance(first, cy_combiners.ComparableValue): 445 if first.requires_hydration: 446 for comparable in heap: 447 assert comparable.requires_hydration 448 comparable.hydrate(self._compare, self._key) 449 assert not comparable.requires_hydration 450 return heap 451 else: 452 return heap 453 else: 454 return [ 455 cy_combiners.ComparableValue(element, self._compare, self._key) 456 for element in heap 457 ] 458 else: 459 return heap 460 461 def display_data(self): 462 return { 463 'n': self._n, 464 'compare': DisplayDataItem( 465 self._compare.__name__ if hasattr(self._compare, '__name__') else 466 self._compare.__class__.__name__).drop_if_none() 467 } 468 469 # The accumulator type is a tuple 470 # (bool, Union[List[T], List[ComparableValue[T]]) 471 # where the boolean indicates whether the second slot contains a List of T 472 # (False) or List of ComparableValue[T] (True). In either case, the List 473 # maintains heap invariance. When the contents of the List are 474 # ComparableValue[T] they either all 'requires_hydration' or none do. 475 # This accumulator representation allows us to minimize the data encoding 476 # overheads. Creation of ComparableValues is elided for performance reasons 477 # when there is no need for complicated comparison functions. 478 def create_accumulator(self, *args, **kwargs): 479 return (False, []) 480 481 def add_input(self, accumulator, element, *args, **kwargs): 482 # Caching to avoid paying the price of variadic expansion of args / kwargs 483 # when it's not needed (for the 'if' case below). 484 holds_comparables, heap = accumulator 485 if self._compare is not operator.lt or self._key: 486 heap = self._hydrated_heap(heap) 487 holds_comparables = True 488 else: 489 assert not holds_comparables 490 491 comparable = ( 492 cy_combiners.ComparableValue(element, self._compare, self._key) 493 if holds_comparables else element) 494 495 if len(heap) < self._n: 496 heapq.heappush(heap, comparable) 497 else: 498 heapq.heappushpop(heap, comparable) 499 return (holds_comparables, heap) 500 501 def merge_accumulators(self, accumulators, *args, **kwargs): 502 result_heap = None 503 holds_comparables = None 504 for accumulator in accumulators: 505 holds_comparables, heap = accumulator 506 if self._compare is not operator.lt or self._key: 507 heap = self._hydrated_heap(heap) 508 holds_comparables = True 509 else: 510 assert not holds_comparables 511 512 if result_heap is None: 513 result_heap = heap 514 else: 515 for comparable in heap: 516 _, result_heap = self.add_input( 517 (holds_comparables, result_heap), 518 comparable.value if holds_comparables else comparable) 519 520 assert result_heap is not None and holds_comparables is not None 521 return (holds_comparables, result_heap) 522 523 def compact(self, accumulator, *args, **kwargs): 524 holds_comparables, heap = accumulator 525 # Unwrap to avoid serialization via pickle. 526 if holds_comparables: 527 return (False, [comparable.value for comparable in heap]) 528 else: 529 return accumulator 530 531 def extract_output(self, accumulator, *args, **kwargs): 532 holds_comparables, heap = accumulator 533 if self._compare is not operator.lt or self._key: 534 if not holds_comparables: 535 heap = self._hydrated_heap(heap) 536 holds_comparables = True 537 else: 538 assert not holds_comparables 539 540 assert len(heap) <= self._n 541 heap.sort(reverse=True) 542 return [ 543 comparable.value if holds_comparables else comparable 544 for comparable in heap 545 ] 546 547 548 class Largest(TopCombineFn): 549 def default_label(self): 550 return 'Largest(%s)' % self._n 551 552 553 class Smallest(TopCombineFn): 554 def __init__(self, n): 555 super().__init__(n, reverse=True) 556 557 def default_label(self): 558 return 'Smallest(%s)' % self._n 559 560 561 class Sample(object): 562 """Combiners for sampling n elements without replacement.""" 563 564 # pylint: disable=no-self-argument 565 566 @with_input_types(T) 567 @with_output_types(List[T]) 568 class FixedSizeGlobally(CombinerWithoutDefaults): 569 """Sample n elements from the input PCollection without replacement.""" 570 def __init__(self, n): 571 super().__init__() 572 self._n = n 573 574 def expand(self, pcoll): 575 if self.has_defaults: 576 return pcoll | core.CombineGlobally(SampleCombineFn(self._n)) 577 else: 578 return pcoll | core.CombineGlobally(SampleCombineFn( 579 self._n)).without_defaults() 580 581 def display_data(self): 582 return {'n': self._n} 583 584 def default_label(self): 585 return 'FixedSizeGlobally(%d)' % self._n 586 587 @with_input_types(Tuple[K, V]) 588 @with_output_types(Tuple[K, List[V]]) 589 class FixedSizePerKey(ptransform.PTransform): 590 """Sample n elements associated with each key without replacement.""" 591 def __init__(self, n): 592 self._n = n 593 594 def expand(self, pcoll): 595 return pcoll | core.CombinePerKey(SampleCombineFn(self._n)) 596 597 def display_data(self): 598 return {'n': self._n} 599 600 def default_label(self): 601 return 'FixedSizePerKey(%d)' % self._n 602 603 604 @with_input_types(T) 605 @with_output_types(List[T]) 606 class SampleCombineFn(core.CombineFn): 607 """CombineFn for all Sample transforms.""" 608 def __init__(self, n): 609 super().__init__() 610 # Most of this combiner's work is done by a TopCombineFn. We could just 611 # subclass TopCombineFn to make this class, but since sampling is not 612 # really a kind of Top operation, we use a TopCombineFn instance as a 613 # helper instead. 614 self._top_combiner = TopCombineFn(n) 615 616 def setup(self): 617 self._top_combiner.setup() 618 619 def create_accumulator(self): 620 return self._top_combiner.create_accumulator() 621 622 def add_input(self, heap, element): 623 # Before passing elements to the Top combiner, we pair them with random 624 # numbers. The elements with the n largest random number "keys" will be 625 # selected for the output. 626 return self._top_combiner.add_input(heap, (random.random(), element)) 627 628 def merge_accumulators(self, heaps): 629 return self._top_combiner.merge_accumulators(heaps) 630 631 def compact(self, heap): 632 return self._top_combiner.compact(heap) 633 634 def extract_output(self, heap): 635 # Here we strip off the random number keys we added in add_input. 636 return [e for _, e in self._top_combiner.extract_output(heap)] 637 638 def teardown(self): 639 self._top_combiner.teardown() 640 641 642 class _TupleCombineFnBase(core.CombineFn): 643 def __init__(self, *combiners, merge_accumulators_batch_size=None): 644 self._combiners = [core.CombineFn.maybe_from_callable(c) for c in combiners] 645 self._named_combiners = combiners 646 # If the `merge_accumulators_batch_size` value is not specified, we chose a 647 # bounded default that is inversely proportional to the number of 648 # accumulators in merged tuples. 649 num_combiners = max(1, len(combiners)) 650 self._merge_accumulators_batch_size = ( 651 merge_accumulators_batch_size or max(10, 1000 // num_combiners)) 652 653 def display_data(self): 654 combiners = [ 655 c.__name__ if hasattr(c, '__name__') else c.__class__.__name__ 656 for c in self._named_combiners 657 ] 658 return { 659 'combiners': str(combiners), 660 'merge_accumulators_batch_size': self._merge_accumulators_batch_size 661 } 662 663 def setup(self, *args, **kwargs): 664 for c in self._combiners: 665 c.setup(*args, **kwargs) 666 667 def create_accumulator(self, *args, **kwargs): 668 return [c.create_accumulator(*args, **kwargs) for c in self._combiners] 669 670 def merge_accumulators(self, accumulators, *args, **kwargs): 671 # Make sure that `accumulators` is an iterator (so that the position is 672 # remembered). 673 accumulators = iter(accumulators) 674 result = next(accumulators) 675 while True: 676 # Load accumulators into memory and merge in batches to decrease peak 677 # memory usage. 678 accumulators_batch = [result] + list( 679 itertools.islice(accumulators, self._merge_accumulators_batch_size)) 680 if len(accumulators_batch) == 1: 681 break 682 result = [ 683 c.merge_accumulators(a, *args, **kwargs) for c, 684 a in zip(self._combiners, zip(*accumulators_batch)) 685 ] 686 return result 687 688 def compact(self, accumulator, *args, **kwargs): 689 return [ 690 c.compact(a, *args, **kwargs) for c, 691 a in zip(self._combiners, accumulator) 692 ] 693 694 def extract_output(self, accumulator, *args, **kwargs): 695 return tuple( 696 c.extract_output(a, *args, **kwargs) for c, 697 a in zip(self._combiners, accumulator)) 698 699 def teardown(self, *args, **kwargs): 700 for c in reversed(self._combiners): 701 c.teardown(*args, **kwargs) 702 703 704 class TupleCombineFn(_TupleCombineFnBase): 705 """A combiner for combining tuples via a tuple of combiners. 706 707 Takes as input a tuple of N CombineFns and combines N-tuples by 708 combining the k-th element of each tuple with the k-th CombineFn, 709 outputting a new N-tuple of combined values. 710 """ 711 def add_input(self, accumulator, element, *args, **kwargs): 712 return [ 713 c.add_input(a, e, *args, **kwargs) for c, 714 a, 715 e in zip(self._combiners, accumulator, element) 716 ] 717 718 def with_common_input(self): 719 return SingleInputTupleCombineFn(*self._combiners) 720 721 722 class SingleInputTupleCombineFn(_TupleCombineFnBase): 723 """A combiner for combining a single value via a tuple of combiners. 724 725 Takes as input a tuple of N CombineFns and combines elements by 726 applying each CombineFn to each input, producing an N-tuple of 727 the outputs corresponding to each of the N CombineFn's outputs. 728 """ 729 def add_input(self, accumulator, element, *args, **kwargs): 730 return [ 731 c.add_input(a, element, *args, **kwargs) for c, 732 a in zip(self._combiners, accumulator) 733 ] 734 735 736 @with_input_types(T) 737 @with_output_types(List[T]) 738 class ToList(CombinerWithoutDefaults): 739 """A global CombineFn that condenses a PCollection into a single list.""" 740 def expand(self, pcoll): 741 if self.has_defaults: 742 return pcoll | self.label >> core.CombineGlobally(ToListCombineFn()) 743 else: 744 return pcoll | self.label >> core.CombineGlobally( 745 ToListCombineFn()).without_defaults() 746 747 748 @with_input_types(T) 749 @with_output_types(List[T]) 750 class ToListCombineFn(core.CombineFn): 751 """CombineFn for to_list.""" 752 def create_accumulator(self): 753 return [] 754 755 def add_input(self, accumulator, element): 756 accumulator.append(element) 757 return accumulator 758 759 def merge_accumulators(self, accumulators): 760 return sum(accumulators, []) 761 762 def extract_output(self, accumulator): 763 return accumulator 764 765 766 @with_input_types(Tuple[K, V]) 767 @with_output_types(Dict[K, V]) 768 class ToDict(CombinerWithoutDefaults): 769 """A global CombineFn that condenses a PCollection into a single dict. 770 771 PCollections should consist of 2-tuples, notionally (key, value) pairs. 772 If multiple values are associated with the same key, only one of the values 773 will be present in the resulting dict. 774 """ 775 def expand(self, pcoll): 776 if self.has_defaults: 777 return pcoll | self.label >> core.CombineGlobally(ToDictCombineFn()) 778 else: 779 return pcoll | self.label >> core.CombineGlobally( 780 ToDictCombineFn()).without_defaults() 781 782 783 @with_input_types(Tuple[K, V]) 784 @with_output_types(Dict[K, V]) 785 class ToDictCombineFn(core.CombineFn): 786 """CombineFn for to_dict.""" 787 def create_accumulator(self): 788 return {} 789 790 def add_input(self, accumulator, element): 791 key, value = element 792 accumulator[key] = value 793 return accumulator 794 795 def merge_accumulators(self, accumulators): 796 result = {} 797 for a in accumulators: 798 result.update(a) 799 return result 800 801 def extract_output(self, accumulator): 802 return accumulator 803 804 805 @with_input_types(T) 806 @with_output_types(Set[T]) 807 class ToSet(CombinerWithoutDefaults): 808 """A global CombineFn that condenses a PCollection into a set.""" 809 def expand(self, pcoll): 810 if self.has_defaults: 811 return pcoll | self.label >> core.CombineGlobally(ToSetCombineFn()) 812 else: 813 return pcoll | self.label >> core.CombineGlobally( 814 ToSetCombineFn()).without_defaults() 815 816 817 @with_input_types(T) 818 @with_output_types(Set[T]) 819 class ToSetCombineFn(core.CombineFn): 820 """CombineFn for ToSet.""" 821 def create_accumulator(self): 822 return set() 823 824 def add_input(self, accumulator, element): 825 accumulator.add(element) 826 return accumulator 827 828 def merge_accumulators(self, accumulators): 829 return set.union(*accumulators) 830 831 def extract_output(self, accumulator): 832 return accumulator 833 834 835 class _CurriedFn(core.CombineFn): 836 """Wrapped CombineFn with extra arguments.""" 837 def __init__(self, fn, args, kwargs): 838 self.fn = fn 839 self.args = args 840 self.kwargs = kwargs 841 842 def setup(self): 843 self.fn.setup(*self.args, **self.kwargs) 844 845 def create_accumulator(self): 846 return self.fn.create_accumulator(*self.args, **self.kwargs) 847 848 def add_input(self, accumulator, element): 849 return self.fn.add_input(accumulator, element, *self.args, **self.kwargs) 850 851 def merge_accumulators(self, accumulators): 852 return self.fn.merge_accumulators(accumulators, *self.args, **self.kwargs) 853 854 def compact(self, accumulator): 855 return self.fn.compact(accumulator, *self.args, **self.kwargs) 856 857 def extract_output(self, accumulator): 858 return self.fn.extract_output(accumulator, *self.args, **self.kwargs) 859 860 def teardown(self): 861 self.fn.teardown(*self.args, **self.kwargs) 862 863 def apply(self, elements): 864 return self.fn.apply(elements, *self.args, **self.kwargs) 865 866 867 def curry_combine_fn(fn, args, kwargs): 868 if not args and not kwargs: 869 return fn 870 else: 871 return _CurriedFn(fn, args, kwargs) 872 873 874 class PhasedCombineFnExecutor(object): 875 """Executor for phases of combine operations.""" 876 def __init__(self, phase, fn, args, kwargs): 877 878 self.combine_fn = curry_combine_fn(fn, args, kwargs) 879 880 if phase == 'all': 881 self.apply = self.full_combine 882 elif phase == 'add': 883 self.apply = self.add_only 884 elif phase == 'merge': 885 self.apply = self.merge_only 886 elif phase == 'extract': 887 self.apply = self.extract_only 888 elif phase == 'convert': 889 self.apply = self.convert_to_accumulator 890 else: 891 raise ValueError('Unexpected phase: %s' % phase) 892 893 def full_combine(self, elements): 894 return self.combine_fn.apply(elements) 895 896 def add_only(self, elements): 897 return self.combine_fn.add_inputs( 898 self.combine_fn.create_accumulator(), elements) 899 900 def merge_only(self, accumulators): 901 return self.combine_fn.merge_accumulators(accumulators) 902 903 def extract_only(self, accumulator): 904 return self.combine_fn.extract_output(accumulator) 905 906 def convert_to_accumulator(self, element): 907 return self.combine_fn.add_input( 908 self.combine_fn.create_accumulator(), element) 909 910 911 class Latest(object): 912 """Combiners for computing the latest element""" 913 @with_input_types(T) 914 @with_output_types(T) 915 class Globally(CombinerWithoutDefaults): 916 """Compute the element with the latest timestamp from a 917 PCollection.""" 918 @staticmethod 919 def add_timestamp(element, timestamp=core.DoFn.TimestampParam): 920 return [(element, timestamp)] 921 922 def expand(self, pcoll): 923 if self.has_defaults: 924 return ( 925 pcoll 926 | core.ParDo(self.add_timestamp).with_output_types( 927 Tuple[T, TimestampType]) 928 | core.CombineGlobally(LatestCombineFn())) 929 else: 930 return ( 931 pcoll 932 | core.ParDo(self.add_timestamp).with_output_types( 933 Tuple[T, TimestampType]) 934 | core.CombineGlobally(LatestCombineFn()).without_defaults()) 935 936 @with_input_types(Tuple[K, V]) 937 @with_output_types(Tuple[K, V]) 938 class PerKey(ptransform.PTransform): 939 """Compute elements with the latest timestamp for each key 940 from a keyed PCollection""" 941 @staticmethod 942 def add_timestamp(element, timestamp=core.DoFn.TimestampParam): 943 key, value = element 944 return [(key, (value, timestamp))] 945 946 def expand(self, pcoll): 947 return ( 948 pcoll 949 | core.ParDo(self.add_timestamp).with_output_types( 950 Tuple[K, Tuple[T, TimestampType]]) 951 | core.CombinePerKey(LatestCombineFn())) 952 953 954 @with_input_types(Tuple[T, TimestampType]) 955 @with_output_types(T) 956 class LatestCombineFn(core.CombineFn): 957 """CombineFn to get the element with the latest timestamp 958 from a PCollection.""" 959 def create_accumulator(self): 960 return (None, window.MIN_TIMESTAMP) 961 962 def add_input(self, accumulator, element): 963 if accumulator[1] > element[1]: 964 return accumulator 965 else: 966 return element 967 968 def merge_accumulators(self, accumulators): 969 result = self.create_accumulator() 970 for accumulator in accumulators: 971 result = self.add_input(result, accumulator) 972 return result 973 974 def extract_output(self, accumulator): 975 return accumulator[0]