github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/options/value_provider_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Unit tests for the ValueProvider class."""
    19  
    20  # pytype: skip-file
    21  
    22  import logging
    23  import unittest
    24  
    25  from mock import Mock
    26  
    27  from apache_beam.options.pipeline_options import DebugOptions
    28  from apache_beam.options.pipeline_options import PipelineOptions
    29  from apache_beam.options.value_provider import NestedValueProvider
    30  from apache_beam.options.value_provider import RuntimeValueProvider
    31  from apache_beam.options.value_provider import StaticValueProvider
    32  
    33  
    34  # TODO(https://github.com/apache/beam/issues/18197): Require unique names only
    35  # within a test. For now, <file name acronym>_vp_arg<number> will be the
    36  # convention to name value-provider arguments in tests, as opposed to
    37  # <file name acronym>_non_vp_arg<number> for non-value-provider arguments.
    38  # The number will grow per file as tests are added.
    39  class ValueProviderTests(unittest.TestCase):
    40    def setUp(self):
    41      # Reset runtime options to avoid side-effects caused by other tests.
    42      # Note that is_accessible assertions require runtime_options to
    43      # be uninitialized.
    44      RuntimeValueProvider.set_runtime_options(None)
    45  
    46    def tearDown(self):
    47      # Reset runtime options to avoid side-effects in other tests.
    48      RuntimeValueProvider.set_runtime_options(None)
    49  
    50    def test_static_value_provider_keyword_argument(self):
    51      class UserDefinedOptions(PipelineOptions):
    52        @classmethod
    53        def _add_argparse_args(cls, parser):
    54          parser.add_value_provider_argument(
    55              '--vpt_vp_arg1',
    56              help='This keyword argument is a value provider',
    57              default='some value')
    58  
    59      options = UserDefinedOptions(['--vpt_vp_arg1', 'abc'])
    60      self.assertTrue(isinstance(options.vpt_vp_arg1, StaticValueProvider))
    61      self.assertTrue(options.vpt_vp_arg1.is_accessible())
    62      self.assertEqual(options.vpt_vp_arg1.get(), 'abc')
    63  
    64    def test_runtime_value_provider_keyword_argument(self):
    65      class UserDefinedOptions(PipelineOptions):
    66        @classmethod
    67        def _add_argparse_args(cls, parser):
    68          parser.add_value_provider_argument(
    69              '--vpt_vp_arg2', help='This keyword argument is a value provider')
    70  
    71      options = UserDefinedOptions()
    72      self.assertTrue(isinstance(options.vpt_vp_arg2, RuntimeValueProvider))
    73      self.assertFalse(options.vpt_vp_arg2.is_accessible())
    74      with self.assertRaises(RuntimeError):
    75        options.vpt_vp_arg2.get()
    76  
    77    def test_static_value_provider_positional_argument(self):
    78      class UserDefinedOptions(PipelineOptions):
    79        @classmethod
    80        def _add_argparse_args(cls, parser):
    81          parser.add_value_provider_argument(
    82              'vpt_vp_arg3',
    83              help='This positional argument is a value provider',
    84              default='some value')
    85  
    86      options = UserDefinedOptions(['abc'])
    87      self.assertTrue(isinstance(options.vpt_vp_arg3, StaticValueProvider))
    88      self.assertTrue(options.vpt_vp_arg3.is_accessible())
    89      self.assertEqual(options.vpt_vp_arg3.get(), 'abc')
    90  
    91    def test_runtime_value_provider_positional_argument(self):
    92      class UserDefinedOptions(PipelineOptions):
    93        @classmethod
    94        def _add_argparse_args(cls, parser):
    95          parser.add_value_provider_argument(
    96              'vpt_vp_arg4', help='This positional argument is a value provider')
    97  
    98      options = UserDefinedOptions([])
    99      self.assertTrue(isinstance(options.vpt_vp_arg4, RuntimeValueProvider))
   100      self.assertFalse(options.vpt_vp_arg4.is_accessible())
   101      with self.assertRaises(RuntimeError):
   102        options.vpt_vp_arg4.get()
   103  
   104    def test_static_value_provider_type_cast(self):
   105      class UserDefinedOptions(PipelineOptions):
   106        @classmethod
   107        def _add_argparse_args(cls, parser):
   108          parser.add_value_provider_argument(
   109              '--vpt_vp_arg5', type=int, help='This flag is a value provider')
   110  
   111      options = UserDefinedOptions(['--vpt_vp_arg5', '123'])
   112      self.assertTrue(isinstance(options.vpt_vp_arg5, StaticValueProvider))
   113      self.assertTrue(options.vpt_vp_arg5.is_accessible())
   114      self.assertEqual(options.vpt_vp_arg5.get(), 123)
   115  
   116    def test_set_runtime_option(self):
   117      # define ValueProvider options, with and without default values
   118      class UserDefinedOptions1(PipelineOptions):
   119        @classmethod
   120        def _add_argparse_args(cls, parser):
   121          parser.add_value_provider_argument(
   122              '--vpt_vp_arg6',
   123              help='This keyword argument is a value provider')  # set at runtime
   124  
   125          parser.add_value_provider_argument(         # not set, had default int
   126              '-v', '--vpt_vp_arg7',                      # with short form
   127              default=123,
   128              type=int)
   129  
   130          parser.add_value_provider_argument(         # not set, had default str
   131              '--vpt_vp-arg8',                            # with dash in name
   132              default='123',
   133              type=str)
   134  
   135          parser.add_value_provider_argument(         # not set and no default
   136              '--vpt_vp_arg9',
   137              type=float)
   138  
   139          parser.add_value_provider_argument(         # positional argument set
   140              'vpt_vp_arg10',                         # default & runtime ignored
   141              help='This positional argument is a value provider',
   142              type=float,
   143              default=5.4)
   144  
   145      # provide values at graph-construction time
   146      # (options not provided here become of the type RuntimeValueProvider)
   147      options = UserDefinedOptions1(['1.2'])
   148      self.assertFalse(options.vpt_vp_arg6.is_accessible())
   149      self.assertFalse(options.vpt_vp_arg7.is_accessible())
   150      self.assertFalse(options.vpt_vp_arg8.is_accessible())
   151      self.assertFalse(options.vpt_vp_arg9.is_accessible())
   152      self.assertTrue(options.vpt_vp_arg10.is_accessible())
   153  
   154      # provide values at job-execution time
   155      # (options not provided here will use their default, if they have one)
   156      RuntimeValueProvider.set_runtime_options({
   157          'vpt_vp_arg6': 'abc', 'vpt_vp_arg10': '3.2'
   158      })
   159      self.assertTrue(options.vpt_vp_arg6.is_accessible())
   160      self.assertEqual(options.vpt_vp_arg6.get(), 'abc')
   161      self.assertTrue(options.vpt_vp_arg7.is_accessible())
   162      self.assertEqual(options.vpt_vp_arg7.get(), 123)
   163      self.assertTrue(options.vpt_vp_arg8.is_accessible())
   164      self.assertEqual(options.vpt_vp_arg8.get(), '123')
   165      self.assertTrue(options.vpt_vp_arg9.is_accessible())
   166      self.assertIsNone(options.vpt_vp_arg9.get())
   167      self.assertTrue(options.vpt_vp_arg10.is_accessible())
   168      self.assertEqual(options.vpt_vp_arg10.get(), 1.2)
   169  
   170    def test_choices(self):
   171      class UserDefinedOptions(PipelineOptions):
   172        @classmethod
   173        def _add_argparse_args(cls, parser):
   174          parser.add_argument(
   175              '--vpt_vp_arg11',
   176              choices=['a', 'b'],
   177              help='This flag is a value provider with concrete choices')
   178          parser.add_argument(
   179              '--vpt_vp_arg12',
   180              choices=[1, 2],
   181              type=int,
   182              help='This flag is a value provider with concrete choices')
   183  
   184      options = UserDefinedOptions(['--vpt_vp_arg11', 'a', '--vpt_vp_arg12', '2'])
   185      self.assertEqual(options.vpt_vp_arg11, 'a')
   186      self.assertEqual(options.vpt_vp_arg12, 2)
   187  
   188    def test_static_value_provider_choices(self):
   189      class UserDefinedOptions(PipelineOptions):
   190        @classmethod
   191        def _add_argparse_args(cls, parser):
   192          parser.add_value_provider_argument(
   193              '--vpt_vp_arg13',
   194              choices=['a', 'b'],
   195              help='This flag is a value provider with concrete choices')
   196          parser.add_value_provider_argument(
   197              '--vpt_vp_arg14',
   198              choices=[1, 2],
   199              type=int,
   200              help='This flag is a value provider with concrete choices')
   201  
   202      options = UserDefinedOptions(['--vpt_vp_arg13', 'a', '--vpt_vp_arg14', '2'])
   203      self.assertEqual(options.vpt_vp_arg13.get(), 'a')
   204      self.assertEqual(options.vpt_vp_arg14.get(), 2)
   205  
   206    def test_experiments_setup(self):
   207      self.assertFalse('feature_1' in RuntimeValueProvider.experiments)
   208  
   209      RuntimeValueProvider.set_runtime_options(
   210          {'experiments': ['feature_1', 'feature_2']})
   211      self.assertTrue(isinstance(RuntimeValueProvider.experiments, set))
   212      self.assertTrue('feature_1' in RuntimeValueProvider.experiments)
   213      self.assertTrue('feature_2' in RuntimeValueProvider.experiments)
   214  
   215    def test_experiments_options_setup(self):
   216      options = PipelineOptions(['--experiments', 'a', '--experiments', 'b,c'])
   217      options = options.view_as(DebugOptions)
   218      self.assertIn('a', options.experiments)
   219      self.assertIn('b,c', options.experiments)
   220      self.assertNotIn('c', options.experiments)
   221  
   222    def test_nested_value_provider_wrap_static(self):
   223      vp = NestedValueProvider(StaticValueProvider(int, 1), lambda x: x + 1)
   224  
   225      self.assertTrue(vp.is_accessible())
   226      self.assertEqual(vp.get(), 2)
   227  
   228    def test_nested_value_provider_caches_value(self):
   229      mock_fn = Mock()
   230  
   231      def translator(x):
   232        mock_fn()
   233        return x
   234  
   235      vp = NestedValueProvider(StaticValueProvider(int, 1), translator)
   236  
   237      vp.get()
   238      self.assertEqual(mock_fn.call_count, 1)
   239      vp.get()
   240      self.assertEqual(mock_fn.call_count, 1)
   241  
   242    def test_nested_value_provider_wrap_runtime(self):
   243      class UserDefinedOptions(PipelineOptions):
   244        @classmethod
   245        def _add_argparse_args(cls, parser):
   246          parser.add_value_provider_argument(
   247              '--vpt_vp_arg15',
   248              help='This keyword argument is a value provider')  # set at runtime
   249  
   250      options = UserDefinedOptions([])
   251      vp = NestedValueProvider(options.vpt_vp_arg15, lambda x: x + x)
   252      self.assertFalse(vp.is_accessible())
   253  
   254      RuntimeValueProvider.set_runtime_options({'vpt_vp_arg15': 'abc'})
   255  
   256      self.assertTrue(vp.is_accessible())
   257      self.assertEqual(vp.get(), 'abcabc')
   258  
   259  
   260  if __name__ == '__main__':
   261    logging.getLogger().setLevel(logging.INFO)
   262    unittest.main()