github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/coders/row_coder_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 # pytype: skip-file 18 19 import logging 20 import typing 21 import unittest 22 from itertools import chain 23 24 import numpy as np 25 from google.protobuf import json_format 26 from numpy.testing import assert_array_equal 27 28 import apache_beam as beam 29 from apache_beam.coders import RowCoder 30 from apache_beam.coders import coder_impl 31 from apache_beam.coders.typecoders import registry as coders_registry 32 from apache_beam.internal import pickler 33 from apache_beam.portability.api import schema_pb2 34 from apache_beam.testing.test_pipeline import TestPipeline 35 from apache_beam.testing.util import assert_that 36 from apache_beam.testing.util import equal_to 37 from apache_beam.typehints.schemas import named_tuple_from_schema 38 from apache_beam.typehints.schemas import typing_to_runner_api 39 from apache_beam.utils.timestamp import Timestamp 40 41 Person = typing.NamedTuple( 42 "Person", 43 [ 44 ("name", str), 45 ("age", np.int32), 46 ("address", typing.Optional[str]), 47 ("aliases", typing.List[str]), 48 ("knows_javascript", bool), 49 ("payload", typing.Optional[bytes]), 50 ("custom_metadata", typing.Mapping[str, int]), 51 ("favorite_time", Timestamp), 52 ]) 53 54 NullablePerson = typing.NamedTuple( 55 "NullablePerson", 56 [("name", typing.Optional[str]), ("age", np.int32), 57 ("address", typing.Optional[str]), ("aliases", typing.List[str]), 58 ("knows_javascript", bool), ("payload", typing.Optional[bytes]), 59 ("custom_metadata", typing.Mapping[str, int]), 60 ("favorite_time", typing.Optional[Timestamp]), 61 ("one_more_field", typing.Optional[str])]) 62 63 64 class People(typing.NamedTuple): 65 primary: Person 66 partner: typing.Optional[Person] 67 68 69 coders_registry.register_coder(Person, RowCoder) 70 coders_registry.register_coder(People, RowCoder) 71 72 73 class RowCoderTest(unittest.TestCase): 74 JON_SNOW = Person( 75 name="Jon Snow", 76 age=np.int32(23), 77 address=None, 78 aliases=["crow", "wildling"], 79 knows_javascript=False, 80 payload=None, 81 custom_metadata={}, 82 favorite_time=Timestamp.from_rfc3339('2016-03-18T23:22:59.123456Z'), 83 ) 84 PEOPLE = [ 85 JON_SNOW, 86 Person( 87 "Daenerys Targaryen", 88 np.int32(25), 89 "Westeros", 90 ["Mother of Dragons"], 91 False, 92 None, 93 {"dragons": 3}, 94 Timestamp.from_rfc3339('1970-04-26T17:46:40Z'), 95 ), 96 Person( 97 "Michael Bluth", 98 np.int32(30), 99 None, [], 100 True, 101 b"I've made a huge mistake", {}, 102 Timestamp.from_rfc3339('2020-08-12T15:51:00.032Z')) 103 ] 104 105 def test_row_accepts_trailing_zeros_truncated(self): 106 expected_coder = RowCoder( 107 typing_to_runner_api(NullablePerson).row_type.schema) 108 person = NullablePerson( 109 None, 110 np.int32(25), 111 "Westeros", ["Mother of Dragons"], 112 False, 113 None, {"dragons": 3}, 114 None, 115 "NotNull") 116 out = expected_coder.encode(person) 117 # 9 fields, 1 null byte, field 0, 5, 7 are null 118 new_payload = bytes([9, 1, 1 | 1 << 5 | 1 << 7]) + out[4:] 119 new_value = expected_coder.decode(new_payload) 120 self.assertEqual(person, new_value) 121 122 def test_create_row_coder_from_named_tuple(self): 123 expected_coder = RowCoder(typing_to_runner_api(Person).row_type.schema) 124 real_coder = coders_registry.get_coder(Person) 125 126 for test_case in self.PEOPLE: 127 self.assertEqual( 128 expected_coder.encode(test_case), real_coder.encode(test_case)) 129 130 self.assertEqual( 131 test_case, real_coder.decode(real_coder.encode(test_case))) 132 133 def test_create_row_coder_from_nested_named_tuple(self): 134 expected_coder = RowCoder(typing_to_runner_api(People).row_type.schema) 135 real_coder = coders_registry.get_coder(People) 136 137 for primary in self.PEOPLE: 138 for other in self.PEOPLE + [None]: 139 test_case = People(primary=primary, partner=other) 140 self.assertEqual( 141 expected_coder.encode(test_case), real_coder.encode(test_case)) 142 143 self.assertEqual( 144 test_case, real_coder.decode(real_coder.encode(test_case))) 145 146 def test_create_row_coder_from_schema(self): 147 schema = schema_pb2.Schema( 148 id="person", 149 fields=[ 150 schema_pb2.Field( 151 name="name", 152 type=schema_pb2.FieldType(atomic_type=schema_pb2.STRING)), 153 schema_pb2.Field( 154 name="age", 155 type=schema_pb2.FieldType(atomic_type=schema_pb2.INT32)), 156 schema_pb2.Field( 157 name="address", 158 type=schema_pb2.FieldType( 159 atomic_type=schema_pb2.STRING, nullable=True)), 160 schema_pb2.Field( 161 name="aliases", 162 type=schema_pb2.FieldType( 163 array_type=schema_pb2.ArrayType( 164 element_type=schema_pb2.FieldType( 165 atomic_type=schema_pb2.STRING)))), 166 schema_pb2.Field( 167 name="knows_javascript", 168 type=schema_pb2.FieldType(atomic_type=schema_pb2.BOOLEAN)), 169 schema_pb2.Field( 170 name="payload", 171 type=schema_pb2.FieldType( 172 atomic_type=schema_pb2.BYTES, nullable=True)), 173 schema_pb2.Field( 174 name="custom_metadata", 175 type=schema_pb2.FieldType( 176 map_type=schema_pb2.MapType( 177 key_type=schema_pb2.FieldType( 178 atomic_type=schema_pb2.STRING), 179 value_type=schema_pb2.FieldType( 180 atomic_type=schema_pb2.INT64), 181 ))), 182 schema_pb2.Field( 183 name="favorite_time", 184 type=schema_pb2.FieldType( 185 logical_type=schema_pb2.LogicalType( 186 urn="beam:logical_type:micros_instant:v1", 187 representation=schema_pb2.FieldType( 188 row_type=schema_pb2.RowType( 189 schema=schema_pb2.Schema( 190 id="micros_instant", 191 fields=[ 192 schema_pb2.Field( 193 name="seconds", 194 type=schema_pb2.FieldType( 195 atomic_type=schema_pb2.INT64)), 196 schema_pb2.Field( 197 name="micros", 198 type=schema_pb2.FieldType( 199 atomic_type=schema_pb2.INT64)), 200 ])))))), 201 ]) 202 coder = RowCoder(schema) 203 204 for test_case in self.PEOPLE: 205 self.assertEqual(test_case, coder.decode(coder.encode(test_case))) 206 207 @unittest.skip( 208 "https://github.com/apache/beam/issues/19696 - Overflow behavior in " 209 "VarIntCoder is currently inconsistent") 210 def test_overflows(self): 211 IntTester = typing.NamedTuple( 212 'IntTester', 213 [ 214 # TODO(https://github.com/apache/beam/issues/19815): Test int8 and 215 # int16 here as well when those types are supported 216 # ('i8', typing.Optional[np.int8]), 217 # ('i16', typing.Optional[np.int16]), 218 ('i32', typing.Optional[np.int32]), 219 ('i64', typing.Optional[np.int64]), 220 ]) 221 222 c = RowCoder.from_type_hint(IntTester, None) 223 224 no_overflow = chain( 225 (IntTester(i32=i, i64=None) for i in (-2**31, 2**31 - 1)), 226 (IntTester(i32=None, i64=i) for i in (-2**63, 2**63 - 1)), 227 ) 228 229 # Encode max/min ints to make sure they don't throw any error 230 for case in no_overflow: 231 c.encode(case) 232 233 overflow = chain( 234 (IntTester(i32=i, i64=None) for i in (-2**31 - 1, 2**31)), 235 (IntTester(i32=None, i64=i) for i in (-2**63 - 1, 2**63)), 236 ) 237 238 # Encode max+1/min-1 ints to make sure they DO throw an error 239 # pylint: disable=cell-var-from-loop 240 for case in overflow: 241 self.assertRaises(OverflowError, lambda: c.encode(case)) 242 243 def test_none_in_non_nullable_field_throws(self): 244 Test = typing.NamedTuple('Test', [('foo', str)]) 245 246 c = RowCoder.from_type_hint(Test, None) 247 self.assertRaises(ValueError, lambda: c.encode(Test(foo=None))) 248 249 def test_schema_remove_column(self): 250 fields = [("field1", str), ("field2", str)] 251 # new schema is missing one field that was in the old schema 252 Old = typing.NamedTuple('Old', fields) 253 New = typing.NamedTuple('New', fields[:-1]) 254 255 old_coder = RowCoder.from_type_hint(Old, None) 256 new_coder = RowCoder.from_type_hint(New, None) 257 258 self.assertEqual( 259 New("foo"), new_coder.decode(old_coder.encode(Old("foo", "bar")))) 260 261 def test_schema_add_column(self): 262 fields = [("field1", str), ("field2", typing.Optional[str])] 263 # new schema has one (optional) field that didn't exist in the old schema 264 Old = typing.NamedTuple('Old', fields[:-1]) 265 New = typing.NamedTuple('New', fields) 266 267 old_coder = RowCoder.from_type_hint(Old, None) 268 new_coder = RowCoder.from_type_hint(New, None) 269 270 self.assertEqual( 271 New("bar", None), new_coder.decode(old_coder.encode(Old("bar")))) 272 273 def test_schema_add_column_with_null_value(self): 274 fields = [("field1", typing.Optional[str]), ("field2", str), 275 ("field3", typing.Optional[str])] 276 # new schema has one (optional) field that didn't exist in the old schema 277 Old = typing.NamedTuple('Old', fields[:-1]) 278 New = typing.NamedTuple('New', fields) 279 280 old_coder = RowCoder.from_type_hint(Old, None) 281 new_coder = RowCoder.from_type_hint(New, None) 282 283 self.assertEqual( 284 New(None, "baz", None), 285 new_coder.decode(old_coder.encode(Old(None, "baz")))) 286 287 def test_row_coder_picklable(self): 288 # occasionally coders can get pickled, RowCoder should be able to handle it 289 coder = coders_registry.get_coder(Person) 290 roundtripped = pickler.loads(pickler.dumps(coder)) 291 292 self.assertEqual(roundtripped, coder) 293 294 def test_row_coder_in_pipeine(self): 295 with TestPipeline() as p: 296 res = ( 297 p 298 | beam.Create(self.PEOPLE) 299 | beam.Filter(lambda person: person.name == "Jon Snow")) 300 assert_that(res, equal_to([self.JON_SNOW])) 301 302 def test_row_coder_nested_struct(self): 303 Pair = typing.NamedTuple('Pair', [('left', Person), ('right', Person)]) 304 305 value = Pair(self.PEOPLE[0], self.PEOPLE[1]) 306 coder = RowCoder(typing_to_runner_api(Pair).row_type.schema) 307 308 self.assertEqual(value, coder.decode(coder.encode(value))) 309 310 def test_encoding_position_reorder_fields(self): 311 schema1 = schema_pb2.Schema( 312 id="reorder_test_schema1", 313 fields=[ 314 schema_pb2.Field( 315 name="f_int32", 316 type=schema_pb2.FieldType(atomic_type=schema_pb2.INT32), 317 ), 318 schema_pb2.Field( 319 name="f_str", 320 type=schema_pb2.FieldType(atomic_type=schema_pb2.STRING), 321 ), 322 ]) 323 schema2 = schema_pb2.Schema( 324 id="reorder_test_schema2", 325 encoding_positions_set=True, 326 fields=[ 327 schema_pb2.Field( 328 name="f_str", 329 type=schema_pb2.FieldType(atomic_type=schema_pb2.STRING), 330 encoding_position=1, 331 ), 332 schema_pb2.Field( 333 name="f_int32", 334 type=schema_pb2.FieldType(atomic_type=schema_pb2.INT32), 335 encoding_position=0, 336 ), 337 ]) 338 339 RowSchema1 = named_tuple_from_schema(schema1) 340 RowSchema2 = named_tuple_from_schema(schema2) 341 roundtripped = RowCoder(schema2).decode( 342 RowCoder(schema1).encode(RowSchema1(42, "Hello World!"))) 343 344 self.assertEqual(RowSchema2(f_int32=42, f_str="Hello World!"), roundtripped) 345 346 def test_encoding_position_add_fields_and_reorder(self): 347 old_schema = schema_pb2.Schema( 348 id="add_test_old", 349 fields=[ 350 schema_pb2.Field( 351 name="f_int32", 352 type=schema_pb2.FieldType(atomic_type=schema_pb2.INT32), 353 ), 354 schema_pb2.Field( 355 name="f_str", 356 type=schema_pb2.FieldType(atomic_type=schema_pb2.STRING), 357 ), 358 ]) 359 new_schema = schema_pb2.Schema( 360 encoding_positions_set=True, 361 id="add_test_new", 362 fields=[ 363 schema_pb2.Field( 364 name="f_new_str", 365 type=schema_pb2.FieldType( 366 atomic_type=schema_pb2.STRING, nullable=True), 367 encoding_position=2, 368 ), 369 schema_pb2.Field( 370 name="f_int32", 371 type=schema_pb2.FieldType(atomic_type=schema_pb2.INT32), 372 encoding_position=0, 373 ), 374 schema_pb2.Field( 375 name="f_str", 376 type=schema_pb2.FieldType(atomic_type=schema_pb2.STRING), 377 encoding_position=1, 378 ), 379 ]) 380 381 Old = named_tuple_from_schema(old_schema) 382 New = named_tuple_from_schema(new_schema) 383 roundtripped = RowCoder(new_schema).decode( 384 RowCoder(old_schema).encode(Old(42, "Hello World!"))) 385 386 self.assertEqual( 387 New(f_new_str=None, f_int32=42, f_str="Hello World!"), roundtripped) 388 389 def test_row_coder_fail_early_bad_schema(self): 390 schema_proto = schema_pb2.Schema( 391 fields=[ 392 schema_pb2.Field( 393 name="type_with_no_typeinfo", type=schema_pb2.FieldType()) 394 ], 395 id='bad-schema') 396 397 # Should raise an exception referencing the problem field 398 self.assertRaisesRegex( 399 ValueError, "type_with_no_typeinfo", lambda: RowCoder(schema_proto)) 400 401 def test_row_coder_cloud_object_schema(self): 402 schema_proto = schema_pb2.Schema(id='some-cloud-object-schema') 403 schema_proto_json = json_format.MessageToJson(schema_proto).encode('utf-8') 404 405 coder = RowCoder(schema_proto) 406 407 cloud_object = coder.as_cloud_object() 408 409 self.assertEqual(schema_proto_json, cloud_object['schema']) 410 411 def test_batch_encode_decode(self): 412 coder = RowCoder(typing_to_runner_api(Person).row_type.schema).get_impl() 413 seq_out = coder_impl.create_OutputStream() 414 for person in self.PEOPLE: 415 coder.encode_to_stream(person, seq_out, False) 416 417 batch_out = coder_impl.create_OutputStream() 418 columnar = { 419 field: np.array([getattr(person, field) for person in self.PEOPLE], 420 ndmin=1, 421 dtype=object) 422 for field in Person._fields 423 } 424 coder.encode_batch_to_stream(columnar, batch_out) 425 if seq_out.get() != batch_out.get(): 426 a, b = seq_out.get(), batch_out.get() 427 N = 25 428 for k in range(0, max(len(a), len(b)), N): 429 print(k, a[k:k + N] == b[k:k + N]) 430 print(a[k:k + N]) 431 print(b[k:k + N]) 432 self.assertEqual(seq_out.get(), batch_out.get()) 433 434 for size in [len(self.PEOPLE) - 1, len(self.PEOPLE), len(self.PEOPLE) + 1]: 435 dest = { 436 field: np.ndarray((size, ), dtype=a.dtype) 437 for field, 438 a in columnar.items() 439 } 440 n = min(size, len(self.PEOPLE)) 441 self.assertEqual( 442 n, 443 coder.decode_batch_from_stream( 444 dest, coder_impl.create_InputStream(seq_out.get()))) 445 for field, a in columnar.items(): 446 assert_array_equal(a[:n], dest[field][:n]) 447 448 449 if __name__ == "__main__": 450 logging.getLogger().setLevel(logging.INFO) 451 unittest.main()