github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/create_source.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 # pytype: skip-file 19 20 from apache_beam.io import iobase 21 from apache_beam.transforms.core import Create 22 23 24 class _CreateSource(iobase.BoundedSource): 25 """Internal source that is used by Create()""" 26 def __init__(self, serialized_values, coder): 27 self._coder = coder 28 self._serialized_values = [] 29 self._total_size = 0 30 self._serialized_values = serialized_values 31 self._total_size = sum(map(len, self._serialized_values)) 32 33 def read(self, range_tracker): 34 start_position = range_tracker.start_position() 35 current_position = start_position 36 37 def split_points_unclaimed(stop_position): 38 if current_position >= stop_position: 39 return 0 40 return stop_position - current_position - 1 41 42 range_tracker.set_split_points_unclaimed_callback(split_points_unclaimed) 43 element_iter = iter(self._serialized_values[start_position:]) 44 for i in range(start_position, range_tracker.stop_position()): 45 if not range_tracker.try_claim(i): 46 return 47 current_position = i 48 yield self._coder.decode(next(element_iter)) 49 50 def split(self, desired_bundle_size, start_position=None, stop_position=None): 51 if len(self._serialized_values) < 2: 52 yield iobase.SourceBundle( 53 weight=0, 54 source=self, 55 start_position=0, 56 stop_position=len(self._serialized_values)) 57 else: 58 if start_position is None: 59 start_position = 0 60 if stop_position is None: 61 stop_position = len(self._serialized_values) 62 avg_size_per_value = self._total_size // len(self._serialized_values) 63 num_values_per_split = max( 64 int(desired_bundle_size // avg_size_per_value), 1) 65 start = start_position 66 while start < stop_position: 67 end = min(start + num_values_per_split, stop_position) 68 remaining = stop_position - end 69 # Avoid having a too small bundle at the end. 70 if remaining < (num_values_per_split // 4): 71 end = stop_position 72 sub_source = Create._create_source( 73 self._serialized_values[start:end], self._coder) 74 yield iobase.SourceBundle( 75 weight=(end - start), 76 source=sub_source, 77 start_position=0, 78 stop_position=(end - start)) 79 start = end 80 81 def get_range_tracker(self, start_position, stop_position): 82 if start_position is None: 83 start_position = 0 84 if stop_position is None: 85 stop_position = len(self._serialized_values) 86 from apache_beam import io 87 return io.OffsetRangeTracker(start_position, stop_position) 88 89 def estimate_size(self): 90 return self._total_size