github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/coders/standard_coders_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Unit tests for coders that must be consistent across all Beam SDKs.
    19  """
    20  # pytype: skip-file
    21  
    22  import json
    23  import logging
    24  import math
    25  import os.path
    26  import sys
    27  import unittest
    28  from copy import deepcopy
    29  from typing import Dict
    30  from typing import Tuple
    31  
    32  import numpy as np
    33  import yaml
    34  from numpy.testing import assert_array_equal
    35  
    36  from apache_beam.coders import coder_impl
    37  from apache_beam.portability.api import beam_runner_api_pb2
    38  from apache_beam.portability.api import schema_pb2
    39  from apache_beam.runners import pipeline_context
    40  from apache_beam.transforms import userstate
    41  from apache_beam.transforms import window
    42  from apache_beam.transforms.window import IntervalWindow
    43  from apache_beam.typehints import schemas
    44  from apache_beam.utils import windowed_value
    45  from apache_beam.utils.sharded_key import ShardedKey
    46  from apache_beam.utils.timestamp import Timestamp
    47  from apache_beam.utils.windowed_value import PaneInfo
    48  from apache_beam.utils.windowed_value import PaneInfoTiming
    49  
    50  STANDARD_CODERS_YAML = os.path.normpath(
    51      os.path.join(
    52          os.path.dirname(__file__), '../portability/api/standard_coders.yaml'))
    53  
    54  
    55  def _load_test_cases(test_yaml):
    56    """Load test data from yaml file and return an iterable of test cases.
    57  
    58    See ``standard_coders.yaml`` for more details.
    59    """
    60    if not os.path.exists(test_yaml):
    61      raise ValueError('Could not find the test spec: %s' % test_yaml)
    62    with open(test_yaml, 'rb') as coder_spec:
    63      for ix, spec in enumerate(
    64          yaml.load_all(coder_spec, Loader=yaml.SafeLoader)):
    65        spec['index'] = ix
    66        name = spec.get('name', spec['coder']['urn'].split(':')[-2])
    67        yield [name, spec]
    68  
    69  
    70  def parse_float(s):
    71    x = float(s)
    72    if math.isnan(x):
    73      # In Windows, float('NaN') has opposite sign from other platforms.
    74      # For the purpose of this test, we just need consistency.
    75      x = abs(x)
    76    return x
    77  
    78  
    79  def value_parser_from_schema(schema):
    80    def attribute_parser_from_type(type_):
    81      parser = nonnull_attribute_parser_from_type(type_)
    82      if type_.nullable:
    83        return lambda x: None if x is None else parser(x)
    84      else:
    85        return parser
    86  
    87    def nonnull_attribute_parser_from_type(type_):
    88      # TODO: This should be exhaustive
    89      type_info = type_.WhichOneof("type_info")
    90      if type_info == "atomic_type":
    91        if type_.atomic_type == schema_pb2.BYTES:
    92          return lambda x: x.encode("utf-8")
    93        else:
    94          return schemas.ATOMIC_TYPE_TO_PRIMITIVE[type_.atomic_type]
    95      elif type_info == "array_type":
    96        element_parser = attribute_parser_from_type(type_.array_type.element_type)
    97        return lambda x: list(map(element_parser, x))
    98      elif type_info == "map_type":
    99        key_parser = attribute_parser_from_type(type_.map_type.key_type)
   100        value_parser = attribute_parser_from_type(type_.map_type.value_type)
   101        return lambda x: dict(
   102            (key_parser(k), value_parser(v)) for k, v in x.items())
   103      elif type_info == "row_type":
   104        return value_parser_from_schema(type_.row_type.schema)
   105      elif type_info == "logical_type":
   106        # In YAML logical types are represented with their representation types.
   107        to_language_type = schemas.LogicalType.from_runner_api(
   108            type_.logical_type).to_language_type
   109        parse_representation = attribute_parser_from_type(
   110            type_.logical_type.representation)
   111        return lambda x: to_language_type(parse_representation(x))
   112  
   113    parsers = [(field.name, attribute_parser_from_type(field.type))
   114               for field in schema.fields]
   115  
   116    constructor = schemas.named_tuple_from_schema(schema)
   117  
   118    def value_parser(x):
   119      result = []
   120      x = deepcopy(x)
   121      for name, parser in parsers:
   122        value = x.pop(name)
   123        result.append(None if value is None else parser(value))
   124  
   125      if len(x):
   126        raise ValueError(
   127            "Test data contains attributes that don't exist in the schema: {}".
   128            format(', '.join(x.keys())))
   129  
   130      return constructor(*result)
   131  
   132    return value_parser
   133  
   134  
   135  class StandardCodersTest(unittest.TestCase):
   136  
   137    _urn_to_json_value_parser = {
   138        'beam:coder:bytes:v1': lambda x: x.encode('utf-8'),
   139        'beam:coder:bool:v1': lambda x: x,
   140        'beam:coder:string_utf8:v1': lambda x: x,
   141        'beam:coder:varint:v1': lambda x: x,
   142        'beam:coder:kv:v1': lambda x,
   143        key_parser,
   144        value_parser: (key_parser(x['key']), value_parser(x['value'])),
   145        'beam:coder:interval_window:v1': lambda x: IntervalWindow(
   146            start=Timestamp(micros=(x['end'] - x['span']) * 1000),
   147            end=Timestamp(micros=x['end'] * 1000)),
   148        'beam:coder:iterable:v1': lambda x,
   149        parser: list(map(parser, x)),
   150        'beam:coder:state_backed_iterable:v1': lambda x,
   151        parser: list(map(parser, x)),
   152        'beam:coder:global_window:v1': lambda x: window.GlobalWindow(),
   153        'beam:coder:windowed_value:v1': lambda x,
   154        value_parser,
   155        window_parser: windowed_value.create(
   156            value_parser(x['value']),
   157            x['timestamp'] * 1000,
   158            tuple(window_parser(w) for w in x['windows'])),
   159        'beam:coder:param_windowed_value:v1': lambda x,
   160        value_parser,
   161        window_parser: windowed_value.create(
   162            value_parser(x['value']),
   163            x['timestamp'] * 1000,
   164            tuple(window_parser(w) for w in x['windows']),
   165            PaneInfo(
   166                x['pane']['is_first'],
   167                x['pane']['is_last'],
   168                PaneInfoTiming.from_string(x['pane']['timing']),
   169                x['pane']['index'],
   170                x['pane']['on_time_index'])),
   171        'beam:coder:timer:v1': lambda x,
   172        value_parser,
   173        window_parser: userstate.Timer(
   174            user_key=value_parser(x['userKey']),
   175            dynamic_timer_tag=x['dynamicTimerTag'],
   176            clear_bit=x['clearBit'],
   177            windows=tuple(window_parser(w) for w in x['windows']),
   178            fire_timestamp=None,
   179            hold_timestamp=None,
   180            paneinfo=None) if x['clearBit'] else userstate.Timer(
   181                user_key=value_parser(x['userKey']),
   182                dynamic_timer_tag=x['dynamicTimerTag'],
   183                clear_bit=x['clearBit'],
   184                fire_timestamp=Timestamp(micros=x['fireTimestamp'] * 1000),
   185                hold_timestamp=Timestamp(micros=x['holdTimestamp'] * 1000),
   186                windows=tuple(window_parser(w) for w in x['windows']),
   187                paneinfo=PaneInfo(
   188                    x['pane']['is_first'],
   189                    x['pane']['is_last'],
   190                    PaneInfoTiming.from_string(x['pane']['timing']),
   191                    x['pane']['index'],
   192                    x['pane']['on_time_index'])),
   193        'beam:coder:double:v1': parse_float,
   194        'beam:coder:sharded_key:v1': lambda x,
   195        value_parser: ShardedKey(
   196            key=value_parser(x['key']), shard_id=x['shardId'].encode('utf-8')),
   197        'beam:coder:custom_window:v1': lambda x,
   198        window_parser: window_parser(x['window']),
   199        'beam:coder:nullable:v1': lambda x,
   200        value_parser: x.encode('utf-8') if x else None
   201    }
   202  
   203    def test_standard_coders(self):
   204      for name, spec in _load_test_cases(STANDARD_CODERS_YAML):
   205        logging.info('Executing %s test.', name)
   206        self._run_standard_coder(name, spec)
   207  
   208    def _run_standard_coder(self, name, spec):
   209      def assert_equal(actual, expected):
   210        """Handle nan values which self.assertEqual fails on."""
   211        if (isinstance(actual, float) and isinstance(expected, float) and
   212            math.isnan(actual) and math.isnan(expected)):
   213          return
   214        self.assertEqual(actual, expected)
   215  
   216      coder = self.parse_coder(spec['coder'])
   217      parse_value = self.json_value_parser(spec['coder'])
   218      nested_list = [spec['nested']] if 'nested' in spec else [True, False]
   219      for nested in nested_list:
   220        for expected_encoded, json_value in spec['examples'].items():
   221          value = parse_value(json_value)
   222          expected_encoded = expected_encoded.encode('latin1')
   223          if not spec['coder'].get('non_deterministic', False):
   224            actual_encoded = encode_nested(coder, value, nested)
   225            if self.fix and actual_encoded != expected_encoded:
   226              self.to_fix[spec['index'], expected_encoded] = actual_encoded
   227            else:
   228              self.assertEqual(expected_encoded, actual_encoded)
   229              decoded = decode_nested(coder, expected_encoded, nested)
   230              assert_equal(decoded, value)
   231          else:
   232            # Only verify decoding for a non-deterministic coder
   233            self.assertEqual(
   234                decode_nested(coder, expected_encoded, nested), value)
   235  
   236      if spec['coder']['urn'] == 'beam:coder:row:v1':
   237        # Test batch encoding/decoding as well.
   238        values = [
   239            parse_value(json_value) for json_value in spec['examples'].values()
   240        ]
   241        columnar = {
   242            field.name: np.array([getattr(value, field.name) for value in values])
   243            for field in coder.schema.fields
   244        }
   245        dest = {
   246            field: np.empty_like(values)
   247            for field, values in columnar.items()
   248        }
   249        for column in dest.values():
   250          column[:] = 0 if 'int' in column.dtype.name else None
   251        expected_encoded = ''.join(spec['examples'].keys()).encode('latin1')
   252        actual_encoded = encode_batch(coder, columnar)
   253        assert_equal(expected_encoded, actual_encoded)
   254        decoded_count = decode_batch(coder, expected_encoded, dest)
   255        assert_equal(len(spec['examples']), decoded_count)
   256        for field, values in dest.items():
   257          assert_array_equal(columnar[field], dest[field])
   258  
   259    def parse_coder(self, spec):
   260      context = pipeline_context.PipelineContext()
   261      coder_id = str(hash(str(spec)))
   262      component_ids = [
   263          context.coders.get_id(self.parse_coder(c))
   264          for c in spec.get('components', ())
   265      ]
   266      if spec.get('state'):
   267  
   268        def iterable_state_read(state_token, elem_coder):
   269          state = spec.get('state').get(state_token.decode('latin1'))
   270          if state is None:
   271            state = ''
   272          input_stream = coder_impl.create_InputStream(state.encode('latin1'))
   273          while input_stream.size() > 0:
   274            yield elem_coder.decode_from_stream(input_stream, True)
   275  
   276        context.iterable_state_read = iterable_state_read
   277  
   278      context.coders.put_proto(
   279          coder_id,
   280          beam_runner_api_pb2.Coder(
   281              spec=beam_runner_api_pb2.FunctionSpec(
   282                  urn=spec['urn'],
   283                  payload=spec.get('payload', '').encode('latin1')),
   284              component_coder_ids=component_ids))
   285      return context.coders.get_by_id(coder_id)
   286  
   287    def json_value_parser(self, coder_spec):
   288      # TODO: integrate this with the logic for the other parsers
   289      if coder_spec['urn'] == 'beam:coder:row:v1':
   290        schema = schema_pb2.Schema.FromString(
   291            coder_spec['payload'].encode('latin1'))
   292        return value_parser_from_schema(schema)
   293  
   294      component_parsers = [
   295          self.json_value_parser(c) for c in coder_spec.get('components', ())
   296      ]
   297      return lambda x: self._urn_to_json_value_parser[coder_spec['urn']](
   298          x, *component_parsers)
   299  
   300    # Used when --fix is passed.
   301  
   302    fix = False
   303    to_fix = {}  # type: Dict[Tuple[int, bytes], bytes]
   304  
   305    @classmethod
   306    def tearDownClass(cls):
   307      if cls.fix and cls.to_fix:
   308        print("FIXING", len(cls.to_fix), "TESTS")
   309        doc_sep = '\n---\n'
   310        docs = open(STANDARD_CODERS_YAML).read().split(doc_sep)
   311  
   312        def quote(s):
   313          return json.dumps(s.decode('latin1')).replace(r'\u0000', r'\0')
   314  
   315        for (doc_ix, expected_encoded), actual_encoded in cls.to_fix.items():
   316          print(quote(expected_encoded), "->", quote(actual_encoded))
   317          docs[doc_ix] = docs[doc_ix].replace(
   318              quote(expected_encoded) + ':', quote(actual_encoded) + ':')
   319        open(STANDARD_CODERS_YAML, 'w').write(doc_sep.join(docs))
   320  
   321  
   322  def encode_nested(coder, value, nested=True):
   323    out = coder_impl.create_OutputStream()
   324    coder.get_impl().encode_to_stream(value, out, nested)
   325    return out.get()
   326  
   327  
   328  def decode_nested(coder, encoded, nested=True):
   329    return coder.get_impl().decode_from_stream(
   330        coder_impl.create_InputStream(encoded), nested)
   331  
   332  
   333  def encode_batch(row_coder, values):
   334    out = coder_impl.create_OutputStream()
   335    row_coder.get_impl().encode_batch_to_stream(values, out)
   336    return out.get()
   337  
   338  
   339  def decode_batch(row_coder, encoded, dest):
   340    return row_coder.get_impl().decode_batch_from_stream(
   341        dest, coder_impl.create_InputStream(encoded))
   342  
   343  
   344  if __name__ == '__main__':
   345    if '--fix' in sys.argv:
   346      StandardCodersTest.fix = True
   347      sys.argv.remove('--fix')
   348    unittest.main()