github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/options/value_provider_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Unit tests for the ValueProvider class.""" 19 20 # pytype: skip-file 21 22 import logging 23 import unittest 24 25 from mock import Mock 26 27 from apache_beam.options.pipeline_options import DebugOptions 28 from apache_beam.options.pipeline_options import PipelineOptions 29 from apache_beam.options.value_provider import NestedValueProvider 30 from apache_beam.options.value_provider import RuntimeValueProvider 31 from apache_beam.options.value_provider import StaticValueProvider 32 33 34 # TODO(https://github.com/apache/beam/issues/18197): Require unique names only 35 # within a test. For now, <file name acronym>_vp_arg<number> will be the 36 # convention to name value-provider arguments in tests, as opposed to 37 # <file name acronym>_non_vp_arg<number> for non-value-provider arguments. 38 # The number will grow per file as tests are added. 39 class ValueProviderTests(unittest.TestCase): 40 def setUp(self): 41 # Reset runtime options to avoid side-effects caused by other tests. 42 # Note that is_accessible assertions require runtime_options to 43 # be uninitialized. 44 RuntimeValueProvider.set_runtime_options(None) 45 46 def tearDown(self): 47 # Reset runtime options to avoid side-effects in other tests. 48 RuntimeValueProvider.set_runtime_options(None) 49 50 def test_static_value_provider_keyword_argument(self): 51 class UserDefinedOptions(PipelineOptions): 52 @classmethod 53 def _add_argparse_args(cls, parser): 54 parser.add_value_provider_argument( 55 '--vpt_vp_arg1', 56 help='This keyword argument is a value provider', 57 default='some value') 58 59 options = UserDefinedOptions(['--vpt_vp_arg1', 'abc']) 60 self.assertTrue(isinstance(options.vpt_vp_arg1, StaticValueProvider)) 61 self.assertTrue(options.vpt_vp_arg1.is_accessible()) 62 self.assertEqual(options.vpt_vp_arg1.get(), 'abc') 63 64 def test_runtime_value_provider_keyword_argument(self): 65 class UserDefinedOptions(PipelineOptions): 66 @classmethod 67 def _add_argparse_args(cls, parser): 68 parser.add_value_provider_argument( 69 '--vpt_vp_arg2', help='This keyword argument is a value provider') 70 71 options = UserDefinedOptions() 72 self.assertTrue(isinstance(options.vpt_vp_arg2, RuntimeValueProvider)) 73 self.assertFalse(options.vpt_vp_arg2.is_accessible()) 74 with self.assertRaises(RuntimeError): 75 options.vpt_vp_arg2.get() 76 77 def test_static_value_provider_positional_argument(self): 78 class UserDefinedOptions(PipelineOptions): 79 @classmethod 80 def _add_argparse_args(cls, parser): 81 parser.add_value_provider_argument( 82 'vpt_vp_arg3', 83 help='This positional argument is a value provider', 84 default='some value') 85 86 options = UserDefinedOptions(['abc']) 87 self.assertTrue(isinstance(options.vpt_vp_arg3, StaticValueProvider)) 88 self.assertTrue(options.vpt_vp_arg3.is_accessible()) 89 self.assertEqual(options.vpt_vp_arg3.get(), 'abc') 90 91 def test_runtime_value_provider_positional_argument(self): 92 class UserDefinedOptions(PipelineOptions): 93 @classmethod 94 def _add_argparse_args(cls, parser): 95 parser.add_value_provider_argument( 96 'vpt_vp_arg4', help='This positional argument is a value provider') 97 98 options = UserDefinedOptions([]) 99 self.assertTrue(isinstance(options.vpt_vp_arg4, RuntimeValueProvider)) 100 self.assertFalse(options.vpt_vp_arg4.is_accessible()) 101 with self.assertRaises(RuntimeError): 102 options.vpt_vp_arg4.get() 103 104 def test_static_value_provider_type_cast(self): 105 class UserDefinedOptions(PipelineOptions): 106 @classmethod 107 def _add_argparse_args(cls, parser): 108 parser.add_value_provider_argument( 109 '--vpt_vp_arg5', type=int, help='This flag is a value provider') 110 111 options = UserDefinedOptions(['--vpt_vp_arg5', '123']) 112 self.assertTrue(isinstance(options.vpt_vp_arg5, StaticValueProvider)) 113 self.assertTrue(options.vpt_vp_arg5.is_accessible()) 114 self.assertEqual(options.vpt_vp_arg5.get(), 123) 115 116 def test_set_runtime_option(self): 117 # define ValueProvider options, with and without default values 118 class UserDefinedOptions1(PipelineOptions): 119 @classmethod 120 def _add_argparse_args(cls, parser): 121 parser.add_value_provider_argument( 122 '--vpt_vp_arg6', 123 help='This keyword argument is a value provider') # set at runtime 124 125 parser.add_value_provider_argument( # not set, had default int 126 '-v', '--vpt_vp_arg7', # with short form 127 default=123, 128 type=int) 129 130 parser.add_value_provider_argument( # not set, had default str 131 '--vpt_vp-arg8', # with dash in name 132 default='123', 133 type=str) 134 135 parser.add_value_provider_argument( # not set and no default 136 '--vpt_vp_arg9', 137 type=float) 138 139 parser.add_value_provider_argument( # positional argument set 140 'vpt_vp_arg10', # default & runtime ignored 141 help='This positional argument is a value provider', 142 type=float, 143 default=5.4) 144 145 # provide values at graph-construction time 146 # (options not provided here become of the type RuntimeValueProvider) 147 options = UserDefinedOptions1(['1.2']) 148 self.assertFalse(options.vpt_vp_arg6.is_accessible()) 149 self.assertFalse(options.vpt_vp_arg7.is_accessible()) 150 self.assertFalse(options.vpt_vp_arg8.is_accessible()) 151 self.assertFalse(options.vpt_vp_arg9.is_accessible()) 152 self.assertTrue(options.vpt_vp_arg10.is_accessible()) 153 154 # provide values at job-execution time 155 # (options not provided here will use their default, if they have one) 156 RuntimeValueProvider.set_runtime_options({ 157 'vpt_vp_arg6': 'abc', 'vpt_vp_arg10': '3.2' 158 }) 159 self.assertTrue(options.vpt_vp_arg6.is_accessible()) 160 self.assertEqual(options.vpt_vp_arg6.get(), 'abc') 161 self.assertTrue(options.vpt_vp_arg7.is_accessible()) 162 self.assertEqual(options.vpt_vp_arg7.get(), 123) 163 self.assertTrue(options.vpt_vp_arg8.is_accessible()) 164 self.assertEqual(options.vpt_vp_arg8.get(), '123') 165 self.assertTrue(options.vpt_vp_arg9.is_accessible()) 166 self.assertIsNone(options.vpt_vp_arg9.get()) 167 self.assertTrue(options.vpt_vp_arg10.is_accessible()) 168 self.assertEqual(options.vpt_vp_arg10.get(), 1.2) 169 170 def test_choices(self): 171 class UserDefinedOptions(PipelineOptions): 172 @classmethod 173 def _add_argparse_args(cls, parser): 174 parser.add_argument( 175 '--vpt_vp_arg11', 176 choices=['a', 'b'], 177 help='This flag is a value provider with concrete choices') 178 parser.add_argument( 179 '--vpt_vp_arg12', 180 choices=[1, 2], 181 type=int, 182 help='This flag is a value provider with concrete choices') 183 184 options = UserDefinedOptions(['--vpt_vp_arg11', 'a', '--vpt_vp_arg12', '2']) 185 self.assertEqual(options.vpt_vp_arg11, 'a') 186 self.assertEqual(options.vpt_vp_arg12, 2) 187 188 def test_static_value_provider_choices(self): 189 class UserDefinedOptions(PipelineOptions): 190 @classmethod 191 def _add_argparse_args(cls, parser): 192 parser.add_value_provider_argument( 193 '--vpt_vp_arg13', 194 choices=['a', 'b'], 195 help='This flag is a value provider with concrete choices') 196 parser.add_value_provider_argument( 197 '--vpt_vp_arg14', 198 choices=[1, 2], 199 type=int, 200 help='This flag is a value provider with concrete choices') 201 202 options = UserDefinedOptions(['--vpt_vp_arg13', 'a', '--vpt_vp_arg14', '2']) 203 self.assertEqual(options.vpt_vp_arg13.get(), 'a') 204 self.assertEqual(options.vpt_vp_arg14.get(), 2) 205 206 def test_experiments_setup(self): 207 self.assertFalse('feature_1' in RuntimeValueProvider.experiments) 208 209 RuntimeValueProvider.set_runtime_options( 210 {'experiments': ['feature_1', 'feature_2']}) 211 self.assertTrue(isinstance(RuntimeValueProvider.experiments, set)) 212 self.assertTrue('feature_1' in RuntimeValueProvider.experiments) 213 self.assertTrue('feature_2' in RuntimeValueProvider.experiments) 214 215 def test_experiments_options_setup(self): 216 options = PipelineOptions(['--experiments', 'a', '--experiments', 'b,c']) 217 options = options.view_as(DebugOptions) 218 self.assertIn('a', options.experiments) 219 self.assertIn('b,c', options.experiments) 220 self.assertNotIn('c', options.experiments) 221 222 def test_nested_value_provider_wrap_static(self): 223 vp = NestedValueProvider(StaticValueProvider(int, 1), lambda x: x + 1) 224 225 self.assertTrue(vp.is_accessible()) 226 self.assertEqual(vp.get(), 2) 227 228 def test_nested_value_provider_caches_value(self): 229 mock_fn = Mock() 230 231 def translator(x): 232 mock_fn() 233 return x 234 235 vp = NestedValueProvider(StaticValueProvider(int, 1), translator) 236 237 vp.get() 238 self.assertEqual(mock_fn.call_count, 1) 239 vp.get() 240 self.assertEqual(mock_fn.call_count, 1) 241 242 def test_nested_value_provider_wrap_runtime(self): 243 class UserDefinedOptions(PipelineOptions): 244 @classmethod 245 def _add_argparse_args(cls, parser): 246 parser.add_value_provider_argument( 247 '--vpt_vp_arg15', 248 help='This keyword argument is a value provider') # set at runtime 249 250 options = UserDefinedOptions([]) 251 vp = NestedValueProvider(options.vpt_vp_arg15, lambda x: x + x) 252 self.assertFalse(vp.is_accessible()) 253 254 RuntimeValueProvider.set_runtime_options({'vpt_vp_arg15': 'abc'}) 255 256 self.assertTrue(vp.is_accessible()) 257 self.assertEqual(vp.get(), 'abcabc') 258 259 260 if __name__ == '__main__': 261 logging.getLogger().setLevel(logging.INFO) 262 unittest.main()