github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/coders/coders_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 # pytype: skip-file 18 19 import base64 20 import logging 21 import unittest 22 23 import proto 24 import pytest 25 26 from apache_beam import typehints 27 from apache_beam.coders import proto2_coder_test_messages_pb2 as test_message 28 from apache_beam.coders import coders 29 from apache_beam.coders.avro_record import AvroRecord 30 from apache_beam.coders.typecoders import registry as coders_registry 31 32 33 class PickleCoderTest(unittest.TestCase): 34 def test_basics(self): 35 v = ('a' * 10, 'b' * 90) 36 pickler = coders.PickleCoder() 37 self.assertEqual(v, pickler.decode(pickler.encode(v))) 38 pickler = coders.Base64PickleCoder() 39 self.assertEqual(v, pickler.decode(pickler.encode(v))) 40 self.assertEqual( 41 coders.Base64PickleCoder().encode(v), 42 base64.b64encode(coders.PickleCoder().encode(v))) 43 44 def test_equality(self): 45 self.assertEqual(coders.PickleCoder(), coders.PickleCoder()) 46 self.assertEqual(coders.Base64PickleCoder(), coders.Base64PickleCoder()) 47 self.assertNotEqual(coders.Base64PickleCoder(), coders.PickleCoder()) 48 self.assertNotEqual(coders.Base64PickleCoder(), object()) 49 50 51 class CodersTest(unittest.TestCase): 52 def test_str_utf8_coder(self): 53 real_coder = coders_registry.get_coder(bytes) 54 expected_coder = coders.BytesCoder() 55 self.assertEqual(real_coder.encode(b'abc'), expected_coder.encode(b'abc')) 56 self.assertEqual(b'abc', real_coder.decode(real_coder.encode(b'abc'))) 57 58 59 # The test proto message file was generated by running the following: 60 # 61 # `cd <beam repo>` 62 # `cp sdks/java/extensions/protobuf/src/test/proto/\ 63 # proto2_coder_test_messages.proto sdks/python/apache_beam/coders/` 64 # `cd sdks/python` 65 # `protoc apache_beam/coders/proto2_coder_test_messages.proto 66 # --python_out=. 67 # `rm apache_beam/coders/proto2_coder_test_message.proto` 68 # 69 # Note: The protoc version should match the protobuf library version specified 70 # in setup.py. 71 # 72 # TODO(https://github.com/apache/beam/issues/22319): The proto file should be 73 # placed in a common directory that can be shared between java and python. 74 class ProtoCoderTest(unittest.TestCase): 75 def test_proto_coder(self): 76 ma = test_message.MessageA() 77 mb = ma.field2.add() 78 mb.field1 = True 79 ma.field1 = u'hello world' 80 expected_coder = coders.ProtoCoder(ma.__class__) 81 real_coder = coders_registry.get_coder(ma.__class__) 82 self.assertEqual(expected_coder, real_coder) 83 self.assertEqual(real_coder.encode(ma), expected_coder.encode(ma)) 84 self.assertEqual(ma, real_coder.decode(real_coder.encode(ma))) 85 self.assertEqual(ma.__class__, real_coder.to_type_hint()) 86 87 88 class DeterministicProtoCoderTest(unittest.TestCase): 89 def test_deterministic_proto_coder(self): 90 ma = test_message.MessageA() 91 mb = ma.field2.add() 92 mb.field1 = True 93 ma.field1 = u'hello world' 94 expected_coder = coders.DeterministicProtoCoder(ma.__class__) 95 real_coder = ( 96 coders_registry.get_coder( 97 ma.__class__).as_deterministic_coder(step_label='unused')) 98 self.assertTrue(real_coder.is_deterministic()) 99 self.assertEqual(expected_coder, real_coder) 100 self.assertEqual(real_coder.encode(ma), expected_coder.encode(ma)) 101 self.assertEqual(ma, real_coder.decode(real_coder.encode(ma))) 102 103 def test_deterministic_proto_coder_determinism(self): 104 for _ in range(10): 105 keys = list(range(20)) 106 mm_forward = test_message.MessageWithMap() 107 for key in keys: 108 mm_forward.field1[str(key)].field1 = str(key) 109 mm_reverse = test_message.MessageWithMap() 110 for key in reversed(keys): 111 mm_reverse.field1[str(key)].field1 = str(key) 112 coder = coders.DeterministicProtoCoder(mm_forward.__class__) 113 self.assertEqual(coder.encode(mm_forward), coder.encode(mm_reverse)) 114 115 116 class ProtoPlusMessageB(proto.Message): 117 field1 = proto.Field(proto.BOOL, number=1) 118 119 120 class ProtoPlusMessageA(proto.Message): 121 field1 = proto.Field(proto.STRING, number=1) 122 field2 = proto.RepeatedField(ProtoPlusMessageB, number=2) 123 124 125 class ProtoPlusMessageWithMap(proto.Message): 126 field1 = proto.MapField(proto.STRING, ProtoPlusMessageA, number=1) 127 128 129 class ProtoPlusCoderTest(unittest.TestCase): 130 def test_proto_plus_coder(self): 131 ma = ProtoPlusMessageA() 132 ma.field2 = [ProtoPlusMessageB(field1=True)] 133 ma.field1 = u'hello world' 134 expected_coder = coders.ProtoPlusCoder(ma.__class__) 135 real_coder = coders_registry.get_coder(ma.__class__) 136 self.assertTrue(issubclass(ma.__class__, proto.Message)) 137 self.assertEqual(expected_coder, real_coder) 138 self.assertTrue(real_coder.is_deterministic()) 139 self.assertEqual(real_coder.encode(ma), expected_coder.encode(ma)) 140 self.assertEqual(ma, real_coder.decode(real_coder.encode(ma))) 141 142 def test_proto_plus_coder_determinism(self): 143 for _ in range(10): 144 keys = list(range(20)) 145 mm_forward = ProtoPlusMessageWithMap() 146 for key in keys: 147 mm_forward.field1[str(key)] = ProtoPlusMessageA(field1=str(key)) # pylint: disable=E1137 148 mm_reverse = ProtoPlusMessageWithMap() 149 for key in reversed(keys): 150 mm_reverse.field1[str(key)] = ProtoPlusMessageA(field1=str(key)) # pylint: disable=E1137 151 coder = coders.ProtoPlusCoder(ProtoPlusMessageWithMap) 152 self.assertEqual(coder.encode(mm_forward), coder.encode(mm_reverse)) 153 154 155 class AvroTestCoder(coders.AvroGenericCoder): 156 SCHEMA = """ 157 { 158 "type": "record", "name": "testrecord", 159 "fields": [ 160 {"name": "name", "type": "string"}, 161 {"name": "age", "type": "int"} 162 ] 163 } 164 """ 165 166 def __init__(self): 167 super().__init__(self.SCHEMA) 168 169 170 class AvroTestRecord(AvroRecord): 171 pass 172 173 174 coders_registry.register_coder(AvroTestRecord, AvroTestCoder) 175 176 177 class AvroCoderTest(unittest.TestCase): 178 def test_avro_record_coder(self): 179 real_coder = coders_registry.get_coder(AvroTestRecord) 180 expected_coder = AvroTestCoder() 181 self.assertEqual( 182 real_coder.encode( 183 AvroTestRecord({ 184 "name": "Daenerys targaryen", "age": 23 185 })), 186 expected_coder.encode( 187 AvroTestRecord({ 188 "name": "Daenerys targaryen", "age": 23 189 }))) 190 self.assertEqual( 191 AvroTestRecord({ 192 "name": "Jon Snow", "age": 23 193 }), 194 real_coder.decode( 195 real_coder.encode(AvroTestRecord({ 196 "name": "Jon Snow", "age": 23 197 })))) 198 199 200 class DummyClass(object): 201 """A class with no registered coder.""" 202 def __init__(self): 203 pass 204 205 def __eq__(self, other): 206 if isinstance(other, self.__class__): 207 return True 208 return False 209 210 def __hash__(self): 211 return hash(type(self)) 212 213 214 class FallbackCoderTest(unittest.TestCase): 215 def test_default_fallback_path(self): 216 """Test fallback path picks a matching coder if no coder is registered.""" 217 218 coder = coders_registry.get_coder(DummyClass) 219 # No matching coder, so picks the last fallback coder which is a 220 # FastPrimitivesCoder. 221 self.assertEqual(coder, coders.FastPrimitivesCoder()) 222 self.assertEqual(DummyClass(), coder.decode(coder.encode(DummyClass()))) 223 224 225 class NullableCoderTest(unittest.TestCase): 226 def test_determinism(self): 227 deterministic = coders_registry.get_coder(typehints.Optional[int]) 228 deterministic.as_deterministic_coder('label') 229 230 complex_deterministic = coders_registry.get_coder( 231 typehints.Optional[DummyClass]) 232 complex_deterministic.as_deterministic_coder('label') 233 234 nondeterministic = coders.NullableCoder(coders.Base64PickleCoder()) 235 with pytest.raises(ValueError): 236 nondeterministic.as_deterministic_coder('label') 237 238 239 if __name__ == '__main__': 240 logging.getLogger().setLevel(logging.INFO) 241 unittest.main()