github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/worker/sdk_worker.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """SDK harness for executing Python Fns via the Fn API.""" 19 20 # pytype: skip-file 21 # mypy: disallow-untyped-defs 22 23 import abc 24 import collections 25 import contextlib 26 import functools 27 import json 28 import logging 29 import queue 30 import sys 31 import threading 32 import time 33 import traceback 34 from concurrent import futures 35 from typing import TYPE_CHECKING 36 from typing import Any 37 from typing import Callable 38 from typing import DefaultDict 39 from typing import Dict 40 from typing import FrozenSet 41 from typing import Generic 42 from typing import Iterable 43 from typing import Iterator 44 from typing import List 45 from typing import MutableMapping 46 from typing import Optional 47 from typing import Tuple 48 from typing import TypeVar 49 from typing import Union 50 51 import grpc 52 53 from apache_beam.coders import coder_impl 54 from apache_beam.metrics import monitoring_infos 55 from apache_beam.metrics.execution import MetricsEnvironment 56 from apache_beam.portability.api import beam_fn_api_pb2 57 from apache_beam.portability.api import beam_fn_api_pb2_grpc 58 from apache_beam.portability.api import metrics_pb2 59 from apache_beam.runners.worker import bundle_processor 60 from apache_beam.runners.worker import data_plane 61 from apache_beam.runners.worker import data_sampler 62 from apache_beam.runners.worker import statesampler 63 from apache_beam.runners.worker.channel_factory import GRPCChannelFactory 64 from apache_beam.runners.worker.data_plane import PeriodicThread 65 from apache_beam.runners.worker.statecache import CacheAware 66 from apache_beam.runners.worker.statecache import StateCache 67 from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor 68 from apache_beam.runners.worker.worker_status import FnApiWorkerStatusHandler 69 from apache_beam.utils import thread_pool_executor 70 from apache_beam.utils.sentinel import Sentinel 71 72 if TYPE_CHECKING: 73 from apache_beam.portability.api import endpoints_pb2 74 from apache_beam.utils.profiler import Profile 75 76 T = TypeVar('T') 77 _KT = TypeVar('_KT') 78 _VT = TypeVar('_VT') 79 80 _LOGGER = logging.getLogger(__name__) 81 82 DEFAULT_BUNDLE_PROCESSOR_CACHE_SHUTDOWN_THRESHOLD_S = 60 83 # The number of ProcessBundleRequest instruction ids the BundleProcessorCache 84 # will remember for not running instructions. 85 MAX_KNOWN_NOT_RUNNING_INSTRUCTIONS = 1000 86 # The number of ProcessBundleRequest instruction ids that BundleProcessorCache 87 # will remember for failed instructions. 88 MAX_FAILED_INSTRUCTIONS = 10000 89 90 # retry on transient UNAVAILABLE grpc error from state channels. 91 _GRPC_SERVICE_CONFIG = json.dumps({ 92 "methodConfig": [{ 93 "name": [{ 94 "service": "org.apache.beam.model.fn_execution.v1.BeamFnState" 95 }], 96 "retryPolicy": { 97 "maxAttempts": 5, 98 "initialBackoff": "0.1s", 99 "maxBackoff": "5s", 100 "backoffMultiplier": 2, 101 "retryableStatusCodes": ["UNAVAILABLE"], 102 }, 103 }] 104 }) 105 106 107 class ShortIdCache(object): 108 """ Cache for MonitoringInfo "short ids" 109 """ 110 def __init__(self): 111 # type: () -> None 112 self._lock = threading.Lock() 113 self._last_short_id = 0 114 self._info_key_to_short_id = {} # type: Dict[FrozenSet, str] 115 self._short_id_to_info = {} # type: Dict[str, metrics_pb2.MonitoringInfo] 116 117 def get_short_id(self, monitoring_info): 118 # type: (metrics_pb2.MonitoringInfo) -> str 119 120 """ Returns the assigned shortId for a given MonitoringInfo, assigns one if 121 not assigned already. 122 """ 123 key = monitoring_infos.to_key(monitoring_info) 124 with self._lock: 125 try: 126 return self._info_key_to_short_id[key] 127 except KeyError: 128 self._last_short_id += 1 129 130 # Convert to a hex string (and drop the '0x') for some compression 131 shortId = hex(self._last_short_id)[2:] 132 133 payload_cleared = metrics_pb2.MonitoringInfo() 134 payload_cleared.CopyFrom(monitoring_info) 135 payload_cleared.ClearField('payload') 136 137 self._info_key_to_short_id[key] = shortId 138 self._short_id_to_info[shortId] = payload_cleared 139 return shortId 140 141 def get_infos(self, short_ids): 142 #type: (Iterable[str]) -> Dict[str, metrics_pb2.MonitoringInfo] 143 144 """ Gets the base MonitoringInfo (with payload cleared) for each short ID. 145 146 Throws KeyError if an unassigned short ID is encountered. 147 """ 148 return { 149 short_id: self._short_id_to_info[short_id] 150 for short_id in short_ids 151 } 152 153 154 SHORT_ID_CACHE = ShortIdCache() 155 156 157 class SdkHarness(object): 158 REQUEST_METHOD_PREFIX = '_request_' 159 160 def __init__( 161 self, 162 control_address, # type: str 163 credentials=None, # type: Optional[grpc.ChannelCredentials] 164 worker_id=None, # type: Optional[str] 165 # Caching is disabled by default 166 state_cache_size=0, # type: int 167 # time-based data buffering is disabled by default 168 data_buffer_time_limit_ms=0, # type: int 169 profiler_factory=None, # type: Optional[Callable[..., Profile]] 170 status_address=None, # type: Optional[str] 171 # Heap dump through status api is disabled by default 172 enable_heap_dump=False, # type: bool 173 data_sampler=None, # type: Optional[data_sampler.DataSampler] 174 # Unrecoverable SDK harness initialization error (if any) 175 # that should be reported to the runner when proocessing the first bundle. 176 deferred_exception=None, # type: Optional[Exception] 177 ): 178 # type: (...) -> None 179 self._alive = True 180 self._worker_index = 0 181 self._worker_id = worker_id 182 self._state_cache = StateCache(state_cache_size) 183 self._deferred_exception = deferred_exception 184 options = [('grpc.max_receive_message_length', -1), 185 ('grpc.max_send_message_length', -1)] 186 if credentials is None: 187 _LOGGER.info('Creating insecure control channel for %s.', control_address) 188 self._control_channel = GRPCChannelFactory.insecure_channel( 189 control_address, options=options) 190 else: 191 _LOGGER.info('Creating secure control channel for %s.', control_address) 192 self._control_channel = GRPCChannelFactory.secure_channel( 193 control_address, credentials, options=options) 194 grpc.channel_ready_future(self._control_channel).result(timeout=60) 195 _LOGGER.info('Control channel established.') 196 197 self._control_channel = grpc.intercept_channel( 198 self._control_channel, WorkerIdInterceptor(self._worker_id)) 199 self._data_channel_factory = data_plane.GrpcClientDataChannelFactory( 200 credentials, self._worker_id, data_buffer_time_limit_ms) 201 self._state_handler_factory = GrpcStateHandlerFactory( 202 self._state_cache, credentials) 203 self._profiler_factory = profiler_factory 204 self.data_sampler = data_sampler 205 206 def default_factory(id): 207 # type: (str) -> beam_fn_api_pb2.ProcessBundleDescriptor 208 return self._control_stub.GetProcessBundleDescriptor( 209 beam_fn_api_pb2.GetProcessBundleDescriptorRequest( 210 process_bundle_descriptor_id=id)) 211 212 self._fns = KeyedDefaultDict(default_factory) 213 # BundleProcessor cache across all workers. 214 self._bundle_processor_cache = BundleProcessorCache( 215 state_handler_factory=self._state_handler_factory, 216 data_channel_factory=self._data_channel_factory, 217 fns=self._fns, 218 data_sampler=self.data_sampler, 219 ) 220 221 if status_address: 222 try: 223 self._status_handler = FnApiWorkerStatusHandler( 224 status_address, 225 self._bundle_processor_cache, 226 self._state_cache, 227 enable_heap_dump) # type: Optional[FnApiWorkerStatusHandler] 228 except Exception: 229 traceback_string = traceback.format_exc() 230 _LOGGER.warning( 231 'Error creating worker status request handler, ' 232 'skipping status report. Trace back: %s' % traceback_string) 233 else: 234 self._status_handler = None 235 236 # TODO(BEAM-8998) use common 237 # thread_pool_executor.shared_unbounded_instance() to process bundle 238 # progress once dataflow runner's excessive progress polling is removed. 239 self._report_progress_executor = futures.ThreadPoolExecutor(max_workers=1) 240 self._worker_thread_pool = thread_pool_executor.shared_unbounded_instance() 241 self._responses = queue.Queue( 242 ) # type: queue.Queue[Union[beam_fn_api_pb2.InstructionResponse, Sentinel]] 243 _LOGGER.info('Initializing SDKHarness with unbounded number of workers.') 244 245 def run(self): 246 # type: () -> None 247 self._control_stub = beam_fn_api_pb2_grpc.BeamFnControlStub( 248 self._control_channel) 249 no_more_work = Sentinel.sentinel 250 251 def get_responses(): 252 # type: () -> Iterator[beam_fn_api_pb2.InstructionResponse] 253 while True: 254 response = self._responses.get() 255 if response is no_more_work: 256 return 257 yield response 258 259 self._alive = True 260 261 try: 262 for work_request in self._control_stub.Control(get_responses()): 263 _LOGGER.debug('Got work %s', work_request.instruction_id) 264 request_type = work_request.WhichOneof('request') 265 # Name spacing the request method with 'request_'. The called method 266 # will be like self.request_register(request) 267 getattr(self, SdkHarness.REQUEST_METHOD_PREFIX + request_type)( 268 work_request) 269 finally: 270 self._alive = False 271 272 _LOGGER.info('No more requests from control plane') 273 _LOGGER.info('SDK Harness waiting for in-flight requests to complete') 274 # Wait until existing requests are processed. 275 self._worker_thread_pool.shutdown() 276 # get_responses may be blocked on responses.get(), but we need to return 277 # control to its caller. 278 self._responses.put(no_more_work) 279 # Stop all the workers and clean all the associated resources 280 self._data_channel_factory.close() 281 self._state_handler_factory.close() 282 self._bundle_processor_cache.shutdown() 283 if self._status_handler: 284 self._status_handler.close() 285 _LOGGER.info('Done consuming work.') 286 287 def _execute( 288 self, 289 task, # type: Callable[[], beam_fn_api_pb2.InstructionResponse] 290 request # type: beam_fn_api_pb2.InstructionRequest 291 ): 292 # type: (...) -> None 293 with statesampler.instruction_id(request.instruction_id): 294 try: 295 response = task() 296 except: # pylint: disable=bare-except 297 traceback_string = traceback.format_exc() 298 print(traceback_string, file=sys.stderr) 299 _LOGGER.error( 300 'Error processing instruction %s. Original traceback is\n%s\n', 301 request.instruction_id, 302 traceback_string) 303 response = beam_fn_api_pb2.InstructionResponse( 304 instruction_id=request.instruction_id, error=traceback_string) 305 self._responses.put(response) 306 307 def _request_register(self, request): 308 # type: (beam_fn_api_pb2.InstructionRequest) -> None 309 # registration request is handled synchronously 310 self._execute(lambda: self.create_worker().do_instruction(request), request) 311 312 def _request_process_bundle(self, request): 313 # type: (beam_fn_api_pb2.InstructionRequest) -> None 314 if self._deferred_exception: 315 raise self._deferred_exception 316 self._bundle_processor_cache.activate(request.instruction_id) 317 self._request_execute(request) 318 319 def _request_process_bundle_split(self, request): 320 # type: (beam_fn_api_pb2.InstructionRequest) -> None 321 self._request_process_bundle_action(request) 322 323 def _request_process_bundle_progress(self, request): 324 # type: (beam_fn_api_pb2.InstructionRequest) -> None 325 self._request_process_bundle_action(request) 326 327 def _request_process_bundle_action(self, request): 328 # type: (beam_fn_api_pb2.InstructionRequest) -> None 329 def task(): 330 # type: () -> None 331 self._execute( 332 lambda: self.create_worker().do_instruction(request), request) 333 334 self._report_progress_executor.submit(task) 335 336 def _request_finalize_bundle(self, request): 337 # type: (beam_fn_api_pb2.InstructionRequest) -> None 338 self._request_execute(request) 339 340 def _request_harness_monitoring_infos(self, request): 341 # type: (beam_fn_api_pb2.InstructionRequest) -> None 342 process_wide_monitoring_infos = MetricsEnvironment.process_wide_container( 343 ).to_runner_api_monitoring_infos(None).values() 344 self._execute( 345 lambda: beam_fn_api_pb2.InstructionResponse( 346 instruction_id=request.instruction_id, 347 harness_monitoring_infos=( 348 beam_fn_api_pb2.HarnessMonitoringInfosResponse( 349 monitoring_data={ 350 SHORT_ID_CACHE.get_short_id(info): info.payload 351 for info in process_wide_monitoring_infos 352 }))), 353 request) 354 355 def _request_monitoring_infos(self, request): 356 # type: (beam_fn_api_pb2.InstructionRequest) -> None 357 self._execute( 358 lambda: beam_fn_api_pb2.InstructionResponse( 359 instruction_id=request.instruction_id, 360 monitoring_infos=beam_fn_api_pb2.MonitoringInfosMetadataResponse( 361 monitoring_info=SHORT_ID_CACHE.get_infos( 362 request.monitoring_infos.monitoring_info_id))), 363 request) 364 365 def _request_execute(self, request): 366 # type: (beam_fn_api_pb2.InstructionRequest) -> None 367 def task(): 368 # type: () -> None 369 self._execute( 370 lambda: self.create_worker().do_instruction(request), request) 371 372 self._worker_thread_pool.submit(task) 373 _LOGGER.debug( 374 "Currently using %s threads." % len(self._worker_thread_pool._workers)) 375 376 def _request_sample_data(self, request): 377 # type: (beam_fn_api_pb2.InstructionRequest) -> None 378 379 def get_samples(request): 380 # type: (beam_fn_api_pb2.InstructionRequest) -> beam_fn_api_pb2.InstructionResponse 381 samples: Dict[str, List[bytes]] = {} 382 if self.data_sampler: 383 samples = self.data_sampler.samples(request.sample_data.pcollection_ids) 384 385 sample_response = beam_fn_api_pb2.SampleDataResponse() 386 for pcoll_id in samples: 387 sample_response.element_samples[pcoll_id].elements.extend( 388 beam_fn_api_pb2.SampledElement(element=s) 389 for s in samples[pcoll_id]) 390 391 return beam_fn_api_pb2.InstructionResponse( 392 instruction_id=request.instruction_id, sample_data=sample_response) 393 394 self._execute(lambda: get_samples(request), request) 395 396 def create_worker(self): 397 # type: () -> SdkWorker 398 return SdkWorker( 399 self._bundle_processor_cache, profiler_factory=self._profiler_factory) 400 401 402 class BundleProcessorCache(object): 403 """A cache for ``BundleProcessor``s. 404 405 ``BundleProcessor`` objects are cached by the id of their 406 ``beam_fn_api_pb2.ProcessBundleDescriptor``. 407 408 Attributes: 409 fns (dict): A dictionary that maps bundle descriptor IDs to instances of 410 ``beam_fn_api_pb2.ProcessBundleDescriptor``. 411 state_handler_factory (``StateHandlerFactory``): Used to create state 412 handlers to be used by a ``bundle_processor.BundleProcessor`` during 413 processing. 414 data_channel_factory (``data_plane.DataChannelFactory``) 415 active_bundle_processors (dict): A dictionary, indexed by instruction IDs, 416 containing ``bundle_processor.BundleProcessor`` objects that are currently 417 active processing the corresponding instruction. 418 cached_bundle_processors (dict): A dictionary, indexed by bundle processor 419 id, of cached ``bundle_processor.BundleProcessor`` that are not currently 420 performing processing. 421 """ 422 periodic_shutdown = None # type: Optional[PeriodicThread] 423 424 def __init__( 425 self, 426 state_handler_factory, # type: StateHandlerFactory 427 data_channel_factory, # type: data_plane.DataChannelFactory 428 fns, # type: MutableMapping[str, beam_fn_api_pb2.ProcessBundleDescriptor] 429 data_sampler=None, # type: Optional[data_sampler.DataSampler] 430 ): 431 # type: (...) -> None 432 self.fns = fns 433 self.state_handler_factory = state_handler_factory 434 self.data_channel_factory = data_channel_factory 435 self.known_not_running_instruction_ids = collections.OrderedDict( 436 ) # type: collections.OrderedDict[str, bool] 437 self.failed_instruction_ids = collections.OrderedDict( 438 ) # type: collections.OrderedDict[str, bool] 439 self.active_bundle_processors = { 440 } # type: Dict[str, Tuple[str, bundle_processor.BundleProcessor]] 441 self.cached_bundle_processors = collections.defaultdict( 442 list) # type: DefaultDict[str, List[bundle_processor.BundleProcessor]] 443 self.last_access_times = collections.defaultdict( 444 float) # type: DefaultDict[str, float] 445 self._schedule_periodic_shutdown() 446 self._lock = threading.Lock() 447 self.data_sampler = data_sampler 448 449 def register(self, bundle_descriptor): 450 # type: (beam_fn_api_pb2.ProcessBundleDescriptor) -> None 451 452 """Register a ``beam_fn_api_pb2.ProcessBundleDescriptor`` by its id.""" 453 self.fns[bundle_descriptor.id] = bundle_descriptor 454 455 def activate(self, instruction_id): 456 # type: (str) -> None 457 458 """Makes the ``instruction_id`` known to the bundle processor. 459 460 Allows ``lookup`` to return ``None``. Necessary if ``lookup`` can occur 461 before ``get``. 462 """ 463 with self._lock: 464 self.known_not_running_instruction_ids[instruction_id] = True 465 466 def get(self, instruction_id, bundle_descriptor_id): 467 # type: (str, str) -> bundle_processor.BundleProcessor 468 469 """ 470 Return the requested ``BundleProcessor``, creating it if necessary. 471 472 Moves the ``BundleProcessor`` from the inactive to the active cache. 473 """ 474 with self._lock: 475 try: 476 # pop() is threadsafe 477 processor = self.cached_bundle_processors[bundle_descriptor_id].pop() 478 self.active_bundle_processors[ 479 instruction_id] = bundle_descriptor_id, processor 480 try: 481 del self.known_not_running_instruction_ids[instruction_id] 482 except KeyError: 483 # The instruction may have not been pre-registered before execution 484 # since activate() may have never been invoked 485 pass 486 return processor 487 except IndexError: 488 pass 489 490 # Make sure we instantiate the processor while not holding the lock. 491 processor = bundle_processor.BundleProcessor( 492 self.fns[bundle_descriptor_id], 493 self.state_handler_factory.create_state_handler( 494 self.fns[bundle_descriptor_id].state_api_service_descriptor), 495 self.data_channel_factory, 496 self.data_sampler) 497 with self._lock: 498 self.active_bundle_processors[ 499 instruction_id] = bundle_descriptor_id, processor 500 try: 501 del self.known_not_running_instruction_ids[instruction_id] 502 except KeyError: 503 # The instruction may have not been pre-registered before execution 504 # since activate() may have never been invoked 505 pass 506 return processor 507 508 def lookup(self, instruction_id): 509 # type: (str) -> Optional[bundle_processor.BundleProcessor] 510 511 """ 512 Return the requested ``BundleProcessor`` from the cache. 513 514 Will return ``None`` if the BundleProcessor is known but not yet ready. Will 515 raise an error if the ``instruction_id`` is not known or has been discarded. 516 """ 517 with self._lock: 518 if instruction_id in self.failed_instruction_ids: 519 raise RuntimeError( 520 'Bundle processing associated with %s has failed. ' 521 'Check prior failing response for details.' % instruction_id) 522 processor = self.active_bundle_processors.get( 523 instruction_id, (None, None))[-1] 524 if processor: 525 return processor 526 if instruction_id in self.known_not_running_instruction_ids: 527 return None 528 raise RuntimeError('Unknown process bundle id %s.' % instruction_id) 529 530 def discard(self, instruction_id): 531 # type: (str) -> None 532 533 """ 534 Marks the instruction id as failed shutting down the ``BundleProcessor``. 535 """ 536 with self._lock: 537 self.failed_instruction_ids[instruction_id] = True 538 while len(self.failed_instruction_ids) > MAX_FAILED_INSTRUCTIONS: 539 self.failed_instruction_ids.popitem(last=False) 540 processor = self.active_bundle_processors[instruction_id][1] 541 del self.active_bundle_processors[instruction_id] 542 543 # Perform the shutdown while not holding the lock. 544 processor.shutdown() 545 546 def release(self, instruction_id): 547 # type: (str) -> None 548 549 """ 550 Release the requested ``BundleProcessor``. 551 552 Resets the ``BundleProcessor`` and moves it from the active to the 553 inactive cache. 554 """ 555 with self._lock: 556 self.known_not_running_instruction_ids[instruction_id] = True 557 while len(self.known_not_running_instruction_ids 558 ) > MAX_KNOWN_NOT_RUNNING_INSTRUCTIONS: 559 self.known_not_running_instruction_ids.popitem(last=False) 560 descriptor_id, processor = ( 561 self.active_bundle_processors.pop(instruction_id)) 562 563 # Make sure that we reset the processor while not holding the lock. 564 processor.reset() 565 with self._lock: 566 self.last_access_times[descriptor_id] = time.time() 567 self.cached_bundle_processors[descriptor_id].append(processor) 568 569 def shutdown(self): 570 # type: () -> None 571 572 """ 573 Shutdown all ``BundleProcessor``s in the cache. 574 """ 575 if self.periodic_shutdown: 576 self.periodic_shutdown.cancel() 577 self.periodic_shutdown.join() 578 self.periodic_shutdown = None 579 580 for instruction_id in list(self.active_bundle_processors.keys()): 581 self.discard(instruction_id) 582 for cached_bundle_processors in self.cached_bundle_processors.values(): 583 BundleProcessorCache._shutdown_cached_bundle_processors( 584 cached_bundle_processors) 585 586 def _schedule_periodic_shutdown(self): 587 # type: () -> None 588 def shutdown_inactive_bundle_processors(): 589 # type: () -> None 590 for descriptor_id, last_access_time in self.last_access_times.items(): 591 if (time.time() - last_access_time > 592 DEFAULT_BUNDLE_PROCESSOR_CACHE_SHUTDOWN_THRESHOLD_S): 593 BundleProcessorCache._shutdown_cached_bundle_processors( 594 self.cached_bundle_processors[descriptor_id]) 595 596 self.periodic_shutdown = PeriodicThread( 597 DEFAULT_BUNDLE_PROCESSOR_CACHE_SHUTDOWN_THRESHOLD_S, 598 shutdown_inactive_bundle_processors) 599 self.periodic_shutdown.daemon = True 600 self.periodic_shutdown.start() 601 602 @staticmethod 603 def _shutdown_cached_bundle_processors(cached_bundle_processors): 604 # type: (List[bundle_processor.BundleProcessor]) -> None 605 try: 606 while True: 607 # pop() is threadsafe 608 bundle_processor = cached_bundle_processors.pop() 609 bundle_processor.shutdown() 610 except IndexError: 611 pass 612 613 614 class SdkWorker(object): 615 def __init__( 616 self, 617 bundle_processor_cache, # type: BundleProcessorCache 618 profiler_factory=None, # type: Optional[Callable[..., Profile]] 619 ): 620 # type: (...) -> None 621 self.bundle_processor_cache = bundle_processor_cache 622 self.profiler_factory = profiler_factory 623 624 def do_instruction(self, request): 625 # type: (beam_fn_api_pb2.InstructionRequest) -> beam_fn_api_pb2.InstructionResponse 626 request_type = request.WhichOneof('request') 627 if request_type: 628 # E.g. if register is set, this will call self.register(request.register)) 629 return getattr(self, request_type)( 630 getattr(request, request_type), request.instruction_id) 631 else: 632 raise NotImplementedError 633 634 def register( 635 self, 636 request, # type: beam_fn_api_pb2.RegisterRequest 637 instruction_id # type: str 638 ): 639 # type: (...) -> beam_fn_api_pb2.InstructionResponse 640 641 """Registers a set of ``beam_fn_api_pb2.ProcessBundleDescriptor``s. 642 643 This set of ``beam_fn_api_pb2.ProcessBundleDescriptor`` come as part of a 644 ``beam_fn_api_pb2.RegisterRequest``, which the runner sends to the SDK 645 worker before starting processing to register stages. 646 """ 647 648 for process_bundle_descriptor in request.process_bundle_descriptor: 649 self.bundle_processor_cache.register(process_bundle_descriptor) 650 return beam_fn_api_pb2.InstructionResponse( 651 instruction_id=instruction_id, 652 register=beam_fn_api_pb2.RegisterResponse()) 653 654 def process_bundle( 655 self, 656 request, # type: beam_fn_api_pb2.ProcessBundleRequest 657 instruction_id # type: str 658 ): 659 # type: (...) -> beam_fn_api_pb2.InstructionResponse 660 bundle_processor = self.bundle_processor_cache.get( 661 instruction_id, request.process_bundle_descriptor_id) 662 try: 663 with bundle_processor.state_handler.process_instruction_id( 664 instruction_id, request.cache_tokens): 665 with self.maybe_profile(instruction_id): 666 delayed_applications, requests_finalization = ( 667 bundle_processor.process_bundle(instruction_id)) 668 monitoring_infos = bundle_processor.monitoring_infos() 669 response = beam_fn_api_pb2.InstructionResponse( 670 instruction_id=instruction_id, 671 process_bundle=beam_fn_api_pb2.ProcessBundleResponse( 672 residual_roots=delayed_applications, 673 monitoring_infos=monitoring_infos, 674 monitoring_data={ 675 SHORT_ID_CACHE.get_short_id(info): info.payload 676 for info in monitoring_infos 677 }, 678 requires_finalization=requests_finalization)) 679 # Don't release here if finalize is needed. 680 if not requests_finalization: 681 self.bundle_processor_cache.release(instruction_id) 682 return response 683 except: # pylint: disable=bare-except 684 # Don't re-use bundle processors on failure. 685 self.bundle_processor_cache.discard(instruction_id) 686 raise 687 688 def process_bundle_split( 689 self, 690 request, # type: beam_fn_api_pb2.ProcessBundleSplitRequest 691 instruction_id # type: str 692 ): 693 # type: (...) -> beam_fn_api_pb2.InstructionResponse 694 try: 695 processor = self.bundle_processor_cache.lookup(request.instruction_id) 696 except RuntimeError: 697 return beam_fn_api_pb2.InstructionResponse( 698 instruction_id=instruction_id, error=traceback.format_exc()) 699 # Return an empty response if we aren't running. This can happen 700 # if the ProcessBundleRequest has not started or already finished. 701 process_bundle_split = ( 702 processor.try_split(request) 703 if processor else beam_fn_api_pb2.ProcessBundleSplitResponse()) 704 return beam_fn_api_pb2.InstructionResponse( 705 instruction_id=instruction_id, 706 process_bundle_split=process_bundle_split) 707 708 def process_bundle_progress( 709 self, 710 request, # type: beam_fn_api_pb2.ProcessBundleProgressRequest 711 instruction_id # type: str 712 ): 713 # type: (...) -> beam_fn_api_pb2.InstructionResponse 714 try: 715 processor = self.bundle_processor_cache.lookup(request.instruction_id) 716 except RuntimeError: 717 return beam_fn_api_pb2.InstructionResponse( 718 instruction_id=instruction_id, error=traceback.format_exc()) 719 if processor: 720 monitoring_infos = processor.monitoring_infos() 721 else: 722 # Return an empty response if we aren't running. This can happen 723 # if the ProcessBundleRequest has not started or already finished. 724 monitoring_infos = [] 725 return beam_fn_api_pb2.InstructionResponse( 726 instruction_id=instruction_id, 727 process_bundle_progress=beam_fn_api_pb2.ProcessBundleProgressResponse( 728 monitoring_infos=monitoring_infos, 729 monitoring_data={ 730 SHORT_ID_CACHE.get_short_id(info): info.payload 731 for info in monitoring_infos 732 })) 733 734 def finalize_bundle( 735 self, 736 request, # type: beam_fn_api_pb2.FinalizeBundleRequest 737 instruction_id # type: str 738 ): 739 # type: (...) -> beam_fn_api_pb2.InstructionResponse 740 try: 741 processor = self.bundle_processor_cache.lookup(request.instruction_id) 742 except RuntimeError: 743 return beam_fn_api_pb2.InstructionResponse( 744 instruction_id=instruction_id, error=traceback.format_exc()) 745 if processor: 746 try: 747 finalize_response = processor.finalize_bundle() 748 self.bundle_processor_cache.release(request.instruction_id) 749 return beam_fn_api_pb2.InstructionResponse( 750 instruction_id=instruction_id, finalize_bundle=finalize_response) 751 except: 752 self.bundle_processor_cache.discard(request.instruction_id) 753 raise 754 # We can reach this state if there was an erroneous request to finalize 755 # the bundle while it is being initialized or has already been finalized 756 # and released. 757 raise RuntimeError( 758 'Bundle is not in a finalizable state for %s' % instruction_id) 759 760 @contextlib.contextmanager 761 def maybe_profile(self, instruction_id): 762 # type: (str) -> Iterator[None] 763 if self.profiler_factory: 764 profiler = self.profiler_factory(instruction_id) 765 if profiler: 766 with profiler: 767 yield 768 else: 769 yield 770 else: 771 yield 772 773 774 class StateHandler(metaclass=abc.ABCMeta): 775 """An abstract object representing a ``StateHandler``.""" 776 @abc.abstractmethod 777 def get_raw( 778 self, 779 state_key, # type: beam_fn_api_pb2.StateKey 780 continuation_token=None # type: Optional[bytes] 781 ): 782 # type: (...) -> Tuple[bytes, Optional[bytes]] 783 784 """Gets the contents of state for the given state key. 785 786 State is associated to a state key, AND an instruction_id, which is set 787 when calling process_instruction_id. 788 789 Returns a tuple with the contents in state, and an optional continuation 790 token, which is used to page the API. 791 """ 792 raise NotImplementedError(type(self)) 793 794 @abc.abstractmethod 795 def append_raw( 796 self, 797 state_key, # type: beam_fn_api_pb2.StateKey 798 data # type: bytes 799 ): 800 # type: (...) -> _Future 801 802 """Append the input data into the state key. 803 804 Returns a future that allows one to wait for the completion of the call. 805 806 State is associated to a state key, AND an instruction_id, which is set 807 when calling process_instruction_id. 808 """ 809 raise NotImplementedError(type(self)) 810 811 @abc.abstractmethod 812 def clear(self, state_key): 813 # type: (beam_fn_api_pb2.StateKey) -> _Future 814 815 """Clears the contents of a cell for the input state key. 816 817 Returns a future that allows one to wait for the completion of the call. 818 819 State is associated to a state key, AND an instruction_id, which is set 820 when calling process_instruction_id. 821 """ 822 raise NotImplementedError(type(self)) 823 824 @abc.abstractmethod 825 @contextlib.contextmanager 826 def process_instruction_id(self, bundle_id): 827 # type: (str) -> Iterator[None] 828 829 """Switch the context of the state handler to a specific instruction. 830 831 This must be called before performing any write or read operations on the 832 existing state. 833 """ 834 raise NotImplementedError(type(self)) 835 836 @abc.abstractmethod 837 def done(self): 838 # type: () -> None 839 840 """Mark the state handler as done, and potentially delete all context.""" 841 raise NotImplementedError(type(self)) 842 843 844 class StateHandlerFactory(metaclass=abc.ABCMeta): 845 """An abstract factory for creating ``DataChannel``.""" 846 @abc.abstractmethod 847 def create_state_handler(self, api_service_descriptor): 848 # type: (endpoints_pb2.ApiServiceDescriptor) -> CachingStateHandler 849 850 """Returns a ``StateHandler`` from the given ApiServiceDescriptor.""" 851 raise NotImplementedError(type(self)) 852 853 @abc.abstractmethod 854 def close(self): 855 # type: () -> None 856 857 """Close all channels that this factory owns.""" 858 raise NotImplementedError(type(self)) 859 860 861 class GrpcStateHandlerFactory(StateHandlerFactory): 862 """A factory for ``GrpcStateHandler``. 863 864 Caches the created channels by ``state descriptor url``. 865 """ 866 def __init__(self, state_cache, credentials=None): 867 # type: (StateCache, Optional[grpc.ChannelCredentials]) -> None 868 self._state_handler_cache = {} # type: Dict[str, CachingStateHandler] 869 self._lock = threading.Lock() 870 self._throwing_state_handler = ThrowingStateHandler() 871 self._credentials = credentials 872 self._state_cache = state_cache 873 874 def create_state_handler(self, api_service_descriptor): 875 # type: (endpoints_pb2.ApiServiceDescriptor) -> CachingStateHandler 876 if not api_service_descriptor: 877 return self._throwing_state_handler 878 url = api_service_descriptor.url 879 if url not in self._state_handler_cache: 880 with self._lock: 881 if url not in self._state_handler_cache: 882 # Options to have no limits (-1) on the size of the messages 883 # received or sent over the data plane. The actual buffer size is 884 # controlled in a layer above. 885 options = [('grpc.max_receive_message_length', -1), 886 ('grpc.max_send_message_length', -1), 887 ('grpc.service_config', _GRPC_SERVICE_CONFIG)] 888 if self._credentials is None: 889 _LOGGER.info('Creating insecure state channel for %s.', url) 890 grpc_channel = GRPCChannelFactory.insecure_channel( 891 url, options=options) 892 else: 893 _LOGGER.info('Creating secure state channel for %s.', url) 894 grpc_channel = GRPCChannelFactory.secure_channel( 895 url, self._credentials, options=options) 896 _LOGGER.info('State channel established.') 897 # Add workerId to the grpc channel 898 grpc_channel = grpc.intercept_channel( 899 grpc_channel, WorkerIdInterceptor()) 900 self._state_handler_cache[url] = GlobalCachingStateHandler( 901 self._state_cache, 902 GrpcStateHandler( 903 beam_fn_api_pb2_grpc.BeamFnStateStub(grpc_channel))) 904 return self._state_handler_cache[url] 905 906 def close(self): 907 # type: () -> None 908 _LOGGER.info('Closing all cached gRPC state handlers.') 909 for _, state_handler in self._state_handler_cache.items(): 910 state_handler.done() 911 self._state_handler_cache.clear() 912 self._state_cache.invalidate_all() 913 914 915 class CachingStateHandler(metaclass=abc.ABCMeta): 916 @abc.abstractmethod 917 @contextlib.contextmanager 918 def process_instruction_id(self, bundle_id, cache_tokens): 919 # type: (str, Iterable[beam_fn_api_pb2.ProcessBundleRequest.CacheToken]) -> Iterator[None] 920 raise NotImplementedError(type(self)) 921 922 @abc.abstractmethod 923 def blocking_get( 924 self, 925 state_key, # type: beam_fn_api_pb2.StateKey 926 coder, # type: coder_impl.CoderImpl 927 ): 928 # type: (...) -> Iterable[Any] 929 raise NotImplementedError(type(self)) 930 931 @abc.abstractmethod 932 def extend( 933 self, 934 state_key, # type: beam_fn_api_pb2.StateKey 935 coder, # type: coder_impl.CoderImpl 936 elements, # type: Iterable[Any] 937 ): 938 # type: (...) -> _Future 939 raise NotImplementedError(type(self)) 940 941 @abc.abstractmethod 942 def clear(self, state_key): 943 # type: (beam_fn_api_pb2.StateKey) -> _Future 944 raise NotImplementedError(type(self)) 945 946 @abc.abstractmethod 947 def done(self): 948 # type: () -> None 949 raise NotImplementedError(type(self)) 950 951 952 class ThrowingStateHandler(CachingStateHandler): 953 """A caching state handler that errors on any requests.""" 954 @contextlib.contextmanager 955 def process_instruction_id(self, bundle_id, cache_tokens): 956 # type: (str, Iterable[beam_fn_api_pb2.ProcessBundleRequest.CacheToken]) -> Iterator[None] 957 raise RuntimeError( 958 'Unable to handle state requests for ProcessBundleDescriptor ' 959 'for bundle id %s.' % bundle_id) 960 961 def blocking_get( 962 self, 963 state_key, # type: beam_fn_api_pb2.StateKey 964 coder, # type: coder_impl.CoderImpl 965 ): 966 # type: (...) -> Iterable[Any] 967 raise RuntimeError( 968 'Unable to handle state requests for ProcessBundleDescriptor without ' 969 'state ApiServiceDescriptor for state key %s.' % state_key) 970 971 def extend( 972 self, 973 state_key, # type: beam_fn_api_pb2.StateKey 974 coder, # type: coder_impl.CoderImpl 975 elements, # type: Iterable[Any] 976 ): 977 # type: (...) -> _Future 978 raise RuntimeError( 979 'Unable to handle state requests for ProcessBundleDescriptor without ' 980 'state ApiServiceDescriptor for state key %s.' % state_key) 981 982 def clear(self, state_key): 983 # type: (beam_fn_api_pb2.StateKey) -> _Future 984 raise RuntimeError( 985 'Unable to handle state requests for ProcessBundleDescriptor without ' 986 'state ApiServiceDescriptor for state key %s.' % state_key) 987 988 def done(self): 989 # type: () -> None 990 raise RuntimeError( 991 'Unable to handle state requests for ProcessBundleDescriptor.') 992 993 994 class GrpcStateHandler(StateHandler): 995 996 _DONE = Sentinel.sentinel 997 998 def __init__(self, state_stub): 999 # type: (beam_fn_api_pb2_grpc.BeamFnStateStub) -> None 1000 self._lock = threading.Lock() 1001 self._state_stub = state_stub 1002 self._requests = queue.Queue( 1003 ) # type: queue.Queue[Union[beam_fn_api_pb2.StateRequest, Sentinel]] 1004 self._responses_by_id = {} # type: Dict[str, _Future] 1005 self._last_id = 0 1006 self._exception = None # type: Optional[Exception] 1007 self._context = threading.local() 1008 self.start() 1009 1010 @contextlib.contextmanager 1011 def process_instruction_id(self, bundle_id): 1012 # type: (str) -> Iterator[None] 1013 if getattr(self._context, 'process_instruction_id', None) is not None: 1014 raise RuntimeError( 1015 'Already bound to %r' % self._context.process_instruction_id) 1016 self._context.process_instruction_id = bundle_id 1017 try: 1018 yield 1019 finally: 1020 self._context.process_instruction_id = None 1021 1022 def start(self): 1023 # type: () -> None 1024 self._done = False 1025 1026 def request_iter(): 1027 # type: () -> Iterator[beam_fn_api_pb2.StateRequest] 1028 while True: 1029 request = self._requests.get() 1030 if request is self._DONE or self._done: 1031 break 1032 yield request 1033 1034 responses = self._state_stub.State(request_iter()) 1035 1036 def pull_responses(): 1037 # type: () -> None 1038 try: 1039 for response in responses: 1040 # Popping an item from a dictionary is atomic in cPython 1041 future = self._responses_by_id.pop(response.id) 1042 future.set(response) 1043 if self._done: 1044 break 1045 except Exception as e: 1046 self._exception = e 1047 raise 1048 1049 reader = threading.Thread(target=pull_responses, name='read_state') 1050 reader.daemon = True 1051 reader.start() 1052 1053 def done(self): 1054 # type: () -> None 1055 self._done = True 1056 self._requests.put(self._DONE) 1057 1058 def get_raw( 1059 self, 1060 state_key, # type: beam_fn_api_pb2.StateKey 1061 continuation_token=None # type: Optional[bytes] 1062 ): 1063 # type: (...) -> Tuple[bytes, Optional[bytes]] 1064 response = self._blocking_request( 1065 beam_fn_api_pb2.StateRequest( 1066 state_key=state_key, 1067 get=beam_fn_api_pb2.StateGetRequest( 1068 continuation_token=continuation_token))) 1069 return response.get.data, response.get.continuation_token 1070 1071 def append_raw( 1072 self, 1073 state_key, # type: Optional[beam_fn_api_pb2.StateKey] 1074 data # type: bytes 1075 ): 1076 # type: (...) -> _Future 1077 return self._request( 1078 beam_fn_api_pb2.StateRequest( 1079 state_key=state_key, 1080 append=beam_fn_api_pb2.StateAppendRequest(data=data))) 1081 1082 def clear(self, state_key): 1083 # type: (Optional[beam_fn_api_pb2.StateKey]) -> _Future 1084 return self._request( 1085 beam_fn_api_pb2.StateRequest( 1086 state_key=state_key, clear=beam_fn_api_pb2.StateClearRequest())) 1087 1088 def _request(self, request): 1089 # type: (beam_fn_api_pb2.StateRequest) -> _Future[beam_fn_api_pb2.StateResponse] 1090 request.id = self._next_id() 1091 request.instruction_id = self._context.process_instruction_id 1092 # Adding a new item to a dictionary is atomic in cPython 1093 self._responses_by_id[request.id] = future = _Future[ 1094 beam_fn_api_pb2.StateResponse]() 1095 # Request queue is thread-safe 1096 self._requests.put(request) 1097 return future 1098 1099 def _blocking_request(self, request): 1100 # type: (beam_fn_api_pb2.StateRequest) -> beam_fn_api_pb2.StateResponse 1101 req_future = self._request(request) 1102 while not req_future.wait(timeout=1): 1103 if self._exception: 1104 raise self._exception 1105 elif self._done: 1106 raise RuntimeError() 1107 response = req_future.get() 1108 if response.error: 1109 raise RuntimeError(response.error) 1110 else: 1111 return response 1112 1113 def _next_id(self): 1114 # type: () -> str 1115 with self._lock: 1116 # Use a lock here because this GrpcStateHandler is shared across all 1117 # requests which have the same process bundle descriptor. State requests 1118 # can concurrently access this section if a Runner uses threads / workers 1119 # (aka "parallelism") to send data to this SdkHarness and its workers. 1120 self._last_id += 1 1121 request_id = self._last_id 1122 return str(request_id) 1123 1124 1125 class GlobalCachingStateHandler(CachingStateHandler): 1126 """ A State handler which retrieves and caches state. 1127 If caching is activated, caches across bundles using a supplied cache token. 1128 If activated but no cache token is supplied, caching is done at the bundle 1129 level. 1130 """ 1131 def __init__( 1132 self, 1133 global_state_cache, # type: StateCache 1134 underlying_state # type: StateHandler 1135 ): 1136 # type: (...) -> None 1137 self._underlying = underlying_state 1138 self._state_cache = global_state_cache 1139 self._context = threading.local() 1140 1141 @contextlib.contextmanager 1142 def process_instruction_id(self, bundle_id, cache_tokens): 1143 # type: (str, Iterable[beam_fn_api_pb2.ProcessBundleRequest.CacheToken]) -> Iterator[None] 1144 if getattr(self._context, 'user_state_cache_token', None) is not None: 1145 raise RuntimeError( 1146 'Cache tokens already set to %s' % 1147 self._context.user_state_cache_token) 1148 self._context.side_input_cache_tokens = {} 1149 user_state_cache_token = None 1150 for cache_token_struct in cache_tokens: 1151 if cache_token_struct.HasField("user_state"): 1152 # There should only be one user state token present 1153 assert not user_state_cache_token 1154 user_state_cache_token = cache_token_struct.token 1155 elif cache_token_struct.HasField("side_input"): 1156 self._context.side_input_cache_tokens[ 1157 cache_token_struct.side_input.transform_id, 1158 cache_token_struct.side_input. 1159 side_input_id] = cache_token_struct.token 1160 # TODO: Consider a two-level cache to avoid extra logic and locking 1161 # for items cached at the bundle level. 1162 self._context.bundle_cache_token = bundle_id 1163 try: 1164 self._context.user_state_cache_token = user_state_cache_token 1165 with self._underlying.process_instruction_id(bundle_id): 1166 yield 1167 finally: 1168 self._context.side_input_cache_tokens = {} 1169 self._context.user_state_cache_token = None 1170 self._context.bundle_cache_token = None 1171 1172 def blocking_get( 1173 self, 1174 state_key, # type: beam_fn_api_pb2.StateKey 1175 coder, # type: coder_impl.CoderImpl 1176 ): 1177 # type: (...) -> Iterable[Any] 1178 cache_token = self._get_cache_token(state_key) 1179 if not cache_token: 1180 # Cache disabled / no cache token. Can't do a lookup/store in the cache. 1181 # Fall back to lazily materializing the state, one element at a time. 1182 return self._lazy_iterator(state_key, coder) 1183 # Cache lookup 1184 cache_state_key = self._convert_to_cache_key(state_key) 1185 return self._state_cache.get( 1186 (cache_state_key, cache_token), 1187 lambda key: self._partially_cached_iterable(state_key, coder)) 1188 1189 def extend( 1190 self, 1191 state_key, # type: beam_fn_api_pb2.StateKey 1192 coder, # type: coder_impl.CoderImpl 1193 elements, # type: Iterable[Any] 1194 ): 1195 # type: (...) -> _Future 1196 cache_token = self._get_cache_token(state_key) 1197 if cache_token: 1198 # Update the cache if the value is already present and 1199 # can be updated. 1200 cache_key = self._convert_to_cache_key(state_key) 1201 cached_value = self._state_cache.peek((cache_key, cache_token)) 1202 if isinstance(cached_value, list): 1203 # The state is fully cached and can be extended 1204 1205 # Materialize provided iterable to ensure reproducible iterations, 1206 # here and when writing to the state handler below. 1207 elements = list(elements) 1208 cached_value.extend(elements) 1209 # Re-insert into the cache the updated value so the updated size is 1210 # reflected. 1211 self._state_cache.put((cache_key, cache_token), cached_value) 1212 1213 # Write to state handler 1214 futures = [] 1215 out = coder_impl.create_OutputStream() 1216 for element in elements: 1217 coder.encode_to_stream(element, out, True) 1218 if out.size() > data_plane._DEFAULT_SIZE_FLUSH_THRESHOLD: 1219 futures.append(self._underlying.append_raw(state_key, out.get())) 1220 out = coder_impl.create_OutputStream() 1221 if out.size(): 1222 futures.append(self._underlying.append_raw(state_key, out.get())) 1223 return _DeferredCall( 1224 lambda *results: beam_fn_api_pb2.StateResponse( 1225 error='\n'.join( 1226 result.error for result in results if result and result.error), 1227 append=beam_fn_api_pb2.StateAppendResponse()), 1228 *futures) 1229 1230 def clear(self, state_key): 1231 # type: (beam_fn_api_pb2.StateKey) -> _Future 1232 cache_token = self._get_cache_token(state_key) 1233 if cache_token: 1234 cache_key = self._convert_to_cache_key(state_key) 1235 self._state_cache.put((cache_key, cache_token), []) 1236 return self._underlying.clear(state_key) 1237 1238 def done(self): 1239 # type: () -> None 1240 self._underlying.done() 1241 1242 def _lazy_iterator( 1243 self, 1244 state_key, # type: beam_fn_api_pb2.StateKey 1245 coder, # type: coder_impl.CoderImpl 1246 continuation_token=None # type: Optional[bytes] 1247 ): 1248 # type: (...) -> Iterator[Any] 1249 1250 """Materializes the state lazily, one element at a time. 1251 :return A generator which returns the next element if advanced. 1252 """ 1253 while True: 1254 data, continuation_token = ( 1255 self._underlying.get_raw(state_key, continuation_token)) 1256 input_stream = coder_impl.create_InputStream(data) 1257 while input_stream.size() > 0: 1258 yield coder.decode_from_stream(input_stream, True) 1259 if not continuation_token: 1260 break 1261 1262 def _get_cache_token(self, state_key): 1263 # type: (beam_fn_api_pb2.StateKey) -> Optional[bytes] 1264 if not self._state_cache.is_cache_enabled(): 1265 return None 1266 elif state_key.HasField('bag_user_state'): 1267 if self._context.user_state_cache_token: 1268 return self._context.user_state_cache_token 1269 else: 1270 return self._context.bundle_cache_token 1271 elif state_key.WhichOneof('type').endswith('_side_input'): 1272 side_input = getattr(state_key, state_key.WhichOneof('type')) 1273 return self._context.side_input_cache_tokens.get( 1274 (side_input.transform_id, side_input.side_input_id), 1275 self._context.bundle_cache_token) 1276 return None 1277 1278 def _partially_cached_iterable( 1279 self, 1280 state_key, # type: beam_fn_api_pb2.StateKey 1281 coder # type: coder_impl.CoderImpl 1282 ): 1283 # type: (...) -> Iterable[Any] 1284 1285 """Materialized the first page of data, concatenated with a lazy iterable 1286 of the rest, if any. 1287 """ 1288 data, continuation_token = self._underlying.get_raw(state_key, None) 1289 head = [] 1290 input_stream = coder_impl.create_InputStream(data) 1291 while input_stream.size() > 0: 1292 head.append(coder.decode_from_stream(input_stream, True)) 1293 1294 if not continuation_token: 1295 return head 1296 else: 1297 return self.ContinuationIterable( 1298 head, 1299 functools.partial( 1300 self._lazy_iterator, state_key, coder, continuation_token)) 1301 1302 class ContinuationIterable(Generic[T], CacheAware): 1303 def __init__(self, head, continue_iterator_fn): 1304 # type: (Iterable[T], Callable[[], Iterable[T]]) -> None 1305 self.head = head 1306 self.continue_iterator_fn = continue_iterator_fn 1307 1308 def __iter__(self): 1309 # type: () -> Iterator[T] 1310 for item in self.head: 1311 yield item 1312 for item in self.continue_iterator_fn(): 1313 yield item 1314 1315 def get_referents_for_cache(self): 1316 # type: () -> List[Any] 1317 # Only capture the size of the elements and not the 1318 # continuation iterator since it references objects 1319 # we don't want to include in the cache measurement. 1320 return [self.head] 1321 1322 @staticmethod 1323 def _convert_to_cache_key(state_key): 1324 # type: (beam_fn_api_pb2.StateKey) -> bytes 1325 return state_key.SerializeToString() 1326 1327 1328 class _Future(Generic[T]): 1329 """A simple future object to implement blocking requests. 1330 """ 1331 def __init__(self): 1332 # type: () -> None 1333 self._event = threading.Event() 1334 1335 def wait(self, timeout=None): 1336 # type: (Optional[float]) -> bool 1337 return self._event.wait(timeout) 1338 1339 def get(self, timeout=None): 1340 # type: (Optional[float]) -> T 1341 if self.wait(timeout): 1342 return self._value 1343 else: 1344 raise LookupError() 1345 1346 def set(self, value): 1347 # type: (T) -> _Future[T] 1348 self._value = value 1349 self._event.set() 1350 return self 1351 1352 @classmethod 1353 def done(cls): 1354 # type: () -> _Future[None] 1355 if not hasattr(cls, 'DONE'): 1356 done_future = _Future[None]() 1357 done_future.set(None) 1358 cls.DONE = done_future # type: ignore[attr-defined] 1359 return cls.DONE # type: ignore[attr-defined] 1360 1361 1362 class _DeferredCall(_Future[T]): 1363 def __init__(self, func, *args): 1364 # type: (Callable[..., Any], *Any) -> None 1365 self._func = func 1366 self._args = [ 1367 arg if isinstance(arg, _Future) else _Future().set(arg) for arg in args 1368 ] 1369 1370 def wait(self, timeout=None): 1371 # type: (Optional[float]) -> bool 1372 return all(arg.wait(timeout) for arg in self._args) 1373 1374 def get(self, timeout=None): 1375 # type: (Optional[float]) -> T 1376 return self._func(*(arg.get(timeout) for arg in self._args)) 1377 1378 def set(self, value): 1379 # type: (T) -> _Future[T] 1380 raise NotImplementedError() 1381 1382 1383 class KeyedDefaultDict(DefaultDict[_KT, _VT]): 1384 if TYPE_CHECKING: 1385 # we promise to only use a subset of what DefaultDict can do 1386 def __init__(self, default_factory): 1387 # type: (Callable[[_KT], _VT]) -> None 1388 pass 1389 1390 def __missing__(self, key): 1391 # type: (_KT) -> _VT 1392 # typing: default_factory takes an arg, but the base class does not 1393 self[key] = self.default_factory(key) # type: ignore # pylint: disable=E1137 1394 return self[key]