github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/pipeline_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Unit tests for the Pipeline class.""" 19 20 # pytype: skip-file 21 22 import copy 23 import platform 24 import unittest 25 26 import mock 27 import pytest 28 29 import apache_beam as beam 30 from apache_beam import typehints 31 from apache_beam.coders import BytesCoder 32 from apache_beam.io import Read 33 from apache_beam.metrics import Metrics 34 from apache_beam.options.pipeline_options import PortableOptions 35 from apache_beam.pipeline import Pipeline 36 from apache_beam.pipeline import PipelineOptions 37 from apache_beam.pipeline import PipelineVisitor 38 from apache_beam.pipeline import PTransformOverride 39 from apache_beam.portability import common_urns 40 from apache_beam.portability.api import beam_runner_api_pb2 41 from apache_beam.pvalue import AsSingleton 42 from apache_beam.pvalue import TaggedOutput 43 from apache_beam.runners.dataflow.native_io.iobase import NativeSource 44 from apache_beam.testing.test_pipeline import TestPipeline 45 from apache_beam.testing.util import assert_that 46 from apache_beam.testing.util import equal_to 47 from apache_beam.transforms import CombineGlobally 48 from apache_beam.transforms import Create 49 from apache_beam.transforms import DoFn 50 from apache_beam.transforms import FlatMap 51 from apache_beam.transforms import Map 52 from apache_beam.transforms import ParDo 53 from apache_beam.transforms import PTransform 54 from apache_beam.transforms import WindowInto 55 from apache_beam.transforms.display import DisplayDataItem 56 from apache_beam.transforms.environments import ProcessEnvironment 57 from apache_beam.transforms.resources import ResourceHint 58 from apache_beam.transforms.userstate import BagStateSpec 59 from apache_beam.transforms.window import SlidingWindows 60 from apache_beam.transforms.window import TimestampedValue 61 from apache_beam.utils import windowed_value 62 from apache_beam.utils.timestamp import MIN_TIMESTAMP 63 64 # TODO(BEAM-1555): Test is failing on the service, with FakeSource. 65 66 67 class FakeSource(NativeSource): 68 """Fake source returning a fixed list of values.""" 69 class _Reader(object): 70 def __init__(self, vals): 71 self._vals = vals 72 self._output_counter = Metrics.counter('main', 'outputs') 73 74 def __enter__(self): 75 return self 76 77 def __exit__(self, exception_type, exception_value, traceback): 78 pass 79 80 def __iter__(self): 81 for v in self._vals: 82 self._output_counter.inc() 83 yield v 84 85 def __init__(self, vals): 86 self._vals = vals 87 88 def reader(self): 89 return FakeSource._Reader(self._vals) 90 91 92 class FakeUnboundedSource(NativeSource): 93 """Fake unbounded source. Does not work at runtime""" 94 def reader(self): 95 return None 96 97 def is_bounded(self): 98 return False 99 100 101 class DoubleParDo(beam.PTransform): 102 def expand(self, input): 103 return input | 'Inner' >> beam.Map(lambda a: a * 2) 104 105 def to_runner_api_parameter(self, context): 106 return self.to_runner_api_pickled(context) 107 108 109 class TripleParDo(beam.PTransform): 110 def expand(self, input): 111 # Keeping labels the same intentionally to make sure that there is no label 112 # conflict due to replacement. 113 return input | 'Inner' >> beam.Map(lambda a: a * 3) 114 115 116 class ToStringParDo(beam.PTransform): 117 def expand(self, input): 118 # We use copy.copy() here to make sure the typehint mechanism doesn't 119 # automatically infer that the output type is str. 120 return input | 'Inner' >> beam.Map(lambda a: copy.copy(str(a))) 121 122 123 class FlattenAndDouble(beam.PTransform): 124 def expand(self, pcolls): 125 return pcolls | beam.Flatten() | 'Double' >> DoubleParDo() 126 127 128 class FlattenAndTriple(beam.PTransform): 129 def expand(self, pcolls): 130 return pcolls | beam.Flatten() | 'Triple' >> TripleParDo() 131 132 133 class AddWithProductDoFn(beam.DoFn): 134 def process(self, input, a, b): 135 yield input + a * b 136 137 138 class AddThenMultiplyDoFn(beam.DoFn): 139 def process(self, input, a, b): 140 yield (input + a) * b 141 142 143 class AddThenMultiply(beam.PTransform): 144 def expand(self, pvalues): 145 return pvalues[0] | beam.ParDo( 146 AddThenMultiplyDoFn(), AsSingleton(pvalues[1]), AsSingleton(pvalues[2])) 147 148 149 class PipelineTest(unittest.TestCase): 150 @staticmethod 151 def custom_callable(pcoll): 152 return pcoll | '+1' >> FlatMap(lambda x: [x + 1]) 153 154 # Some of these tests designate a runner by name, others supply a runner. 155 # This variation is just to verify that both means of runner specification 156 # work and is not related to other aspects of the tests. 157 158 class CustomTransform(PTransform): 159 def expand(self, pcoll): 160 return pcoll | '+1' >> FlatMap(lambda x: [x + 1]) 161 162 class Visitor(PipelineVisitor): 163 def __init__(self, visited): 164 self.visited = visited 165 self.enter_composite = [] 166 self.leave_composite = [] 167 168 def visit_value(self, value, _): 169 self.visited.append(value) 170 171 def enter_composite_transform(self, transform_node): 172 self.enter_composite.append(transform_node) 173 174 def leave_composite_transform(self, transform_node): 175 self.leave_composite.append(transform_node) 176 177 def test_create(self): 178 with TestPipeline() as pipeline: 179 pcoll = pipeline | 'label1' >> Create([1, 2, 3]) 180 assert_that(pcoll, equal_to([1, 2, 3])) 181 182 # Test if initial value is an iterator object. 183 pcoll2 = pipeline | 'label2' >> Create(iter((4, 5, 6))) 184 pcoll3 = pcoll2 | 'do' >> FlatMap(lambda x: [x + 10]) 185 assert_that(pcoll3, equal_to([14, 15, 16]), label='pcoll3') 186 187 def test_flatmap_builtin(self): 188 with TestPipeline() as pipeline: 189 pcoll = pipeline | 'label1' >> Create([1, 2, 3]) 190 assert_that(pcoll, equal_to([1, 2, 3])) 191 192 pcoll2 = pcoll | 'do' >> FlatMap(lambda x: [x + 10]) 193 assert_that(pcoll2, equal_to([11, 12, 13]), label='pcoll2') 194 195 pcoll3 = pcoll2 | 'm1' >> Map(lambda x: [x, 12]) 196 assert_that( 197 pcoll3, equal_to([[11, 12], [12, 12], [13, 12]]), label='pcoll3') 198 199 pcoll4 = pcoll3 | 'do2' >> FlatMap(set) 200 assert_that(pcoll4, equal_to([11, 12, 12, 12, 13]), label='pcoll4') 201 202 def test_maptuple_builtin(self): 203 with TestPipeline() as pipeline: 204 pcoll = pipeline | Create([('e1', 'e2')]) 205 side1 = beam.pvalue.AsSingleton(pipeline | 'side1' >> Create(['s1'])) 206 side2 = beam.pvalue.AsSingleton(pipeline | 'side2' >> Create(['s2'])) 207 208 # A test function with a tuple input, an auxiliary parameter, 209 # and some side inputs. 210 fn = lambda e1, e2, t=DoFn.TimestampParam, s1=None, s2=None: ( 211 e1, e2, t, s1, s2) 212 assert_that( 213 pcoll | 'NoSides' >> beam.core.MapTuple(fn), 214 equal_to([('e1', 'e2', MIN_TIMESTAMP, None, None)]), 215 label='NoSidesCheck') 216 assert_that( 217 pcoll | 'StaticSides' >> beam.core.MapTuple(fn, 's1', 's2'), 218 equal_to([('e1', 'e2', MIN_TIMESTAMP, 's1', 's2')]), 219 label='StaticSidesCheck') 220 assert_that( 221 pcoll | 'DynamicSides' >> beam.core.MapTuple(fn, side1, side2), 222 equal_to([('e1', 'e2', MIN_TIMESTAMP, 's1', 's2')]), 223 label='DynamicSidesCheck') 224 assert_that( 225 pcoll | 'MixedSides' >> beam.core.MapTuple(fn, s2=side2), 226 equal_to([('e1', 'e2', MIN_TIMESTAMP, None, 's2')]), 227 label='MixedSidesCheck') 228 229 def test_flatmaptuple_builtin(self): 230 with TestPipeline() as pipeline: 231 pcoll = pipeline | Create([('e1', 'e2')]) 232 side1 = beam.pvalue.AsSingleton(pipeline | 'side1' >> Create(['s1'])) 233 side2 = beam.pvalue.AsSingleton(pipeline | 'side2' >> Create(['s2'])) 234 235 # A test function with a tuple input, an auxiliary parameter, 236 # and some side inputs. 237 fn = lambda e1, e2, t=DoFn.TimestampParam, s1=None, s2=None: ( 238 e1, e2, t, s1, s2) 239 assert_that( 240 pcoll | 'NoSides' >> beam.core.FlatMapTuple(fn), 241 equal_to(['e1', 'e2', MIN_TIMESTAMP, None, None]), 242 label='NoSidesCheck') 243 assert_that( 244 pcoll | 'StaticSides' >> beam.core.FlatMapTuple(fn, 's1', 's2'), 245 equal_to(['e1', 'e2', MIN_TIMESTAMP, 's1', 's2']), 246 label='StaticSidesCheck') 247 assert_that( 248 pcoll 249 | 'DynamicSides' >> beam.core.FlatMapTuple(fn, side1, side2), 250 equal_to(['e1', 'e2', MIN_TIMESTAMP, 's1', 's2']), 251 label='DynamicSidesCheck') 252 assert_that( 253 pcoll | 'MixedSides' >> beam.core.FlatMapTuple(fn, s2=side2), 254 equal_to(['e1', 'e2', MIN_TIMESTAMP, None, 's2']), 255 label='MixedSidesCheck') 256 257 def test_create_singleton_pcollection(self): 258 with TestPipeline() as pipeline: 259 pcoll = pipeline | 'label' >> Create([[1, 2, 3]]) 260 assert_that(pcoll, equal_to([[1, 2, 3]])) 261 262 # TODO(BEAM-1555): Test is failing on the service, with FakeSource. 263 # @pytest.mark.it_validatesrunner 264 def test_metrics_in_fake_source(self): 265 pipeline = TestPipeline() 266 pcoll = pipeline | Read(FakeSource([1, 2, 3, 4, 5, 6])) 267 assert_that(pcoll, equal_to([1, 2, 3, 4, 5, 6])) 268 res = pipeline.run() 269 metric_results = res.metrics().query() 270 outputs_counter = metric_results['counters'][0] 271 self.assertEqual(outputs_counter.key.step, 'Read') 272 self.assertEqual(outputs_counter.key.metric.name, 'outputs') 273 self.assertEqual(outputs_counter.committed, 6) 274 275 def test_fake_read(self): 276 with TestPipeline() as pipeline: 277 pcoll = pipeline | 'read' >> Read(FakeSource([1, 2, 3])) 278 assert_that(pcoll, equal_to([1, 2, 3])) 279 280 def test_visit_entire_graph(self): 281 pipeline = Pipeline() 282 pcoll1 = pipeline | 'pcoll' >> beam.Impulse() 283 pcoll2 = pcoll1 | 'do1' >> FlatMap(lambda x: [x + 1]) 284 pcoll3 = pcoll2 | 'do2' >> FlatMap(lambda x: [x + 1]) 285 pcoll4 = pcoll2 | 'do3' >> FlatMap(lambda x: [x + 1]) 286 transform = PipelineTest.CustomTransform() 287 pcoll5 = pcoll4 | transform 288 289 visitor = PipelineTest.Visitor(visited=[]) 290 pipeline.visit(visitor) 291 self.assertEqual({pcoll1, pcoll2, pcoll3, pcoll4, pcoll5}, 292 set(visitor.visited)) 293 self.assertEqual(set(visitor.enter_composite), set(visitor.leave_composite)) 294 self.assertEqual(2, len(visitor.enter_composite)) 295 self.assertEqual(visitor.enter_composite[1].transform, transform) 296 self.assertEqual(visitor.leave_composite[0].transform, transform) 297 298 def test_apply_custom_transform(self): 299 with TestPipeline() as pipeline: 300 pcoll = pipeline | 'pcoll' >> Create([1, 2, 3]) 301 result = pcoll | PipelineTest.CustomTransform() 302 assert_that(result, equal_to([2, 3, 4])) 303 304 def test_reuse_custom_transform_instance(self): 305 pipeline = Pipeline() 306 pcoll1 = pipeline | 'pcoll1' >> Create([1, 2, 3]) 307 pcoll2 = pipeline | 'pcoll2' >> Create([4, 5, 6]) 308 transform = PipelineTest.CustomTransform() 309 pcoll1 | transform 310 with self.assertRaises(RuntimeError) as cm: 311 pipeline.apply(transform, pcoll2) 312 self.assertEqual( 313 cm.exception.args[0], 314 'A transform with label "CustomTransform" already exists in the ' 315 'pipeline. To apply a transform with a specified label write ' 316 'pvalue | "label" >> transform') 317 318 def test_reuse_cloned_custom_transform_instance(self): 319 with TestPipeline() as pipeline: 320 pcoll1 = pipeline | 'pc1' >> Create([1, 2, 3]) 321 pcoll2 = pipeline | 'pc2' >> Create([4, 5, 6]) 322 transform = PipelineTest.CustomTransform() 323 result1 = pcoll1 | transform 324 result2 = pcoll2 | 'new_label' >> transform 325 assert_that(result1, equal_to([2, 3, 4]), label='r1') 326 assert_that(result2, equal_to([5, 6, 7]), label='r2') 327 328 def test_transform_no_super_init(self): 329 class AddSuffix(PTransform): 330 def __init__(self, suffix): 331 # No call to super(...).__init__ 332 self.suffix = suffix 333 334 def expand(self, pcoll): 335 return pcoll | Map(lambda x: x + self.suffix) 336 337 self.assertEqual(['a-x', 'b-x', 'c-x'], 338 sorted(['a', 'b', 'c'] | 'AddSuffix' >> AddSuffix('-x'))) 339 340 @unittest.skip("Fails on some platforms with new urllib3.") 341 def test_memory_usage(self): 342 try: 343 import resource 344 except ImportError: 345 # Skip the test if resource module is not available (e.g. non-Unix os). 346 self.skipTest('resource module not available.') 347 if platform.mac_ver()[0]: 348 # Skip the test on macos, depending on version it returns ru_maxrss in 349 # different units. 350 self.skipTest('ru_maxrss is not in standard units.') 351 352 def get_memory_usage_in_bytes(): 353 return resource.getrusage(resource.RUSAGE_SELF).ru_maxrss * (2**10) 354 355 def check_memory(value, memory_threshold): 356 memory_usage = get_memory_usage_in_bytes() 357 if memory_usage > memory_threshold: 358 raise RuntimeError( 359 'High memory usage: %d > %d' % (memory_usage, memory_threshold)) 360 return value 361 362 len_elements = 1000000 363 num_elements = 10 364 num_maps = 100 365 366 # TODO(robertwb): reduce memory usage of FnApiRunner so that this test 367 # passes. 368 with TestPipeline(runner='BundleBasedDirectRunner') as pipeline: 369 370 # Consumed memory should not be proportional to the number of maps. 371 memory_threshold = ( 372 get_memory_usage_in_bytes() + (5 * len_elements * num_elements)) 373 374 # Plus small additional slack for memory fluctuations during the test. 375 memory_threshold += 10 * (2**20) 376 377 biglist = pipeline | 'oom:create' >> Create( 378 ['x' * len_elements] * num_elements) 379 for i in range(num_maps): 380 biglist = biglist | ('oom:addone-%d' % i) >> Map(lambda x: x + 'y') 381 result = biglist | 'oom:check' >> Map(check_memory, memory_threshold) 382 assert_that( 383 result, 384 equal_to(['x' * len_elements + 'y' * num_maps] * num_elements)) 385 386 def test_aggregator_empty_input(self): 387 actual = [] | CombineGlobally(max).without_defaults() 388 self.assertEqual(actual, []) 389 390 def test_pipeline_as_context(self): 391 def raise_exception(exn): 392 raise exn 393 394 with self.assertRaises(ValueError): 395 with Pipeline() as p: 396 # pylint: disable=expression-not-assigned 397 p | Create([ValueError('msg')]) | Map(raise_exception) 398 399 def test_ptransform_overrides(self): 400 class MyParDoOverride(PTransformOverride): 401 def matches(self, applied_ptransform): 402 return isinstance(applied_ptransform.transform, DoubleParDo) 403 404 def get_replacement_transform_for_applied_ptransform( 405 self, applied_ptransform): 406 ptransform = applied_ptransform.transform 407 if isinstance(ptransform, DoubleParDo): 408 return TripleParDo() 409 raise ValueError('Unsupported type of transform: %r' % ptransform) 410 411 p = Pipeline() 412 pcoll = p | beam.Create([1, 2, 3]) | 'Multiply' >> DoubleParDo() 413 assert_that(pcoll, equal_to([3, 6, 9])) 414 415 p.replace_all([MyParDoOverride()]) 416 p.run() 417 418 def test_ptransform_override_type_hints(self): 419 class NoTypeHintOverride(PTransformOverride): 420 def matches(self, applied_ptransform): 421 return isinstance(applied_ptransform.transform, DoubleParDo) 422 423 def get_replacement_transform_for_applied_ptransform( 424 self, applied_ptransform): 425 return ToStringParDo() 426 427 class WithTypeHintOverride(PTransformOverride): 428 def matches(self, applied_ptransform): 429 return isinstance(applied_ptransform.transform, DoubleParDo) 430 431 def get_replacement_transform_for_applied_ptransform( 432 self, applied_ptransform): 433 return ToStringParDo().with_input_types(int).with_output_types(str) 434 435 for override, expected_type in [(NoTypeHintOverride(), int), 436 (WithTypeHintOverride(), str)]: 437 p = TestPipeline() 438 pcoll = ( 439 p 440 | beam.Create([1, 2, 3]) 441 | 'Operate' >> DoubleParDo() 442 | 'NoOp' >> beam.Map(lambda x: x)) 443 444 p.replace_all([override]) 445 self.assertEqual(pcoll.producer.inputs[0].element_type, expected_type) 446 447 def test_ptransform_override_multiple_inputs(self): 448 class MyParDoOverride(PTransformOverride): 449 def matches(self, applied_ptransform): 450 return isinstance(applied_ptransform.transform, FlattenAndDouble) 451 452 def get_replacement_transform(self, applied_ptransform): 453 return FlattenAndTriple() 454 455 p = Pipeline() 456 pcoll1 = p | 'pc1' >> beam.Create([1, 2, 3]) 457 pcoll2 = p | 'pc2' >> beam.Create([4, 5, 6]) 458 pcoll3 = (pcoll1, pcoll2) | 'FlattenAndMultiply' >> FlattenAndDouble() 459 assert_that(pcoll3, equal_to([3, 6, 9, 12, 15, 18])) 460 461 p.replace_all([MyParDoOverride()]) 462 p.run() 463 464 def test_ptransform_override_side_inputs(self): 465 class MyParDoOverride(PTransformOverride): 466 def matches(self, applied_ptransform): 467 return ( 468 isinstance(applied_ptransform.transform, ParDo) and 469 isinstance(applied_ptransform.transform.fn, AddWithProductDoFn)) 470 471 def get_replacement_transform(self, transform): 472 return AddThenMultiply() 473 474 p = Pipeline() 475 pcoll1 = p | 'pc1' >> beam.Create([2]) 476 pcoll2 = p | 'pc2' >> beam.Create([3]) 477 pcoll3 = p | 'pc3' >> beam.Create([4, 5, 6]) 478 result = pcoll3 | 'Operate' >> beam.ParDo( 479 AddWithProductDoFn(), AsSingleton(pcoll1), AsSingleton(pcoll2)) 480 assert_that(result, equal_to([18, 21, 24])) 481 482 p.replace_all([MyParDoOverride()]) 483 p.run() 484 485 def test_ptransform_override_replacement_inputs(self): 486 class MyParDoOverride(PTransformOverride): 487 def matches(self, applied_ptransform): 488 return ( 489 isinstance(applied_ptransform.transform, ParDo) and 490 isinstance(applied_ptransform.transform.fn, AddWithProductDoFn)) 491 492 def get_replacement_transform(self, transform): 493 return AddThenMultiply() 494 495 def get_replacement_inputs(self, applied_ptransform): 496 assert len(applied_ptransform.inputs) == 1 497 assert len(applied_ptransform.side_inputs) == 2 498 # Swap the order of the two side inputs 499 return ( 500 applied_ptransform.inputs[0], 501 applied_ptransform.side_inputs[1].pvalue, 502 applied_ptransform.side_inputs[0].pvalue) 503 504 p = Pipeline() 505 pcoll1 = p | 'pc1' >> beam.Create([2]) 506 pcoll2 = p | 'pc2' >> beam.Create([3]) 507 pcoll3 = p | 'pc3' >> beam.Create([4, 5, 6]) 508 result = pcoll3 | 'Operate' >> beam.ParDo( 509 AddWithProductDoFn(), AsSingleton(pcoll1), AsSingleton(pcoll2)) 510 assert_that(result, equal_to([14, 16, 18])) 511 512 p.replace_all([MyParDoOverride()]) 513 p.run() 514 515 def test_ptransform_override_multiple_outputs(self): 516 class MultiOutputComposite(PTransform): 517 def __init__(self): 518 self.output_tags = set() 519 520 def expand(self, pcoll): 521 def mux_input(x): 522 x = x * 2 523 if isinstance(x, int): 524 yield TaggedOutput('numbers', x) 525 else: 526 yield TaggedOutput('letters', x) 527 528 multi = pcoll | 'MyReplacement' >> beam.ParDo(mux_input).with_outputs() 529 letters = multi.letters | 'LettersComposite' >> beam.Map( 530 lambda x: x * 3) 531 numbers = multi.numbers | 'NumbersComposite' >> beam.Map( 532 lambda x: x * 5) 533 534 return { 535 'letters': letters, 536 'numbers': numbers, 537 } 538 539 class MultiOutputOverride(PTransformOverride): 540 def matches(self, applied_ptransform): 541 return applied_ptransform.full_label == 'MyMultiOutput' 542 543 def get_replacement_transform_for_applied_ptransform( 544 self, applied_ptransform): 545 return MultiOutputComposite() 546 547 def mux_input(x): 548 if isinstance(x, int): 549 yield TaggedOutput('numbers', x) 550 else: 551 yield TaggedOutput('letters', x) 552 553 with TestPipeline() as p: 554 multi = ( 555 p 556 | beam.Create([1, 2, 3, 'a', 'b', 'c']) 557 | 'MyMultiOutput' >> beam.ParDo(mux_input).with_outputs()) 558 letters = multi.letters | 'MyLetters' >> beam.Map(lambda x: x) 559 numbers = multi.numbers | 'MyNumbers' >> beam.Map(lambda x: x) 560 561 # Assert that the PCollection replacement worked correctly and that 562 # elements are flowing through. The replacement transform first 563 # multiples by 2 then the leaf nodes inside the composite multiply by 564 # an additional 3 and 5. Use prime numbers to ensure that each 565 # transform is getting executed once. 566 assert_that( 567 letters, 568 equal_to(['a' * 2 * 3, 'b' * 2 * 3, 'c' * 2 * 3]), 569 label='assert letters') 570 assert_that( 571 numbers, 572 equal_to([1 * 2 * 5, 2 * 2 * 5, 3 * 2 * 5]), 573 label='assert numbers') 574 575 # Do the replacement and run the element assertions. 576 p.replace_all([MultiOutputOverride()]) 577 578 # The following checks the graph to make sure the replacement occurred. 579 visitor = PipelineTest.Visitor(visited=[]) 580 p.visit(visitor) 581 pcollections = visitor.visited 582 composites = visitor.enter_composite 583 584 # Assert the replacement is in the composite list and retrieve the 585 # AppliedPTransform. 586 self.assertIn( 587 MultiOutputComposite, [t.transform.__class__ for t in composites]) 588 multi_output_composite = list( 589 filter( 590 lambda t: t.transform.__class__ == MultiOutputComposite, 591 composites))[0] 592 593 # Assert that all of the replacement PCollections are in the graph. 594 for output in multi_output_composite.outputs.values(): 595 self.assertIn(output, pcollections) 596 597 # Assert that all of the "old"/replaced PCollections are not in the graph. 598 self.assertNotIn(multi[None], visitor.visited) 599 self.assertNotIn(multi.letters, visitor.visited) 600 self.assertNotIn(multi.numbers, visitor.visited) 601 602 def test_kv_ptransform_honor_type_hints(self): 603 604 # The return type of this DoFn cannot be inferred by the default 605 # Beam type inference 606 class StatefulDoFn(DoFn): 607 BYTES_STATE = BagStateSpec('bytes', BytesCoder()) 608 609 def return_recursive(self, count): 610 if count == 0: 611 return ["some string"] 612 else: 613 self.return_recursive(count - 1) 614 615 def process(self, element, counter=DoFn.StateParam(BYTES_STATE)): 616 return self.return_recursive(1) 617 618 with TestPipeline() as p: 619 pcoll = ( 620 p 621 | beam.Create([(1, 1), (2, 2), (3, 3)]) 622 | beam.GroupByKey() 623 | beam.ParDo(StatefulDoFn())) 624 self.assertEqual(pcoll.element_type, typehints.Any) 625 626 with TestPipeline() as p: 627 pcoll = ( 628 p 629 | beam.Create([(1, 1), (2, 2), (3, 3)]) 630 | beam.GroupByKey() 631 | beam.ParDo(StatefulDoFn()).with_output_types(str)) 632 self.assertEqual(pcoll.element_type, str) 633 634 def test_track_pcoll_unbounded(self): 635 pipeline = TestPipeline() 636 pcoll1 = pipeline | 'read' >> Read(FakeUnboundedSource()) 637 pcoll2 = pcoll1 | 'do1' >> FlatMap(lambda x: [x + 1]) 638 pcoll3 = pcoll2 | 'do2' >> FlatMap(lambda x: [x + 1]) 639 self.assertIs(pcoll1.is_bounded, False) 640 self.assertIs(pcoll2.is_bounded, False) 641 self.assertIs(pcoll3.is_bounded, False) 642 643 def test_track_pcoll_bounded(self): 644 pipeline = TestPipeline() 645 pcoll1 = pipeline | 'label1' >> Create([1, 2, 3]) 646 pcoll2 = pcoll1 | 'do1' >> FlatMap(lambda x: [x + 1]) 647 pcoll3 = pcoll2 | 'do2' >> FlatMap(lambda x: [x + 1]) 648 self.assertIs(pcoll1.is_bounded, True) 649 self.assertIs(pcoll2.is_bounded, True) 650 self.assertIs(pcoll3.is_bounded, True) 651 652 def test_track_pcoll_bounded_flatten(self): 653 pipeline = TestPipeline() 654 pcoll1_a = pipeline | 'label_a' >> Create([1, 2, 3]) 655 pcoll2_a = pcoll1_a | 'do_a' >> FlatMap(lambda x: [x + 1]) 656 657 pcoll1_b = pipeline | 'label_b' >> Create([1, 2, 3]) 658 pcoll2_b = pcoll1_b | 'do_b' >> FlatMap(lambda x: [x + 1]) 659 660 merged = (pcoll2_a, pcoll2_b) | beam.Flatten() 661 662 self.assertIs(pcoll1_a.is_bounded, True) 663 self.assertIs(pcoll2_a.is_bounded, True) 664 self.assertIs(pcoll1_b.is_bounded, True) 665 self.assertIs(pcoll2_b.is_bounded, True) 666 self.assertIs(merged.is_bounded, True) 667 668 def test_track_pcoll_unbounded_flatten(self): 669 pipeline = TestPipeline() 670 pcoll1_bounded = pipeline | 'label1' >> Create([1, 2, 3]) 671 pcoll2_bounded = pcoll1_bounded | 'do1' >> FlatMap(lambda x: [x + 1]) 672 673 pcoll1_unbounded = pipeline | 'read' >> Read(FakeUnboundedSource()) 674 pcoll2_unbounded = pcoll1_unbounded | 'do2' >> FlatMap(lambda x: [x + 1]) 675 676 merged = (pcoll2_bounded, pcoll2_unbounded) | beam.Flatten() 677 678 self.assertIs(pcoll1_bounded.is_bounded, True) 679 self.assertIs(pcoll2_bounded.is_bounded, True) 680 self.assertIs(pcoll1_unbounded.is_bounded, False) 681 self.assertIs(pcoll2_unbounded.is_bounded, False) 682 self.assertIs(merged.is_bounded, False) 683 684 def test_incompatible_submission_and_runtime_envs_fail_pipeline(self): 685 with mock.patch( 686 'apache_beam.transforms.environments.sdk_base_version_capability' 687 ) as base_version: 688 base_version.side_effect = [ 689 f"beam:version:sdk_base:apache/beam_python3.5_sdk:2.{i}.0" 690 for i in range(100) 691 ] 692 with self.assertRaisesRegex( 693 RuntimeError, 694 'Pipeline construction environment and pipeline runtime ' 695 'environment are not compatible.'): 696 with TestPipeline() as p: 697 _ = p | Create([None]) 698 699 700 class DoFnTest(unittest.TestCase): 701 def test_element(self): 702 class TestDoFn(DoFn): 703 def process(self, element): 704 yield element + 10 705 706 with TestPipeline() as pipeline: 707 pcoll = pipeline | 'Create' >> Create([1, 2]) | 'Do' >> ParDo(TestDoFn()) 708 assert_that(pcoll, equal_to([11, 12])) 709 710 def test_side_input_no_tag(self): 711 class TestDoFn(DoFn): 712 def process(self, element, prefix, suffix): 713 return ['%s-%s-%s' % (prefix, element, suffix)] 714 715 with TestPipeline() as pipeline: 716 words_list = ['aa', 'bb', 'cc'] 717 words = pipeline | 'SomeWords' >> Create(words_list) 718 prefix = 'zyx' 719 suffix = pipeline | 'SomeString' >> Create(['xyz']) # side in 720 result = words | 'DecorateWordsDoFnNoTag' >> ParDo( 721 TestDoFn(), prefix, suffix=AsSingleton(suffix)) 722 assert_that(result, equal_to(['zyx-%s-xyz' % x for x in words_list])) 723 724 def test_side_input_tagged(self): 725 class TestDoFn(DoFn): 726 def process(self, element, prefix, suffix=DoFn.SideInputParam): 727 return ['%s-%s-%s' % (prefix, element, suffix)] 728 729 with TestPipeline() as pipeline: 730 words_list = ['aa', 'bb', 'cc'] 731 words = pipeline | 'SomeWords' >> Create(words_list) 732 prefix = 'zyx' 733 suffix = pipeline | 'SomeString' >> Create(['xyz']) # side in 734 result = words | 'DecorateWordsDoFnNoTag' >> ParDo( 735 TestDoFn(), prefix, suffix=AsSingleton(suffix)) 736 assert_that(result, equal_to(['zyx-%s-xyz' % x for x in words_list])) 737 738 @pytest.mark.it_validatesrunner 739 def test_element_param(self): 740 pipeline = TestPipeline() 741 input = [1, 2] 742 pcoll = ( 743 pipeline 744 | 'Create' >> Create(input) 745 | 'Ele param' >> Map(lambda element=DoFn.ElementParam: element)) 746 assert_that(pcoll, equal_to(input)) 747 pipeline.run() 748 749 @pytest.mark.it_validatesrunner 750 def test_key_param(self): 751 pipeline = TestPipeline() 752 pcoll = ( 753 pipeline 754 | 'Create' >> Create([('a', 1), ('b', 2)]) 755 | 'Key param' >> Map(lambda _, key=DoFn.KeyParam: key)) 756 assert_that(pcoll, equal_to(['a', 'b'])) 757 pipeline.run() 758 759 def test_window_param(self): 760 class TestDoFn(DoFn): 761 def process(self, element, window=DoFn.WindowParam): 762 yield (element, (float(window.start), float(window.end))) 763 764 with TestPipeline() as pipeline: 765 pcoll = ( 766 pipeline 767 | Create([1, 7]) 768 | Map(lambda x: TimestampedValue(x, x)) 769 | WindowInto(windowfn=SlidingWindows(10, 5)) 770 | ParDo(TestDoFn())) 771 assert_that( 772 pcoll, 773 equal_to([(1, (-5, 5)), (1, (0, 10)), (7, (0, 10)), (7, (5, 15))])) 774 pcoll2 = pcoll | 'Again' >> ParDo(TestDoFn()) 775 assert_that( 776 pcoll2, 777 equal_to([((1, (-5, 5)), (-5, 5)), ((1, (0, 10)), (0, 10)), 778 ((7, (0, 10)), (0, 10)), ((7, (5, 15)), (5, 15))]), 779 label='doubled windows') 780 781 def test_timestamp_param(self): 782 class TestDoFn(DoFn): 783 def process(self, element, timestamp=DoFn.TimestampParam): 784 yield timestamp 785 786 with TestPipeline() as pipeline: 787 pcoll = pipeline | 'Create' >> Create([1, 2]) | 'Do' >> ParDo(TestDoFn()) 788 assert_that(pcoll, equal_to([MIN_TIMESTAMP, MIN_TIMESTAMP])) 789 790 def test_timestamp_param_map(self): 791 with TestPipeline() as p: 792 assert_that( 793 p | Create([1, 2]) | beam.Map(lambda _, t=DoFn.TimestampParam: t), 794 equal_to([MIN_TIMESTAMP, MIN_TIMESTAMP])) 795 796 def test_pane_info_param(self): 797 with TestPipeline() as p: 798 pc = p | Create([(None, None)]) 799 assert_that( 800 pc | beam.Map(lambda _, p=DoFn.PaneInfoParam: p), 801 equal_to([windowed_value.PANE_INFO_UNKNOWN]), 802 label='CheckUngrouped') 803 assert_that( 804 pc | beam.GroupByKey() | beam.Map(lambda _, p=DoFn.PaneInfoParam: p), 805 equal_to([ 806 windowed_value.PaneInfo( 807 is_first=True, 808 is_last=True, 809 timing=windowed_value.PaneInfoTiming.ON_TIME, 810 index=0, 811 nonspeculative_index=0) 812 ]), 813 label='CheckGrouped') 814 815 def test_incomparable_default(self): 816 class IncomparableType(object): 817 def __eq__(self, other): 818 raise RuntimeError() 819 820 def __ne__(self, other): 821 raise RuntimeError() 822 823 def __hash__(self): 824 raise RuntimeError() 825 826 # Ensure that we don't use default values in a context where they must be 827 # comparable (see BEAM-8301). 828 with TestPipeline() as pipeline: 829 pcoll = ( 830 pipeline 831 | beam.Create([None]) 832 | Map(lambda e, x=IncomparableType(): (e, type(x).__name__))) 833 assert_that(pcoll, equal_to([(None, 'IncomparableType')])) 834 835 836 class Bacon(PipelineOptions): 837 @classmethod 838 def _add_argparse_args(cls, parser): 839 parser.add_argument('--slices', type=int) 840 841 842 class Eggs(PipelineOptions): 843 @classmethod 844 def _add_argparse_args(cls, parser): 845 parser.add_argument('--style', default='scrambled') 846 847 848 class Breakfast(Bacon, Eggs): 849 pass 850 851 852 class PipelineOptionsTest(unittest.TestCase): 853 def test_flag_parsing(self): 854 options = Breakfast(['--slices=3', '--style=sunny side up', '--ignored']) 855 self.assertEqual(3, options.slices) 856 self.assertEqual('sunny side up', options.style) 857 858 def test_keyword_parsing(self): 859 options = Breakfast(['--slices=3', '--style=sunny side up', '--ignored'], 860 slices=10) 861 self.assertEqual(10, options.slices) 862 self.assertEqual('sunny side up', options.style) 863 864 def test_attribute_setting(self): 865 options = Breakfast(slices=10) 866 self.assertEqual(10, options.slices) 867 options.slices = 20 868 self.assertEqual(20, options.slices) 869 870 def test_view_as(self): 871 generic_options = PipelineOptions(['--slices=3']) 872 self.assertEqual(3, generic_options.view_as(Bacon).slices) 873 self.assertEqual(3, generic_options.view_as(Breakfast).slices) 874 875 generic_options.view_as(Breakfast).slices = 10 876 self.assertEqual(10, generic_options.view_as(Bacon).slices) 877 878 with self.assertRaises(AttributeError): 879 generic_options.slices # pylint: disable=pointless-statement 880 881 with self.assertRaises(AttributeError): 882 generic_options.view_as(Eggs).slices # pylint: disable=expression-not-assigned 883 884 def test_defaults(self): 885 options = Breakfast(['--slices=3']) 886 self.assertEqual(3, options.slices) 887 self.assertEqual('scrambled', options.style) 888 889 def test_dir(self): 890 options = Breakfast() 891 self.assertEqual({ 892 'from_dictionary', 893 'get_all_options', 894 'slices', 895 'style', 896 'view_as', 897 'display_data' 898 }, 899 { 900 attr 901 for attr in dir(options) 902 if not attr.startswith('_') and attr != 'next' 903 }) 904 self.assertEqual({ 905 'from_dictionary', 906 'get_all_options', 907 'style', 908 'view_as', 909 'display_data' 910 }, 911 { 912 attr 913 for attr in dir(options.view_as(Eggs)) 914 if not attr.startswith('_') and attr != 'next' 915 }) 916 917 918 class RunnerApiTest(unittest.TestCase): 919 def test_parent_pointer(self): 920 class MyPTransform(beam.PTransform): 921 def expand(self, p): 922 self.p = p 923 return p | beam.Create([None]) 924 925 p = beam.Pipeline() 926 p | MyPTransform() # pylint: disable=expression-not-assigned 927 p = Pipeline.from_runner_api( 928 Pipeline.to_runner_api(p, use_fake_coders=True), None, None) 929 self.assertIsNotNone(p.transforms_stack[0].parts[0].parent) 930 self.assertEqual( 931 p.transforms_stack[0].parts[0].parent, p.transforms_stack[0]) 932 933 def test_requirements(self): 934 p = beam.Pipeline() 935 _ = ( 936 p | beam.Create([]) 937 | beam.ParDo(lambda x, finalize=beam.DoFn.BundleFinalizerParam: None)) 938 proto = p.to_runner_api() 939 self.assertTrue( 940 common_urns.requirements.REQUIRES_BUNDLE_FINALIZATION.urn, 941 proto.requirements) 942 943 def test_annotations(self): 944 some_proto = BytesCoder().to_runner_api(None) 945 946 class EmptyTransform(beam.PTransform): 947 def expand(self, pcoll): 948 return pcoll 949 950 def annotations(self): 951 return {'foo': 'some_string'} 952 953 class NonEmptyTransform(beam.PTransform): 954 def expand(self, pcoll): 955 return pcoll | beam.Map(lambda x: x) 956 957 def annotations(self): 958 return { 959 'foo': b'some_bytes', 960 'proto': some_proto, 961 } 962 963 p = beam.Pipeline() 964 _ = p | beam.Create([]) | EmptyTransform() | NonEmptyTransform() 965 proto = p.to_runner_api() 966 967 seen = 0 968 for transform in proto.components.transforms.values(): 969 if transform.unique_name == 'EmptyTransform': 970 seen += 1 971 self.assertEqual(transform.annotations['foo'], b'some_string') 972 elif transform.unique_name == 'NonEmptyTransform': 973 seen += 1 974 self.assertEqual(transform.annotations['foo'], b'some_bytes') 975 self.assertEqual( 976 transform.annotations['proto'], some_proto.SerializeToString()) 977 self.assertEqual(seen, 2) 978 979 def test_transform_ids(self): 980 class MyPTransform(beam.PTransform): 981 def expand(self, p): 982 self.p = p 983 return p | beam.Create([None]) 984 985 p = beam.Pipeline() 986 p | MyPTransform() # pylint: disable=expression-not-assigned 987 runner_api_proto = Pipeline.to_runner_api(p) 988 989 for transform_id in runner_api_proto.components.transforms: 990 self.assertRegex(transform_id, r'[a-zA-Z0-9-_]+') 991 992 def test_input_names(self): 993 class MyPTransform(beam.PTransform): 994 def expand(self, pcolls): 995 return pcolls.values() | beam.Flatten() 996 997 p = beam.Pipeline() 998 input_names = set('ABC') 999 inputs = {x: p | x >> beam.Create([x]) for x in input_names} 1000 inputs | MyPTransform() # pylint: disable=expression-not-assigned 1001 runner_api_proto = Pipeline.to_runner_api(p) 1002 1003 for transform_proto in runner_api_proto.components.transforms.values(): 1004 if transform_proto.unique_name == 'MyPTransform': 1005 self.assertEqual(set(transform_proto.inputs.keys()), input_names) 1006 break 1007 else: 1008 self.fail('Unable to find transform.') 1009 1010 def test_display_data(self): 1011 class MyParentTransform(beam.PTransform): 1012 def expand(self, p): 1013 self.p = p 1014 return p | beam.Create([None]) 1015 1016 def display_data(self): # type: () -> dict 1017 parent_dd = super().display_data() 1018 parent_dd['p_dd_string'] = DisplayDataItem( 1019 'p_dd_string_value', label='p_dd_string_label') 1020 parent_dd['p_dd_string_2'] = DisplayDataItem('p_dd_string_value_2') 1021 parent_dd['p_dd_bool'] = DisplayDataItem(True, label='p_dd_bool_label') 1022 parent_dd['p_dd_int'] = DisplayDataItem(1, label='p_dd_int_label') 1023 return parent_dd 1024 1025 class MyPTransform(MyParentTransform): 1026 def expand(self, p): 1027 self.p = p 1028 return p | beam.Create([None]) 1029 1030 def display_data(self): # type: () -> dict 1031 parent_dd = super().display_data() 1032 parent_dd['dd_string'] = DisplayDataItem( 1033 'dd_string_value', label='dd_string_label') 1034 parent_dd['dd_string_2'] = DisplayDataItem('dd_string_value_2') 1035 parent_dd['dd_bool'] = DisplayDataItem(False, label='dd_bool_label') 1036 parent_dd['dd_double'] = DisplayDataItem(1.1, label='dd_double_label') 1037 return parent_dd 1038 1039 p = beam.Pipeline() 1040 p | MyPTransform() # pylint: disable=expression-not-assigned 1041 1042 proto_pipeline = Pipeline.to_runner_api(p, use_fake_coders=True) 1043 my_transform, = [ 1044 transform 1045 for transform in proto_pipeline.components.transforms.values() 1046 if transform.unique_name == 'MyPTransform' 1047 ] 1048 self.assertIsNotNone(my_transform) 1049 self.assertListEqual( 1050 list(my_transform.display_data), 1051 [ 1052 beam_runner_api_pb2.DisplayData( 1053 urn=common_urns.StandardDisplayData.DisplayData.LABELLED.urn, 1054 payload=beam_runner_api_pb2.LabelledPayload( 1055 label='p_dd_string_label', 1056 key='p_dd_string', 1057 namespace='apache_beam.pipeline_test.MyPTransform', 1058 string_value='p_dd_string_value').SerializeToString()), 1059 beam_runner_api_pb2.DisplayData( 1060 urn=common_urns.StandardDisplayData.DisplayData.LABELLED.urn, 1061 payload=beam_runner_api_pb2.LabelledPayload( 1062 label='p_dd_string_2', 1063 key='p_dd_string_2', 1064 namespace='apache_beam.pipeline_test.MyPTransform', 1065 string_value='p_dd_string_value_2').SerializeToString()), 1066 beam_runner_api_pb2.DisplayData( 1067 urn=common_urns.StandardDisplayData.DisplayData.LABELLED.urn, 1068 payload=beam_runner_api_pb2.LabelledPayload( 1069 label='p_dd_bool_label', 1070 key='p_dd_bool', 1071 namespace='apache_beam.pipeline_test.MyPTransform', 1072 bool_value=True).SerializeToString()), 1073 beam_runner_api_pb2.DisplayData( 1074 urn=common_urns.StandardDisplayData.DisplayData.LABELLED.urn, 1075 payload=beam_runner_api_pb2.LabelledPayload( 1076 label='p_dd_int_label', 1077 key='p_dd_int', 1078 namespace='apache_beam.pipeline_test.MyPTransform', 1079 int_value=1).SerializeToString()), 1080 beam_runner_api_pb2.DisplayData( 1081 urn=common_urns.StandardDisplayData.DisplayData.LABELLED.urn, 1082 payload=beam_runner_api_pb2.LabelledPayload( 1083 label='dd_string_label', 1084 key='dd_string', 1085 namespace='apache_beam.pipeline_test.MyPTransform', 1086 string_value='dd_string_value').SerializeToString()), 1087 beam_runner_api_pb2.DisplayData( 1088 urn=common_urns.StandardDisplayData.DisplayData.LABELLED.urn, 1089 payload=beam_runner_api_pb2.LabelledPayload( 1090 label='dd_string_2', 1091 key='dd_string_2', 1092 namespace='apache_beam.pipeline_test.MyPTransform', 1093 string_value='dd_string_value_2').SerializeToString()), 1094 beam_runner_api_pb2.DisplayData( 1095 urn=common_urns.StandardDisplayData.DisplayData.LABELLED.urn, 1096 payload=beam_runner_api_pb2.LabelledPayload( 1097 label='dd_bool_label', 1098 key='dd_bool', 1099 namespace='apache_beam.pipeline_test.MyPTransform', 1100 bool_value=False).SerializeToString()), 1101 beam_runner_api_pb2.DisplayData( 1102 urn=common_urns.StandardDisplayData.DisplayData.LABELLED.urn, 1103 payload=beam_runner_api_pb2.LabelledPayload( 1104 label='dd_double_label', 1105 key='dd_double', 1106 namespace='apache_beam.pipeline_test.MyPTransform', 1107 double_value=1.1).SerializeToString()), 1108 ]) 1109 1110 def test_runner_api_roundtrip_preserves_resource_hints(self): 1111 p = beam.Pipeline() 1112 _ = ( 1113 p | beam.Create([1, 2]) 1114 | beam.Map(lambda x: x + 1).with_resource_hints(accelerator='gpu')) 1115 1116 self.assertEqual( 1117 p.transforms_stack[0].parts[1].transform.get_resource_hints(), 1118 {common_urns.resource_hints.ACCELERATOR.urn: b'gpu'}) 1119 1120 for _ in range(3): 1121 # Verify that DEFAULT environments are recreated during multiple RunnerAPI 1122 # translation and hints don't get lost. 1123 p = Pipeline.from_runner_api(Pipeline.to_runner_api(p), None, None) 1124 self.assertEqual( 1125 p.transforms_stack[0].parts[1].transform.get_resource_hints(), 1126 {common_urns.resource_hints.ACCELERATOR.urn: b'gpu'}) 1127 1128 def test_hints_on_composite_transforms_are_propagated_to_subtransforms(self): 1129 class FooHint(ResourceHint): 1130 urn = 'foo_urn' 1131 1132 class BarHint(ResourceHint): 1133 urn = 'bar_urn' 1134 1135 class BazHint(ResourceHint): 1136 urn = 'baz_urn' 1137 1138 class QuxHint(ResourceHint): 1139 urn = 'qux_urn' 1140 1141 class UseMaxValueHint(ResourceHint): 1142 urn = 'use_max_value_urn' 1143 1144 @classmethod 1145 def get_merged_value( 1146 cls, outer_value, inner_value): # type: (bytes, bytes) -> bytes 1147 return ResourceHint._use_max(outer_value, inner_value) 1148 1149 ResourceHint.register_resource_hint('foo_hint', FooHint) 1150 ResourceHint.register_resource_hint('bar_hint', BarHint) 1151 ResourceHint.register_resource_hint('baz_hint', BazHint) 1152 ResourceHint.register_resource_hint('qux_hint', QuxHint) 1153 ResourceHint.register_resource_hint('use_max_value_hint', UseMaxValueHint) 1154 1155 @beam.ptransform_fn 1156 def SubTransform(pcoll): 1157 return pcoll | beam.Map(lambda x: x + 1).with_resource_hints( 1158 foo_hint='set_on_subtransform', use_max_value_hint='10') 1159 1160 @beam.ptransform_fn 1161 def CompositeTransform(pcoll): 1162 return pcoll | beam.Map(lambda x: x * 2) | SubTransform() 1163 1164 p = beam.Pipeline() 1165 _ = ( 1166 p | beam.Create([1, 2]) 1167 | CompositeTransform().with_resource_hints( 1168 foo_hint='should_be_overriden_by_subtransform', 1169 bar_hint='set_on_composite', 1170 baz_hint='set_on_composite', 1171 use_max_value_hint='100')) 1172 options = PortableOptions([ 1173 '--resource_hint=baz_hint=should_be_overriden_by_composite', 1174 '--resource_hint=qux_hint=set_via_options', 1175 '--environment_type=PROCESS', 1176 '--environment_option=process_command=foo', 1177 '--sdk_location=container', 1178 ]) 1179 environment = ProcessEnvironment.from_options(options) 1180 proto = Pipeline.to_runner_api(p, default_environment=environment) 1181 1182 for t in proto.components.transforms.values(): 1183 if "CompositeTransform/SubTransform/Map" in t.unique_name: 1184 environment = proto.components.environments.get(t.environment_id) 1185 self.assertEqual( 1186 environment.resource_hints.get('foo_urn'), b'set_on_subtransform') 1187 self.assertEqual( 1188 environment.resource_hints.get('bar_urn'), b'set_on_composite') 1189 self.assertEqual( 1190 environment.resource_hints.get('baz_urn'), b'set_on_composite') 1191 self.assertEqual( 1192 environment.resource_hints.get('qux_urn'), b'set_via_options') 1193 self.assertEqual( 1194 environment.resource_hints.get('use_max_value_urn'), b'100') 1195 found = True 1196 assert found 1197 1198 def test_environments_with_same_resource_hints_are_reused(self): 1199 class HintX(ResourceHint): 1200 urn = 'X_urn' 1201 1202 class HintY(ResourceHint): 1203 urn = 'Y_urn' 1204 1205 class HintIsOdd(ResourceHint): 1206 urn = 'IsOdd_urn' 1207 1208 ResourceHint.register_resource_hint('X', HintX) 1209 ResourceHint.register_resource_hint('Y', HintY) 1210 ResourceHint.register_resource_hint('IsOdd', HintIsOdd) 1211 1212 p = beam.Pipeline() 1213 num_iter = 4 1214 for i in range(num_iter): 1215 _ = ( 1216 p 1217 | f'NoHintCreate_{i}' >> beam.Create([1, 2]) 1218 | f'NoHint_{i}' >> beam.Map(lambda x: x + 1)) 1219 _ = ( 1220 p 1221 | f'XCreate_{i}' >> beam.Create([1, 2]) 1222 | 1223 f'HintX_{i}' >> beam.Map(lambda x: x + 1).with_resource_hints(X='X')) 1224 _ = ( 1225 p 1226 | f'XYCreate_{i}' >> beam.Create([1, 2]) 1227 | f'HintXY_{i}' >> beam.Map(lambda x: x + 1).with_resource_hints( 1228 X='X', Y='Y')) 1229 _ = ( 1230 p 1231 | f'IsOddCreate_{i}' >> beam.Create([1, 2]) 1232 | f'IsOdd_{i}' >> 1233 beam.Map(lambda x: x + 1).with_resource_hints(IsOdd=str(i % 2 != 0))) 1234 1235 proto = Pipeline.to_runner_api(p) 1236 count_x = count_xy = count_is_odd = count_no_hints = 0 1237 env_ids = set() 1238 for _, t in proto.components.transforms.items(): 1239 env = proto.components.environments[t.environment_id] 1240 if t.unique_name.startswith('HintX_'): 1241 count_x += 1 1242 env_ids.add(t.environment_id) 1243 self.assertEqual(env.resource_hints, {'X_urn': b'X'}) 1244 1245 if t.unique_name.startswith('HintXY_'): 1246 count_xy += 1 1247 env_ids.add(t.environment_id) 1248 self.assertEqual(env.resource_hints, {'X_urn': b'X', 'Y_urn': b'Y'}) 1249 1250 if t.unique_name.startswith('NoHint_'): 1251 count_no_hints += 1 1252 env_ids.add(t.environment_id) 1253 self.assertEqual(env.resource_hints, {}) 1254 1255 if t.unique_name.startswith('IsOdd_'): 1256 count_is_odd += 1 1257 env_ids.add(t.environment_id) 1258 self.assertTrue( 1259 env.resource_hints == {'IsOdd_urn': b'True'} or 1260 env.resource_hints == {'IsOdd_urn': b'False'}) 1261 assert count_x == count_is_odd == count_xy == count_no_hints == num_iter 1262 assert num_iter > 1 1263 1264 self.assertEqual(len(env_ids), 5) 1265 1266 def test_multiple_application_of_the_same_transform_set_different_hints(self): 1267 class FooHint(ResourceHint): 1268 urn = 'foo_urn' 1269 1270 class UseMaxValueHint(ResourceHint): 1271 urn = 'use_max_value_urn' 1272 1273 @classmethod 1274 def get_merged_value( 1275 cls, outer_value, inner_value): # type: (bytes, bytes) -> bytes 1276 return ResourceHint._use_max(outer_value, inner_value) 1277 1278 ResourceHint.register_resource_hint('foo_hint', FooHint) 1279 ResourceHint.register_resource_hint('use_max_value_hint', UseMaxValueHint) 1280 1281 @beam.ptransform_fn 1282 def SubTransform(pcoll): 1283 return pcoll | beam.Map(lambda x: x + 1) 1284 1285 @beam.ptransform_fn 1286 def CompositeTransform(pcoll): 1287 sub = SubTransform() 1288 return ( 1289 pcoll 1290 | 'first' >> sub.with_resource_hints(foo_hint='first_application') 1291 | 'second' >> sub.with_resource_hints(foo_hint='second_application')) 1292 1293 p = beam.Pipeline() 1294 _ = (p | beam.Create([1, 2]) | CompositeTransform()) 1295 proto = Pipeline.to_runner_api(p) 1296 count = 0 1297 for t in proto.components.transforms.values(): 1298 if "CompositeTransform/first/Map" in t.unique_name: 1299 environment = proto.components.environments.get(t.environment_id) 1300 self.assertEqual( 1301 b'first_application', environment.resource_hints.get('foo_urn')) 1302 count += 1 1303 if "CompositeTransform/second/Map" in t.unique_name: 1304 environment = proto.components.environments.get(t.environment_id) 1305 self.assertEqual( 1306 b'second_application', environment.resource_hints.get('foo_urn')) 1307 count += 1 1308 assert count == 2 1309 1310 def test_environments_are_deduplicated(self): 1311 def file_artifact(path, hash, staged_name): 1312 return beam_runner_api_pb2.ArtifactInformation( 1313 type_urn=common_urns.artifact_types.FILE.urn, 1314 type_payload=beam_runner_api_pb2.ArtifactFilePayload( 1315 path=path, sha256=hash).SerializeToString(), 1316 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1317 role_payload=beam_runner_api_pb2.ArtifactStagingToRolePayload( 1318 staged_name=staged_name).SerializeToString(), 1319 ) 1320 1321 proto = beam_runner_api_pb2.Pipeline( 1322 components=beam_runner_api_pb2.Components( 1323 transforms={ 1324 f'transform{ix}': beam_runner_api_pb2.PTransform( 1325 environment_id=f'e{ix}') 1326 for ix in range(8) 1327 }, 1328 environments={ 1329 # Same hash and destination. 1330 'e1': beam_runner_api_pb2.Environment( 1331 dependencies=[file_artifact('a1', 'x', 'dest')]), 1332 'e2': beam_runner_api_pb2.Environment( 1333 dependencies=[file_artifact('a2', 'x', 'dest')]), 1334 # Different hash. 1335 'e3': beam_runner_api_pb2.Environment( 1336 dependencies=[file_artifact('a3', 'y', 'dest')]), 1337 # Different destination. 1338 'e4': beam_runner_api_pb2.Environment( 1339 dependencies=[file_artifact('a4', 'y', 'dest2')]), 1340 # Multiple files with same hash and destinations. 1341 'e5': beam_runner_api_pb2.Environment( 1342 dependencies=[ 1343 file_artifact('a1', 'x', 'dest'), 1344 file_artifact('b1', 'xb', 'destB') 1345 ]), 1346 'e6': beam_runner_api_pb2.Environment( 1347 dependencies=[ 1348 file_artifact('a2', 'x', 'dest'), 1349 file_artifact('b2', 'xb', 'destB') 1350 ]), 1351 # Overlapping, but not identical, files. 1352 'e7': beam_runner_api_pb2.Environment( 1353 dependencies=[ 1354 file_artifact('a1', 'x', 'dest'), 1355 file_artifact('b2', 'y', 'destB') 1356 ]), 1357 # Same files as first, but differing other properties. 1358 'e0': beam_runner_api_pb2.Environment( 1359 resource_hints={'hint': b'value'}, 1360 dependencies=[file_artifact('a1', 'x', 'dest')]), 1361 })) 1362 Pipeline.merge_compatible_environments(proto) 1363 1364 # These environments are equivalent. 1365 self.assertEqual( 1366 proto.components.transforms['transform1'].environment_id, 1367 proto.components.transforms['transform2'].environment_id) 1368 1369 self.assertEqual( 1370 proto.components.transforms['transform5'].environment_id, 1371 proto.components.transforms['transform6'].environment_id) 1372 1373 # These are not. 1374 self.assertNotEqual( 1375 proto.components.transforms['transform1'].environment_id, 1376 proto.components.transforms['transform3'].environment_id) 1377 self.assertNotEqual( 1378 proto.components.transforms['transform4'].environment_id, 1379 proto.components.transforms['transform3'].environment_id) 1380 self.assertNotEqual( 1381 proto.components.transforms['transform6'].environment_id, 1382 proto.components.transforms['transform7'].environment_id) 1383 self.assertNotEqual( 1384 proto.components.transforms['transform1'].environment_id, 1385 proto.components.transforms['transform0'].environment_id) 1386 1387 self.assertEqual(len(proto.components.environments), 6) 1388 1389 1390 if __name__ == '__main__': 1391 unittest.main()