github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/coders/coders_test_common.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Tests common to all coder implementations.""" 19 # pytype: skip-file 20 21 import base64 22 import collections 23 import enum 24 import logging 25 import math 26 import unittest 27 from decimal import Decimal 28 from typing import Any 29 from typing import List 30 from typing import NamedTuple 31 32 import pytest 33 34 from apache_beam.coders import proto2_coder_test_messages_pb2 as test_message 35 from apache_beam.coders import coders 36 from apache_beam.coders import typecoders 37 from apache_beam.internal import pickler 38 from apache_beam.runners import pipeline_context 39 from apache_beam.transforms import userstate 40 from apache_beam.transforms import window 41 from apache_beam.transforms.window import GlobalWindow 42 from apache_beam.typehints import sharded_key_type 43 from apache_beam.typehints import typehints 44 from apache_beam.utils import timestamp 45 from apache_beam.utils import windowed_value 46 from apache_beam.utils.sharded_key import ShardedKey 47 from apache_beam.utils.timestamp import MIN_TIMESTAMP 48 49 from . import observable 50 51 try: 52 import dataclasses 53 except ImportError: 54 dataclasses = None # type: ignore 55 56 MyNamedTuple = collections.namedtuple('A', ['x', 'y']) 57 MyTypedNamedTuple = NamedTuple('MyTypedNamedTuple', [('f1', int), ('f2', str)]) 58 59 60 class MyEnum(enum.Enum): 61 E1 = 5 62 E2 = enum.auto() 63 E3 = 'abc' 64 65 66 MyIntEnum = enum.IntEnum('MyIntEnum', 'I1 I2 I3') 67 MyIntFlag = enum.IntFlag('MyIntFlag', 'F1 F2 F3') 68 MyFlag = enum.Flag('MyFlag', 'F1 F2 F3') # pylint: disable=too-many-function-args 69 70 71 class DefinesGetState: 72 def __init__(self, value): 73 self.value = value 74 75 def __getstate__(self): 76 return self.value 77 78 def __eq__(self, other): 79 return type(other) is type(self) and other.value == self.value 80 81 82 class DefinesGetAndSetState(DefinesGetState): 83 def __setstate__(self, value): 84 self.value = value 85 86 87 # Defined out of line for picklability. 88 class CustomCoder(coders.Coder): 89 def encode(self, x): 90 return str(x + 1).encode('utf-8') 91 92 def decode(self, encoded): 93 return int(encoded) - 1 94 95 96 if dataclasses is not None: 97 98 @dataclasses.dataclass(frozen=True) 99 class FrozenDataClass: 100 a: Any 101 b: int 102 103 @dataclasses.dataclass 104 class UnFrozenDataClass: 105 x: int 106 y: int 107 108 109 # These tests need to all be run in the same process due to the asserts 110 # in tearDownClass. 111 @pytest.mark.no_xdist 112 class CodersTest(unittest.TestCase): 113 114 # These class methods ensure that we test each defined coder in both 115 # nested and unnested context. 116 117 # Common test values representing Python's built-in types. 118 test_values_deterministic: List[Any] = [ 119 None, 120 1, 121 -1, 122 1.5, 123 b'str\0str', 124 u'unicode\0\u0101', 125 (), 126 (1, 2, 3), 127 [], 128 [1, 2, 3], 129 True, 130 False, 131 ] 132 test_values = test_values_deterministic + [ 133 {}, 134 { 135 'a': 'b' 136 }, 137 { 138 0: {}, 1: len 139 }, 140 set(), 141 {'a', 'b'}, 142 len, 143 ] 144 145 @classmethod 146 def setUpClass(cls): 147 cls.seen = set() 148 cls.seen_nested = set() 149 150 @classmethod 151 def tearDownClass(cls): 152 standard = set( 153 c for c in coders.__dict__.values() if isinstance(c, type) and 154 issubclass(c, coders.Coder) and 'Base' not in c.__name__) 155 standard -= set([ 156 coders.Coder, 157 coders.AvroGenericCoder, 158 coders.DeterministicProtoCoder, 159 coders.FastCoder, 160 coders.ListLikeCoder, 161 coders.ProtoCoder, 162 coders.ProtoPlusCoder, 163 coders.BigEndianShortCoder, 164 coders.SinglePrecisionFloatCoder, 165 coders.ToBytesCoder, 166 coders.BigIntegerCoder, # tested in DecimalCoder 167 ]) 168 cls.seen_nested -= set( 169 [coders.ProtoCoder, coders.ProtoPlusCoder, CustomCoder]) 170 assert not standard - cls.seen, str(standard - cls.seen) 171 assert not cls.seen_nested - standard, str(cls.seen_nested - standard) 172 173 @classmethod 174 def _observe(cls, coder): 175 cls.seen.add(type(coder)) 176 cls._observe_nested(coder) 177 178 @classmethod 179 def _observe_nested(cls, coder): 180 if isinstance(coder, coders.TupleCoder): 181 for c in coder.coders(): 182 cls.seen_nested.add(type(c)) 183 cls._observe_nested(c) 184 185 def check_coder(self, coder, *values, **kwargs): 186 context = kwargs.pop('context', pipeline_context.PipelineContext()) 187 test_size_estimation = kwargs.pop('test_size_estimation', True) 188 assert not kwargs 189 self._observe(coder) 190 for v in values: 191 self.assertEqual(v, coder.decode(coder.encode(v))) 192 if test_size_estimation: 193 self.assertEqual(coder.estimate_size(v), len(coder.encode(v))) 194 self.assertEqual( 195 coder.estimate_size(v), coder.get_impl().estimate_size(v)) 196 self.assertEqual( 197 coder.get_impl().get_estimated_size_and_observables(v), 198 (coder.get_impl().estimate_size(v), [])) 199 copy1 = pickler.loads(pickler.dumps(coder)) 200 copy2 = coders.Coder.from_runner_api(coder.to_runner_api(context), context) 201 for v in values: 202 self.assertEqual(v, copy1.decode(copy2.encode(v))) 203 if coder.is_deterministic(): 204 self.assertEqual(copy1.encode(v), copy2.encode(v)) 205 206 def test_custom_coder(self): 207 208 self.check_coder(CustomCoder(), 1, -10, 5) 209 self.check_coder( 210 coders.TupleCoder((CustomCoder(), coders.BytesCoder())), (1, b'a'), 211 (-10, b'b'), (5, b'c')) 212 213 def test_pickle_coder(self): 214 coder = coders.PickleCoder() 215 self.check_coder(coder, *self.test_values) 216 217 def test_memoizing_pickle_coder(self): 218 coder = coders._MemoizingPickleCoder() 219 self.check_coder(coder, *self.test_values) 220 221 def test_deterministic_coder(self): 222 coder = coders.FastPrimitivesCoder() 223 deterministic_coder = coders.DeterministicFastPrimitivesCoder(coder, 'step') 224 self.check_coder(deterministic_coder, *self.test_values_deterministic) 225 for v in self.test_values_deterministic: 226 self.check_coder(coders.TupleCoder((deterministic_coder, )), (v, )) 227 self.check_coder( 228 coders.TupleCoder( 229 (deterministic_coder, ) * len(self.test_values_deterministic)), 230 tuple(self.test_values_deterministic)) 231 232 self.check_coder(deterministic_coder, {}) 233 self.check_coder(deterministic_coder, {2: 'x', 1: 'y'}) 234 with self.assertRaises(TypeError): 235 self.check_coder(deterministic_coder, {1: 'x', 'y': 2}) 236 self.check_coder(deterministic_coder, [1, {}]) 237 with self.assertRaises(TypeError): 238 self.check_coder(deterministic_coder, [1, {1: 'x', 'y': 2}]) 239 240 self.check_coder( 241 coders.TupleCoder((deterministic_coder, coder)), (1, {}), ('a', [{}])) 242 243 self.check_coder(deterministic_coder, test_message.MessageA(field1='value')) 244 245 self.check_coder( 246 deterministic_coder, [MyNamedTuple(1, 2), MyTypedNamedTuple(1, 'a')]) 247 248 if dataclasses is not None: 249 self.check_coder(deterministic_coder, FrozenDataClass(1, 2)) 250 251 with self.assertRaises(TypeError): 252 self.check_coder(deterministic_coder, UnFrozenDataClass(1, 2)) 253 with self.assertRaises(TypeError): 254 self.check_coder( 255 deterministic_coder, FrozenDataClass(UnFrozenDataClass(1, 2), 3)) 256 with self.assertRaises(TypeError): 257 self.check_coder( 258 deterministic_coder, MyNamedTuple(UnFrozenDataClass(1, 2), 3)) 259 260 self.check_coder(deterministic_coder, list(MyEnum)) 261 self.check_coder(deterministic_coder, list(MyIntEnum)) 262 self.check_coder(deterministic_coder, list(MyIntFlag)) 263 self.check_coder(deterministic_coder, list(MyFlag)) 264 265 self.check_coder( 266 deterministic_coder, 267 [DefinesGetAndSetState(1), DefinesGetAndSetState((1, 2, 3))]) 268 269 with self.assertRaises(TypeError): 270 self.check_coder(deterministic_coder, DefinesGetState(1)) 271 with self.assertRaises(TypeError): 272 self.check_coder( 273 deterministic_coder, DefinesGetAndSetState({ 274 1: 'x', 'y': 2 275 })) 276 277 def test_dill_coder(self): 278 cell_value = (lambda x: lambda: x)(0).__closure__[0] 279 self.check_coder(coders.DillCoder(), 'a', 1, cell_value) 280 self.check_coder( 281 coders.TupleCoder((coders.VarIntCoder(), coders.DillCoder())), 282 (1, cell_value)) 283 284 def test_fast_primitives_coder(self): 285 coder = coders.FastPrimitivesCoder(coders.SingletonCoder(len)) 286 self.check_coder(coder, *self.test_values) 287 for v in self.test_values: 288 self.check_coder(coders.TupleCoder((coder, )), (v, )) 289 290 def test_fast_primitives_coder_large_int(self): 291 coder = coders.FastPrimitivesCoder() 292 self.check_coder(coder, 10**100) 293 294 def test_fake_deterministic_fast_primitives_coder(self): 295 coder = coders.FakeDeterministicFastPrimitivesCoder(coders.PickleCoder()) 296 self.check_coder(coder, *self.test_values) 297 for v in self.test_values: 298 self.check_coder(coders.TupleCoder((coder, )), (v, )) 299 300 def test_bytes_coder(self): 301 self.check_coder(coders.BytesCoder(), b'a', b'\0', b'z' * 1000) 302 303 def test_bool_coder(self): 304 self.check_coder(coders.BooleanCoder(), True, False) 305 306 def test_varint_coder(self): 307 # Small ints. 308 self.check_coder(coders.VarIntCoder(), *range(-10, 10)) 309 # Multi-byte encoding starts at 128 310 self.check_coder(coders.VarIntCoder(), *range(120, 140)) 311 # Large values 312 MAX_64_BIT_INT = 0x7fffffffffffffff 313 self.check_coder( 314 coders.VarIntCoder(), 315 *[ 316 int(math.pow(-1, k) * math.exp(k)) 317 for k in range(0, int(math.log(MAX_64_BIT_INT))) 318 ]) 319 320 def test_float_coder(self): 321 self.check_coder( 322 coders.FloatCoder(), *[float(0.1 * x) for x in range(-100, 100)]) 323 self.check_coder( 324 coders.FloatCoder(), *[float(2**(0.1 * x)) for x in range(-100, 100)]) 325 self.check_coder(coders.FloatCoder(), float('-Inf'), float('Inf')) 326 self.check_coder( 327 coders.TupleCoder((coders.FloatCoder(), coders.FloatCoder())), (0, 1), 328 (-100, 100), (0.5, 0.25)) 329 330 def test_singleton_coder(self): 331 a = 'anything' 332 b = 'something else' 333 self.check_coder(coders.SingletonCoder(a), a) 334 self.check_coder(coders.SingletonCoder(b), b) 335 self.check_coder( 336 coders.TupleCoder((coders.SingletonCoder(a), coders.SingletonCoder(b))), 337 (a, b)) 338 339 def test_interval_window_coder(self): 340 self.check_coder( 341 coders.IntervalWindowCoder(), 342 *[ 343 window.IntervalWindow(x, y) for x in [-2**52, 0, 2**52] 344 for y in range(-100, 100) 345 ]) 346 self.check_coder( 347 coders.TupleCoder((coders.IntervalWindowCoder(), )), 348 (window.IntervalWindow(0, 10), )) 349 350 def test_timestamp_coder(self): 351 self.check_coder( 352 coders.TimestampCoder(), 353 *[timestamp.Timestamp(micros=x) for x in (-1000, 0, 1000)]) 354 self.check_coder( 355 coders.TimestampCoder(), 356 timestamp.Timestamp(micros=-1234567000), 357 timestamp.Timestamp(micros=1234567000)) 358 self.check_coder( 359 coders.TimestampCoder(), 360 timestamp.Timestamp(micros=-1234567890123456000), 361 timestamp.Timestamp(micros=1234567890123456000)) 362 self.check_coder( 363 coders.TupleCoder((coders.TimestampCoder(), coders.BytesCoder())), 364 (timestamp.Timestamp.of(27), b'abc')) 365 366 def test_timer_coder(self): 367 self.check_coder( 368 coders._TimerCoder(coders.StrUtf8Coder(), coders.GlobalWindowCoder()), 369 *[ 370 userstate.Timer( 371 user_key="key", 372 dynamic_timer_tag="tag", 373 windows=(GlobalWindow(), ), 374 clear_bit=True, 375 fire_timestamp=None, 376 hold_timestamp=None, 377 paneinfo=None), 378 userstate.Timer( 379 user_key="key", 380 dynamic_timer_tag="tag", 381 windows=(GlobalWindow(), ), 382 clear_bit=False, 383 fire_timestamp=timestamp.Timestamp.of(123), 384 hold_timestamp=timestamp.Timestamp.of(456), 385 paneinfo=windowed_value.PANE_INFO_UNKNOWN) 386 ]) 387 388 def test_tuple_coder(self): 389 kv_coder = coders.TupleCoder((coders.VarIntCoder(), coders.BytesCoder())) 390 # Verify cloud object representation 391 self.assertEqual({ 392 '@type': 'kind:pair', 393 'is_pair_like': True, 394 'component_encodings': [ 395 coders.VarIntCoder().as_cloud_object(), 396 coders.BytesCoder().as_cloud_object() 397 ], 398 }, 399 kv_coder.as_cloud_object()) 400 # Test binary representation 401 self.assertEqual(b'\x04abc', kv_coder.encode((4, b'abc'))) 402 # Test unnested 403 self.check_coder(kv_coder, (1, b'a'), (-2, b'a' * 100), (300, b'abc\0' * 5)) 404 # Test nested 405 self.check_coder( 406 coders.TupleCoder(( 407 coders.TupleCoder((coders.PickleCoder(), coders.VarIntCoder())), 408 coders.StrUtf8Coder(), 409 coders.BooleanCoder())), ((1, 2), 'a', True), 410 ((-2, 5), u'a\u0101' * 100, False), ((300, 1), 'abc\0' * 5, True)) 411 412 def test_tuple_sequence_coder(self): 413 int_tuple_coder = coders.TupleSequenceCoder(coders.VarIntCoder()) 414 self.check_coder(int_tuple_coder, (1, -1, 0), (), tuple(range(1000))) 415 self.check_coder( 416 coders.TupleCoder((coders.VarIntCoder(), int_tuple_coder)), 417 (1, (1, 2, 3))) 418 419 def test_base64_pickle_coder(self): 420 self.check_coder(coders.Base64PickleCoder(), 'a', 1, 1.5, (1, 2, 3)) 421 422 def test_utf8_coder(self): 423 self.check_coder(coders.StrUtf8Coder(), 'a', u'ab\u00FF', u'\u0101\0') 424 425 def test_iterable_coder(self): 426 iterable_coder = coders.IterableCoder(coders.VarIntCoder()) 427 # Verify cloud object representation 428 self.assertEqual({ 429 '@type': 'kind:stream', 430 'is_stream_like': True, 431 'component_encodings': [coders.VarIntCoder().as_cloud_object()] 432 }, 433 iterable_coder.as_cloud_object()) 434 # Test unnested 435 self.check_coder(iterable_coder, [1], [-1, 0, 100]) 436 # Test nested 437 self.check_coder( 438 coders.TupleCoder( 439 (coders.VarIntCoder(), coders.IterableCoder(coders.VarIntCoder()))), 440 (1, [1, 2, 3])) 441 442 def test_iterable_coder_unknown_length(self): 443 # Empty 444 self._test_iterable_coder_of_unknown_length(0) 445 # Single element 446 self._test_iterable_coder_of_unknown_length(1) 447 # Multiple elements 448 self._test_iterable_coder_of_unknown_length(100) 449 # Multiple elements with underlying stream buffer overflow. 450 self._test_iterable_coder_of_unknown_length(80000) 451 452 def _test_iterable_coder_of_unknown_length(self, count): 453 def iter_generator(count): 454 for i in range(count): 455 yield i 456 457 iterable_coder = coders.IterableCoder(coders.VarIntCoder()) 458 self.assertCountEqual( 459 list(iter_generator(count)), 460 iterable_coder.decode(iterable_coder.encode(iter_generator(count)))) 461 462 def test_list_coder(self): 463 list_coder = coders.ListCoder(coders.VarIntCoder()) 464 # Test unnested 465 self.check_coder(list_coder, [1], [-1, 0, 100]) 466 # Test nested 467 self.check_coder( 468 coders.TupleCoder((coders.VarIntCoder(), list_coder)), (1, [1, 2, 3])) 469 470 def test_windowedvalue_coder_paneinfo(self): 471 coder = coders.WindowedValueCoder( 472 coders.VarIntCoder(), coders.GlobalWindowCoder()) 473 test_paneinfo_values = [ 474 windowed_value.PANE_INFO_UNKNOWN, 475 windowed_value.PaneInfo( 476 True, True, windowed_value.PaneInfoTiming.EARLY, 0, -1), 477 windowed_value.PaneInfo( 478 True, False, windowed_value.PaneInfoTiming.ON_TIME, 0, 0), 479 windowed_value.PaneInfo( 480 True, False, windowed_value.PaneInfoTiming.ON_TIME, 10, 0), 481 windowed_value.PaneInfo( 482 False, True, windowed_value.PaneInfoTiming.ON_TIME, 0, 23), 483 windowed_value.PaneInfo( 484 False, True, windowed_value.PaneInfoTiming.ON_TIME, 12, 23), 485 windowed_value.PaneInfo( 486 False, False, windowed_value.PaneInfoTiming.LATE, 0, 123), 487 ] 488 489 test_values = [ 490 windowed_value.WindowedValue(123, 234, (GlobalWindow(), ), p) 491 for p in test_paneinfo_values 492 ] 493 494 # Test unnested. 495 self.check_coder( 496 coder, 497 windowed_value.WindowedValue( 498 123, 234, (GlobalWindow(), ), windowed_value.PANE_INFO_UNKNOWN)) 499 for value in test_values: 500 self.check_coder(coder, value) 501 502 # Test nested. 503 for value1 in test_values: 504 for value2 in test_values: 505 self.check_coder(coders.TupleCoder((coder, coder)), (value1, value2)) 506 507 def test_windowed_value_coder(self): 508 coder = coders.WindowedValueCoder( 509 coders.VarIntCoder(), coders.GlobalWindowCoder()) 510 # Verify cloud object representation 511 self.assertEqual({ 512 '@type': 'kind:windowed_value', 513 'is_wrapper': True, 514 'component_encodings': [ 515 coders.VarIntCoder().as_cloud_object(), 516 coders.GlobalWindowCoder().as_cloud_object(), 517 ], 518 }, 519 coder.as_cloud_object()) 520 # Test binary representation 521 self.assertEqual( 522 b'\x7f\xdf;dZ\x1c\xac\t\x00\x00\x00\x01\x0f\x01', 523 coder.encode(window.GlobalWindows.windowed_value(1))) 524 525 # Test decoding large timestamp 526 self.assertEqual( 527 coder.decode(b'\x7f\xdf;dZ\x1c\xac\x08\x00\x00\x00\x01\x0f\x00'), 528 windowed_value.create(0, MIN_TIMESTAMP.micros, (GlobalWindow(), ))) 529 530 # Test unnested 531 self.check_coder( 532 coders.WindowedValueCoder(coders.VarIntCoder()), 533 windowed_value.WindowedValue(3, -100, ()), 534 windowed_value.WindowedValue(-1, 100, (1, 2, 3))) 535 536 # Test Global Window 537 self.check_coder( 538 coders.WindowedValueCoder( 539 coders.VarIntCoder(), coders.GlobalWindowCoder()), 540 window.GlobalWindows.windowed_value(1)) 541 542 # Test nested 543 self.check_coder( 544 coders.TupleCoder(( 545 coders.WindowedValueCoder(coders.FloatCoder()), 546 coders.WindowedValueCoder(coders.StrUtf8Coder()))), 547 ( 548 windowed_value.WindowedValue(1.5, 0, ()), 549 windowed_value.WindowedValue("abc", 10, ('window', )))) 550 551 def test_param_windowed_value_coder(self): 552 from apache_beam.transforms.window import IntervalWindow 553 from apache_beam.utils.windowed_value import PaneInfo 554 wv = windowed_value.create( 555 b'', 556 # Milliseconds to microseconds 557 1000 * 1000, 558 (IntervalWindow(11, 21), ), 559 PaneInfo(True, False, 1, 2, 3)) 560 windowed_value_coder = coders.WindowedValueCoder( 561 coders.BytesCoder(), coders.IntervalWindowCoder()) 562 payload = windowed_value_coder.encode(wv) 563 coder = coders.ParamWindowedValueCoder( 564 payload, [coders.VarIntCoder(), coders.IntervalWindowCoder()]) 565 566 # Test binary representation 567 self.assertEqual( 568 b'\x01', coder.encode(window.GlobalWindows.windowed_value(1))) 569 570 # Test unnested 571 self.check_coder( 572 coders.ParamWindowedValueCoder( 573 payload, [coders.VarIntCoder(), coders.IntervalWindowCoder()]), 574 windowed_value.WindowedValue( 575 3, 576 1, (window.IntervalWindow(11, 21), ), 577 PaneInfo(True, False, 1, 2, 3)), 578 windowed_value.WindowedValue( 579 1, 580 1, (window.IntervalWindow(11, 21), ), 581 PaneInfo(True, False, 1, 2, 3))) 582 583 # Test nested 584 self.check_coder( 585 coders.TupleCoder(( 586 coders.ParamWindowedValueCoder( 587 payload, [coders.FloatCoder(), coders.IntervalWindowCoder()]), 588 coders.ParamWindowedValueCoder( 589 payload, 590 [coders.StrUtf8Coder(), coders.IntervalWindowCoder()]))), 591 ( 592 windowed_value.WindowedValue( 593 1.5, 594 1, (window.IntervalWindow(11, 21), ), 595 PaneInfo(True, False, 1, 2, 3)), 596 windowed_value.WindowedValue( 597 "abc", 598 1, (window.IntervalWindow(11, 21), ), 599 PaneInfo(True, False, 1, 2, 3)))) 600 601 def test_proto_coder(self): 602 # For instructions on how these test proto message were generated, 603 # see coders_test.py 604 ma = test_message.MessageA() 605 mab = ma.field2.add() 606 mab.field1 = True 607 ma.field1 = u'hello world' 608 609 mb = test_message.MessageA() 610 mb.field1 = u'beam' 611 612 proto_coder = coders.ProtoCoder(ma.__class__) 613 self.check_coder(proto_coder, ma) 614 self.check_coder( 615 coders.TupleCoder((proto_coder, coders.BytesCoder())), (ma, b'a'), 616 (mb, b'b')) 617 618 def test_global_window_coder(self): 619 coder = coders.GlobalWindowCoder() 620 value = window.GlobalWindow() 621 # Verify cloud object representation 622 self.assertEqual({'@type': 'kind:global_window'}, coder.as_cloud_object()) 623 # Test binary representation 624 self.assertEqual(b'', coder.encode(value)) 625 self.assertEqual(value, coder.decode(b'')) 626 # Test unnested 627 self.check_coder(coder, value) 628 # Test nested 629 self.check_coder(coders.TupleCoder((coder, coder)), (value, value)) 630 631 def test_length_prefix_coder(self): 632 coder = coders.LengthPrefixCoder(coders.BytesCoder()) 633 # Verify cloud object representation 634 self.assertEqual({ 635 '@type': 'kind:length_prefix', 636 'component_encodings': [coders.BytesCoder().as_cloud_object()] 637 }, 638 coder.as_cloud_object()) 639 # Test binary representation 640 self.assertEqual(b'\x00', coder.encode(b'')) 641 self.assertEqual(b'\x01a', coder.encode(b'a')) 642 self.assertEqual(b'\x02bc', coder.encode(b'bc')) 643 self.assertEqual(b'\xff\x7f' + b'z' * 16383, coder.encode(b'z' * 16383)) 644 # Test unnested 645 self.check_coder(coder, b'', b'a', b'bc', b'def') 646 # Test nested 647 self.check_coder( 648 coders.TupleCoder((coder, coder)), (b'', b'a'), (b'bc', b'def')) 649 650 def test_nested_observables(self): 651 class FakeObservableIterator(observable.ObservableMixin): 652 def __iter__(self): 653 return iter([1, 2, 3]) 654 655 # Coder for elements from the observable iterator. 656 elem_coder = coders.VarIntCoder() 657 iter_coder = coders.TupleSequenceCoder(elem_coder) 658 659 # Test nested WindowedValue observable. 660 coder = coders.WindowedValueCoder(iter_coder) 661 observ = FakeObservableIterator() 662 value = windowed_value.WindowedValue(observ, 0, ()) 663 self.assertEqual( 664 coder.get_impl().get_estimated_size_and_observables(value)[1], 665 [(observ, elem_coder.get_impl())]) 666 667 # Test nested tuple observable. 668 coder = coders.TupleCoder((coders.StrUtf8Coder(), iter_coder)) 669 value = (u'123', observ) 670 self.assertEqual( 671 coder.get_impl().get_estimated_size_and_observables(value)[1], 672 [(observ, elem_coder.get_impl())]) 673 674 def test_state_backed_iterable_coder(self): 675 # pylint: disable=global-variable-undefined 676 # required for pickling by reference 677 global state 678 state = {} 679 680 def iterable_state_write(values, element_coder_impl): 681 token = b'state_token_%d' % len(state) 682 state[token] = [element_coder_impl.encode(e) for e in values] 683 return token 684 685 def iterable_state_read(token, element_coder_impl): 686 return [element_coder_impl.decode(s) for s in state[token]] 687 688 coder = coders.StateBackedIterableCoder( 689 coders.VarIntCoder(), 690 read_state=iterable_state_read, 691 write_state=iterable_state_write, 692 write_state_threshold=1) 693 # Note: do not use check_coder 694 # see https://github.com/cloudpipe/cloudpickle/issues/452 695 self._observe(coder) 696 self.assertEqual([1, 2, 3], coder.decode(coder.encode([1, 2, 3]))) 697 # Ensure that state was actually used. 698 self.assertNotEqual(state, {}) 699 tupleCoder = coders.TupleCoder((coder, coder)) 700 self._observe(tupleCoder) 701 self.assertEqual(([1], [2, 3]), 702 tupleCoder.decode(tupleCoder.encode(([1], [2, 3])))) 703 704 def test_nullable_coder(self): 705 self.check_coder(coders.NullableCoder(coders.VarIntCoder()), None, 2 * 64) 706 707 def test_map_coder(self): 708 values = [ 709 {1: "one", 300: "three hundred"}, # force yapf to be nice 710 {}, 711 {i: str(i) for i in range(5000)} 712 ] 713 map_coder = coders.MapCoder(coders.VarIntCoder(), coders.StrUtf8Coder()) 714 self.check_coder(map_coder, *values) 715 self.check_coder(map_coder.as_deterministic_coder("label"), *values) 716 717 def test_sharded_key_coder(self): 718 key_and_coders = [(b'', b'\x00', coders.BytesCoder()), 719 (b'key', b'\x03key', coders.BytesCoder()), 720 ('key', b'\03\x6b\x65\x79', coders.StrUtf8Coder()), 721 (('k', 1), 722 b'\x01\x6b\x01', 723 coders.TupleCoder( 724 (coders.StrUtf8Coder(), coders.VarIntCoder())))] 725 726 for key, bytes_repr, key_coder in key_and_coders: 727 coder = coders.ShardedKeyCoder(key_coder) 728 # Verify cloud object representation 729 self.assertEqual({ 730 '@type': 'kind:sharded_key', 731 'component_encodings': [key_coder.as_cloud_object()] 732 }, 733 coder.as_cloud_object()) 734 735 # Test str repr 736 self.assertEqual('%s' % coder, 'ShardedKeyCoder[%s]' % key_coder) 737 738 self.assertEqual(b'\x00' + bytes_repr, coder.encode(ShardedKey(key, b''))) 739 self.assertEqual( 740 b'\x03123' + bytes_repr, coder.encode(ShardedKey(key, b'123'))) 741 742 # Test unnested 743 self.check_coder(coder, ShardedKey(key, b'')) 744 self.check_coder(coder, ShardedKey(key, b'123')) 745 746 # Test type hints 747 self.assertTrue( 748 isinstance( 749 coder.to_type_hint(), sharded_key_type.ShardedKeyTypeConstraint)) 750 key_type = coder.to_type_hint().key_type 751 if isinstance(key_type, typehints.TupleConstraint): 752 self.assertEqual(key_type.tuple_types, (type(key[0]), type(key[1]))) 753 else: 754 self.assertEqual(key_type, type(key)) 755 self.assertEqual( 756 coders.ShardedKeyCoder.from_type_hint( 757 coder.to_type_hint(), typecoders.CoderRegistry()), 758 coder) 759 760 for other_key, _, other_key_coder in key_and_coders: 761 other_coder = coders.ShardedKeyCoder(other_key_coder) 762 # Test nested 763 self.check_coder( 764 coders.TupleCoder((coder, other_coder)), 765 (ShardedKey(key, b''), ShardedKey(other_key, b''))) 766 self.check_coder( 767 coders.TupleCoder((coder, other_coder)), 768 (ShardedKey(key, b'123'), ShardedKey(other_key, b''))) 769 770 def test_timestamp_prefixing_window_coder(self): 771 self.check_coder( 772 coders.TimestampPrefixingWindowCoder(coders.IntervalWindowCoder()), 773 *[ 774 window.IntervalWindow(x, y) for x in [-2**52, 0, 2**52] 775 for y in range(-100, 100) 776 ]) 777 self.check_coder( 778 coders.TupleCoder(( 779 coders.TimestampPrefixingWindowCoder( 780 coders.IntervalWindowCoder()), )), 781 (window.IntervalWindow(0, 10), )) 782 783 def test_decimal_coder(self): 784 test_coder = coders.DecimalCoder() 785 786 test_values = [ 787 Decimal("-10.5"), 788 Decimal("-1"), 789 Decimal(), 790 Decimal("1"), 791 Decimal("13.258"), 792 ] 793 794 test_encodings = ("AZc", "AP8", "AAA", "AAE", "AzPK") 795 796 self.check_coder(test_coder, *test_values) 797 798 for idx, value in enumerate(test_values): 799 self.assertEqual( 800 test_encodings[idx], 801 base64.b64encode(test_coder.encode(value)).decode().rstrip("=")) 802 803 804 if __name__ == '__main__': 805 logging.getLogger().setLevel(logging.INFO) 806 unittest.main()