github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/worker/sideinputs.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Utilities for handling side inputs.""" 19 20 # pytype: skip-file 21 22 import logging 23 import queue 24 import threading 25 import traceback 26 from collections import abc 27 28 from apache_beam.coders import observable 29 from apache_beam.io import iobase 30 from apache_beam.runners.worker import opcounters 31 from apache_beam.transforms import window 32 from apache_beam.utils.sentinel import Sentinel 33 34 # Maximum number of reader threads for reading side input sources, per side 35 # input. 36 MAX_SOURCE_READER_THREADS = 15 37 38 # Number of slots for elements in side input element queue. Note that this 39 # value is intentionally smaller than MAX_SOURCE_READER_THREADS so as to reduce 40 # memory pressure of holding potentially-large elements in memory. Note that 41 # the number of pending elements in memory is equal to the sum of 42 # MAX_SOURCE_READER_THREADS and ELEMENT_QUEUE_SIZE. 43 ELEMENT_QUEUE_SIZE = 10 44 45 # Special element value sentinel for signaling reader state. 46 READER_THREAD_IS_DONE_SENTINEL = Sentinel.sentinel 47 48 # Used to efficiently window the values of non-windowed side inputs. 49 _globally_windowed = window.GlobalWindows.windowed_value(None).with_value 50 51 _LOGGER = logging.getLogger(__name__) 52 53 54 class PrefetchingSourceSetIterable(object): 55 """Value iterator that reads concurrently from a set of sources.""" 56 def __init__( 57 self, 58 sources, 59 max_reader_threads=MAX_SOURCE_READER_THREADS, 60 read_counter=None, 61 element_counter=None): 62 self.sources = sources 63 self.num_reader_threads = min(max_reader_threads, len(self.sources)) 64 65 # Queue for sources that are to be read. 66 self.sources_queue = queue.Queue() 67 for source in sources: 68 self.sources_queue.put(source) 69 # Queue for elements that have been read. 70 self.element_queue = queue.Queue(ELEMENT_QUEUE_SIZE) 71 # Queue for exceptions encountered in reader threads; to be rethrown. 72 self.reader_exceptions = queue.Queue() 73 # Whether we have already iterated; this iterable can only be used once. 74 self.already_iterated = False 75 # Whether an error was encountered in any source reader. 76 self.has_errored = False 77 78 self.read_counter = read_counter or opcounters.NoOpTransformIOCounter() 79 self.element_counter = element_counter 80 self.reader_threads = [] 81 self._start_reader_threads() 82 83 def add_byte_counter(self, reader): 84 """Adds byte counter observer to a side input reader. 85 86 Args: 87 reader: A reader that should inherit from ObservableMixin to have 88 bytes tracked. 89 """ 90 def update_bytes_read(record_size, is_record_size=False, **kwargs): 91 # Let the reader report block size. 92 if is_record_size: 93 self.read_counter.add_bytes_read(record_size) 94 95 if isinstance(reader, observable.ObservableMixin): 96 reader.register_observer(update_bytes_read) 97 98 def _start_reader_threads(self): 99 for _ in range(0, self.num_reader_threads): 100 t = threading.Thread(target=self._reader_thread) 101 t.daemon = True 102 t.start() 103 self.reader_threads.append(t) 104 105 def _reader_thread(self): 106 # pylint: disable=too-many-nested-blocks 107 try: 108 while True: 109 try: 110 source = self.sources_queue.get_nowait() 111 if isinstance(source, iobase.BoundedSource): 112 for value in source.read(source.get_range_tracker(None, None)): 113 if self.has_errored: 114 # If any reader has errored, just return. 115 return 116 if isinstance(value, window.WindowedValue): 117 self.element_queue.put(value) 118 else: 119 self.element_queue.put(_globally_windowed(value)) 120 else: 121 # Native dataflow source. 122 with source.reader() as reader: 123 # The tracking of time spend reading and bytes read from side 124 # inputs is kept behind an experiment flag to test performance 125 # impact. 126 self.add_byte_counter(reader) 127 returns_windowed_values = reader.returns_windowed_values 128 for value in reader: 129 if self.has_errored: 130 # If any reader has errored, just return. 131 return 132 if returns_windowed_values: 133 self.element_queue.put(value) 134 else: 135 self.element_queue.put(_globally_windowed(value)) 136 except queue.Empty: 137 return 138 except Exception as e: # pylint: disable=broad-except 139 _LOGGER.error( 140 'Encountered exception in PrefetchingSourceSetIterable ' 141 'reader thread: %s', 142 traceback.format_exc()) 143 self.reader_exceptions.put(e) 144 self.has_errored = True 145 finally: 146 self.element_queue.put(READER_THREAD_IS_DONE_SENTINEL) 147 148 def __iter__(self): 149 # pylint: disable=too-many-nested-blocks 150 if self.already_iterated: 151 raise RuntimeError( 152 'Can only iterate once over PrefetchingSourceSetIterable instance.') 153 self.already_iterated = True 154 155 # The invariants during execution are: 156 # 1) A worker thread always posts the sentinel as the last thing it does 157 # before exiting. 158 # 2) We always wait for all sentinels and then join all threads. 159 num_readers_finished = 0 160 try: 161 while True: 162 try: 163 with self.read_counter: 164 element = self.element_queue.get() 165 if element is READER_THREAD_IS_DONE_SENTINEL: 166 num_readers_finished += 1 167 if num_readers_finished == self.num_reader_threads: 168 return 169 else: 170 if self.element_counter: 171 self.element_counter.update_from(element) 172 yield element 173 self.element_counter.update_collect() 174 else: 175 yield element 176 finally: 177 if self.has_errored: 178 raise self.reader_exceptions.get() 179 except GeneratorExit: 180 self.has_errored = True 181 raise 182 finally: 183 while num_readers_finished < self.num_reader_threads: 184 element = self.element_queue.get() 185 if element is READER_THREAD_IS_DONE_SENTINEL: 186 num_readers_finished += 1 187 for t in self.reader_threads: 188 t.join() 189 190 191 def get_iterator_fn_for_sources( 192 sources, 193 max_reader_threads=MAX_SOURCE_READER_THREADS, 194 read_counter=None, 195 element_counter=None): 196 """Returns callable that returns iterator over elements for given sources.""" 197 def _inner(): 198 return iter( 199 PrefetchingSourceSetIterable( 200 sources, 201 max_reader_threads=max_reader_threads, 202 read_counter=read_counter, 203 element_counter=element_counter)) 204 205 return _inner 206 207 208 class EmulatedIterable(abc.Iterable): 209 """Emulates an iterable for a side input.""" 210 def __init__(self, iterator_fn): 211 self.iterator_fn = iterator_fn 212 213 def __iter__(self): 214 return self.iterator_fn()