github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/worker/data_plane.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Implementation of ``DataChannel``s to communicate across the data plane."""
    19  
    20  # pytype: skip-file
    21  # mypy: disallow-untyped-defs
    22  
    23  import abc
    24  import collections
    25  import json
    26  import logging
    27  import queue
    28  import threading
    29  import time
    30  from typing import TYPE_CHECKING
    31  from typing import Any
    32  from typing import Callable
    33  from typing import Collection
    34  from typing import DefaultDict
    35  from typing import Dict
    36  from typing import Iterable
    37  from typing import Iterator
    38  from typing import List
    39  from typing import Mapping
    40  from typing import Optional
    41  from typing import Set
    42  from typing import Tuple
    43  from typing import Union
    44  
    45  import grpc
    46  
    47  from apache_beam.coders import coder_impl
    48  from apache_beam.portability.api import beam_fn_api_pb2
    49  from apache_beam.portability.api import beam_fn_api_pb2_grpc
    50  from apache_beam.runners.worker.channel_factory import GRPCChannelFactory
    51  from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor
    52  
    53  if TYPE_CHECKING:
    54    import apache_beam.coders.slow_stream
    55  
    56    OutputStream = apache_beam.coders.slow_stream.OutputStream
    57    DataOrTimers = Union[beam_fn_api_pb2.Elements.Data,
    58                         beam_fn_api_pb2.Elements.Timers]
    59  else:
    60    OutputStream = type(coder_impl.create_OutputStream())
    61  
    62  _LOGGER = logging.getLogger(__name__)
    63  
    64  _DEFAULT_SIZE_FLUSH_THRESHOLD = 10 << 20  # 10MB
    65  _DEFAULT_TIME_FLUSH_THRESHOLD_MS = 0  # disable time-based flush by default
    66  
    67  # Keep a set of completed instructions to discard late received data. The set
    68  # can have up to _MAX_CLEANED_INSTRUCTIONS items. See _GrpcDataChannel.
    69  _MAX_CLEANED_INSTRUCTIONS = 10000
    70  
    71  # retry on transient UNAVAILABLE grpc error from data channels.
    72  _GRPC_SERVICE_CONFIG = json.dumps({
    73      "methodConfig": [{
    74          "name": [{
    75              "service": "org.apache.beam.model.fn_execution.v1.BeamFnData"
    76          }],
    77          "retryPolicy": {
    78              "maxAttempts": 5,
    79              "initialBackoff": "0.1s",
    80              "maxBackoff": "5s",
    81              "backoffMultiplier": 2,
    82              "retryableStatusCodes": ["UNAVAILABLE"],
    83          },
    84      }]
    85  })
    86  
    87  
    88  class ClosableOutputStream(OutputStream):
    89    """A Outputstream for use with CoderImpls that has a close() method."""
    90    def __init__(
    91        self,
    92        close_callback=None  # type: Optional[Callable[[bytes], None]]
    93    ):
    94      # type: (...) -> None
    95      super().__init__()
    96      self._close_callback = close_callback
    97  
    98    def close(self):
    99      # type: () -> None
   100      if self._close_callback:
   101        self._close_callback(self.get())
   102  
   103    def maybe_flush(self):
   104      # type: () -> None
   105      pass
   106  
   107    def flush(self):
   108      # type: () -> None
   109      pass
   110  
   111    @staticmethod
   112    def create(
   113        close_callback,  # type: Optional[Callable[[bytes], None]]
   114        flush_callback,  # type: Optional[Callable[[bytes], None]]
   115        data_buffer_time_limit_ms  # type: int
   116    ):
   117      # type: (...) -> ClosableOutputStream
   118      if data_buffer_time_limit_ms > 0:
   119        return TimeBasedBufferingClosableOutputStream(
   120            close_callback,
   121            flush_callback=flush_callback,
   122            time_flush_threshold_ms=data_buffer_time_limit_ms)
   123      else:
   124        return SizeBasedBufferingClosableOutputStream(
   125            close_callback, flush_callback=flush_callback)
   126  
   127  
   128  class SizeBasedBufferingClosableOutputStream(ClosableOutputStream):
   129    """A size-based buffering OutputStream."""
   130  
   131    def __init__(
   132        self,
   133        close_callback=None,  # type: Optional[Callable[[bytes], None]]
   134        flush_callback=None,  # type: Optional[Callable[[bytes], None]]
   135        size_flush_threshold=_DEFAULT_SIZE_FLUSH_THRESHOLD  # type: int
   136    ):
   137      super().__init__(close_callback)
   138      self._flush_callback = flush_callback
   139      self._size_flush_threshold = size_flush_threshold
   140  
   141    # This must be called explicitly to avoid flushing partial elements.
   142    def maybe_flush(self):
   143      # type: () -> None
   144      if self.size() > self._size_flush_threshold:
   145        self.flush()
   146  
   147    def flush(self):
   148      # type: () -> None
   149      if self._flush_callback:
   150        self._flush_callback(self.get())
   151        self._clear()
   152  
   153  
   154  class TimeBasedBufferingClosableOutputStream(
   155      SizeBasedBufferingClosableOutputStream):
   156    """A buffering OutputStream with both time-based and size-based."""
   157    _periodic_flusher = None  # type: Optional[PeriodicThread]
   158  
   159    def __init__(
   160        self,
   161        close_callback=None,  # type: Optional[Callable[[bytes], None]]
   162        flush_callback=None,  # type: Optional[Callable[[bytes], None]]
   163        size_flush_threshold=_DEFAULT_SIZE_FLUSH_THRESHOLD,  # type: int
   164        time_flush_threshold_ms=_DEFAULT_TIME_FLUSH_THRESHOLD_MS  # type: int
   165    ):
   166      # type: (...) -> None
   167      super().__init__(close_callback, flush_callback, size_flush_threshold)
   168      assert time_flush_threshold_ms > 0
   169      self._time_flush_threshold_ms = time_flush_threshold_ms
   170      self._flush_lock = threading.Lock()
   171      self._schedule_lock = threading.Lock()
   172      self._closed = False
   173      self._schedule_periodic_flush()
   174  
   175    def flush(self):
   176      # type: () -> None
   177      with self._flush_lock:
   178        super().flush()
   179  
   180    def close(self):
   181      # type: () -> None
   182      with self._schedule_lock:
   183        self._closed = True
   184        if self._periodic_flusher:
   185          self._periodic_flusher.cancel()
   186          self._periodic_flusher = None
   187      super().close()
   188  
   189    def _schedule_periodic_flush(self):
   190      # type: () -> None
   191      def _flush():
   192        # type: () -> None
   193        with self._schedule_lock:
   194          if not self._closed:
   195            self.flush()
   196  
   197      self._periodic_flusher = PeriodicThread(
   198          self._time_flush_threshold_ms / 1000.0, _flush)
   199      self._periodic_flusher.daemon = True
   200      self._periodic_flusher.start()
   201  
   202  
   203  class PeriodicThread(threading.Thread):
   204    """Call a function periodically with the specified number of seconds"""
   205  
   206    def __init__(
   207        self,
   208        interval,  # type: float
   209        function,  # type: Callable
   210        args=None,  # type: Optional[Iterable]
   211        kwargs=None  # type: Optional[Mapping[str, Any]]
   212    ):
   213      # type: (...) -> None
   214      threading.Thread.__init__(self)
   215      self._interval = interval
   216      self._function = function
   217      self._args = args if args is not None else []
   218      self._kwargs = kwargs if kwargs is not None else {}
   219      self._finished = threading.Event()
   220  
   221    def run(self):
   222      # type: () -> None
   223      next_call = time.time() + self._interval
   224      while not self._finished.wait(next_call - time.time()):
   225        next_call = next_call + self._interval
   226        self._function(*self._args, **self._kwargs)
   227  
   228    def cancel(self):
   229      # type: () -> None
   230  
   231      """Stop the thread if it hasn't finished yet."""
   232      self._finished.set()
   233  
   234  
   235  class DataChannel(metaclass=abc.ABCMeta):
   236    """Represents a channel for reading and writing data over the data plane.
   237  
   238    Read data and timer from this channel with the input_elements method::
   239  
   240      for elements_data in data_channel.input_elements(
   241          instruction_id, transform_ids, timers):
   242        [process elements_data]
   243  
   244    Write data to this channel using the output_stream method::
   245  
   246      out1 = data_channel.output_stream(instruction_id, transform_id)
   247      out1.write(...)
   248      out1.close()
   249  
   250    Write timer to this channel using the output_timer_stream method::
   251  
   252      out1 = data_channel.output_timer_stream(instruction_id,
   253                                              transform_id,
   254                                              timer_family_id)
   255      out1.write(...)
   256      out1.close()
   257  
   258    When all data/timer for all instructions is written, close the channel::
   259  
   260      data_channel.close()
   261    """
   262    @abc.abstractmethod
   263    def input_elements(
   264        self,
   265        instruction_id,  # type: str
   266        expected_inputs,  # type: Collection[Union[str, Tuple[str, str]]]
   267        abort_callback=None  # type: Optional[Callable[[], bool]]
   268    ):
   269      # type: (...) -> Iterator[DataOrTimers]
   270  
   271      """Returns an iterable of all Element.Data and Element.Timers bundles for
   272      instruction_id.
   273  
   274      This iterable terminates only once the full set of data has been recieved
   275      for each of the expected transforms. It may block waiting for more data.
   276  
   277      Args:
   278          instruction_id: which instruction the results must belong to
   279          expected_inputs: which transforms to wait on for completion
   280          abort_callback: a callback to invoke if blocking returning whether
   281              to abort before consuming all the data
   282      """
   283      raise NotImplementedError(type(self))
   284  
   285    @abc.abstractmethod
   286    def output_stream(
   287        self,
   288        instruction_id,  # type: str
   289        transform_id  # type: str
   290    ):
   291      # type: (...) -> ClosableOutputStream
   292  
   293      """Returns an output stream writing elements to transform_id.
   294  
   295      Args:
   296          instruction_id: which instruction this stream belongs to
   297          transform_id: the transform_id of the returned stream
   298      """
   299      raise NotImplementedError(type(self))
   300  
   301    @abc.abstractmethod
   302    def output_timer_stream(
   303        self,
   304        instruction_id,  # type: str
   305        transform_id,  # type: str
   306        timer_family_id  # type: str
   307    ):
   308      # type: (...) -> ClosableOutputStream
   309  
   310      """Returns an output stream written timers to transform_id.
   311  
   312      Args:
   313          instruction_id: which instruction this stream belongs to
   314          transform_id: the transform_id of the returned stream
   315          timer_family_id: the timer family of the written timer
   316      """
   317      raise NotImplementedError(type(self))
   318  
   319    @abc.abstractmethod
   320    def close(self):
   321      # type: () -> None
   322  
   323      """Closes this channel, indicating that all data has been written.
   324  
   325      Data can continue to be read.
   326  
   327      If this channel is shared by many instructions, should only be called on
   328      worker shutdown.
   329      """
   330      raise NotImplementedError(type(self))
   331  
   332  
   333  class InMemoryDataChannel(DataChannel):
   334    """An in-memory implementation of a DataChannel.
   335  
   336    This channel is two-sided.  What is written to one side is read by the other.
   337    The inverse() method returns the other side of a instance.
   338    """
   339    def __init__(self, inverse=None, data_buffer_time_limit_ms=0):
   340      # type: (Optional[InMemoryDataChannel], int) -> None
   341      self._inputs = []  # type: List[DataOrTimers]
   342      self._data_buffer_time_limit_ms = data_buffer_time_limit_ms
   343      self._inverse = inverse or InMemoryDataChannel(
   344          self, data_buffer_time_limit_ms=data_buffer_time_limit_ms)
   345  
   346    def inverse(self):
   347      # type: () -> InMemoryDataChannel
   348      return self._inverse
   349  
   350    def input_elements(
   351        self,
   352        instruction_id,  # type: str
   353        unused_expected_inputs,  # type: Any
   354        abort_callback=None  # type: Optional[Callable[[], bool]]
   355    ):
   356      # type: (...) -> Iterator[DataOrTimers]
   357      other_inputs = []
   358      for element in self._inputs:
   359        if element.instruction_id == instruction_id:
   360          if isinstance(element, beam_fn_api_pb2.Elements.Timers):
   361            if not element.is_last:
   362              yield element
   363          if isinstance(element, beam_fn_api_pb2.Elements.Data):
   364            if element.data or element.is_last:
   365              yield element
   366        else:
   367          other_inputs.append(element)
   368      self._inputs = other_inputs
   369  
   370    def output_timer_stream(
   371        self,
   372        instruction_id,  # type: str
   373        transform_id,  # type: str
   374        timer_family_id  # type: str
   375    ):
   376      # type: (...) -> ClosableOutputStream
   377      def add_to_inverse_output(timer):
   378        # type: (bytes) -> None
   379        if timer:
   380          self._inverse._inputs.append(
   381              beam_fn_api_pb2.Elements.Timers(
   382                  instruction_id=instruction_id,
   383                  transform_id=transform_id,
   384                  timer_family_id=timer_family_id,
   385                  timers=timer,
   386                  is_last=False))
   387  
   388      def close_stream(timer):
   389        # type: (bytes) -> None
   390        add_to_inverse_output(timer)
   391        self._inverse._inputs.append(
   392            beam_fn_api_pb2.Elements.Timers(
   393                instruction_id=instruction_id,
   394                transform_id=transform_id,
   395                timer_family_id='',
   396                is_last=True))
   397  
   398      return ClosableOutputStream.create(
   399          add_to_inverse_output, close_stream, self._data_buffer_time_limit_ms)
   400  
   401    def output_stream(self, instruction_id, transform_id):
   402      # type: (str, str) -> ClosableOutputStream
   403      def add_to_inverse_output(data):
   404        # type: (bytes) -> None
   405        self._inverse._inputs.append(  # pylint: disable=protected-access
   406            beam_fn_api_pb2.Elements.Data(
   407                instruction_id=instruction_id,
   408                transform_id=transform_id,
   409                data=data))
   410  
   411      return ClosableOutputStream.create(
   412          add_to_inverse_output,
   413          add_to_inverse_output,
   414          self._data_buffer_time_limit_ms)
   415  
   416    def close(self):
   417      # type: () -> None
   418      pass
   419  
   420  
   421  class _GrpcDataChannel(DataChannel):
   422    """Base class for implementing a BeamFnData-based DataChannel."""
   423  
   424    _WRITES_FINISHED = object()
   425  
   426    def __init__(self, data_buffer_time_limit_ms=0):
   427      # type: (int) -> None
   428      self._data_buffer_time_limit_ms = data_buffer_time_limit_ms
   429      self._to_send = queue.Queue()  # type: queue.Queue[DataOrTimers]
   430      self._received = collections.defaultdict(
   431          lambda: queue.Queue(maxsize=5)
   432      )  # type: DefaultDict[str, queue.Queue[DataOrTimers]]
   433  
   434      # Keep a cache of completed instructions. Data for completed instructions
   435      # must be discarded. See input_elements() and _clean_receiving_queue().
   436      # OrderedDict is used as FIFO set with the value being always `True`.
   437      self._cleaned_instruction_ids = collections.OrderedDict(
   438      )  # type: collections.OrderedDict[str, bool]
   439  
   440      self._receive_lock = threading.Lock()
   441      self._reads_finished = threading.Event()
   442      self._closed = False
   443      self._exception = None  # type: Optional[Exception]
   444  
   445    def close(self):
   446      # type: () -> None
   447      self._to_send.put(self._WRITES_FINISHED)  # type: ignore[arg-type]
   448      self._closed = True
   449  
   450    def wait(self, timeout=None):
   451      # type: (Optional[int]) -> None
   452      self._reads_finished.wait(timeout)
   453  
   454    def _receiving_queue(self, instruction_id):
   455      # type: (str) -> Optional[queue.Queue[DataOrTimers]]
   456  
   457      """
   458      Gets or creates queue for a instruction_id. Or, returns None if the
   459      instruction_id is already cleaned up. This is best-effort as we track
   460      a limited number of cleaned-up instructions.
   461      """
   462      with self._receive_lock:
   463        if instruction_id in self._cleaned_instruction_ids:
   464          return None
   465        return self._received[instruction_id]
   466  
   467    def _clean_receiving_queue(self, instruction_id):
   468      # type: (str) -> None
   469  
   470      """
   471      Removes the queue and adds the instruction_id to the cleaned-up list. The
   472      instruction_id cannot be reused for new queue.
   473      """
   474      with self._receive_lock:
   475        self._received.pop(instruction_id)
   476        self._cleaned_instruction_ids[instruction_id] = True
   477        while len(self._cleaned_instruction_ids) > _MAX_CLEANED_INSTRUCTIONS:
   478          self._cleaned_instruction_ids.popitem(last=False)
   479  
   480    def input_elements(
   481        self,
   482        instruction_id,  # type: str
   483        expected_inputs,  # type: Collection[Union[str, Tuple[str, str]]]
   484        abort_callback=None  # type: Optional[Callable[[], bool]]
   485    ):
   486  
   487      # type: (...) -> Iterator[DataOrTimers]
   488  
   489      """
   490      Generator to retrieve elements for an instruction_id
   491      input_elements should be called only once for an instruction_id
   492  
   493      Args:
   494        instruction_id(str): instruction_id for which data is read
   495        expected_inputs(collection): expected inputs, include both data and timer.
   496      """
   497      received = self._receiving_queue(instruction_id)
   498      if received is None:
   499        raise RuntimeError('Instruction cleaned up already %s' % instruction_id)
   500      done_inputs = set()  # type: Set[Union[str, Tuple[str, str]]]
   501      abort_callback = abort_callback or (lambda: False)
   502      try:
   503        while len(done_inputs) < len(expected_inputs):
   504          try:
   505            element = received.get(timeout=1)
   506          except queue.Empty:
   507            if self._closed:
   508              raise RuntimeError('Channel closed prematurely.')
   509            if abort_callback():
   510              return
   511            if self._exception:
   512              raise self._exception from None
   513          else:
   514            if isinstance(element, beam_fn_api_pb2.Elements.Timers):
   515              if element.is_last:
   516                done_inputs.add((element.transform_id, element.timer_family_id))
   517              else:
   518                yield element
   519            elif isinstance(element, beam_fn_api_pb2.Elements.Data):
   520              if element.is_last:
   521                done_inputs.add(element.transform_id)
   522              else:
   523                assert element.transform_id not in done_inputs
   524                yield element
   525            else:
   526              raise ValueError('Unexpected input element type %s' % type(element))
   527      finally:
   528        # Instruction_ids are not reusable so Clean queue once we are done with
   529        #  an instruction_id
   530        self._clean_receiving_queue(instruction_id)
   531  
   532    def output_stream(self, instruction_id, transform_id):
   533      # type: (str, str) -> ClosableOutputStream
   534      def add_to_send_queue(data):
   535        # type: (bytes) -> None
   536        if data:
   537          self._to_send.put(
   538              beam_fn_api_pb2.Elements.Data(
   539                  instruction_id=instruction_id,
   540                  transform_id=transform_id,
   541                  data=data))
   542  
   543      def close_callback(data):
   544        # type: (bytes) -> None
   545        add_to_send_queue(data)
   546        # End of stream marker.
   547        self._to_send.put(
   548            beam_fn_api_pb2.Elements.Data(
   549                instruction_id=instruction_id,
   550                transform_id=transform_id,
   551                is_last=True))
   552  
   553      return ClosableOutputStream.create(
   554          close_callback, add_to_send_queue, self._data_buffer_time_limit_ms)
   555  
   556    def output_timer_stream(
   557        self,
   558        instruction_id,  # type: str
   559        transform_id,  # type: str
   560        timer_family_id  # type: str
   561    ):
   562      # type: (...) -> ClosableOutputStream
   563      def add_to_send_queue(timer):
   564        # type: (bytes) -> None
   565        if timer:
   566          self._to_send.put(
   567              beam_fn_api_pb2.Elements.Timers(
   568                  instruction_id=instruction_id,
   569                  transform_id=transform_id,
   570                  timer_family_id=timer_family_id,
   571                  timers=timer,
   572                  is_last=False))
   573  
   574      def close_callback(timer):
   575        # type: (bytes) -> None
   576        add_to_send_queue(timer)
   577        self._to_send.put(
   578            beam_fn_api_pb2.Elements.Timers(
   579                instruction_id=instruction_id,
   580                transform_id=transform_id,
   581                timer_family_id=timer_family_id,
   582                is_last=True))
   583  
   584      return ClosableOutputStream.create(
   585          close_callback, add_to_send_queue, self._data_buffer_time_limit_ms)
   586  
   587    def _write_outputs(self):
   588      # type: () -> Iterator[beam_fn_api_pb2.Elements]
   589      stream_done = False
   590      while not stream_done:
   591        streams = [self._to_send.get()]
   592        try:
   593          # Coalesce up to 100 other items.
   594          for _ in range(100):
   595            streams.append(self._to_send.get_nowait())
   596        except queue.Empty:
   597          pass
   598        if streams[-1] is self._WRITES_FINISHED:
   599          stream_done = True
   600          streams.pop()
   601        if streams:
   602          data_stream = []
   603          timer_stream = []
   604          for stream in streams:
   605            if isinstance(stream, beam_fn_api_pb2.Elements.Timers):
   606              timer_stream.append(stream)
   607            elif isinstance(stream, beam_fn_api_pb2.Elements.Data):
   608              data_stream.append(stream)
   609            else:
   610              raise ValueError('Unexpected output element type %s' % type(stream))
   611          yield beam_fn_api_pb2.Elements(data=data_stream, timers=timer_stream)
   612  
   613    def _read_inputs(self, elements_iterator):
   614      # type: (Iterable[beam_fn_api_pb2.Elements]) -> None
   615  
   616      next_discard_log_time = 0  # type: float
   617  
   618      def _put_queue(instruction_id, element):
   619        # type: (str, Union[beam_fn_api_pb2.Elements.Data, beam_fn_api_pb2.Elements.Timers]) -> None
   620  
   621        """
   622        Puts element to the queue of the instruction_id, or discards it if the
   623        instruction_id is already cleaned up.
   624        """
   625        nonlocal next_discard_log_time
   626        start_time = time.time()
   627        next_waiting_log_time = start_time + 300
   628        while True:
   629          input_queue = self._receiving_queue(instruction_id)
   630          if input_queue is None:
   631            current_time = time.time()
   632            if next_discard_log_time <= current_time:
   633              # Log every 10 seconds across all _put_queue calls
   634              _LOGGER.info(
   635                  'Discard inputs for cleaned up instruction: %s', instruction_id)
   636              next_discard_log_time = current_time + 10
   637            return
   638          try:
   639            input_queue.put(element, timeout=1)
   640            return
   641          except queue.Full:
   642            current_time = time.time()
   643            if next_waiting_log_time <= current_time:
   644              # Log every 5 mins in each _put_queue call
   645              _LOGGER.info(
   646                  'Waiting on input queue of instruction: %s for %.2f seconds',
   647                  instruction_id,
   648                  current_time - start_time)
   649              next_waiting_log_time = current_time + 300
   650  
   651      try:
   652        for elements in elements_iterator:
   653          for timer in elements.timers:
   654            _put_queue(timer.instruction_id, timer)
   655          for data in elements.data:
   656            _put_queue(data.instruction_id, data)
   657      except Exception as e:
   658        if not self._closed:
   659          _LOGGER.exception('Failed to read inputs in the data plane.')
   660          self._exception = e
   661          raise
   662      finally:
   663        self._closed = True
   664        self._reads_finished.set()
   665  
   666    def set_inputs(self, elements_iterator):
   667      # type: (Iterable[beam_fn_api_pb2.Elements]) -> None
   668      reader = threading.Thread(
   669          target=lambda: self._read_inputs(elements_iterator),
   670          name='read_grpc_client_inputs')
   671      reader.daemon = True
   672      reader.start()
   673  
   674  
   675  class GrpcClientDataChannel(_GrpcDataChannel):
   676    """A DataChannel wrapping the client side of a BeamFnData connection."""
   677  
   678    def __init__(
   679        self,
   680        data_stub,  # type: beam_fn_api_pb2_grpc.BeamFnDataStub
   681        data_buffer_time_limit_ms=0  # type: int
   682    ):
   683      # type: (...) -> None
   684      super().__init__(data_buffer_time_limit_ms)
   685      self.set_inputs(data_stub.Data(self._write_outputs()))
   686  
   687  
   688  class BeamFnDataServicer(beam_fn_api_pb2_grpc.BeamFnDataServicer):
   689    """Implementation of BeamFnDataServicer for any number of clients"""
   690    def __init__(
   691        self,
   692        data_buffer_time_limit_ms=0  # type: int
   693    ):
   694      self._lock = threading.Lock()
   695      self._connections_by_worker_id = collections.defaultdict(
   696          lambda: _GrpcDataChannel(data_buffer_time_limit_ms)
   697      )  # type: DefaultDict[str, _GrpcDataChannel]
   698  
   699    def get_conn_by_worker_id(self, worker_id):
   700      # type: (str) -> _GrpcDataChannel
   701      with self._lock:
   702        return self._connections_by_worker_id[worker_id]
   703  
   704    def Data(
   705        self,
   706        elements_iterator,  # type: Iterable[beam_fn_api_pb2.Elements]
   707        context  # type: Any
   708    ):
   709      # type: (...) -> Iterator[beam_fn_api_pb2.Elements]
   710      worker_id = dict(context.invocation_metadata())['worker_id']
   711      data_conn = self.get_conn_by_worker_id(worker_id)
   712      data_conn.set_inputs(elements_iterator)
   713      for elements in data_conn._write_outputs():
   714        yield elements
   715  
   716  
   717  class DataChannelFactory(metaclass=abc.ABCMeta):
   718    """An abstract factory for creating ``DataChannel``."""
   719    @abc.abstractmethod
   720    def create_data_channel(self, remote_grpc_port):
   721      # type: (beam_fn_api_pb2.RemoteGrpcPort) -> GrpcClientDataChannel
   722  
   723      """Returns a ``DataChannel`` from the given RemoteGrpcPort."""
   724      raise NotImplementedError(type(self))
   725  
   726    @abc.abstractmethod
   727    def create_data_channel_from_url(self, url):
   728      # type: (str) -> Optional[GrpcClientDataChannel]
   729  
   730      """Returns a ``DataChannel`` from the given url."""
   731      raise NotImplementedError(type(self))
   732  
   733    @abc.abstractmethod
   734    def close(self):
   735      # type: () -> None
   736  
   737      """Close all channels that this factory owns."""
   738      raise NotImplementedError(type(self))
   739  
   740  
   741  class GrpcClientDataChannelFactory(DataChannelFactory):
   742    """A factory for ``GrpcClientDataChannel``.
   743  
   744    Caches the created channels by ``data descriptor url``.
   745    """
   746  
   747    def __init__(
   748        self,
   749        credentials=None,  # type: Any
   750        worker_id=None,  # type: Optional[str]
   751        data_buffer_time_limit_ms=0  # type: int
   752    ):
   753      # type: (...) -> None
   754      self._data_channel_cache = {}  # type: Dict[str, GrpcClientDataChannel]
   755      self._lock = threading.Lock()
   756      self._credentials = None
   757      self._worker_id = worker_id
   758      self._data_buffer_time_limit_ms = data_buffer_time_limit_ms
   759      if credentials is not None:
   760        _LOGGER.info('Using secure channel creds.')
   761        self._credentials = credentials
   762  
   763    def create_data_channel_from_url(self, url):
   764      # type: (str) -> Optional[GrpcClientDataChannel]
   765      if not url:
   766        return None
   767      if url not in self._data_channel_cache:
   768        with self._lock:
   769          if url not in self._data_channel_cache:
   770            _LOGGER.info('Creating client data channel for %s', url)
   771            # Options to have no limits (-1) on the size of the messages
   772            # received or sent over the data plane. The actual buffer size
   773            # is controlled in a layer above.
   774            channel_options = [("grpc.max_receive_message_length", -1),
   775                               ("grpc.max_send_message_length", -1),
   776                               ("grpc.service_config", _GRPC_SERVICE_CONFIG)]
   777            grpc_channel = None
   778            if self._credentials is None:
   779              grpc_channel = GRPCChannelFactory.insecure_channel(
   780                  url, options=channel_options)
   781            else:
   782              grpc_channel = GRPCChannelFactory.secure_channel(
   783                  url, self._credentials, options=channel_options)
   784            # Add workerId to the grpc channel
   785            grpc_channel = grpc.intercept_channel(
   786                grpc_channel, WorkerIdInterceptor(self._worker_id))
   787            self._data_channel_cache[url] = GrpcClientDataChannel(
   788                beam_fn_api_pb2_grpc.BeamFnDataStub(grpc_channel),
   789                self._data_buffer_time_limit_ms)
   790  
   791      return self._data_channel_cache[url]
   792  
   793    def create_data_channel(self, remote_grpc_port):
   794      # type: (beam_fn_api_pb2.RemoteGrpcPort) -> GrpcClientDataChannel
   795      url = remote_grpc_port.api_service_descriptor.url
   796      # TODO(https://github.com/apache/beam/issues/19737): this can return None
   797      #  if url is falsey, but this seems incorrect, as code that calls this
   798      #  method seems to always expect non-Optional values.
   799      return self.create_data_channel_from_url(url)  # type: ignore[return-value]
   800  
   801    def close(self):
   802      # type: () -> None
   803      _LOGGER.info('Closing all cached grpc data channels.')
   804      for _, channel in self._data_channel_cache.items():
   805        channel.close()
   806      self._data_channel_cache.clear()
   807  
   808  
   809  class InMemoryDataChannelFactory(DataChannelFactory):
   810    """A singleton factory for ``InMemoryDataChannel``."""
   811    def __init__(self, in_memory_data_channel):
   812      # type: (GrpcClientDataChannel) -> None
   813      self._in_memory_data_channel = in_memory_data_channel
   814  
   815    def create_data_channel(self, unused_remote_grpc_port):
   816      # type: (beam_fn_api_pb2.RemoteGrpcPort) -> GrpcClientDataChannel
   817      return self._in_memory_data_channel
   818  
   819    def create_data_channel_from_url(self, url):
   820      # type: (Any) -> GrpcClientDataChannel
   821      return self._in_memory_data_channel
   822  
   823    def close(self):
   824      # type: () -> None
   825      pass