github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/coders/coders_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  # pytype: skip-file
    18  
    19  import base64
    20  import logging
    21  import unittest
    22  
    23  import proto
    24  import pytest
    25  
    26  from apache_beam import typehints
    27  from apache_beam.coders import proto2_coder_test_messages_pb2 as test_message
    28  from apache_beam.coders import coders
    29  from apache_beam.coders.avro_record import AvroRecord
    30  from apache_beam.coders.typecoders import registry as coders_registry
    31  
    32  
    33  class PickleCoderTest(unittest.TestCase):
    34    def test_basics(self):
    35      v = ('a' * 10, 'b' * 90)
    36      pickler = coders.PickleCoder()
    37      self.assertEqual(v, pickler.decode(pickler.encode(v)))
    38      pickler = coders.Base64PickleCoder()
    39      self.assertEqual(v, pickler.decode(pickler.encode(v)))
    40      self.assertEqual(
    41          coders.Base64PickleCoder().encode(v),
    42          base64.b64encode(coders.PickleCoder().encode(v)))
    43  
    44    def test_equality(self):
    45      self.assertEqual(coders.PickleCoder(), coders.PickleCoder())
    46      self.assertEqual(coders.Base64PickleCoder(), coders.Base64PickleCoder())
    47      self.assertNotEqual(coders.Base64PickleCoder(), coders.PickleCoder())
    48      self.assertNotEqual(coders.Base64PickleCoder(), object())
    49  
    50  
    51  class CodersTest(unittest.TestCase):
    52    def test_str_utf8_coder(self):
    53      real_coder = coders_registry.get_coder(bytes)
    54      expected_coder = coders.BytesCoder()
    55      self.assertEqual(real_coder.encode(b'abc'), expected_coder.encode(b'abc'))
    56      self.assertEqual(b'abc', real_coder.decode(real_coder.encode(b'abc')))
    57  
    58  
    59  # The test proto message file was generated by running the following:
    60  #
    61  # `cd <beam repo>`
    62  # `cp sdks/java/extensions/protobuf/src/test/proto/\
    63  #    proto2_coder_test_messages.proto sdks/python/apache_beam/coders/`
    64  # `cd sdks/python`
    65  # `protoc apache_beam/coders/proto2_coder_test_messages.proto
    66  #    --python_out=.
    67  # `rm apache_beam/coders/proto2_coder_test_message.proto`
    68  #
    69  # Note: The protoc version should match the protobuf library version specified
    70  # in setup.py.
    71  #
    72  # TODO(https://github.com/apache/beam/issues/22319): The proto file should be
    73  # placed in a common directory that can be shared between java and python.
    74  class ProtoCoderTest(unittest.TestCase):
    75    def test_proto_coder(self):
    76      ma = test_message.MessageA()
    77      mb = ma.field2.add()
    78      mb.field1 = True
    79      ma.field1 = u'hello world'
    80      expected_coder = coders.ProtoCoder(ma.__class__)
    81      real_coder = coders_registry.get_coder(ma.__class__)
    82      self.assertEqual(expected_coder, real_coder)
    83      self.assertEqual(real_coder.encode(ma), expected_coder.encode(ma))
    84      self.assertEqual(ma, real_coder.decode(real_coder.encode(ma)))
    85      self.assertEqual(ma.__class__, real_coder.to_type_hint())
    86  
    87  
    88  class DeterministicProtoCoderTest(unittest.TestCase):
    89    def test_deterministic_proto_coder(self):
    90      ma = test_message.MessageA()
    91      mb = ma.field2.add()
    92      mb.field1 = True
    93      ma.field1 = u'hello world'
    94      expected_coder = coders.DeterministicProtoCoder(ma.__class__)
    95      real_coder = (
    96          coders_registry.get_coder(
    97              ma.__class__).as_deterministic_coder(step_label='unused'))
    98      self.assertTrue(real_coder.is_deterministic())
    99      self.assertEqual(expected_coder, real_coder)
   100      self.assertEqual(real_coder.encode(ma), expected_coder.encode(ma))
   101      self.assertEqual(ma, real_coder.decode(real_coder.encode(ma)))
   102  
   103    def test_deterministic_proto_coder_determinism(self):
   104      for _ in range(10):
   105        keys = list(range(20))
   106        mm_forward = test_message.MessageWithMap()
   107        for key in keys:
   108          mm_forward.field1[str(key)].field1 = str(key)
   109        mm_reverse = test_message.MessageWithMap()
   110        for key in reversed(keys):
   111          mm_reverse.field1[str(key)].field1 = str(key)
   112        coder = coders.DeterministicProtoCoder(mm_forward.__class__)
   113        self.assertEqual(coder.encode(mm_forward), coder.encode(mm_reverse))
   114  
   115  
   116  class ProtoPlusMessageB(proto.Message):
   117    field1 = proto.Field(proto.BOOL, number=1)
   118  
   119  
   120  class ProtoPlusMessageA(proto.Message):
   121    field1 = proto.Field(proto.STRING, number=1)
   122    field2 = proto.RepeatedField(ProtoPlusMessageB, number=2)
   123  
   124  
   125  class ProtoPlusMessageWithMap(proto.Message):
   126    field1 = proto.MapField(proto.STRING, ProtoPlusMessageA, number=1)
   127  
   128  
   129  class ProtoPlusCoderTest(unittest.TestCase):
   130    def test_proto_plus_coder(self):
   131      ma = ProtoPlusMessageA()
   132      ma.field2 = [ProtoPlusMessageB(field1=True)]
   133      ma.field1 = u'hello world'
   134      expected_coder = coders.ProtoPlusCoder(ma.__class__)
   135      real_coder = coders_registry.get_coder(ma.__class__)
   136      self.assertTrue(issubclass(ma.__class__, proto.Message))
   137      self.assertEqual(expected_coder, real_coder)
   138      self.assertTrue(real_coder.is_deterministic())
   139      self.assertEqual(real_coder.encode(ma), expected_coder.encode(ma))
   140      self.assertEqual(ma, real_coder.decode(real_coder.encode(ma)))
   141  
   142    def test_proto_plus_coder_determinism(self):
   143      for _ in range(10):
   144        keys = list(range(20))
   145        mm_forward = ProtoPlusMessageWithMap()
   146        for key in keys:
   147          mm_forward.field1[str(key)] = ProtoPlusMessageA(field1=str(key))  # pylint: disable=E1137
   148        mm_reverse = ProtoPlusMessageWithMap()
   149        for key in reversed(keys):
   150          mm_reverse.field1[str(key)] = ProtoPlusMessageA(field1=str(key))  # pylint: disable=E1137
   151        coder = coders.ProtoPlusCoder(ProtoPlusMessageWithMap)
   152        self.assertEqual(coder.encode(mm_forward), coder.encode(mm_reverse))
   153  
   154  
   155  class AvroTestCoder(coders.AvroGenericCoder):
   156    SCHEMA = """
   157    {
   158      "type": "record", "name": "testrecord",
   159      "fields": [
   160        {"name": "name", "type": "string"},
   161        {"name": "age", "type": "int"}
   162      ]
   163    }
   164    """
   165  
   166    def __init__(self):
   167      super().__init__(self.SCHEMA)
   168  
   169  
   170  class AvroTestRecord(AvroRecord):
   171    pass
   172  
   173  
   174  coders_registry.register_coder(AvroTestRecord, AvroTestCoder)
   175  
   176  
   177  class AvroCoderTest(unittest.TestCase):
   178    def test_avro_record_coder(self):
   179      real_coder = coders_registry.get_coder(AvroTestRecord)
   180      expected_coder = AvroTestCoder()
   181      self.assertEqual(
   182          real_coder.encode(
   183              AvroTestRecord({
   184                  "name": "Daenerys targaryen", "age": 23
   185              })),
   186          expected_coder.encode(
   187              AvroTestRecord({
   188                  "name": "Daenerys targaryen", "age": 23
   189              })))
   190      self.assertEqual(
   191          AvroTestRecord({
   192              "name": "Jon Snow", "age": 23
   193          }),
   194          real_coder.decode(
   195              real_coder.encode(AvroTestRecord({
   196                  "name": "Jon Snow", "age": 23
   197              }))))
   198  
   199  
   200  class DummyClass(object):
   201    """A class with no registered coder."""
   202    def __init__(self):
   203      pass
   204  
   205    def __eq__(self, other):
   206      if isinstance(other, self.__class__):
   207        return True
   208      return False
   209  
   210    def __hash__(self):
   211      return hash(type(self))
   212  
   213  
   214  class FallbackCoderTest(unittest.TestCase):
   215    def test_default_fallback_path(self):
   216      """Test fallback path picks a matching coder if no coder is registered."""
   217  
   218      coder = coders_registry.get_coder(DummyClass)
   219      # No matching coder, so picks the last fallback coder which is a
   220      # FastPrimitivesCoder.
   221      self.assertEqual(coder, coders.FastPrimitivesCoder())
   222      self.assertEqual(DummyClass(), coder.decode(coder.encode(DummyClass())))
   223  
   224  
   225  class NullableCoderTest(unittest.TestCase):
   226    def test_determinism(self):
   227      deterministic = coders_registry.get_coder(typehints.Optional[int])
   228      deterministic.as_deterministic_coder('label')
   229  
   230      complex_deterministic = coders_registry.get_coder(
   231          typehints.Optional[DummyClass])
   232      complex_deterministic.as_deterministic_coder('label')
   233  
   234      nondeterministic = coders.NullableCoder(coders.Base64PickleCoder())
   235      with pytest.raises(ValueError):
   236        nondeterministic.as_deterministic_coder('label')
   237  
   238  
   239  if __name__ == '__main__':
   240    logging.getLogger().setLevel(logging.INFO)
   241    unittest.main()