github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/concat_source.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """For internal use only; no backwards-compatibility guarantees.
    19  
    20  Concat Source, which reads the union of several other sources.
    21  """
    22  # pytype: skip-file
    23  
    24  import bisect
    25  import threading
    26  
    27  from apache_beam.io import iobase
    28  
    29  
    30  class ConcatSource(iobase.BoundedSource):
    31    """For internal use only; no backwards-compatibility guarantees.
    32  
    33    A ``BoundedSource`` that can group a set of ``BoundedSources``.
    34  
    35    Primarily for internal use, use the ``apache_beam.Flatten`` transform
    36    to create the union of several reads.
    37    """
    38    def __init__(self, sources):
    39      self._source_bundles = [
    40          source if isinstance(source, iobase.SourceBundle) else
    41          iobase.SourceBundle(None, source, None, None) for source in sources
    42      ]
    43  
    44    @property
    45    def sources(self):
    46      return [s.source for s in self._source_bundles]
    47  
    48    def estimate_size(self):
    49      return sum(s.source.estimate_size() for s in self._source_bundles)
    50  
    51    def split(
    52        self, desired_bundle_size=None, start_position=None, stop_position=None):
    53      if start_position or stop_position:
    54        raise ValueError(
    55            'Multi-level initial splitting is not supported. Expected start and '
    56            'stop positions to be None. Received %r and %r respectively.' %
    57            (start_position, stop_position))
    58  
    59      for source in self._source_bundles:
    60        # We assume all sub-sources to produce bundles that specify weight using
    61        # the same unit. For example, all sub-sources may specify the size in
    62        # bytes as their weight.
    63        for bundle in source.source.split(desired_bundle_size,
    64                                          source.start_position,
    65                                          source.stop_position):
    66          yield bundle
    67  
    68    def get_range_tracker(self, start_position=None, stop_position=None):
    69      if start_position is None:
    70        start_position = (0, None)
    71      if stop_position is None:
    72        stop_position = (len(self._source_bundles), None)
    73      return ConcatRangeTracker(
    74          start_position, stop_position, self._source_bundles)
    75  
    76    def read(self, range_tracker):
    77      start_source, _ = range_tracker.start_position()
    78      stop_source, stop_pos = range_tracker.stop_position()
    79      if stop_pos is not None:
    80        stop_source += 1
    81      for source_ix in range(start_source, stop_source):
    82        if not range_tracker.try_claim((source_ix, None)):
    83          break
    84        for record in self._source_bundles[source_ix].source.read(
    85            range_tracker.sub_range_tracker(source_ix)):
    86          yield record
    87  
    88    def default_output_coder(self):
    89      if self._source_bundles:
    90        # Getting coder from the first sub-sources. This assumes all sub-sources
    91        # to produce the same coder.
    92        return self._source_bundles[0].source.default_output_coder()
    93      else:
    94        return super().default_output_coder()
    95  
    96  
    97  class ConcatRangeTracker(iobase.RangeTracker):
    98    """For internal use only; no backwards-compatibility guarantees.
    99  
   100    Range tracker for ConcatSource"""
   101    def __init__(self, start, end, source_bundles):
   102      """Initializes ``ConcatRangeTracker``
   103  
   104      Args:
   105        start: start position, a tuple of (source_index, source_position)
   106        end: end position, a tuple of (source_index, source_position)
   107        source_bundles: the list of source bundles in the ConcatSource
   108      """
   109      super().__init__()
   110      self._start = start
   111      self._end = end
   112      self._source_bundles = source_bundles
   113      self._lock = threading.RLock()
   114      # Lazily-initialized list of RangeTrackers corresponding to each source.
   115      self._range_trackers = [None] * len(source_bundles)
   116      # The currently-being-iterated-over (and latest claimed) source.
   117      self._claimed_source_ix = self._start[0]
   118      # Now compute cumulative progress through the sources for converting
   119      # between global fractions and fractions within specific sources.
   120      # TODO(robertwb): Implement fraction-at-position to properly scale
   121      # partial start and end sources.
   122      # Note, however, that in practice splits are typically on source
   123      # boundaries anyways.
   124      last = end[0] if end[1] is None else end[0] + 1
   125      self._cumulative_weights = (
   126          [0] * start[0] +
   127          self._compute_cumulative_weights(source_bundles[start[0]:last]) + [1] *
   128          (len(source_bundles) - last - start[0]))
   129  
   130    @staticmethod
   131    def _compute_cumulative_weights(source_bundles):
   132      # Two adjacent sources must differ so that they can be uniquely
   133      # identified by a single global fraction.  Let min_diff be the
   134      # smallest allowable difference between sources.
   135      min_diff = 1e-5
   136      # For the computation below, we need weights for all sources.
   137      # Substitute average weights for those whose weights are
   138      # unspecified (or 1.0 for everything if none are known).
   139      known = [s.weight for s in source_bundles if s.weight is not None]
   140      avg = sum(known) / len(known) if known else 1.0
   141      weights = [s.weight or avg for s in source_bundles]
   142  
   143      # Now compute running totals of the percent done upon reaching
   144      # each source, with respect to the start and end positions.
   145      # E.g. if the weights were [100, 20, 3] we would produce
   146      # [0.0, 100/123, 120/123, 1.0]
   147      total = float(sum(weights))
   148      running_total = [0]
   149      for w in weights:
   150        running_total.append(max(min_diff, min(1, running_total[-1] + w / total)))
   151      running_total[-1] = 1  # In case of rounding error.
   152      # There are issues if, due to rouding error or greatly differing sizes,
   153      # two adjacent running total weights are equal. Normalize this things so
   154      # that this never happens.
   155      for k in range(1, len(running_total)):
   156        if running_total[k] == running_total[k - 1]:
   157          for j in range(k):
   158            running_total[j] *= (1 - min_diff)
   159      return running_total
   160  
   161    def start_position(self):
   162      return self._start
   163  
   164    def stop_position(self):
   165      return self._end
   166  
   167    def try_claim(self, pos):
   168      source_ix, source_pos = pos
   169      with self._lock:
   170        if source_ix > self._end[0]:
   171          return False
   172        elif source_ix == self._end[0] and self._end[1] is None:
   173          return False
   174        else:
   175          assert source_ix >= self._claimed_source_ix
   176          self._claimed_source_ix = source_ix
   177          if source_pos is None:
   178            return True
   179          else:
   180            return self.sub_range_tracker(source_ix).try_claim(source_pos)
   181  
   182    def try_split(self, pos):
   183      source_ix, source_pos = pos
   184      with self._lock:
   185        if source_ix < self._claimed_source_ix:
   186          # Already claimed.
   187          return None
   188        elif source_ix > self._end[0]:
   189          # After end.
   190          return None
   191        elif source_ix == self._end[0] and self._end[1] is None:
   192          # At/after end.
   193          return None
   194        else:
   195          if source_ix > self._claimed_source_ix:
   196            # Prefer to split on even boundary.
   197            split_pos = None
   198            ratio = self._cumulative_weights[source_ix]
   199          else:
   200            # Split the current subsource.
   201            split = self.sub_range_tracker(source_ix).try_split(source_pos)
   202            if not split:
   203              return None
   204            split_pos, frac = split
   205            ratio = self.local_to_global(source_ix, frac)
   206  
   207          self._end = source_ix, split_pos
   208          self._cumulative_weights = [
   209              min(w / ratio, 1) for w in self._cumulative_weights
   210          ]
   211          return (source_ix, split_pos), ratio
   212  
   213    def set_current_position(self, pos):
   214      raise NotImplementedError('Should only be called on sub-trackers')
   215  
   216    def position_at_fraction(self, fraction):
   217      source_ix, source_frac = self.global_to_local(fraction)
   218      last = self._end[0] if self._end[1] is None else self._end[0] + 1
   219      if source_ix == last:
   220        return (source_ix, None)
   221      else:
   222        return (
   223            source_ix,
   224            self.sub_range_tracker(source_ix).position_at_fraction(source_frac))
   225  
   226    def fraction_consumed(self):
   227      with self._lock:
   228        if self._claimed_source_ix == len(self._source_bundles):
   229          return 1.0
   230        else:
   231          return self.local_to_global(
   232              self._claimed_source_ix,
   233              self.sub_range_tracker(self._claimed_source_ix).fraction_consumed())
   234  
   235    def local_to_global(self, source_ix, source_frac):
   236      cw = self._cumulative_weights
   237      # The global fraction is the fraction to source_ix plus some portion of
   238      # the way towards the next source.
   239      return cw[source_ix] + source_frac * (cw[source_ix + 1] - cw[source_ix])
   240  
   241    def global_to_local(self, frac):
   242      if frac == 1:
   243        last = self._end[0] if self._end[1] is None else self._end[0] + 1
   244        return (last, None)
   245      else:
   246        cw = self._cumulative_weights
   247        # Find the last source that starts at or before frac.
   248        source_ix = bisect.bisect(cw, frac) - 1
   249        # Return this source, converting what's left of frac after starting
   250        # this source into a value in [0.0, 1.0) representing how far we are
   251        # towards the next source.
   252        return (
   253            source_ix,
   254            (frac - cw[source_ix]) / (cw[source_ix + 1] - cw[source_ix]))
   255  
   256    def sub_range_tracker(self, source_ix):
   257      assert self._start[0] <= source_ix <= self._end[0]
   258      if self._range_trackers[source_ix] is None:
   259        with self._lock:
   260          if self._range_trackers[source_ix] is None:
   261            source = self._source_bundles[source_ix]
   262            if source_ix == self._start[0] and self._start[1] is not None:
   263              start = self._start[1]
   264            else:
   265              start = source.start_position
   266            if source_ix == self._end[0] and self._end[1] is not None:
   267              stop = self._end[1]
   268            else:
   269              stop = source.stop_position
   270            self._range_trackers[source_ix] = source.source.get_range_tracker(
   271                start, stop)
   272      return self._range_trackers[source_ix]