github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/source_test_utils_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 # pytype: skip-file 19 20 import logging 21 import tempfile 22 import unittest 23 24 from apache_beam.io import source_test_utils 25 from apache_beam.io.filebasedsource_test import LineSource 26 27 28 class SourceTestUtilsTest(unittest.TestCase): 29 def _create_file_with_data(self, lines): 30 assert isinstance(lines, list) 31 with tempfile.NamedTemporaryFile(delete=False) as f: 32 for line in lines: 33 f.write(line + b'\n') 34 35 return f.name 36 37 def _create_data(self, num_lines): 38 return [b'line ' + str(i).encode('latin1') for i in range(num_lines)] 39 40 def _create_source(self, data): 41 source = LineSource(self._create_file_with_data(data)) 42 # By performing initial splitting, we can get a source for a single file. 43 # This source, that uses OffsetRangeTracker, is better for testing purposes, 44 # than using the original source for a file-pattern. 45 for bundle in source.split(float('inf')): 46 return bundle.source 47 48 def test_read_from_source(self): 49 data = self._create_data(100) 50 source = self._create_source(data) 51 self.assertCountEqual( 52 data, source_test_utils.read_from_source(source, None, None)) 53 54 def test_source_equals_reference_source(self): 55 data = self._create_data(100) 56 reference_source = self._create_source(data) 57 sources_info = [(split.source, split.start_position, split.stop_position) 58 for split in reference_source.split(desired_bundle_size=50)] 59 if len(sources_info) < 2: 60 raise ValueError( 61 'Test is too trivial since splitting only generated %d' 62 'bundles. Please adjust the test so that at least ' 63 'two splits get generated.' % len(sources_info)) 64 65 source_test_utils.assert_sources_equal_reference_source( 66 (reference_source, None, None), sources_info) 67 68 def test_split_at_fraction_successful(self): 69 data = self._create_data(100) 70 source = self._create_source(data) 71 result1 = source_test_utils.assert_split_at_fraction_behavior( 72 source, 73 10, 74 0.5, 75 source_test_utils.ExpectedSplitOutcome.MUST_SUCCEED_AND_BE_CONSISTENT) 76 result2 = source_test_utils.assert_split_at_fraction_behavior( 77 source, 78 20, 79 0.5, 80 source_test_utils.ExpectedSplitOutcome.MUST_SUCCEED_AND_BE_CONSISTENT) 81 self.assertEqual(result1, result2) 82 self.assertEqual(100, result1[0] + result1[1]) 83 84 result3 = source_test_utils.assert_split_at_fraction_behavior( 85 source, 86 30, 87 0.8, 88 source_test_utils.ExpectedSplitOutcome.MUST_SUCCEED_AND_BE_CONSISTENT) 89 result4 = source_test_utils.assert_split_at_fraction_behavior( 90 source, 91 50, 92 0.8, 93 source_test_utils.ExpectedSplitOutcome.MUST_SUCCEED_AND_BE_CONSISTENT) 94 self.assertEqual(result3, result4) 95 self.assertEqual(100, result3[0] + result4[1]) 96 97 self.assertTrue(result1[0] < result3[0]) 98 self.assertTrue(result1[1] > result3[1]) 99 100 def test_split_at_fraction_fails(self): 101 data = self._create_data(100) 102 source = self._create_source(data) 103 104 result = source_test_utils.assert_split_at_fraction_behavior( 105 source, 90, 0.1, source_test_utils.ExpectedSplitOutcome.MUST_FAIL) 106 self.assertEqual(result[0], 100) 107 self.assertEqual(result[1], -1) 108 109 with self.assertRaises(ValueError): 110 source_test_utils.assert_split_at_fraction_behavior( 111 source, 10, 0.5, source_test_utils.ExpectedSplitOutcome.MUST_FAIL) 112 113 def test_split_at_fraction_binary(self): 114 data = self._create_data(100) 115 source = self._create_source(data) 116 117 stats = source_test_utils.SplitFractionStatistics([], []) 118 source_test_utils.assert_split_at_fraction_binary( 119 source, data, 10, 0.5, None, 0.8, None, stats) 120 121 # These lists should not be empty now. 122 self.assertTrue(stats.successful_fractions) 123 self.assertTrue(stats.non_trivial_fractions) 124 125 def test_split_at_fraction_exhaustive(self): 126 data = self._create_data(10) 127 source = self._create_source(data) 128 source_test_utils.assert_split_at_fraction_exhaustive(source) 129 130 131 if __name__ == '__main__': 132 logging.getLogger().setLevel(logging.INFO) 133 unittest.main()