github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/display_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Unit tests for the DisplayData API.""" 19 20 # pytype: skip-file 21 22 import unittest 23 from datetime import datetime 24 25 # pylint: disable=ungrouped-imports 26 import hamcrest as hc 27 from hamcrest.core.base_matcher import BaseMatcher 28 29 import apache_beam as beam 30 from apache_beam.options.pipeline_options import PipelineOptions 31 from apache_beam.transforms.display import DisplayData 32 from apache_beam.transforms.display import DisplayDataItem 33 from apache_beam.transforms.display import HasDisplayData 34 35 # pylint: enable=ungrouped-imports 36 37 38 class DisplayDataItemMatcher(BaseMatcher): 39 """ Matcher class for DisplayDataItems in unit tests. 40 """ 41 IGNORED = object() 42 43 def __init__( 44 self, 45 key=IGNORED, 46 value=IGNORED, 47 namespace=IGNORED, 48 label=IGNORED, 49 shortValue=IGNORED): 50 if all(member == DisplayDataItemMatcher.IGNORED 51 for member in [key, value, namespace, label, shortValue]): 52 raise ValueError('Must receive at least one item attribute to match') 53 54 self.key = key 55 self.value = value 56 self.namespace = namespace 57 self.label = label 58 self.shortValue = shortValue 59 60 def _matches(self, item): 61 if self.key != DisplayDataItemMatcher.IGNORED and item.key != self.key: 62 return False 63 if (self.namespace != DisplayDataItemMatcher.IGNORED and 64 item.namespace != self.namespace): 65 return False 66 if (self.value != DisplayDataItemMatcher.IGNORED and 67 item.value != self.value): 68 return False 69 if (self.label != DisplayDataItemMatcher.IGNORED and 70 item.label != self.label): 71 return False 72 if (self.shortValue != DisplayDataItemMatcher.IGNORED and 73 item.shortValue != self.shortValue): 74 return False 75 return True 76 77 def describe_to(self, description): 78 descriptors = [] 79 if self.key != DisplayDataItemMatcher.IGNORED: 80 descriptors.append('key is {}'.format(self.key)) 81 if self.value != DisplayDataItemMatcher.IGNORED: 82 descriptors.append('value is {}'.format(self.value)) 83 if self.namespace != DisplayDataItemMatcher.IGNORED: 84 descriptors.append('namespace is {}'.format(self.namespace)) 85 if self.label != DisplayDataItemMatcher.IGNORED: 86 descriptors.append('label is {}'.format(self.label)) 87 if self.shortValue != DisplayDataItemMatcher.IGNORED: 88 descriptors.append('shortValue is {}'.format(self.shortValue)) 89 90 item_description = '{}'.format(' and '.join(descriptors)) 91 description.append(item_description) 92 93 94 class DisplayDataTest(unittest.TestCase): 95 def test_display_data_item_matcher(self): 96 with self.assertRaises(ValueError): 97 DisplayDataItemMatcher() 98 99 def test_inheritance_ptransform(self): 100 class MyTransform(beam.PTransform): 101 pass 102 103 display_pt = MyTransform() 104 # PTransform inherits from HasDisplayData. 105 self.assertTrue(isinstance(display_pt, HasDisplayData)) 106 self.assertEqual(display_pt.display_data(), {}) 107 108 def test_inheritance_dofn(self): 109 class MyDoFn(beam.DoFn): 110 pass 111 112 display_dofn = MyDoFn() 113 self.assertTrue(isinstance(display_dofn, HasDisplayData)) 114 self.assertEqual(display_dofn.display_data(), {}) 115 116 def test_unsupported_type_display_data(self): 117 class MyDisplayComponent(HasDisplayData): 118 def display_data(self): 119 return {'item_key': 'item_value'} 120 121 with self.assertRaises(ValueError): 122 DisplayData.create_from_options(MyDisplayComponent()) 123 124 def test_value_provider_display_data(self): 125 class TestOptions(PipelineOptions): 126 @classmethod 127 def _add_argparse_args(cls, parser): 128 parser.add_value_provider_argument( 129 '--int_flag', type=int, help='int_flag description') 130 parser.add_value_provider_argument( 131 '--str_flag', 132 type=str, 133 default='hello', 134 help='str_flag description') 135 parser.add_value_provider_argument( 136 '--float_flag', type=float, help='float_flag description') 137 138 options = TestOptions(['--int_flag', '1']) 139 items = DisplayData.create_from_options(options).items 140 expected_items = [ 141 DisplayDataItemMatcher('int_flag', '1'), 142 DisplayDataItemMatcher( 143 'str_flag', 144 'RuntimeValueProvider(option: str_flag,' 145 ' type: str, default_value: \'hello\')'), 146 DisplayDataItemMatcher( 147 'float_flag', 148 'RuntimeValueProvider(option: float_flag,' 149 ' type: float, default_value: None)') 150 ] 151 hc.assert_that(items, hc.has_items(*expected_items)) 152 153 def test_create_list_display_data(self): 154 flags = ['--extra_package', 'package1', '--extra_package', 'package2'] 155 pipeline_options = PipelineOptions(flags=flags) 156 items = DisplayData.create_from_options(pipeline_options).items 157 hc.assert_that( 158 items, 159 hc.has_items( 160 DisplayDataItemMatcher( 161 'extra_packages', str(['package1', 'package2'])))) 162 163 def test_unicode_type_display_data(self): 164 class MyDoFn(beam.DoFn): 165 def display_data(self): 166 return { 167 'unicode_string': 'my string', 168 'unicode_literal_string': u'my literal string' 169 } 170 171 fn = MyDoFn() 172 dd = DisplayData.create_from(fn) 173 for item in dd.items: 174 self.assertEqual(item.type, 'STRING') 175 176 def test_base_cases(self): 177 """ Tests basic display data cases (key:value, key:dict) 178 It does not test subcomponent inclusion 179 """ 180 class MyDoFn(beam.DoFn): 181 def __init__(self, my_display_data=None): 182 self.my_display_data = my_display_data 183 184 def process(self, element): 185 yield element + 1 186 187 def display_data(self): 188 return { 189 'static_integer': 120, 190 'static_string': 'static me!', 191 'complex_url': DisplayDataItem( 192 'github.com', url='http://github.com', label='The URL'), 193 'python_class': HasDisplayData, 194 'my_dd': self.my_display_data 195 } 196 197 now = datetime.now() 198 fn = MyDoFn(my_display_data=now) 199 dd = DisplayData.create_from(fn) 200 nspace = '{}.{}'.format(fn.__module__, fn.__class__.__name__) 201 expected_items = [ 202 DisplayDataItemMatcher( 203 key='complex_url', 204 value='github.com', 205 namespace=nspace, 206 label='The URL'), 207 DisplayDataItemMatcher(key='my_dd', value=now, namespace=nspace), 208 DisplayDataItemMatcher( 209 key='python_class', 210 value=HasDisplayData, 211 namespace=nspace, 212 shortValue='HasDisplayData'), 213 DisplayDataItemMatcher( 214 key='static_integer', value=120, namespace=nspace), 215 DisplayDataItemMatcher( 216 key='static_string', value='static me!', namespace=nspace) 217 ] 218 219 hc.assert_that(dd.items, hc.has_items(*expected_items)) 220 221 def test_drop_if_none(self): 222 class MyDoFn(beam.DoFn): 223 def display_data(self): 224 return { 225 'some_val': DisplayDataItem('something').drop_if_none(), 226 'non_val': DisplayDataItem(None).drop_if_none(), 227 'def_val': DisplayDataItem(True).drop_if_default(True), 228 'nodef_val': DisplayDataItem(True).drop_if_default(False) 229 } 230 231 dd = DisplayData.create_from(MyDoFn()) 232 expected_items = [ 233 DisplayDataItemMatcher('some_val', 'something'), 234 DisplayDataItemMatcher('nodef_val', True) 235 ] 236 hc.assert_that(dd.items, hc.has_items(*expected_items)) 237 238 def test_subcomponent(self): 239 class SpecialDoFn(beam.DoFn): 240 def display_data(self): 241 return {'dofn_value': 42} 242 243 dofn = SpecialDoFn() 244 pardo = beam.ParDo(dofn) 245 dd = DisplayData.create_from(pardo) 246 dofn_nspace = '{}.{}'.format(dofn.__module__, dofn.__class__.__name__) 247 pardo_nspace = '{}.{}'.format(pardo.__module__, pardo.__class__.__name__) 248 expected_items = [ 249 DisplayDataItemMatcher('dofn_value', 42, dofn_nspace), 250 DisplayDataItemMatcher('fn', SpecialDoFn, pardo_nspace) 251 ] 252 253 hc.assert_that(dd.items, hc.has_items(*expected_items)) 254 255 256 # TODO: Test __repr__ function 257 # TODO: Test PATH when added by swegner@ 258 if __name__ == '__main__': 259 unittest.main()