github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/options/capture_limiters.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Module to condition how Interactive Beam stops capturing data.
    19  
    20  For internal use only; no backwards-compatibility guarantees.
    21  """
    22  
    23  import threading
    24  
    25  import pandas as pd
    26  
    27  from apache_beam.portability.api import beam_interactive_api_pb2
    28  from apache_beam.portability.api import beam_runner_api_pb2
    29  from apache_beam.runners.interactive import interactive_environment as ie
    30  from apache_beam.utils.windowed_value import WindowedValue
    31  
    32  
    33  class Limiter:
    34    """Limits an aspect of the caching layer."""
    35    def is_triggered(self):
    36      # type: () -> bool
    37  
    38      """Returns True if the limiter has triggered, and caching should stop."""
    39      raise NotImplementedError
    40  
    41  
    42  class ElementLimiter(Limiter):
    43    """A `Limiter` that limits reading from cache based on some property of an
    44    element.
    45    """
    46    def update(self, e):
    47      # type: (Any) -> None # noqa: F821
    48  
    49      """Update the internal state based on some property of an element.
    50  
    51      This is executed on every element that is read from cache.
    52      """
    53      raise NotImplementedError
    54  
    55  
    56  class SizeLimiter(Limiter):
    57    """Limits the cache size to a specified byte limit."""
    58    def __init__(
    59        self,
    60        size_limit  # type: int
    61    ):
    62      self._size_limit = size_limit
    63  
    64    def is_triggered(self):
    65      total_capture_size = 0
    66      ie.current_env().track_user_pipelines()
    67      for user_pipeline in ie.current_env().tracked_user_pipelines:
    68        cache_manager = ie.current_env().get_cache_manager(user_pipeline)
    69        if hasattr(cache_manager, 'capture_size'):
    70          total_capture_size += cache_manager.capture_size
    71      return total_capture_size >= self._size_limit
    72  
    73  
    74  class DurationLimiter(Limiter):
    75    """Limits the duration of the capture."""
    76    def __init__(
    77        self,
    78        duration_limit  # type: datetime.timedelta # noqa: F821
    79    ):
    80      self._duration_limit = duration_limit
    81      self._timer = threading.Timer(duration_limit.total_seconds(), self._trigger)
    82      self._timer.daemon = True
    83      self._triggered = False
    84      self._timer.start()
    85  
    86    def _trigger(self):
    87      self._triggered = True
    88  
    89    def is_triggered(self):
    90      return self._triggered
    91  
    92  
    93  class CountLimiter(ElementLimiter):
    94    """Limits by counting the number of elements seen."""
    95    def __init__(self, max_count):
    96      self._max_count = max_count
    97      self._count = 0
    98  
    99    def update(self, e):
   100      # A TestStreamFileRecord can contain many elements at once. If e is a file
   101      # record, then count the number of elements in the bundle.
   102      if isinstance(e, beam_interactive_api_pb2.TestStreamFileRecord):
   103        if not e.recorded_event.element_event:
   104          return
   105        self._count += len(e.recorded_event.element_event.elements)
   106  
   107      # Otherwise, count everything else but the header of the file since it is
   108      # not an element.
   109      elif not isinstance(e, beam_interactive_api_pb2.TestStreamFileHeader):
   110        # When elements are DataFrames, we want the output to be constrained by
   111        # how many rows we have read, not how many DataFrames we have read.
   112        if isinstance(e, WindowedValue) and isinstance(e.value, pd.DataFrame):
   113          self._count += len(e.value)
   114        else:
   115          self._count += 1
   116  
   117    def is_triggered(self):
   118      return self._count >= self._max_count
   119  
   120  
   121  class ProcessingTimeLimiter(ElementLimiter):
   122    """Limits by how long the ProcessingTime passed in the element stream.
   123  
   124    Reads all elements from the timespan [start, start + duration).
   125  
   126    This measures the duration from the first element in the stream. Each
   127    subsequent element has a delta "advance_duration" that moves the internal
   128    clock forward. This triggers when the duration from the internal clock and
   129    the start exceeds the given duration.
   130    """
   131    def __init__(self, max_duration_secs):
   132      """Initialize the ProcessingTimeLimiter."""
   133      self._max_duration_us = max_duration_secs * 1e6
   134      self._start_us = 0
   135      self._cur_time_us = 0
   136  
   137    def update(self, e):
   138      # Only look at TestStreamFileRecords which hold the processing time.
   139      if not isinstance(e, beam_runner_api_pb2.TestStreamPayload.Event):
   140        return
   141  
   142      if not e.HasField('processing_time_event'):
   143        return
   144  
   145      if self._start_us == 0:
   146        self._start_us = e.processing_time_event.advance_duration
   147      self._cur_time_us += e.processing_time_event.advance_duration
   148  
   149    def is_triggered(self):
   150      return self._cur_time_us - self._start_us >= self._max_duration_us