github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/create_source.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  # pytype: skip-file
    19  
    20  from apache_beam.io import iobase
    21  from apache_beam.transforms.core import Create
    22  
    23  
    24  class _CreateSource(iobase.BoundedSource):
    25    """Internal source that is used by Create()"""
    26    def __init__(self, serialized_values, coder):
    27      self._coder = coder
    28      self._serialized_values = []
    29      self._total_size = 0
    30      self._serialized_values = serialized_values
    31      self._total_size = sum(map(len, self._serialized_values))
    32  
    33    def read(self, range_tracker):
    34      start_position = range_tracker.start_position()
    35      current_position = start_position
    36  
    37      def split_points_unclaimed(stop_position):
    38        if current_position >= stop_position:
    39          return 0
    40        return stop_position - current_position - 1
    41  
    42      range_tracker.set_split_points_unclaimed_callback(split_points_unclaimed)
    43      element_iter = iter(self._serialized_values[start_position:])
    44      for i in range(start_position, range_tracker.stop_position()):
    45        if not range_tracker.try_claim(i):
    46          return
    47        current_position = i
    48        yield self._coder.decode(next(element_iter))
    49  
    50    def split(self, desired_bundle_size, start_position=None, stop_position=None):
    51      if len(self._serialized_values) < 2:
    52        yield iobase.SourceBundle(
    53            weight=0,
    54            source=self,
    55            start_position=0,
    56            stop_position=len(self._serialized_values))
    57      else:
    58        if start_position is None:
    59          start_position = 0
    60        if stop_position is None:
    61          stop_position = len(self._serialized_values)
    62        avg_size_per_value = self._total_size // len(self._serialized_values)
    63        num_values_per_split = max(
    64            int(desired_bundle_size // avg_size_per_value), 1)
    65        start = start_position
    66        while start < stop_position:
    67          end = min(start + num_values_per_split, stop_position)
    68          remaining = stop_position - end
    69          # Avoid having a too small bundle at the end.
    70          if remaining < (num_values_per_split // 4):
    71            end = stop_position
    72          sub_source = Create._create_source(
    73              self._serialized_values[start:end], self._coder)
    74          yield iobase.SourceBundle(
    75              weight=(end - start),
    76              source=sub_source,
    77              start_position=0,
    78              stop_position=(end - start))
    79          start = end
    80  
    81    def get_range_tracker(self, start_position, stop_position):
    82      if start_position is None:
    83        start_position = 0
    84      if stop_position is None:
    85        stop_position = len(self._serialized_values)
    86      from apache_beam import io
    87      return io.OffsetRangeTracker(start_position, stop_position)
    88  
    89    def estimate_size(self):
    90      return self._total_size