github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/sources_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Unit tests for the sources framework.""" 19 # pytype: skip-file 20 21 import logging 22 import os 23 import tempfile 24 import unittest 25 26 import apache_beam as beam 27 from apache_beam import coders 28 from apache_beam.io import iobase 29 from apache_beam.io import range_trackers 30 from apache_beam.testing.test_pipeline import TestPipeline 31 from apache_beam.testing.util import assert_that 32 from apache_beam.testing.util import equal_to 33 34 35 class LineSource(iobase.BoundedSource): 36 """A simple source that reads lines from a given file.""" 37 38 TEST_BUNDLE_SIZE = 10 39 40 def __init__(self, file_name): 41 self._file_name = file_name 42 43 def read(self, range_tracker): 44 with open(self._file_name, 'rb') as f: 45 start = range_tracker.start_position() 46 f.seek(start) 47 if start > 0: 48 f.seek(-1, os.SEEK_CUR) 49 start -= 1 50 start += len(f.readline()) 51 current = start 52 line = f.readline() 53 while range_tracker.try_claim(current): 54 if not line: 55 return 56 yield line.rstrip(b'\n') 57 current += len(line) 58 line = f.readline() 59 60 def split(self, desired_bundle_size, start_position=None, stop_position=None): 61 assert start_position is None 62 assert stop_position is None 63 size = self.estimate_size() 64 65 bundle_start = 0 66 while bundle_start < size: 67 bundle_stop = min(bundle_start + LineSource.TEST_BUNDLE_SIZE, size) 68 yield iobase.SourceBundle( 69 bundle_stop - bundle_start, self, bundle_start, bundle_stop) 70 bundle_start = bundle_stop 71 72 def get_range_tracker(self, start_position, stop_position): 73 if start_position is None: 74 start_position = 0 75 if stop_position is None: 76 stop_position = self._get_file_size() 77 return range_trackers.OffsetRangeTracker(start_position, stop_position) 78 79 def default_output_coder(self): 80 return coders.BytesCoder() 81 82 def estimate_size(self): 83 return self._get_file_size() 84 85 def _get_file_size(self): 86 with open(self._file_name, 'rb') as f: 87 f.seek(0, os.SEEK_END) 88 return f.tell() 89 90 91 class SourcesTest(unittest.TestCase): 92 def _create_temp_file(self, contents): 93 with tempfile.NamedTemporaryFile(delete=False) as f: 94 f.write(contents) 95 return f.name 96 97 def test_read_from_source(self): 98 file_name = self._create_temp_file(b'aaaa\nbbbb\ncccc\ndddd') 99 100 source = LineSource(file_name) 101 range_tracker = source.get_range_tracker(None, None) 102 result = [line for line in source.read(range_tracker)] 103 104 self.assertCountEqual([b'aaaa', b'bbbb', b'cccc', b'dddd'], result) 105 self.assertTrue( 106 range_tracker.last_attempted_record_start >= 107 range_tracker.stop_position()) 108 109 def test_source_estimated_size(self): 110 file_name = self._create_temp_file(b'aaaa\n') 111 112 source = LineSource(file_name) 113 self.assertEqual(5, source.estimate_size()) 114 115 def test_run_direct(self): 116 file_name = self._create_temp_file(b'aaaa\nbbbb\ncccc\ndddd') 117 with TestPipeline() as pipeline: 118 pcoll = pipeline | beam.io.Read(LineSource(file_name)) 119 assert_that(pcoll, equal_to([b'aaaa', b'bbbb', b'cccc', b'dddd'])) 120 121 122 if __name__ == '__main__': 123 logging.getLogger().setLevel(logging.INFO) 124 unittest.main()