github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/dask/dask_runner_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 import inspect 18 import unittest 19 20 import apache_beam as beam 21 from apache_beam.options.pipeline_options import PipelineOptions 22 from apache_beam.testing import test_pipeline 23 from apache_beam.testing.util import assert_that 24 from apache_beam.testing.util import equal_to 25 26 try: 27 from apache_beam.runners.dask.dask_runner import DaskOptions 28 from apache_beam.runners.dask.dask_runner import DaskRunner 29 import dask 30 import dask.distributed as ddist 31 except (ImportError, ModuleNotFoundError): 32 raise unittest.SkipTest('Dask must be installed to run tests.') 33 34 35 class DaskOptionsTest(unittest.TestCase): 36 def test_parses_connection_timeout__defaults_to_none(self): 37 default_options = PipelineOptions([]) 38 default_dask_options = default_options.view_as(DaskOptions) 39 self.assertEqual(None, default_dask_options.timeout) 40 41 def test_parses_connection_timeout__parses_int(self): 42 conn_options = PipelineOptions('--dask_connection_timeout 12'.split()) 43 dask_conn_options = conn_options.view_as(DaskOptions) 44 self.assertEqual(12, dask_conn_options.timeout) 45 46 def test_parses_connection_timeout__handles_bad_input(self): 47 err_options = PipelineOptions('--dask_connection_timeout foo'.split()) 48 dask_err_options = err_options.view_as(DaskOptions) 49 self.assertEqual(dask.config.no_default, dask_err_options.timeout) 50 51 def test_parser_destinations__agree_with_dask_client(self): 52 options = PipelineOptions( 53 '--dask_client_address localhost:8080 --dask_connection_timeout 600 ' 54 '--dask_scheduler_file foobar.cfg --dask_client_name charlie ' 55 '--dask_connection_limit 1024'.split()) 56 dask_options = options.view_as(DaskOptions) 57 58 # Get the argument names for the constructor. 59 client_args = list(inspect.signature(ddist.Client).parameters) 60 61 for opt_name in dask_options.get_all_options(drop_default=True).keys(): 62 with self.subTest(f'{opt_name} in dask.distributed.Client constructor'): 63 self.assertIn(opt_name, client_args) 64 65 66 class DaskRunnerRunPipelineTest(unittest.TestCase): 67 """Test class used to introspect the dask runner via a debugger.""" 68 def setUp(self) -> None: 69 self.pipeline = test_pipeline.TestPipeline(runner=DaskRunner()) 70 71 def test_create(self): 72 with self.pipeline as p: 73 pcoll = p | beam.Create([1]) 74 assert_that(pcoll, equal_to([1])) 75 76 def test_create_and_map(self): 77 def double(x): 78 return x * 2 79 80 with self.pipeline as p: 81 pcoll = p | beam.Create([1]) | beam.Map(double) 82 assert_that(pcoll, equal_to([2])) 83 84 def test_create_map_and_groupby(self): 85 def double(x): 86 return x * 2, x 87 88 with self.pipeline as p: 89 pcoll = p | beam.Create([1]) | beam.Map(double) | beam.GroupByKey() 90 assert_that(pcoll, equal_to([(2, [1])])) 91 92 93 if __name__ == '__main__': 94 unittest.main()