github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/dataflow/internal/apiclient_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Unit tests for the apiclient module.""" 19 20 # pytype: skip-file 21 22 import itertools 23 import json 24 import logging 25 import os 26 import sys 27 import unittest 28 29 import mock 30 31 from apache_beam.io.filesystems import FileSystems 32 from apache_beam.metrics.cells import DistributionData 33 from apache_beam.options.pipeline_options import GoogleCloudOptions 34 from apache_beam.options.pipeline_options import PipelineOptions 35 from apache_beam.pipeline import Pipeline 36 from apache_beam.portability import common_urns 37 from apache_beam.portability.api import beam_runner_api_pb2 38 from apache_beam.runners.dataflow.internal import names 39 from apache_beam.runners.dataflow.internal.clients import dataflow 40 from apache_beam.transforms import Create 41 from apache_beam.transforms import DataflowDistributionCounter 42 from apache_beam.transforms import DoFn 43 from apache_beam.transforms import ParDo 44 from apache_beam.transforms.environments import DockerEnvironment 45 46 # Protect against environments where apitools library is not available. 47 # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports 48 try: 49 from apache_beam.runners.dataflow.internal import apiclient 50 except ImportError: 51 apiclient = None # type: ignore 52 # pylint: enable=wrong-import-order, wrong-import-position 53 54 FAKE_PIPELINE_URL = "gs://invalid-bucket/anywhere" 55 _LOGGER = logging.getLogger(__name__) 56 57 58 @unittest.skipIf(apiclient is None, 'GCP dependencies are not installed') 59 class UtilTest(unittest.TestCase): 60 @unittest.skip("Enable once BEAM-1080 is fixed.") 61 def test_create_application_client(self): 62 pipeline_options = PipelineOptions() 63 apiclient.DataflowApplicationClient(pipeline_options) 64 65 def test_pipeline_url(self): 66 pipeline_options = PipelineOptions([ 67 '--subnetwork', 68 '/regions/MY/subnetworks/SUBNETWORK', 69 '--temp_location', 70 'gs://any-location/temp' 71 ]) 72 env = apiclient.Environment( 73 [], 74 pipeline_options, 75 '2.0.0', # any environment version 76 FAKE_PIPELINE_URL) 77 78 recovered_options = None 79 for additionalProperty in env.proto.sdkPipelineOptions.additionalProperties: 80 if additionalProperty.key == 'options': 81 recovered_options = additionalProperty.value 82 break 83 else: 84 self.fail( 85 'No pipeline options found in %s' % env.proto.sdkPipelineOptions) 86 87 pipeline_url = None 88 for property in recovered_options.object_value.properties: 89 if property.key == 'pipelineUrl': 90 pipeline_url = property.value 91 break 92 else: 93 self.fail('No pipeline_url found in %s' % recovered_options) 94 95 self.assertEqual(pipeline_url.string_value, FAKE_PIPELINE_URL) 96 97 def test_set_network(self): 98 pipeline_options = PipelineOptions([ 99 '--network', 100 'anetworkname', 101 '--temp_location', 102 'gs://any-location/temp' 103 ]) 104 env = apiclient.Environment( 105 [], #packages 106 pipeline_options, 107 '2.0.0', #any environment version 108 FAKE_PIPELINE_URL) 109 self.assertEqual(env.proto.workerPools[0].network, 'anetworkname') 110 111 def test_set_subnetwork(self): 112 pipeline_options = PipelineOptions([ 113 '--subnetwork', 114 '/regions/MY/subnetworks/SUBNETWORK', 115 '--temp_location', 116 'gs://any-location/temp' 117 ]) 118 119 env = apiclient.Environment( 120 [], #packages 121 pipeline_options, 122 '2.0.0', #any environment version 123 FAKE_PIPELINE_URL) 124 self.assertEqual( 125 env.proto.workerPools[0].subnetwork, 126 '/regions/MY/subnetworks/SUBNETWORK') 127 128 def test_flexrs_blank(self): 129 pipeline_options = PipelineOptions( 130 ['--temp_location', 'gs://any-location/temp']) 131 132 env = apiclient.Environment( 133 [], #packages 134 pipeline_options, 135 '2.0.0', #any environment version 136 FAKE_PIPELINE_URL) 137 self.assertEqual(env.proto.flexResourceSchedulingGoal, None) 138 139 def test_flexrs_cost(self): 140 pipeline_options = PipelineOptions([ 141 '--flexrs_goal', 142 'COST_OPTIMIZED', 143 '--temp_location', 144 'gs://any-location/temp' 145 ]) 146 147 env = apiclient.Environment( 148 [], #packages 149 pipeline_options, 150 '2.0.0', #any environment version 151 FAKE_PIPELINE_URL) 152 self.assertEqual( 153 env.proto.flexResourceSchedulingGoal, 154 ( 155 dataflow.Environment.FlexResourceSchedulingGoalValueValuesEnum. 156 FLEXRS_COST_OPTIMIZED)) 157 158 def test_flexrs_speed(self): 159 pipeline_options = PipelineOptions([ 160 '--flexrs_goal', 161 'SPEED_OPTIMIZED', 162 '--temp_location', 163 'gs://any-location/temp' 164 ]) 165 166 env = apiclient.Environment( 167 [], #packages 168 pipeline_options, 169 '2.0.0', #any environment version 170 FAKE_PIPELINE_URL) 171 self.assertEqual( 172 env.proto.flexResourceSchedulingGoal, 173 ( 174 dataflow.Environment.FlexResourceSchedulingGoalValueValuesEnum. 175 FLEXRS_SPEED_OPTIMIZED)) 176 177 def _verify_sdk_harness_container_images_get_set(self, pipeline_options): 178 pipeline = Pipeline(options=pipeline_options) 179 pipeline | Create([1, 2, 3]) | ParDo(DoFn()) # pylint:disable=expression-not-assigned 180 181 test_environment = DockerEnvironment(container_image='test_default_image') 182 proto_pipeline, _ = pipeline.to_runner_api( 183 return_context=True, default_environment=test_environment) 184 185 dummy_env = beam_runner_api_pb2.Environment( 186 urn=common_urns.environments.DOCKER.urn, 187 payload=( 188 beam_runner_api_pb2.DockerPayload( 189 container_image='dummy_image')).SerializeToString()) 190 dummy_env.capabilities.append( 191 common_urns.protocols.MULTI_CORE_BUNDLE_PROCESSING.urn) 192 proto_pipeline.components.environments['dummy_env_id'].CopyFrom(dummy_env) 193 194 dummy_transform = beam_runner_api_pb2.PTransform( 195 environment_id='dummy_env_id') 196 proto_pipeline.components.transforms['dummy_transform_id'].CopyFrom( 197 dummy_transform) 198 199 env = apiclient.Environment( 200 [], # packages 201 pipeline_options, 202 '2.0.0', # any environment version 203 FAKE_PIPELINE_URL, 204 proto_pipeline) 205 worker_pool = env.proto.workerPools[0] 206 207 self.assertEqual(2, len(worker_pool.sdkHarnessContainerImages)) 208 # Only one of the environments is missing MULTI_CORE_BUNDLE_PROCESSING. 209 self.assertEqual( 210 1, 211 sum( 212 c.useSingleCorePerContainer 213 for c in worker_pool.sdkHarnessContainerImages)) 214 215 env_and_image = [(item.environmentId, item.containerImage) 216 for item in worker_pool.sdkHarnessContainerImages] 217 self.assertIn(('dummy_env_id', 'dummy_image'), env_and_image) 218 self.assertIn((mock.ANY, 'test_default_image'), env_and_image) 219 220 def test_sdk_harness_container_images_get_set_runner_v2(self): 221 pipeline_options = PipelineOptions([ 222 '--experiments=use_runner_v2', 223 '--temp_location', 224 'gs://any-location/temp' 225 ]) 226 227 self._verify_sdk_harness_container_images_get_set(pipeline_options) 228 229 def test_sdk_harness_container_images_get_set_prime(self): 230 pipeline_options = PipelineOptions([ 231 '--dataflow_service_options=enable_prime', 232 '--temp_location', 233 'gs://any-location/temp' 234 ]) 235 236 self._verify_sdk_harness_container_images_get_set(pipeline_options) 237 238 def _verify_sdk_harness_container_image_overrides(self, pipeline_options): 239 test_environment = DockerEnvironment( 240 container_image='dummy_container_image') 241 proto_pipeline, _ = Pipeline().to_runner_api( 242 return_context=True, default_environment=test_environment) 243 244 # Accessing non-public method for testing. 245 apiclient.DataflowApplicationClient._apply_sdk_environment_overrides( 246 proto_pipeline, 247 { 248 '.*dummy.*': 'new_dummy_container_image', 249 '.*notfound.*': 'new_dummy_container_image_2' 250 }, 251 pipeline_options) 252 253 self.assertIsNotNone(1, len(proto_pipeline.components.environments)) 254 env = list(proto_pipeline.components.environments.values())[0] 255 256 from apache_beam.utils import proto_utils 257 docker_payload = proto_utils.parse_Bytes( 258 env.payload, beam_runner_api_pb2.DockerPayload) 259 260 # Container image should be overridden by the given override. 261 self.assertEqual( 262 docker_payload.container_image, 'new_dummy_container_image') 263 264 def test_sdk_harness_container_image_overrides_runner_v2(self): 265 pipeline_options = PipelineOptions([ 266 '--experiments=use_runner_v2', 267 '--temp_location', 268 'gs://any-location/temp' 269 ]) 270 271 self._verify_sdk_harness_container_image_overrides(pipeline_options) 272 273 def test_sdk_harness_container_image_overrides_prime(self): 274 pipeline_options = PipelineOptions([ 275 '--dataflow_service_options=enable_prime', 276 '--temp_location', 277 'gs://any-location/temp' 278 ]) 279 280 self._verify_sdk_harness_container_image_overrides(pipeline_options) 281 282 def _verify_dataflow_container_image_override(self, pipeline_options): 283 pipeline = Pipeline(options=pipeline_options) 284 pipeline | Create([1, 2, 3]) | ParDo(DoFn()) # pylint:disable=expression-not-assigned 285 286 dummy_env = DockerEnvironment( 287 container_image='apache/beam_dummy_name:dummy_tag') 288 proto_pipeline, _ = pipeline.to_runner_api( 289 return_context=True, default_environment=dummy_env) 290 291 # Accessing non-public method for testing. 292 apiclient.DataflowApplicationClient._apply_sdk_environment_overrides( 293 proto_pipeline, {}, pipeline_options) 294 295 from apache_beam.utils import proto_utils 296 found_override = False 297 for env in proto_pipeline.components.environments.values(): 298 docker_payload = proto_utils.parse_Bytes( 299 env.payload, beam_runner_api_pb2.DockerPayload) 300 if docker_payload.container_image.startswith( 301 names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY): 302 found_override = True 303 304 self.assertTrue(found_override) 305 306 def test_dataflow_container_image_override_runner_v2(self): 307 pipeline_options = PipelineOptions([ 308 '--experiments=use_runner_v2', 309 '--temp_location', 310 'gs://any-location/temp' 311 ]) 312 313 self._verify_dataflow_container_image_override(pipeline_options) 314 315 def test_dataflow_container_image_override_prime(self): 316 pipeline_options = PipelineOptions([ 317 '--dataflow_service_options=enable_prime', 318 '--temp_location', 319 'gs://any-location/temp' 320 ]) 321 322 self._verify_dataflow_container_image_override(pipeline_options) 323 324 def _verify_non_apache_container_not_overridden(self, pipeline_options): 325 pipeline = Pipeline(options=pipeline_options) 326 pipeline | Create([1, 2, 3]) | ParDo(DoFn()) # pylint:disable=expression-not-assigned 327 328 dummy_env = DockerEnvironment( 329 container_image='other_org/dummy_name:dummy_tag') 330 proto_pipeline, _ = pipeline.to_runner_api( 331 return_context=True, default_environment=dummy_env) 332 333 # Accessing non-public method for testing. 334 apiclient.DataflowApplicationClient._apply_sdk_environment_overrides( 335 proto_pipeline, {}, pipeline_options) 336 337 self.assertIsNotNone(2, len(proto_pipeline.components.environments)) 338 339 from apache_beam.utils import proto_utils 340 found_override = False 341 for env in proto_pipeline.components.environments.values(): 342 docker_payload = proto_utils.parse_Bytes( 343 env.payload, beam_runner_api_pb2.DockerPayload) 344 if docker_payload.container_image.startswith( 345 names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY): 346 found_override = True 347 348 self.assertFalse(found_override) 349 350 def test_non_apache_container_not_overridden_runner_v2(self): 351 pipeline_options = PipelineOptions([ 352 '--experiments=use_runner_v2', 353 '--temp_location', 354 'gs://any-location/temp' 355 ]) 356 357 self._verify_non_apache_container_not_overridden(pipeline_options) 358 359 def test_non_apache_container_not_overridden_prime(self): 360 pipeline_options = PipelineOptions([ 361 '--dataflow_service_options=enable_prime', 362 '--temp_location', 363 'gs://any-location/temp' 364 ]) 365 366 self._verify_non_apache_container_not_overridden(pipeline_options) 367 368 def _verify_pipeline_sdk_not_overridden(self, pipeline_options): 369 pipeline = Pipeline(options=pipeline_options) 370 pipeline | Create([1, 2, 3]) | ParDo(DoFn()) # pylint:disable=expression-not-assigned 371 372 proto_pipeline, _ = pipeline.to_runner_api(return_context=True) 373 374 dummy_env = DockerEnvironment( 375 container_image='dummy_prefix/dummy_name:dummy_tag') 376 proto_pipeline, _ = pipeline.to_runner_api( 377 return_context=True, default_environment=dummy_env) 378 379 # Accessing non-public method for testing. 380 apiclient.DataflowApplicationClient._apply_sdk_environment_overrides( 381 proto_pipeline, {}, pipeline_options) 382 383 self.assertIsNotNone(2, len(proto_pipeline.components.environments)) 384 385 from apache_beam.utils import proto_utils 386 found_override = False 387 for env in proto_pipeline.components.environments.values(): 388 docker_payload = proto_utils.parse_Bytes( 389 env.payload, beam_runner_api_pb2.DockerPayload) 390 if docker_payload.container_image.startswith( 391 names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY): 392 found_override = True 393 394 self.assertFalse(found_override) 395 396 def test_pipeline_sdk_not_overridden_runner_v2(self): 397 pipeline_options = PipelineOptions([ 398 '--experiments=use_runner_v2', 399 '--temp_location', 400 'gs://any-location/temp', 401 '--sdk_container_image=dummy_prefix/dummy_name:dummy_tag' 402 ]) 403 404 self._verify_pipeline_sdk_not_overridden(pipeline_options) 405 406 def test_pipeline_sdk_not_overridden_prime(self): 407 pipeline_options = PipelineOptions([ 408 '--dataflow_service_options=enable_prime', 409 '--temp_location', 410 'gs://any-location/temp', 411 '--sdk_container_image=dummy_prefix/dummy_name:dummy_tag' 412 ]) 413 414 self._verify_pipeline_sdk_not_overridden(pipeline_options) 415 416 def test_invalid_default_job_name(self): 417 # Regexp for job names in dataflow. 418 regexp = '^[a-z]([-a-z0-9]{0,61}[a-z0-9])?$' 419 420 job_name = apiclient.Job._build_default_job_name('invalid.-_user_n*/ame') 421 self.assertRegex(job_name, regexp) 422 423 job_name = apiclient.Job._build_default_job_name( 424 'invalid-extremely-long.username_that_shouldbeshortened_or_is_invalid') 425 self.assertRegex(job_name, regexp) 426 427 def test_default_job_name(self): 428 job_name = apiclient.Job.default_job_name(None) 429 regexp = 'beamapp-.*-[0-9]{10}-[0-9]{6}-[a-z0-9]{8}$' 430 self.assertRegex(job_name, regexp) 431 432 def test_split_int(self): 433 number = 12345 434 split_number = apiclient.to_split_int(number) 435 self.assertEqual((split_number.lowBits, split_number.highBits), (number, 0)) 436 shift_number = number << 32 437 split_number = apiclient.to_split_int(shift_number) 438 self.assertEqual((split_number.lowBits, split_number.highBits), (0, number)) 439 440 def test_translate_distribution_using_accumulator(self): 441 metric_update = dataflow.CounterUpdate() 442 accumulator = mock.Mock() 443 accumulator.min = 1 444 accumulator.max = 15 445 accumulator.sum = 16 446 accumulator.count = 2 447 apiclient.translate_distribution(accumulator, metric_update) 448 self.assertEqual(metric_update.distribution.min.lowBits, accumulator.min) 449 self.assertEqual(metric_update.distribution.max.lowBits, accumulator.max) 450 self.assertEqual(metric_update.distribution.sum.lowBits, accumulator.sum) 451 self.assertEqual( 452 metric_update.distribution.count.lowBits, accumulator.count) 453 454 def test_translate_distribution_using_distribution_data(self): 455 metric_update = dataflow.CounterUpdate() 456 distribution_update = DistributionData(16, 2, 1, 15) 457 apiclient.translate_distribution(distribution_update, metric_update) 458 self.assertEqual( 459 metric_update.distribution.min.lowBits, distribution_update.min) 460 self.assertEqual( 461 metric_update.distribution.max.lowBits, distribution_update.max) 462 self.assertEqual( 463 metric_update.distribution.sum.lowBits, distribution_update.sum) 464 self.assertEqual( 465 metric_update.distribution.count.lowBits, distribution_update.count) 466 467 def test_translate_distribution_using_dataflow_distribution_counter(self): 468 counter_update = DataflowDistributionCounter() 469 counter_update.add_input(1) 470 counter_update.add_input(3) 471 metric_proto = dataflow.CounterUpdate() 472 apiclient.translate_distribution(counter_update, metric_proto) 473 histogram = mock.Mock(firstBucketOffset=None, bucketCounts=None) 474 counter_update.translate_to_histogram(histogram) 475 self.assertEqual(metric_proto.distribution.min.lowBits, counter_update.min) 476 self.assertEqual(metric_proto.distribution.max.lowBits, counter_update.max) 477 self.assertEqual(metric_proto.distribution.sum.lowBits, counter_update.sum) 478 self.assertEqual( 479 metric_proto.distribution.count.lowBits, counter_update.count) 480 self.assertEqual( 481 metric_proto.distribution.histogram.bucketCounts, 482 histogram.bucketCounts) 483 self.assertEqual( 484 metric_proto.distribution.histogram.firstBucketOffset, 485 histogram.firstBucketOffset) 486 487 def test_translate_means(self): 488 metric_update = dataflow.CounterUpdate() 489 accumulator = mock.Mock() 490 accumulator.sum = 16 491 accumulator.count = 2 492 apiclient.MetricUpdateTranslators.translate_scalar_mean_int( 493 accumulator, metric_update) 494 self.assertEqual(metric_update.integerMean.sum.lowBits, accumulator.sum) 495 self.assertEqual(metric_update.integerMean.count.lowBits, accumulator.count) 496 497 accumulator.sum = 16.0 498 accumulator.count = 2 499 apiclient.MetricUpdateTranslators.translate_scalar_mean_float( 500 accumulator, metric_update) 501 self.assertEqual(metric_update.floatingPointMean.sum, accumulator.sum) 502 self.assertEqual( 503 metric_update.floatingPointMean.count.lowBits, accumulator.count) 504 505 def test_translate_means_using_distribution_accumulator(self): 506 # This is the special case for MeanByteCount. 507 # Which is reported over the FnAPI as a beam distribution, 508 # and to the service as a MetricUpdate IntegerMean. 509 metric_update = dataflow.CounterUpdate() 510 accumulator = mock.Mock() 511 accumulator.min = 7 512 accumulator.max = 9 513 accumulator.sum = 16 514 accumulator.count = 2 515 apiclient.MetricUpdateTranslators.translate_scalar_mean_int( 516 accumulator, metric_update) 517 self.assertEqual(metric_update.integerMean.sum.lowBits, accumulator.sum) 518 self.assertEqual(metric_update.integerMean.count.lowBits, accumulator.count) 519 520 accumulator.sum = 16.0 521 accumulator.count = 2 522 apiclient.MetricUpdateTranslators.translate_scalar_mean_float( 523 accumulator, metric_update) 524 self.assertEqual(metric_update.floatingPointMean.sum, accumulator.sum) 525 self.assertEqual( 526 metric_update.floatingPointMean.count.lowBits, accumulator.count) 527 528 def test_default_ip_configuration(self): 529 pipeline_options = PipelineOptions( 530 ['--temp_location', 'gs://any-location/temp']) 531 env = apiclient.Environment([], 532 pipeline_options, 533 '2.0.0', 534 FAKE_PIPELINE_URL) 535 self.assertEqual(env.proto.workerPools[0].ipConfiguration, None) 536 537 def test_public_ip_configuration(self): 538 pipeline_options = PipelineOptions( 539 ['--temp_location', 'gs://any-location/temp', '--use_public_ips']) 540 env = apiclient.Environment([], 541 pipeline_options, 542 '2.0.0', 543 FAKE_PIPELINE_URL) 544 self.assertEqual( 545 env.proto.workerPools[0].ipConfiguration, 546 dataflow.WorkerPool.IpConfigurationValueValuesEnum.WORKER_IP_PUBLIC) 547 548 def test_private_ip_configuration(self): 549 pipeline_options = PipelineOptions( 550 ['--temp_location', 'gs://any-location/temp', '--no_use_public_ips']) 551 env = apiclient.Environment([], 552 pipeline_options, 553 '2.0.0', 554 FAKE_PIPELINE_URL) 555 self.assertEqual( 556 env.proto.workerPools[0].ipConfiguration, 557 dataflow.WorkerPool.IpConfigurationValueValuesEnum.WORKER_IP_PRIVATE) 558 559 def test_number_of_worker_harness_threads(self): 560 pipeline_options = PipelineOptions([ 561 '--temp_location', 562 'gs://any-location/temp', 563 '--number_of_worker_harness_threads', 564 '2' 565 ]) 566 env = apiclient.Environment([], 567 pipeline_options, 568 '2.0.0', 569 FAKE_PIPELINE_URL) 570 self.assertEqual(env.proto.workerPools[0].numThreadsPerWorker, 2) 571 572 @mock.patch( 573 'apache_beam.runners.dataflow.internal.apiclient.' 574 'beam_version.__version__', 575 '2.2.0') 576 def test_harness_override_absent_with_runner_v2(self): 577 pipeline_options = PipelineOptions([ 578 '--temp_location', 579 'gs://any-location/temp', 580 '--streaming', 581 '--experiments=use_runner_v2' 582 ]) 583 env = apiclient.Environment( 584 [], #packages 585 pipeline_options, 586 '2.0.0', #any environment version 587 FAKE_PIPELINE_URL) 588 if env.proto.experiments: 589 for experiment in env.proto.experiments: 590 self.assertNotIn('runner_harness_container_image=', experiment) 591 592 @mock.patch( 593 'apache_beam.runners.dataflow.internal.apiclient.' 594 'beam_version.__version__', 595 '2.2.0') 596 def test_custom_harness_override_present_with_runner_v2(self): 597 pipeline_options = PipelineOptions([ 598 '--temp_location', 599 'gs://any-location/temp', 600 '--streaming', 601 '--experiments=runner_harness_container_image=fake_image', 602 '--experiments=use_runner_v2', 603 ]) 604 env = apiclient.Environment( 605 [], #packages 606 pipeline_options, 607 '2.0.0', #any environment version 608 FAKE_PIPELINE_URL) 609 self.assertEqual( 610 1, 611 len([ 612 x for x in env.proto.experiments 613 if x.startswith('runner_harness_container_image=') 614 ])) 615 self.assertIn( 616 'runner_harness_container_image=fake_image', env.proto.experiments) 617 618 @mock.patch( 619 'apache_beam.runners.dataflow.internal.apiclient.' 620 'beam_version.__version__', 621 '2.2.0.dev') 622 def test_pinned_worker_harness_image_tag_used_in_dev_sdk(self): 623 # streaming, fnapi pipeline. 624 pipeline_options = PipelineOptions( 625 ['--temp_location', 'gs://any-location/temp', '--streaming']) 626 env = apiclient.Environment( 627 [], #packages 628 pipeline_options, 629 '2.0.0', #any environment version 630 FAKE_PIPELINE_URL) 631 self.assertEqual( 632 env.proto.workerPools[0].workerHarnessContainerImage, 633 ( 634 names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY + 635 '/beam_python%d.%d_sdk:%s' % ( 636 sys.version_info[0], 637 sys.version_info[1], 638 names.BEAM_FNAPI_CONTAINER_VERSION))) 639 640 # batch, legacy pipeline. 641 pipeline_options = PipelineOptions( 642 ['--temp_location', 'gs://any-location/temp']) 643 env = apiclient.Environment( 644 [], #packages 645 pipeline_options, 646 '2.0.0', #any environment version 647 FAKE_PIPELINE_URL) 648 self.assertEqual( 649 env.proto.workerPools[0].workerHarnessContainerImage, 650 ( 651 names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY + '/python%d%d:%s' % ( 652 sys.version_info[0], 653 sys.version_info[1], 654 names.BEAM_CONTAINER_VERSION))) 655 656 @mock.patch( 657 'apache_beam.runners.dataflow.internal.apiclient.' 658 'beam_version.__version__', 659 '2.2.0') 660 def test_worker_harness_image_tag_matches_released_sdk_version(self): 661 # streaming, fnapi pipeline. 662 pipeline_options = PipelineOptions( 663 ['--temp_location', 'gs://any-location/temp', '--streaming']) 664 env = apiclient.Environment( 665 [], #packages 666 pipeline_options, 667 '2.0.0', #any environment version 668 FAKE_PIPELINE_URL) 669 self.assertEqual( 670 env.proto.workerPools[0].workerHarnessContainerImage, 671 ( 672 names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY + 673 '/beam_python%d.%d_sdk:2.2.0' % 674 (sys.version_info[0], sys.version_info[1]))) 675 676 # batch, legacy pipeline. 677 pipeline_options = PipelineOptions( 678 ['--temp_location', 'gs://any-location/temp']) 679 env = apiclient.Environment( 680 [], #packages 681 pipeline_options, 682 '2.0.0', #any environment version 683 FAKE_PIPELINE_URL) 684 self.assertEqual( 685 env.proto.workerPools[0].workerHarnessContainerImage, 686 ( 687 names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY + '/python%d%d:2.2.0' % 688 (sys.version_info[0], sys.version_info[1]))) 689 690 @mock.patch( 691 'apache_beam.runners.dataflow.internal.apiclient.' 692 'beam_version.__version__', 693 '2.2.0.rc1') 694 def test_worker_harness_image_tag_matches_base_sdk_version_of_an_rc(self): 695 # streaming, fnapi pipeline. 696 pipeline_options = PipelineOptions( 697 ['--temp_location', 'gs://any-location/temp', '--streaming']) 698 env = apiclient.Environment( 699 [], #packages 700 pipeline_options, 701 '2.0.0', #any environment version 702 FAKE_PIPELINE_URL) 703 self.assertEqual( 704 env.proto.workerPools[0].workerHarnessContainerImage, 705 ( 706 names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY + 707 '/beam_python%d.%d_sdk:2.2.0' % 708 (sys.version_info[0], sys.version_info[1]))) 709 710 # batch, legacy pipeline. 711 pipeline_options = PipelineOptions( 712 ['--temp_location', 'gs://any-location/temp']) 713 env = apiclient.Environment( 714 [], #packages 715 pipeline_options, 716 '2.0.0', #any environment version 717 FAKE_PIPELINE_URL) 718 self.assertEqual( 719 env.proto.workerPools[0].workerHarnessContainerImage, 720 ( 721 names.DATAFLOW_CONTAINER_IMAGE_REPOSITORY + '/python%d%d:2.2.0' % 722 (sys.version_info[0], sys.version_info[1]))) 723 724 def test_worker_harness_override_takes_precedence_over_sdk_defaults(self): 725 # streaming, fnapi pipeline. 726 pipeline_options = PipelineOptions([ 727 '--temp_location', 728 'gs://any-location/temp', 729 '--streaming', 730 '--sdk_container_image=some:image' 731 ]) 732 env = apiclient.Environment( 733 [], #packages 734 pipeline_options, 735 '2.0.0', #any environment version 736 FAKE_PIPELINE_URL) 737 self.assertEqual( 738 env.proto.workerPools[0].workerHarnessContainerImage, 'some:image') 739 # batch, legacy pipeline. 740 pipeline_options = PipelineOptions([ 741 '--temp_location', 742 'gs://any-location/temp', 743 '--sdk_container_image=some:image' 744 ]) 745 env = apiclient.Environment( 746 [], #packages 747 pipeline_options, 748 '2.0.0', #any environment version 749 FAKE_PIPELINE_URL) 750 self.assertEqual( 751 env.proto.workerPools[0].workerHarnessContainerImage, 'some:image') 752 753 @mock.patch( 754 'apache_beam.runners.dataflow.internal.apiclient.Job.' 755 'job_id_for_name', 756 return_value='test_id') 757 def test_transform_name_mapping(self, mock_job): 758 pipeline_options = PipelineOptions([ 759 '--project', 760 'test_project', 761 '--job_name', 762 'test_job_name', 763 '--temp_location', 764 'gs://test-location/temp', 765 '--update', 766 '--transform_name_mapping', 767 '{\"from\":\"to\"}' 768 ]) 769 job = apiclient.Job(pipeline_options, beam_runner_api_pb2.Pipeline()) 770 self.assertIsNotNone(job.proto.transformNameMapping) 771 772 def test_created_from_snapshot_id(self): 773 pipeline_options = PipelineOptions([ 774 '--project', 775 'test_project', 776 '--job_name', 777 'test_job_name', 778 '--temp_location', 779 'gs://test-location/temp', 780 '--create_from_snapshot', 781 'test_snapshot_id' 782 ]) 783 job = apiclient.Job(pipeline_options, beam_runner_api_pb2.Pipeline()) 784 self.assertEqual('test_snapshot_id', job.proto.createdFromSnapshotId) 785 786 def test_labels(self): 787 pipeline_options = PipelineOptions([ 788 '--project', 789 'test_project', 790 '--job_name', 791 'test_job_name', 792 '--temp_location', 793 'gs://test-location/temp' 794 ]) 795 job = apiclient.Job(pipeline_options, beam_runner_api_pb2.Pipeline()) 796 self.assertIsNone(job.proto.labels) 797 798 pipeline_options = PipelineOptions([ 799 '--project', 800 'test_project', 801 '--job_name', 802 'test_job_name', 803 '--temp_location', 804 'gs://test-location/temp', 805 '--label', 806 'key1=value1', 807 '--label', 808 'key2', 809 '--label', 810 'key3=value3', 811 '--labels', 812 'key4=value4', 813 '--labels', 814 'key5' 815 ]) 816 job = apiclient.Job(pipeline_options, beam_runner_api_pb2.Pipeline()) 817 self.assertEqual(5, len(job.proto.labels.additionalProperties)) 818 self.assertEqual('key1', job.proto.labels.additionalProperties[0].key) 819 self.assertEqual('value1', job.proto.labels.additionalProperties[0].value) 820 self.assertEqual('key2', job.proto.labels.additionalProperties[1].key) 821 self.assertEqual('', job.proto.labels.additionalProperties[1].value) 822 self.assertEqual('key3', job.proto.labels.additionalProperties[2].key) 823 self.assertEqual('value3', job.proto.labels.additionalProperties[2].value) 824 self.assertEqual('key4', job.proto.labels.additionalProperties[3].key) 825 self.assertEqual('value4', job.proto.labels.additionalProperties[3].value) 826 self.assertEqual('key5', job.proto.labels.additionalProperties[4].key) 827 self.assertEqual('', job.proto.labels.additionalProperties[4].value) 828 829 pipeline_options = PipelineOptions([ 830 '--project', 831 'test_project', 832 '--job_name', 833 'test_job_name', 834 '--temp_location', 835 'gs://test-location/temp', 836 '--labels', 837 '{ "name": "wrench", "mass": "1_3kg", "count": "3" }' 838 ]) 839 job = apiclient.Job(pipeline_options, beam_runner_api_pb2.Pipeline()) 840 self.assertEqual(3, len(job.proto.labels.additionalProperties)) 841 self.assertEqual('name', job.proto.labels.additionalProperties[0].key) 842 self.assertEqual('wrench', job.proto.labels.additionalProperties[0].value) 843 self.assertEqual('mass', job.proto.labels.additionalProperties[1].key) 844 self.assertEqual('1_3kg', job.proto.labels.additionalProperties[1].value) 845 self.assertEqual('count', job.proto.labels.additionalProperties[2].key) 846 self.assertEqual('3', job.proto.labels.additionalProperties[2].value) 847 848 def test_experiment_use_multiple_sdk_containers(self): 849 pipeline_options = PipelineOptions([ 850 '--project', 851 'test_project', 852 '--job_name', 853 'test_job_name', 854 '--temp_location', 855 'gs://test-location/temp', 856 '--experiments', 857 'beam_fn_api' 858 ]) 859 environment = apiclient.Environment([], 860 pipeline_options, 861 1, 862 FAKE_PIPELINE_URL) 863 self.assertIn('use_multiple_sdk_containers', environment.proto.experiments) 864 865 pipeline_options = PipelineOptions([ 866 '--project', 867 'test_project', 868 '--job_name', 869 'test_job_name', 870 '--temp_location', 871 'gs://test-location/temp', 872 '--experiments', 873 'beam_fn_api', 874 '--experiments', 875 'use_multiple_sdk_containers' 876 ]) 877 environment = apiclient.Environment([], 878 pipeline_options, 879 1, 880 FAKE_PIPELINE_URL) 881 self.assertIn('use_multiple_sdk_containers', environment.proto.experiments) 882 883 pipeline_options = PipelineOptions([ 884 '--project', 885 'test_project', 886 '--job_name', 887 'test_job_name', 888 '--temp_location', 889 'gs://test-location/temp', 890 '--experiments', 891 'beam_fn_api', 892 '--experiments', 893 'no_use_multiple_sdk_containers' 894 ]) 895 environment = apiclient.Environment([], 896 pipeline_options, 897 1, 898 FAKE_PIPELINE_URL) 899 self.assertNotIn( 900 'use_multiple_sdk_containers', environment.proto.experiments) 901 902 @mock.patch( 903 'apache_beam.runners.dataflow.internal.apiclient.sys.version_info', 904 (3, 8)) 905 def test_get_python_sdk_name(self): 906 pipeline_options = PipelineOptions([ 907 '--project', 908 'test_project', 909 '--job_name', 910 'test_job_name', 911 '--temp_location', 912 'gs://test-location/temp', 913 '--experiments', 914 'beam_fn_api', 915 '--experiments', 916 'use_multiple_sdk_containers' 917 ]) 918 environment = apiclient.Environment([], 919 pipeline_options, 920 1, 921 FAKE_PIPELINE_URL) 922 self.assertEqual( 923 'Apache Beam Python 3.8 SDK', environment._get_python_sdk_name()) 924 925 @mock.patch( 926 'apache_beam.runners.dataflow.internal.apiclient.sys.version_info', 927 (2, 7)) 928 @mock.patch( 929 'apache_beam.runners.dataflow.internal.apiclient.' 930 'beam_version.__version__', 931 '2.2.0') 932 def test_interpreter_version_check_fails_py27(self): 933 pipeline_options = PipelineOptions([]) 934 self.assertRaises( 935 Exception, 936 apiclient._verify_interpreter_version_is_supported, 937 pipeline_options) 938 939 @mock.patch( 940 'apache_beam.runners.dataflow.internal.apiclient.sys.version_info', 941 (3, 0, 0)) 942 @mock.patch( 943 'apache_beam.runners.dataflow.internal.apiclient.' 944 'beam_version.__version__', 945 '2.2.0.dev') 946 def test_interpreter_version_check_passes_on_dev_sdks(self): 947 pipeline_options = PipelineOptions([]) 948 apiclient._verify_interpreter_version_is_supported(pipeline_options) 949 950 @mock.patch( 951 'apache_beam.runners.dataflow.internal.apiclient.' 952 'beam_version.__version__', 953 '2.2.0') 954 @mock.patch( 955 'apache_beam.runners.dataflow.internal.apiclient.sys.version_info', 956 (3, 0, 0)) 957 def test_interpreter_version_check_passes_with_experiment(self): 958 pipeline_options = PipelineOptions( 959 ["--experiment=use_unsupported_python_version"]) 960 apiclient._verify_interpreter_version_is_supported(pipeline_options) 961 962 @mock.patch( 963 'apache_beam.runners.dataflow.internal.apiclient.sys.version_info', 964 (3, 8, 2)) 965 @mock.patch( 966 'apache_beam.runners.dataflow.internal.apiclient.' 967 'beam_version.__version__', 968 '2.2.0') 969 def test_interpreter_version_check_passes_py38(self): 970 pipeline_options = PipelineOptions([]) 971 apiclient._verify_interpreter_version_is_supported(pipeline_options) 972 973 @mock.patch( 974 'apache_beam.runners.dataflow.internal.apiclient.sys.version_info', 975 (3, 12, 0)) 976 @mock.patch( 977 'apache_beam.runners.dataflow.internal.apiclient.' 978 'beam_version.__version__', 979 '2.2.0') 980 def test_interpreter_version_check_fails_on_not_yet_supported_version(self): 981 pipeline_options = PipelineOptions([]) 982 self.assertRaises( 983 Exception, 984 apiclient._verify_interpreter_version_is_supported, 985 pipeline_options) 986 987 def test_get_response_encoding(self): 988 encoding = apiclient.get_response_encoding() 989 990 assert encoding == 'utf8' 991 992 def test_graph_is_uploaded(self): 993 pipeline_options = PipelineOptions([ 994 '--project', 995 'test_project', 996 '--job_name', 997 'test_job_name', 998 '--temp_location', 999 'gs://test-location/temp', 1000 '--experiments', 1001 'beam_fn_api', 1002 '--experiments', 1003 'upload_graph' 1004 ]) 1005 job = apiclient.Job(pipeline_options, beam_runner_api_pb2.Pipeline()) 1006 pipeline_options.view_as(GoogleCloudOptions).no_auth = True 1007 client = apiclient.DataflowApplicationClient(pipeline_options) 1008 with mock.patch.object(client, 'stage_file', side_effect=None): 1009 with mock.patch.object(client, 'create_job_description', 1010 side_effect=None): 1011 with mock.patch.object(client, 1012 'submit_job_description', 1013 side_effect=None): 1014 client.create_job(job) 1015 client.stage_file.assert_called_once_with( 1016 mock.ANY, "dataflow_graph.json", mock.ANY) 1017 client.create_job_description.assert_called_once() 1018 1019 def test_create_job_returns_existing_job(self): 1020 pipeline_options = PipelineOptions([ 1021 '--project', 1022 'test_project', 1023 '--job_name', 1024 'test_job_name', 1025 '--temp_location', 1026 'gs://test-location/temp', 1027 ]) 1028 job = apiclient.Job(pipeline_options, beam_runner_api_pb2.Pipeline()) 1029 self.assertTrue(job.proto.clientRequestId) # asserts non-empty string 1030 pipeline_options.view_as(GoogleCloudOptions).no_auth = True 1031 client = apiclient.DataflowApplicationClient(pipeline_options) 1032 1033 response = dataflow.Job() 1034 # different clientRequestId from `job` 1035 response.clientRequestId = "20210821081910123456-1234" 1036 response.name = 'test_job_name' 1037 response.id = '2021-08-19_21_18_43-9756917246311111021' 1038 1039 with mock.patch.object(client._client.projects_locations_jobs, 1040 'Create', 1041 side_effect=[response]): 1042 with mock.patch.object(client, 'create_job_description', 1043 side_effect=None): 1044 with self.assertRaises( 1045 apiclient.DataflowJobAlreadyExistsError) as context: 1046 client.create_job(job) 1047 1048 self.assertEqual( 1049 str(context.exception), 1050 'There is already active job named %s with id: %s. If you want to ' 1051 'submit a second job, try again by setting a different name using ' 1052 '--job_name.' % ('test_job_name', response.id)) 1053 1054 def test_update_job_returns_existing_job(self): 1055 pipeline_options = PipelineOptions([ 1056 '--project', 1057 'test_project', 1058 '--job_name', 1059 'test_job_name', 1060 '--temp_location', 1061 'gs://test-location/temp', 1062 '--region', 1063 'us-central1', 1064 '--update', 1065 ]) 1066 replace_job_id = '2021-08-21_00_00_01-6081497447916622336' 1067 with mock.patch('apache_beam.runners.dataflow.internal.apiclient.Job.' 1068 'job_id_for_name', 1069 return_value=replace_job_id) as job_id_for_name_mock: 1070 job = apiclient.Job(pipeline_options, beam_runner_api_pb2.Pipeline()) 1071 job_id_for_name_mock.assert_called_once() 1072 1073 self.assertTrue(job.proto.clientRequestId) # asserts non-empty string 1074 1075 pipeline_options.view_as(GoogleCloudOptions).no_auth = True 1076 client = apiclient.DataflowApplicationClient(pipeline_options) 1077 1078 response = dataflow.Job() 1079 # different clientRequestId from `job` 1080 response.clientRequestId = "20210821083254123456-1234" 1081 response.name = 'test_job_name' 1082 response.id = '2021-08-19_21_29_07-5725551945600207770' 1083 1084 with mock.patch.object(client, 'create_job_description', side_effect=None): 1085 with mock.patch.object(client._client.projects_locations_jobs, 1086 'Create', 1087 side_effect=[response]): 1088 1089 with self.assertRaises( 1090 apiclient.DataflowJobAlreadyExistsError) as context: 1091 client.create_job(job) 1092 1093 self.assertEqual( 1094 str(context.exception), 1095 'The job named %s with id: %s has already been updated into job ' 1096 'id: %s and cannot be updated again.' % 1097 ('test_job_name', replace_job_id, response.id)) 1098 1099 def test_template_file_generation_with_upload_graph(self): 1100 pipeline_options = PipelineOptions([ 1101 '--project', 1102 'test_project', 1103 '--job_name', 1104 'test_job_name', 1105 '--temp_location', 1106 'gs://test-location/temp', 1107 '--experiments', 1108 'upload_graph', 1109 '--template_location', 1110 'gs://test-location/template' 1111 ]) 1112 job = apiclient.Job(pipeline_options, beam_runner_api_pb2.Pipeline()) 1113 job.proto.steps.append(dataflow.Step(name='test_step_name')) 1114 1115 pipeline_options.view_as(GoogleCloudOptions).no_auth = True 1116 client = apiclient.DataflowApplicationClient(pipeline_options) 1117 with mock.patch.object(client, 'stage_file', side_effect=None): 1118 with mock.patch.object(client, 'create_job_description', 1119 side_effect=None): 1120 with mock.patch.object(client, 1121 'submit_job_description', 1122 side_effect=None): 1123 client.create_job(job) 1124 1125 client.stage_file.assert_has_calls([ 1126 mock.call(mock.ANY, 'dataflow_graph.json', mock.ANY), 1127 mock.call(mock.ANY, 'template', mock.ANY) 1128 ]) 1129 client.create_job_description.assert_called_once() 1130 # template is generated, but job should not be submitted to the 1131 # service. 1132 client.submit_job_description.assert_not_called() 1133 1134 template_filename = client.stage_file.call_args_list[-1][0][1] 1135 self.assertTrue('template' in template_filename) 1136 template_content = client.stage_file.call_args_list[-1][0][2].read( 1137 ).decode('utf-8') 1138 template_obj = json.loads(template_content) 1139 self.assertFalse(template_obj.get('steps')) 1140 self.assertTrue(template_obj['stepsLocation']) 1141 1142 def test_stage_resources(self): 1143 pipeline_options = PipelineOptions([ 1144 '--temp_location', 1145 'gs://test-location/temp', 1146 '--staging_location', 1147 'gs://test-location/staging', 1148 '--no_auth' 1149 ]) 1150 pipeline = beam_runner_api_pb2.Pipeline( 1151 components=beam_runner_api_pb2.Components( 1152 environments={ 1153 'env1': beam_runner_api_pb2.Environment( 1154 dependencies=[ 1155 beam_runner_api_pb2.ArtifactInformation( 1156 type_urn=common_urns.artifact_types.FILE.urn, 1157 type_payload=beam_runner_api_pb2. 1158 ArtifactFilePayload( 1159 path='/tmp/foo1').SerializeToString(), 1160 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1161 role_payload=beam_runner_api_pb2. 1162 ArtifactStagingToRolePayload( 1163 staged_name='foo1').SerializeToString()), 1164 beam_runner_api_pb2.ArtifactInformation( 1165 type_urn=common_urns.artifact_types.FILE.urn, 1166 type_payload=beam_runner_api_pb2. 1167 ArtifactFilePayload( 1168 path='/tmp/bar1').SerializeToString(), 1169 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1170 role_payload=beam_runner_api_pb2. 1171 ArtifactStagingToRolePayload( 1172 staged_name='bar1').SerializeToString()), 1173 beam_runner_api_pb2.ArtifactInformation( 1174 type_urn=common_urns.artifact_types.FILE.urn, 1175 type_payload=beam_runner_api_pb2. 1176 ArtifactFilePayload( 1177 path='/tmp/baz').SerializeToString(), 1178 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1179 role_payload=beam_runner_api_pb2. 1180 ArtifactStagingToRolePayload( 1181 staged_name='baz1').SerializeToString()), 1182 beam_runner_api_pb2.ArtifactInformation( 1183 type_urn=common_urns.artifact_types.FILE.urn, 1184 type_payload=beam_runner_api_pb2. 1185 ArtifactFilePayload( 1186 path='/tmp/renamed1', 1187 sha256='abcdefg').SerializeToString(), 1188 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1189 role_payload=beam_runner_api_pb2. 1190 ArtifactStagingToRolePayload( 1191 staged_name='renamed1').SerializeToString()) 1192 ]), 1193 'env2': beam_runner_api_pb2.Environment( 1194 dependencies=[ 1195 beam_runner_api_pb2.ArtifactInformation( 1196 type_urn=common_urns.artifact_types.FILE.urn, 1197 type_payload=beam_runner_api_pb2. 1198 ArtifactFilePayload( 1199 path='/tmp/foo2').SerializeToString(), 1200 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1201 role_payload=beam_runner_api_pb2. 1202 ArtifactStagingToRolePayload( 1203 staged_name='foo2').SerializeToString()), 1204 beam_runner_api_pb2.ArtifactInformation( 1205 type_urn=common_urns.artifact_types.FILE.urn, 1206 type_payload=beam_runner_api_pb2. 1207 ArtifactFilePayload( 1208 path='/tmp/bar2').SerializeToString(), 1209 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1210 role_payload=beam_runner_api_pb2. 1211 ArtifactStagingToRolePayload( 1212 staged_name='bar2').SerializeToString()), 1213 beam_runner_api_pb2.ArtifactInformation( 1214 type_urn=common_urns.artifact_types.FILE.urn, 1215 type_payload=beam_runner_api_pb2. 1216 ArtifactFilePayload( 1217 path='/tmp/baz').SerializeToString(), 1218 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1219 role_payload=beam_runner_api_pb2. 1220 ArtifactStagingToRolePayload( 1221 staged_name='baz2').SerializeToString()), 1222 beam_runner_api_pb2.ArtifactInformation( 1223 type_urn=common_urns.artifact_types.FILE.urn, 1224 type_payload=beam_runner_api_pb2. 1225 ArtifactFilePayload( 1226 path='/tmp/renamed2', 1227 sha256='abcdefg').SerializeToString(), 1228 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1229 role_payload=beam_runner_api_pb2. 1230 ArtifactStagingToRolePayload( 1231 staged_name='renamed2').SerializeToString()) 1232 ]) 1233 })) 1234 client = apiclient.DataflowApplicationClient(pipeline_options) 1235 with mock.patch.object(apiclient._LegacyDataflowStager, 1236 'stage_job_resources') as mock_stager: 1237 client._stage_resources(pipeline, pipeline_options) 1238 mock_stager.assert_called_once_with( 1239 [('/tmp/foo1', 'foo1', ''), ('/tmp/bar1', 'bar1', ''), 1240 ('/tmp/baz', 'baz1', ''), ('/tmp/renamed1', 'renamed1', 'abcdefg'), 1241 ('/tmp/foo2', 'foo2', ''), ('/tmp/bar2', 'bar2', '')], 1242 staging_location='gs://test-location/staging') 1243 1244 pipeline_expected = beam_runner_api_pb2.Pipeline( 1245 components=beam_runner_api_pb2.Components( 1246 environments={ 1247 'env1': beam_runner_api_pb2.Environment( 1248 dependencies=[ 1249 beam_runner_api_pb2.ArtifactInformation( 1250 type_urn=common_urns.artifact_types.URL.urn, 1251 type_payload=beam_runner_api_pb2.ArtifactUrlPayload( 1252 url='gs://test-location/staging/foo1' 1253 ).SerializeToString(), 1254 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1255 role_payload=beam_runner_api_pb2. 1256 ArtifactStagingToRolePayload( 1257 staged_name='foo1').SerializeToString()), 1258 beam_runner_api_pb2.ArtifactInformation( 1259 type_urn=common_urns.artifact_types.URL.urn, 1260 type_payload=beam_runner_api_pb2.ArtifactUrlPayload( 1261 url='gs://test-location/staging/bar1'). 1262 SerializeToString(), 1263 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1264 role_payload=beam_runner_api_pb2. 1265 ArtifactStagingToRolePayload( 1266 staged_name='bar1').SerializeToString()), 1267 beam_runner_api_pb2.ArtifactInformation( 1268 type_urn=common_urns.artifact_types.URL.urn, 1269 type_payload=beam_runner_api_pb2.ArtifactUrlPayload( 1270 url='gs://test-location/staging/baz1'). 1271 SerializeToString(), 1272 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1273 role_payload=beam_runner_api_pb2. 1274 ArtifactStagingToRolePayload( 1275 staged_name='baz1').SerializeToString()), 1276 beam_runner_api_pb2.ArtifactInformation( 1277 type_urn=common_urns.artifact_types.URL.urn, 1278 type_payload=beam_runner_api_pb2.ArtifactUrlPayload( 1279 url='gs://test-location/staging/renamed1', 1280 sha256='abcdefg').SerializeToString(), 1281 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1282 role_payload=beam_runner_api_pb2. 1283 ArtifactStagingToRolePayload( 1284 staged_name='renamed1').SerializeToString()) 1285 ]), 1286 'env2': beam_runner_api_pb2.Environment( 1287 dependencies=[ 1288 beam_runner_api_pb2.ArtifactInformation( 1289 type_urn=common_urns.artifact_types.URL.urn, 1290 type_payload=beam_runner_api_pb2.ArtifactUrlPayload( 1291 url='gs://test-location/staging/foo2'). 1292 SerializeToString(), 1293 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1294 role_payload=beam_runner_api_pb2. 1295 ArtifactStagingToRolePayload( 1296 staged_name='foo2').SerializeToString()), 1297 beam_runner_api_pb2.ArtifactInformation( 1298 type_urn=common_urns.artifact_types.URL.urn, 1299 type_payload=beam_runner_api_pb2.ArtifactUrlPayload( 1300 url='gs://test-location/staging/bar2'). 1301 SerializeToString(), 1302 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1303 role_payload=beam_runner_api_pb2. 1304 ArtifactStagingToRolePayload( 1305 staged_name='bar2').SerializeToString()), 1306 beam_runner_api_pb2.ArtifactInformation( 1307 type_urn=common_urns.artifact_types.URL.urn, 1308 type_payload=beam_runner_api_pb2.ArtifactUrlPayload( 1309 url='gs://test-location/staging/baz1'). 1310 SerializeToString(), 1311 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1312 role_payload=beam_runner_api_pb2. 1313 ArtifactStagingToRolePayload( 1314 staged_name='baz1').SerializeToString()), 1315 beam_runner_api_pb2.ArtifactInformation( 1316 type_urn=common_urns.artifact_types.URL.urn, 1317 type_payload=beam_runner_api_pb2.ArtifactUrlPayload( 1318 url='gs://test-location/staging/renamed1', 1319 sha256='abcdefg').SerializeToString(), 1320 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1321 role_payload=beam_runner_api_pb2. 1322 ArtifactStagingToRolePayload( 1323 staged_name='renamed1').SerializeToString()) 1324 ]) 1325 })) 1326 self.assertEqual(pipeline, pipeline_expected) 1327 1328 def test_set_dataflow_service_option(self): 1329 pipeline_options = PipelineOptions([ 1330 '--dataflow_service_option', 1331 'whizz=bang', 1332 '--temp_location', 1333 'gs://any-location/temp' 1334 ]) 1335 env = apiclient.Environment( 1336 [], #packages 1337 pipeline_options, 1338 '2.0.0', #any environment version 1339 FAKE_PIPELINE_URL) 1340 self.assertEqual(env.proto.serviceOptions, ['whizz=bang']) 1341 1342 def test_enable_hot_key_logging(self): 1343 # Tests that the enable_hot_key_logging is not set by default. 1344 pipeline_options = PipelineOptions( 1345 ['--temp_location', 'gs://any-location/temp']) 1346 env = apiclient.Environment( 1347 [], #packages 1348 pipeline_options, 1349 '2.0.0', #any environment version 1350 FAKE_PIPELINE_URL) 1351 self.assertIsNone(env.proto.debugOptions) 1352 1353 # Now test that it is set when given. 1354 pipeline_options = PipelineOptions([ 1355 '--enable_hot_key_logging', '--temp_location', 'gs://any-location/temp' 1356 ]) 1357 env = apiclient.Environment( 1358 [], #packages 1359 pipeline_options, 1360 '2.0.0', #any environment version 1361 FAKE_PIPELINE_URL) 1362 self.assertEqual( 1363 env.proto.debugOptions, dataflow.DebugOptions(enableHotKeyLogging=True)) 1364 1365 def _mock_uncached_copy(self, staging_root, src, sha256, dst_name=None): 1366 sha_prefix = sha256[0:2] 1367 gcs_cache_path = FileSystems.join( 1368 staging_root, 1369 apiclient.DataflowApplicationClient._GCS_CACHE_PREFIX, 1370 sha_prefix, 1371 sha256) 1372 1373 if not dst_name: 1374 _, dst_name = os.path.split(src) 1375 return [ 1376 mock.call.gcs_exists(gcs_cache_path), 1377 mock.call.gcs_upload(src, gcs_cache_path), 1378 mock.call.gcs_gcs_copy( 1379 source_file_names=[gcs_cache_path], 1380 destination_file_names=[f'gs://test-location/staging/{dst_name}']) 1381 ] 1382 1383 def _mock_cached_copy(self, staging_root, src, sha256, dst_name=None): 1384 uncached = self._mock_uncached_copy(staging_root, src, sha256, dst_name) 1385 uncached.pop(1) 1386 return uncached 1387 1388 def test_stage_artifacts_with_caching(self): 1389 pipeline_options = PipelineOptions([ 1390 '--temp_location', 1391 'gs://test-location/temp', 1392 '--staging_location', 1393 'gs://test-location/staging', 1394 '--no_auth', 1395 '--enable_artifact_caching' 1396 ]) 1397 pipeline = beam_runner_api_pb2.Pipeline( 1398 components=beam_runner_api_pb2.Components( 1399 environments={ 1400 'env1': beam_runner_api_pb2.Environment( 1401 dependencies=[ 1402 beam_runner_api_pb2.ArtifactInformation( 1403 type_urn=common_urns.artifact_types.FILE.urn, 1404 type_payload=beam_runner_api_pb2. 1405 ArtifactFilePayload( 1406 path='/tmp/foo1', 1407 sha256='abcd').SerializeToString(), 1408 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1409 role_payload=beam_runner_api_pb2. 1410 ArtifactStagingToRolePayload( 1411 staged_name='foo1').SerializeToString()), 1412 beam_runner_api_pb2.ArtifactInformation( 1413 type_urn=common_urns.artifact_types.FILE.urn, 1414 type_payload=beam_runner_api_pb2. 1415 ArtifactFilePayload( 1416 path='/tmp/bar1', 1417 sha256='defg').SerializeToString(), 1418 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1419 role_payload=beam_runner_api_pb2. 1420 ArtifactStagingToRolePayload( 1421 staged_name='bar1').SerializeToString()), 1422 beam_runner_api_pb2.ArtifactInformation( 1423 type_urn=common_urns.artifact_types.FILE.urn, 1424 type_payload=beam_runner_api_pb2. 1425 ArtifactFilePayload(path='/tmp/baz', sha256='hijk' 1426 ).SerializeToString(), 1427 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1428 role_payload=beam_runner_api_pb2. 1429 ArtifactStagingToRolePayload( 1430 staged_name='baz1').SerializeToString()), 1431 beam_runner_api_pb2.ArtifactInformation( 1432 type_urn=common_urns.artifact_types.FILE.urn, 1433 type_payload=beam_runner_api_pb2. 1434 ArtifactFilePayload( 1435 path='/tmp/renamed1', 1436 sha256='abcdefg').SerializeToString(), 1437 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1438 role_payload=beam_runner_api_pb2. 1439 ArtifactStagingToRolePayload( 1440 staged_name='renamed1').SerializeToString()) 1441 ]), 1442 'env2': beam_runner_api_pb2.Environment( 1443 dependencies=[ 1444 beam_runner_api_pb2.ArtifactInformation( 1445 type_urn=common_urns.artifact_types.FILE.urn, 1446 type_payload=beam_runner_api_pb2. 1447 ArtifactFilePayload( 1448 path='/tmp/foo2', 1449 sha256='lmno').SerializeToString(), 1450 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1451 role_payload=beam_runner_api_pb2. 1452 ArtifactStagingToRolePayload( 1453 staged_name='foo2').SerializeToString()), 1454 beam_runner_api_pb2.ArtifactInformation( 1455 type_urn=common_urns.artifact_types.FILE.urn, 1456 type_payload=beam_runner_api_pb2. 1457 ArtifactFilePayload( 1458 path='/tmp/bar2', 1459 sha256='pqrs').SerializeToString(), 1460 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1461 role_payload=beam_runner_api_pb2. 1462 ArtifactStagingToRolePayload( 1463 staged_name='bar2').SerializeToString()), 1464 beam_runner_api_pb2.ArtifactInformation( 1465 type_urn=common_urns.artifact_types.FILE.urn, 1466 type_payload=beam_runner_api_pb2. 1467 ArtifactFilePayload(path='/tmp/baz', sha256='tuv' 1468 ).SerializeToString(), 1469 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1470 role_payload=beam_runner_api_pb2. 1471 ArtifactStagingToRolePayload( 1472 staged_name='baz2').SerializeToString()), 1473 beam_runner_api_pb2.ArtifactInformation( 1474 type_urn=common_urns.artifact_types.FILE.urn, 1475 type_payload=beam_runner_api_pb2. 1476 ArtifactFilePayload( 1477 path='/tmp/renamed2', 1478 sha256='abcdefg').SerializeToString(), 1479 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1480 role_payload=beam_runner_api_pb2. 1481 ArtifactStagingToRolePayload( 1482 staged_name='renamed2').SerializeToString()) 1483 ]) 1484 })) 1485 client = apiclient.DataflowApplicationClient(pipeline_options) 1486 staging_root = 'gs://test-location/staging' 1487 1488 # every other artifact already exists 1489 n = [0] 1490 1491 def exists_return_value(*args): 1492 n[0] += 1 1493 return n[0] % 2 == 0 1494 1495 with mock.patch.object(FileSystems, 1496 'exists', 1497 side_effect=exists_return_value) as mock_gcs_exists: 1498 with mock.patch.object(apiclient.DataflowApplicationClient, 1499 '_uncached_gcs_file_copy') as mock_gcs_copy: 1500 with mock.patch.object(FileSystems, 'copy') as mock_gcs_gcs_copy: 1501 1502 manager = mock.Mock() 1503 manager.attach_mock(mock_gcs_exists, 'gcs_exists') 1504 manager.attach_mock(mock_gcs_copy, 'gcs_upload') 1505 manager.attach_mock(mock_gcs_gcs_copy, 'gcs_gcs_copy') 1506 1507 client._stage_resources(pipeline, pipeline_options) 1508 expected_calls = list( 1509 itertools.chain.from_iterable([ 1510 self._mock_uncached_copy(staging_root, '/tmp/foo1', 'abcd'), 1511 self._mock_cached_copy(staging_root, '/tmp/bar1', 'defg'), 1512 self._mock_uncached_copy( 1513 staging_root, '/tmp/baz', 'hijk', 'baz1'), 1514 self._mock_cached_copy( 1515 staging_root, '/tmp/renamed1', 'abcdefg'), 1516 self._mock_uncached_copy(staging_root, '/tmp/foo2', 'lmno'), 1517 self._mock_cached_copy(staging_root, '/tmp/bar2', 'pqrs'), 1518 ])) 1519 assert manager.mock_calls == expected_calls 1520 1521 pipeline_expected = beam_runner_api_pb2.Pipeline( 1522 components=beam_runner_api_pb2.Components( 1523 environments={ 1524 'env1': beam_runner_api_pb2.Environment( 1525 dependencies=[ 1526 beam_runner_api_pb2.ArtifactInformation( 1527 type_urn=common_urns.artifact_types.URL.urn, 1528 type_payload=beam_runner_api_pb2.ArtifactUrlPayload( 1529 url='gs://test-location/staging/foo1', 1530 sha256='abcd').SerializeToString(), 1531 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1532 role_payload=beam_runner_api_pb2. 1533 ArtifactStagingToRolePayload( 1534 staged_name='foo1').SerializeToString()), 1535 beam_runner_api_pb2.ArtifactInformation( 1536 type_urn=common_urns.artifact_types.URL.urn, 1537 type_payload=beam_runner_api_pb2.ArtifactUrlPayload( 1538 url='gs://test-location/staging/bar1', 1539 sha256='defg').SerializeToString(), 1540 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1541 role_payload=beam_runner_api_pb2. 1542 ArtifactStagingToRolePayload( 1543 staged_name='bar1').SerializeToString()), 1544 beam_runner_api_pb2.ArtifactInformation( 1545 type_urn=common_urns.artifact_types.URL.urn, 1546 type_payload=beam_runner_api_pb2.ArtifactUrlPayload( 1547 url='gs://test-location/staging/baz1', 1548 sha256='hijk').SerializeToString(), 1549 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1550 role_payload=beam_runner_api_pb2. 1551 ArtifactStagingToRolePayload( 1552 staged_name='baz1').SerializeToString()), 1553 beam_runner_api_pb2.ArtifactInformation( 1554 type_urn=common_urns.artifact_types.URL.urn, 1555 type_payload=beam_runner_api_pb2.ArtifactUrlPayload( 1556 url='gs://test-location/staging/renamed1', 1557 sha256='abcdefg').SerializeToString(), 1558 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1559 role_payload=beam_runner_api_pb2. 1560 ArtifactStagingToRolePayload( 1561 staged_name='renamed1').SerializeToString()) 1562 ]), 1563 'env2': beam_runner_api_pb2.Environment( 1564 dependencies=[ 1565 beam_runner_api_pb2.ArtifactInformation( 1566 type_urn=common_urns.artifact_types.URL.urn, 1567 type_payload=beam_runner_api_pb2.ArtifactUrlPayload( 1568 url='gs://test-location/staging/foo2', 1569 sha256='lmno').SerializeToString(), 1570 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1571 role_payload=beam_runner_api_pb2. 1572 ArtifactStagingToRolePayload( 1573 staged_name='foo2').SerializeToString()), 1574 beam_runner_api_pb2.ArtifactInformation( 1575 type_urn=common_urns.artifact_types.URL.urn, 1576 type_payload=beam_runner_api_pb2.ArtifactUrlPayload( 1577 url='gs://test-location/staging/bar2', 1578 sha256='pqrs').SerializeToString(), 1579 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1580 role_payload=beam_runner_api_pb2. 1581 ArtifactStagingToRolePayload( 1582 staged_name='bar2').SerializeToString()), 1583 beam_runner_api_pb2.ArtifactInformation( 1584 type_urn=common_urns.artifact_types.URL.urn, 1585 type_payload=beam_runner_api_pb2.ArtifactUrlPayload( 1586 url='gs://test-location/staging/baz1', 1587 sha256='tuv').SerializeToString(), 1588 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1589 role_payload=beam_runner_api_pb2. 1590 ArtifactStagingToRolePayload( 1591 staged_name='baz1').SerializeToString()), 1592 beam_runner_api_pb2.ArtifactInformation( 1593 type_urn=common_urns.artifact_types.URL.urn, 1594 type_payload=beam_runner_api_pb2.ArtifactUrlPayload( 1595 url='gs://test-location/staging/renamed1', 1596 sha256='abcdefg').SerializeToString(), 1597 role_urn=common_urns.artifact_roles.STAGING_TO.urn, 1598 role_payload=beam_runner_api_pb2. 1599 ArtifactStagingToRolePayload( 1600 staged_name='renamed1').SerializeToString()) 1601 ]) 1602 })) 1603 self.assertEqual(pipeline, pipeline_expected) 1604 1605 1606 if __name__ == '__main__': 1607 unittest.main()