github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/worker/sdk_worker.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """SDK harness for executing Python Fns via the Fn API."""
    19  
    20  # pytype: skip-file
    21  # mypy: disallow-untyped-defs
    22  
    23  import abc
    24  import collections
    25  import contextlib
    26  import functools
    27  import json
    28  import logging
    29  import queue
    30  import sys
    31  import threading
    32  import time
    33  import traceback
    34  from concurrent import futures
    35  from typing import TYPE_CHECKING
    36  from typing import Any
    37  from typing import Callable
    38  from typing import DefaultDict
    39  from typing import Dict
    40  from typing import FrozenSet
    41  from typing import Generic
    42  from typing import Iterable
    43  from typing import Iterator
    44  from typing import List
    45  from typing import MutableMapping
    46  from typing import Optional
    47  from typing import Tuple
    48  from typing import TypeVar
    49  from typing import Union
    50  
    51  import grpc
    52  
    53  from apache_beam.coders import coder_impl
    54  from apache_beam.metrics import monitoring_infos
    55  from apache_beam.metrics.execution import MetricsEnvironment
    56  from apache_beam.portability.api import beam_fn_api_pb2
    57  from apache_beam.portability.api import beam_fn_api_pb2_grpc
    58  from apache_beam.portability.api import metrics_pb2
    59  from apache_beam.runners.worker import bundle_processor
    60  from apache_beam.runners.worker import data_plane
    61  from apache_beam.runners.worker import data_sampler
    62  from apache_beam.runners.worker import statesampler
    63  from apache_beam.runners.worker.channel_factory import GRPCChannelFactory
    64  from apache_beam.runners.worker.data_plane import PeriodicThread
    65  from apache_beam.runners.worker.statecache import CacheAware
    66  from apache_beam.runners.worker.statecache import StateCache
    67  from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor
    68  from apache_beam.runners.worker.worker_status import FnApiWorkerStatusHandler
    69  from apache_beam.utils import thread_pool_executor
    70  from apache_beam.utils.sentinel import Sentinel
    71  
    72  if TYPE_CHECKING:
    73    from apache_beam.portability.api import endpoints_pb2
    74    from apache_beam.utils.profiler import Profile
    75  
    76  T = TypeVar('T')
    77  _KT = TypeVar('_KT')
    78  _VT = TypeVar('_VT')
    79  
    80  _LOGGER = logging.getLogger(__name__)
    81  
    82  DEFAULT_BUNDLE_PROCESSOR_CACHE_SHUTDOWN_THRESHOLD_S = 60
    83  # The number of ProcessBundleRequest instruction ids the BundleProcessorCache
    84  # will remember for not running instructions.
    85  MAX_KNOWN_NOT_RUNNING_INSTRUCTIONS = 1000
    86  # The number of ProcessBundleRequest instruction ids that BundleProcessorCache
    87  # will remember for failed instructions.
    88  MAX_FAILED_INSTRUCTIONS = 10000
    89  
    90  # retry on transient UNAVAILABLE grpc error from state channels.
    91  _GRPC_SERVICE_CONFIG = json.dumps({
    92      "methodConfig": [{
    93          "name": [{
    94              "service": "org.apache.beam.model.fn_execution.v1.BeamFnState"
    95          }],
    96          "retryPolicy": {
    97              "maxAttempts": 5,
    98              "initialBackoff": "0.1s",
    99              "maxBackoff": "5s",
   100              "backoffMultiplier": 2,
   101              "retryableStatusCodes": ["UNAVAILABLE"],
   102          },
   103      }]
   104  })
   105  
   106  
   107  class ShortIdCache(object):
   108    """ Cache for MonitoringInfo "short ids"
   109    """
   110    def __init__(self):
   111      # type: () -> None
   112      self._lock = threading.Lock()
   113      self._last_short_id = 0
   114      self._info_key_to_short_id = {}  # type: Dict[FrozenSet, str]
   115      self._short_id_to_info = {}  # type: Dict[str, metrics_pb2.MonitoringInfo]
   116  
   117    def get_short_id(self, monitoring_info):
   118      # type: (metrics_pb2.MonitoringInfo) -> str
   119  
   120      """ Returns the assigned shortId for a given MonitoringInfo, assigns one if
   121      not assigned already.
   122      """
   123      key = monitoring_infos.to_key(monitoring_info)
   124      with self._lock:
   125        try:
   126          return self._info_key_to_short_id[key]
   127        except KeyError:
   128          self._last_short_id += 1
   129  
   130          # Convert to a hex string (and drop the '0x') for some compression
   131          shortId = hex(self._last_short_id)[2:]
   132  
   133          payload_cleared = metrics_pb2.MonitoringInfo()
   134          payload_cleared.CopyFrom(monitoring_info)
   135          payload_cleared.ClearField('payload')
   136  
   137          self._info_key_to_short_id[key] = shortId
   138          self._short_id_to_info[shortId] = payload_cleared
   139          return shortId
   140  
   141    def get_infos(self, short_ids):
   142      #type: (Iterable[str]) -> Dict[str, metrics_pb2.MonitoringInfo]
   143  
   144      """ Gets the base MonitoringInfo (with payload cleared) for each short ID.
   145  
   146      Throws KeyError if an unassigned short ID is encountered.
   147      """
   148      return {
   149          short_id: self._short_id_to_info[short_id]
   150          for short_id in short_ids
   151      }
   152  
   153  
   154  SHORT_ID_CACHE = ShortIdCache()
   155  
   156  
   157  class SdkHarness(object):
   158    REQUEST_METHOD_PREFIX = '_request_'
   159  
   160    def __init__(
   161        self,
   162        control_address,  # type: str
   163        credentials=None,  # type: Optional[grpc.ChannelCredentials]
   164        worker_id=None,  # type: Optional[str]
   165        # Caching is disabled by default
   166        state_cache_size=0,  # type: int
   167        # time-based data buffering is disabled by default
   168        data_buffer_time_limit_ms=0,  # type: int
   169        profiler_factory=None,  # type: Optional[Callable[..., Profile]]
   170        status_address=None,  # type: Optional[str]
   171        # Heap dump through status api is disabled by default
   172        enable_heap_dump=False,  # type: bool
   173        data_sampler=None,  # type: Optional[data_sampler.DataSampler]
   174        # Unrecoverable SDK harness initialization error (if any)
   175        # that should be reported to the runner when proocessing the first bundle.
   176        deferred_exception=None, # type: Optional[Exception]
   177    ):
   178      # type: (...) -> None
   179      self._alive = True
   180      self._worker_index = 0
   181      self._worker_id = worker_id
   182      self._state_cache = StateCache(state_cache_size)
   183      self._deferred_exception = deferred_exception
   184      options = [('grpc.max_receive_message_length', -1),
   185                 ('grpc.max_send_message_length', -1)]
   186      if credentials is None:
   187        _LOGGER.info('Creating insecure control channel for %s.', control_address)
   188        self._control_channel = GRPCChannelFactory.insecure_channel(
   189            control_address, options=options)
   190      else:
   191        _LOGGER.info('Creating secure control channel for %s.', control_address)
   192        self._control_channel = GRPCChannelFactory.secure_channel(
   193            control_address, credentials, options=options)
   194      grpc.channel_ready_future(self._control_channel).result(timeout=60)
   195      _LOGGER.info('Control channel established.')
   196  
   197      self._control_channel = grpc.intercept_channel(
   198          self._control_channel, WorkerIdInterceptor(self._worker_id))
   199      self._data_channel_factory = data_plane.GrpcClientDataChannelFactory(
   200          credentials, self._worker_id, data_buffer_time_limit_ms)
   201      self._state_handler_factory = GrpcStateHandlerFactory(
   202          self._state_cache, credentials)
   203      self._profiler_factory = profiler_factory
   204      self.data_sampler = data_sampler
   205  
   206      def default_factory(id):
   207        # type: (str) -> beam_fn_api_pb2.ProcessBundleDescriptor
   208        return self._control_stub.GetProcessBundleDescriptor(
   209            beam_fn_api_pb2.GetProcessBundleDescriptorRequest(
   210                process_bundle_descriptor_id=id))
   211  
   212      self._fns = KeyedDefaultDict(default_factory)
   213      # BundleProcessor cache across all workers.
   214      self._bundle_processor_cache = BundleProcessorCache(
   215          state_handler_factory=self._state_handler_factory,
   216          data_channel_factory=self._data_channel_factory,
   217          fns=self._fns,
   218          data_sampler=self.data_sampler,
   219      )
   220  
   221      if status_address:
   222        try:
   223          self._status_handler = FnApiWorkerStatusHandler(
   224              status_address,
   225              self._bundle_processor_cache,
   226              self._state_cache,
   227              enable_heap_dump)  # type: Optional[FnApiWorkerStatusHandler]
   228        except Exception:
   229          traceback_string = traceback.format_exc()
   230          _LOGGER.warning(
   231              'Error creating worker status request handler, '
   232              'skipping status report. Trace back: %s' % traceback_string)
   233      else:
   234        self._status_handler = None
   235  
   236      # TODO(BEAM-8998) use common
   237      # thread_pool_executor.shared_unbounded_instance() to process bundle
   238      # progress once dataflow runner's excessive progress polling is removed.
   239      self._report_progress_executor = futures.ThreadPoolExecutor(max_workers=1)
   240      self._worker_thread_pool = thread_pool_executor.shared_unbounded_instance()
   241      self._responses = queue.Queue(
   242      )  # type: queue.Queue[Union[beam_fn_api_pb2.InstructionResponse, Sentinel]]
   243      _LOGGER.info('Initializing SDKHarness with unbounded number of workers.')
   244  
   245    def run(self):
   246      # type: () -> None
   247      self._control_stub = beam_fn_api_pb2_grpc.BeamFnControlStub(
   248          self._control_channel)
   249      no_more_work = Sentinel.sentinel
   250  
   251      def get_responses():
   252        # type: () -> Iterator[beam_fn_api_pb2.InstructionResponse]
   253        while True:
   254          response = self._responses.get()
   255          if response is no_more_work:
   256            return
   257          yield response
   258  
   259      self._alive = True
   260  
   261      try:
   262        for work_request in self._control_stub.Control(get_responses()):
   263          _LOGGER.debug('Got work %s', work_request.instruction_id)
   264          request_type = work_request.WhichOneof('request')
   265          # Name spacing the request method with 'request_'. The called method
   266          # will be like self.request_register(request)
   267          getattr(self, SdkHarness.REQUEST_METHOD_PREFIX + request_type)(
   268              work_request)
   269      finally:
   270        self._alive = False
   271  
   272      _LOGGER.info('No more requests from control plane')
   273      _LOGGER.info('SDK Harness waiting for in-flight requests to complete')
   274      # Wait until existing requests are processed.
   275      self._worker_thread_pool.shutdown()
   276      # get_responses may be blocked on responses.get(), but we need to return
   277      # control to its caller.
   278      self._responses.put(no_more_work)
   279      # Stop all the workers and clean all the associated resources
   280      self._data_channel_factory.close()
   281      self._state_handler_factory.close()
   282      self._bundle_processor_cache.shutdown()
   283      if self._status_handler:
   284        self._status_handler.close()
   285      _LOGGER.info('Done consuming work.')
   286  
   287    def _execute(
   288        self,
   289        task,  # type: Callable[[], beam_fn_api_pb2.InstructionResponse]
   290        request  # type:  beam_fn_api_pb2.InstructionRequest
   291    ):
   292      # type: (...) -> None
   293      with statesampler.instruction_id(request.instruction_id):
   294        try:
   295          response = task()
   296        except:  # pylint: disable=bare-except
   297          traceback_string = traceback.format_exc()
   298          print(traceback_string, file=sys.stderr)
   299          _LOGGER.error(
   300              'Error processing instruction %s. Original traceback is\n%s\n',
   301              request.instruction_id,
   302              traceback_string)
   303          response = beam_fn_api_pb2.InstructionResponse(
   304              instruction_id=request.instruction_id, error=traceback_string)
   305        self._responses.put(response)
   306  
   307    def _request_register(self, request):
   308      # type: (beam_fn_api_pb2.InstructionRequest) -> None
   309      # registration request is handled synchronously
   310      self._execute(lambda: self.create_worker().do_instruction(request), request)
   311  
   312    def _request_process_bundle(self, request):
   313      # type: (beam_fn_api_pb2.InstructionRequest) -> None
   314      if self._deferred_exception:
   315        raise self._deferred_exception
   316      self._bundle_processor_cache.activate(request.instruction_id)
   317      self._request_execute(request)
   318  
   319    def _request_process_bundle_split(self, request):
   320      # type: (beam_fn_api_pb2.InstructionRequest) -> None
   321      self._request_process_bundle_action(request)
   322  
   323    def _request_process_bundle_progress(self, request):
   324      # type: (beam_fn_api_pb2.InstructionRequest) -> None
   325      self._request_process_bundle_action(request)
   326  
   327    def _request_process_bundle_action(self, request):
   328      # type: (beam_fn_api_pb2.InstructionRequest) -> None
   329      def task():
   330        # type: () -> None
   331        self._execute(
   332            lambda: self.create_worker().do_instruction(request), request)
   333  
   334      self._report_progress_executor.submit(task)
   335  
   336    def _request_finalize_bundle(self, request):
   337      # type: (beam_fn_api_pb2.InstructionRequest) -> None
   338      self._request_execute(request)
   339  
   340    def _request_harness_monitoring_infos(self, request):
   341      # type: (beam_fn_api_pb2.InstructionRequest) -> None
   342      process_wide_monitoring_infos = MetricsEnvironment.process_wide_container(
   343      ).to_runner_api_monitoring_infos(None).values()
   344      self._execute(
   345          lambda: beam_fn_api_pb2.InstructionResponse(
   346              instruction_id=request.instruction_id,
   347              harness_monitoring_infos=(
   348                  beam_fn_api_pb2.HarnessMonitoringInfosResponse(
   349                      monitoring_data={
   350                          SHORT_ID_CACHE.get_short_id(info): info.payload
   351                          for info in process_wide_monitoring_infos
   352                      }))),
   353          request)
   354  
   355    def _request_monitoring_infos(self, request):
   356      # type: (beam_fn_api_pb2.InstructionRequest) -> None
   357      self._execute(
   358          lambda: beam_fn_api_pb2.InstructionResponse(
   359              instruction_id=request.instruction_id,
   360              monitoring_infos=beam_fn_api_pb2.MonitoringInfosMetadataResponse(
   361                  monitoring_info=SHORT_ID_CACHE.get_infos(
   362                      request.monitoring_infos.monitoring_info_id))),
   363          request)
   364  
   365    def _request_execute(self, request):
   366      # type: (beam_fn_api_pb2.InstructionRequest) -> None
   367      def task():
   368        # type: () -> None
   369        self._execute(
   370            lambda: self.create_worker().do_instruction(request), request)
   371  
   372      self._worker_thread_pool.submit(task)
   373      _LOGGER.debug(
   374          "Currently using %s threads." % len(self._worker_thread_pool._workers))
   375  
   376    def _request_sample_data(self, request):
   377      # type: (beam_fn_api_pb2.InstructionRequest) -> None
   378  
   379      def get_samples(request):
   380        # type: (beam_fn_api_pb2.InstructionRequest) -> beam_fn_api_pb2.InstructionResponse
   381        samples: Dict[str, List[bytes]] = {}
   382        if self.data_sampler:
   383          samples = self.data_sampler.samples(request.sample_data.pcollection_ids)
   384  
   385        sample_response = beam_fn_api_pb2.SampleDataResponse()
   386        for pcoll_id in samples:
   387          sample_response.element_samples[pcoll_id].elements.extend(
   388              beam_fn_api_pb2.SampledElement(element=s)
   389              for s in samples[pcoll_id])
   390  
   391        return beam_fn_api_pb2.InstructionResponse(
   392            instruction_id=request.instruction_id, sample_data=sample_response)
   393  
   394      self._execute(lambda: get_samples(request), request)
   395  
   396    def create_worker(self):
   397      # type: () -> SdkWorker
   398      return SdkWorker(
   399          self._bundle_processor_cache, profiler_factory=self._profiler_factory)
   400  
   401  
   402  class BundleProcessorCache(object):
   403    """A cache for ``BundleProcessor``s.
   404  
   405    ``BundleProcessor`` objects are cached by the id of their
   406    ``beam_fn_api_pb2.ProcessBundleDescriptor``.
   407  
   408    Attributes:
   409      fns (dict): A dictionary that maps bundle descriptor IDs to instances of
   410        ``beam_fn_api_pb2.ProcessBundleDescriptor``.
   411      state_handler_factory (``StateHandlerFactory``): Used to create state
   412        handlers to be used by a ``bundle_processor.BundleProcessor`` during
   413        processing.
   414      data_channel_factory (``data_plane.DataChannelFactory``)
   415      active_bundle_processors (dict): A dictionary, indexed by instruction IDs,
   416        containing ``bundle_processor.BundleProcessor`` objects that are currently
   417        active processing the corresponding instruction.
   418      cached_bundle_processors (dict): A dictionary, indexed by bundle processor
   419        id, of cached ``bundle_processor.BundleProcessor`` that are not currently
   420        performing processing.
   421    """
   422    periodic_shutdown = None  # type: Optional[PeriodicThread]
   423  
   424    def __init__(
   425        self,
   426        state_handler_factory,  # type: StateHandlerFactory
   427        data_channel_factory,  # type: data_plane.DataChannelFactory
   428        fns,  # type: MutableMapping[str, beam_fn_api_pb2.ProcessBundleDescriptor]
   429        data_sampler=None,  # type: Optional[data_sampler.DataSampler]
   430    ):
   431      # type: (...) -> None
   432      self.fns = fns
   433      self.state_handler_factory = state_handler_factory
   434      self.data_channel_factory = data_channel_factory
   435      self.known_not_running_instruction_ids = collections.OrderedDict(
   436      )  # type: collections.OrderedDict[str, bool]
   437      self.failed_instruction_ids = collections.OrderedDict(
   438      )  # type: collections.OrderedDict[str, bool]
   439      self.active_bundle_processors = {
   440      }  # type: Dict[str, Tuple[str, bundle_processor.BundleProcessor]]
   441      self.cached_bundle_processors = collections.defaultdict(
   442          list)  # type: DefaultDict[str, List[bundle_processor.BundleProcessor]]
   443      self.last_access_times = collections.defaultdict(
   444          float)  # type: DefaultDict[str, float]
   445      self._schedule_periodic_shutdown()
   446      self._lock = threading.Lock()
   447      self.data_sampler = data_sampler
   448  
   449    def register(self, bundle_descriptor):
   450      # type: (beam_fn_api_pb2.ProcessBundleDescriptor) -> None
   451  
   452      """Register a ``beam_fn_api_pb2.ProcessBundleDescriptor`` by its id."""
   453      self.fns[bundle_descriptor.id] = bundle_descriptor
   454  
   455    def activate(self, instruction_id):
   456      # type: (str) -> None
   457  
   458      """Makes the ``instruction_id`` known to the bundle processor.
   459  
   460      Allows ``lookup`` to return ``None``. Necessary if ``lookup`` can occur
   461      before ``get``.
   462      """
   463      with self._lock:
   464        self.known_not_running_instruction_ids[instruction_id] = True
   465  
   466    def get(self, instruction_id, bundle_descriptor_id):
   467      # type: (str, str) -> bundle_processor.BundleProcessor
   468  
   469      """
   470      Return the requested ``BundleProcessor``, creating it if necessary.
   471  
   472      Moves the ``BundleProcessor`` from the inactive to the active cache.
   473      """
   474      with self._lock:
   475        try:
   476          # pop() is threadsafe
   477          processor = self.cached_bundle_processors[bundle_descriptor_id].pop()
   478          self.active_bundle_processors[
   479            instruction_id] = bundle_descriptor_id, processor
   480          try:
   481            del self.known_not_running_instruction_ids[instruction_id]
   482          except KeyError:
   483            # The instruction may have not been pre-registered before execution
   484            # since activate() may have never been invoked
   485            pass
   486          return processor
   487        except IndexError:
   488          pass
   489  
   490      # Make sure we instantiate the processor while not holding the lock.
   491      processor = bundle_processor.BundleProcessor(
   492          self.fns[bundle_descriptor_id],
   493          self.state_handler_factory.create_state_handler(
   494              self.fns[bundle_descriptor_id].state_api_service_descriptor),
   495          self.data_channel_factory,
   496          self.data_sampler)
   497      with self._lock:
   498        self.active_bundle_processors[
   499          instruction_id] = bundle_descriptor_id, processor
   500        try:
   501          del self.known_not_running_instruction_ids[instruction_id]
   502        except KeyError:
   503          # The instruction may have not been pre-registered before execution
   504          # since activate() may have never been invoked
   505          pass
   506      return processor
   507  
   508    def lookup(self, instruction_id):
   509      # type: (str) -> Optional[bundle_processor.BundleProcessor]
   510  
   511      """
   512      Return the requested ``BundleProcessor`` from the cache.
   513  
   514      Will return ``None`` if the BundleProcessor is known but not yet ready. Will
   515      raise an error if the ``instruction_id`` is not known or has been discarded.
   516      """
   517      with self._lock:
   518        if instruction_id in self.failed_instruction_ids:
   519          raise RuntimeError(
   520              'Bundle processing associated with %s has failed. '
   521              'Check prior failing response for details.' % instruction_id)
   522        processor = self.active_bundle_processors.get(
   523            instruction_id, (None, None))[-1]
   524        if processor:
   525          return processor
   526        if instruction_id in self.known_not_running_instruction_ids:
   527          return None
   528        raise RuntimeError('Unknown process bundle id %s.' % instruction_id)
   529  
   530    def discard(self, instruction_id):
   531      # type: (str) -> None
   532  
   533      """
   534      Marks the instruction id as failed shutting down the ``BundleProcessor``.
   535      """
   536      with self._lock:
   537        self.failed_instruction_ids[instruction_id] = True
   538        while len(self.failed_instruction_ids) > MAX_FAILED_INSTRUCTIONS:
   539          self.failed_instruction_ids.popitem(last=False)
   540        processor = self.active_bundle_processors[instruction_id][1]
   541        del self.active_bundle_processors[instruction_id]
   542  
   543      # Perform the shutdown while not holding the lock.
   544      processor.shutdown()
   545  
   546    def release(self, instruction_id):
   547      # type: (str) -> None
   548  
   549      """
   550      Release the requested ``BundleProcessor``.
   551  
   552      Resets the ``BundleProcessor`` and moves it from the active to the
   553      inactive cache.
   554      """
   555      with self._lock:
   556        self.known_not_running_instruction_ids[instruction_id] = True
   557        while len(self.known_not_running_instruction_ids
   558                  ) > MAX_KNOWN_NOT_RUNNING_INSTRUCTIONS:
   559          self.known_not_running_instruction_ids.popitem(last=False)
   560        descriptor_id, processor = (
   561            self.active_bundle_processors.pop(instruction_id))
   562  
   563      # Make sure that we reset the processor while not holding the lock.
   564      processor.reset()
   565      with self._lock:
   566        self.last_access_times[descriptor_id] = time.time()
   567        self.cached_bundle_processors[descriptor_id].append(processor)
   568  
   569    def shutdown(self):
   570      # type: () -> None
   571  
   572      """
   573      Shutdown all ``BundleProcessor``s in the cache.
   574      """
   575      if self.periodic_shutdown:
   576        self.periodic_shutdown.cancel()
   577        self.periodic_shutdown.join()
   578        self.periodic_shutdown = None
   579  
   580      for instruction_id in list(self.active_bundle_processors.keys()):
   581        self.discard(instruction_id)
   582      for cached_bundle_processors in self.cached_bundle_processors.values():
   583        BundleProcessorCache._shutdown_cached_bundle_processors(
   584            cached_bundle_processors)
   585  
   586    def _schedule_periodic_shutdown(self):
   587      # type: () -> None
   588      def shutdown_inactive_bundle_processors():
   589        # type: () -> None
   590        for descriptor_id, last_access_time in self.last_access_times.items():
   591          if (time.time() - last_access_time >
   592              DEFAULT_BUNDLE_PROCESSOR_CACHE_SHUTDOWN_THRESHOLD_S):
   593            BundleProcessorCache._shutdown_cached_bundle_processors(
   594                self.cached_bundle_processors[descriptor_id])
   595  
   596      self.periodic_shutdown = PeriodicThread(
   597          DEFAULT_BUNDLE_PROCESSOR_CACHE_SHUTDOWN_THRESHOLD_S,
   598          shutdown_inactive_bundle_processors)
   599      self.periodic_shutdown.daemon = True
   600      self.periodic_shutdown.start()
   601  
   602    @staticmethod
   603    def _shutdown_cached_bundle_processors(cached_bundle_processors):
   604      # type: (List[bundle_processor.BundleProcessor]) -> None
   605      try:
   606        while True:
   607          # pop() is threadsafe
   608          bundle_processor = cached_bundle_processors.pop()
   609          bundle_processor.shutdown()
   610      except IndexError:
   611        pass
   612  
   613  
   614  class SdkWorker(object):
   615    def __init__(
   616        self,
   617        bundle_processor_cache,  # type: BundleProcessorCache
   618        profiler_factory=None,  # type: Optional[Callable[..., Profile]]
   619    ):
   620      # type: (...) -> None
   621      self.bundle_processor_cache = bundle_processor_cache
   622      self.profiler_factory = profiler_factory
   623  
   624    def do_instruction(self, request):
   625      # type: (beam_fn_api_pb2.InstructionRequest) -> beam_fn_api_pb2.InstructionResponse
   626      request_type = request.WhichOneof('request')
   627      if request_type:
   628        # E.g. if register is set, this will call self.register(request.register))
   629        return getattr(self, request_type)(
   630            getattr(request, request_type), request.instruction_id)
   631      else:
   632        raise NotImplementedError
   633  
   634    def register(
   635        self,
   636        request,  # type: beam_fn_api_pb2.RegisterRequest
   637        instruction_id  # type: str
   638    ):
   639      # type: (...) -> beam_fn_api_pb2.InstructionResponse
   640  
   641      """Registers a set of ``beam_fn_api_pb2.ProcessBundleDescriptor``s.
   642  
   643      This set of ``beam_fn_api_pb2.ProcessBundleDescriptor`` come as part of a
   644      ``beam_fn_api_pb2.RegisterRequest``, which the runner sends to the SDK
   645      worker before starting processing to register stages.
   646      """
   647  
   648      for process_bundle_descriptor in request.process_bundle_descriptor:
   649        self.bundle_processor_cache.register(process_bundle_descriptor)
   650      return beam_fn_api_pb2.InstructionResponse(
   651          instruction_id=instruction_id,
   652          register=beam_fn_api_pb2.RegisterResponse())
   653  
   654    def process_bundle(
   655        self,
   656        request,  # type: beam_fn_api_pb2.ProcessBundleRequest
   657        instruction_id  # type: str
   658    ):
   659      # type: (...) -> beam_fn_api_pb2.InstructionResponse
   660      bundle_processor = self.bundle_processor_cache.get(
   661          instruction_id, request.process_bundle_descriptor_id)
   662      try:
   663        with bundle_processor.state_handler.process_instruction_id(
   664            instruction_id, request.cache_tokens):
   665          with self.maybe_profile(instruction_id):
   666            delayed_applications, requests_finalization = (
   667                bundle_processor.process_bundle(instruction_id))
   668            monitoring_infos = bundle_processor.monitoring_infos()
   669            response = beam_fn_api_pb2.InstructionResponse(
   670                instruction_id=instruction_id,
   671                process_bundle=beam_fn_api_pb2.ProcessBundleResponse(
   672                    residual_roots=delayed_applications,
   673                    monitoring_infos=monitoring_infos,
   674                    monitoring_data={
   675                        SHORT_ID_CACHE.get_short_id(info): info.payload
   676                        for info in monitoring_infos
   677                    },
   678                    requires_finalization=requests_finalization))
   679        # Don't release here if finalize is needed.
   680        if not requests_finalization:
   681          self.bundle_processor_cache.release(instruction_id)
   682        return response
   683      except:  # pylint: disable=bare-except
   684        # Don't re-use bundle processors on failure.
   685        self.bundle_processor_cache.discard(instruction_id)
   686        raise
   687  
   688    def process_bundle_split(
   689        self,
   690        request,  # type: beam_fn_api_pb2.ProcessBundleSplitRequest
   691        instruction_id  # type: str
   692    ):
   693      # type: (...) -> beam_fn_api_pb2.InstructionResponse
   694      try:
   695        processor = self.bundle_processor_cache.lookup(request.instruction_id)
   696      except RuntimeError:
   697        return beam_fn_api_pb2.InstructionResponse(
   698            instruction_id=instruction_id, error=traceback.format_exc())
   699      # Return an empty response if we aren't running. This can happen
   700      # if the ProcessBundleRequest has not started or already finished.
   701      process_bundle_split = (
   702          processor.try_split(request)
   703          if processor else beam_fn_api_pb2.ProcessBundleSplitResponse())
   704      return beam_fn_api_pb2.InstructionResponse(
   705          instruction_id=instruction_id,
   706          process_bundle_split=process_bundle_split)
   707  
   708    def process_bundle_progress(
   709        self,
   710        request,  # type: beam_fn_api_pb2.ProcessBundleProgressRequest
   711        instruction_id  # type: str
   712    ):
   713      # type: (...) -> beam_fn_api_pb2.InstructionResponse
   714      try:
   715        processor = self.bundle_processor_cache.lookup(request.instruction_id)
   716      except RuntimeError:
   717        return beam_fn_api_pb2.InstructionResponse(
   718            instruction_id=instruction_id, error=traceback.format_exc())
   719      if processor:
   720        monitoring_infos = processor.monitoring_infos()
   721      else:
   722        # Return an empty response if we aren't running. This can happen
   723        # if the ProcessBundleRequest has not started or already finished.
   724        monitoring_infos = []
   725      return beam_fn_api_pb2.InstructionResponse(
   726          instruction_id=instruction_id,
   727          process_bundle_progress=beam_fn_api_pb2.ProcessBundleProgressResponse(
   728              monitoring_infos=monitoring_infos,
   729              monitoring_data={
   730                  SHORT_ID_CACHE.get_short_id(info): info.payload
   731                  for info in monitoring_infos
   732              }))
   733  
   734    def finalize_bundle(
   735        self,
   736        request,  # type: beam_fn_api_pb2.FinalizeBundleRequest
   737        instruction_id  # type: str
   738    ):
   739      # type: (...) -> beam_fn_api_pb2.InstructionResponse
   740      try:
   741        processor = self.bundle_processor_cache.lookup(request.instruction_id)
   742      except RuntimeError:
   743        return beam_fn_api_pb2.InstructionResponse(
   744            instruction_id=instruction_id, error=traceback.format_exc())
   745      if processor:
   746        try:
   747          finalize_response = processor.finalize_bundle()
   748          self.bundle_processor_cache.release(request.instruction_id)
   749          return beam_fn_api_pb2.InstructionResponse(
   750              instruction_id=instruction_id, finalize_bundle=finalize_response)
   751        except:
   752          self.bundle_processor_cache.discard(request.instruction_id)
   753          raise
   754      # We can reach this state if there was an erroneous request to finalize
   755      # the bundle while it is being initialized or has already been finalized
   756      # and released.
   757      raise RuntimeError(
   758          'Bundle is not in a finalizable state for %s' % instruction_id)
   759  
   760    @contextlib.contextmanager
   761    def maybe_profile(self, instruction_id):
   762      # type: (str) -> Iterator[None]
   763      if self.profiler_factory:
   764        profiler = self.profiler_factory(instruction_id)
   765        if profiler:
   766          with profiler:
   767            yield
   768        else:
   769          yield
   770      else:
   771        yield
   772  
   773  
   774  class StateHandler(metaclass=abc.ABCMeta):
   775    """An abstract object representing a ``StateHandler``."""
   776    @abc.abstractmethod
   777    def get_raw(
   778        self,
   779        state_key,  # type: beam_fn_api_pb2.StateKey
   780        continuation_token=None  # type: Optional[bytes]
   781    ):
   782      # type: (...) -> Tuple[bytes, Optional[bytes]]
   783  
   784      """Gets the contents of state for the given state key.
   785  
   786      State is associated to a state key, AND an instruction_id, which is set
   787      when calling process_instruction_id.
   788  
   789      Returns a tuple with the contents in state, and an optional continuation
   790      token, which is used to page the API.
   791      """
   792      raise NotImplementedError(type(self))
   793  
   794    @abc.abstractmethod
   795    def append_raw(
   796        self,
   797        state_key,  # type: beam_fn_api_pb2.StateKey
   798        data  # type: bytes
   799    ):
   800      # type: (...) -> _Future
   801  
   802      """Append the input data into the state key.
   803  
   804      Returns a future that allows one to wait for the completion of the call.
   805  
   806      State is associated to a state key, AND an instruction_id, which is set
   807      when calling process_instruction_id.
   808      """
   809      raise NotImplementedError(type(self))
   810  
   811    @abc.abstractmethod
   812    def clear(self, state_key):
   813      # type: (beam_fn_api_pb2.StateKey) -> _Future
   814  
   815      """Clears the contents of a cell for the input state key.
   816  
   817      Returns a future that allows one to wait for the completion of the call.
   818  
   819      State is associated to a state key, AND an instruction_id, which is set
   820      when calling process_instruction_id.
   821      """
   822      raise NotImplementedError(type(self))
   823  
   824    @abc.abstractmethod
   825    @contextlib.contextmanager
   826    def process_instruction_id(self, bundle_id):
   827      # type: (str) -> Iterator[None]
   828  
   829      """Switch the context of the state handler to a specific instruction.
   830  
   831      This must be called before performing any write or read operations on the
   832      existing state.
   833      """
   834      raise NotImplementedError(type(self))
   835  
   836    @abc.abstractmethod
   837    def done(self):
   838      # type: () -> None
   839  
   840      """Mark the state handler as done, and potentially delete all context."""
   841      raise NotImplementedError(type(self))
   842  
   843  
   844  class StateHandlerFactory(metaclass=abc.ABCMeta):
   845    """An abstract factory for creating ``DataChannel``."""
   846    @abc.abstractmethod
   847    def create_state_handler(self, api_service_descriptor):
   848      # type: (endpoints_pb2.ApiServiceDescriptor) -> CachingStateHandler
   849  
   850      """Returns a ``StateHandler`` from the given ApiServiceDescriptor."""
   851      raise NotImplementedError(type(self))
   852  
   853    @abc.abstractmethod
   854    def close(self):
   855      # type: () -> None
   856  
   857      """Close all channels that this factory owns."""
   858      raise NotImplementedError(type(self))
   859  
   860  
   861  class GrpcStateHandlerFactory(StateHandlerFactory):
   862    """A factory for ``GrpcStateHandler``.
   863  
   864    Caches the created channels by ``state descriptor url``.
   865    """
   866    def __init__(self, state_cache, credentials=None):
   867      # type: (StateCache, Optional[grpc.ChannelCredentials]) -> None
   868      self._state_handler_cache = {}  # type: Dict[str, CachingStateHandler]
   869      self._lock = threading.Lock()
   870      self._throwing_state_handler = ThrowingStateHandler()
   871      self._credentials = credentials
   872      self._state_cache = state_cache
   873  
   874    def create_state_handler(self, api_service_descriptor):
   875      # type: (endpoints_pb2.ApiServiceDescriptor) -> CachingStateHandler
   876      if not api_service_descriptor:
   877        return self._throwing_state_handler
   878      url = api_service_descriptor.url
   879      if url not in self._state_handler_cache:
   880        with self._lock:
   881          if url not in self._state_handler_cache:
   882            # Options to have no limits (-1) on the size of the messages
   883            # received or sent over the data plane. The actual buffer size is
   884            # controlled in a layer above.
   885            options = [('grpc.max_receive_message_length', -1),
   886                       ('grpc.max_send_message_length', -1),
   887                       ('grpc.service_config', _GRPC_SERVICE_CONFIG)]
   888            if self._credentials is None:
   889              _LOGGER.info('Creating insecure state channel for %s.', url)
   890              grpc_channel = GRPCChannelFactory.insecure_channel(
   891                  url, options=options)
   892            else:
   893              _LOGGER.info('Creating secure state channel for %s.', url)
   894              grpc_channel = GRPCChannelFactory.secure_channel(
   895                  url, self._credentials, options=options)
   896            _LOGGER.info('State channel established.')
   897            # Add workerId to the grpc channel
   898            grpc_channel = grpc.intercept_channel(
   899                grpc_channel, WorkerIdInterceptor())
   900            self._state_handler_cache[url] = GlobalCachingStateHandler(
   901                self._state_cache,
   902                GrpcStateHandler(
   903                    beam_fn_api_pb2_grpc.BeamFnStateStub(grpc_channel)))
   904      return self._state_handler_cache[url]
   905  
   906    def close(self):
   907      # type: () -> None
   908      _LOGGER.info('Closing all cached gRPC state handlers.')
   909      for _, state_handler in self._state_handler_cache.items():
   910        state_handler.done()
   911      self._state_handler_cache.clear()
   912      self._state_cache.invalidate_all()
   913  
   914  
   915  class CachingStateHandler(metaclass=abc.ABCMeta):
   916    @abc.abstractmethod
   917    @contextlib.contextmanager
   918    def process_instruction_id(self, bundle_id, cache_tokens):
   919      # type: (str, Iterable[beam_fn_api_pb2.ProcessBundleRequest.CacheToken]) -> Iterator[None]
   920      raise NotImplementedError(type(self))
   921  
   922    @abc.abstractmethod
   923    def blocking_get(
   924        self,
   925        state_key,  # type: beam_fn_api_pb2.StateKey
   926        coder,  # type: coder_impl.CoderImpl
   927    ):
   928      # type: (...) -> Iterable[Any]
   929      raise NotImplementedError(type(self))
   930  
   931    @abc.abstractmethod
   932    def extend(
   933        self,
   934        state_key,  # type: beam_fn_api_pb2.StateKey
   935        coder,  # type: coder_impl.CoderImpl
   936        elements,  # type: Iterable[Any]
   937    ):
   938      # type: (...) -> _Future
   939      raise NotImplementedError(type(self))
   940  
   941    @abc.abstractmethod
   942    def clear(self, state_key):
   943      # type: (beam_fn_api_pb2.StateKey) -> _Future
   944      raise NotImplementedError(type(self))
   945  
   946    @abc.abstractmethod
   947    def done(self):
   948      # type: () -> None
   949      raise NotImplementedError(type(self))
   950  
   951  
   952  class ThrowingStateHandler(CachingStateHandler):
   953    """A caching state handler that errors on any requests."""
   954    @contextlib.contextmanager
   955    def process_instruction_id(self, bundle_id, cache_tokens):
   956      # type: (str, Iterable[beam_fn_api_pb2.ProcessBundleRequest.CacheToken]) -> Iterator[None]
   957      raise RuntimeError(
   958          'Unable to handle state requests for ProcessBundleDescriptor '
   959          'for bundle id %s.' % bundle_id)
   960  
   961    def blocking_get(
   962        self,
   963        state_key,  # type: beam_fn_api_pb2.StateKey
   964        coder,  # type: coder_impl.CoderImpl
   965    ):
   966      # type: (...) -> Iterable[Any]
   967      raise RuntimeError(
   968          'Unable to handle state requests for ProcessBundleDescriptor without '
   969          'state ApiServiceDescriptor for state key %s.' % state_key)
   970  
   971    def extend(
   972        self,
   973        state_key,  # type: beam_fn_api_pb2.StateKey
   974        coder,  # type: coder_impl.CoderImpl
   975        elements,  # type: Iterable[Any]
   976    ):
   977      # type: (...) -> _Future
   978      raise RuntimeError(
   979          'Unable to handle state requests for ProcessBundleDescriptor without '
   980          'state ApiServiceDescriptor for state key %s.' % state_key)
   981  
   982    def clear(self, state_key):
   983      # type: (beam_fn_api_pb2.StateKey) -> _Future
   984      raise RuntimeError(
   985          'Unable to handle state requests for ProcessBundleDescriptor without '
   986          'state ApiServiceDescriptor for state key %s.' % state_key)
   987  
   988    def done(self):
   989      # type: () -> None
   990      raise RuntimeError(
   991          'Unable to handle state requests for ProcessBundleDescriptor.')
   992  
   993  
   994  class GrpcStateHandler(StateHandler):
   995  
   996    _DONE = Sentinel.sentinel
   997  
   998    def __init__(self, state_stub):
   999      # type: (beam_fn_api_pb2_grpc.BeamFnStateStub) -> None
  1000      self._lock = threading.Lock()
  1001      self._state_stub = state_stub
  1002      self._requests = queue.Queue(
  1003      )  # type: queue.Queue[Union[beam_fn_api_pb2.StateRequest, Sentinel]]
  1004      self._responses_by_id = {}  # type: Dict[str, _Future]
  1005      self._last_id = 0
  1006      self._exception = None  # type: Optional[Exception]
  1007      self._context = threading.local()
  1008      self.start()
  1009  
  1010    @contextlib.contextmanager
  1011    def process_instruction_id(self, bundle_id):
  1012      # type: (str) -> Iterator[None]
  1013      if getattr(self._context, 'process_instruction_id', None) is not None:
  1014        raise RuntimeError(
  1015            'Already bound to %r' % self._context.process_instruction_id)
  1016      self._context.process_instruction_id = bundle_id
  1017      try:
  1018        yield
  1019      finally:
  1020        self._context.process_instruction_id = None
  1021  
  1022    def start(self):
  1023      # type: () -> None
  1024      self._done = False
  1025  
  1026      def request_iter():
  1027        # type: () -> Iterator[beam_fn_api_pb2.StateRequest]
  1028        while True:
  1029          request = self._requests.get()
  1030          if request is self._DONE or self._done:
  1031            break
  1032          yield request
  1033  
  1034      responses = self._state_stub.State(request_iter())
  1035  
  1036      def pull_responses():
  1037        # type: () -> None
  1038        try:
  1039          for response in responses:
  1040            # Popping an item from a dictionary is atomic in cPython
  1041            future = self._responses_by_id.pop(response.id)
  1042            future.set(response)
  1043            if self._done:
  1044              break
  1045        except Exception as e:
  1046          self._exception = e
  1047          raise
  1048  
  1049      reader = threading.Thread(target=pull_responses, name='read_state')
  1050      reader.daemon = True
  1051      reader.start()
  1052  
  1053    def done(self):
  1054      # type: () -> None
  1055      self._done = True
  1056      self._requests.put(self._DONE)
  1057  
  1058    def get_raw(
  1059        self,
  1060        state_key,  # type: beam_fn_api_pb2.StateKey
  1061        continuation_token=None  # type: Optional[bytes]
  1062    ):
  1063      # type: (...) -> Tuple[bytes, Optional[bytes]]
  1064      response = self._blocking_request(
  1065          beam_fn_api_pb2.StateRequest(
  1066              state_key=state_key,
  1067              get=beam_fn_api_pb2.StateGetRequest(
  1068                  continuation_token=continuation_token)))
  1069      return response.get.data, response.get.continuation_token
  1070  
  1071    def append_raw(
  1072        self,
  1073        state_key,  # type: Optional[beam_fn_api_pb2.StateKey]
  1074        data  # type: bytes
  1075    ):
  1076      # type: (...) -> _Future
  1077      return self._request(
  1078          beam_fn_api_pb2.StateRequest(
  1079              state_key=state_key,
  1080              append=beam_fn_api_pb2.StateAppendRequest(data=data)))
  1081  
  1082    def clear(self, state_key):
  1083      # type: (Optional[beam_fn_api_pb2.StateKey]) -> _Future
  1084      return self._request(
  1085          beam_fn_api_pb2.StateRequest(
  1086              state_key=state_key, clear=beam_fn_api_pb2.StateClearRequest()))
  1087  
  1088    def _request(self, request):
  1089      # type: (beam_fn_api_pb2.StateRequest) -> _Future[beam_fn_api_pb2.StateResponse]
  1090      request.id = self._next_id()
  1091      request.instruction_id = self._context.process_instruction_id
  1092      # Adding a new item to a dictionary is atomic in cPython
  1093      self._responses_by_id[request.id] = future = _Future[
  1094          beam_fn_api_pb2.StateResponse]()
  1095      # Request queue is thread-safe
  1096      self._requests.put(request)
  1097      return future
  1098  
  1099    def _blocking_request(self, request):
  1100      # type: (beam_fn_api_pb2.StateRequest) -> beam_fn_api_pb2.StateResponse
  1101      req_future = self._request(request)
  1102      while not req_future.wait(timeout=1):
  1103        if self._exception:
  1104          raise self._exception
  1105        elif self._done:
  1106          raise RuntimeError()
  1107      response = req_future.get()
  1108      if response.error:
  1109        raise RuntimeError(response.error)
  1110      else:
  1111        return response
  1112  
  1113    def _next_id(self):
  1114      # type: () -> str
  1115      with self._lock:
  1116        # Use a lock here because this GrpcStateHandler is shared across all
  1117        # requests which have the same process bundle descriptor. State requests
  1118        # can concurrently access this section if a Runner uses threads / workers
  1119        # (aka "parallelism") to send data to this SdkHarness and its workers.
  1120        self._last_id += 1
  1121        request_id = self._last_id
  1122      return str(request_id)
  1123  
  1124  
  1125  class GlobalCachingStateHandler(CachingStateHandler):
  1126    """ A State handler which retrieves and caches state.
  1127     If caching is activated, caches across bundles using a supplied cache token.
  1128     If activated but no cache token is supplied, caching is done at the bundle
  1129     level.
  1130    """
  1131    def __init__(
  1132        self,
  1133        global_state_cache,  # type: StateCache
  1134        underlying_state  # type: StateHandler
  1135    ):
  1136      # type: (...) -> None
  1137      self._underlying = underlying_state
  1138      self._state_cache = global_state_cache
  1139      self._context = threading.local()
  1140  
  1141    @contextlib.contextmanager
  1142    def process_instruction_id(self, bundle_id, cache_tokens):
  1143      # type: (str, Iterable[beam_fn_api_pb2.ProcessBundleRequest.CacheToken]) -> Iterator[None]
  1144      if getattr(self._context, 'user_state_cache_token', None) is not None:
  1145        raise RuntimeError(
  1146            'Cache tokens already set to %s' %
  1147            self._context.user_state_cache_token)
  1148      self._context.side_input_cache_tokens = {}
  1149      user_state_cache_token = None
  1150      for cache_token_struct in cache_tokens:
  1151        if cache_token_struct.HasField("user_state"):
  1152          # There should only be one user state token present
  1153          assert not user_state_cache_token
  1154          user_state_cache_token = cache_token_struct.token
  1155        elif cache_token_struct.HasField("side_input"):
  1156          self._context.side_input_cache_tokens[
  1157              cache_token_struct.side_input.transform_id,
  1158              cache_token_struct.side_input.
  1159              side_input_id] = cache_token_struct.token
  1160      # TODO: Consider a two-level cache to avoid extra logic and locking
  1161      # for items cached at the bundle level.
  1162      self._context.bundle_cache_token = bundle_id
  1163      try:
  1164        self._context.user_state_cache_token = user_state_cache_token
  1165        with self._underlying.process_instruction_id(bundle_id):
  1166          yield
  1167      finally:
  1168        self._context.side_input_cache_tokens = {}
  1169        self._context.user_state_cache_token = None
  1170        self._context.bundle_cache_token = None
  1171  
  1172    def blocking_get(
  1173        self,
  1174        state_key,  # type: beam_fn_api_pb2.StateKey
  1175        coder,  # type: coder_impl.CoderImpl
  1176    ):
  1177      # type: (...) -> Iterable[Any]
  1178      cache_token = self._get_cache_token(state_key)
  1179      if not cache_token:
  1180        # Cache disabled / no cache token. Can't do a lookup/store in the cache.
  1181        # Fall back to lazily materializing the state, one element at a time.
  1182        return self._lazy_iterator(state_key, coder)
  1183      # Cache lookup
  1184      cache_state_key = self._convert_to_cache_key(state_key)
  1185      return self._state_cache.get(
  1186          (cache_state_key, cache_token),
  1187          lambda key: self._partially_cached_iterable(state_key, coder))
  1188  
  1189    def extend(
  1190        self,
  1191        state_key,  # type: beam_fn_api_pb2.StateKey
  1192        coder,  # type: coder_impl.CoderImpl
  1193        elements,  # type: Iterable[Any]
  1194    ):
  1195      # type: (...) -> _Future
  1196      cache_token = self._get_cache_token(state_key)
  1197      if cache_token:
  1198        # Update the cache if the value is already present and
  1199        # can be updated.
  1200        cache_key = self._convert_to_cache_key(state_key)
  1201        cached_value = self._state_cache.peek((cache_key, cache_token))
  1202        if isinstance(cached_value, list):
  1203          # The state is fully cached and can be extended
  1204  
  1205          # Materialize provided iterable to ensure reproducible iterations,
  1206          # here and when writing to the state handler below.
  1207          elements = list(elements)
  1208          cached_value.extend(elements)
  1209          # Re-insert into the cache the updated value so the updated size is
  1210          # reflected.
  1211          self._state_cache.put((cache_key, cache_token), cached_value)
  1212  
  1213      # Write to state handler
  1214      futures = []
  1215      out = coder_impl.create_OutputStream()
  1216      for element in elements:
  1217        coder.encode_to_stream(element, out, True)
  1218        if out.size() > data_plane._DEFAULT_SIZE_FLUSH_THRESHOLD:
  1219          futures.append(self._underlying.append_raw(state_key, out.get()))
  1220          out = coder_impl.create_OutputStream()
  1221      if out.size():
  1222        futures.append(self._underlying.append_raw(state_key, out.get()))
  1223      return _DeferredCall(
  1224          lambda *results: beam_fn_api_pb2.StateResponse(
  1225              error='\n'.join(
  1226                  result.error for result in results if result and result.error),
  1227              append=beam_fn_api_pb2.StateAppendResponse()),
  1228          *futures)
  1229  
  1230    def clear(self, state_key):
  1231      # type: (beam_fn_api_pb2.StateKey) -> _Future
  1232      cache_token = self._get_cache_token(state_key)
  1233      if cache_token:
  1234        cache_key = self._convert_to_cache_key(state_key)
  1235        self._state_cache.put((cache_key, cache_token), [])
  1236      return self._underlying.clear(state_key)
  1237  
  1238    def done(self):
  1239      # type: () -> None
  1240      self._underlying.done()
  1241  
  1242    def _lazy_iterator(
  1243        self,
  1244        state_key,  # type: beam_fn_api_pb2.StateKey
  1245        coder,  # type: coder_impl.CoderImpl
  1246        continuation_token=None  # type: Optional[bytes]
  1247    ):
  1248      # type: (...) -> Iterator[Any]
  1249  
  1250      """Materializes the state lazily, one element at a time.
  1251         :return A generator which returns the next element if advanced.
  1252      """
  1253      while True:
  1254        data, continuation_token = (
  1255            self._underlying.get_raw(state_key, continuation_token))
  1256        input_stream = coder_impl.create_InputStream(data)
  1257        while input_stream.size() > 0:
  1258          yield coder.decode_from_stream(input_stream, True)
  1259        if not continuation_token:
  1260          break
  1261  
  1262    def _get_cache_token(self, state_key):
  1263      # type: (beam_fn_api_pb2.StateKey) -> Optional[bytes]
  1264      if not self._state_cache.is_cache_enabled():
  1265        return None
  1266      elif state_key.HasField('bag_user_state'):
  1267        if self._context.user_state_cache_token:
  1268          return self._context.user_state_cache_token
  1269        else:
  1270          return self._context.bundle_cache_token
  1271      elif state_key.WhichOneof('type').endswith('_side_input'):
  1272        side_input = getattr(state_key, state_key.WhichOneof('type'))
  1273        return self._context.side_input_cache_tokens.get(
  1274            (side_input.transform_id, side_input.side_input_id),
  1275            self._context.bundle_cache_token)
  1276      return None
  1277  
  1278    def _partially_cached_iterable(
  1279        self,
  1280        state_key,  # type: beam_fn_api_pb2.StateKey
  1281        coder  # type: coder_impl.CoderImpl
  1282    ):
  1283      # type: (...) -> Iterable[Any]
  1284  
  1285      """Materialized the first page of data, concatenated with a lazy iterable
  1286      of the rest, if any.
  1287      """
  1288      data, continuation_token = self._underlying.get_raw(state_key, None)
  1289      head = []
  1290      input_stream = coder_impl.create_InputStream(data)
  1291      while input_stream.size() > 0:
  1292        head.append(coder.decode_from_stream(input_stream, True))
  1293  
  1294      if not continuation_token:
  1295        return head
  1296      else:
  1297        return self.ContinuationIterable(
  1298            head,
  1299            functools.partial(
  1300                self._lazy_iterator, state_key, coder, continuation_token))
  1301  
  1302    class ContinuationIterable(Generic[T], CacheAware):
  1303      def __init__(self, head, continue_iterator_fn):
  1304        # type: (Iterable[T], Callable[[], Iterable[T]]) -> None
  1305        self.head = head
  1306        self.continue_iterator_fn = continue_iterator_fn
  1307  
  1308      def __iter__(self):
  1309        # type: () -> Iterator[T]
  1310        for item in self.head:
  1311          yield item
  1312        for item in self.continue_iterator_fn():
  1313          yield item
  1314  
  1315      def get_referents_for_cache(self):
  1316        # type: () -> List[Any]
  1317        # Only capture the size of the elements and not the
  1318        # continuation iterator since it references objects
  1319        # we don't want to include in the cache measurement.
  1320        return [self.head]
  1321  
  1322    @staticmethod
  1323    def _convert_to_cache_key(state_key):
  1324      # type: (beam_fn_api_pb2.StateKey) -> bytes
  1325      return state_key.SerializeToString()
  1326  
  1327  
  1328  class _Future(Generic[T]):
  1329    """A simple future object to implement blocking requests.
  1330    """
  1331    def __init__(self):
  1332      # type: () -> None
  1333      self._event = threading.Event()
  1334  
  1335    def wait(self, timeout=None):
  1336      # type: (Optional[float]) -> bool
  1337      return self._event.wait(timeout)
  1338  
  1339    def get(self, timeout=None):
  1340      # type: (Optional[float]) -> T
  1341      if self.wait(timeout):
  1342        return self._value
  1343      else:
  1344        raise LookupError()
  1345  
  1346    def set(self, value):
  1347      # type: (T) -> _Future[T]
  1348      self._value = value
  1349      self._event.set()
  1350      return self
  1351  
  1352    @classmethod
  1353    def done(cls):
  1354      # type: () -> _Future[None]
  1355      if not hasattr(cls, 'DONE'):
  1356        done_future = _Future[None]()
  1357        done_future.set(None)
  1358        cls.DONE = done_future  # type: ignore[attr-defined]
  1359      return cls.DONE  # type: ignore[attr-defined]
  1360  
  1361  
  1362  class _DeferredCall(_Future[T]):
  1363    def __init__(self, func, *args):
  1364      # type: (Callable[..., Any], *Any) -> None
  1365      self._func = func
  1366      self._args = [
  1367          arg if isinstance(arg, _Future) else _Future().set(arg) for arg in args
  1368      ]
  1369  
  1370    def wait(self, timeout=None):
  1371      # type: (Optional[float]) -> bool
  1372      return all(arg.wait(timeout) for arg in self._args)
  1373  
  1374    def get(self, timeout=None):
  1375      # type: (Optional[float]) -> T
  1376      return self._func(*(arg.get(timeout) for arg in self._args))
  1377  
  1378    def set(self, value):
  1379      # type: (T) -> _Future[T]
  1380      raise NotImplementedError()
  1381  
  1382  
  1383  class KeyedDefaultDict(DefaultDict[_KT, _VT]):
  1384    if TYPE_CHECKING:
  1385      # we promise to only use a subset of what DefaultDict can do
  1386      def __init__(self, default_factory):
  1387        # type: (Callable[[_KT], _VT]) -> None
  1388        pass
  1389  
  1390    def __missing__(self, key):
  1391      # type: (_KT) -> _VT
  1392      # typing: default_factory takes an arg, but the base class does not
  1393      self[key] = self.default_factory(key)  # type: ignore # pylint: disable=E1137
  1394      return self[key]