github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/trigger.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Support for Apache Beam triggers. 19 20 Triggers control when in processing time windows get emitted. 21 """ 22 23 # pytype: skip-file 24 25 import collections 26 import copy 27 import logging 28 import numbers 29 from abc import ABCMeta 30 from abc import abstractmethod 31 from collections import abc as collections_abc # ambiguty with direct abc 32 from enum import Flag 33 from enum import auto 34 from itertools import zip_longest 35 36 from apache_beam.coders import coder_impl 37 from apache_beam.coders import observable 38 from apache_beam.portability.api import beam_runner_api_pb2 39 from apache_beam.transforms import combiners 40 from apache_beam.transforms import core 41 from apache_beam.transforms.timeutil import TimeDomain 42 from apache_beam.transforms.window import GlobalWindow 43 from apache_beam.transforms.window import GlobalWindows 44 from apache_beam.transforms.window import TimestampCombiner 45 from apache_beam.transforms.window import WindowedValue 46 from apache_beam.transforms.window import WindowFn 47 from apache_beam.utils import windowed_value 48 from apache_beam.utils.timestamp import MAX_TIMESTAMP 49 from apache_beam.utils.timestamp import MIN_TIMESTAMP 50 from apache_beam.utils.timestamp import TIME_GRANULARITY 51 52 __all__ = [ 53 'AccumulationMode', 54 'TriggerFn', 55 'DefaultTrigger', 56 'AfterWatermark', 57 'AfterProcessingTime', 58 'AfterCount', 59 'Repeatedly', 60 'AfterAny', 61 'AfterAll', 62 'AfterEach', 63 'OrFinally', 64 ] 65 66 _LOGGER = logging.getLogger(__name__) 67 68 69 class AccumulationMode(object): 70 """Controls what to do with data when a trigger fires multiple times.""" 71 DISCARDING = beam_runner_api_pb2.AccumulationMode.DISCARDING 72 ACCUMULATING = beam_runner_api_pb2.AccumulationMode.ACCUMULATING 73 # TODO(robertwb): Provide retractions of previous outputs. 74 # RETRACTING = 3 75 76 77 class _StateTag(metaclass=ABCMeta): 78 """An identifier used to store and retrieve typed, combinable state. 79 80 The given tag must be unique for this step.""" 81 def __init__(self, tag): 82 self.tag = tag 83 84 85 class _ReadModifyWriteStateTag(_StateTag): 86 """StateTag pointing to an element.""" 87 def __repr__(self): 88 return 'ValueStateTag(%s)' % (self.tag) 89 90 def with_prefix(self, prefix): 91 return _ReadModifyWriteStateTag(prefix + self.tag) 92 93 94 class _SetStateTag(_StateTag): 95 """StateTag pointing to an element.""" 96 def __repr__(self): 97 return 'SetStateTag({tag})'.format(tag=self.tag) 98 99 def with_prefix(self, prefix): 100 return _SetStateTag(prefix + self.tag) 101 102 103 class _CombiningValueStateTag(_StateTag): 104 """StateTag pointing to an element, accumulated with a combiner. 105 106 The given tag must be unique for this step. The given CombineFn will be 107 applied (possibly incrementally and eagerly) when adding elements.""" 108 109 # TODO(robertwb): Also store the coder (perhaps extracted from the combine_fn) 110 def __init__(self, tag, combine_fn): 111 super().__init__(tag) 112 if not combine_fn: 113 raise ValueError('combine_fn must be specified.') 114 if not isinstance(combine_fn, core.CombineFn): 115 combine_fn = core.CombineFn.from_callable(combine_fn) 116 self.combine_fn = combine_fn 117 118 def __repr__(self): 119 return 'CombiningValueStateTag(%s, %s)' % (self.tag, self.combine_fn) 120 121 def with_prefix(self, prefix): 122 return _CombiningValueStateTag(prefix + self.tag, self.combine_fn) 123 124 def without_extraction(self): 125 class NoExtractionCombineFn(core.CombineFn): 126 setup = self.combine_fn.setup 127 create_accumulator = self.combine_fn.create_accumulator 128 add_input = self.combine_fn.add_input 129 merge_accumulators = self.combine_fn.merge_accumulators 130 compact = self.combine_fn.compact 131 extract_output = staticmethod(lambda x: x) 132 teardown = self.combine_fn.teardown 133 134 return _CombiningValueStateTag(self.tag, NoExtractionCombineFn()) 135 136 137 class _ListStateTag(_StateTag): 138 """StateTag pointing to a list of elements.""" 139 def __repr__(self): 140 return 'ListStateTag(%s)' % self.tag 141 142 def with_prefix(self, prefix): 143 return _ListStateTag(prefix + self.tag) 144 145 146 class _WatermarkHoldStateTag(_StateTag): 147 def __init__(self, tag, timestamp_combiner_impl): 148 super().__init__(tag) 149 self.timestamp_combiner_impl = timestamp_combiner_impl 150 151 def __repr__(self): 152 return 'WatermarkHoldStateTag(%s, %s)' % ( 153 self.tag, self.timestamp_combiner_impl) 154 155 def with_prefix(self, prefix): 156 return _WatermarkHoldStateTag( 157 prefix + self.tag, self.timestamp_combiner_impl) 158 159 160 class DataLossReason(Flag): 161 """Enum defining potential reasons that a trigger may cause data loss. 162 163 These flags should only cover when the trigger is the cause, though windowing 164 can be taken into account. For instance, AfterWatermark may not flag itself 165 as finishing if the windowing doesn't allow lateness. 166 """ 167 168 # Trigger will never be the source of data loss. 169 NO_POTENTIAL_LOSS = 0 170 171 # Trigger may finish. In this case, data that comes in after the trigger may 172 # be lost. Example: AfterCount(1) will stop firing after the first element. 173 MAY_FINISH = auto() 174 175 # Deprecated: Beam will emit buffered data at GC time. Any other behavior 176 # should be treated as a bug with the runner used. 177 CONDITION_NOT_GUARANTEED = auto() 178 179 180 # Convenience functions for checking if a flag is included. Each is equivalent 181 # to `reason & flag == flag` 182 183 184 def _IncludesMayFinish(reason): 185 # type: (DataLossReason) -> bool 186 return reason & DataLossReason.MAY_FINISH == DataLossReason.MAY_FINISH 187 188 189 # pylint: disable=unused-argument 190 # TODO(robertwb): Provisional API, Java likely to change as well. 191 class TriggerFn(metaclass=ABCMeta): 192 """A TriggerFn determines when window (panes) are emitted. 193 194 See https://beam.apache.org/documentation/programming-guide/#triggers 195 """ 196 @abstractmethod 197 def on_element(self, element, window, context): 198 """Called when a new element arrives in a window. 199 200 Args: 201 element: the element being added 202 window: the window to which the element is being added 203 context: a context (e.g. a TriggerContext instance) for managing state 204 and setting timers 205 """ 206 pass 207 208 @abstractmethod 209 def on_merge(self, to_be_merged, merge_result, context): 210 """Called when multiple windows are merged. 211 212 Args: 213 to_be_merged: the set of windows to be merged 214 merge_result: the window into which the windows are being merged 215 context: a context (e.g. a TriggerContext instance) for managing state 216 and setting timers 217 """ 218 pass 219 220 @abstractmethod 221 def should_fire(self, time_domain, timestamp, window, context): 222 """Whether this trigger should cause the window to fire. 223 224 Args: 225 time_domain: WATERMARK for event-time timers and REAL_TIME for 226 processing-time timers. 227 timestamp: for time_domain WATERMARK, it represents the 228 watermark: (a lower bound on) the watermark of the system 229 and for time_domain REAL_TIME, it represents the 230 trigger: timestamp of the processing-time timer. 231 window: the window whose trigger is being considered 232 context: a context (e.g. a TriggerContext instance) for managing state 233 and setting timers 234 235 Returns: 236 whether this trigger should cause a firing 237 """ 238 pass 239 240 @abstractmethod 241 def has_ontime_pane(self): 242 """Whether this trigger creates an empty pane even if there are no elements. 243 244 Returns: 245 True if this trigger guarantees that there will always be an ON_TIME pane 246 even if there are no elements in that pane. 247 """ 248 pass 249 250 @abstractmethod 251 def on_fire(self, watermark, window, context): 252 """Called when a trigger actually fires. 253 254 Args: 255 watermark: (a lower bound on) the watermark of the system 256 window: the window whose trigger is being fired 257 context: a context (e.g. a TriggerContext instance) for managing state 258 and setting timers 259 260 Returns: 261 whether this trigger is finished 262 """ 263 pass 264 265 @abstractmethod 266 def reset(self, window, context): 267 """Clear any state and timers used by this TriggerFn.""" 268 pass 269 270 def may_lose_data(self, unused_windowing): 271 # type: (core.Windowing) -> DataLossReason 272 273 """Returns whether or not this trigger could cause data loss. 274 275 A trigger can cause data loss in the following scenarios: 276 277 * The trigger has a chance to finish. For instance, AfterWatermark() 278 without a late trigger would cause all late data to be lost. This 279 scenario is only accounted for if the windowing strategy allows 280 late data. Otherwise, the trigger is not responsible for the data 281 loss. 282 283 Note that this only returns the potential for loss. It does not mean that 284 there will be data loss. It also only accounts for loss related to the 285 trigger, not other potential causes. 286 287 Args: 288 windowing: The Windowing that this trigger belongs to. It does not need 289 to be the top-level trigger. 290 291 Returns: 292 The DataLossReason. If there is no potential loss, 293 DataLossReason.NO_POTENTIAL_LOSS is returned. Otherwise, all the 294 potential reasons are returned as a single value. 295 """ 296 # For backwards compatibility's sake, we're assuming the trigger is safe. 297 return DataLossReason.NO_POTENTIAL_LOSS 298 299 300 # pylint: enable=unused-argument 301 302 @staticmethod 303 def from_runner_api(proto, context): 304 return { 305 'after_all': AfterAll, 306 'after_any': AfterAny, 307 'after_each': AfterEach, 308 'after_end_of_window': AfterWatermark, 309 'after_processing_time': AfterProcessingTime, 310 # after_processing_time, after_synchronized_processing_time 311 'always': Always, 312 'default': DefaultTrigger, 313 'element_count': AfterCount, 314 'never': _Never, 315 'or_finally': OrFinally, 316 'repeat': Repeatedly, 317 }[proto.WhichOneof('trigger')].from_runner_api(proto, context) 318 319 @abstractmethod 320 def to_runner_api(self, unused_context): 321 pass 322 323 324 class DefaultTrigger(TriggerFn): 325 """Semantically Repeatedly(AfterWatermark()), but more optimized.""" 326 def __init__(self): 327 pass 328 329 def __repr__(self): 330 return 'DefaultTrigger()' 331 332 def on_element(self, element, window, context): 333 context.set_timer(str(window), TimeDomain.WATERMARK, window.end) 334 335 def on_merge(self, to_be_merged, merge_result, context): 336 for window in to_be_merged: 337 context.clear_timer(str(window), TimeDomain.WATERMARK) 338 339 def should_fire(self, time_domain, watermark, window, context): 340 if watermark >= window.end: 341 # Explicitly clear the timer so that late elements are not emitted again 342 # when the timer is fired. 343 context.clear_timer(str(window), TimeDomain.WATERMARK) 344 return watermark >= window.end 345 346 def on_fire(self, watermark, window, context): 347 return False 348 349 def reset(self, window, context): 350 context.clear_timer(str(window), TimeDomain.WATERMARK) 351 352 def may_lose_data(self, unused_windowing): 353 return DataLossReason.NO_POTENTIAL_LOSS 354 355 def __eq__(self, other): 356 return type(self) == type(other) 357 358 def __hash__(self): 359 return hash(type(self)) 360 361 @staticmethod 362 def from_runner_api(proto, context): 363 return DefaultTrigger() 364 365 def to_runner_api(self, unused_context): 366 return beam_runner_api_pb2.Trigger( 367 default=beam_runner_api_pb2.Trigger.Default()) 368 369 def has_ontime_pane(self): 370 return True 371 372 373 class AfterProcessingTime(TriggerFn): 374 """Fire exactly once after a specified delay from processing time.""" 375 376 STATE_TAG = _SetStateTag('has_timer') 377 378 def __init__(self, delay=0): 379 """Initialize a processing time trigger with a delay in seconds.""" 380 self.delay = delay 381 382 def __repr__(self): 383 return 'AfterProcessingTime(delay=%d)' % self.delay 384 385 def on_element(self, element, window, context): 386 if not context.get_state(self.STATE_TAG): 387 context.set_timer( 388 '', TimeDomain.REAL_TIME, context.get_current_time() + self.delay) 389 context.add_state(self.STATE_TAG, True) 390 391 def on_merge(self, to_be_merged, merge_result, context): 392 # timers will be kept through merging 393 pass 394 395 def should_fire(self, time_domain, timestamp, window, context): 396 if time_domain == TimeDomain.REAL_TIME: 397 return True 398 399 def on_fire(self, timestamp, window, context): 400 return True 401 402 def reset(self, window, context): 403 context.clear_state(self.STATE_TAG) 404 405 def may_lose_data(self, unused_windowing): 406 """AfterProcessingTime may finish.""" 407 return DataLossReason.MAY_FINISH 408 409 @staticmethod 410 def from_runner_api(proto, context): 411 return AfterProcessingTime( 412 delay=( 413 proto.after_processing_time.timestamp_transforms[0].delay. 414 delay_millis) // 1000) 415 416 def to_runner_api(self, context): 417 delay_proto = beam_runner_api_pb2.TimestampTransform( 418 delay=beam_runner_api_pb2.TimestampTransform.Delay( 419 delay_millis=self.delay * 1000)) 420 return beam_runner_api_pb2.Trigger( 421 after_processing_time=beam_runner_api_pb2.Trigger.AfterProcessingTime( 422 timestamp_transforms=[delay_proto])) 423 424 def has_ontime_pane(self): 425 return False 426 427 428 class Always(TriggerFn): 429 """Repeatedly invoke the given trigger, never finishing.""" 430 def __init__(self): 431 pass 432 433 def __repr__(self): 434 return 'Always' 435 436 def __eq__(self, other): 437 return type(self) == type(other) 438 439 def __hash__(self): 440 return 1 441 442 def on_element(self, element, window, context): 443 pass 444 445 def on_merge(self, to_be_merged, merge_result, context): 446 pass 447 448 def has_ontime_pane(self): 449 return False 450 451 def reset(self, window, context): 452 pass 453 454 def should_fire(self, time_domain, watermark, window, context): 455 return True 456 457 def on_fire(self, watermark, window, context): 458 return False 459 460 def may_lose_data(self, unused_windowing): 461 """No potential loss, since the trigger always fires.""" 462 return DataLossReason.NO_POTENTIAL_LOSS 463 464 @staticmethod 465 def from_runner_api(proto, context): 466 return Always() 467 468 def to_runner_api(self, context): 469 return beam_runner_api_pb2.Trigger( 470 always=beam_runner_api_pb2.Trigger.Always()) 471 472 473 class _Never(TriggerFn): 474 """A trigger that never fires. 475 476 Data may still be released at window closing. 477 """ 478 def __init__(self): 479 pass 480 481 def __repr__(self): 482 return 'Never' 483 484 def __eq__(self, other): 485 return type(self) == type(other) 486 487 def __hash__(self): 488 return hash(type(self)) 489 490 def on_element(self, element, window, context): 491 pass 492 493 def on_merge(self, to_be_merged, merge_result, context): 494 pass 495 496 def has_ontime_pane(self): 497 False 498 499 def reset(self, window, context): 500 pass 501 502 def should_fire(self, time_domain, watermark, window, context): 503 return False 504 505 def on_fire(self, watermark, window, context): 506 return True 507 508 def may_lose_data(self, unused_windowing): 509 """No potential data loss. 510 511 Though Never doesn't explicitly trigger, it still collects data on 512 windowing closing. 513 """ 514 return DataLossReason.NO_POTENTIAL_LOSS 515 516 @staticmethod 517 def from_runner_api(proto, context): 518 return _Never() 519 520 def to_runner_api(self, context): 521 return beam_runner_api_pb2.Trigger( 522 never=beam_runner_api_pb2.Trigger.Never()) 523 524 525 class AfterWatermark(TriggerFn): 526 """Fire exactly once when the watermark passes the end of the window. 527 528 Args: 529 early: if not None, a speculative trigger to repeatedly evaluate before 530 the watermark passes the end of the window 531 late: if not None, a speculative trigger to repeatedly evaluate after 532 the watermark passes the end of the window 533 """ 534 LATE_TAG = _CombiningValueStateTag('is_late', any) 535 536 def __init__(self, early=None, late=None): 537 # TODO(zhoufek): Maybe don't wrap early/late if they are already Repeatedly 538 self.early = Repeatedly(early) if early else None 539 self.late = Repeatedly(late) if late else None 540 541 def __repr__(self): 542 qualifiers = [] 543 if self.early: 544 qualifiers.append('early=%s' % self.early.underlying) 545 if self.late: 546 qualifiers.append('late=%s' % self.late.underlying) 547 return 'AfterWatermark(%s)' % ', '.join(qualifiers) 548 549 def is_late(self, context): 550 return self.late and context.get_state(self.LATE_TAG) 551 552 def on_element(self, element, window, context): 553 if self.is_late(context): 554 self.late.on_element(element, window, NestedContext(context, 'late')) 555 else: 556 context.set_timer('', TimeDomain.WATERMARK, window.end) 557 if self.early: 558 self.early.on_element(element, window, NestedContext(context, 'early')) 559 560 def on_merge(self, to_be_merged, merge_result, context): 561 # TODO(robertwb): Figure out whether the 'rewind' semantics could be used 562 # here. 563 if self.is_late(context): 564 self.late.on_merge( 565 to_be_merged, merge_result, NestedContext(context, 'late')) 566 else: 567 # Note: Timer clearing solely an optimization. 568 for window in to_be_merged: 569 if window.end != merge_result.end: 570 context.clear_timer('', TimeDomain.WATERMARK) 571 if self.early: 572 self.early.on_merge( 573 to_be_merged, merge_result, NestedContext(context, 'early')) 574 575 def should_fire(self, time_domain, watermark, window, context): 576 if self.is_late(context): 577 return self.late.should_fire( 578 time_domain, watermark, window, NestedContext(context, 'late')) 579 elif watermark >= window.end: 580 # Explicitly clear the timer so that late elements are not emitted again 581 # when the timer is fired. 582 context.clear_timer('', TimeDomain.WATERMARK) 583 return True 584 elif self.early: 585 return self.early.should_fire( 586 time_domain, watermark, window, NestedContext(context, 'early')) 587 return False 588 589 def on_fire(self, watermark, window, context): 590 if self.is_late(context): 591 return self.late.on_fire( 592 watermark, window, NestedContext(context, 'late')) 593 elif watermark >= window.end: 594 context.add_state(self.LATE_TAG, True) 595 return not self.late 596 elif self.early: 597 self.early.on_fire(watermark, window, NestedContext(context, 'early')) 598 return False 599 600 def reset(self, window, context): 601 if self.late: 602 context.clear_state(self.LATE_TAG) 603 if self.early: 604 self.early.reset(window, NestedContext(context, 'early')) 605 if self.late: 606 self.late.reset(window, NestedContext(context, 'late')) 607 608 def may_lose_data(self, windowing): 609 """May cause data loss if lateness allowed and no late trigger set.""" 610 if windowing.allowed_lateness == 0: 611 return DataLossReason.NO_POTENTIAL_LOSS 612 if self.late is None: 613 return DataLossReason.MAY_FINISH 614 return self.late.may_lose_data(windowing) 615 616 def __eq__(self, other): 617 return ( 618 type(self) == type(other) and self.early == other.early and 619 self.late == other.late) 620 621 def __hash__(self): 622 return hash((type(self), self.early, self.late)) 623 624 @staticmethod 625 def from_runner_api(proto, context): 626 return AfterWatermark( 627 early=TriggerFn.from_runner_api( 628 proto.after_end_of_window.early_firings, context) 629 if proto.after_end_of_window.HasField('early_firings') else None, 630 late=TriggerFn.from_runner_api( 631 proto.after_end_of_window.late_firings, context) 632 if proto.after_end_of_window.HasField('late_firings') else None) 633 634 def to_runner_api(self, context): 635 early_proto = self.early.underlying.to_runner_api( 636 context) if self.early else None 637 late_proto = self.late.underlying.to_runner_api( 638 context) if self.late else None 639 return beam_runner_api_pb2.Trigger( 640 after_end_of_window=beam_runner_api_pb2.Trigger.AfterEndOfWindow( 641 early_firings=early_proto, late_firings=late_proto)) 642 643 def has_ontime_pane(self): 644 return True 645 646 647 class AfterCount(TriggerFn): 648 """Fire when there are at least count elements in this window pane.""" 649 650 COUNT_TAG = _CombiningValueStateTag('count', combiners.CountCombineFn()) 651 652 def __init__(self, count): 653 if not isinstance(count, numbers.Integral) or count < 1: 654 raise ValueError("count (%d) must be a positive integer." % count) 655 self.count = count 656 657 def __repr__(self): 658 return 'AfterCount(%s)' % self.count 659 660 def __eq__(self, other): 661 return type(self) == type(other) and self.count == other.count 662 663 def __hash__(self): 664 return hash(self.count) 665 666 def on_element(self, element, window, context): 667 context.add_state(self.COUNT_TAG, 1) 668 669 def on_merge(self, to_be_merged, merge_result, context): 670 # states automatically merged 671 pass 672 673 def should_fire(self, time_domain, watermark, window, context): 674 return context.get_state(self.COUNT_TAG) >= self.count 675 676 def on_fire(self, watermark, window, context): 677 return True 678 679 def reset(self, window, context): 680 context.clear_state(self.COUNT_TAG) 681 682 def may_lose_data(self, unused_windowing): 683 """AfterCount may finish.""" 684 return DataLossReason.MAY_FINISH 685 686 @staticmethod 687 def from_runner_api(proto, unused_context): 688 return AfterCount(proto.element_count.element_count) 689 690 def to_runner_api(self, unused_context): 691 return beam_runner_api_pb2.Trigger( 692 element_count=beam_runner_api_pb2.Trigger.ElementCount( 693 element_count=self.count)) 694 695 def has_ontime_pane(self): 696 return False 697 698 699 class Repeatedly(TriggerFn): 700 """Repeatedly invoke the given trigger, never finishing.""" 701 def __init__(self, underlying): 702 self.underlying = underlying 703 704 def __repr__(self): 705 return 'Repeatedly(%s)' % self.underlying 706 707 def __eq__(self, other): 708 return type(self) == type(other) and self.underlying == other.underlying 709 710 def __hash__(self): 711 return hash(self.underlying) 712 713 def on_element(self, element, window, context): 714 self.underlying.on_element(element, window, context) 715 716 def on_merge(self, to_be_merged, merge_result, context): 717 self.underlying.on_merge(to_be_merged, merge_result, context) 718 719 def should_fire(self, time_domain, watermark, window, context): 720 return self.underlying.should_fire(time_domain, watermark, window, context) 721 722 def on_fire(self, watermark, window, context): 723 if self.underlying.on_fire(watermark, window, context): 724 self.underlying.reset(window, context) 725 return False 726 727 def reset(self, window, context): 728 self.underlying.reset(window, context) 729 730 def may_lose_data(self, windowing): 731 """Repeatedly will run in a loop and pick up whatever is left at GC.""" 732 return DataLossReason.NO_POTENTIAL_LOSS 733 734 @staticmethod 735 def from_runner_api(proto, context): 736 return Repeatedly( 737 TriggerFn.from_runner_api(proto.repeat.subtrigger, context)) 738 739 def to_runner_api(self, context): 740 return beam_runner_api_pb2.Trigger( 741 repeat=beam_runner_api_pb2.Trigger.Repeat( 742 subtrigger=self.underlying.to_runner_api(context))) 743 744 def has_ontime_pane(self): 745 return self.underlying.has_ontime_pane() 746 747 748 class _ParallelTriggerFn(TriggerFn, metaclass=ABCMeta): 749 def __init__(self, *triggers): 750 self.triggers = triggers 751 752 def __repr__(self): 753 return '%s(%s)' % ( 754 self.__class__.__name__, ', '.join(str(t) for t in self.triggers)) 755 756 def __eq__(self, other): 757 return type(self) == type(other) and self.triggers == other.triggers 758 759 def __hash__(self): 760 return hash(self.triggers) 761 762 @abstractmethod 763 def combine_op(self, trigger_results): 764 pass 765 766 def on_element(self, element, window, context): 767 for ix, trigger in enumerate(self.triggers): 768 trigger.on_element(element, window, self._sub_context(context, ix)) 769 770 def on_merge(self, to_be_merged, merge_result, context): 771 for ix, trigger in enumerate(self.triggers): 772 trigger.on_merge( 773 to_be_merged, merge_result, self._sub_context(context, ix)) 774 775 def should_fire(self, time_domain, watermark, window, context): 776 self._time_domain = time_domain 777 return self.combine_op( 778 trigger.should_fire( 779 time_domain, watermark, window, self._sub_context(context, ix)) 780 for ix, 781 trigger in enumerate(self.triggers)) 782 783 def on_fire(self, watermark, window, context): 784 finished = [] 785 for ix, trigger in enumerate(self.triggers): 786 nested_context = self._sub_context(context, ix) 787 if trigger.should_fire(TimeDomain.WATERMARK, 788 watermark, 789 window, 790 nested_context): 791 finished.append(trigger.on_fire(watermark, window, nested_context)) 792 return self.combine_op(finished) 793 794 def may_lose_data(self, windowing): 795 may_finish = self.combine_op( 796 _IncludesMayFinish(t.may_lose_data(windowing)) for t in self.triggers) 797 return ( 798 DataLossReason.MAY_FINISH 799 if may_finish else DataLossReason.NO_POTENTIAL_LOSS) 800 801 def reset(self, window, context): 802 for ix, trigger in enumerate(self.triggers): 803 trigger.reset(window, self._sub_context(context, ix)) 804 805 @staticmethod 806 def _sub_context(context, index): 807 return NestedContext(context, '%d/' % index) 808 809 @staticmethod 810 def from_runner_api(proto, context): 811 subtriggers = [ 812 TriggerFn.from_runner_api(subtrigger, context) for subtrigger in 813 proto.after_all.subtriggers or proto.after_any.subtriggers 814 ] 815 if proto.after_all.subtriggers: 816 return AfterAll(*subtriggers) 817 else: 818 return AfterAny(*subtriggers) 819 820 def to_runner_api(self, context): 821 subtriggers = [ 822 subtrigger.to_runner_api(context) for subtrigger in self.triggers 823 ] 824 if self.combine_op == all: 825 return beam_runner_api_pb2.Trigger( 826 after_all=beam_runner_api_pb2.Trigger.AfterAll( 827 subtriggers=subtriggers)) 828 elif self.combine_op == any: 829 return beam_runner_api_pb2.Trigger( 830 after_any=beam_runner_api_pb2.Trigger.AfterAny( 831 subtriggers=subtriggers)) 832 else: 833 raise NotImplementedError(self) 834 835 def has_ontime_pane(self): 836 return any(t.has_ontime_pane() for t in self.triggers) 837 838 839 class AfterAny(_ParallelTriggerFn): 840 """Fires when any subtrigger fires. 841 842 Also finishes when any subtrigger finishes. 843 """ 844 combine_op = any 845 846 847 class AfterAll(_ParallelTriggerFn): 848 """Fires when all subtriggers have fired. 849 850 Also finishes when all subtriggers have finished. 851 """ 852 combine_op = all 853 854 855 class AfterEach(TriggerFn): 856 857 INDEX_TAG = _CombiningValueStateTag( 858 'index', (lambda indices: 0 if not indices else max(indices))) 859 860 def __init__(self, *triggers): 861 self.triggers = triggers 862 863 def __repr__(self): 864 return '%s(%s)' % ( 865 self.__class__.__name__, ', '.join(str(t) for t in self.triggers)) 866 867 def __eq__(self, other): 868 return type(self) == type(other) and self.triggers == other.triggers 869 870 def __hash__(self): 871 return hash(self.triggers) 872 873 def on_element(self, element, window, context): 874 ix = context.get_state(self.INDEX_TAG) 875 if ix < len(self.triggers): 876 self.triggers[ix].on_element( 877 element, window, self._sub_context(context, ix)) 878 879 def on_merge(self, to_be_merged, merge_result, context): 880 # This takes the furthest window on merging. 881 # TODO(robertwb): Revisit this when merging windows logic is settled for 882 # all possible merging situations. 883 ix = context.get_state(self.INDEX_TAG) 884 if ix < len(self.triggers): 885 self.triggers[ix].on_merge( 886 to_be_merged, merge_result, self._sub_context(context, ix)) 887 888 def should_fire(self, time_domain, watermark, window, context): 889 ix = context.get_state(self.INDEX_TAG) 890 if ix < len(self.triggers): 891 return self.triggers[ix].should_fire( 892 time_domain, watermark, window, self._sub_context(context, ix)) 893 894 def on_fire(self, watermark, window, context): 895 ix = context.get_state(self.INDEX_TAG) 896 if ix < len(self.triggers): 897 if self.triggers[ix].on_fire(watermark, 898 window, 899 self._sub_context(context, ix)): 900 ix += 1 901 context.add_state(self.INDEX_TAG, ix) 902 return ix == len(self.triggers) 903 904 def reset(self, window, context): 905 context.clear_state(self.INDEX_TAG) 906 for ix, trigger in enumerate(self.triggers): 907 trigger.reset(window, self._sub_context(context, ix)) 908 909 def may_lose_data(self, windowing): 910 """If all sub-triggers may finish, this may finish.""" 911 may_finish = all( 912 _IncludesMayFinish(t.may_lose_data(windowing)) for t in self.triggers) 913 return ( 914 DataLossReason.MAY_FINISH 915 if may_finish else DataLossReason.NO_POTENTIAL_LOSS) 916 917 @staticmethod 918 def _sub_context(context, index): 919 return NestedContext(context, '%d/' % index) 920 921 @staticmethod 922 def from_runner_api(proto, context): 923 return AfterEach( 924 *[ 925 TriggerFn.from_runner_api(subtrigger, context) 926 for subtrigger in proto.after_each.subtriggers 927 ]) 928 929 def to_runner_api(self, context): 930 return beam_runner_api_pb2.Trigger( 931 after_each=beam_runner_api_pb2.Trigger.AfterEach( 932 subtriggers=[ 933 subtrigger.to_runner_api(context) 934 for subtrigger in self.triggers 935 ])) 936 937 def has_ontime_pane(self): 938 return any(t.has_ontime_pane() for t in self.triggers) 939 940 941 class OrFinally(AfterAny): 942 @staticmethod 943 def from_runner_api(proto, context): 944 return OrFinally( 945 TriggerFn.from_runner_api(proto.or_finally.main, context), 946 # getattr is used as finally is a keyword in Python 947 TriggerFn.from_runner_api( 948 getattr(proto.or_finally, 'finally'), context)) 949 950 def to_runner_api(self, context): 951 return beam_runner_api_pb2.Trigger( 952 or_finally=beam_runner_api_pb2.Trigger.OrFinally( 953 main=self.triggers[0].to_runner_api(context), 954 # dict keyword argument is used as finally is a keyword in Python 955 **{'finally': self.triggers[1].to_runner_api(context)})) 956 957 958 class TriggerContext(object): 959 def __init__(self, outer, window, clock): 960 self._outer = outer 961 self._window = window 962 self._clock = clock 963 964 def get_current_time(self): 965 return self._clock.time() 966 967 def set_timer(self, name, time_domain, timestamp): 968 self._outer.set_timer(self._window, name, time_domain, timestamp) 969 970 def clear_timer(self, name, time_domain): 971 self._outer.clear_timer(self._window, name, time_domain) 972 973 def add_state(self, tag, value): 974 self._outer.add_state(self._window, tag, value) 975 976 def get_state(self, tag): 977 return self._outer.get_state(self._window, tag) 978 979 def clear_state(self, tag): 980 return self._outer.clear_state(self._window, tag) 981 982 983 class NestedContext(object): 984 """Namespaced context useful for defining composite triggers.""" 985 def __init__(self, outer, prefix): 986 self._outer = outer 987 self._prefix = prefix 988 989 def get_current_time(self): 990 return self._outer.get_current_time() 991 992 def set_timer(self, name, time_domain, timestamp): 993 self._outer.set_timer(self._prefix + name, time_domain, timestamp) 994 995 def clear_timer(self, name, time_domain): 996 self._outer.clear_timer(self._prefix + name, time_domain) 997 998 def add_state(self, tag, value): 999 self._outer.add_state(tag.with_prefix(self._prefix), value) 1000 1001 def get_state(self, tag): 1002 return self._outer.get_state(tag.with_prefix(self._prefix)) 1003 1004 def clear_state(self, tag): 1005 self._outer.clear_state(tag.with_prefix(self._prefix)) 1006 1007 1008 # pylint: disable=unused-argument 1009 class SimpleState(metaclass=ABCMeta): 1010 """Basic state storage interface used for triggering. 1011 1012 Only timers must hold the watermark (by their timestamp). 1013 """ 1014 @abstractmethod 1015 def set_timer( 1016 self, window, name, time_domain, timestamp, dynamic_timer_tag=''): 1017 pass 1018 1019 @abstractmethod 1020 def get_window(self, window_id): 1021 pass 1022 1023 @abstractmethod 1024 def clear_timer(self, window, name, time_domain, dynamic_timer_tag=''): 1025 pass 1026 1027 @abstractmethod 1028 def add_state(self, window, tag, value): 1029 pass 1030 1031 @abstractmethod 1032 def get_state(self, window, tag): 1033 pass 1034 1035 @abstractmethod 1036 def clear_state(self, window, tag): 1037 pass 1038 1039 def at(self, window, clock): 1040 return NestedContext(TriggerContext(self, window, clock), 'trigger') 1041 1042 1043 class UnmergedState(SimpleState): 1044 """State suitable for use in TriggerDriver. 1045 1046 This class must be implemented by each backend. 1047 """ 1048 @abstractmethod 1049 def set_global_state(self, tag, value): 1050 pass 1051 1052 @abstractmethod 1053 def get_global_state(self, tag, default=None): 1054 pass 1055 1056 1057 # pylint: enable=unused-argument 1058 1059 1060 class MergeableStateAdapter(SimpleState): 1061 """Wraps an UnmergedState, tracking merged windows.""" 1062 # TODO(robertwb): A similar indirection could be used for sliding windows 1063 # or other window_fns when a single element typically belongs to many windows. 1064 1065 WINDOW_IDS = _ReadModifyWriteStateTag('window_ids') 1066 1067 def __init__(self, raw_state): 1068 self.raw_state = raw_state 1069 self.window_ids = self.raw_state.get_global_state(self.WINDOW_IDS, {}) 1070 self.counter = None 1071 1072 def set_timer( 1073 self, window, name, time_domain, timestamp, dynamic_timer_tag=''): 1074 self.raw_state.set_timer( 1075 self._get_id(window), 1076 name, 1077 time_domain, 1078 timestamp, 1079 dynamic_timer_tag=dynamic_timer_tag) 1080 1081 def clear_timer(self, window, name, time_domain, dynamic_timer_tag=''): 1082 for window_id in self._get_ids(window): 1083 self.raw_state.clear_timer( 1084 window_id, name, time_domain, dynamic_timer_tag=dynamic_timer_tag) 1085 1086 def add_state(self, window, tag, value): 1087 if isinstance(tag, _ReadModifyWriteStateTag): 1088 raise ValueError( 1089 'Merging requested for non-mergeable state tag: %r.' % tag) 1090 elif isinstance(tag, _CombiningValueStateTag): 1091 tag = tag.without_extraction() 1092 self.raw_state.add_state(self._get_id(window), tag, value) 1093 1094 def get_state(self, window, tag): 1095 if isinstance(tag, _CombiningValueStateTag): 1096 original_tag, tag = tag, tag.without_extraction() 1097 values = [ 1098 self.raw_state.get_state(window_id, tag) 1099 for window_id in self._get_ids(window) 1100 ] 1101 if isinstance(tag, _ReadModifyWriteStateTag): 1102 raise ValueError( 1103 'Merging requested for non-mergeable state tag: %r.' % tag) 1104 elif isinstance(tag, _CombiningValueStateTag): 1105 return original_tag.combine_fn.extract_output( 1106 original_tag.combine_fn.merge_accumulators(values)) 1107 elif isinstance(tag, _ListStateTag): 1108 return [v for vs in values for v in vs] 1109 elif isinstance(tag, _SetStateTag): 1110 return {v for vs in values for v in vs} 1111 elif isinstance(tag, _WatermarkHoldStateTag): 1112 return tag.timestamp_combiner_impl.combine_all(values) 1113 else: 1114 raise ValueError('Invalid tag.', tag) 1115 1116 def clear_state(self, window, tag): 1117 for window_id in self._get_ids(window): 1118 self.raw_state.clear_state(window_id, tag) 1119 if tag is None: 1120 del self.window_ids[window] 1121 self._persist_window_ids() 1122 1123 def merge(self, to_be_merged, merge_result): 1124 for window in to_be_merged: 1125 if window != merge_result: 1126 if window in self.window_ids: 1127 if merge_result in self.window_ids: 1128 merge_window_ids = self.window_ids[merge_result] 1129 else: 1130 merge_window_ids = self.window_ids[merge_result] = [] 1131 merge_window_ids.extend(self.window_ids.pop(window)) 1132 self._persist_window_ids() 1133 1134 def known_windows(self): 1135 return list(self.window_ids) 1136 1137 def get_window(self, window_id): 1138 for window, ids in self.window_ids.items(): 1139 if window_id in ids: 1140 return window 1141 raise ValueError('No window for %s' % window_id) 1142 1143 def _get_id(self, window): 1144 if window in self.window_ids: 1145 return self.window_ids[window][0] 1146 1147 window_id = self._get_next_counter() 1148 self.window_ids[window] = [window_id] 1149 self._persist_window_ids() 1150 return window_id 1151 1152 def _get_ids(self, window): 1153 return self.window_ids.get(window, []) 1154 1155 def _get_next_counter(self): 1156 if not self.window_ids: 1157 self.counter = 0 1158 elif self.counter is None: 1159 self.counter = max(k for ids in self.window_ids.values() for k in ids) 1160 self.counter += 1 1161 return self.counter 1162 1163 def _persist_window_ids(self): 1164 self.raw_state.set_global_state(self.WINDOW_IDS, self.window_ids) 1165 1166 def __repr__(self): 1167 return '\n\t'.join([repr(self.window_ids)] + 1168 repr(self.raw_state).split('\n')) 1169 1170 1171 def create_trigger_driver( 1172 windowing, is_batch=False, phased_combine_fn=None, clock=None): 1173 """Create the TriggerDriver for the given windowing and options.""" 1174 1175 # TODO(https://github.com/apache/beam/issues/20165): Respect closing and 1176 # on-time behaviors. For batch, we should always fire once, no matter what. 1177 if is_batch and windowing.triggerfn == _Never(): 1178 windowing = copy.copy(windowing) 1179 windowing.triggerfn = Always() 1180 1181 # TODO(robertwb): We can do more if we know elements are in timestamp 1182 # sorted order. 1183 if windowing.is_default() and is_batch: 1184 driver = BatchGlobalTriggerDriver() 1185 elif (windowing.windowfn == GlobalWindows() and 1186 (windowing.triggerfn in [AfterCount(1), Always()]) and is_batch): 1187 # Here we also just pass through all the values exactly once. 1188 driver = BatchGlobalTriggerDriver() 1189 else: 1190 driver = GeneralTriggerDriver(windowing, clock) 1191 1192 if phased_combine_fn: 1193 # TODO(ccy): Refactor GeneralTriggerDriver to combine values eagerly using 1194 # the known phased_combine_fn here. 1195 driver = CombiningTriggerDriver(phased_combine_fn, driver) 1196 return driver 1197 1198 1199 class TriggerDriver(metaclass=ABCMeta): 1200 """Breaks a series of bundle and timer firings into window (pane)s.""" 1201 @abstractmethod 1202 def process_elements( 1203 self, 1204 state, 1205 windowed_values, 1206 output_watermark, 1207 input_watermark=MIN_TIMESTAMP): 1208 pass 1209 1210 @abstractmethod 1211 def process_timer( 1212 self, 1213 window_id, 1214 name, 1215 time_domain, 1216 timestamp, 1217 state, 1218 input_watermark=None): 1219 pass 1220 1221 def process_entire_key(self, key, windowed_values): 1222 # This state holds per-key, multi-window state. 1223 state = InMemoryUnmergedState() 1224 for wvalue in self.process_elements(state, 1225 windowed_values, 1226 MIN_TIMESTAMP, 1227 MIN_TIMESTAMP): 1228 yield wvalue.with_value((key, wvalue.value)) 1229 while state.timers: 1230 fired = state.get_and_clear_timers() 1231 for timer_window, (name, time_domain, fire_time, _) in fired: 1232 for wvalue in self.process_timer(timer_window, 1233 name, 1234 time_domain, 1235 fire_time, 1236 state): 1237 yield wvalue.with_value((key, wvalue.value)) 1238 1239 1240 class _UnwindowedValues(observable.ObservableMixin): 1241 """Exposes iterable of windowed values as iterable of unwindowed values.""" 1242 def __init__(self, windowed_values): 1243 super().__init__() 1244 self._windowed_values = windowed_values 1245 1246 def __iter__(self): 1247 for wv in self._windowed_values: 1248 unwindowed_value = wv.value 1249 self.notify_observers(unwindowed_value) 1250 yield unwindowed_value 1251 1252 def __repr__(self): 1253 return '<_UnwindowedValues of %s>' % self._windowed_values 1254 1255 def __reduce__(self): 1256 return list, (list(self), ) 1257 1258 def __eq__(self, other): 1259 if isinstance(other, collections_abc.Iterable): 1260 return all( 1261 a == b for a, b in zip_longest(self, other, fillvalue=object())) 1262 else: 1263 return NotImplemented 1264 1265 def __hash__(self): 1266 return hash(tuple(self)) 1267 1268 1269 coder_impl.FastPrimitivesCoderImpl.register_iterable_like_type( 1270 _UnwindowedValues) 1271 1272 1273 class BatchGlobalTriggerDriver(TriggerDriver): 1274 """Groups all received values together. 1275 """ 1276 GLOBAL_WINDOW_TUPLE = (GlobalWindow(), ) 1277 ONLY_FIRING = windowed_value.PaneInfo( 1278 is_first=True, 1279 is_last=True, 1280 timing=windowed_value.PaneInfoTiming.ON_TIME, 1281 index=0, 1282 nonspeculative_index=0) 1283 1284 def process_elements( 1285 self, 1286 state, 1287 windowed_values, 1288 unused_output_watermark, 1289 unused_input_watermark=MIN_TIMESTAMP): 1290 yield WindowedValue( 1291 _UnwindowedValues(windowed_values), 1292 MIN_TIMESTAMP, 1293 self.GLOBAL_WINDOW_TUPLE, 1294 self.ONLY_FIRING) 1295 1296 def process_timer( 1297 self, 1298 window_id, 1299 name, 1300 time_domain, 1301 timestamp, 1302 state, 1303 input_watermark=None): 1304 raise TypeError('Triggers never set or called for batch default windowing.') 1305 1306 1307 class CombiningTriggerDriver(TriggerDriver): 1308 """Uses a phased_combine_fn to process output of wrapped TriggerDriver.""" 1309 def __init__(self, phased_combine_fn, underlying): 1310 self.phased_combine_fn = phased_combine_fn 1311 self.underlying = underlying 1312 1313 def process_elements( 1314 self, 1315 state, 1316 windowed_values, 1317 output_watermark, 1318 input_watermark=MIN_TIMESTAMP): 1319 uncombined = self.underlying.process_elements( 1320 state, windowed_values, output_watermark, input_watermark) 1321 for output in uncombined: 1322 yield output.with_value(self.phased_combine_fn.apply(output.value)) 1323 1324 def process_timer( 1325 self, 1326 window_id, 1327 name, 1328 time_domain, 1329 timestamp, 1330 state, 1331 input_watermark=None): 1332 uncombined = self.underlying.process_timer( 1333 window_id, name, time_domain, timestamp, state, input_watermark) 1334 for output in uncombined: 1335 yield output.with_value(self.phased_combine_fn.apply(output.value)) 1336 1337 1338 class GeneralTriggerDriver(TriggerDriver): 1339 """Breaks a series of bundle and timer firings into window (pane)s. 1340 1341 Suitable for all variants of Windowing. 1342 """ 1343 ELEMENTS = _ListStateTag('elements') 1344 TOMBSTONE = _CombiningValueStateTag('tombstone', combiners.CountCombineFn()) 1345 INDEX = _CombiningValueStateTag('index', combiners.CountCombineFn()) 1346 NONSPECULATIVE_INDEX = _CombiningValueStateTag( 1347 'nonspeculative_index', combiners.CountCombineFn()) 1348 1349 def __init__(self, windowing, clock): 1350 self.clock = clock 1351 self.allowed_lateness = windowing.allowed_lateness 1352 self.window_fn = windowing.windowfn 1353 self.timestamp_combiner_impl = TimestampCombiner.get_impl( 1354 windowing.timestamp_combiner, self.window_fn) 1355 # pylint: disable=invalid-name 1356 self.WATERMARK_HOLD = _WatermarkHoldStateTag( 1357 'watermark', self.timestamp_combiner_impl) 1358 # pylint: enable=invalid-name 1359 self.trigger_fn = windowing.triggerfn 1360 self.accumulation_mode = windowing.accumulation_mode 1361 self.is_merging = True 1362 1363 def process_elements( 1364 self, 1365 state, 1366 windowed_values, 1367 output_watermark, 1368 input_watermark=MIN_TIMESTAMP): 1369 if self.is_merging: 1370 state = MergeableStateAdapter(state) 1371 1372 windows_to_elements = collections.defaultdict(list) 1373 for wv in windowed_values: 1374 for window in wv.windows: 1375 # ignore expired windows 1376 if input_watermark > window.end + self.allowed_lateness: 1377 continue 1378 windows_to_elements[window].append((wv.value, wv.timestamp)) 1379 1380 # First handle merging. 1381 if self.is_merging: 1382 old_windows = set(state.known_windows()) 1383 all_windows = old_windows.union(list(windows_to_elements)) 1384 1385 if all_windows != old_windows: 1386 merged_away = {} 1387 1388 class TriggerMergeContext(WindowFn.MergeContext): 1389 def merge(_, to_be_merged, merge_result): # pylint: disable=no-self-argument 1390 for window in to_be_merged: 1391 if window != merge_result: 1392 merged_away[window] = merge_result 1393 # Clear state associated with PaneInfo since it is 1394 # not preserved across merges. 1395 state.clear_state(window, self.INDEX) 1396 state.clear_state(window, self.NONSPECULATIVE_INDEX) 1397 state.merge(to_be_merged, merge_result) 1398 # using the outer self argument. 1399 self.trigger_fn.on_merge( 1400 to_be_merged, merge_result, state.at(merge_result, self.clock)) 1401 1402 self.window_fn.merge(TriggerMergeContext(all_windows)) 1403 1404 merged_windows_to_elements = collections.defaultdict(list) 1405 for window, values in windows_to_elements.items(): 1406 while window in merged_away: 1407 window = merged_away[window] 1408 merged_windows_to_elements[window].extend(values) 1409 windows_to_elements = merged_windows_to_elements 1410 1411 for window in merged_away: 1412 state.clear_state(window, self.WATERMARK_HOLD) 1413 1414 # Next handle element adding. 1415 for window, elements in windows_to_elements.items(): 1416 if state.get_state(window, self.TOMBSTONE): 1417 continue 1418 # Add watermark hold. 1419 # TODO(ccy): Add late data and garbage-collection hold support. 1420 output_time = self.timestamp_combiner_impl.merge( 1421 window, 1422 ( 1423 element_output_time for element_output_time in ( 1424 self.timestamp_combiner_impl.assign_output_time( 1425 window, timestamp) for unused_value, 1426 timestamp in elements) 1427 if element_output_time >= output_watermark)) 1428 if output_time is not None: 1429 state.add_state(window, self.WATERMARK_HOLD, output_time) 1430 1431 context = state.at(window, self.clock) 1432 for value, unused_timestamp in elements: 1433 state.add_state(window, self.ELEMENTS, value) 1434 self.trigger_fn.on_element(value, window, context) 1435 1436 # Maybe fire this window. 1437 if self.trigger_fn.should_fire(TimeDomain.WATERMARK, 1438 input_watermark, 1439 window, 1440 context): 1441 finished = self.trigger_fn.on_fire(input_watermark, window, context) 1442 yield self._output(window, finished, state, output_watermark, False) 1443 1444 def process_timer( 1445 self, 1446 window_id, 1447 unused_name, 1448 time_domain, 1449 timestamp, 1450 state, 1451 input_watermark=None): 1452 if input_watermark is None: 1453 input_watermark = timestamp 1454 1455 if self.is_merging: 1456 state = MergeableStateAdapter(state) 1457 window = state.get_window(window_id) 1458 if state.get_state(window, self.TOMBSTONE): 1459 return 1460 1461 if time_domain in (TimeDomain.WATERMARK, TimeDomain.REAL_TIME): 1462 if not self.is_merging or window in state.known_windows(): 1463 context = state.at(window, self.clock) 1464 if self.trigger_fn.should_fire(time_domain, timestamp, window, context): 1465 finished = self.trigger_fn.on_fire(timestamp, window, context) 1466 yield self._output( 1467 window, 1468 finished, 1469 state, 1470 timestamp, 1471 time_domain == TimeDomain.WATERMARK) 1472 else: 1473 raise Exception('Unexpected time domain: %s' % time_domain) 1474 1475 def _output(self, window, finished, state, output_watermark, maybe_ontime): 1476 """Output window and clean up if appropriate.""" 1477 index = state.get_state(window, self.INDEX) 1478 state.add_state(window, self.INDEX, 1) 1479 if output_watermark <= window.max_timestamp(): 1480 nonspeculative_index = -1 1481 timing = windowed_value.PaneInfoTiming.EARLY 1482 if state.get_state(window, self.NONSPECULATIVE_INDEX): 1483 nonspeculative_index = state.get_state( 1484 window, self.NONSPECULATIVE_INDEX) 1485 state.add_state(window, self.NONSPECULATIVE_INDEX, 1) 1486 _LOGGER.warning( 1487 'Watermark moved backwards in time ' 1488 'or late data moved window end forward.') 1489 else: 1490 nonspeculative_index = state.get_state(window, self.NONSPECULATIVE_INDEX) 1491 state.add_state(window, self.NONSPECULATIVE_INDEX, 1) 1492 timing = ( 1493 windowed_value.PaneInfoTiming.ON_TIME if maybe_ontime and 1494 nonspeculative_index == 0 else windowed_value.PaneInfoTiming.LATE) 1495 pane_info = windowed_value.PaneInfo( 1496 index == 0, finished, timing, index, nonspeculative_index) 1497 1498 values = state.get_state(window, self.ELEMENTS) 1499 if finished: 1500 # TODO(robertwb): allowed lateness 1501 state.clear_state(window, self.ELEMENTS) 1502 state.add_state(window, self.TOMBSTONE, 1) 1503 elif self.accumulation_mode == AccumulationMode.DISCARDING: 1504 state.clear_state(window, self.ELEMENTS) 1505 1506 timestamp = state.get_state(window, self.WATERMARK_HOLD) 1507 if timestamp is None: 1508 # If no watermark hold was set, output at end of window. 1509 timestamp = window.max_timestamp() 1510 elif output_watermark < window.end and self.trigger_fn.has_ontime_pane(): 1511 # Hold the watermark in case there is an empty pane that needs to be fired 1512 # at the end of the window. 1513 pass 1514 else: 1515 state.clear_state(window, self.WATERMARK_HOLD) 1516 1517 return WindowedValue(values, timestamp, (window, ), pane_info) 1518 1519 1520 class InMemoryUnmergedState(UnmergedState): 1521 """In-memory implementation of UnmergedState. 1522 1523 Used for batch and testing. 1524 """ 1525 def __init__(self, defensive_copy=False): 1526 # TODO(robertwb): Clean defensive_copy. It is too expensive in production. 1527 self.timers = collections.defaultdict(dict) 1528 self.state = collections.defaultdict(lambda: collections.defaultdict(list)) 1529 self.global_state = {} 1530 self.defensive_copy = defensive_copy 1531 1532 def copy(self): 1533 cloned_object = InMemoryUnmergedState(defensive_copy=self.defensive_copy) 1534 cloned_object.timers = copy.deepcopy(self.timers) 1535 cloned_object.global_state = copy.deepcopy(self.global_state) 1536 for window in self.state: 1537 for tag in self.state[window]: 1538 cloned_object.state[window][tag] = copy.copy(self.state[window][tag]) 1539 return cloned_object 1540 1541 def set_global_state(self, tag, value): 1542 assert isinstance(tag, _ReadModifyWriteStateTag) 1543 if self.defensive_copy: 1544 value = copy.deepcopy(value) 1545 self.global_state[tag.tag] = value 1546 1547 def get_global_state(self, tag, default=None): 1548 return self.global_state.get(tag.tag, default) 1549 1550 def set_timer( 1551 self, window, name, time_domain, timestamp, dynamic_timer_tag=''): 1552 self.timers[window][(name, time_domain, dynamic_timer_tag)] = timestamp 1553 1554 def clear_timer(self, window, name, time_domain, dynamic_timer_tag=''): 1555 self.timers[window].pop((name, time_domain, dynamic_timer_tag), None) 1556 if not self.timers[window]: 1557 del self.timers[window] 1558 1559 def get_window(self, window_id): 1560 return window_id 1561 1562 def add_state(self, window, tag, value): 1563 if self.defensive_copy: 1564 value = copy.deepcopy(value) 1565 if isinstance(tag, _ReadModifyWriteStateTag): 1566 self.state[window][tag.tag] = value 1567 elif isinstance(tag, _CombiningValueStateTag): 1568 # TODO(robertwb): Store merged accumulators. 1569 self.state[window][tag.tag].append(value) 1570 elif isinstance(tag, _ListStateTag): 1571 self.state[window][tag.tag].append(value) 1572 elif isinstance(tag, _SetStateTag): 1573 self.state[window][tag.tag].append(value) 1574 elif isinstance(tag, _WatermarkHoldStateTag): 1575 self.state[window][tag.tag].append(value) 1576 else: 1577 raise ValueError('Invalid tag.', tag) 1578 1579 def get_state(self, window, tag): 1580 values = self.state[window][tag.tag] 1581 if isinstance(tag, _ReadModifyWriteStateTag): 1582 return values 1583 elif isinstance(tag, _CombiningValueStateTag): 1584 return tag.combine_fn.apply(values) 1585 elif isinstance(tag, _ListStateTag): 1586 return values 1587 elif isinstance(tag, _SetStateTag): 1588 return values 1589 elif isinstance(tag, _WatermarkHoldStateTag): 1590 return tag.timestamp_combiner_impl.combine_all(values) 1591 else: 1592 raise ValueError('Invalid tag.', tag) 1593 1594 def clear_state(self, window, tag): 1595 self.state[window].pop(tag.tag, None) 1596 if not self.state[window]: 1597 self.state.pop(window, None) 1598 1599 def get_timers( 1600 self, clear=False, watermark=MAX_TIMESTAMP, processing_time=None): 1601 """Gets expired timers and reports if there 1602 are any realtime timers set per state. 1603 1604 Expiration is measured against the watermark for event-time timers, 1605 and against a wall clock for processing-time timers. 1606 """ 1607 expired = [] 1608 has_realtime_timer = False 1609 for window, timers in list(self.timers.items()): 1610 for (name, time_domain, dynamic_timer_tag), timestamp in list( 1611 timers.items()): 1612 if time_domain == TimeDomain.REAL_TIME: 1613 time_marker = processing_time 1614 has_realtime_timer = True 1615 elif time_domain == TimeDomain.WATERMARK: 1616 time_marker = watermark 1617 else: 1618 _LOGGER.error( 1619 'TimeDomain error: No timers defined for time domain %s.', 1620 time_domain) 1621 if timestamp <= time_marker: 1622 expired.append( 1623 (window, (name, time_domain, timestamp, dynamic_timer_tag))) 1624 if clear: 1625 del timers[(name, time_domain, dynamic_timer_tag)] 1626 if not timers and clear: 1627 del self.timers[window] 1628 return expired, has_realtime_timer 1629 1630 def get_and_clear_timers(self, watermark=MAX_TIMESTAMP): 1631 return self.get_timers(clear=True, watermark=watermark)[0] 1632 1633 def get_earliest_hold(self): 1634 earliest_hold = MAX_TIMESTAMP 1635 for unused_window, tagged_states in self.state.items(): 1636 # TODO(https://github.com/apache/beam/issues/18441): currently, this 1637 # assumes that the watermark hold tag is named "watermark". This is 1638 # currently only true because the only place watermark holds are set is 1639 # in the GeneralTriggerDriver, where we use this name. We should fix 1640 # this by allowing enumeration of the tag types used in adding state. 1641 if 'watermark' in tagged_states and tagged_states['watermark']: 1642 hold = min(tagged_states['watermark']) - TIME_GRANULARITY 1643 earliest_hold = min(earliest_hold, hold) 1644 return earliest_hold 1645 1646 def __repr__(self): 1647 state_str = '\n'.join( 1648 '%s: %s' % (key, dict(state)) for key, state in self.state.items()) 1649 return 'timers: %s\nstate: %s' % (dict(self.timers), state_str)