github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/batch_dofn_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """UnitTests for Batched DoFn (process_batch) API.""" 19 20 # pytype: skip-file 21 22 import unittest 23 from typing import Iterator 24 from typing import List 25 from typing import Tuple 26 from typing import no_type_check 27 28 from parameterized import parameterized_class 29 30 import apache_beam as beam 31 32 33 class ElementDoFn(beam.DoFn): 34 def process(self, element: int, *args, **kwargs) -> Iterator[float]: 35 yield element / 2 36 37 38 class BatchDoFn(beam.DoFn): 39 def process_batch(self, batch: List[int], *args, 40 **kwargs) -> Iterator[List[float]]: 41 yield [element / 2 for element in batch] 42 43 44 class NoReturnAnnotation(beam.DoFn): 45 def process_batch(self, batch: List[int], *args, **kwargs): 46 yield [element * 2 for element in batch] 47 48 49 class OverrideTypeInference(beam.DoFn): 50 def process_batch(self, batch, *args, **kwargs): 51 yield [element * 2 for element in batch] 52 53 def get_input_batch_type(self, input_element_type): 54 return List[input_element_type] 55 56 def get_output_batch_type(self, input_element_type): 57 return List[input_element_type] 58 59 60 class EitherDoFn(beam.DoFn): 61 def process(self, element: int, *args, **kwargs) -> Iterator[float]: 62 yield element / 2 63 64 def process_batch(self, batch: List[int], *args, 65 **kwargs) -> Iterator[List[float]]: 66 yield [element / 2 for element in batch] 67 68 69 class ElementToBatchDoFn(beam.DoFn): 70 @beam.DoFn.yields_batches 71 def process(self, element: int, *args, **kwargs) -> Iterator[List[int]]: 72 yield [element] * element 73 74 def infer_output_type(self, input_element_type): 75 return input_element_type 76 77 78 class BatchToElementDoFn(beam.DoFn): 79 @beam.DoFn.yields_elements 80 def process_batch(self, batch: List[int], *args, 81 **kwargs) -> Iterator[Tuple[int, int]]: 82 yield (sum(batch), len(batch)) 83 84 85 def get_test_class_name(cls, num, params_dict): 86 return "%s_%s" % (cls.__name__, params_dict['dofn'].__class__.__name__) 87 88 89 @parameterized_class([ 90 { 91 "dofn": ElementDoFn(), 92 "input_element_type": int, 93 "expected_process_defined": True, 94 "expected_process_batch_defined": False, 95 "expected_input_batch_type": None, 96 "expected_output_batch_type": None 97 }, 98 { 99 "dofn": BatchDoFn(), 100 "input_element_type": int, 101 "expected_process_defined": False, 102 "expected_process_batch_defined": True, 103 "expected_input_batch_type": beam.typehints.List[int], 104 "expected_output_batch_type": beam.typehints.List[float] 105 }, 106 { 107 "dofn": NoReturnAnnotation(), 108 "input_element_type": int, 109 "expected_process_defined": False, 110 "expected_process_batch_defined": True, 111 "expected_input_batch_type": beam.typehints.List[int], 112 "expected_output_batch_type": beam.typehints.List[int] 113 }, 114 { 115 "dofn": OverrideTypeInference(), 116 "input_element_type": int, 117 "expected_process_defined": False, 118 "expected_process_batch_defined": True, 119 "expected_input_batch_type": beam.typehints.List[int], 120 "expected_output_batch_type": beam.typehints.List[int] 121 }, 122 { 123 "dofn": EitherDoFn(), 124 "input_element_type": int, 125 "expected_process_defined": True, 126 "expected_process_batch_defined": True, 127 "expected_input_batch_type": beam.typehints.List[int], 128 "expected_output_batch_type": beam.typehints.List[float] 129 }, 130 { 131 "dofn": ElementToBatchDoFn(), 132 "input_element_type": int, 133 "expected_process_defined": True, 134 "expected_process_batch_defined": False, 135 "expected_input_batch_type": None, 136 "expected_output_batch_type": beam.typehints.List[int] 137 }, 138 { 139 "dofn": BatchToElementDoFn(), 140 "input_element_type": int, 141 "expected_process_defined": False, 142 "expected_process_batch_defined": True, 143 "expected_input_batch_type": beam.typehints.List[int], 144 "expected_output_batch_type": None, 145 }, 146 ], 147 class_name_func=get_test_class_name) 148 class BatchDoFnParameterizedTest(unittest.TestCase): 149 def test_process_defined(self): 150 self.assertEqual(self.dofn._process_defined, self.expected_process_defined) 151 152 def test_process_batch_defined(self): 153 self.assertEqual( 154 self.dofn._process_batch_defined, self.expected_process_batch_defined) 155 156 def test_get_input_batch_type(self): 157 self.assertEqual( 158 self.dofn._get_input_batch_type_normalized(self.input_element_type), 159 self.expected_input_batch_type) 160 161 def test_get_output_batch_type(self): 162 self.assertEqual( 163 self.dofn._get_output_batch_type_normalized(self.input_element_type), 164 self.expected_output_batch_type) 165 166 def test_can_yield_batches(self): 167 expected = self.expected_output_batch_type is not None 168 self.assertEqual(self.dofn._can_yield_batches, expected) 169 170 171 class NoInputAnnotation(beam.DoFn): 172 def process_batch(self, batch, *args, **kwargs): 173 yield [element * 2 for element in batch] 174 175 176 class MismatchedBatchProducingDoFn(beam.DoFn): 177 """A DoFn that produces batches from both process and process_batch, with 178 mismatched return types (one yields floats, the other ints). Should yield 179 a construction time error when applied.""" 180 @beam.DoFn.yields_batches 181 def process(self, element: int, *args, **kwargs) -> Iterator[List[int]]: 182 yield [element] 183 184 def process_batch(self, batch: List[int], *args, 185 **kwargs) -> Iterator[List[float]]: 186 yield [element / 2 for element in batch] 187 188 189 class MismatchedElementProducingDoFn(beam.DoFn): 190 """A DoFn that produces elements from both process and process_batch, with 191 mismatched return types (one yields floats, the other ints). Should yield 192 a construction time error when applied.""" 193 def process(self, element: int, *args, **kwargs) -> Iterator[float]: 194 yield element / 2 195 196 @beam.DoFn.yields_elements 197 def process_batch(self, batch: List[int], *args, **kwargs) -> Iterator[int]: 198 yield batch[0] 199 200 201 class NoElementOutputAnnotation(beam.DoFn): 202 def process_batch(self, batch: List[int], *args, 203 **kwargs) -> Iterator[List[int]]: 204 yield [element * 2 for element in batch] 205 206 207 class BatchDoFnTest(unittest.TestCase): 208 def test_map_pardo(self): 209 # verify batch dofn accessors work well with beam.Map generated DoFn 210 # checking this in parameterized test causes a circular reference issue 211 dofn = beam.Map(lambda x: x * 2).dofn 212 213 self.assertTrue(dofn._process_defined) 214 self.assertFalse(dofn._process_batch_defined) 215 self.assertEqual(dofn._get_input_batch_type_normalized(int), None) 216 self.assertEqual(dofn._get_output_batch_type_normalized(int), None) 217 218 def test_no_input_annotation_raises(self): 219 p = beam.Pipeline() 220 pc = p | beam.Create([1, 2, 3]) 221 222 with self.assertRaisesRegex(TypeError, r'NoInputAnnotation.process_batch'): 223 _ = pc | beam.ParDo(NoInputAnnotation()) 224 225 def test_unsupported_dofn_param_raises(self): 226 class BadParam(beam.DoFn): 227 @no_type_check 228 def process_batch(self, batch: List[int], key=beam.DoFn.KeyParam): 229 yield batch * key 230 231 p = beam.Pipeline() 232 pc = p | beam.Create([1, 2, 3]) 233 234 with self.assertRaisesRegex(NotImplementedError, r'BadParam.*KeyParam'): 235 _ = pc | beam.ParDo(BadParam()) 236 237 def test_mismatched_batch_producer_raises(self): 238 p = beam.Pipeline() 239 pc = p | beam.Create([1, 2, 3]) 240 241 # Note (?ms) makes this a multiline regex, where . matches newlines. 242 # See (?aiLmsux) at 243 # https://docs.python.org/3.4/library/re.html#regular-expression-syntax 244 with self.assertRaisesRegex( 245 TypeError, 246 (r'(?ms)MismatchedBatchProducingDoFn.*' 247 r'process: List\[<class \'int\'>\].*process_batch: ' 248 r'List\[<class \'float\'>\]')): 249 _ = pc | beam.ParDo(MismatchedBatchProducingDoFn()) 250 251 def test_mismatched_element_producer_raises(self): 252 p = beam.Pipeline() 253 pc = p | beam.Create([1, 2, 3]) 254 255 # Note (?ms) makes this a multiline regex, where . matches newlines. 256 # See (?aiLmsux) at 257 # https://docs.python.org/3.4/library/re.html#regular-expression-syntax 258 with self.assertRaisesRegex( 259 TypeError, 260 r'(?ms)MismatchedElementProducingDoFn.*process:.*process_batch:'): 261 _ = pc | beam.ParDo(MismatchedElementProducingDoFn()) 262 263 def test_cant_infer_batchconverter_input_raises(self): 264 p = beam.Pipeline() 265 pc = p | beam.Create(['a', 'b', 'c']) 266 267 with self.assertRaisesRegex( 268 TypeError, 269 # Error should mention "input", and the name of the DoFn 270 r'input.*BatchDoFn.*'): 271 _ = pc | beam.ParDo(BatchDoFn()) 272 273 def test_cant_infer_batchconverter_output_raises(self): 274 p = beam.Pipeline() 275 pc = p | beam.Create([1, 2, 3]) 276 277 with self.assertRaisesRegex( 278 TypeError, 279 # Error should mention "output", the name of the DoFn, and suggest 280 # overriding DoFn.infer_output_type 281 r'output.*NoElementOutputAnnotation.*DoFn\.infer_output_type'): 282 _ = pc | beam.ParDo(NoElementOutputAnnotation()) 283 284 def test_element_to_batch_dofn_typehint(self): 285 # Verify that element to batch DoFn sets the correct typehint on the output 286 # PCollection. 287 288 p = beam.Pipeline() 289 pc = (p | beam.Create([1, 2, 3]) | beam.ParDo(ElementToBatchDoFn())) 290 291 self.assertEqual(pc.element_type, int) 292 293 def test_batch_to_element_dofn_typehint(self): 294 # Verify that batch to element DoFn sets the correct typehint on the output 295 # PCollection. 296 297 p = beam.Pipeline() 298 pc = (p | beam.Create([1, 2, 3]) | beam.ParDo(BatchToElementDoFn())) 299 300 self.assertEqual(pc.element_type, beam.typehints.Tuple[int, int]) 301 302 303 if __name__ == '__main__': 304 unittest.main()