github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/dask/dask_runner_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  import inspect
    18  import unittest
    19  
    20  import apache_beam as beam
    21  from apache_beam.options.pipeline_options import PipelineOptions
    22  from apache_beam.testing import test_pipeline
    23  from apache_beam.testing.util import assert_that
    24  from apache_beam.testing.util import equal_to
    25  
    26  try:
    27    from apache_beam.runners.dask.dask_runner import DaskOptions
    28    from apache_beam.runners.dask.dask_runner import DaskRunner
    29    import dask
    30    import dask.distributed as ddist
    31  except (ImportError, ModuleNotFoundError):
    32    raise unittest.SkipTest('Dask must be installed to run tests.')
    33  
    34  
    35  class DaskOptionsTest(unittest.TestCase):
    36    def test_parses_connection_timeout__defaults_to_none(self):
    37      default_options = PipelineOptions([])
    38      default_dask_options = default_options.view_as(DaskOptions)
    39      self.assertEqual(None, default_dask_options.timeout)
    40  
    41    def test_parses_connection_timeout__parses_int(self):
    42      conn_options = PipelineOptions('--dask_connection_timeout 12'.split())
    43      dask_conn_options = conn_options.view_as(DaskOptions)
    44      self.assertEqual(12, dask_conn_options.timeout)
    45  
    46    def test_parses_connection_timeout__handles_bad_input(self):
    47      err_options = PipelineOptions('--dask_connection_timeout foo'.split())
    48      dask_err_options = err_options.view_as(DaskOptions)
    49      self.assertEqual(dask.config.no_default, dask_err_options.timeout)
    50  
    51    def test_parser_destinations__agree_with_dask_client(self):
    52      options = PipelineOptions(
    53          '--dask_client_address localhost:8080 --dask_connection_timeout 600 '
    54          '--dask_scheduler_file foobar.cfg --dask_client_name charlie '
    55          '--dask_connection_limit 1024'.split())
    56      dask_options = options.view_as(DaskOptions)
    57  
    58      # Get the argument names for the constructor.
    59      client_args = list(inspect.signature(ddist.Client).parameters)
    60  
    61      for opt_name in dask_options.get_all_options(drop_default=True).keys():
    62        with self.subTest(f'{opt_name} in dask.distributed.Client constructor'):
    63          self.assertIn(opt_name, client_args)
    64  
    65  
    66  class DaskRunnerRunPipelineTest(unittest.TestCase):
    67    """Test class used to introspect the dask runner via a debugger."""
    68    def setUp(self) -> None:
    69      self.pipeline = test_pipeline.TestPipeline(runner=DaskRunner())
    70  
    71    def test_create(self):
    72      with self.pipeline as p:
    73        pcoll = p | beam.Create([1])
    74        assert_that(pcoll, equal_to([1]))
    75  
    76    def test_create_and_map(self):
    77      def double(x):
    78        return x * 2
    79  
    80      with self.pipeline as p:
    81        pcoll = p | beam.Create([1]) | beam.Map(double)
    82        assert_that(pcoll, equal_to([2]))
    83  
    84    def test_create_map_and_groupby(self):
    85      def double(x):
    86        return x * 2, x
    87  
    88      with self.pipeline as p:
    89        pcoll = p | beam.Create([1]) | beam.Map(double) | beam.GroupByKey()
    90        assert_that(pcoll, equal_to([(2, [1])]))
    91  
    92  
    93  if __name__ == '__main__':
    94    unittest.main()