github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/userstate.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """User-facing interfaces for the Beam State and Timer APIs."""
    19  
    20  # pytype: skip-file
    21  # mypy: disallow-untyped-defs
    22  
    23  import collections
    24  import types
    25  from typing import TYPE_CHECKING
    26  from typing import Any
    27  from typing import Callable
    28  from typing import Dict
    29  from typing import Iterable
    30  from typing import NamedTuple
    31  from typing import Optional
    32  from typing import Set
    33  from typing import Tuple
    34  from typing import TypeVar
    35  
    36  from apache_beam.coders import Coder
    37  from apache_beam.coders import coders
    38  from apache_beam.portability import common_urns
    39  from apache_beam.portability.api import beam_runner_api_pb2
    40  from apache_beam.transforms.timeutil import TimeDomain
    41  
    42  if TYPE_CHECKING:
    43    from apache_beam.runners.pipeline_context import PipelineContext
    44    from apache_beam.transforms.core import CombineFn, DoFn
    45    from apache_beam.utils import windowed_value
    46    from apache_beam.utils.timestamp import Timestamp
    47  
    48  CallableT = TypeVar('CallableT', bound=Callable)
    49  
    50  
    51  class StateSpec(object):
    52    """Specification for a user DoFn state cell."""
    53    def __init__(self, name, coder):
    54      # type: (str, Coder) -> None
    55      if not isinstance(name, str):
    56        raise TypeError("name is not a string")
    57      if not isinstance(coder, Coder):
    58        raise TypeError("coder is not of type Coder")
    59      self.name = name
    60      self.coder = coder
    61  
    62    def __repr__(self):
    63      # type: () -> str
    64      return '%s(%s)' % (self.__class__.__name__, self.name)
    65  
    66    def to_runner_api(self, context):
    67      # type: (PipelineContext) -> beam_runner_api_pb2.StateSpec
    68      raise NotImplementedError
    69  
    70  
    71  class ReadModifyWriteStateSpec(StateSpec):
    72    """Specification for a user DoFn value state cell."""
    73    def to_runner_api(self, context):
    74      # type: (PipelineContext) -> beam_runner_api_pb2.StateSpec
    75      return beam_runner_api_pb2.StateSpec(
    76          read_modify_write_spec=beam_runner_api_pb2.ReadModifyWriteStateSpec(
    77              coder_id=context.coders.get_id(self.coder)),
    78          protocol=beam_runner_api_pb2.FunctionSpec(
    79              urn=common_urns.user_state.BAG.urn))
    80  
    81  
    82  class BagStateSpec(StateSpec):
    83    """Specification for a user DoFn bag state cell."""
    84    def to_runner_api(self, context):
    85      # type: (PipelineContext) -> beam_runner_api_pb2.StateSpec
    86      return beam_runner_api_pb2.StateSpec(
    87          bag_spec=beam_runner_api_pb2.BagStateSpec(
    88              element_coder_id=context.coders.get_id(self.coder)),
    89          protocol=beam_runner_api_pb2.FunctionSpec(
    90              urn=common_urns.user_state.BAG.urn))
    91  
    92  
    93  class SetStateSpec(StateSpec):
    94    """Specification for a user DoFn Set State cell"""
    95    def to_runner_api(self, context):
    96      # type: (PipelineContext) -> beam_runner_api_pb2.StateSpec
    97      return beam_runner_api_pb2.StateSpec(
    98          set_spec=beam_runner_api_pb2.SetStateSpec(
    99              element_coder_id=context.coders.get_id(self.coder)),
   100          protocol=beam_runner_api_pb2.FunctionSpec(
   101              urn=common_urns.user_state.BAG.urn))
   102  
   103  
   104  class CombiningValueStateSpec(StateSpec):
   105    """Specification for a user DoFn combining value state cell."""
   106    def __init__(self, name, coder=None, combine_fn=None):
   107      # type: (str, Optional[Coder], Any) -> None
   108  
   109      """Initialize the specification for CombiningValue state.
   110  
   111      CombiningValueStateSpec(name, combine_fn) -> Coder-inferred combining value
   112        state spec.
   113      CombiningValueStateSpec(name, coder, combine_fn) -> Combining value state
   114        spec with coder and combine_fn specified.
   115  
   116      Args:
   117        name (str): The name by which the state is identified.
   118        coder (Coder): Coder specifying how to encode the values to be combined.
   119          May be inferred.
   120        combine_fn (``CombineFn`` or ``callable``): Function specifying how to
   121          combine the values passed to state.
   122      """
   123      # Avoid circular import.
   124      from apache_beam.transforms.core import CombineFn
   125      # We want the coder to be optional, but unfortunately it comes
   126      # before the non-optional combine_fn parameter, which we can't
   127      # change for backwards compatibility reasons.
   128      #
   129      # Instead, allow it to be omitted (by either passing two arguments
   130      # or combine_fn by keyword.)
   131      if combine_fn is None:
   132        if coder is None:
   133          raise ValueError('combine_fn must be provided')
   134        else:
   135          coder, combine_fn = None, coder
   136      self.combine_fn = CombineFn.maybe_from_callable(combine_fn)
   137      # The coder here should be for the accumulator type of the given CombineFn.
   138      if coder is None:
   139        coder = self.combine_fn.get_accumulator_coder()
   140  
   141      super().__init__(name, coder)
   142  
   143    def to_runner_api(self, context):
   144      # type: (PipelineContext) -> beam_runner_api_pb2.StateSpec
   145      return beam_runner_api_pb2.StateSpec(
   146          combining_spec=beam_runner_api_pb2.CombiningStateSpec(
   147              combine_fn=self.combine_fn.to_runner_api(context),
   148              accumulator_coder_id=context.coders.get_id(self.coder)),
   149          protocol=beam_runner_api_pb2.FunctionSpec(
   150              urn=common_urns.user_state.BAG.urn))
   151  
   152  
   153  # TODO(BEAM-9562): Update Timer to have of() and clear() APIs.
   154  Timer = NamedTuple(
   155      'Timer',
   156      [
   157          ('user_key', Any),
   158          ('dynamic_timer_tag', str),
   159          ('windows', Tuple['windowed_value.BoundedWindow', ...]),
   160          ('clear_bit', bool),
   161          ('fire_timestamp', Optional['Timestamp']),
   162          ('hold_timestamp', Optional['Timestamp']),
   163          ('paneinfo', Optional['windowed_value.PaneInfo']),
   164      ])
   165  
   166  
   167  # TODO(BEAM-9562): Plumb through actual key_coder and window_coder.
   168  class TimerSpec(object):
   169    """Specification for a user stateful DoFn timer."""
   170    prefix = "ts-"
   171  
   172    def __init__(self, name, time_domain):
   173      # type: (str, str) -> None
   174      self.name = self.prefix + name
   175      if time_domain not in (TimeDomain.WATERMARK, TimeDomain.REAL_TIME):
   176        raise ValueError('Unsupported TimeDomain: %r.' % (time_domain, ))
   177      self.time_domain = time_domain
   178      self._attached_callback = None  # type: Optional[Callable]
   179  
   180    def __repr__(self):
   181      # type: () -> str
   182      return '%s(%s)' % (self.__class__.__name__, self.name)
   183  
   184    def to_runner_api(self, context, key_coder, window_coder):
   185      # type: (PipelineContext, Coder, Coder) -> beam_runner_api_pb2.TimerFamilySpec
   186      return beam_runner_api_pb2.TimerFamilySpec(
   187          time_domain=TimeDomain.to_runner_api(self.time_domain),
   188          timer_family_coder_id=context.coders.get_id(
   189              coders._TimerCoder(key_coder, window_coder)))
   190  
   191  
   192  def on_timer(timer_spec):
   193    # type: (TimerSpec) -> Callable[[CallableT], CallableT]
   194  
   195    """Decorator for timer firing DoFn method.
   196  
   197    This decorator allows a user to specify an on_timer processing method
   198    in a stateful DoFn.  Sample usage::
   199  
   200      class MyDoFn(DoFn):
   201        TIMER_SPEC = TimerSpec('timer', TimeDomain.WATERMARK)
   202  
   203        @on_timer(TIMER_SPEC)
   204        def my_timer_expiry_callback(self):
   205          logging.info('Timer expired!')
   206    """
   207  
   208    if not isinstance(timer_spec, TimerSpec):
   209      raise ValueError('@on_timer decorator expected TimerSpec.')
   210  
   211    def _inner(method):
   212      # type: (CallableT) -> CallableT
   213      if not callable(method):
   214        raise ValueError('@on_timer decorator expected callable.')
   215      if timer_spec._attached_callback:
   216        raise ValueError(
   217            'Multiple on_timer callbacks registered for %r.' % timer_spec)
   218      timer_spec._attached_callback = method
   219      return method
   220  
   221    return _inner
   222  
   223  
   224  def get_dofn_specs(dofn):
   225    # type: (DoFn) -> Tuple[Set[StateSpec], Set[TimerSpec]]
   226  
   227    """Gets the state and timer specs for a DoFn, if any.
   228  
   229    Args:
   230      dofn (apache_beam.transforms.core.DoFn): The DoFn instance to introspect for
   231        timer and state specs.
   232    """
   233  
   234    # Avoid circular import.
   235    from apache_beam.runners.common import MethodWrapper
   236    from apache_beam.transforms.core import _DoFnParam
   237    from apache_beam.transforms.core import _StateDoFnParam
   238    from apache_beam.transforms.core import _TimerDoFnParam
   239  
   240    all_state_specs = set()
   241    all_timer_specs = set()
   242  
   243    # Validate params to process(), start_bundle(), finish_bundle() and to
   244    # any on_timer callbacks.
   245    for method_name in dir(dofn):
   246      if not isinstance(getattr(dofn, method_name, None), types.MethodType):
   247        continue
   248      method = MethodWrapper(dofn, method_name)
   249      param_ids = [
   250          d.param_id for d in method.defaults if isinstance(d, _DoFnParam)
   251      ]
   252      if len(param_ids) != len(set(param_ids)):
   253        raise ValueError(
   254            'DoFn %r has duplicate %s method parameters: %s.' %
   255            (dofn, method_name, param_ids))
   256      for d in method.defaults:
   257        if isinstance(d, _StateDoFnParam):
   258          all_state_specs.add(d.state_spec)
   259        elif isinstance(d, _TimerDoFnParam):
   260          all_timer_specs.add(d.timer_spec)
   261  
   262    return all_state_specs, all_timer_specs
   263  
   264  
   265  def is_stateful_dofn(dofn):
   266    # type: (DoFn) -> bool
   267  
   268    """Determines whether a given DoFn is a stateful DoFn."""
   269  
   270    # A Stateful DoFn is a DoFn that uses user state or timers.
   271    all_state_specs, all_timer_specs = get_dofn_specs(dofn)
   272    return bool(all_state_specs or all_timer_specs)
   273  
   274  
   275  def validate_stateful_dofn(dofn):
   276    # type: (DoFn) -> None
   277  
   278    """Validates the proper specification of a stateful DoFn."""
   279  
   280    # Get state and timer specs.
   281    all_state_specs, all_timer_specs = get_dofn_specs(dofn)
   282  
   283    # Reject DoFns that have multiple state or timer specs with the same name.
   284    if len(all_state_specs) != len(set(s.name for s in all_state_specs)):
   285      raise ValueError(
   286          'DoFn %r has multiple StateSpecs with the same name: %s.' %
   287          (dofn, all_state_specs))
   288    if len(all_timer_specs) != len(set(s.name for s in all_timer_specs)):
   289      raise ValueError(
   290          'DoFn %r has multiple TimerSpecs with the same name: %s.' %
   291          (dofn, all_timer_specs))
   292  
   293    # Reject DoFns that use timer specs without corresponding timer callbacks.
   294    for timer_spec in all_timer_specs:
   295      if not timer_spec._attached_callback:
   296        raise ValueError((
   297            'DoFn %r has a TimerSpec without an associated on_timer '
   298            'callback: %s.') % (dofn, timer_spec))
   299      method_name = timer_spec._attached_callback.__name__
   300      if (timer_spec._attached_callback != getattr(dofn, method_name,
   301                                                   None).__func__):
   302        raise ValueError((
   303            'The on_timer callback for %s is not the specified .%s method '
   304            'for DoFn %r (perhaps it was overwritten?).') %
   305                         (timer_spec, method_name, dofn))
   306  
   307  
   308  class BaseTimer(object):
   309    def clear(self, dynamic_timer_tag=''):
   310      # type: (str) -> None
   311      raise NotImplementedError
   312  
   313    def set(self, timestamp, dynamic_timer_tag=''):
   314      # type: (Timestamp, str) -> None
   315      raise NotImplementedError
   316  
   317  
   318  _TimerTuple = collections.namedtuple('timer_tuple', ('cleared', 'timestamp'))
   319  
   320  
   321  class RuntimeTimer(BaseTimer):
   322    """Timer interface object passed to user code."""
   323    def __init__(self) -> None:
   324      self._timer_recordings = {}  # type: Dict[str, _TimerTuple]
   325      self._cleared = False
   326      self._new_timestamp = None  # type: Optional[Timestamp]
   327  
   328    def clear(self, dynamic_timer_tag=''):
   329      # type: (str) -> None
   330      self._timer_recordings[dynamic_timer_tag] = _TimerTuple(
   331          cleared=True, timestamp=None)
   332  
   333    def set(self, timestamp, dynamic_timer_tag=''):
   334      # type: (Timestamp, str) -> None
   335      self._timer_recordings[dynamic_timer_tag] = _TimerTuple(
   336          cleared=False, timestamp=timestamp)
   337  
   338  
   339  class RuntimeState(object):
   340    """State interface object passed to user code."""
   341    def prefetch(self):
   342      # type: () -> None
   343      # The default implementation here does nothing.
   344      pass
   345  
   346    def finalize(self):
   347      # type: () -> None
   348      pass
   349  
   350  
   351  class ReadModifyWriteRuntimeState(RuntimeState):
   352    def read(self):
   353      # type: () -> Any
   354      raise NotImplementedError(type(self))
   355  
   356    def write(self, value):
   357      # type: (Any) -> None
   358      raise NotImplementedError(type(self))
   359  
   360    def clear(self):
   361      # type: () -> None
   362      raise NotImplementedError(type(self))
   363  
   364    def commit(self):
   365      # type: () -> None
   366      raise NotImplementedError(type(self))
   367  
   368  
   369  class AccumulatingRuntimeState(RuntimeState):
   370    def read(self):
   371      # type: () -> Iterable[Any]
   372      raise NotImplementedError(type(self))
   373  
   374    def add(self, value):
   375      # type: (Any) -> None
   376      raise NotImplementedError(type(self))
   377  
   378    def clear(self):
   379      # type: () -> None
   380      raise NotImplementedError(type(self))
   381  
   382    def commit(self):
   383      # type: () -> None
   384      raise NotImplementedError(type(self))
   385  
   386  
   387  class BagRuntimeState(AccumulatingRuntimeState):
   388    """Bag state interface object passed to user code."""
   389  
   390  
   391  class SetRuntimeState(AccumulatingRuntimeState):
   392    """Set state interface object passed to user code."""
   393  
   394  
   395  class CombiningValueRuntimeState(AccumulatingRuntimeState):
   396    """Combining value state interface object passed to user code."""
   397  
   398  
   399  class UserStateContext(object):
   400    """Wrapper allowing user state and timers to be accessed by a DoFnInvoker."""
   401    def get_timer(self,
   402                  timer_spec,  # type: TimerSpec
   403                  key,  # type: Any
   404                  window,  # type: windowed_value.BoundedWindow
   405                  timestamp,  # type: Timestamp
   406                  pane,  # type: windowed_value.PaneInfo
   407                 ):
   408      # type: (...) -> BaseTimer
   409      raise NotImplementedError(type(self))
   410  
   411    def get_state(self,
   412                  state_spec,  # type: StateSpec
   413                  key,  # type: Any
   414                  window,  # type: windowed_value.BoundedWindow
   415                 ):
   416      # type: (...) -> RuntimeState
   417      raise NotImplementedError(type(self))
   418  
   419    def commit(self):
   420      # type: () -> None
   421      raise NotImplementedError(type(self))