github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/utils.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Utils for the io library.
    19  * CountingSource: Subclass of iobase.BoundedSource. Used
    20  on transforms.ptransform_test.test_read_metrics.
    21  """
    22  
    23  # pytype: skip-file
    24  
    25  from apache_beam.io import iobase
    26  from apache_beam.io.range_trackers import OffsetRangeTracker
    27  from apache_beam.metrics import Metrics
    28  
    29  
    30  class CountingSource(iobase.BoundedSource):
    31    def __init__(self, count):
    32      self.records_read = Metrics.counter(self.__class__, 'recordsRead')
    33      self._count = count
    34  
    35    def estimate_size(self):
    36      return self._count
    37  
    38    def get_range_tracker(self, start_position, stop_position):
    39      if start_position is None:
    40        start_position = 0
    41      if stop_position is None:
    42        stop_position = self._count
    43  
    44      return OffsetRangeTracker(start_position, stop_position)
    45  
    46    def read(self, range_tracker):
    47      for i in range(range_tracker.start_position(),
    48                     range_tracker.stop_position()):
    49        if not range_tracker.try_claim(i):
    50          return
    51        self.records_read.inc()
    52        yield i
    53  
    54    def split(self, desired_bundle_size, start_position=None, stop_position=None):
    55      if start_position is None:
    56        start_position = 0
    57      if stop_position is None:
    58        stop_position = self._count
    59  
    60      bundle_start = start_position
    61      while bundle_start < stop_position:
    62        bundle_stop = min(stop_position, bundle_start + desired_bundle_size)
    63        yield iobase.SourceBundle(
    64            weight=(bundle_stop - bundle_start),
    65            source=self,
    66            start_position=bundle_start,
    67            stop_position=bundle_stop)
    68        bundle_start = bundle_stop