github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/userstate.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """User-facing interfaces for the Beam State and Timer APIs.""" 19 20 # pytype: skip-file 21 # mypy: disallow-untyped-defs 22 23 import collections 24 import types 25 from typing import TYPE_CHECKING 26 from typing import Any 27 from typing import Callable 28 from typing import Dict 29 from typing import Iterable 30 from typing import NamedTuple 31 from typing import Optional 32 from typing import Set 33 from typing import Tuple 34 from typing import TypeVar 35 36 from apache_beam.coders import Coder 37 from apache_beam.coders import coders 38 from apache_beam.portability import common_urns 39 from apache_beam.portability.api import beam_runner_api_pb2 40 from apache_beam.transforms.timeutil import TimeDomain 41 42 if TYPE_CHECKING: 43 from apache_beam.runners.pipeline_context import PipelineContext 44 from apache_beam.transforms.core import CombineFn, DoFn 45 from apache_beam.utils import windowed_value 46 from apache_beam.utils.timestamp import Timestamp 47 48 CallableT = TypeVar('CallableT', bound=Callable) 49 50 51 class StateSpec(object): 52 """Specification for a user DoFn state cell.""" 53 def __init__(self, name, coder): 54 # type: (str, Coder) -> None 55 if not isinstance(name, str): 56 raise TypeError("name is not a string") 57 if not isinstance(coder, Coder): 58 raise TypeError("coder is not of type Coder") 59 self.name = name 60 self.coder = coder 61 62 def __repr__(self): 63 # type: () -> str 64 return '%s(%s)' % (self.__class__.__name__, self.name) 65 66 def to_runner_api(self, context): 67 # type: (PipelineContext) -> beam_runner_api_pb2.StateSpec 68 raise NotImplementedError 69 70 71 class ReadModifyWriteStateSpec(StateSpec): 72 """Specification for a user DoFn value state cell.""" 73 def to_runner_api(self, context): 74 # type: (PipelineContext) -> beam_runner_api_pb2.StateSpec 75 return beam_runner_api_pb2.StateSpec( 76 read_modify_write_spec=beam_runner_api_pb2.ReadModifyWriteStateSpec( 77 coder_id=context.coders.get_id(self.coder)), 78 protocol=beam_runner_api_pb2.FunctionSpec( 79 urn=common_urns.user_state.BAG.urn)) 80 81 82 class BagStateSpec(StateSpec): 83 """Specification for a user DoFn bag state cell.""" 84 def to_runner_api(self, context): 85 # type: (PipelineContext) -> beam_runner_api_pb2.StateSpec 86 return beam_runner_api_pb2.StateSpec( 87 bag_spec=beam_runner_api_pb2.BagStateSpec( 88 element_coder_id=context.coders.get_id(self.coder)), 89 protocol=beam_runner_api_pb2.FunctionSpec( 90 urn=common_urns.user_state.BAG.urn)) 91 92 93 class SetStateSpec(StateSpec): 94 """Specification for a user DoFn Set State cell""" 95 def to_runner_api(self, context): 96 # type: (PipelineContext) -> beam_runner_api_pb2.StateSpec 97 return beam_runner_api_pb2.StateSpec( 98 set_spec=beam_runner_api_pb2.SetStateSpec( 99 element_coder_id=context.coders.get_id(self.coder)), 100 protocol=beam_runner_api_pb2.FunctionSpec( 101 urn=common_urns.user_state.BAG.urn)) 102 103 104 class CombiningValueStateSpec(StateSpec): 105 """Specification for a user DoFn combining value state cell.""" 106 def __init__(self, name, coder=None, combine_fn=None): 107 # type: (str, Optional[Coder], Any) -> None 108 109 """Initialize the specification for CombiningValue state. 110 111 CombiningValueStateSpec(name, combine_fn) -> Coder-inferred combining value 112 state spec. 113 CombiningValueStateSpec(name, coder, combine_fn) -> Combining value state 114 spec with coder and combine_fn specified. 115 116 Args: 117 name (str): The name by which the state is identified. 118 coder (Coder): Coder specifying how to encode the values to be combined. 119 May be inferred. 120 combine_fn (``CombineFn`` or ``callable``): Function specifying how to 121 combine the values passed to state. 122 """ 123 # Avoid circular import. 124 from apache_beam.transforms.core import CombineFn 125 # We want the coder to be optional, but unfortunately it comes 126 # before the non-optional combine_fn parameter, which we can't 127 # change for backwards compatibility reasons. 128 # 129 # Instead, allow it to be omitted (by either passing two arguments 130 # or combine_fn by keyword.) 131 if combine_fn is None: 132 if coder is None: 133 raise ValueError('combine_fn must be provided') 134 else: 135 coder, combine_fn = None, coder 136 self.combine_fn = CombineFn.maybe_from_callable(combine_fn) 137 # The coder here should be for the accumulator type of the given CombineFn. 138 if coder is None: 139 coder = self.combine_fn.get_accumulator_coder() 140 141 super().__init__(name, coder) 142 143 def to_runner_api(self, context): 144 # type: (PipelineContext) -> beam_runner_api_pb2.StateSpec 145 return beam_runner_api_pb2.StateSpec( 146 combining_spec=beam_runner_api_pb2.CombiningStateSpec( 147 combine_fn=self.combine_fn.to_runner_api(context), 148 accumulator_coder_id=context.coders.get_id(self.coder)), 149 protocol=beam_runner_api_pb2.FunctionSpec( 150 urn=common_urns.user_state.BAG.urn)) 151 152 153 # TODO(BEAM-9562): Update Timer to have of() and clear() APIs. 154 Timer = NamedTuple( 155 'Timer', 156 [ 157 ('user_key', Any), 158 ('dynamic_timer_tag', str), 159 ('windows', Tuple['windowed_value.BoundedWindow', ...]), 160 ('clear_bit', bool), 161 ('fire_timestamp', Optional['Timestamp']), 162 ('hold_timestamp', Optional['Timestamp']), 163 ('paneinfo', Optional['windowed_value.PaneInfo']), 164 ]) 165 166 167 # TODO(BEAM-9562): Plumb through actual key_coder and window_coder. 168 class TimerSpec(object): 169 """Specification for a user stateful DoFn timer.""" 170 prefix = "ts-" 171 172 def __init__(self, name, time_domain): 173 # type: (str, str) -> None 174 self.name = self.prefix + name 175 if time_domain not in (TimeDomain.WATERMARK, TimeDomain.REAL_TIME): 176 raise ValueError('Unsupported TimeDomain: %r.' % (time_domain, )) 177 self.time_domain = time_domain 178 self._attached_callback = None # type: Optional[Callable] 179 180 def __repr__(self): 181 # type: () -> str 182 return '%s(%s)' % (self.__class__.__name__, self.name) 183 184 def to_runner_api(self, context, key_coder, window_coder): 185 # type: (PipelineContext, Coder, Coder) -> beam_runner_api_pb2.TimerFamilySpec 186 return beam_runner_api_pb2.TimerFamilySpec( 187 time_domain=TimeDomain.to_runner_api(self.time_domain), 188 timer_family_coder_id=context.coders.get_id( 189 coders._TimerCoder(key_coder, window_coder))) 190 191 192 def on_timer(timer_spec): 193 # type: (TimerSpec) -> Callable[[CallableT], CallableT] 194 195 """Decorator for timer firing DoFn method. 196 197 This decorator allows a user to specify an on_timer processing method 198 in a stateful DoFn. Sample usage:: 199 200 class MyDoFn(DoFn): 201 TIMER_SPEC = TimerSpec('timer', TimeDomain.WATERMARK) 202 203 @on_timer(TIMER_SPEC) 204 def my_timer_expiry_callback(self): 205 logging.info('Timer expired!') 206 """ 207 208 if not isinstance(timer_spec, TimerSpec): 209 raise ValueError('@on_timer decorator expected TimerSpec.') 210 211 def _inner(method): 212 # type: (CallableT) -> CallableT 213 if not callable(method): 214 raise ValueError('@on_timer decorator expected callable.') 215 if timer_spec._attached_callback: 216 raise ValueError( 217 'Multiple on_timer callbacks registered for %r.' % timer_spec) 218 timer_spec._attached_callback = method 219 return method 220 221 return _inner 222 223 224 def get_dofn_specs(dofn): 225 # type: (DoFn) -> Tuple[Set[StateSpec], Set[TimerSpec]] 226 227 """Gets the state and timer specs for a DoFn, if any. 228 229 Args: 230 dofn (apache_beam.transforms.core.DoFn): The DoFn instance to introspect for 231 timer and state specs. 232 """ 233 234 # Avoid circular import. 235 from apache_beam.runners.common import MethodWrapper 236 from apache_beam.transforms.core import _DoFnParam 237 from apache_beam.transforms.core import _StateDoFnParam 238 from apache_beam.transforms.core import _TimerDoFnParam 239 240 all_state_specs = set() 241 all_timer_specs = set() 242 243 # Validate params to process(), start_bundle(), finish_bundle() and to 244 # any on_timer callbacks. 245 for method_name in dir(dofn): 246 if not isinstance(getattr(dofn, method_name, None), types.MethodType): 247 continue 248 method = MethodWrapper(dofn, method_name) 249 param_ids = [ 250 d.param_id for d in method.defaults if isinstance(d, _DoFnParam) 251 ] 252 if len(param_ids) != len(set(param_ids)): 253 raise ValueError( 254 'DoFn %r has duplicate %s method parameters: %s.' % 255 (dofn, method_name, param_ids)) 256 for d in method.defaults: 257 if isinstance(d, _StateDoFnParam): 258 all_state_specs.add(d.state_spec) 259 elif isinstance(d, _TimerDoFnParam): 260 all_timer_specs.add(d.timer_spec) 261 262 return all_state_specs, all_timer_specs 263 264 265 def is_stateful_dofn(dofn): 266 # type: (DoFn) -> bool 267 268 """Determines whether a given DoFn is a stateful DoFn.""" 269 270 # A Stateful DoFn is a DoFn that uses user state or timers. 271 all_state_specs, all_timer_specs = get_dofn_specs(dofn) 272 return bool(all_state_specs or all_timer_specs) 273 274 275 def validate_stateful_dofn(dofn): 276 # type: (DoFn) -> None 277 278 """Validates the proper specification of a stateful DoFn.""" 279 280 # Get state and timer specs. 281 all_state_specs, all_timer_specs = get_dofn_specs(dofn) 282 283 # Reject DoFns that have multiple state or timer specs with the same name. 284 if len(all_state_specs) != len(set(s.name for s in all_state_specs)): 285 raise ValueError( 286 'DoFn %r has multiple StateSpecs with the same name: %s.' % 287 (dofn, all_state_specs)) 288 if len(all_timer_specs) != len(set(s.name for s in all_timer_specs)): 289 raise ValueError( 290 'DoFn %r has multiple TimerSpecs with the same name: %s.' % 291 (dofn, all_timer_specs)) 292 293 # Reject DoFns that use timer specs without corresponding timer callbacks. 294 for timer_spec in all_timer_specs: 295 if not timer_spec._attached_callback: 296 raise ValueError(( 297 'DoFn %r has a TimerSpec without an associated on_timer ' 298 'callback: %s.') % (dofn, timer_spec)) 299 method_name = timer_spec._attached_callback.__name__ 300 if (timer_spec._attached_callback != getattr(dofn, method_name, 301 None).__func__): 302 raise ValueError(( 303 'The on_timer callback for %s is not the specified .%s method ' 304 'for DoFn %r (perhaps it was overwritten?).') % 305 (timer_spec, method_name, dofn)) 306 307 308 class BaseTimer(object): 309 def clear(self, dynamic_timer_tag=''): 310 # type: (str) -> None 311 raise NotImplementedError 312 313 def set(self, timestamp, dynamic_timer_tag=''): 314 # type: (Timestamp, str) -> None 315 raise NotImplementedError 316 317 318 _TimerTuple = collections.namedtuple('timer_tuple', ('cleared', 'timestamp')) 319 320 321 class RuntimeTimer(BaseTimer): 322 """Timer interface object passed to user code.""" 323 def __init__(self) -> None: 324 self._timer_recordings = {} # type: Dict[str, _TimerTuple] 325 self._cleared = False 326 self._new_timestamp = None # type: Optional[Timestamp] 327 328 def clear(self, dynamic_timer_tag=''): 329 # type: (str) -> None 330 self._timer_recordings[dynamic_timer_tag] = _TimerTuple( 331 cleared=True, timestamp=None) 332 333 def set(self, timestamp, dynamic_timer_tag=''): 334 # type: (Timestamp, str) -> None 335 self._timer_recordings[dynamic_timer_tag] = _TimerTuple( 336 cleared=False, timestamp=timestamp) 337 338 339 class RuntimeState(object): 340 """State interface object passed to user code.""" 341 def prefetch(self): 342 # type: () -> None 343 # The default implementation here does nothing. 344 pass 345 346 def finalize(self): 347 # type: () -> None 348 pass 349 350 351 class ReadModifyWriteRuntimeState(RuntimeState): 352 def read(self): 353 # type: () -> Any 354 raise NotImplementedError(type(self)) 355 356 def write(self, value): 357 # type: (Any) -> None 358 raise NotImplementedError(type(self)) 359 360 def clear(self): 361 # type: () -> None 362 raise NotImplementedError(type(self)) 363 364 def commit(self): 365 # type: () -> None 366 raise NotImplementedError(type(self)) 367 368 369 class AccumulatingRuntimeState(RuntimeState): 370 def read(self): 371 # type: () -> Iterable[Any] 372 raise NotImplementedError(type(self)) 373 374 def add(self, value): 375 # type: (Any) -> None 376 raise NotImplementedError(type(self)) 377 378 def clear(self): 379 # type: () -> None 380 raise NotImplementedError(type(self)) 381 382 def commit(self): 383 # type: () -> None 384 raise NotImplementedError(type(self)) 385 386 387 class BagRuntimeState(AccumulatingRuntimeState): 388 """Bag state interface object passed to user code.""" 389 390 391 class SetRuntimeState(AccumulatingRuntimeState): 392 """Set state interface object passed to user code.""" 393 394 395 class CombiningValueRuntimeState(AccumulatingRuntimeState): 396 """Combining value state interface object passed to user code.""" 397 398 399 class UserStateContext(object): 400 """Wrapper allowing user state and timers to be accessed by a DoFnInvoker.""" 401 def get_timer(self, 402 timer_spec, # type: TimerSpec 403 key, # type: Any 404 window, # type: windowed_value.BoundedWindow 405 timestamp, # type: Timestamp 406 pane, # type: windowed_value.PaneInfo 407 ): 408 # type: (...) -> BaseTimer 409 raise NotImplementedError(type(self)) 410 411 def get_state(self, 412 state_spec, # type: StateSpec 413 key, # type: Any 414 window, # type: windowed_value.BoundedWindow 415 ): 416 # type: (...) -> RuntimeState 417 raise NotImplementedError(type(self)) 418 419 def commit(self): 420 # type: () -> None 421 raise NotImplementedError(type(self))