github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/worker/sideinputs.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Utilities for handling side inputs."""
    19  
    20  # pytype: skip-file
    21  
    22  import logging
    23  import queue
    24  import threading
    25  import traceback
    26  from collections import abc
    27  
    28  from apache_beam.coders import observable
    29  from apache_beam.io import iobase
    30  from apache_beam.runners.worker import opcounters
    31  from apache_beam.transforms import window
    32  from apache_beam.utils.sentinel import Sentinel
    33  
    34  # Maximum number of reader threads for reading side input sources, per side
    35  # input.
    36  MAX_SOURCE_READER_THREADS = 15
    37  
    38  # Number of slots for elements in side input element queue.  Note that this
    39  # value is intentionally smaller than MAX_SOURCE_READER_THREADS so as to reduce
    40  # memory pressure of holding potentially-large elements in memory.  Note that
    41  # the number of pending elements in memory is equal to the sum of
    42  # MAX_SOURCE_READER_THREADS and ELEMENT_QUEUE_SIZE.
    43  ELEMENT_QUEUE_SIZE = 10
    44  
    45  # Special element value sentinel for signaling reader state.
    46  READER_THREAD_IS_DONE_SENTINEL = Sentinel.sentinel
    47  
    48  # Used to efficiently window the values of non-windowed side inputs.
    49  _globally_windowed = window.GlobalWindows.windowed_value(None).with_value
    50  
    51  _LOGGER = logging.getLogger(__name__)
    52  
    53  
    54  class PrefetchingSourceSetIterable(object):
    55    """Value iterator that reads concurrently from a set of sources."""
    56    def __init__(
    57        self,
    58        sources,
    59        max_reader_threads=MAX_SOURCE_READER_THREADS,
    60        read_counter=None,
    61        element_counter=None):
    62      self.sources = sources
    63      self.num_reader_threads = min(max_reader_threads, len(self.sources))
    64  
    65      # Queue for sources that are to be read.
    66      self.sources_queue = queue.Queue()
    67      for source in sources:
    68        self.sources_queue.put(source)
    69      # Queue for elements that have been read.
    70      self.element_queue = queue.Queue(ELEMENT_QUEUE_SIZE)
    71      # Queue for exceptions encountered in reader threads; to be rethrown.
    72      self.reader_exceptions = queue.Queue()
    73      # Whether we have already iterated; this iterable can only be used once.
    74      self.already_iterated = False
    75      # Whether an error was encountered in any source reader.
    76      self.has_errored = False
    77  
    78      self.read_counter = read_counter or opcounters.NoOpTransformIOCounter()
    79      self.element_counter = element_counter
    80      self.reader_threads = []
    81      self._start_reader_threads()
    82  
    83    def add_byte_counter(self, reader):
    84      """Adds byte counter observer to a side input reader.
    85  
    86      Args:
    87        reader: A reader that should inherit from ObservableMixin to have
    88          bytes tracked.
    89      """
    90      def update_bytes_read(record_size, is_record_size=False, **kwargs):
    91        # Let the reader report block size.
    92        if is_record_size:
    93          self.read_counter.add_bytes_read(record_size)
    94  
    95      if isinstance(reader, observable.ObservableMixin):
    96        reader.register_observer(update_bytes_read)
    97  
    98    def _start_reader_threads(self):
    99      for _ in range(0, self.num_reader_threads):
   100        t = threading.Thread(target=self._reader_thread)
   101        t.daemon = True
   102        t.start()
   103        self.reader_threads.append(t)
   104  
   105    def _reader_thread(self):
   106      # pylint: disable=too-many-nested-blocks
   107      try:
   108        while True:
   109          try:
   110            source = self.sources_queue.get_nowait()
   111            if isinstance(source, iobase.BoundedSource):
   112              for value in source.read(source.get_range_tracker(None, None)):
   113                if self.has_errored:
   114                  # If any reader has errored, just return.
   115                  return
   116                if isinstance(value, window.WindowedValue):
   117                  self.element_queue.put(value)
   118                else:
   119                  self.element_queue.put(_globally_windowed(value))
   120            else:
   121              # Native dataflow source.
   122              with source.reader() as reader:
   123                # The tracking of time spend reading and bytes read from side
   124                # inputs is kept behind an experiment flag to test performance
   125                # impact.
   126                self.add_byte_counter(reader)
   127                returns_windowed_values = reader.returns_windowed_values
   128                for value in reader:
   129                  if self.has_errored:
   130                    # If any reader has errored, just return.
   131                    return
   132                  if returns_windowed_values:
   133                    self.element_queue.put(value)
   134                  else:
   135                    self.element_queue.put(_globally_windowed(value))
   136          except queue.Empty:
   137            return
   138      except Exception as e:  # pylint: disable=broad-except
   139        _LOGGER.error(
   140            'Encountered exception in PrefetchingSourceSetIterable '
   141            'reader thread: %s',
   142            traceback.format_exc())
   143        self.reader_exceptions.put(e)
   144        self.has_errored = True
   145      finally:
   146        self.element_queue.put(READER_THREAD_IS_DONE_SENTINEL)
   147  
   148    def __iter__(self):
   149      # pylint: disable=too-many-nested-blocks
   150      if self.already_iterated:
   151        raise RuntimeError(
   152            'Can only iterate once over PrefetchingSourceSetIterable instance.')
   153      self.already_iterated = True
   154  
   155      # The invariants during execution are:
   156      # 1) A worker thread always posts the sentinel as the last thing it does
   157      #    before exiting.
   158      # 2) We always wait for all sentinels and then join all threads.
   159      num_readers_finished = 0
   160      try:
   161        while True:
   162          try:
   163            with self.read_counter:
   164              element = self.element_queue.get()
   165            if element is READER_THREAD_IS_DONE_SENTINEL:
   166              num_readers_finished += 1
   167              if num_readers_finished == self.num_reader_threads:
   168                return
   169            else:
   170              if self.element_counter:
   171                self.element_counter.update_from(element)
   172                yield element
   173                self.element_counter.update_collect()
   174              else:
   175                yield element
   176          finally:
   177            if self.has_errored:
   178              raise self.reader_exceptions.get()
   179      except GeneratorExit:
   180        self.has_errored = True
   181        raise
   182      finally:
   183        while num_readers_finished < self.num_reader_threads:
   184          element = self.element_queue.get()
   185          if element is READER_THREAD_IS_DONE_SENTINEL:
   186            num_readers_finished += 1
   187        for t in self.reader_threads:
   188          t.join()
   189  
   190  
   191  def get_iterator_fn_for_sources(
   192      sources,
   193      max_reader_threads=MAX_SOURCE_READER_THREADS,
   194      read_counter=None,
   195      element_counter=None):
   196    """Returns callable that returns iterator over elements for given sources."""
   197    def _inner():
   198      return iter(
   199          PrefetchingSourceSetIterable(
   200              sources,
   201              max_reader_threads=max_reader_threads,
   202              read_counter=read_counter,
   203              element_counter=element_counter))
   204  
   205    return _inner
   206  
   207  
   208  class EmulatedIterable(abc.Iterable):
   209    """Emulates an iterable for a side input."""
   210    def __init__(self, iterator_fn):
   211      self.iterator_fn = iterator_fn
   212  
   213    def __iter__(self):
   214      return self.iterator_fn()