github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/concat_source.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """For internal use only; no backwards-compatibility guarantees. 19 20 Concat Source, which reads the union of several other sources. 21 """ 22 # pytype: skip-file 23 24 import bisect 25 import threading 26 27 from apache_beam.io import iobase 28 29 30 class ConcatSource(iobase.BoundedSource): 31 """For internal use only; no backwards-compatibility guarantees. 32 33 A ``BoundedSource`` that can group a set of ``BoundedSources``. 34 35 Primarily for internal use, use the ``apache_beam.Flatten`` transform 36 to create the union of several reads. 37 """ 38 def __init__(self, sources): 39 self._source_bundles = [ 40 source if isinstance(source, iobase.SourceBundle) else 41 iobase.SourceBundle(None, source, None, None) for source in sources 42 ] 43 44 @property 45 def sources(self): 46 return [s.source for s in self._source_bundles] 47 48 def estimate_size(self): 49 return sum(s.source.estimate_size() for s in self._source_bundles) 50 51 def split( 52 self, desired_bundle_size=None, start_position=None, stop_position=None): 53 if start_position or stop_position: 54 raise ValueError( 55 'Multi-level initial splitting is not supported. Expected start and ' 56 'stop positions to be None. Received %r and %r respectively.' % 57 (start_position, stop_position)) 58 59 for source in self._source_bundles: 60 # We assume all sub-sources to produce bundles that specify weight using 61 # the same unit. For example, all sub-sources may specify the size in 62 # bytes as their weight. 63 for bundle in source.source.split(desired_bundle_size, 64 source.start_position, 65 source.stop_position): 66 yield bundle 67 68 def get_range_tracker(self, start_position=None, stop_position=None): 69 if start_position is None: 70 start_position = (0, None) 71 if stop_position is None: 72 stop_position = (len(self._source_bundles), None) 73 return ConcatRangeTracker( 74 start_position, stop_position, self._source_bundles) 75 76 def read(self, range_tracker): 77 start_source, _ = range_tracker.start_position() 78 stop_source, stop_pos = range_tracker.stop_position() 79 if stop_pos is not None: 80 stop_source += 1 81 for source_ix in range(start_source, stop_source): 82 if not range_tracker.try_claim((source_ix, None)): 83 break 84 for record in self._source_bundles[source_ix].source.read( 85 range_tracker.sub_range_tracker(source_ix)): 86 yield record 87 88 def default_output_coder(self): 89 if self._source_bundles: 90 # Getting coder from the first sub-sources. This assumes all sub-sources 91 # to produce the same coder. 92 return self._source_bundles[0].source.default_output_coder() 93 else: 94 return super().default_output_coder() 95 96 97 class ConcatRangeTracker(iobase.RangeTracker): 98 """For internal use only; no backwards-compatibility guarantees. 99 100 Range tracker for ConcatSource""" 101 def __init__(self, start, end, source_bundles): 102 """Initializes ``ConcatRangeTracker`` 103 104 Args: 105 start: start position, a tuple of (source_index, source_position) 106 end: end position, a tuple of (source_index, source_position) 107 source_bundles: the list of source bundles in the ConcatSource 108 """ 109 super().__init__() 110 self._start = start 111 self._end = end 112 self._source_bundles = source_bundles 113 self._lock = threading.RLock() 114 # Lazily-initialized list of RangeTrackers corresponding to each source. 115 self._range_trackers = [None] * len(source_bundles) 116 # The currently-being-iterated-over (and latest claimed) source. 117 self._claimed_source_ix = self._start[0] 118 # Now compute cumulative progress through the sources for converting 119 # between global fractions and fractions within specific sources. 120 # TODO(robertwb): Implement fraction-at-position to properly scale 121 # partial start and end sources. 122 # Note, however, that in practice splits are typically on source 123 # boundaries anyways. 124 last = end[0] if end[1] is None else end[0] + 1 125 self._cumulative_weights = ( 126 [0] * start[0] + 127 self._compute_cumulative_weights(source_bundles[start[0]:last]) + [1] * 128 (len(source_bundles) - last - start[0])) 129 130 @staticmethod 131 def _compute_cumulative_weights(source_bundles): 132 # Two adjacent sources must differ so that they can be uniquely 133 # identified by a single global fraction. Let min_diff be the 134 # smallest allowable difference between sources. 135 min_diff = 1e-5 136 # For the computation below, we need weights for all sources. 137 # Substitute average weights for those whose weights are 138 # unspecified (or 1.0 for everything if none are known). 139 known = [s.weight for s in source_bundles if s.weight is not None] 140 avg = sum(known) / len(known) if known else 1.0 141 weights = [s.weight or avg for s in source_bundles] 142 143 # Now compute running totals of the percent done upon reaching 144 # each source, with respect to the start and end positions. 145 # E.g. if the weights were [100, 20, 3] we would produce 146 # [0.0, 100/123, 120/123, 1.0] 147 total = float(sum(weights)) 148 running_total = [0] 149 for w in weights: 150 running_total.append(max(min_diff, min(1, running_total[-1] + w / total))) 151 running_total[-1] = 1 # In case of rounding error. 152 # There are issues if, due to rouding error or greatly differing sizes, 153 # two adjacent running total weights are equal. Normalize this things so 154 # that this never happens. 155 for k in range(1, len(running_total)): 156 if running_total[k] == running_total[k - 1]: 157 for j in range(k): 158 running_total[j] *= (1 - min_diff) 159 return running_total 160 161 def start_position(self): 162 return self._start 163 164 def stop_position(self): 165 return self._end 166 167 def try_claim(self, pos): 168 source_ix, source_pos = pos 169 with self._lock: 170 if source_ix > self._end[0]: 171 return False 172 elif source_ix == self._end[0] and self._end[1] is None: 173 return False 174 else: 175 assert source_ix >= self._claimed_source_ix 176 self._claimed_source_ix = source_ix 177 if source_pos is None: 178 return True 179 else: 180 return self.sub_range_tracker(source_ix).try_claim(source_pos) 181 182 def try_split(self, pos): 183 source_ix, source_pos = pos 184 with self._lock: 185 if source_ix < self._claimed_source_ix: 186 # Already claimed. 187 return None 188 elif source_ix > self._end[0]: 189 # After end. 190 return None 191 elif source_ix == self._end[0] and self._end[1] is None: 192 # At/after end. 193 return None 194 else: 195 if source_ix > self._claimed_source_ix: 196 # Prefer to split on even boundary. 197 split_pos = None 198 ratio = self._cumulative_weights[source_ix] 199 else: 200 # Split the current subsource. 201 split = self.sub_range_tracker(source_ix).try_split(source_pos) 202 if not split: 203 return None 204 split_pos, frac = split 205 ratio = self.local_to_global(source_ix, frac) 206 207 self._end = source_ix, split_pos 208 self._cumulative_weights = [ 209 min(w / ratio, 1) for w in self._cumulative_weights 210 ] 211 return (source_ix, split_pos), ratio 212 213 def set_current_position(self, pos): 214 raise NotImplementedError('Should only be called on sub-trackers') 215 216 def position_at_fraction(self, fraction): 217 source_ix, source_frac = self.global_to_local(fraction) 218 last = self._end[0] if self._end[1] is None else self._end[0] + 1 219 if source_ix == last: 220 return (source_ix, None) 221 else: 222 return ( 223 source_ix, 224 self.sub_range_tracker(source_ix).position_at_fraction(source_frac)) 225 226 def fraction_consumed(self): 227 with self._lock: 228 if self._claimed_source_ix == len(self._source_bundles): 229 return 1.0 230 else: 231 return self.local_to_global( 232 self._claimed_source_ix, 233 self.sub_range_tracker(self._claimed_source_ix).fraction_consumed()) 234 235 def local_to_global(self, source_ix, source_frac): 236 cw = self._cumulative_weights 237 # The global fraction is the fraction to source_ix plus some portion of 238 # the way towards the next source. 239 return cw[source_ix] + source_frac * (cw[source_ix + 1] - cw[source_ix]) 240 241 def global_to_local(self, frac): 242 if frac == 1: 243 last = self._end[0] if self._end[1] is None else self._end[0] + 1 244 return (last, None) 245 else: 246 cw = self._cumulative_weights 247 # Find the last source that starts at or before frac. 248 source_ix = bisect.bisect(cw, frac) - 1 249 # Return this source, converting what's left of frac after starting 250 # this source into a value in [0.0, 1.0) representing how far we are 251 # towards the next source. 252 return ( 253 source_ix, 254 (frac - cw[source_ix]) / (cw[source_ix + 1] - cw[source_ix])) 255 256 def sub_range_tracker(self, source_ix): 257 assert self._start[0] <= source_ix <= self._end[0] 258 if self._range_trackers[source_ix] is None: 259 with self._lock: 260 if self._range_trackers[source_ix] is None: 261 source = self._source_bundles[source_ix] 262 if source_ix == self._start[0] and self._start[1] is not None: 263 start = self._start[1] 264 else: 265 start = source.start_position 266 if source_ix == self._end[0] and self._end[1] is not None: 267 stop = self._end[1] 268 else: 269 stop = source.stop_position 270 self._range_trackers[source_ix] = source.source.get_range_tracker( 271 start, stop) 272 return self._range_trackers[source_ix]