github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/create_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Unit tests for the Create and _CreateSource classes.""" 19 # pytype: skip-file 20 21 import logging 22 import unittest 23 24 from apache_beam import Create 25 from apache_beam import coders 26 from apache_beam.coders import FastPrimitivesCoder 27 from apache_beam.internal import pickler 28 from apache_beam.io import source_test_utils 29 from apache_beam.testing.test_pipeline import TestPipeline 30 from apache_beam.testing.util import assert_that 31 from apache_beam.testing.util import equal_to 32 33 34 class CreateTest(unittest.TestCase): 35 def setUp(self): 36 self.coder = FastPrimitivesCoder() 37 38 def test_create_transform(self): 39 with TestPipeline() as p: 40 assert_that(p | 'Empty' >> Create([]), equal_to([]), label='empty') 41 assert_that(p | 'One' >> Create([None]), equal_to([None]), label='one') 42 assert_that(p | Create(list(range(10))), equal_to(list(range(10)))) 43 44 def test_create_source_read(self): 45 self.check_read([], self.coder) 46 self.check_read([1], self.coder) 47 # multiple values. 48 self.check_read(list(range(10)), self.coder) 49 50 def check_read(self, values, coder): 51 source = Create._create_source_from_iterable(values, coder) 52 read_values = source_test_utils.read_from_source(source) 53 self.assertEqual(sorted(values), sorted(read_values)) 54 55 def test_create_source_read_with_initial_splits(self): 56 self.check_read_with_initial_splits([], self.coder, num_splits=2) 57 self.check_read_with_initial_splits([1], self.coder, num_splits=2) 58 values = list(range(8)) 59 # multiple values with a single split. 60 self.check_read_with_initial_splits(values, self.coder, num_splits=1) 61 # multiple values with a single split with a large desired bundle size 62 self.check_read_with_initial_splits(values, self.coder, num_splits=0.5) 63 # multiple values with many splits. 64 self.check_read_with_initial_splits(values, self.coder, num_splits=3) 65 # multiple values with uneven sized splits. 66 self.check_read_with_initial_splits(values, self.coder, num_splits=4) 67 # multiple values with num splits equal to num values. 68 self.check_read_with_initial_splits( 69 values, self.coder, num_splits=len(values)) 70 # multiple values with num splits greater than to num values. 71 self.check_read_with_initial_splits(values, self.coder, num_splits=30) 72 73 def check_read_with_initial_splits(self, values, coder, num_splits): 74 """A test that splits the given source into `num_splits` and verifies that 75 the data read from original source is equal to the union of the data read 76 from the split sources. 77 """ 78 source = Create._create_source_from_iterable(values, coder) 79 desired_bundle_size = source._total_size // num_splits 80 splits = source.split(desired_bundle_size) 81 splits_info = [(split.source, split.start_position, split.stop_position) 82 for split in splits] 83 source_test_utils.assert_sources_equal_reference_source( 84 (source, None, None), splits_info) 85 86 def test_create_source_read_reentrant(self): 87 source = Create._create_source_from_iterable(range(9), self.coder) 88 source_test_utils.assert_reentrant_reads_succeed((source, None, None)) 89 90 def test_create_source_read_reentrant_with_initial_splits(self): 91 source = Create._create_source_from_iterable(range(24), self.coder) 92 for split in source.split(desired_bundle_size=5): 93 source_test_utils.assert_reentrant_reads_succeed( 94 (split.source, split.start_position, split.stop_position)) 95 96 def test_create_source_dynamic_splitting(self): 97 # 2 values 98 source = Create._create_source_from_iterable(range(2), self.coder) 99 source_test_utils.assert_split_at_fraction_exhaustive(source) 100 # Multiple values. 101 source = Create._create_source_from_iterable(range(11), self.coder) 102 source_test_utils.assert_split_at_fraction_exhaustive( 103 source, perform_multi_threaded_test=True) 104 105 def test_create_source_progress(self): 106 num_values = 10 107 source = Create._create_source_from_iterable(range(num_values), self.coder) 108 splits = [split for split in source.split(desired_bundle_size=100)] 109 assert len(splits) == 1 110 fraction_consumed_report = [] 111 split_points_report = [] 112 range_tracker = splits[0].source.get_range_tracker( 113 splits[0].start_position, splits[0].stop_position) 114 for _ in splits[0].source.read(range_tracker): 115 fraction_consumed_report.append(range_tracker.fraction_consumed()) 116 split_points_report.append(range_tracker.split_points()) 117 118 self.assertEqual([float(i) / num_values for i in range(num_values)], 119 fraction_consumed_report) 120 121 expected_split_points_report = [((i - 1), num_values - (i - 1)) 122 for i in range(1, num_values + 1)] 123 124 self.assertEqual(expected_split_points_report, split_points_report) 125 126 def test_create_uses_coder_for_pickling(self): 127 coders.registry.register_coder(_Unpicklable, _UnpicklableCoder) 128 create = Create([_Unpicklable(1), _Unpicklable(2), _Unpicklable(3)]) 129 unpickled_create = pickler.loads(pickler.dumps(create)) 130 self.assertEqual( 131 sorted(create.values, key=lambda v: v.value), 132 sorted(unpickled_create.values, key=lambda v: v.value)) 133 134 with self.assertRaises(NotImplementedError): 135 # As there is no special coder for Union types, this will fall back to 136 # FastPrimitivesCoder, which in turn falls back to pickling. 137 create_mixed_types = Create([_Unpicklable(1), 2]) 138 pickler.dumps(create_mixed_types) 139 140 141 class _Unpicklable(object): 142 def __init__(self, value): 143 self.value = value 144 145 def __eq__(self, other): 146 return self.value == other.value 147 148 def __getstate__(self): 149 raise NotImplementedError() 150 151 def __setstate__(self, state): 152 raise NotImplementedError() 153 154 155 class _UnpicklableCoder(coders.Coder): 156 def encode(self, value): 157 return str(value.value).encode() 158 159 def decode(self, encoded): 160 return _Unpicklable(int(encoded.decode())) 161 162 def to_type_hint(self): 163 return _Unpicklable 164 165 166 if __name__ == '__main__': 167 logging.getLogger().setLevel(logging.INFO) 168 unittest.main()