github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/direct/direct_userstate.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Support for user state in the BundleBasedDirectRunner."""
    19  # pytype: skip-file
    20  
    21  import copy
    22  import itertools
    23  
    24  from apache_beam.transforms import userstate
    25  from apache_beam.transforms.trigger import _ListStateTag
    26  from apache_beam.transforms.trigger import _ReadModifyWriteStateTag
    27  from apache_beam.transforms.trigger import _SetStateTag
    28  
    29  
    30  class DirectRuntimeState(userstate.RuntimeState):
    31    def __init__(self, state_spec, state_tag, current_value_accessor):
    32      self._state_spec = state_spec
    33      self._state_tag = state_tag
    34      self._current_value_accessor = current_value_accessor
    35  
    36    @staticmethod
    37    def for_spec(state_spec, state_tag, current_value_accessor):
    38      if isinstance(state_spec, userstate.ReadModifyWriteStateSpec):
    39        return ReadModifyWriteRuntimeState(
    40            state_spec, state_tag, current_value_accessor)
    41      elif isinstance(state_spec, userstate.BagStateSpec):
    42        return BagRuntimeState(state_spec, state_tag, current_value_accessor)
    43      elif isinstance(state_spec, userstate.CombiningValueStateSpec):
    44        return CombiningValueRuntimeState(
    45            state_spec, state_tag, current_value_accessor)
    46      elif isinstance(state_spec, userstate.SetStateSpec):
    47        return SetRuntimeState(state_spec, state_tag, current_value_accessor)
    48      else:
    49        raise ValueError('Invalid state spec: %s' % state_spec)
    50  
    51    def _encode(self, value):
    52      return self._state_spec.coder.encode(value)
    53  
    54    def _decode(self, value):
    55      return self._state_spec.coder.decode(value)
    56  
    57  
    58  # Sentinel designating an unread value.
    59  UNREAD_VALUE = object()
    60  
    61  
    62  class ReadModifyWriteRuntimeState(DirectRuntimeState,
    63                                    userstate.ReadModifyWriteRuntimeState):
    64    def __init__(self, state_spec, state_tag, current_value_accessor):
    65      super().__init__(state_spec, state_tag, current_value_accessor)
    66      self._value = UNREAD_VALUE
    67      self._cleared = False
    68      self._modified = False
    69  
    70    def read(self):
    71      if self._cleared:
    72        return None
    73      if self._value is UNREAD_VALUE:
    74        self._value = self._current_value_accessor()
    75      if not self._value:
    76        return None
    77      return self._decode(self._value[0])
    78  
    79    def write(self, value):
    80      self._cleared = False
    81      self._modified = True
    82      self._value = [self._encode(value)]
    83  
    84    def clear(self):
    85      self._cleared = True
    86      self._modified = False
    87      self._value = []
    88  
    89    def is_cleared(self):
    90      return self._cleared
    91  
    92    def is_modified(self):
    93      return self._modified
    94  
    95  
    96  class BagRuntimeState(DirectRuntimeState, userstate.BagRuntimeState):
    97    def __init__(self, state_spec, state_tag, current_value_accessor):
    98      super().__init__(state_spec, state_tag, current_value_accessor)
    99      self._cached_value = UNREAD_VALUE
   100      self._cleared = False
   101      self._new_values = []
   102  
   103    def read(self):
   104      if self._cached_value is UNREAD_VALUE:
   105        self._cached_value = self._current_value_accessor()
   106      if not self._cleared:
   107        encoded_values = itertools.chain(self._cached_value, self._new_values)
   108      else:
   109        encoded_values = self._new_values
   110      return (self._decode(v) for v in encoded_values)
   111  
   112    def add(self, value):
   113      self._new_values.append(self._encode(value))
   114  
   115    def clear(self):
   116      self._cleared = True
   117      self._cached_value = []
   118      self._new_values = []
   119  
   120  
   121  class SetRuntimeState(DirectRuntimeState, userstate.SetRuntimeState):
   122    def __init__(self, state_spec, state_tag, current_value_accessor):
   123      super().__init__(state_spec, state_tag, current_value_accessor)
   124      self._current_accumulator = UNREAD_VALUE
   125      self._modified = False
   126  
   127    def _read_initial_value(self):
   128      if self._current_accumulator is UNREAD_VALUE:
   129        self._current_accumulator = {
   130            self._decode(a)
   131            for a in self._current_value_accessor()
   132        }
   133  
   134    def read(self):
   135      self._read_initial_value()
   136      return self._current_accumulator
   137  
   138    def add(self, value):
   139      self._read_initial_value()
   140      self._modified = True
   141      self._current_accumulator.add(value)
   142  
   143    def clear(self):
   144      self._current_accumulator = set()
   145      self._modified = True
   146  
   147    def is_modified(self):
   148      return self._modified
   149  
   150  
   151  class CombiningValueRuntimeState(DirectRuntimeState,
   152                                   userstate.CombiningValueRuntimeState):
   153    """Combining value state interface object passed to user code."""
   154    def __init__(self, state_spec, state_tag, current_value_accessor):
   155      super().__init__(state_spec, state_tag, current_value_accessor)
   156      self._current_accumulator = UNREAD_VALUE
   157      self._modified = False
   158      self._combine_fn = copy.deepcopy(state_spec.combine_fn)
   159      self._combine_fn.setup()
   160      self._finalized = False
   161  
   162    def _read_initial_value(self):
   163      if self._current_accumulator is UNREAD_VALUE:
   164        existing_accumulators = list(
   165            self._decode(a) for a in self._current_value_accessor())
   166        if existing_accumulators:
   167          self._current_accumulator = self._combine_fn.merge_accumulators(
   168              existing_accumulators)
   169        else:
   170          self._current_accumulator = self._combine_fn.create_accumulator()
   171  
   172    def read(self):
   173      self._read_initial_value()
   174      return self._combine_fn.extract_output(self._current_accumulator)
   175  
   176    def add(self, value):
   177      self._read_initial_value()
   178      self._modified = True
   179      self._current_accumulator = self._combine_fn.add_input(
   180          self._current_accumulator, value)
   181  
   182    def clear(self):
   183      self._modified = True
   184      self._current_accumulator = self._combine_fn.create_accumulator()
   185  
   186    def finalize(self):
   187      if not self._finalized:
   188        self._combine_fn.teardown()
   189        self._finalized = True
   190  
   191  
   192  class DirectUserStateContext(userstate.UserStateContext):
   193    """userstate.UserStateContext for the BundleBasedDirectRunner.
   194  
   195    The DirectUserStateContext buffers up updates that are to be committed
   196    by the TransformEvaluator after running a DoFn.
   197    """
   198    def __init__(self, step_context, dofn, key_coder):
   199      self.step_context = step_context
   200      self.dofn = dofn
   201      self.key_coder = key_coder
   202  
   203      self.all_state_specs, self.all_timer_specs = userstate.get_dofn_specs(dofn)
   204      self.state_tags = {}
   205      for state_spec in self.all_state_specs:
   206        state_key = 'user/%s' % state_spec.name
   207        if isinstance(state_spec, userstate.ReadModifyWriteStateSpec):
   208          state_tag = _ReadModifyWriteStateTag(state_key)
   209        elif isinstance(state_spec, userstate.BagStateSpec):
   210          state_tag = _ListStateTag(state_key)
   211        elif isinstance(state_spec, userstate.CombiningValueStateSpec):
   212          state_tag = _ListStateTag(state_key)
   213        elif isinstance(state_spec, userstate.SetStateSpec):
   214          state_tag = _SetStateTag(state_key)
   215        else:
   216          raise ValueError('Invalid state spec: %s' % state_spec)
   217        self.state_tags[state_spec] = state_tag
   218  
   219      self.cached_states = {}
   220      self.cached_timers = {}
   221  
   222    def get_timer(
   223        self, timer_spec: userstate.TimerSpec, key, window, timestamp,
   224        pane) -> userstate.RuntimeTimer:
   225      assert timer_spec in self.all_timer_specs
   226      encoded_key = self.key_coder.encode(key)
   227      cache_key = (encoded_key, window, timer_spec)
   228      if cache_key not in self.cached_timers:
   229        self.cached_timers[cache_key] = userstate.RuntimeTimer()
   230      return self.cached_timers[cache_key]
   231  
   232    def get_state(self, state_spec, key, window):
   233      assert state_spec in self.all_state_specs
   234      encoded_key = self.key_coder.encode(key)
   235      cache_key = (encoded_key, window, state_spec)
   236      if cache_key not in self.cached_states:
   237        state_tag = self.state_tags[state_spec]
   238        value_accessor = (
   239            lambda: self._get_underlying_state(state_spec, key, window))
   240        self.cached_states[cache_key] = DirectRuntimeState.for_spec(
   241            state_spec, state_tag, value_accessor)
   242      return self.cached_states[cache_key]
   243  
   244    def _get_underlying_state(self, state_spec, key, window):
   245      state_tag = self.state_tags[state_spec]
   246      encoded_key = self.key_coder.encode(key)
   247      return (
   248          self.step_context.get_keyed_state(encoded_key).get_state(
   249              window, state_tag))
   250  
   251    def commit(self):
   252      # Commit state modifications.
   253      for cache_key, runtime_state in self.cached_states.items():
   254        encoded_key, window, state_spec = cache_key
   255        state = self.step_context.get_keyed_state(encoded_key)
   256        state_tag = self.state_tags[state_spec]
   257        if isinstance(state_spec, userstate.BagStateSpec):
   258          if runtime_state._cleared:
   259            state.clear_state(window, state_tag)
   260          for new_value in runtime_state._new_values:
   261            state.add_state(window, state_tag, new_value)
   262        elif isinstance(state_spec, userstate.CombiningValueStateSpec):
   263          if runtime_state._modified:
   264            state.clear_state(window, state_tag)
   265            state.add_state(
   266                window,
   267                state_tag,
   268                state_spec.coder.encode(runtime_state._current_accumulator))
   269        elif isinstance(state_spec, userstate.SetStateSpec):
   270          if runtime_state.is_modified():
   271            state.clear_state(window, state_tag)
   272            for new_value in runtime_state._current_accumulator:
   273              state.add_state(
   274                  window, state_tag, state_spec.coder.encode(new_value))
   275        elif isinstance(state_spec, userstate.ReadModifyWriteStateSpec):
   276          if runtime_state.is_cleared():
   277            state.clear_state(window, state_tag)
   278          if runtime_state.is_modified():
   279            state.clear_state(window, state_tag)
   280            state.add_state(window, state_tag, runtime_state._value)
   281        else:
   282          raise ValueError('Invalid state spec: %s' % state_spec)
   283  
   284      # Commit new timers.
   285      for cache_key, runtime_timer in self.cached_timers.items():
   286        encoded_key, window, timer_spec = cache_key
   287        state = self.step_context.get_keyed_state(encoded_key)
   288        timer_name = 'user/%s' % timer_spec.name
   289        for dynamic_timer_tag, timer in runtime_timer._timer_recordings.items():
   290          if timer.cleared:
   291            state.clear_timer(
   292                window,
   293                timer_name,
   294                timer_spec.time_domain,
   295                dynamic_timer_tag=dynamic_timer_tag)
   296          if timer.timestamp:
   297            # TODO(ccy): add corresponding watermark holds after the DirectRunner
   298            # allows for keyed watermark holds.
   299            state.set_timer(
   300                window,
   301                timer_name,
   302                timer_spec.time_domain,
   303                timer.timestamp,
   304                dynamic_timer_tag=dynamic_timer_tag)
   305  
   306    def reset(self):
   307      for state in self.cached_states.values():
   308        state.finalize()
   309      self.cached_states = {}
   310      self.cached_timers = {}