github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/display_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Unit tests for the DisplayData API."""
    19  
    20  # pytype: skip-file
    21  
    22  import unittest
    23  from datetime import datetime
    24  
    25  # pylint: disable=ungrouped-imports
    26  import hamcrest as hc
    27  from hamcrest.core.base_matcher import BaseMatcher
    28  
    29  import apache_beam as beam
    30  from apache_beam.options.pipeline_options import PipelineOptions
    31  from apache_beam.transforms.display import DisplayData
    32  from apache_beam.transforms.display import DisplayDataItem
    33  from apache_beam.transforms.display import HasDisplayData
    34  
    35  # pylint: enable=ungrouped-imports
    36  
    37  
    38  class DisplayDataItemMatcher(BaseMatcher):
    39    """ Matcher class for DisplayDataItems in unit tests.
    40    """
    41    IGNORED = object()
    42  
    43    def __init__(
    44        self,
    45        key=IGNORED,
    46        value=IGNORED,
    47        namespace=IGNORED,
    48        label=IGNORED,
    49        shortValue=IGNORED):
    50      if all(member == DisplayDataItemMatcher.IGNORED
    51             for member in [key, value, namespace, label, shortValue]):
    52        raise ValueError('Must receive at least one item attribute to match')
    53  
    54      self.key = key
    55      self.value = value
    56      self.namespace = namespace
    57      self.label = label
    58      self.shortValue = shortValue
    59  
    60    def _matches(self, item):
    61      if self.key != DisplayDataItemMatcher.IGNORED and item.key != self.key:
    62        return False
    63      if (self.namespace != DisplayDataItemMatcher.IGNORED and
    64          item.namespace != self.namespace):
    65        return False
    66      if (self.value != DisplayDataItemMatcher.IGNORED and
    67          item.value != self.value):
    68        return False
    69      if (self.label != DisplayDataItemMatcher.IGNORED and
    70          item.label != self.label):
    71        return False
    72      if (self.shortValue != DisplayDataItemMatcher.IGNORED and
    73          item.shortValue != self.shortValue):
    74        return False
    75      return True
    76  
    77    def describe_to(self, description):
    78      descriptors = []
    79      if self.key != DisplayDataItemMatcher.IGNORED:
    80        descriptors.append('key is {}'.format(self.key))
    81      if self.value != DisplayDataItemMatcher.IGNORED:
    82        descriptors.append('value is {}'.format(self.value))
    83      if self.namespace != DisplayDataItemMatcher.IGNORED:
    84        descriptors.append('namespace is {}'.format(self.namespace))
    85      if self.label != DisplayDataItemMatcher.IGNORED:
    86        descriptors.append('label is {}'.format(self.label))
    87      if self.shortValue != DisplayDataItemMatcher.IGNORED:
    88        descriptors.append('shortValue is {}'.format(self.shortValue))
    89  
    90      item_description = '{}'.format(' and '.join(descriptors))
    91      description.append(item_description)
    92  
    93  
    94  class DisplayDataTest(unittest.TestCase):
    95    def test_display_data_item_matcher(self):
    96      with self.assertRaises(ValueError):
    97        DisplayDataItemMatcher()
    98  
    99    def test_inheritance_ptransform(self):
   100      class MyTransform(beam.PTransform):
   101        pass
   102  
   103      display_pt = MyTransform()
   104      # PTransform inherits from HasDisplayData.
   105      self.assertTrue(isinstance(display_pt, HasDisplayData))
   106      self.assertEqual(display_pt.display_data(), {})
   107  
   108    def test_inheritance_dofn(self):
   109      class MyDoFn(beam.DoFn):
   110        pass
   111  
   112      display_dofn = MyDoFn()
   113      self.assertTrue(isinstance(display_dofn, HasDisplayData))
   114      self.assertEqual(display_dofn.display_data(), {})
   115  
   116    def test_unsupported_type_display_data(self):
   117      class MyDisplayComponent(HasDisplayData):
   118        def display_data(self):
   119          return {'item_key': 'item_value'}
   120  
   121      with self.assertRaises(ValueError):
   122        DisplayData.create_from_options(MyDisplayComponent())
   123  
   124    def test_value_provider_display_data(self):
   125      class TestOptions(PipelineOptions):
   126        @classmethod
   127        def _add_argparse_args(cls, parser):
   128          parser.add_value_provider_argument(
   129              '--int_flag', type=int, help='int_flag description')
   130          parser.add_value_provider_argument(
   131              '--str_flag',
   132              type=str,
   133              default='hello',
   134              help='str_flag description')
   135          parser.add_value_provider_argument(
   136              '--float_flag', type=float, help='float_flag description')
   137  
   138      options = TestOptions(['--int_flag', '1'])
   139      items = DisplayData.create_from_options(options).items
   140      expected_items = [
   141          DisplayDataItemMatcher('int_flag', '1'),
   142          DisplayDataItemMatcher(
   143              'str_flag',
   144              'RuntimeValueProvider(option: str_flag,'
   145              ' type: str, default_value: \'hello\')'),
   146          DisplayDataItemMatcher(
   147              'float_flag',
   148              'RuntimeValueProvider(option: float_flag,'
   149              ' type: float, default_value: None)')
   150      ]
   151      hc.assert_that(items, hc.has_items(*expected_items))
   152  
   153    def test_create_list_display_data(self):
   154      flags = ['--extra_package', 'package1', '--extra_package', 'package2']
   155      pipeline_options = PipelineOptions(flags=flags)
   156      items = DisplayData.create_from_options(pipeline_options).items
   157      hc.assert_that(
   158          items,
   159          hc.has_items(
   160              DisplayDataItemMatcher(
   161                  'extra_packages', str(['package1', 'package2']))))
   162  
   163    def test_unicode_type_display_data(self):
   164      class MyDoFn(beam.DoFn):
   165        def display_data(self):
   166          return {
   167              'unicode_string': 'my string',
   168              'unicode_literal_string': u'my literal string'
   169          }
   170  
   171      fn = MyDoFn()
   172      dd = DisplayData.create_from(fn)
   173      for item in dd.items:
   174        self.assertEqual(item.type, 'STRING')
   175  
   176    def test_base_cases(self):
   177      """ Tests basic display data cases (key:value, key:dict)
   178      It does not test subcomponent inclusion
   179      """
   180      class MyDoFn(beam.DoFn):
   181        def __init__(self, my_display_data=None):
   182          self.my_display_data = my_display_data
   183  
   184        def process(self, element):
   185          yield element + 1
   186  
   187        def display_data(self):
   188          return {
   189              'static_integer': 120,
   190              'static_string': 'static me!',
   191              'complex_url': DisplayDataItem(
   192                  'github.com', url='http://github.com', label='The URL'),
   193              'python_class': HasDisplayData,
   194              'my_dd': self.my_display_data
   195          }
   196  
   197      now = datetime.now()
   198      fn = MyDoFn(my_display_data=now)
   199      dd = DisplayData.create_from(fn)
   200      nspace = '{}.{}'.format(fn.__module__, fn.__class__.__name__)
   201      expected_items = [
   202          DisplayDataItemMatcher(
   203              key='complex_url',
   204              value='github.com',
   205              namespace=nspace,
   206              label='The URL'),
   207          DisplayDataItemMatcher(key='my_dd', value=now, namespace=nspace),
   208          DisplayDataItemMatcher(
   209              key='python_class',
   210              value=HasDisplayData,
   211              namespace=nspace,
   212              shortValue='HasDisplayData'),
   213          DisplayDataItemMatcher(
   214              key='static_integer', value=120, namespace=nspace),
   215          DisplayDataItemMatcher(
   216              key='static_string', value='static me!', namespace=nspace)
   217      ]
   218  
   219      hc.assert_that(dd.items, hc.has_items(*expected_items))
   220  
   221    def test_drop_if_none(self):
   222      class MyDoFn(beam.DoFn):
   223        def display_data(self):
   224          return {
   225              'some_val': DisplayDataItem('something').drop_if_none(),
   226              'non_val': DisplayDataItem(None).drop_if_none(),
   227              'def_val': DisplayDataItem(True).drop_if_default(True),
   228              'nodef_val': DisplayDataItem(True).drop_if_default(False)
   229          }
   230  
   231      dd = DisplayData.create_from(MyDoFn())
   232      expected_items = [
   233          DisplayDataItemMatcher('some_val', 'something'),
   234          DisplayDataItemMatcher('nodef_val', True)
   235      ]
   236      hc.assert_that(dd.items, hc.has_items(*expected_items))
   237  
   238    def test_subcomponent(self):
   239      class SpecialDoFn(beam.DoFn):
   240        def display_data(self):
   241          return {'dofn_value': 42}
   242  
   243      dofn = SpecialDoFn()
   244      pardo = beam.ParDo(dofn)
   245      dd = DisplayData.create_from(pardo)
   246      dofn_nspace = '{}.{}'.format(dofn.__module__, dofn.__class__.__name__)
   247      pardo_nspace = '{}.{}'.format(pardo.__module__, pardo.__class__.__name__)
   248      expected_items = [
   249          DisplayDataItemMatcher('dofn_value', 42, dofn_nspace),
   250          DisplayDataItemMatcher('fn', SpecialDoFn, pardo_nspace)
   251      ]
   252  
   253      hc.assert_that(dd.items, hc.has_items(*expected_items))
   254  
   255  
   256  # TODO: Test __repr__ function
   257  # TODO: Test PATH when added by swegner@
   258  if __name__ == '__main__':
   259    unittest.main()