github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/ptransform_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Unit tests for the PTransform and descendants.""" 19 20 # pytype: skip-file 21 22 import collections 23 import operator 24 import os 25 import pickle 26 import random 27 import re 28 import typing 29 import unittest 30 from functools import reduce 31 from typing import Optional 32 from unittest.mock import patch 33 34 import hamcrest as hc 35 import numpy as np 36 import pytest 37 from parameterized import parameterized_class 38 39 import apache_beam as beam 40 import apache_beam.transforms.combiners as combine 41 from apache_beam import pvalue 42 from apache_beam import typehints 43 from apache_beam.io.iobase import Read 44 from apache_beam.metrics import Metrics 45 from apache_beam.metrics.metric import MetricsFilter 46 from apache_beam.options.pipeline_options import PipelineOptions 47 from apache_beam.options.pipeline_options import TypeOptions 48 from apache_beam.portability import common_urns 49 from apache_beam.testing.test_pipeline import TestPipeline 50 from apache_beam.testing.test_stream import TestStream 51 from apache_beam.testing.util import SortLists 52 from apache_beam.testing.util import assert_that 53 from apache_beam.testing.util import equal_to 54 from apache_beam.transforms import WindowInto 55 from apache_beam.transforms import trigger 56 from apache_beam.transforms import window 57 from apache_beam.transforms.display import DisplayData 58 from apache_beam.transforms.display import DisplayDataItem 59 from apache_beam.transforms.ptransform import PTransform 60 from apache_beam.transforms.window import TimestampedValue 61 from apache_beam.typehints import with_input_types 62 from apache_beam.typehints import with_output_types 63 from apache_beam.typehints.typehints_test import TypeHintTestCase 64 from apache_beam.utils.timestamp import Timestamp 65 from apache_beam.utils.windowed_value import WindowedValue 66 67 # Disable frequent lint warning due to pipe operator for chaining transforms. 68 # pylint: disable=expression-not-assigned 69 70 71 class PTransformTest(unittest.TestCase): 72 def assertStartswith(self, msg, prefix): 73 self.assertTrue( 74 msg.startswith(prefix), '"%s" does not start with "%s"' % (msg, prefix)) 75 76 def test_str(self): 77 self.assertEqual( 78 '<PTransform(PTransform) label=[PTransform]>', str(PTransform())) 79 80 pa = TestPipeline() 81 res = pa | 'ALabel' >> beam.Impulse() 82 self.assertEqual('AppliedPTransform(ALabel, Impulse)', str(res.producer)) 83 84 pc = TestPipeline() 85 res = pc | beam.Impulse() 86 inputs_tr = res.producer.transform 87 inputs_tr.inputs = ('ci', ) 88 self.assertEqual( 89 "<Impulse(PTransform) label=[Impulse] inputs=('ci',)>", str(inputs_tr)) 90 91 pd = TestPipeline() 92 res = pd | beam.Impulse() 93 side_tr = res.producer.transform 94 side_tr.side_inputs = (4, ) 95 self.assertEqual( 96 '<Impulse(PTransform) label=[Impulse] side_inputs=(4,)>', str(side_tr)) 97 98 inputs_tr.side_inputs = ('cs', ) 99 self.assertEqual( 100 """<Impulse(PTransform) label=[Impulse] """ 101 """inputs=('ci',) side_inputs=('cs',)>""", 102 str(inputs_tr)) 103 104 def test_do_with_do_fn(self): 105 class AddNDoFn(beam.DoFn): 106 def process(self, element, addon): 107 return [element + addon] 108 109 with TestPipeline() as pipeline: 110 pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3]) 111 result = pcoll | 'Do' >> beam.ParDo(AddNDoFn(), 10) 112 assert_that(result, equal_to([11, 12, 13])) 113 114 def test_do_with_unconstructed_do_fn(self): 115 class MyDoFn(beam.DoFn): 116 def process(self): 117 pass 118 119 with self.assertRaises(ValueError): 120 with TestPipeline() as pipeline: 121 pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3]) 122 pcoll | 'Do' >> beam.ParDo(MyDoFn) # Note the lack of ()'s 123 124 def test_do_with_callable(self): 125 with TestPipeline() as pipeline: 126 pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3]) 127 result = pcoll | 'Do' >> beam.FlatMap(lambda x, addon: [x + addon], 10) 128 assert_that(result, equal_to([11, 12, 13])) 129 130 def test_do_with_side_input_as_arg(self): 131 with TestPipeline() as pipeline: 132 side = pipeline | 'Side' >> beam.Create([10]) 133 pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3]) 134 result = pcoll | 'Do' >> beam.FlatMap( 135 lambda x, addon: [x + addon], pvalue.AsSingleton(side)) 136 assert_that(result, equal_to([11, 12, 13])) 137 138 def test_do_with_side_input_as_keyword_arg(self): 139 with TestPipeline() as pipeline: 140 side = pipeline | 'Side' >> beam.Create([10]) 141 pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3]) 142 result = pcoll | 'Do' >> beam.FlatMap( 143 lambda x, addon: [x + addon], addon=pvalue.AsSingleton(side)) 144 assert_that(result, equal_to([11, 12, 13])) 145 146 def test_do_with_do_fn_returning_string_raises_warning(self): 147 with self.assertRaises(typehints.TypeCheckError) as cm: 148 with TestPipeline() as pipeline: 149 pipeline._options.view_as(TypeOptions).runtime_type_check = True 150 pcoll = pipeline | 'Start' >> beam.Create(['2', '9', '3']) 151 pcoll | 'Do' >> beam.FlatMap(lambda x: x + '1') 152 153 # Since the DoFn directly returns a string we should get an 154 # error warning us when the pipeliene runs. 155 156 expected_error_prefix = ( 157 'Returning a str from a ParDo or FlatMap ' 158 'is discouraged.') 159 self.assertStartswith(cm.exception.args[0], expected_error_prefix) 160 161 def test_do_with_do_fn_returning_dict_raises_warning(self): 162 with self.assertRaises(typehints.TypeCheckError) as cm: 163 with TestPipeline() as pipeline: 164 pipeline._options.view_as(TypeOptions).runtime_type_check = True 165 pcoll = pipeline | 'Start' >> beam.Create(['2', '9', '3']) 166 pcoll | 'Do' >> beam.FlatMap(lambda x: {x: '1'}) 167 168 # Since the DoFn directly returns a dict we should get an error warning 169 # us when the pipeliene runs. 170 171 expected_error_prefix = ( 172 'Returning a dict from a ParDo or FlatMap ' 173 'is discouraged.') 174 self.assertStartswith(cm.exception.args[0], expected_error_prefix) 175 176 def test_do_with_multiple_outputs_maintains_unique_name(self): 177 with TestPipeline() as pipeline: 178 pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3]) 179 r1 = pcoll | 'A' >> beam.FlatMap(lambda x: [x + 1]).with_outputs(main='m') 180 r2 = pcoll | 'B' >> beam.FlatMap(lambda x: [x + 2]).with_outputs(main='m') 181 assert_that(r1.m, equal_to([2, 3, 4]), label='r1') 182 assert_that(r2.m, equal_to([3, 4, 5]), label='r2') 183 184 @pytest.mark.it_validatesrunner 185 def test_impulse(self): 186 with TestPipeline() as pipeline: 187 result = pipeline | beam.Impulse() | beam.Map(lambda _: 0) 188 assert_that(result, equal_to([0])) 189 190 # TODO(BEAM-3544): Disable this test in streaming temporarily. 191 # Remove sickbay-streaming tag after it's resolved. 192 @pytest.mark.no_sickbay_streaming 193 @pytest.mark.it_validatesrunner 194 def test_read_metrics(self): 195 from apache_beam.io.utils import CountingSource 196 197 class CounterDoFn(beam.DoFn): 198 def __init__(self): 199 # This counter is unused. 200 self.received_records = Metrics.counter( 201 self.__class__, 'receivedRecords') 202 203 def process(self, element): 204 self.received_records.inc() 205 206 pipeline = TestPipeline() 207 (pipeline | Read(CountingSource(100)) | beam.ParDo(CounterDoFn())) 208 res = pipeline.run() 209 res.wait_until_finish() 210 # This counter is defined in utils.CountingSource. 211 metric_results = res.metrics().query( 212 MetricsFilter().with_name('recordsRead')) 213 outputs_counter = metric_results['counters'][0] 214 self.assertStartswith(outputs_counter.key.step, 'Read') 215 self.assertEqual(outputs_counter.key.metric.name, 'recordsRead') 216 self.assertEqual(outputs_counter.committed, 100) 217 self.assertEqual(outputs_counter.attempted, 100) 218 219 @pytest.mark.it_validatesrunner 220 def test_par_do_with_multiple_outputs_and_using_yield(self): 221 class SomeDoFn(beam.DoFn): 222 """A custom DoFn using yield.""" 223 def process(self, element): 224 yield element 225 if element % 2 == 0: 226 yield pvalue.TaggedOutput('even', element) 227 else: 228 yield pvalue.TaggedOutput('odd', element) 229 230 with TestPipeline() as pipeline: 231 nums = pipeline | 'Some Numbers' >> beam.Create([1, 2, 3, 4]) 232 results = nums | 'ClassifyNumbers' >> beam.ParDo(SomeDoFn()).with_outputs( 233 'odd', 'even', main='main') 234 assert_that(results.main, equal_to([1, 2, 3, 4])) 235 assert_that(results.odd, equal_to([1, 3]), label='assert:odd') 236 assert_that(results.even, equal_to([2, 4]), label='assert:even') 237 238 @pytest.mark.it_validatesrunner 239 def test_par_do_with_multiple_outputs_and_using_return(self): 240 def some_fn(v): 241 if v % 2 == 0: 242 return [v, pvalue.TaggedOutput('even', v)] 243 return [v, pvalue.TaggedOutput('odd', v)] 244 245 with TestPipeline() as pipeline: 246 nums = pipeline | 'Some Numbers' >> beam.Create([1, 2, 3, 4]) 247 results = nums | 'ClassifyNumbers' >> beam.FlatMap(some_fn).with_outputs( 248 'odd', 'even', main='main') 249 assert_that(results.main, equal_to([1, 2, 3, 4])) 250 assert_that(results.odd, equal_to([1, 3]), label='assert:odd') 251 assert_that(results.even, equal_to([2, 4]), label='assert:even') 252 253 @pytest.mark.it_validatesrunner 254 def test_undeclared_outputs(self): 255 with TestPipeline() as pipeline: 256 nums = pipeline | 'Some Numbers' >> beam.Create([1, 2, 3, 4]) 257 results = nums | 'ClassifyNumbers' >> beam.FlatMap( 258 lambda x: [ 259 x, 260 pvalue.TaggedOutput('even' if x % 2 == 0 else 'odd', x), 261 pvalue.TaggedOutput('extra', x) 262 ]).with_outputs() 263 assert_that(results[None], equal_to([1, 2, 3, 4])) 264 assert_that(results.odd, equal_to([1, 3]), label='assert:odd') 265 assert_that(results.even, equal_to([2, 4]), label='assert:even') 266 267 @pytest.mark.it_validatesrunner 268 def test_multiple_empty_outputs(self): 269 with TestPipeline() as pipeline: 270 nums = pipeline | 'Some Numbers' >> beam.Create([1, 3, 5]) 271 results = nums | 'ClassifyNumbers' >> beam.FlatMap( 272 lambda x: 273 [x, pvalue.TaggedOutput('even' 274 if x % 2 == 0 else 'odd', x)]).with_outputs() 275 assert_that(results[None], equal_to([1, 3, 5])) 276 assert_that(results.odd, equal_to([1, 3, 5]), label='assert:odd') 277 assert_that(results.even, equal_to([]), label='assert:even') 278 279 def test_do_requires_do_fn_returning_iterable(self): 280 # This function is incorrect because it returns an object that isn't an 281 # iterable. 282 def incorrect_par_do_fn(x): 283 return x + 5 284 285 with self.assertRaises(typehints.TypeCheckError) as cm: 286 with TestPipeline() as pipeline: 287 pipeline._options.view_as(TypeOptions).runtime_type_check = True 288 pcoll = pipeline | 'Start' >> beam.Create([2, 9, 3]) 289 pcoll | 'Do' >> beam.FlatMap(incorrect_par_do_fn) 290 # It's a requirement that all user-defined functions to a ParDo return 291 # an iterable. 292 293 expected_error_prefix = 'FlatMap and ParDo must return an iterable.' 294 self.assertStartswith(cm.exception.args[0], expected_error_prefix) 295 296 def test_do_fn_with_finish(self): 297 class MyDoFn(beam.DoFn): 298 def process(self, element): 299 pass 300 301 def finish_bundle(self): 302 yield WindowedValue('finish', -1, [window.GlobalWindow()]) 303 304 with TestPipeline() as pipeline: 305 pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3]) 306 result = pcoll | 'Do' >> beam.ParDo(MyDoFn()) 307 308 # May have many bundles, but each has a start and finish. 309 def matcher(): 310 def match(actual): 311 equal_to(['finish'])(list(set(actual))) 312 equal_to([1])([actual.count('finish')]) 313 314 return match 315 316 assert_that(result, matcher()) 317 318 def test_do_fn_with_windowing_in_finish_bundle(self): 319 windowfn = window.FixedWindows(2) 320 321 class MyDoFn(beam.DoFn): 322 def process(self, element): 323 yield TimestampedValue('process' + str(element), 5) 324 325 def finish_bundle(self): 326 yield WindowedValue('finish', 1, [windowfn]) 327 328 with TestPipeline() as pipeline: 329 result = ( 330 pipeline 331 | 'Start' >> beam.Create([1]) 332 | beam.ParDo(MyDoFn()) 333 | WindowInto(windowfn) 334 | 'create tuple' >> beam.Map( 335 lambda v, 336 t=beam.DoFn.TimestampParam, 337 w=beam.DoFn.WindowParam: (v, t, w.start, w.end))) 338 expected_process = [ 339 ('process1', Timestamp(5), Timestamp(4), Timestamp(6)) 340 ] 341 expected_finish = [('finish', Timestamp(1), Timestamp(0), Timestamp(2))] 342 343 assert_that(result, equal_to(expected_process + expected_finish)) 344 345 def test_do_fn_with_start(self): 346 class MyDoFn(beam.DoFn): 347 def __init__(self): 348 self.state = 'init' 349 350 def start_bundle(self): 351 self.state = 'started' 352 353 def process(self, element): 354 if self.state == 'started': 355 yield 'started' 356 self.state = 'process' 357 358 with TestPipeline() as pipeline: 359 pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3]) 360 result = pcoll | 'Do' >> beam.ParDo(MyDoFn()) 361 362 # May have many bundles, but each has a start and finish. 363 def matcher(): 364 def match(actual): 365 equal_to(['started'])(list(set(actual))) 366 equal_to([1])([actual.count('started')]) 367 368 return match 369 370 assert_that(result, matcher()) 371 372 def test_do_fn_with_start_error(self): 373 class MyDoFn(beam.DoFn): 374 def start_bundle(self): 375 return [1] 376 377 def process(self, element): 378 pass 379 380 with self.assertRaises(RuntimeError): 381 with TestPipeline() as p: 382 p | 'Start' >> beam.Create([1, 2, 3]) | 'Do' >> beam.ParDo(MyDoFn()) 383 384 def test_map_builtin(self): 385 with TestPipeline() as pipeline: 386 pcoll = pipeline | 'Start' >> beam.Create([[1, 2], [1], [1, 2, 3]]) 387 result = pcoll | beam.Map(len) 388 assert_that(result, equal_to([1, 2, 3])) 389 390 def test_flatmap_builtin(self): 391 with TestPipeline() as pipeline: 392 pcoll = pipeline | 'Start' >> beam.Create([ 393 [np.array([1, 2, 3])] * 3, [np.array([5, 4, 3]), np.array([5, 6, 7])] 394 ]) 395 result = pcoll | beam.FlatMap(sum) 396 assert_that(result, equal_to([3, 6, 9, 10, 10, 10])) 397 398 def test_filter_builtin(self): 399 with TestPipeline() as pipeline: 400 pcoll = pipeline | 'Start' >> beam.Create([[], [2], [], [4]]) 401 result = pcoll | 'Filter' >> beam.Filter(len) 402 assert_that(result, equal_to([[2], [4]])) 403 404 def test_filter(self): 405 with TestPipeline() as pipeline: 406 pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3, 4]) 407 result = pcoll | 'Filter' >> beam.Filter(lambda x: x % 2 == 0) 408 assert_that(result, equal_to([2, 4])) 409 410 class _MeanCombineFn(beam.CombineFn): 411 def create_accumulator(self): 412 return (0, 0) 413 414 def add_input(self, sum_count, element): 415 (sum_, count) = sum_count 416 return sum_ + element, count + 1 417 418 def merge_accumulators(self, accumulators): 419 sums, counts = zip(*accumulators) 420 return sum(sums), sum(counts) 421 422 def extract_output(self, sum_count): 423 (sum_, count) = sum_count 424 if not count: 425 return float('nan') 426 return sum_ / float(count) 427 428 def test_combine_with_combine_fn(self): 429 vals = [1, 2, 3, 4, 5, 6, 7] 430 with TestPipeline() as pipeline: 431 pcoll = pipeline | 'Start' >> beam.Create(vals) 432 result = pcoll | 'Mean' >> beam.CombineGlobally(self._MeanCombineFn()) 433 assert_that(result, equal_to([sum(vals) // len(vals)])) 434 435 def test_combine_with_callable(self): 436 vals = [1, 2, 3, 4, 5, 6, 7] 437 with TestPipeline() as pipeline: 438 pcoll = pipeline | 'Start' >> beam.Create(vals) 439 result = pcoll | beam.CombineGlobally(sum) 440 assert_that(result, equal_to([sum(vals)])) 441 442 def test_combine_with_side_input_as_arg(self): 443 values = [1, 2, 3, 4, 5, 6, 7] 444 with TestPipeline() as pipeline: 445 pcoll = pipeline | 'Start' >> beam.Create(values) 446 divisor = pipeline | 'Divisor' >> beam.Create([2]) 447 result = pcoll | 'Max' >> beam.CombineGlobally( 448 # Multiples of divisor only. 449 lambda vals, 450 d: max(v for v in vals if v % d == 0), 451 pvalue.AsSingleton(divisor)).without_defaults() 452 filt_vals = [v for v in values if v % 2 == 0] 453 assert_that(result, equal_to([max(filt_vals)])) 454 455 def test_combine_per_key_with_combine_fn(self): 456 vals_1 = [1, 2, 3, 4, 5, 6, 7] 457 vals_2 = [2, 4, 6, 8, 10, 12, 14] 458 with TestPipeline() as pipeline: 459 pcoll = pipeline | 'Start' >> beam.Create( 460 ([('a', x) for x in vals_1] + [('b', x) for x in vals_2])) 461 result = pcoll | 'Mean' >> beam.CombinePerKey(self._MeanCombineFn()) 462 assert_that( 463 result, 464 equal_to([('a', sum(vals_1) // len(vals_1)), 465 ('b', sum(vals_2) // len(vals_2))])) 466 467 def test_combine_per_key_with_callable(self): 468 vals_1 = [1, 2, 3, 4, 5, 6, 7] 469 vals_2 = [2, 4, 6, 8, 10, 12, 14] 470 with TestPipeline() as pipeline: 471 pcoll = pipeline | 'Start' >> beam.Create( 472 ([('a', x) for x in vals_1] + [('b', x) for x in vals_2])) 473 result = pcoll | beam.CombinePerKey(sum) 474 assert_that(result, equal_to([('a', sum(vals_1)), ('b', sum(vals_2))])) 475 476 def test_combine_per_key_with_side_input_as_arg(self): 477 vals_1 = [1, 2, 3, 4, 5, 6, 7] 478 vals_2 = [2, 4, 6, 8, 10, 12, 14] 479 with TestPipeline() as pipeline: 480 pcoll = pipeline | 'Start' >> beam.Create( 481 ([('a', x) for x in vals_1] + [('b', x) for x in vals_2])) 482 divisor = pipeline | 'Divisor' >> beam.Create([2]) 483 result = pcoll | beam.CombinePerKey( 484 lambda vals, 485 d: max(v for v in vals if v % d == 0), 486 pvalue.AsSingleton(divisor)) # Multiples of divisor only. 487 m_1 = max(v for v in vals_1 if v % 2 == 0) 488 m_2 = max(v for v in vals_2 if v % 2 == 0) 489 assert_that(result, equal_to([('a', m_1), ('b', m_2)])) 490 491 def test_group_by_key(self): 492 with TestPipeline() as pipeline: 493 pcoll = pipeline | 'start' >> beam.Create([(1, 1), (2, 1), (3, 1), (1, 2), 494 (2, 2), (1, 3)]) 495 result = pcoll | 'Group' >> beam.GroupByKey() | SortLists 496 assert_that(result, equal_to([(1, [1, 2, 3]), (2, [1, 2]), (3, [1])])) 497 498 def test_group_by_key_unbounded_global_default_trigger(self): 499 test_options = PipelineOptions() 500 test_options.view_as(TypeOptions).allow_unsafe_triggers = False 501 with self.assertRaisesRegex( 502 ValueError, 503 'GroupByKey cannot be applied to an unbounded PCollection with ' + 504 'global windowing and a default trigger'): 505 with TestPipeline(options=test_options) as pipeline: 506 pipeline | TestStream() | beam.GroupByKey() 507 508 def test_group_by_key_unsafe_trigger(self): 509 test_options = PipelineOptions() 510 test_options.view_as(TypeOptions).allow_unsafe_triggers = False 511 with self.assertRaisesRegex(ValueError, 'Unsafe trigger'): 512 with TestPipeline(options=test_options) as pipeline: 513 _ = ( 514 pipeline 515 | beam.Create([(None, None)]) 516 | WindowInto( 517 window.GlobalWindows(), 518 trigger=trigger.AfterCount(5), 519 accumulation_mode=trigger.AccumulationMode.ACCUMULATING) 520 | beam.GroupByKey()) 521 522 def test_group_by_key_allow_unsafe_triggers(self): 523 test_options = PipelineOptions(flags=['--allow_unsafe_triggers']) 524 with TestPipeline(options=test_options) as pipeline: 525 pcoll = ( 526 pipeline 527 | beam.Create([(1, 1), (1, 2), (1, 3), (1, 4)]) 528 | WindowInto( 529 window.GlobalWindows(), 530 trigger=trigger.AfterCount(4), 531 accumulation_mode=trigger.AccumulationMode.ACCUMULATING) 532 | beam.GroupByKey()) 533 assert_that(pcoll, equal_to([(1, [1, 2, 3, 4])])) 534 535 def test_group_by_key_reiteration(self): 536 class MyDoFn(beam.DoFn): 537 def process(self, gbk_result): 538 key, value_list = gbk_result 539 sum_val = 0 540 # Iterate the GBK result for multiple times. 541 for _ in range(0, 17): 542 sum_val += sum(value_list) 543 return [(key, sum_val)] 544 545 with TestPipeline() as pipeline: 546 pcoll = pipeline | 'start' >> beam.Create([(1, 1), (1, 2), (1, 3), 547 (1, 4)]) 548 result = ( 549 pcoll | 'Group' >> beam.GroupByKey() 550 | 'Reiteration-Sum' >> beam.ParDo(MyDoFn())) 551 assert_that(result, equal_to([(1, 170)])) 552 553 def test_group_by_key_deterministic_coder(self): 554 # pylint: disable=global-variable-not-assigned 555 global MyObject # for pickling of the class instance 556 557 class MyObject: 558 def __init__(self, value): 559 self.value = value 560 561 def __eq__(self, other): 562 return self.value == other.value 563 564 def __hash__(self): 565 return hash(self.value) 566 567 class MyObjectCoder(beam.coders.Coder): 568 def encode(self, o): 569 return pickle.dumps((o.value, random.random())) 570 571 def decode(self, encoded): 572 return MyObject(pickle.loads(encoded)[0]) 573 574 def as_deterministic_coder(self, *args): 575 return MydeterministicObjectCoder() 576 577 def to_type_hint(self): 578 return MyObject 579 580 class MydeterministicObjectCoder(beam.coders.Coder): 581 def encode(self, o): 582 return pickle.dumps(o.value) 583 584 def decode(self, encoded): 585 return MyObject(pickle.loads(encoded)) 586 587 def is_deterministic(self): 588 return True 589 590 beam.coders.registry.register_coder(MyObject, MyObjectCoder) 591 592 with TestPipeline() as pipeline: 593 pcoll = pipeline | beam.Create([(MyObject(k % 2), k) for k in range(10)]) 594 grouped = pcoll | beam.GroupByKey() | beam.MapTuple( 595 lambda k, vs: (k.value, sorted(vs))) 596 combined = pcoll | beam.CombinePerKey(sum) | beam.MapTuple( 597 lambda k, v: (k.value, v)) 598 assert_that( 599 grouped, 600 equal_to([(0, [0, 2, 4, 6, 8]), (1, [1, 3, 5, 7, 9])]), 601 'CheckGrouped') 602 assert_that(combined, equal_to([(0, 20), (1, 25)]), 'CheckCombined') 603 604 def test_group_by_key_non_deterministic_coder(self): 605 with self.assertRaisesRegex(Exception, r'deterministic'): 606 with TestPipeline() as pipeline: 607 _ = ( 608 pipeline 609 | beam.Create([(PickledObject(10), None)]) 610 | beam.GroupByKey() 611 | beam.MapTuple(lambda k, v: list(v))) 612 613 def test_group_by_key_allow_non_deterministic_coder(self): 614 with TestPipeline() as pipeline: 615 # The GroupByKey below would fail without this option. 616 pipeline._options.view_as( 617 TypeOptions).allow_non_deterministic_key_coders = True 618 grouped = ( 619 pipeline 620 | beam.Create([(PickledObject(10), None)]) 621 | beam.GroupByKey() 622 | beam.MapTuple(lambda k, v: list(v))) 623 assert_that(grouped, equal_to([[None]])) 624 625 def test_group_by_key_fake_deterministic_coder(self): 626 fresh_registry = beam.coders.typecoders.CoderRegistry() 627 with patch.object( 628 beam.coders, 'registry', fresh_registry), patch.object( 629 beam.coders.typecoders, 'registry', fresh_registry): 630 with TestPipeline() as pipeline: 631 # The GroupByKey below would fail without this registration. 632 beam.coders.registry.register_fallback_coder( 633 beam.coders.coders.FakeDeterministicFastPrimitivesCoder()) 634 grouped = ( 635 pipeline 636 | beam.Create([(PickledObject(10), None)]) 637 | beam.GroupByKey() 638 | beam.MapTuple(lambda k, v: list(v))) 639 assert_that(grouped, equal_to([[None]])) 640 641 def test_partition_with_partition_fn(self): 642 class SomePartitionFn(beam.PartitionFn): 643 def partition_for(self, element, num_partitions, offset): 644 return (element % 3) + offset 645 646 with TestPipeline() as pipeline: 647 pcoll = pipeline | 'Start' >> beam.Create([0, 1, 2, 3, 4, 5, 6, 7, 8]) 648 # Attempt nominal partition operation. 649 partitions = pcoll | 'Part 1' >> beam.Partition(SomePartitionFn(), 4, 1) 650 assert_that(partitions[0], equal_to([])) 651 assert_that(partitions[1], equal_to([0, 3, 6]), label='p1') 652 assert_that(partitions[2], equal_to([1, 4, 7]), label='p2') 653 assert_that(partitions[3], equal_to([2, 5, 8]), label='p3') 654 655 # Check that a bad partition label will yield an error. For the 656 # DirectRunner, this error manifests as an exception. 657 with self.assertRaises(ValueError): 658 with TestPipeline() as pipeline: 659 pcoll = pipeline | 'Start' >> beam.Create([0, 1, 2, 3, 4, 5, 6, 7, 8]) 660 partitions = pcoll | beam.Partition(SomePartitionFn(), 4, 10000) 661 662 def test_partition_with_callable(self): 663 with TestPipeline() as pipeline: 664 pcoll = pipeline | 'Start' >> beam.Create([0, 1, 2, 3, 4, 5, 6, 7, 8]) 665 partitions = ( 666 pcoll | 667 'part' >> beam.Partition(lambda e, n, offset: (e % 3) + offset, 4, 1)) 668 assert_that(partitions[0], equal_to([])) 669 assert_that(partitions[1], equal_to([0, 3, 6]), label='p1') 670 assert_that(partitions[2], equal_to([1, 4, 7]), label='p2') 671 assert_that(partitions[3], equal_to([2, 5, 8]), label='p3') 672 673 def test_partition_with_callable_and_side_input(self): 674 with TestPipeline() as pipeline: 675 pcoll = pipeline | 'Start' >> beam.Create([0, 1, 2, 3, 4, 5, 6, 7, 8]) 676 side_input = pipeline | 'Side Input' >> beam.Create([100, 1000]) 677 partitions = ( 678 pcoll | 'part' >> beam.Partition( 679 lambda e, 680 n, 681 offset, 682 si_list: ((e + len(si_list)) % 3) + offset, 683 4, 684 1, 685 pvalue.AsList(side_input))) 686 assert_that(partitions[0], equal_to([])) 687 assert_that(partitions[1], equal_to([1, 4, 7]), label='p1') 688 assert_that(partitions[2], equal_to([2, 5, 8]), label='p2') 689 assert_that(partitions[3], equal_to([0, 3, 6]), label='p3') 690 691 def test_partition_followed_by_flatten_and_groupbykey(self): 692 """Regression test for an issue with how partitions are handled.""" 693 with TestPipeline() as pipeline: 694 contents = [('aa', 1), ('bb', 2), ('aa', 2)] 695 created = pipeline | 'A' >> beam.Create(contents) 696 partitioned = created | 'B' >> beam.Partition(lambda x, n: len(x) % n, 3) 697 flattened = partitioned | 'C' >> beam.Flatten() 698 grouped = flattened | 'D' >> beam.GroupByKey() | SortLists 699 assert_that(grouped, equal_to([('aa', [1, 2]), ('bb', [2])])) 700 701 @pytest.mark.it_validatesrunner 702 def test_flatten_pcollections(self): 703 with TestPipeline() as pipeline: 704 pcoll_1 = pipeline | 'Start 1' >> beam.Create([0, 1, 2, 3]) 705 pcoll_2 = pipeline | 'Start 2' >> beam.Create([4, 5, 6, 7]) 706 result = (pcoll_1, pcoll_2) | 'Flatten' >> beam.Flatten() 707 assert_that(result, equal_to([0, 1, 2, 3, 4, 5, 6, 7])) 708 709 def test_flatten_no_pcollections(self): 710 with TestPipeline() as pipeline: 711 with self.assertRaises(ValueError): 712 () | 'PipelineArgMissing' >> beam.Flatten() 713 result = () | 'Empty' >> beam.Flatten(pipeline=pipeline) 714 assert_that(result, equal_to([])) 715 716 @pytest.mark.it_validatesrunner 717 def test_flatten_one_single_pcollection(self): 718 with TestPipeline() as pipeline: 719 input = [0, 1, 2, 3] 720 pcoll = pipeline | 'Input' >> beam.Create(input) 721 result = (pcoll, ) | 'Single Flatten' >> beam.Flatten() 722 assert_that(result, equal_to(input)) 723 724 # TODO(https://github.com/apache/beam/issues/20067): Does not work in 725 # streaming mode on Dataflow. 726 @pytest.mark.no_sickbay_streaming 727 @pytest.mark.it_validatesrunner 728 def test_flatten_same_pcollections(self): 729 with TestPipeline() as pipeline: 730 pc = pipeline | beam.Create(['a', 'b']) 731 assert_that((pc, pc, pc) | beam.Flatten(), equal_to(['a', 'b'] * 3)) 732 733 def test_flatten_pcollections_in_iterable(self): 734 with TestPipeline() as pipeline: 735 pcoll_1 = pipeline | 'Start 1' >> beam.Create([0, 1, 2, 3]) 736 pcoll_2 = pipeline | 'Start 2' >> beam.Create([4, 5, 6, 7]) 737 result = [pcoll for pcoll in (pcoll_1, pcoll_2)] | beam.Flatten() 738 assert_that(result, equal_to([0, 1, 2, 3, 4, 5, 6, 7])) 739 740 @pytest.mark.it_validatesrunner 741 def test_flatten_a_flattened_pcollection(self): 742 with TestPipeline() as pipeline: 743 pcoll_1 = pipeline | 'Start 1' >> beam.Create([0, 1, 2, 3]) 744 pcoll_2 = pipeline | 'Start 2' >> beam.Create([4, 5, 6, 7]) 745 pcoll_3 = pipeline | 'Start 3' >> beam.Create([8, 9]) 746 pcoll_12 = (pcoll_1, pcoll_2) | 'Flatten' >> beam.Flatten() 747 pcoll_123 = (pcoll_12, pcoll_3) | 'Flatten again' >> beam.Flatten() 748 assert_that(pcoll_123, equal_to([x for x in range(10)])) 749 750 def test_flatten_input_type_must_be_iterable(self): 751 # Inputs to flatten *must* be an iterable. 752 with self.assertRaises(ValueError): 753 4 | beam.Flatten() 754 755 def test_flatten_input_type_must_be_iterable_of_pcolls(self): 756 # Inputs to flatten *must* be an iterable of PCollections. 757 with self.assertRaises(TypeError): 758 {'l': 'test'} | beam.Flatten() 759 with self.assertRaises(TypeError): 760 set([1, 2, 3]) | beam.Flatten() 761 762 @pytest.mark.it_validatesrunner 763 def test_flatten_multiple_pcollections_having_multiple_consumers(self): 764 with TestPipeline() as pipeline: 765 input = pipeline | 'Start' >> beam.Create(['AA', 'BBB', 'CC']) 766 767 def split_even_odd(element): 768 tag = 'even_length' if len(element) % 2 == 0 else 'odd_length' 769 return pvalue.TaggedOutput(tag, element) 770 771 even_length, odd_length = (input | beam.Map(split_even_odd) 772 .with_outputs('even_length', 'odd_length')) 773 merged = (even_length, odd_length) | 'Flatten' >> beam.Flatten() 774 775 assert_that(merged, equal_to(['AA', 'BBB', 'CC'])) 776 assert_that(even_length, equal_to(['AA', 'CC']), label='assert:even') 777 assert_that(odd_length, equal_to(['BBB']), label='assert:odd') 778 779 def test_group_by_key_input_must_be_kv_pairs(self): 780 with self.assertRaises(typehints.TypeCheckError) as e: 781 with TestPipeline() as pipeline: 782 pcolls = pipeline | 'A' >> beam.Create([1, 2, 3, 4, 5]) 783 pcolls | 'D' >> beam.GroupByKey() 784 785 self.assertStartswith( 786 e.exception.args[0], 787 'Input type hint violation at D: expected ' 788 'Tuple[TypeVariable[K], TypeVariable[V]]') 789 790 def test_group_by_key_only_input_must_be_kv_pairs(self): 791 with self.assertRaises(typehints.TypeCheckError) as cm: 792 with TestPipeline() as pipeline: 793 pcolls = pipeline | 'A' >> beam.Create(['a', 'b', 'f']) 794 pcolls | 'D' >> beam.GroupByKey() 795 796 expected_error_prefix = ( 797 'Input type hint violation at D: expected ' 798 'Tuple[TypeVariable[K], TypeVariable[V]]') 799 self.assertStartswith(cm.exception.args[0], expected_error_prefix) 800 801 def test_keys_and_values(self): 802 with TestPipeline() as pipeline: 803 pcoll = pipeline | 'Start' >> beam.Create([(3, 1), (2, 1), (1, 1), (3, 2), 804 (2, 2), (3, 3)]) 805 keys = pcoll.apply(beam.Keys('keys')) 806 vals = pcoll.apply(beam.Values('vals')) 807 assert_that(keys, equal_to([1, 2, 2, 3, 3, 3]), label='assert:keys') 808 assert_that(vals, equal_to([1, 1, 1, 2, 2, 3]), label='assert:vals') 809 810 def test_kv_swap(self): 811 with TestPipeline() as pipeline: 812 pcoll = pipeline | 'Start' >> beam.Create([(6, 3), (1, 2), (7, 1), (5, 2), 813 (3, 2)]) 814 result = pcoll.apply(beam.KvSwap(), label='swap') 815 assert_that(result, equal_to([(1, 7), (2, 1), (2, 3), (2, 5), (3, 6)])) 816 817 def test_distinct(self): 818 with TestPipeline() as pipeline: 819 pcoll = pipeline | 'Start' >> beam.Create( 820 [6, 3, 1, 1, 9, 'pleat', 'pleat', 'kazoo', 'navel']) 821 result = pcoll.apply(beam.Distinct()) 822 assert_that(result, equal_to([1, 3, 6, 9, 'pleat', 'kazoo', 'navel'])) 823 824 def test_chained_ptransforms(self): 825 with TestPipeline() as pipeline: 826 t = ( 827 beam.Map(lambda x: (x, 1)) 828 | beam.GroupByKey() 829 | beam.Map(lambda x_ones: (x_ones[0], sum(x_ones[1])))) 830 result = pipeline | 'Start' >> beam.Create(['a', 'a', 'b']) | t 831 assert_that(result, equal_to([('a', 2), ('b', 1)])) 832 833 def test_apply_to_list(self): 834 self.assertCountEqual([1, 2, 3], 835 [0, 1, 2] | 'AddOne' >> beam.Map(lambda x: x + 1)) 836 self.assertCountEqual([1], 837 [0, 1, 2] | 'Odd' >> beam.Filter(lambda x: x % 2)) 838 self.assertCountEqual([1, 2, 100, 3], ([1, 2, 3], [100]) | beam.Flatten()) 839 join_input = ([('k', 'a')], [('k', 'b'), ('k', 'c')]) 840 self.assertCountEqual([('k', (['a'], ['b', 'c']))], 841 join_input | beam.CoGroupByKey() | SortLists) 842 843 def test_multi_input_ptransform(self): 844 class DisjointUnion(PTransform): 845 def expand(self, pcollections): 846 return ( 847 pcollections 848 | beam.Flatten() 849 | beam.Map(lambda x: (x, None)) 850 | beam.GroupByKey() 851 | beam.Map(lambda kv: kv[0])) 852 853 self.assertEqual([1, 2, 3], sorted(([1, 2], [2, 3]) | DisjointUnion())) 854 855 def test_apply_to_crazy_pvaluish(self): 856 class NestedFlatten(PTransform): 857 """A PTransform taking and returning nested PValueish. 858 859 Takes as input a list of dicts, and returns a dict with the corresponding 860 values flattened. 861 """ 862 def _extract_input_pvalues(self, pvalueish): 863 pvalueish = list(pvalueish) 864 return pvalueish, sum([list(p.values()) for p in pvalueish], []) 865 866 def expand(self, pcoll_dicts): 867 keys = reduce(operator.or_, [set(p.keys()) for p in pcoll_dicts]) 868 res = {} 869 for k in keys: 870 res[k] = [p[k] for p in pcoll_dicts if k in p] | k >> beam.Flatten() 871 return res 872 873 res = [{ 874 'a': [1, 2, 3] 875 }, { 876 'a': [4, 5, 6], 'b': ['x', 'y', 'z'] 877 }, { 878 'a': [7, 8], 'b': ['x', 'y'], 'c': [] 879 }] | NestedFlatten() 880 self.assertEqual(3, len(res)) 881 self.assertEqual([1, 2, 3, 4, 5, 6, 7, 8], sorted(res['a'])) 882 self.assertEqual(['x', 'x', 'y', 'y', 'z'], sorted(res['b'])) 883 self.assertEqual([], sorted(res['c'])) 884 885 def test_named_tuple(self): 886 MinMax = collections.namedtuple('MinMax', ['min', 'max']) 887 888 class MinMaxTransform(PTransform): 889 def expand(self, pcoll): 890 return MinMax( 891 min=pcoll | beam.CombineGlobally(min).without_defaults(), 892 max=pcoll | beam.CombineGlobally(max).without_defaults()) 893 894 res = [1, 2, 4, 8] | MinMaxTransform() 895 self.assertIsInstance(res, MinMax) 896 self.assertEqual(res, MinMax(min=[1], max=[8])) 897 898 flat = res | beam.Flatten() 899 self.assertEqual(sorted(flat), [1, 8]) 900 901 def test_tuple_twice(self): 902 class Duplicate(PTransform): 903 def expand(self, pcoll): 904 return pcoll, pcoll 905 906 res1, res2 = [1, 2, 4, 8] | Duplicate() 907 self.assertEqual(sorted(res1), [1, 2, 4, 8]) 908 self.assertEqual(sorted(res2), [1, 2, 4, 8]) 909 910 def test_resource_hint_application_is_additive(self): 911 t = beam.Map(lambda x: x + 1).with_resource_hints( 912 accelerator='gpu').with_resource_hints(min_ram=1).with_resource_hints( 913 accelerator='tpu') 914 self.assertEqual( 915 t.get_resource_hints(), 916 { 917 common_urns.resource_hints.ACCELERATOR.urn: b'tpu', 918 common_urns.resource_hints.MIN_RAM_BYTES.urn: b'1' 919 }) 920 921 922 class TestGroupBy(unittest.TestCase): 923 def test_lambdas(self): 924 def normalize(key, values): 925 return tuple(key) if isinstance(key, tuple) else key, sorted(values) 926 927 with TestPipeline() as p: 928 pcoll = p | beam.Create(range(6)) 929 assert_that( 930 pcoll | beam.GroupBy() | beam.MapTuple(normalize), 931 equal_to([((), [0, 1, 2, 3, 4, 5])]), 932 'GroupAll') 933 assert_that( 934 pcoll | beam.GroupBy(lambda x: x % 2) 935 | 'n2' >> beam.MapTuple(normalize), 936 equal_to([(0, [0, 2, 4]), (1, [1, 3, 5])]), 937 'GroupOne') 938 assert_that( 939 pcoll | 'G2' >> beam.GroupBy(lambda x: x % 2).force_tuple_keys() 940 | 'n3' >> beam.MapTuple(normalize), 941 equal_to([((0, ), [0, 2, 4]), ((1, ), [1, 3, 5])]), 942 'GroupOneTuple') 943 assert_that( 944 pcoll | beam.GroupBy(a=lambda x: x % 2, b=lambda x: x < 4) 945 | 'n4' >> beam.MapTuple(normalize), 946 equal_to([((0, True), [0, 2]), ((1, True), [1, 3]), ((0, False), [4]), 947 ((1, False), [5])]), 948 'GroupTwo') 949 950 def test_fields(self): 951 def normalize(key, values): 952 if isinstance(key, tuple): 953 key = beam.Row( 954 **{name: value 955 for name, value in zip(type(key)._fields, key)}) 956 return key, sorted(v.value for v in values) 957 958 with TestPipeline() as p: 959 pcoll = p | beam.Create(range(-2, 3)) | beam.Map(int) | beam.Map( 960 lambda x: beam.Row( 961 value=x, square=x * x, sign=x // abs(x) if x else 0)) 962 assert_that( 963 pcoll | beam.GroupBy('square') | beam.MapTuple(normalize), 964 equal_to([ 965 (0, [0]), 966 (1, [-1, 1]), 967 (4, [-2, 2]), 968 ]), 969 'GroupSquare') 970 assert_that( 971 pcoll | 'G2' >> beam.GroupBy('square').force_tuple_keys() 972 | 'n2' >> beam.MapTuple(normalize), 973 equal_to([ 974 (beam.Row(square=0), [0]), 975 (beam.Row(square=1), [-1, 1]), 976 (beam.Row(square=4), [-2, 2]), 977 ]), 978 'GroupSquareTupleKey') 979 assert_that( 980 pcoll | beam.GroupBy('square', 'sign') 981 | 'n3' >> beam.MapTuple(normalize), 982 equal_to([ 983 (beam.Row(square=0, sign=0), [0]), 984 (beam.Row(square=1, sign=1), [1]), 985 (beam.Row(square=4, sign=1), [2]), 986 (beam.Row(square=1, sign=-1), [-1]), 987 (beam.Row(square=4, sign=-1), [-2]), 988 ]), 989 'GroupSquareSign') 990 assert_that( 991 pcoll | beam.GroupBy('square', big=lambda x: x.value > 1) 992 | 'n4' >> beam.MapTuple(normalize), 993 equal_to([ 994 (beam.Row(square=0, big=False), [0]), 995 (beam.Row(square=1, big=False), [-1, 1]), 996 (beam.Row(square=4, big=False), [-2]), 997 (beam.Row(square=4, big=True), [2]), 998 ]), 999 'GroupSquareNonzero') 1000 1001 def test_aggregate(self): 1002 def named_tuple_to_row(t): 1003 return beam.Row( 1004 **{name: value 1005 for name, value in zip(type(t)._fields, t)}) 1006 1007 with TestPipeline() as p: 1008 pcoll = p | beam.Create(range(-2, 3)) | beam.Map( 1009 lambda x: beam.Row( 1010 value=x, square=x * x, sign=x // abs(x) if x else 0)) 1011 1012 assert_that( 1013 pcoll 1014 | beam.GroupBy('square', big=lambda x: x.value > 1) 1015 .aggregate_field('value', sum, 'sum') 1016 .aggregate_field(lambda x: x.sign == 1, all, 'positive') 1017 | beam.Map(named_tuple_to_row), 1018 equal_to([ 1019 beam.Row(square=0, big=False, sum=0, positive=False), # [0], 1020 beam.Row(square=1, big=False, sum=0, positive=False), # [-1, 1] 1021 beam.Row(square=4, big=False, sum=-2, positive=False), # [-2] 1022 beam.Row(square=4, big=True, sum=2, positive=True), # [2] 1023 ])) 1024 1025 def test_pickled_field(self): 1026 with TestPipeline() as p: 1027 assert_that( 1028 p 1029 | beam.Create(['a', 'a', 'b']) 1030 | beam.Map( 1031 lambda s: beam.Row( 1032 key1=PickledObject(s), key2=s.upper(), value=0)) 1033 | beam.GroupBy('key1', 'key2') 1034 | beam.MapTuple(lambda k, vs: (k.key1.value, k.key2, len(list(vs)))), 1035 equal_to([('a', 'A', 2), ('b', 'B', 1)])) 1036 1037 1038 class SelectTest(unittest.TestCase): 1039 def test_simple(self): 1040 with TestPipeline() as p: 1041 rows = ( 1042 p | beam.Create([1, 2, 10]) 1043 | beam.Select(a=lambda x: x * x, b=lambda x: -x)) 1044 1045 assert_that( 1046 rows, 1047 equal_to([ 1048 beam.Row(a=1, b=-1), 1049 beam.Row(a=4, b=-2), 1050 beam.Row(a=100, b=-10), 1051 ]), 1052 label='CheckFromLambdas') 1053 1054 from_attr = rows | beam.Select('b', z='a') 1055 assert_that( 1056 from_attr, 1057 equal_to([ 1058 beam.Row(b=-1, z=1), 1059 beam.Row(b=-2, z=4), 1060 beam.Row( 1061 b=-10, 1062 z=100, 1063 ), 1064 ]), 1065 label='CheckFromAttrs') 1066 1067 1068 @beam.ptransform_fn 1069 def SamplePTransform(pcoll): 1070 """Sample transform using the @ptransform_fn decorator.""" 1071 map_transform = 'ToPairs' >> beam.Map(lambda v: (v, None)) 1072 combine_transform = 'Group' >> beam.CombinePerKey(lambda vs: None) 1073 keys_transform = 'Distinct' >> beam.Keys() 1074 return pcoll | map_transform | combine_transform | keys_transform 1075 1076 1077 class PTransformLabelsTest(unittest.TestCase): 1078 class CustomTransform(beam.PTransform): 1079 1080 pardo = None # type: Optional[beam.PTransform] 1081 1082 def expand(self, pcoll): 1083 self.pardo = '*Do*' >> beam.FlatMap(lambda x: [x + 1]) 1084 return pcoll | self.pardo 1085 1086 def test_chained_ptransforms(self): 1087 """Tests that chaining gets proper nesting.""" 1088 with TestPipeline() as pipeline: 1089 map1 = 'Map1' >> beam.Map(lambda x: (x, 1)) 1090 gbk = 'Gbk' >> beam.GroupByKey() 1091 map2 = 'Map2' >> beam.Map(lambda x_ones2: (x_ones2[0], sum(x_ones2[1]))) 1092 t = (map1 | gbk | map2) 1093 result = pipeline | 'Start' >> beam.Create(['a', 'a', 'b']) | t 1094 self.assertTrue('Map1|Gbk|Map2/Map1' in pipeline.applied_labels) 1095 self.assertTrue('Map1|Gbk|Map2/Gbk' in pipeline.applied_labels) 1096 self.assertTrue('Map1|Gbk|Map2/Map2' in pipeline.applied_labels) 1097 assert_that(result, equal_to([('a', 2), ('b', 1)])) 1098 1099 def test_apply_custom_transform_without_label(self): 1100 with TestPipeline() as pipeline: 1101 pcoll = pipeline | 'PColl' >> beam.Create([1, 2, 3]) 1102 custom = PTransformLabelsTest.CustomTransform() 1103 result = pipeline.apply(custom, pcoll) 1104 self.assertTrue('CustomTransform' in pipeline.applied_labels) 1105 self.assertTrue('CustomTransform/*Do*' in pipeline.applied_labels) 1106 assert_that(result, equal_to([2, 3, 4])) 1107 1108 def test_apply_custom_transform_with_label(self): 1109 with TestPipeline() as pipeline: 1110 pcoll = pipeline | 'PColl' >> beam.Create([1, 2, 3]) 1111 custom = PTransformLabelsTest.CustomTransform('*Custom*') 1112 result = pipeline.apply(custom, pcoll) 1113 self.assertTrue('*Custom*' in pipeline.applied_labels) 1114 self.assertTrue('*Custom*/*Do*' in pipeline.applied_labels) 1115 assert_that(result, equal_to([2, 3, 4])) 1116 1117 def test_combine_without_label(self): 1118 vals = [1, 2, 3, 4, 5, 6, 7] 1119 with TestPipeline() as pipeline: 1120 pcoll = pipeline | 'Start' >> beam.Create(vals) 1121 combine = beam.CombineGlobally(sum) 1122 result = pcoll | combine 1123 self.assertTrue('CombineGlobally(sum)' in pipeline.applied_labels) 1124 assert_that(result, equal_to([sum(vals)])) 1125 1126 def test_apply_ptransform_using_decorator(self): 1127 pipeline = TestPipeline() 1128 pcoll = pipeline | 'PColl' >> beam.Create([1, 2, 3]) 1129 _ = pcoll | '*Sample*' >> SamplePTransform() 1130 self.assertTrue('*Sample*' in pipeline.applied_labels) 1131 self.assertTrue('*Sample*/ToPairs' in pipeline.applied_labels) 1132 self.assertTrue('*Sample*/Group' in pipeline.applied_labels) 1133 self.assertTrue('*Sample*/Distinct' in pipeline.applied_labels) 1134 1135 def test_combine_with_label(self): 1136 vals = [1, 2, 3, 4, 5, 6, 7] 1137 with TestPipeline() as pipeline: 1138 pcoll = pipeline | 'Start' >> beam.Create(vals) 1139 combine = '*Sum*' >> beam.CombineGlobally(sum) 1140 result = pcoll | combine 1141 self.assertTrue('*Sum*' in pipeline.applied_labels) 1142 assert_that(result, equal_to([sum(vals)])) 1143 1144 def check_label(self, ptransform, expected_label): 1145 pipeline = TestPipeline() 1146 pipeline | 'Start' >> beam.Create([('a', 1)]) | ptransform 1147 actual_label = sorted( 1148 label for label in pipeline.applied_labels 1149 if not label.startswith('Start'))[0] 1150 self.assertEqual(expected_label, re.sub(r'\d{3,}', '#', actual_label)) 1151 1152 def test_default_labels(self): 1153 def my_function(*args): 1154 pass 1155 1156 self.check_label(beam.Map(len), 'Map(len)') 1157 self.check_label(beam.Map(my_function), 'Map(my_function)') 1158 self.check_label( 1159 beam.Map(lambda x: x), 'Map(<lambda at ptransform_test.py:#>)') 1160 self.check_label(beam.FlatMap(list), 'FlatMap(list)') 1161 self.check_label(beam.FlatMap(my_function), 'FlatMap(my_function)') 1162 self.check_label(beam.Filter(sum), 'Filter(sum)') 1163 self.check_label(beam.CombineGlobally(sum), 'CombineGlobally(sum)') 1164 self.check_label(beam.CombinePerKey(sum), 'CombinePerKey(sum)') 1165 1166 class MyDoFn(beam.DoFn): 1167 def process(self, unused_element): 1168 pass 1169 1170 self.check_label(beam.ParDo(MyDoFn()), 'ParDo(MyDoFn)') 1171 1172 def test_label_propogation(self): 1173 self.check_label('TestMap' >> beam.Map(len), 'TestMap') 1174 self.check_label('TestLambda' >> beam.Map(lambda x: x), 'TestLambda') 1175 self.check_label('TestFlatMap' >> beam.FlatMap(list), 'TestFlatMap') 1176 self.check_label('TestFilter' >> beam.Filter(sum), 'TestFilter') 1177 self.check_label('TestCG' >> beam.CombineGlobally(sum), 'TestCG') 1178 self.check_label('TestCPK' >> beam.CombinePerKey(sum), 'TestCPK') 1179 1180 class MyDoFn(beam.DoFn): 1181 def process(self, unused_element): 1182 pass 1183 1184 self.check_label('TestParDo' >> beam.ParDo(MyDoFn()), 'TestParDo') 1185 1186 1187 class PTransformTestDisplayData(unittest.TestCase): 1188 def test_map_named_function(self): 1189 tr = beam.Map(len) 1190 dd = DisplayData.create_from(tr) 1191 nspace = 'apache_beam.transforms.core.CallableWrapperDoFn' 1192 expected_item = DisplayDataItem( 1193 'len', key='fn', label='Transform Function', namespace=nspace) 1194 hc.assert_that(dd.items, hc.has_item(expected_item)) 1195 1196 def test_map_anonymous_function(self): 1197 tr = beam.Map(lambda x: x) 1198 dd = DisplayData.create_from(tr) 1199 nspace = 'apache_beam.transforms.core.CallableWrapperDoFn' 1200 expected_item = DisplayDataItem( 1201 '<lambda>', key='fn', label='Transform Function', namespace=nspace) 1202 hc.assert_that(dd.items, hc.has_item(expected_item)) 1203 1204 def test_flatmap_named_function(self): 1205 tr = beam.FlatMap(list) 1206 dd = DisplayData.create_from(tr) 1207 nspace = 'apache_beam.transforms.core.CallableWrapperDoFn' 1208 expected_item = DisplayDataItem( 1209 'list', key='fn', label='Transform Function', namespace=nspace) 1210 hc.assert_that(dd.items, hc.has_item(expected_item)) 1211 1212 def test_flatmap_anonymous_function(self): 1213 tr = beam.FlatMap(lambda x: [x]) 1214 dd = DisplayData.create_from(tr) 1215 nspace = 'apache_beam.transforms.core.CallableWrapperDoFn' 1216 expected_item = DisplayDataItem( 1217 '<lambda>', key='fn', label='Transform Function', namespace=nspace) 1218 hc.assert_that(dd.items, hc.has_item(expected_item)) 1219 1220 def test_filter_named_function(self): 1221 tr = beam.Filter(sum) 1222 dd = DisplayData.create_from(tr) 1223 nspace = 'apache_beam.transforms.core.CallableWrapperDoFn' 1224 expected_item = DisplayDataItem( 1225 'sum', key='fn', label='Transform Function', namespace=nspace) 1226 hc.assert_that(dd.items, hc.has_item(expected_item)) 1227 1228 def test_filter_anonymous_function(self): 1229 tr = beam.Filter(lambda x: x // 30) 1230 dd = DisplayData.create_from(tr) 1231 nspace = 'apache_beam.transforms.core.CallableWrapperDoFn' 1232 expected_item = DisplayDataItem( 1233 '<lambda>', key='fn', label='Transform Function', namespace=nspace) 1234 hc.assert_that(dd.items, hc.has_item(expected_item)) 1235 1236 1237 class PTransformTypeCheckTestCase(TypeHintTestCase): 1238 def assertStartswith(self, msg, prefix): 1239 self.assertTrue( 1240 msg.startswith(prefix), '"%s" does not start with "%s"' % (msg, prefix)) 1241 1242 def setUp(self): 1243 self.p = TestPipeline() 1244 1245 def test_do_fn_pipeline_pipeline_type_check_satisfied(self): 1246 @with_input_types(int, int) 1247 @with_output_types(int) 1248 class AddWithFive(beam.DoFn): 1249 def process(self, element, five): 1250 return [element + five] 1251 1252 d = ( 1253 self.p 1254 | 'T' >> beam.Create([1, 2, 3]).with_output_types(int) 1255 | 'Add' >> beam.ParDo(AddWithFive(), 5)) 1256 1257 assert_that(d, equal_to([6, 7, 8])) 1258 self.p.run() 1259 1260 def test_do_fn_pipeline_pipeline_type_check_violated(self): 1261 @with_input_types(str, str) 1262 @with_output_types(str) 1263 class ToUpperCaseWithPrefix(beam.DoFn): 1264 def process(self, element, prefix): 1265 return [prefix + element.upper()] 1266 1267 with self.assertRaises(typehints.TypeCheckError) as e: 1268 ( 1269 self.p 1270 | 'T' >> beam.Create([1, 2, 3]).with_output_types(int) 1271 | 'Upper' >> beam.ParDo(ToUpperCaseWithPrefix(), 'hello')) 1272 1273 self.assertStartswith( 1274 e.exception.args[0], 1275 "Type hint violation for 'Upper': " 1276 "requires {} but got {} for element".format(str, int)) 1277 1278 def test_do_fn_pipeline_runtime_type_check_satisfied(self): 1279 self.p._options.view_as(TypeOptions).runtime_type_check = True 1280 1281 @with_input_types(int, int) 1282 @with_output_types(int) 1283 class AddWithNum(beam.DoFn): 1284 def process(self, element, num): 1285 return [element + num] 1286 1287 d = ( 1288 self.p 1289 | 'T' >> beam.Create([1, 2, 3]).with_output_types(int) 1290 | 'Add' >> beam.ParDo(AddWithNum(), 5)) 1291 1292 assert_that(d, equal_to([6, 7, 8])) 1293 self.p.run() 1294 1295 def test_do_fn_pipeline_runtime_type_check_violated(self): 1296 self.p._options.view_as(TypeOptions).runtime_type_check = True 1297 1298 @with_input_types(int, int) 1299 @with_output_types(int) 1300 class AddWithNum(beam.DoFn): 1301 def process(self, element, num): 1302 return [element + num] 1303 1304 with self.assertRaises(typehints.TypeCheckError) as e: 1305 ( 1306 self.p 1307 | 'T' >> beam.Create(['1', '2', '3']).with_output_types(str) 1308 | 'Add' >> beam.ParDo(AddWithNum(), 5)) 1309 self.p.run() 1310 1311 self.assertStartswith( 1312 e.exception.args[0], 1313 "Type hint violation for 'Add': " 1314 "requires {} but got {} for element".format(int, str)) 1315 1316 def test_pardo_does_not_type_check_using_type_hint_decorators(self): 1317 @with_input_types(a=int) 1318 @with_output_types(typing.List[str]) 1319 def int_to_str(a): 1320 return [str(a)] 1321 1322 # The function above is expecting an int for its only parameter. However, it 1323 # will receive a str instead, which should result in a raised exception. 1324 with self.assertRaises(typehints.TypeCheckError) as e: 1325 ( 1326 self.p 1327 | 'S' >> beam.Create(['b', 'a', 'r']).with_output_types(str) 1328 | 'ToStr' >> beam.FlatMap(int_to_str)) 1329 1330 self.assertStartswith( 1331 e.exception.args[0], 1332 "Type hint violation for 'ToStr': " 1333 "requires {} but got {} for a".format(int, str)) 1334 1335 def test_pardo_properly_type_checks_using_type_hint_decorators(self): 1336 @with_input_types(a=str) 1337 @with_output_types(typing.List[str]) 1338 def to_all_upper_case(a): 1339 return [a.upper()] 1340 1341 # If this type-checks than no error should be raised. 1342 d = ( 1343 self.p 1344 | 'T' >> beam.Create(['t', 'e', 's', 't']).with_output_types(str) 1345 | 'Case' >> beam.FlatMap(to_all_upper_case)) 1346 assert_that(d, equal_to(['T', 'E', 'S', 'T'])) 1347 self.p.run() 1348 1349 # Output type should have been recognized as 'str' rather than List[str] to 1350 # do the flatten part of FlatMap. 1351 self.assertEqual(str, d.element_type) 1352 1353 def test_pardo_does_not_type_check_using_type_hint_methods(self): 1354 # The first ParDo outputs pcoll's of type int, however the second ParDo is 1355 # expecting pcoll's of type str instead. 1356 with self.assertRaises(typehints.TypeCheckError) as e: 1357 ( 1358 self.p 1359 | 'S' >> beam.Create(['t', 'e', 's', 't']).with_output_types(str) 1360 | ( 1361 'Score' >> beam.FlatMap(lambda x: [1] if x == 't' else [2]). 1362 with_input_types(str).with_output_types(int)) 1363 | ( 1364 'Upper' >> beam.FlatMap(lambda x: [x.upper()]).with_input_types( 1365 str).with_output_types(str))) 1366 1367 self.assertStartswith( 1368 e.exception.args[0], 1369 "Type hint violation for 'Upper': " 1370 "requires {} but got {} for x".format(str, int)) 1371 1372 def test_pardo_properly_type_checks_using_type_hint_methods(self): 1373 # Pipeline should be created successfully without an error 1374 d = ( 1375 self.p 1376 | 'S' >> beam.Create(['t', 'e', 's', 't']).with_output_types(str) 1377 | 'Dup' >> beam.FlatMap(lambda x: [x + x]).with_input_types( 1378 str).with_output_types(str) 1379 | 'Upper' >> beam.FlatMap(lambda x: [x.upper()]).with_input_types( 1380 str).with_output_types(str)) 1381 1382 assert_that(d, equal_to(['TT', 'EE', 'SS', 'TT'])) 1383 self.p.run() 1384 1385 def test_map_does_not_type_check_using_type_hints_methods(self): 1386 # The transform before 'Map' has indicated that it outputs PCollections with 1387 # int's, while Map is expecting one of str. 1388 with self.assertRaises(typehints.TypeCheckError) as e: 1389 ( 1390 self.p 1391 | 'S' >> beam.Create([1, 2, 3, 4]).with_output_types(int) 1392 | 'Upper' >> beam.Map(lambda x: x.upper()).with_input_types( 1393 str).with_output_types(str)) 1394 1395 self.assertStartswith( 1396 e.exception.args[0], 1397 "Type hint violation for 'Upper': " 1398 "requires {} but got {} for x".format(str, int)) 1399 1400 def test_map_properly_type_checks_using_type_hints_methods(self): 1401 # No error should be raised if this type-checks properly. 1402 d = ( 1403 self.p 1404 | 'S' >> beam.Create([1, 2, 3, 4]).with_output_types(int) 1405 | 'ToStr' >> 1406 beam.Map(lambda x: str(x)).with_input_types(int).with_output_types(str)) 1407 assert_that(d, equal_to(['1', '2', '3', '4'])) 1408 self.p.run() 1409 1410 def test_map_does_not_type_check_using_type_hints_decorator(self): 1411 @with_input_types(s=str) 1412 @with_output_types(str) 1413 def upper(s): 1414 return s.upper() 1415 1416 # Hinted function above expects a str at pipeline construction. 1417 # However, 'Map' should detect that Create has hinted an int instead. 1418 with self.assertRaises(typehints.TypeCheckError) as e: 1419 ( 1420 self.p 1421 | 'S' >> beam.Create([1, 2, 3, 4]).with_output_types(int) 1422 | 'Upper' >> beam.Map(upper)) 1423 1424 self.assertStartswith( 1425 e.exception.args[0], 1426 "Type hint violation for 'Upper': " 1427 "requires {} but got {} for s".format(str, int)) 1428 1429 def test_map_properly_type_checks_using_type_hints_decorator(self): 1430 @with_input_types(a=bool) 1431 @with_output_types(int) 1432 def bool_to_int(a): 1433 return int(a) 1434 1435 # If this type-checks than no error should be raised. 1436 d = ( 1437 self.p 1438 | 'Bools' >> beam.Create([True, False, True]).with_output_types(bool) 1439 | 'ToInts' >> beam.Map(bool_to_int)) 1440 assert_that(d, equal_to([1, 0, 1])) 1441 self.p.run() 1442 1443 def test_filter_does_not_type_check_using_type_hints_method(self): 1444 # Filter is expecting an int but instead looks to the 'left' and sees a str 1445 # incoming. 1446 with self.assertRaises(typehints.TypeCheckError) as e: 1447 ( 1448 self.p 1449 | 'Strs' >> beam.Create(['1', '2', '3', '4', '5' 1450 ]).with_output_types(str) 1451 | 'Lower' >> beam.Map(lambda x: x.lower()).with_input_types( 1452 str).with_output_types(str) 1453 | 'Below 3' >> beam.Filter(lambda x: x < 3).with_input_types(int)) 1454 1455 self.assertStartswith( 1456 e.exception.args[0], 1457 "Type hint violation for 'Below 3': " 1458 "requires {} but got {} for x".format(int, str)) 1459 1460 def test_filter_type_checks_using_type_hints_method(self): 1461 # No error should be raised if this type-checks properly. 1462 d = ( 1463 self.p 1464 | beam.Create(['1', '2', '3', '4', '5']).with_output_types(str) 1465 | 'ToInt' >> 1466 beam.Map(lambda x: int(x)).with_input_types(str).with_output_types(int) 1467 | 'Below 3' >> beam.Filter(lambda x: x < 3).with_input_types(int)) 1468 assert_that(d, equal_to([1, 2])) 1469 self.p.run() 1470 1471 def test_filter_does_not_type_check_using_type_hints_decorator(self): 1472 @with_input_types(a=float) 1473 def more_than_half(a): 1474 return a > 0.50 1475 1476 # Func above was hinted to only take a float, yet an int will be passed. 1477 with self.assertRaises(typehints.TypeCheckError) as e: 1478 ( 1479 self.p 1480 | 'Ints' >> beam.Create([1, 2, 3, 4]).with_output_types(int) 1481 | 'Half' >> beam.Filter(more_than_half)) 1482 1483 self.assertStartswith( 1484 e.exception.args[0], 1485 "Type hint violation for 'Half': " 1486 "requires {} but got {} for a".format(float, int)) 1487 1488 def test_filter_type_checks_using_type_hints_decorator(self): 1489 @with_input_types(b=int) 1490 def half(b): 1491 return bool(random.choice([0, 1])) 1492 1493 # Filter should deduce that it returns the same type that it takes. 1494 ( 1495 self.p 1496 | 'Str' >> beam.Create(range(5)).with_output_types(int) 1497 | 'Half' >> beam.Filter(half) 1498 | 'ToBool' >> beam.Map(lambda x: bool(x)).with_input_types( 1499 int).with_output_types(bool)) 1500 1501 def test_pardo_like_inheriting_output_types_from_annotation(self): 1502 def fn1(x: str) -> int: 1503 return 1 1504 1505 def fn1_flat(x: str) -> typing.List[int]: 1506 return [1] 1507 1508 def fn2(x: int, y: str) -> str: 1509 return y 1510 1511 def fn2_flat(x: int, y: str) -> typing.List[str]: 1512 return [y] 1513 1514 # We only need the args section of the hints. 1515 def output_hints(transform): 1516 return transform.default_type_hints().output_types[0][0] 1517 1518 self.assertEqual(int, output_hints(beam.Map(fn1))) 1519 self.assertEqual(int, output_hints(beam.FlatMap(fn1_flat))) 1520 1521 self.assertEqual(str, output_hints(beam.MapTuple(fn2))) 1522 self.assertEqual(str, output_hints(beam.FlatMapTuple(fn2_flat))) 1523 1524 def add(a: typing.Iterable[int]) -> int: 1525 return sum(a) 1526 1527 self.assertCompatible( 1528 typing.Tuple[typing.TypeVar('K'), int], 1529 output_hints(beam.CombinePerKey(add))) 1530 1531 def test_group_by_key_only_output_type_deduction(self): 1532 d = ( 1533 self.p 1534 | 'Str' >> beam.Create(['t', 'e', 's', 't']).with_output_types(str) 1535 | ( 1536 'Pair' >> beam.Map(lambda x: (x, ord(x))).with_output_types( 1537 typing.Tuple[str, str])) 1538 | beam.GroupByKey()) 1539 1540 # Output type should correctly be deduced. 1541 # GBK-only should deduce that Tuple[A, B] is turned into 1542 # Tuple[A, Iterable[B]]. 1543 self.assertCompatible( 1544 typing.Tuple[str, typing.Iterable[str]], d.element_type) 1545 1546 def test_group_by_key_output_type_deduction(self): 1547 d = ( 1548 self.p 1549 | 'Str' >> beam.Create(range(20)).with_output_types(int) 1550 | ( 1551 'PairNegative' >> beam.Map(lambda x: (x % 5, -x)).with_output_types( 1552 typing.Tuple[int, int])) 1553 | beam.GroupByKey()) 1554 1555 # Output type should correctly be deduced. 1556 # GBK should deduce that Tuple[A, B] is turned into Tuple[A, Iterable[B]]. 1557 self.assertCompatible( 1558 typing.Tuple[int, typing.Iterable[int]], d.element_type) 1559 1560 def test_group_by_key_only_does_not_type_check(self): 1561 # GBK will be passed raw int's here instead of some form of Tuple[A, B]. 1562 with self.assertRaises(typehints.TypeCheckError) as e: 1563 ( 1564 self.p 1565 | beam.Create([1, 2, 3]).with_output_types(int) 1566 | 'F' >> beam.GroupByKey()) 1567 1568 self.assertStartswith( 1569 e.exception.args[0], 1570 "Input type hint violation at F: " 1571 "expected Tuple[TypeVariable[K], TypeVariable[V]], " 1572 "got {}".format(int)) 1573 1574 def test_group_by_does_not_type_check(self): 1575 # Create is returning a List[int, str], rather than a Tuple[int, str] 1576 # that is aliased to Tuple[int, str]. 1577 with self.assertRaises(typehints.TypeCheckError) as e: 1578 ( 1579 self.p 1580 | (beam.Create([[1], [2]]).with_output_types(typing.Iterable[int])) 1581 | 'T' >> beam.GroupByKey()) 1582 1583 self.assertStartswith( 1584 e.exception.args[0], 1585 "Input type hint violation at T: " 1586 "expected Tuple[TypeVariable[K], TypeVariable[V]], " 1587 "got Iterable[<class 'int'>]") 1588 1589 def test_pipeline_checking_pardo_insufficient_type_information(self): 1590 self.p._options.view_as(TypeOptions).type_check_strictness = 'ALL_REQUIRED' 1591 1592 # Type checking is enabled, but 'Create' doesn't pass on any relevant type 1593 # information to the ParDo. 1594 with self.assertRaises(typehints.TypeCheckError) as e: 1595 ( 1596 self.p 1597 | 'Nums' >> beam.Create(range(5)) 1598 | 'ModDup' >> beam.FlatMap(lambda x: (x % 2, x))) 1599 1600 self.assertEqual( 1601 'Pipeline type checking is enabled, however no output ' 1602 'type-hint was found for the PTransform Create(Nums)', 1603 e.exception.args[0]) 1604 1605 def test_pipeline_checking_gbk_insufficient_type_information(self): 1606 self.p._options.view_as(TypeOptions).type_check_strictness = 'ALL_REQUIRED' 1607 # Type checking is enabled, but 'Map' doesn't pass on any relevant type 1608 # information to GBK-only. 1609 with self.assertRaises(typehints.TypeCheckError) as e: 1610 ( 1611 self.p 1612 | 'Nums' >> beam.Create(range(5)).with_output_types(int) 1613 | 'ModDup' >> beam.Map(lambda x: (x % 2, x)) 1614 | beam.GroupByKey()) 1615 1616 self.assertEqual( 1617 'Pipeline type checking is enabled, however no output ' 1618 'type-hint was found for the PTransform ' 1619 'ParDo(ModDup)', 1620 e.exception.args[0]) 1621 1622 def test_disable_pipeline_type_check(self): 1623 self.p._options.view_as(TypeOptions).pipeline_type_check = False 1624 1625 # The pipeline below should raise a TypeError, however pipeline type 1626 # checking was disabled above. 1627 ( 1628 self.p 1629 | 'T' >> beam.Create([1, 2, 3]).with_output_types(int) 1630 | 'Lower' >> beam.Map(lambda x: x.lower()).with_input_types( 1631 str).with_output_types(str)) 1632 1633 def test_run_time_type_checking_enabled_type_violation(self): 1634 self.p._options.view_as(TypeOptions).pipeline_type_check = False 1635 self.p._options.view_as(TypeOptions).runtime_type_check = True 1636 1637 @with_output_types(str) 1638 @with_input_types(x=int) 1639 def int_to_string(x): 1640 return str(x) 1641 1642 # Function above has been type-hinted to only accept an int. But in the 1643 # pipeline execution it'll be passed a string due to the output of Create. 1644 ( 1645 self.p 1646 | 'T' >> beam.Create(['some_string']) 1647 | 'ToStr' >> beam.Map(int_to_string)) 1648 with self.assertRaises(typehints.TypeCheckError) as e: 1649 self.p.run() 1650 1651 self.assertStartswith( 1652 e.exception.args[0], 1653 "Runtime type violation detected within ParDo(ToStr): " 1654 "Type-hint for argument: 'x' violated. " 1655 "Expected an instance of {}, " 1656 "instead found some_string, an instance of {}.".format(int, str)) 1657 1658 def test_run_time_type_checking_enabled_types_satisfied(self): 1659 self.p._options.view_as(TypeOptions).pipeline_type_check = False 1660 self.p._options.view_as(TypeOptions).runtime_type_check = True 1661 1662 @with_output_types(typing.Tuple[int, str]) 1663 @with_input_types(x=str) 1664 def group_with_upper_ord(x): 1665 return (ord(x.upper()) % 5, x) 1666 1667 # Pipeline checking is off, but the above function should satisfy types at 1668 # run-time. 1669 result = ( 1670 self.p 1671 | 'T' >> beam.Create(['t', 'e', 's', 't', 'i', 'n', 'g' 1672 ]).with_output_types(str) 1673 | 'GenKeys' >> beam.Map(group_with_upper_ord) 1674 | 'O' >> beam.GroupByKey() 1675 | SortLists) 1676 1677 assert_that( 1678 result, 1679 equal_to([(1, ['g']), (3, ['i', 'n', 's']), (4, ['e', 't', 't'])])) 1680 self.p.run() 1681 1682 def test_pipeline_checking_satisfied_but_run_time_types_violate(self): 1683 self.p._options.view_as(TypeOptions).pipeline_type_check = False 1684 self.p._options.view_as(TypeOptions).runtime_type_check = True 1685 1686 @with_output_types(typing.Tuple[bool, int]) 1687 @with_input_types(a=int) 1688 def is_even_as_key(a): 1689 # Simulate a programming error, should be: return (a % 2 == 0, a) 1690 # However this returns Tuple[int, int] 1691 return (a % 2, a) 1692 1693 ( 1694 self.p 1695 | 'Nums' >> beam.Create(range(5)).with_output_types(int) 1696 | 'IsEven' >> beam.Map(is_even_as_key) 1697 | 'Parity' >> beam.GroupByKey()) 1698 1699 # Although all the types appear to be correct when checked at pipeline 1700 # construction. Runtime type-checking should detect the 'is_even_as_key' is 1701 # returning Tuple[int, int], instead of Tuple[bool, int]. 1702 with self.assertRaises(typehints.TypeCheckError) as e: 1703 self.p.run() 1704 1705 self.assertStartswith( 1706 e.exception.args[0], 1707 "Runtime type violation detected within ParDo(IsEven): " 1708 "Tuple[<class 'bool'>, <class 'int'>] hint type-constraint violated. " 1709 "The type of element #0 in the passed tuple is incorrect. " 1710 "Expected an instance of type <class 'bool'>, " 1711 "instead received an instance of type int.") 1712 1713 def test_pipeline_checking_satisfied_run_time_checking_satisfied(self): 1714 self.p._options.view_as(TypeOptions).pipeline_type_check = False 1715 1716 @with_output_types(typing.Tuple[bool, int]) 1717 @with_input_types(a=int) 1718 def is_even_as_key(a): 1719 # The programming error in the above test-case has now been fixed. 1720 # Everything should properly type-check. 1721 return (a % 2 == 0, a) 1722 1723 result = ( 1724 self.p 1725 | 'Nums' >> beam.Create(range(5)).with_output_types(int) 1726 | 'IsEven' >> beam.Map(is_even_as_key) 1727 | 'Parity' >> beam.GroupByKey() 1728 | SortLists) 1729 1730 assert_that(result, equal_to([(False, [1, 3]), (True, [0, 2, 4])])) 1731 self.p.run() 1732 1733 def test_pipeline_runtime_checking_violation_simple_type_input(self): 1734 self.p._options.view_as(TypeOptions).runtime_type_check = True 1735 self.p._options.view_as(TypeOptions).pipeline_type_check = False 1736 1737 # The type-hinted applied via the 'with_input_types()' method indicates the 1738 # ParDo should receive an instance of type 'str', however an 'int' will be 1739 # passed instead. 1740 with self.assertRaises(typehints.TypeCheckError) as e: 1741 ( 1742 self.p 1743 | beam.Create([1, 1, 1]) 1744 | ( 1745 'ToInt' >> beam.FlatMap(lambda x: [int(x)]).with_input_types( 1746 str).with_output_types(int))) 1747 self.p.run() 1748 1749 self.assertStartswith( 1750 e.exception.args[0], 1751 "Runtime type violation detected within ParDo(ToInt): " 1752 "Type-hint for argument: 'x' violated. " 1753 "Expected an instance of {}, " 1754 "instead found 1, an instance of {}.".format(str, int)) 1755 1756 def test_pipeline_runtime_checking_violation_composite_type_input(self): 1757 self.p._options.view_as(TypeOptions).runtime_type_check = True 1758 self.p._options.view_as(TypeOptions).pipeline_type_check = False 1759 1760 with self.assertRaises(typehints.TypeCheckError) as e: 1761 ( 1762 self.p 1763 | beam.Create([(1, 3.0), (2, 4.9), (3, 9.5)]) 1764 | ( 1765 'Add' >> 1766 beam.FlatMap(lambda x_y: [x_y[0] + x_y[1]]).with_input_types( 1767 typing.Tuple[int, int]).with_output_types(int))) 1768 self.p.run() 1769 1770 self.assertStartswith( 1771 e.exception.args[0], 1772 "Runtime type violation detected within ParDo(Add): " 1773 "Type-hint for argument: 'x_y' violated: " 1774 "Tuple[<class 'int'>, <class 'int'>] hint type-constraint violated. " 1775 "The type of element #1 in the passed tuple is incorrect. " 1776 "Expected an instance of type <class 'int'>, instead received an " 1777 "instance of type float.") 1778 1779 def test_pipeline_runtime_checking_violation_simple_type_output(self): 1780 self.p._options.view_as(TypeOptions).runtime_type_check = True 1781 self.p._options.view_as(TypeOptions).pipeline_type_check = False 1782 1783 # The type-hinted applied via the 'returns()' method indicates the ParDo 1784 # should output an instance of type 'int', however a 'float' will be 1785 # generated instead. 1786 print( 1787 "HINTS", 1788 ( 1789 'ToInt' >> beam.FlatMap(lambda x: [float(x)]).with_input_types( 1790 int).with_output_types(int)).get_type_hints()) 1791 with self.assertRaises(typehints.TypeCheckError) as e: 1792 ( 1793 self.p 1794 | beam.Create([1, 1, 1]) 1795 | ( 1796 'ToInt' >> beam.FlatMap(lambda x: [float(x)]).with_input_types( 1797 int).with_output_types(int))) 1798 self.p.run() 1799 1800 if self.p._options.view_as(TypeOptions).runtime_type_check: 1801 self.assertStartswith( 1802 e.exception.args[0], 1803 "Runtime type violation detected within " 1804 "ParDo(ToInt): " 1805 "According to type-hint expected output should be " 1806 "of type {}. Instead, received '1.0', " 1807 "an instance of type {}.".format(int, float)) 1808 1809 if self.p._options.view_as(TypeOptions).performance_runtime_type_check: 1810 self.assertStartswith( 1811 e.exception.args[0], 1812 "Runtime type violation detected within ToInt: " 1813 "Type-hint for argument: 'x' violated. " 1814 "Expected an instance of {}, " 1815 "instead found 1.0, an instance of {}".format(int, float)) 1816 1817 def test_pipeline_runtime_checking_violation_composite_type_output(self): 1818 self.p._options.view_as(TypeOptions).runtime_type_check = True 1819 self.p._options.view_as(TypeOptions).pipeline_type_check = False 1820 1821 # The type-hinted applied via the 'returns()' method indicates the ParDo 1822 # should return an instance of type: Tuple[float, int]. However, an instance 1823 # of 'int' will be generated instead. 1824 with self.assertRaises(typehints.TypeCheckError) as e: 1825 ( 1826 self.p 1827 | beam.Create([(1, 3.0), (2, 4.9), (3, 9.5)]) 1828 | ( 1829 'Swap' >> 1830 beam.FlatMap(lambda x_y1: [x_y1[0] + x_y1[1]]).with_input_types( 1831 typing.Tuple[int, float]).with_output_types( 1832 typing.Tuple[float, int]))) 1833 self.p.run() 1834 1835 if self.p._options.view_as(TypeOptions).runtime_type_check: 1836 self.assertStartswith( 1837 e.exception.args[0], 1838 "Runtime type violation detected within " 1839 "ParDo(Swap): Tuple type constraint violated. " 1840 "Valid object instance must be of type 'tuple'. Instead, " 1841 "an instance of 'float' was received.") 1842 1843 if self.p._options.view_as(TypeOptions).performance_runtime_type_check: 1844 self.assertStartswith( 1845 e.exception.args[0], 1846 "Runtime type violation detected within " 1847 "Swap: Type-hint for argument: 'x_y1' violated: " 1848 "Tuple type constraint violated. " 1849 "Valid object instance must be of type 'tuple'. ") 1850 1851 def test_pipeline_runtime_checking_violation_with_side_inputs_decorator(self): 1852 self.p._options.view_as(TypeOptions).pipeline_type_check = False 1853 self.p._options.view_as(TypeOptions).runtime_type_check = True 1854 1855 @with_output_types(int) 1856 @with_input_types(a=int, b=int) 1857 def add(a, b): 1858 return a + b 1859 1860 with self.assertRaises(typehints.TypeCheckError) as e: 1861 (self.p | beam.Create([1, 2, 3, 4]) | 'Add 1' >> beam.Map(add, 1.0)) 1862 self.p.run() 1863 1864 self.assertStartswith( 1865 e.exception.args[0], 1866 "Runtime type violation detected within ParDo(Add 1): " 1867 "Type-hint for argument: 'b' violated. " 1868 "Expected an instance of {}, " 1869 "instead found 1.0, an instance of {}.".format(int, float)) 1870 1871 def test_pipeline_runtime_checking_violation_with_side_inputs_via_method(self): # pylint: disable=line-too-long 1872 self.p._options.view_as(TypeOptions).runtime_type_check = True 1873 self.p._options.view_as(TypeOptions).pipeline_type_check = False 1874 1875 with self.assertRaises(typehints.TypeCheckError) as e: 1876 ( 1877 self.p 1878 | beam.Create([1, 2, 3, 4]) 1879 | ( 1880 'Add 1' >> beam.Map(lambda x, one: x + one, 1.0).with_input_types( 1881 int, int).with_output_types(float))) 1882 self.p.run() 1883 1884 self.assertStartswith( 1885 e.exception.args[0], 1886 "Runtime type violation detected within ParDo(Add 1): " 1887 "Type-hint for argument: 'one' violated. " 1888 "Expected an instance of {}, " 1889 "instead found 1.0, an instance of {}.".format(int, float)) 1890 1891 def test_combine_properly_pipeline_type_checks_using_decorator(self): 1892 @with_output_types(int) 1893 @with_input_types(ints=typing.Iterable[int]) 1894 def sum_ints(ints): 1895 return sum(ints) 1896 1897 d = ( 1898 self.p 1899 | 'T' >> beam.Create([1, 2, 3]).with_output_types(int) 1900 | 'Sum' >> beam.CombineGlobally(sum_ints)) 1901 1902 self.assertEqual(int, d.element_type) 1903 assert_that(d, equal_to([6])) 1904 self.p.run() 1905 1906 def test_combine_properly_pipeline_type_checks_without_decorator(self): 1907 def sum_ints(ints): 1908 return sum(ints) 1909 1910 d = ( 1911 self.p 1912 | beam.Create([1, 2, 3]) 1913 | beam.Map(lambda x: ('key', x)) 1914 | beam.CombinePerKey(sum_ints)) 1915 1916 self.assertEqual(typehints.Tuple[str, typehints.Any], d.element_type) 1917 self.p.run() 1918 1919 def test_combine_func_type_hint_does_not_take_iterable_using_decorator(self): 1920 @with_output_types(int) 1921 @with_input_types(a=int) 1922 def bad_combine(a): 1923 5 + a 1924 1925 with self.assertRaises(typehints.TypeCheckError) as e: 1926 ( 1927 self.p 1928 | 'M' >> beam.Create([1, 2, 3]).with_output_types(int) 1929 | 'Add' >> beam.CombineGlobally(bad_combine)) 1930 1931 self.assertEqual( 1932 "All functions for a Combine PTransform must accept a " 1933 "single argument compatible with: Iterable[Any]. " 1934 "Instead a function with input type: {} was received.".format(int), 1935 e.exception.args[0]) 1936 1937 def test_combine_pipeline_type_propagation_using_decorators(self): 1938 @with_output_types(int) 1939 @with_input_types(ints=typing.Iterable[int]) 1940 def sum_ints(ints): 1941 return sum(ints) 1942 1943 @with_output_types(typing.List[int]) 1944 @with_input_types(n=int) 1945 def range_from_zero(n): 1946 return list(range(n + 1)) 1947 1948 d = ( 1949 self.p 1950 | 'T' >> beam.Create([1, 2, 3]).with_output_types(int) 1951 | 'Sum' >> beam.CombineGlobally(sum_ints) 1952 | 'Range' >> beam.ParDo(range_from_zero)) 1953 1954 self.assertEqual(int, d.element_type) 1955 assert_that(d, equal_to([0, 1, 2, 3, 4, 5, 6])) 1956 self.p.run() 1957 1958 def test_combine_runtime_type_check_satisfied_using_decorators(self): 1959 self.p._options.view_as(TypeOptions).pipeline_type_check = False 1960 1961 @with_output_types(int) 1962 @with_input_types(ints=typing.Iterable[int]) 1963 def iter_mul(ints): 1964 return reduce(operator.mul, ints, 1) 1965 1966 d = ( 1967 self.p 1968 | 'K' >> beam.Create([5, 5, 5, 5]).with_output_types(int) 1969 | 'Mul' >> beam.CombineGlobally(iter_mul)) 1970 1971 assert_that(d, equal_to([625])) 1972 self.p.run() 1973 1974 def test_combine_runtime_type_check_violation_using_decorators(self): 1975 self.p._options.view_as(TypeOptions).pipeline_type_check = False 1976 self.p._options.view_as(TypeOptions).runtime_type_check = True 1977 1978 # Combine fn is returning the incorrect type 1979 @with_output_types(int) 1980 @with_input_types(ints=typing.Iterable[int]) 1981 def iter_mul(ints): 1982 return str(reduce(operator.mul, ints, 1)) 1983 1984 with self.assertRaises(typehints.TypeCheckError) as e: 1985 ( 1986 self.p 1987 | 'K' >> beam.Create([5, 5, 5, 5]).with_output_types(int) 1988 | 'Mul' >> beam.CombineGlobally(iter_mul)) 1989 self.p.run() 1990 1991 self.assertStartswith( 1992 e.exception.args[0], 1993 "Runtime type violation detected within " 1994 "Mul/CombinePerKey: " 1995 "Type-hint for return type violated. " 1996 "Expected an instance of {}, instead found".format(int)) 1997 1998 def test_combine_pipeline_type_check_using_methods(self): 1999 d = ( 2000 self.p 2001 | beam.Create(['t', 'e', 's', 't']).with_output_types(str) 2002 | ( 2003 'concat' >> beam.CombineGlobally(lambda s: ''.join(s)). 2004 with_input_types(str).with_output_types(str))) 2005 2006 def matcher(expected): 2007 def match(actual): 2008 equal_to(expected)(list(actual[0])) 2009 2010 return match 2011 2012 assert_that(d, matcher('estt')) 2013 self.p.run() 2014 2015 def test_combine_runtime_type_check_using_methods(self): 2016 self.p._options.view_as(TypeOptions).pipeline_type_check = False 2017 self.p._options.view_as(TypeOptions).runtime_type_check = True 2018 2019 d = ( 2020 self.p 2021 | beam.Create(range(5)).with_output_types(int) 2022 | ( 2023 'Sum' >> beam.CombineGlobally(lambda s: sum(s)).with_input_types( 2024 int).with_output_types(int))) 2025 2026 assert_that(d, equal_to([10])) 2027 self.p.run() 2028 2029 def test_combine_pipeline_type_check_violation_using_methods(self): 2030 with self.assertRaises(typehints.TypeCheckError) as e: 2031 ( 2032 self.p 2033 | beam.Create(range(3)).with_output_types(int) 2034 | ( 2035 'SortJoin' >> beam.CombineGlobally(lambda s: ''.join(sorted(s))). 2036 with_input_types(str).with_output_types(str))) 2037 2038 self.assertStartswith( 2039 e.exception.args[0], 2040 "Input type hint violation at SortJoin: " 2041 "expected {}, got {}".format(str, int)) 2042 2043 def test_combine_runtime_type_check_violation_using_methods(self): 2044 self.p._options.view_as(TypeOptions).pipeline_type_check = False 2045 self.p._options.view_as(TypeOptions).runtime_type_check = True 2046 2047 with self.assertRaises(typehints.TypeCheckError) as e: 2048 ( 2049 self.p 2050 | beam.Create([0]).with_output_types(int) 2051 | ( 2052 'SortJoin' >> beam.CombineGlobally(lambda s: ''.join(sorted(s))). 2053 with_input_types(str).with_output_types(str))) 2054 self.p.run() 2055 2056 self.assertStartswith( 2057 e.exception.args[0], 2058 "Runtime type violation detected within " 2059 "ParDo(SortJoin/KeyWithVoid): " 2060 "Type-hint for argument: 'v' violated. " 2061 "Expected an instance of {}, " 2062 "instead found 0, an instance of {}.".format(str, int)) 2063 2064 def test_combine_insufficient_type_hint_information(self): 2065 self.p._options.view_as(TypeOptions).type_check_strictness = 'ALL_REQUIRED' 2066 2067 with self.assertRaises(typehints.TypeCheckError) as e: 2068 ( 2069 self.p 2070 | 'E' >> beam.Create(range(3)).with_output_types(int) 2071 | 'SortJoin' >> beam.CombineGlobally(lambda s: ''.join(sorted(s))) 2072 | 'F' >> beam.Map(lambda x: x + 1)) 2073 2074 self.assertStartswith( 2075 e.exception.args[0], 2076 'Pipeline type checking is enabled, ' 2077 'however no output type-hint was found for the PTransform ' 2078 'ParDo(' 2079 'SortJoin/CombinePerKey/') 2080 2081 def test_mean_globally_pipeline_checking_satisfied(self): 2082 d = ( 2083 self.p 2084 | 'C' >> beam.Create(range(5)).with_output_types(int) 2085 | 'Mean' >> combine.Mean.Globally()) 2086 2087 self.assertEqual(float, d.element_type) 2088 assert_that(d, equal_to([2.0])) 2089 self.p.run() 2090 2091 def test_mean_globally_pipeline_checking_violated(self): 2092 with self.assertRaises(typehints.TypeCheckError) as e: 2093 ( 2094 self.p 2095 | 'C' >> beam.Create(['test']).with_output_types(str) 2096 | 'Mean' >> combine.Mean.Globally()) 2097 2098 expected_msg = \ 2099 "Type hint violation for 'CombinePerKey': " \ 2100 "requires Tuple[TypeVariable[K], Union[<class 'float'>, <class 'int'>, " \ 2101 "<class 'numpy.float64'>, <class 'numpy.int64'>]] " \ 2102 "but got Tuple[None, <class 'str'>] for element" 2103 2104 self.assertStartswith(e.exception.args[0], expected_msg) 2105 2106 def test_mean_globally_runtime_checking_satisfied(self): 2107 self.p._options.view_as(TypeOptions).runtime_type_check = True 2108 2109 d = ( 2110 self.p 2111 | 'C' >> beam.Create(range(5)).with_output_types(int) 2112 | 'Mean' >> combine.Mean.Globally()) 2113 2114 self.assertEqual(float, d.element_type) 2115 assert_that(d, equal_to([2.0])) 2116 self.p.run() 2117 2118 def test_mean_globally_runtime_checking_violated(self): 2119 self.p._options.view_as(TypeOptions).pipeline_type_check = False 2120 self.p._options.view_as(TypeOptions).runtime_type_check = True 2121 2122 with self.assertRaises(typehints.TypeCheckError) as e: 2123 ( 2124 self.p 2125 | 'C' >> beam.Create(['t', 'e', 's', 't']).with_output_types(str) 2126 | 'Mean' >> combine.Mean.Globally()) 2127 self.p.run() 2128 self.assertEqual( 2129 "Runtime type violation detected for transform input " 2130 "when executing ParDoFlatMap(Combine): Tuple[Any, " 2131 "Iterable[Union[int, float]]] hint type-constraint " 2132 "violated. The type of element #1 in the passed tuple " 2133 "is incorrect. Iterable[Union[int, float]] hint " 2134 "type-constraint violated. The type of element #0 in " 2135 "the passed Iterable is incorrect: Union[int, float] " 2136 "type-constraint violated. Expected an instance of one " 2137 "of: ('int', 'float'), received str instead.", 2138 e.exception.args[0]) 2139 2140 def test_mean_per_key_pipeline_checking_satisfied(self): 2141 d = ( 2142 self.p 2143 | beam.Create(range(5)).with_output_types(int) 2144 | ( 2145 'EvenGroup' >> beam.Map(lambda x: (not x % 2, x)).with_output_types( 2146 typing.Tuple[bool, int])) 2147 | 'EvenMean' >> combine.Mean.PerKey()) 2148 2149 self.assertCompatible(typing.Tuple[bool, float], d.element_type) 2150 assert_that(d, equal_to([(False, 2.0), (True, 2.0)])) 2151 self.p.run() 2152 2153 def test_mean_per_key_pipeline_checking_violated(self): 2154 with self.assertRaises(typehints.TypeCheckError) as e: 2155 ( 2156 self.p 2157 | beam.Create(map(str, range(5))).with_output_types(str) 2158 | ( 2159 'UpperPair' >> beam.Map(lambda x: 2160 (x.upper(), x)).with_output_types( 2161 typing.Tuple[str, str])) 2162 | 'EvenMean' >> combine.Mean.PerKey()) 2163 self.p.run() 2164 2165 expected_msg = \ 2166 "Type hint violation for 'CombinePerKey(MeanCombineFn)': " \ 2167 "requires Tuple[TypeVariable[K], Union[<class 'float'>, <class 'int'>, " \ 2168 "<class 'numpy.float64'>, <class 'numpy.int64'>]] " \ 2169 "but got Tuple[<class 'str'>, <class 'str'>] for element" 2170 2171 self.assertStartswith(e.exception.args[0], expected_msg) 2172 2173 def test_mean_per_key_runtime_checking_satisfied(self): 2174 self.p._options.view_as(TypeOptions).runtime_type_check = True 2175 2176 d = ( 2177 self.p 2178 | beam.Create(range(5)).with_output_types(int) 2179 | ( 2180 'OddGroup' >> beam.Map(lambda x: 2181 (bool(x % 2), x)).with_output_types( 2182 typing.Tuple[bool, int])) 2183 | 'OddMean' >> combine.Mean.PerKey()) 2184 2185 self.assertCompatible(typing.Tuple[bool, float], d.element_type) 2186 assert_that(d, equal_to([(False, 2.0), (True, 2.0)])) 2187 self.p.run() 2188 2189 def test_mean_per_key_runtime_checking_violated(self): 2190 self.p._options.view_as(TypeOptions).pipeline_type_check = False 2191 self.p._options.view_as(TypeOptions).runtime_type_check = True 2192 2193 with self.assertRaises(typehints.TypeCheckError) as e: 2194 ( 2195 self.p 2196 | beam.Create(range(5)).with_output_types(int) 2197 | ( 2198 'OddGroup' >> beam.Map(lambda x: 2199 (x, str(bool(x % 2)))).with_output_types( 2200 typing.Tuple[int, str])) 2201 | 'OddMean' >> combine.Mean.PerKey()) 2202 self.p.run() 2203 2204 expected_msg = \ 2205 "Runtime type violation detected within " \ 2206 "OddMean/CombinePerKey(MeanCombineFn): " \ 2207 "Type-hint for argument: 'element' violated: " \ 2208 "Union[<class 'float'>, <class 'int'>, <class 'numpy.float64'>, <class " \ 2209 "'numpy.int64'>] type-constraint violated. " \ 2210 "Expected an instance of one of: (\"<class 'float'>\", \"<class " \ 2211 "'int'>\", \"<class 'numpy.float64'>\", \"<class 'numpy.int64'>\"), " \ 2212 "received str instead" 2213 2214 self.assertStartswith(e.exception.args[0], expected_msg) 2215 2216 def test_count_globally_pipeline_type_checking_satisfied(self): 2217 d = ( 2218 self.p 2219 | 'P' >> beam.Create(range(5)).with_output_types(int) 2220 | 'CountInt' >> combine.Count.Globally()) 2221 2222 self.assertEqual(int, d.element_type) 2223 assert_that(d, equal_to([5])) 2224 self.p.run() 2225 2226 def test_count_globally_runtime_type_checking_satisfied(self): 2227 self.p._options.view_as(TypeOptions).runtime_type_check = True 2228 2229 d = ( 2230 self.p 2231 | 'P' >> beam.Create(range(5)).with_output_types(int) 2232 | 'CountInt' >> combine.Count.Globally()) 2233 2234 self.assertEqual(int, d.element_type) 2235 assert_that(d, equal_to([5])) 2236 self.p.run() 2237 2238 def test_count_perkey_pipeline_type_checking_satisfied(self): 2239 d = ( 2240 self.p 2241 | beam.Create(range(5)).with_output_types(int) 2242 | 'EvenGroup' >> beam.Map(lambda x: (not x % 2, x)).with_output_types( 2243 typing.Tuple[bool, int]) 2244 | 'CountInt' >> combine.Count.PerKey()) 2245 2246 self.assertCompatible(typing.Tuple[bool, int], d.element_type) 2247 assert_that(d, equal_to([(False, 2), (True, 3)])) 2248 self.p.run() 2249 2250 def test_count_perkey_pipeline_type_checking_violated(self): 2251 with self.assertRaises(typehints.TypeCheckError) as e: 2252 ( 2253 self.p 2254 | beam.Create(range(5)).with_output_types(int) 2255 | 'CountInt' >> combine.Count.PerKey()) 2256 2257 self.assertStartswith( 2258 e.exception.args[0], 'Input type hint violation at CountInt') 2259 2260 def test_count_perkey_runtime_type_checking_satisfied(self): 2261 self.p._options.view_as(TypeOptions).runtime_type_check = True 2262 2263 d = ( 2264 self.p 2265 | beam.Create(['t', 'e', 's', 't']).with_output_types(str) 2266 | 'DupKey' >> beam.Map(lambda x: (x, x)).with_output_types( 2267 typing.Tuple[str, str]) 2268 | 'CountDups' >> combine.Count.PerKey()) 2269 2270 self.assertCompatible(typing.Tuple[str, int], d.element_type) 2271 assert_that(d, equal_to([('e', 1), ('s', 1), ('t', 2)])) 2272 self.p.run() 2273 2274 def test_count_perelement_pipeline_type_checking_satisfied(self): 2275 d = ( 2276 self.p 2277 | beam.Create([1, 1, 2, 3]).with_output_types(int) 2278 | 'CountElems' >> combine.Count.PerElement()) 2279 2280 self.assertCompatible(typing.Tuple[int, int], d.element_type) 2281 assert_that(d, equal_to([(1, 2), (2, 1), (3, 1)])) 2282 self.p.run() 2283 2284 def test_count_perelement_pipeline_type_checking_violated(self): 2285 self.p._options.view_as(TypeOptions).type_check_strictness = 'ALL_REQUIRED' 2286 2287 with self.assertRaises(typehints.TypeCheckError) as e: 2288 ( 2289 self.p 2290 | 'f' >> beam.Create([1, 1, 2, 3]) 2291 | 'CountElems' >> combine.Count.PerElement()) 2292 2293 self.assertEqual( 2294 'Pipeline type checking is enabled, however no output ' 2295 'type-hint was found for the PTransform ' 2296 'Create(f)', 2297 e.exception.args[0]) 2298 2299 def test_count_perelement_runtime_type_checking_satisfied(self): 2300 self.p._options.view_as(TypeOptions).runtime_type_check = True 2301 2302 d = ( 2303 self.p 2304 | beam.Create([True, True, False, True, True]).with_output_types(bool) 2305 | 'CountElems' >> combine.Count.PerElement()) 2306 2307 self.assertCompatible(typing.Tuple[bool, int], d.element_type) 2308 assert_that(d, equal_to([(False, 1), (True, 4)])) 2309 self.p.run() 2310 2311 def test_top_of_pipeline_checking_satisfied(self): 2312 d = ( 2313 self.p 2314 | beam.Create(range(5, 11)).with_output_types(int) 2315 | 'Top 3' >> combine.Top.Of(3)) 2316 2317 self.assertCompatible(typing.Iterable[int], d.element_type) 2318 assert_that(d, equal_to([[10, 9, 8]])) 2319 self.p.run() 2320 2321 def test_top_of_runtime_checking_satisfied(self): 2322 self.p._options.view_as(TypeOptions).runtime_type_check = True 2323 2324 d = ( 2325 self.p 2326 | beam.Create(list('testing')).with_output_types(str) 2327 | 'AciiTop' >> combine.Top.Of(3)) 2328 2329 self.assertCompatible(typing.Iterable[str], d.element_type) 2330 assert_that(d, equal_to([['t', 't', 's']])) 2331 self.p.run() 2332 2333 def test_per_key_pipeline_checking_violated(self): 2334 with self.assertRaises(typehints.TypeCheckError) as e: 2335 ( 2336 self.p 2337 | beam.Create(range(100)).with_output_types(int) 2338 | 'Num + 1' >> beam.Map(lambda x: x + 1).with_output_types(int) 2339 | 'TopMod' >> combine.Top.PerKey(1)) 2340 2341 self.assertStartswith( 2342 e.exception.args[0], 2343 "Input type hint violation at TopMod: expected Tuple[TypeVariable[K], " 2344 "TypeVariable[V]], got {}".format(int)) 2345 2346 def test_per_key_pipeline_checking_satisfied(self): 2347 d = ( 2348 self.p 2349 | beam.Create(range(100)).with_output_types(int) 2350 | ( 2351 'GroupMod 3' >> beam.Map(lambda x: (x % 3, x)).with_output_types( 2352 typing.Tuple[int, int])) 2353 | 'TopMod' >> combine.Top.PerKey(1)) 2354 2355 self.assertCompatible( 2356 typing.Tuple[int, typing.Iterable[int]], d.element_type) 2357 assert_that(d, equal_to([(0, [99]), (1, [97]), (2, [98])])) 2358 self.p.run() 2359 2360 def test_per_key_runtime_checking_satisfied(self): 2361 self.p._options.view_as(TypeOptions).runtime_type_check = True 2362 2363 d = ( 2364 self.p 2365 | beam.Create(range(21)) 2366 | ( 2367 'GroupMod 3' >> beam.Map(lambda x: (x % 3, x)).with_output_types( 2368 typing.Tuple[int, int])) 2369 | 'TopMod' >> combine.Top.PerKey(1)) 2370 2371 self.assertCompatible( 2372 typing.Tuple[int, typing.Iterable[int]], d.element_type) 2373 assert_that(d, equal_to([(0, [18]), (1, [19]), (2, [20])])) 2374 self.p.run() 2375 2376 def test_sample_globally_pipeline_satisfied(self): 2377 d = ( 2378 self.p 2379 | beam.Create([2, 2, 3, 3]).with_output_types(int) 2380 | 'Sample' >> combine.Sample.FixedSizeGlobally(3)) 2381 2382 self.assertCompatible(typing.Iterable[int], d.element_type) 2383 2384 def matcher(expected_len): 2385 def match(actual): 2386 equal_to([expected_len])([len(actual[0])]) 2387 2388 return match 2389 2390 assert_that(d, matcher(3)) 2391 self.p.run() 2392 2393 def test_sample_globally_runtime_satisfied(self): 2394 self.p._options.view_as(TypeOptions).runtime_type_check = True 2395 2396 d = ( 2397 self.p 2398 | beam.Create([2, 2, 3, 3]).with_output_types(int) 2399 | 'Sample' >> combine.Sample.FixedSizeGlobally(2)) 2400 2401 self.assertCompatible(typing.Iterable[int], d.element_type) 2402 2403 def matcher(expected_len): 2404 def match(actual): 2405 equal_to([expected_len])([len(actual[0])]) 2406 2407 return match 2408 2409 assert_that(d, matcher(2)) 2410 self.p.run() 2411 2412 def test_sample_per_key_pipeline_satisfied(self): 2413 d = ( 2414 self.p 2415 | ( 2416 beam.Create([(1, 2), (1, 2), (2, 3), 2417 (2, 3)]).with_output_types(typing.Tuple[int, int])) 2418 | 'Sample' >> combine.Sample.FixedSizePerKey(2)) 2419 2420 self.assertCompatible( 2421 typing.Tuple[int, typing.Iterable[int]], d.element_type) 2422 2423 def matcher(expected_len): 2424 def match(actual): 2425 for _, sample in actual: 2426 equal_to([expected_len])([len(sample)]) 2427 2428 return match 2429 2430 assert_that(d, matcher(2)) 2431 self.p.run() 2432 2433 def test_sample_per_key_runtime_satisfied(self): 2434 self.p._options.view_as(TypeOptions).runtime_type_check = True 2435 2436 d = ( 2437 self.p 2438 | ( 2439 beam.Create([(1, 2), (1, 2), (2, 3), 2440 (2, 3)]).with_output_types(typing.Tuple[int, int])) 2441 | 'Sample' >> combine.Sample.FixedSizePerKey(1)) 2442 2443 self.assertCompatible( 2444 typing.Tuple[int, typing.Iterable[int]], d.element_type) 2445 2446 def matcher(expected_len): 2447 def match(actual): 2448 for _, sample in actual: 2449 equal_to([expected_len])([len(sample)]) 2450 2451 return match 2452 2453 assert_that(d, matcher(1)) 2454 self.p.run() 2455 2456 def test_to_list_pipeline_check_satisfied(self): 2457 d = ( 2458 self.p 2459 | beam.Create((1, 2, 3, 4)).with_output_types(int) 2460 | combine.ToList()) 2461 2462 self.assertCompatible(typing.List[int], d.element_type) 2463 2464 def matcher(expected): 2465 def match(actual): 2466 equal_to(expected)(actual[0]) 2467 2468 return match 2469 2470 assert_that(d, matcher([1, 2, 3, 4])) 2471 self.p.run() 2472 2473 def test_to_list_runtime_check_satisfied(self): 2474 self.p._options.view_as(TypeOptions).runtime_type_check = True 2475 2476 d = ( 2477 self.p 2478 | beam.Create(list('test')).with_output_types(str) 2479 | combine.ToList()) 2480 2481 self.assertCompatible(typing.List[str], d.element_type) 2482 2483 def matcher(expected): 2484 def match(actual): 2485 equal_to(expected)(actual[0]) 2486 2487 return match 2488 2489 assert_that(d, matcher(['e', 's', 't', 't'])) 2490 self.p.run() 2491 2492 def test_to_dict_pipeline_check_violated(self): 2493 with self.assertRaises(typehints.TypeCheckError) as e: 2494 ( 2495 self.p 2496 | beam.Create([1, 2, 3, 4]).with_output_types(int) 2497 | combine.ToDict()) 2498 2499 self.assertStartswith( 2500 e.exception.args[0], 2501 "Input type hint violation at ToDict: expected Tuple[TypeVariable[K], " 2502 "TypeVariable[V]], got {}".format(int)) 2503 2504 def test_to_dict_pipeline_check_satisfied(self): 2505 d = ( 2506 self.p 2507 | beam.Create([(1, 2), 2508 (3, 4)]).with_output_types(typing.Tuple[int, int]) 2509 | combine.ToDict()) 2510 2511 self.assertCompatible(typing.Dict[int, int], d.element_type) 2512 assert_that(d, equal_to([{1: 2, 3: 4}])) 2513 self.p.run() 2514 2515 def test_to_dict_runtime_check_satisfied(self): 2516 self.p._options.view_as(TypeOptions).runtime_type_check = True 2517 2518 d = ( 2519 self.p 2520 | ( 2521 beam.Create([('1', 2), 2522 ('3', 4)]).with_output_types(typing.Tuple[str, int])) 2523 | combine.ToDict()) 2524 2525 self.assertCompatible(typing.Dict[str, int], d.element_type) 2526 assert_that(d, equal_to([{'1': 2, '3': 4}])) 2527 self.p.run() 2528 2529 def test_runtime_type_check_python_type_error(self): 2530 self.p._options.view_as(TypeOptions).runtime_type_check = True 2531 2532 with self.assertRaises(TypeError) as e: 2533 ( 2534 self.p 2535 | beam.Create([1, 2, 3]).with_output_types(int) 2536 | 'Len' >> beam.Map(lambda x: len(x)).with_output_types(int)) 2537 self.p.run() 2538 2539 # Our special type-checking related TypeError shouldn't have been raised. 2540 # Instead the above pipeline should have triggered a regular Python runtime 2541 # TypeError. 2542 self.assertEqual( 2543 "object of type 'int' has no len() [while running 'Len']", 2544 e.exception.args[0]) 2545 self.assertFalse(isinstance(e, typehints.TypeCheckError)) 2546 2547 def test_pardo_type_inference(self): 2548 self.assertEqual(int, beam.Filter(lambda x: False).infer_output_type(int)) 2549 self.assertEqual( 2550 typehints.Tuple[str, int], 2551 beam.Map(lambda x: (x, 1)).infer_output_type(str)) 2552 2553 def test_gbk_type_inference(self): 2554 self.assertEqual( 2555 typehints.Tuple[str, typehints.Iterable[int]], 2556 beam.GroupByKey().infer_output_type(typehints.KV[str, int])) 2557 2558 def test_pipeline_inference(self): 2559 created = self.p | beam.Create(['a', 'b', 'c']) 2560 mapped = created | 'pair with 1' >> beam.Map(lambda x: (x, 1)) 2561 grouped = mapped | beam.GroupByKey() 2562 self.assertEqual(str, created.element_type) 2563 self.assertEqual(typehints.KV[str, int], mapped.element_type) 2564 self.assertEqual( 2565 typehints.KV[str, typehints.Iterable[int]], grouped.element_type) 2566 2567 def test_inferred_bad_kv_type(self): 2568 with self.assertRaises(typehints.TypeCheckError) as e: 2569 _ = ( 2570 self.p 2571 | beam.Create(['a', 'b', 'c']) 2572 | 'Ungroupable' >> beam.Map(lambda x: (x, 0, 1.0)) 2573 | beam.GroupByKey()) 2574 2575 self.assertStartswith( 2576 e.exception.args[0], 2577 "Input type hint violation at GroupByKey: " 2578 "expected Tuple[TypeVariable[K], TypeVariable[V]], " 2579 "got Tuple[<class 'str'>, <class 'int'>, <class 'float'>]") 2580 2581 def test_type_inference_command_line_flag_toggle(self): 2582 self.p._options.view_as(TypeOptions).pipeline_type_check = False 2583 x = self.p | 'C1' >> beam.Create([1, 2, 3, 4]) 2584 self.assertIsNone(x.element_type) 2585 2586 self.p._options.view_as(TypeOptions).pipeline_type_check = True 2587 x = self.p | 'C2' >> beam.Create([1, 2, 3, 4]) 2588 self.assertEqual(int, x.element_type) 2589 2590 def test_eager_execution(self): 2591 doubled = [1, 2, 3, 4] | beam.Map(lambda x: 2 * x) 2592 self.assertEqual([2, 4, 6, 8], doubled) 2593 2594 def test_eager_execution_tagged_outputs(self): 2595 result = [1, 2, 3, 4] | beam.Map( 2596 lambda x: pvalue.TaggedOutput('bar', 2 * x)).with_outputs('bar') 2597 self.assertEqual([2, 4, 6, 8], result.bar) 2598 with self.assertRaises(KeyError, 2599 msg='Tag \'foo\' is not a defined output tag'): 2600 result.foo 2601 2602 2603 @parameterized_class([{'use_subprocess': False}, {'use_subprocess': True}]) 2604 class DeadLettersTest(unittest.TestCase): 2605 @classmethod 2606 def die(cls, x): 2607 if cls.use_subprocess: 2608 os._exit(x) 2609 else: 2610 raise ValueError(x) 2611 2612 @classmethod 2613 def die_if_negative(cls, x): 2614 if x < 0: 2615 cls.die(x) 2616 else: 2617 return x 2618 2619 @classmethod 2620 def exception_if_negative(cls, x): 2621 if x < 0: 2622 raise ValueError(x) 2623 else: 2624 return x 2625 2626 @classmethod 2627 def die_if_less(cls, x, bound=0): 2628 if x < bound: 2629 cls.die(x) 2630 else: 2631 return x, bound 2632 2633 def test_error_messages(self): 2634 with TestPipeline() as p: 2635 good, bad = ( 2636 p 2637 | beam.Create([-1, 10, -100, 2, 0]) 2638 | beam.Map(self.exception_if_negative).with_exception_handling()) 2639 assert_that(good, equal_to([0, 2, 10]), label='CheckGood') 2640 assert_that( 2641 bad | 2642 beam.MapTuple(lambda e, exc_info: (e, exc_info[1].replace(',', ''))), 2643 equal_to([(-1, 'ValueError(-1)'), (-100, 'ValueError(-100)')]), 2644 label='CheckBad') 2645 2646 def test_filters_exceptions(self): 2647 with TestPipeline() as p: 2648 good, _ = ( 2649 p 2650 | beam.Create([-1, 10, -100, 2, 0]) 2651 | beam.Map(self.exception_if_negative).with_exception_handling( 2652 use_subprocess=self.use_subprocess, 2653 exc_class=(ValueError, TypeError))) 2654 assert_that(good, equal_to([0, 2, 10]), label='CheckGood') 2655 2656 with self.assertRaises(Exception): 2657 with TestPipeline() as p: 2658 good, _ = ( 2659 p 2660 | beam.Create([-1, 10, -100, 2, 0]) 2661 | beam.Map(self.die_if_negative).with_exception_handling( 2662 use_subprocess=self.use_subprocess, 2663 exc_class=TypeError)) 2664 2665 def test_tuples(self): 2666 2667 with TestPipeline() as p: 2668 good, _ = ( 2669 p 2670 | beam.Create([(1, 2), (3, 2), (1, -10)]) 2671 | beam.MapTuple(self.die_if_less).with_exception_handling( 2672 use_subprocess=self.use_subprocess)) 2673 assert_that(good, equal_to([(3, 2), (1, -10)]), label='CheckGood') 2674 2675 def test_side_inputs(self): 2676 2677 with TestPipeline() as p: 2678 input = p | beam.Create([-1, 10, 100]) 2679 2680 assert_that(( 2681 input 2682 | 'Default' >> beam.Map(self.die_if_less).with_exception_handling( 2683 use_subprocess=self.use_subprocess)).good, 2684 equal_to([(10, 0), (100, 0)]), 2685 label='CheckDefault') 2686 assert_that(( 2687 input 2688 | 'Pos' >> beam.Map(self.die_if_less, 20).with_exception_handling( 2689 use_subprocess=self.use_subprocess)).good, 2690 equal_to([(100, 20)]), 2691 label='PosSideInput') 2692 assert_that(( 2693 input 2694 | 2695 'Key' >> beam.Map(self.die_if_less, bound=30).with_exception_handling( 2696 use_subprocess=self.use_subprocess)).good, 2697 equal_to([(100, 30)]), 2698 label='KeySideInput') 2699 2700 def test_multiple_outputs(self): 2701 die = type(self).die 2702 2703 def die_on_negative_even_odd(x): 2704 if x < 0: 2705 die(x) 2706 elif x % 2 == 0: 2707 return pvalue.TaggedOutput('even', x) 2708 elif x % 2 == 1: 2709 return pvalue.TaggedOutput('odd', x) 2710 2711 with TestPipeline() as p: 2712 results = ( 2713 p 2714 | beam.Create([1, -1, 2, -2, 3]) 2715 | beam.Map(die_on_negative_even_odd).with_exception_handling( 2716 use_subprocess=self.use_subprocess)) 2717 assert_that(results.even, equal_to([2]), label='CheckEven') 2718 assert_that(results.odd, equal_to([1, 3]), label='CheckOdd') 2719 2720 def test_params(self): 2721 die = type(self).die 2722 2723 def die_if_negative_with_timestamp(x, ts=beam.DoFn.TimestampParam): 2724 if x < 0: 2725 die(x) 2726 else: 2727 return x, ts 2728 2729 with TestPipeline() as p: 2730 good, _ = ( 2731 p 2732 | beam.Create([-1, 0, 1]) 2733 | beam.Map(lambda x: TimestampedValue(x, x)) 2734 | beam.Map(die_if_negative_with_timestamp).with_exception_handling( 2735 use_subprocess=self.use_subprocess)) 2736 assert_that(good, equal_to([(0, Timestamp(0)), (1, Timestamp(1))])) 2737 2738 def test_timeout(self): 2739 import time 2740 timeout = 1 if self.use_subprocess else .1 2741 2742 with TestPipeline() as p: 2743 good, bad = ( 2744 p 2745 | beam.Create('records starting with lowercase S are slow'.split()) 2746 | beam.Map( 2747 lambda x: time.sleep(2.5 * timeout) if x.startswith('s') else x) 2748 .with_exception_handling( 2749 use_subprocess=self.use_subprocess, timeout=timeout)) 2750 assert_that( 2751 good, 2752 equal_to(['records', 'with', 'lowercase', 'S', 'are']), 2753 label='CheckGood') 2754 assert_that( 2755 bad | 2756 beam.MapTuple(lambda e, exc_info: (e, exc_info[1].replace(',', ''))), 2757 equal_to([('starting', 'TimeoutError()'), 2758 ('slow', 'TimeoutError()')]), 2759 label='CheckBad') 2760 2761 def test_lifecycle(self): 2762 die = type(self).die 2763 2764 class MyDoFn(beam.DoFn): 2765 state = None 2766 2767 def setup(self): 2768 assert self.state is None 2769 self.state = 'setup' 2770 2771 def start_bundle(self): 2772 assert self.state in ('setup', 'finish_bundle'), self.state 2773 self.state = 'start_bundle' 2774 2775 def finish_bundle(self): 2776 assert self.state in ('start_bundle', ), self.state 2777 self.state = 'finish_bundle' 2778 2779 def teardown(self): 2780 assert self.state in ('setup', 'finish_bundle'), self.state 2781 self.state = 'teardown' 2782 2783 def process(self, x): 2784 if x < 0: 2785 die(x) 2786 else: 2787 yield self.state 2788 2789 with TestPipeline() as p: 2790 good, _ = ( 2791 p 2792 | beam.Create([-1, 0, 1, -10, 10]) 2793 | beam.ParDo(MyDoFn()).with_exception_handling( 2794 use_subprocess=self.use_subprocess)) 2795 assert_that(good, equal_to(['start_bundle'] * 3)) 2796 2797 def test_partial(self): 2798 if self.use_subprocess: 2799 self.skipTest('Subprocess and partial mutally exclusive.') 2800 2801 def die_if_negative_iter(elements): 2802 for element in elements: 2803 if element < 0: 2804 raise ValueError(element) 2805 yield element 2806 2807 with TestPipeline() as p: 2808 input = p | beam.Create([(-1, 1, 11), (2, -2, 22), (3, 33, -3), (4, 44)]) 2809 2810 assert_that(( 2811 input 2812 | 'Partial' >> beam.FlatMap( 2813 die_if_negative_iter).with_exception_handling(partial=True)).good, 2814 equal_to([2, 3, 33, 4, 44]), 2815 'CheckPartial') 2816 2817 assert_that(( 2818 input 2819 | 'Complete' >> beam.FlatMap(die_if_negative_iter). 2820 with_exception_handling(partial=False)).good, 2821 equal_to([4, 44]), 2822 'CheckComplete') 2823 2824 def test_threshold(self): 2825 # The threshold is high enough. 2826 with TestPipeline() as p: 2827 _ = ( 2828 p 2829 | beam.Create([-1, -2, 0, 1, 2, 3, 4, 5]) 2830 | beam.Map(self.die_if_negative).with_exception_handling( 2831 threshold=0.5, use_subprocess=self.use_subprocess)) 2832 2833 # The threshold is too low enough. 2834 with self.assertRaisesRegex(Exception, "2 / 8 = 0.25 > 0.1"): 2835 with TestPipeline() as p: 2836 _ = ( 2837 p 2838 | beam.Create([-1, -2, 0, 1, 2, 3, 4, 5]) 2839 | beam.Map(self.die_if_negative).with_exception_handling( 2840 threshold=0.1, use_subprocess=self.use_subprocess)) 2841 2842 # The threshold is too low per window. 2843 with self.assertRaisesRegex(Exception, "2 / 2 = 1.0 > 0.5"): 2844 with TestPipeline() as p: 2845 _ = ( 2846 p 2847 | beam.Create([-1, -2, 0, 1, 2, 3, 4, 5]) 2848 | beam.Map(lambda x: TimestampedValue(x, x)) 2849 | beam.Map(self.die_if_negative).with_exception_handling( 2850 threshold=0.5, 2851 threshold_windowing=window.FixedWindows(10), 2852 use_subprocess=self.use_subprocess)) 2853 2854 2855 class TestPTransformFn(TypeHintTestCase): 2856 def test_type_checking_fail(self): 2857 @beam.ptransform_fn 2858 def MyTransform(pcoll): 2859 return pcoll | beam.ParDo(lambda x: [x]).with_output_types(str) 2860 2861 p = TestPipeline() 2862 with self.assertRaisesRegex(beam.typehints.TypeCheckError, 2863 r'expected.*int.*got.*str'): 2864 _ = (p | beam.Create([1, 2]) | MyTransform().with_output_types(int)) 2865 2866 def test_type_checking_success(self): 2867 @beam.ptransform_fn 2868 def MyTransform(pcoll): 2869 return pcoll | beam.ParDo(lambda x: [x]).with_output_types(int) 2870 2871 with TestPipeline() as p: 2872 _ = (p | beam.Create([1, 2]) | MyTransform().with_output_types(int)) 2873 2874 def test_type_hints_arg(self): 2875 # Tests passing type hints via the magic 'type_hints' argument name. 2876 @beam.ptransform_fn 2877 def MyTransform(pcoll, type_hints, test_arg): 2878 self.assertEqual(test_arg, 'test') 2879 return ( 2880 pcoll 2881 | beam.ParDo(lambda x: [x]).with_output_types( 2882 type_hints.output_types[0][0])) 2883 2884 with TestPipeline() as p: 2885 _ = (p | beam.Create([1, 2]) | MyTransform('test').with_output_types(int)) 2886 2887 2888 class PickledObject(object): 2889 def __init__(self, value): 2890 self.value = value 2891 2892 2893 if __name__ == '__main__': 2894 unittest.main()