github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/coders/coder_impl.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 # cython: language_level=3 19 20 """Coder implementations. 21 22 The actual encode/decode implementations are split off from coders to 23 allow conditional (compiled/pure) implementations, which can be used to 24 encode many elements with minimal overhead. 25 26 This module may be optionally compiled with Cython, using the corresponding 27 coder_impl.pxd file for type hints. In particular, because CoderImpls are 28 never pickled and sent across the wire (unlike Coders themselves) the workers 29 can use compiled Impls even if the main program does not (or vice versa). 30 31 For internal use only; no backwards-compatibility guarantees. 32 """ 33 # pytype: skip-file 34 35 import decimal 36 import enum 37 import itertools 38 import json 39 import logging 40 import pickle 41 from io import BytesIO 42 from typing import TYPE_CHECKING 43 from typing import Any 44 from typing import Callable 45 from typing import Dict 46 from typing import Iterable 47 from typing import Iterator 48 from typing import List 49 from typing import Optional 50 from typing import Sequence 51 from typing import Set 52 from typing import Tuple 53 from typing import Type 54 55 import dill 56 import numpy as np 57 from fastavro import parse_schema 58 from fastavro import schemaless_reader 59 from fastavro import schemaless_writer 60 61 from apache_beam.coders import observable 62 from apache_beam.coders.avro_record import AvroRecord 63 from apache_beam.typehints.schemas import named_tuple_from_schema 64 from apache_beam.utils import proto_utils 65 from apache_beam.utils import windowed_value 66 from apache_beam.utils.sharded_key import ShardedKey 67 from apache_beam.utils.timestamp import MAX_TIMESTAMP 68 from apache_beam.utils.timestamp import MIN_TIMESTAMP 69 from apache_beam.utils.timestamp import Timestamp 70 71 try: 72 import dataclasses 73 except ImportError: 74 dataclasses = None # type: ignore 75 76 if TYPE_CHECKING: 77 import proto 78 from apache_beam.transforms import userstate 79 from apache_beam.transforms.window import IntervalWindow 80 81 try: 82 from . import stream # pylint: disable=unused-import 83 except ImportError: 84 SLOW_STREAM = True 85 else: 86 SLOW_STREAM = False 87 88 is_compiled = False 89 fits_in_64_bits = lambda x: -(1 << 63) <= x <= (1 << 63) - 1 90 91 if TYPE_CHECKING or SLOW_STREAM: 92 from .slow_stream import InputStream as create_InputStream 93 from .slow_stream import OutputStream as create_OutputStream 94 from .slow_stream import ByteCountingOutputStream 95 from .slow_stream import get_varint_size 96 97 try: 98 import cython 99 is_compiled = cython.compiled 100 except ImportError: 101 globals()['cython'] = type('fake_cython', (), {'cast': lambda typ, x: x}) 102 103 else: 104 # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports 105 from .stream import InputStream as create_InputStream 106 from .stream import OutputStream as create_OutputStream 107 from .stream import ByteCountingOutputStream 108 from .stream import get_varint_size 109 # Make it possible to import create_InputStream and other cdef-classes 110 # from apache_beam.coders.coder_impl when Cython codepath is used. 111 globals()['create_InputStream'] = create_InputStream 112 globals()['create_OutputStream'] = create_OutputStream 113 globals()['ByteCountingOutputStream'] = ByteCountingOutputStream 114 # pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports 115 is_compiled = True 116 117 _LOGGER = logging.getLogger(__name__) 118 119 _TIME_SHIFT = 1 << 63 120 MIN_TIMESTAMP_micros = MIN_TIMESTAMP.micros 121 MAX_TIMESTAMP_micros = MAX_TIMESTAMP.micros 122 123 IterableStateReader = Callable[[bytes, 'CoderImpl'], Iterable] 124 IterableStateWriter = Callable[[Iterable, 'CoderImpl'], bytes] 125 Observables = List[Tuple[observable.ObservableMixin, 'CoderImpl']] 126 127 128 class CoderImpl(object): 129 """For internal use only; no backwards-compatibility guarantees.""" 130 def encode_to_stream(self, value, stream, nested): 131 # type: (Any, create_OutputStream, bool) -> None 132 133 """Reads object from potentially-nested encoding in stream.""" 134 raise NotImplementedError 135 136 def decode_from_stream(self, stream, nested): 137 # type: (create_InputStream, bool) -> Any 138 139 """Reads object from potentially-nested encoding in stream.""" 140 raise NotImplementedError 141 142 def encode(self, value): 143 # type: (Any) -> bytes 144 145 """Encodes an object to an unnested string.""" 146 raise NotImplementedError 147 148 def decode(self, encoded): 149 # type: (bytes) -> Any 150 151 """Decodes an object to an unnested string.""" 152 raise NotImplementedError 153 154 def encode_all(self, values): 155 # type: (Iterable[Any]) -> bytes 156 out = create_OutputStream() 157 for value in values: 158 self.encode_to_stream(value, out, True) 159 return out.get() 160 161 def decode_all(self, encoded): 162 # type: (bytes) -> Iterator[Any] 163 input_stream = create_InputStream(encoded) 164 while input_stream.size() > 0: 165 yield self.decode_from_stream(input_stream, True) 166 167 def encode_nested(self, value): 168 # type: (Any) -> bytes 169 out = create_OutputStream() 170 self.encode_to_stream(value, out, True) 171 return out.get() 172 173 def decode_nested(self, encoded): 174 # type: (bytes) -> Any 175 return self.decode_from_stream(create_InputStream(encoded), True) 176 177 def estimate_size(self, value, nested=False): 178 # type: (Any, bool) -> int 179 180 """Estimates the encoded size of the given value, in bytes.""" 181 out = ByteCountingOutputStream() 182 self.encode_to_stream(value, out, nested) 183 return out.get_count() 184 185 def _get_nested_size(self, inner_size, nested): 186 if not nested: 187 return inner_size 188 varint_size = get_varint_size(inner_size) 189 return varint_size + inner_size 190 191 def get_estimated_size_and_observables(self, value, nested=False): 192 # type: (Any, bool) -> Tuple[int, Observables] 193 194 """Returns estimated size of value along with any nested observables. 195 196 The list of nested observables is returned as a list of 2-tuples of 197 (obj, coder_impl), where obj is an instance of observable.ObservableMixin, 198 and coder_impl is the CoderImpl that can be used to encode elements sent by 199 obj to its observers. 200 201 Arguments: 202 value: the value whose encoded size is to be estimated. 203 nested: whether the value is nested. 204 205 Returns: 206 The estimated encoded size of the given value and a list of observables 207 whose elements are 2-tuples of (obj, coder_impl) as described above. 208 """ 209 return self.estimate_size(value, nested), [] 210 211 212 class SimpleCoderImpl(CoderImpl): 213 """For internal use only; no backwards-compatibility guarantees. 214 215 Subclass of CoderImpl implementing stream methods using encode/decode.""" 216 def encode_to_stream(self, value, stream, nested): 217 # type: (Any, create_OutputStream, bool) -> None 218 219 """Reads object from potentially-nested encoding in stream.""" 220 stream.write(self.encode(value), nested) 221 222 def decode_from_stream(self, stream, nested): 223 # type: (create_InputStream, bool) -> Any 224 225 """Reads object from potentially-nested encoding in stream.""" 226 return self.decode(stream.read_all(nested)) 227 228 229 class StreamCoderImpl(CoderImpl): 230 """For internal use only; no backwards-compatibility guarantees. 231 232 Subclass of CoderImpl implementing encode/decode using stream methods.""" 233 def encode(self, value): 234 # type: (Any) -> bytes 235 out = create_OutputStream() 236 self.encode_to_stream(value, out, False) 237 return out.get() 238 239 def decode(self, encoded): 240 # type: (bytes) -> Any 241 return self.decode_from_stream(create_InputStream(encoded), False) 242 243 def estimate_size(self, value, nested=False): 244 # type: (Any, bool) -> int 245 246 """Estimates the encoded size of the given value, in bytes.""" 247 out = ByteCountingOutputStream() 248 self.encode_to_stream(value, out, nested) 249 return out.get_count() 250 251 252 class CallbackCoderImpl(CoderImpl): 253 """For internal use only; no backwards-compatibility guarantees. 254 255 A CoderImpl that calls back to the _impl methods on the Coder itself. 256 257 This is the default implementation used if Coder._get_impl() 258 is not overwritten. 259 """ 260 def __init__(self, encoder, decoder, size_estimator=None): 261 self._encoder = encoder 262 self._decoder = decoder 263 self._size_estimator = size_estimator or self._default_size_estimator 264 265 def _default_size_estimator(self, value): 266 return len(self.encode(value)) 267 268 def encode_to_stream(self, value, stream, nested): 269 # type: (Any, create_OutputStream, bool) -> None 270 return stream.write(self._encoder(value), nested) 271 272 def decode_from_stream(self, stream, nested): 273 # type: (create_InputStream, bool) -> Any 274 return self._decoder(stream.read_all(nested)) 275 276 def encode(self, value): 277 return self._encoder(value) 278 279 def decode(self, encoded): 280 return self._decoder(encoded) 281 282 def estimate_size(self, value, nested=False): 283 # type: (Any, bool) -> int 284 return self._get_nested_size(self._size_estimator(value), nested) 285 286 def get_estimated_size_and_observables(self, value, nested=False): 287 # type: (Any, bool) -> Tuple[int, Observables] 288 # TODO(robertwb): Remove this once all coders are correct. 289 if isinstance(value, observable.ObservableMixin): 290 # CallbackCoderImpl can presumably encode the elements too. 291 return 1, [(value, self)] 292 293 return self.estimate_size(value, nested), [] 294 295 def __repr__(self): 296 return 'CallbackCoderImpl[encoder=%s, decoder=%s]' % ( 297 self._encoder, self._decoder) 298 299 300 class ProtoCoderImpl(SimpleCoderImpl): 301 """For internal use only; no backwards-compatibility guarantees.""" 302 def __init__(self, proto_message_type): 303 self.proto_message_type = proto_message_type 304 305 def encode(self, value): 306 return value.SerializePartialToString() 307 308 def decode(self, encoded): 309 proto_message = self.proto_message_type() 310 proto_message.ParseFromString(encoded) # This is in effect "ParsePartial". 311 return proto_message 312 313 314 class DeterministicProtoCoderImpl(ProtoCoderImpl): 315 """For internal use only; no backwards-compatibility guarantees.""" 316 def encode(self, value): 317 return value.SerializePartialToString(deterministic=True) 318 319 320 class ProtoPlusCoderImpl(SimpleCoderImpl): 321 """For internal use only; no backwards-compatibility guarantees.""" 322 def __init__(self, proto_plus_type): 323 # type: (Type[proto.Message]) -> None 324 self.proto_plus_type = proto_plus_type 325 326 def encode(self, value): 327 return value._pb.SerializePartialToString(deterministic=True) 328 329 def decode(self, value): 330 return self.proto_plus_type.deserialize(value) 331 332 333 UNKNOWN_TYPE = 0xFF 334 NONE_TYPE = 0 335 INT_TYPE = 1 336 FLOAT_TYPE = 2 337 BYTES_TYPE = 3 338 UNICODE_TYPE = 4 339 BOOL_TYPE = 9 340 LIST_TYPE = 5 341 TUPLE_TYPE = 6 342 DICT_TYPE = 7 343 SET_TYPE = 8 344 ITERABLE_LIKE_TYPE = 10 345 346 PROTO_TYPE = 100 347 DATACLASS_TYPE = 101 348 NAMED_TUPLE_TYPE = 102 349 ENUM_TYPE = 103 350 NESTED_STATE_TYPE = 104 351 352 # Types that can be encoded as iterables, but are not literally 353 # lists, etc. due to being lazy. The actual type is not preserved 354 # through encoding, only the elements. This is particularly useful 355 # for the value list types created in GroupByKey. 356 _ITERABLE_LIKE_TYPES = set() # type: Set[Type] 357 358 359 class FastPrimitivesCoderImpl(StreamCoderImpl): 360 """For internal use only; no backwards-compatibility guarantees.""" 361 def __init__( 362 self, fallback_coder_impl, requires_deterministic_step_label=None): 363 self.fallback_coder_impl = fallback_coder_impl 364 self.iterable_coder_impl = IterableCoderImpl(self) 365 self.requires_deterministic_step_label = requires_deterministic_step_label 366 self.warn_deterministic_fallback = True 367 368 @staticmethod 369 def register_iterable_like_type(t): 370 _ITERABLE_LIKE_TYPES.add(t) 371 372 def get_estimated_size_and_observables(self, value, nested=False): 373 # type: (Any, bool) -> Tuple[int, Observables] 374 if isinstance(value, observable.ObservableMixin): 375 # FastPrimitivesCoderImpl can presumably encode the elements too. 376 return 1, [(value, self)] 377 378 out = ByteCountingOutputStream() 379 self.encode_to_stream(value, out, nested) 380 return out.get_count(), [] 381 382 def encode_to_stream(self, value, stream, nested): 383 # type: (Any, create_OutputStream, bool) -> None 384 t = type(value) 385 if value is None: 386 stream.write_byte(NONE_TYPE) 387 elif t is int: 388 # In Python 3, an int may be larger than 64 bits. 389 # We need to check whether value fits into a 64 bit integer before 390 # writing the marker byte. 391 try: 392 # In Cython-compiled code this will throw an overflow error 393 # when value does not fit into int64. 394 int_value = value 395 # If Cython is not used, we must do a (slower) check ourselves. 396 if not TYPE_CHECKING and not is_compiled: 397 if not fits_in_64_bits(value): 398 raise OverflowError() 399 stream.write_byte(INT_TYPE) 400 stream.write_var_int64(int_value) 401 except OverflowError: 402 stream.write_byte(UNKNOWN_TYPE) 403 self.fallback_coder_impl.encode_to_stream(value, stream, nested) 404 elif t is float: 405 stream.write_byte(FLOAT_TYPE) 406 stream.write_bigendian_double(value) 407 elif t is bytes: 408 stream.write_byte(BYTES_TYPE) 409 stream.write(value, nested) 410 elif t is str: 411 unicode_value = value # for typing 412 stream.write_byte(UNICODE_TYPE) 413 stream.write(unicode_value.encode('utf-8'), nested) 414 elif t is list or t is tuple: 415 stream.write_byte(LIST_TYPE if t is list else TUPLE_TYPE) 416 stream.write_var_int64(len(value)) 417 for e in value: 418 self.encode_to_stream(e, stream, True) 419 elif t is bool: 420 stream.write_byte(BOOL_TYPE) 421 stream.write_byte(value) 422 elif t in _ITERABLE_LIKE_TYPES: 423 stream.write_byte(ITERABLE_LIKE_TYPE) 424 self.iterable_coder_impl.encode_to_stream(value, stream, nested) 425 elif t is dict: 426 dict_value = value # for typing 427 stream.write_byte(DICT_TYPE) 428 stream.write_var_int64(len(dict_value)) 429 if self.requires_deterministic_step_label is not None: 430 try: 431 ordered_kvs = sorted(dict_value.items()) 432 except Exception as exn: 433 raise TypeError( 434 "Unable to deterministically order keys of dict for '%s'" % 435 self.requires_deterministic_step_label) from exn 436 for k, v in ordered_kvs: 437 self.encode_to_stream(k, stream, True) 438 self.encode_to_stream(v, stream, True) 439 else: 440 # Loop over dict.items() is optimized by Cython. 441 for k, v in dict_value.items(): 442 self.encode_to_stream(k, stream, True) 443 self.encode_to_stream(v, stream, True) 444 elif t is set: 445 stream.write_byte(SET_TYPE) 446 stream.write_var_int64(len(value)) 447 if self.requires_deterministic_step_label is not None: 448 try: 449 value = sorted(value) 450 except Exception as exn: 451 raise TypeError( 452 "Unable to deterministically order element of set for '%s'" % 453 self.requires_deterministic_step_label) from exn 454 for e in value: 455 self.encode_to_stream(e, stream, True) 456 # All possibly deterministic encodings should be above this clause, 457 # all non-deterministic ones below. 458 elif self.requires_deterministic_step_label is not None: 459 self.encode_special_deterministic(value, stream) 460 else: 461 stream.write_byte(UNKNOWN_TYPE) 462 self.fallback_coder_impl.encode_to_stream(value, stream, nested) 463 464 def encode_special_deterministic(self, value, stream): 465 if self.warn_deterministic_fallback: 466 _LOGGER.warning( 467 "Using fallback deterministic coder for type '%s' in '%s'. ", 468 type(value), 469 self.requires_deterministic_step_label) 470 self.warn_deterministic_fallback = False 471 if isinstance(value, proto_utils.message_types): 472 stream.write_byte(PROTO_TYPE) 473 self.encode_type(type(value), stream) 474 stream.write(value.SerializePartialToString(deterministic=True), True) 475 elif dataclasses and dataclasses.is_dataclass(value): 476 stream.write_byte(DATACLASS_TYPE) 477 if not type(value).__dataclass_params__.frozen: 478 raise TypeError( 479 "Unable to deterministically encode non-frozen '%s' of type '%s' " 480 "for the input of '%s'" % 481 (value, type(value), self.requires_deterministic_step_label)) 482 self.encode_type(type(value), stream) 483 values = [ 484 getattr(value, field.name) for field in dataclasses.fields(value) 485 ] 486 try: 487 self.iterable_coder_impl.encode_to_stream(values, stream, True) 488 except Exception as e: 489 raise TypeError(self._deterministic_encoding_error_msg(value)) from e 490 elif isinstance(value, tuple) and hasattr(type(value), '_fields'): 491 stream.write_byte(NAMED_TUPLE_TYPE) 492 self.encode_type(type(value), stream) 493 try: 494 self.iterable_coder_impl.encode_to_stream(value, stream, True) 495 except Exception as e: 496 raise TypeError(self._deterministic_encoding_error_msg(value)) from e 497 elif isinstance(value, enum.Enum): 498 stream.write_byte(ENUM_TYPE) 499 self.encode_type(type(value), stream) 500 # Enum values can be of any type. 501 try: 502 self.encode_to_stream(value.value, stream, True) 503 except Exception as e: 504 raise TypeError(self._deterministic_encoding_error_msg(value)) from e 505 elif hasattr(value, "__getstate__"): 506 if not hasattr(value, "__setstate__"): 507 raise TypeError( 508 "Unable to deterministically encode '%s' of type '%s', " 509 "for the input of '%s'. The object defines __getstate__ but not " 510 "__setstate__." % 511 (value, type(value), self.requires_deterministic_step_label)) 512 stream.write_byte(NESTED_STATE_TYPE) 513 self.encode_type(type(value), stream) 514 state_value = value.__getstate__() 515 try: 516 self.encode_to_stream(state_value, stream, True) 517 except Exception as e: 518 raise TypeError(self._deterministic_encoding_error_msg(value)) from e 519 else: 520 raise TypeError(self._deterministic_encoding_error_msg(value)) 521 522 def _deterministic_encoding_error_msg(self, value): 523 return ( 524 "Unable to deterministically encode '%s' of type '%s', " 525 "please provide a type hint for the input of '%s'" % 526 (value, type(value), self.requires_deterministic_step_label)) 527 528 def encode_type(self, t, stream): 529 stream.write(dill.dumps(t), True) 530 531 def decode_type(self, stream): 532 return _unpickle_type(stream.read_all(True)) 533 534 def decode_from_stream(self, stream, nested): 535 # type: (create_InputStream, bool) -> Any 536 t = stream.read_byte() 537 if t == NONE_TYPE: 538 return None 539 elif t == INT_TYPE: 540 return stream.read_var_int64() 541 elif t == FLOAT_TYPE: 542 return stream.read_bigendian_double() 543 elif t == BYTES_TYPE: 544 return stream.read_all(nested) 545 elif t == UNICODE_TYPE: 546 return stream.read_all(nested).decode('utf-8') 547 elif t == LIST_TYPE or t == TUPLE_TYPE or t == SET_TYPE: 548 vlen = stream.read_var_int64() 549 vlist = [self.decode_from_stream(stream, True) for _ in range(vlen)] 550 if t == LIST_TYPE: 551 return vlist 552 elif t == TUPLE_TYPE: 553 return tuple(vlist) 554 return set(vlist) 555 elif t == DICT_TYPE: 556 vlen = stream.read_var_int64() 557 v = {} 558 for _ in range(vlen): 559 k = self.decode_from_stream(stream, True) 560 v[k] = self.decode_from_stream(stream, True) 561 return v 562 elif t == BOOL_TYPE: 563 return not not stream.read_byte() 564 elif t == ITERABLE_LIKE_TYPE: 565 return self.iterable_coder_impl.decode_from_stream(stream, nested) 566 elif t == PROTO_TYPE: 567 cls = self.decode_type(stream) 568 msg = cls() 569 msg.ParseFromString(stream.read_all(True)) 570 return msg 571 elif t == DATACLASS_TYPE or t == NAMED_TUPLE_TYPE: 572 cls = self.decode_type(stream) 573 return cls(*self.iterable_coder_impl.decode_from_stream(stream, True)) 574 elif t == ENUM_TYPE: 575 cls = self.decode_type(stream) 576 return cls(self.decode_from_stream(stream, True)) 577 elif t == NESTED_STATE_TYPE: 578 cls = self.decode_type(stream) 579 state = self.decode_from_stream(stream, True) 580 value = cls.__new__(cls) 581 value.__setstate__(state) 582 return value 583 elif t == UNKNOWN_TYPE: 584 return self.fallback_coder_impl.decode_from_stream(stream, nested) 585 else: 586 raise ValueError('Unknown type tag %x' % t) 587 588 589 _unpickled_types = {} # type: Dict[bytes, type] 590 591 592 def _unpickle_type(bs): 593 t = _unpickled_types.get(bs, None) 594 if t is None: 595 t = _unpickled_types[bs] = dill.loads(bs) 596 # Fix unpicklable anonymous named tuples for Python 3.6. 597 if t.__base__ is tuple and hasattr(t, '_fields'): 598 try: 599 pickle.loads(pickle.dumps(t)) 600 except pickle.PicklingError: 601 t.__reduce__ = lambda self: (_unpickle_named_tuple, (bs, tuple(self))) 602 return t 603 604 605 def _unpickle_named_tuple(bs, items): 606 return _unpickle_type(bs)(*items) 607 608 609 class BytesCoderImpl(CoderImpl): 610 """For internal use only; no backwards-compatibility guarantees. 611 612 A coder for bytes/str objects.""" 613 def encode_to_stream(self, value, out, nested): 614 # type: (bytes, create_OutputStream, bool) -> None 615 616 # value might be of type np.bytes if passed from encode_batch, and cython 617 # does not recognize it as bytes. 618 if is_compiled and isinstance(value, np.bytes_): 619 value = bytes(value) 620 621 out.write(value, nested) 622 623 def decode_from_stream(self, in_stream, nested): 624 # type: (create_InputStream, bool) -> bytes 625 return in_stream.read_all(nested) 626 627 def encode(self, value): 628 assert isinstance(value, bytes), (value, type(value)) 629 return value 630 631 def decode(self, encoded): 632 return encoded 633 634 635 class BooleanCoderImpl(CoderImpl): 636 """For internal use only; no backwards-compatibility guarantees. 637 638 A coder for bool objects.""" 639 def encode_to_stream(self, value, out, nested): 640 out.write_byte(1 if value else 0) 641 642 def decode_from_stream(self, in_stream, nested): 643 value = in_stream.read_byte() 644 if value == 0: 645 return False 646 elif value == 1: 647 return True 648 raise ValueError("Expected 0 or 1, got %s" % value) 649 650 def encode(self, value): 651 return b'\x01' if value else b'\x00' 652 653 def decode(self, encoded): 654 value = ord(encoded) 655 if value == 0: 656 return False 657 elif value == 1: 658 return True 659 raise ValueError("Expected 0 or 1, got %s" % value) 660 661 def estimate_size(self, unused_value, nested=False): 662 # Note that booleans are encoded the same way regardless of nesting. 663 return 1 664 665 666 class MapCoderImpl(StreamCoderImpl): 667 """For internal use only; no backwards-compatibility guarantees. 668 669 Note this implementation always uses nested context when encoding keys 670 and values. This differs from Java's MapCoder, which uses 671 nested=False if possible for the last value encoded. 672 673 This difference is acceptable because MapCoder is not standard. It is only 674 used in a standard context by RowCoder which always uses nested context for 675 attribute values. 676 677 A coder for typing.Mapping objects.""" 678 def __init__( 679 self, 680 key_coder, # type: CoderImpl 681 value_coder, # type: CoderImpl 682 is_deterministic = False 683 ): 684 self._key_coder = key_coder 685 self._value_coder = value_coder 686 self._is_deterministic = is_deterministic 687 688 def encode_to_stream(self, dict_value, out, nested): 689 out.write_bigendian_int32(len(dict_value)) 690 # Note this implementation always uses nested context when encoding keys 691 # and values which differs from Java. See note in docstring. 692 if self._is_deterministic: 693 for key, value in sorted(dict_value.items()): 694 self._key_coder.encode_to_stream(key, out, True) 695 self._value_coder.encode_to_stream(value, out, True) 696 else: 697 # This loop is separate from the above so the loop over dict.items() 698 # will be optimized by Cython. 699 for key, value in dict_value.items(): 700 self._key_coder.encode_to_stream(key, out, True) 701 self._value_coder.encode_to_stream(value, out, True) 702 703 def decode_from_stream(self, in_stream, nested): 704 size = in_stream.read_bigendian_int32() 705 result = {} 706 for _ in range(size): 707 # Note this implementation always uses nested context when encoding keys 708 # and values which differs from Java. See note in docstring. 709 key = self._key_coder.decode_from_stream(in_stream, True) 710 value = self._value_coder.decode_from_stream(in_stream, True) 711 result[key] = value 712 713 return result 714 715 def estimate_size(self, unused_value, nested=False): 716 estimate = 4 # 4 bytes for int32 size prefix 717 for key, value in unused_value.items(): 718 estimate += self._key_coder.estimate_size(key, True) 719 estimate += self._value_coder.estimate_size(value, True) 720 return estimate 721 722 723 class NullableCoderImpl(StreamCoderImpl): 724 """For internal use only; no backwards-compatibility guarantees. 725 726 A coder for typing.Optional objects.""" 727 728 ENCODE_NULL = 0 729 ENCODE_PRESENT = 1 730 731 def __init__( 732 self, 733 value_coder # type: CoderImpl 734 ): 735 self._value_coder = value_coder 736 737 def encode_to_stream(self, value, out, nested): 738 if value is None: 739 out.write_byte(self.ENCODE_NULL) 740 else: 741 out.write_byte(self.ENCODE_PRESENT) 742 self._value_coder.encode_to_stream(value, out, nested) 743 744 def decode_from_stream(self, in_stream, nested): 745 null_indicator = in_stream.read_byte() 746 if null_indicator == self.ENCODE_NULL: 747 return None 748 elif null_indicator == self.ENCODE_PRESENT: 749 return self._value_coder.decode_from_stream(in_stream, nested) 750 else: 751 raise ValueError( 752 "Encountered unexpected value for null indicator: '%s'" % 753 null_indicator) 754 755 def estimate_size(self, unused_value, nested=False): 756 return 1 + ( 757 self._value_coder.estimate_size(unused_value) 758 if unused_value is not None else 0) 759 760 761 class BigEndianShortCoderImpl(StreamCoderImpl): 762 """For internal use only; no backwards-compatibility guarantees.""" 763 def encode_to_stream(self, value, out, nested): 764 # type: (int, create_OutputStream, bool) -> None 765 out.write_bigendian_int16(value) 766 767 def decode_from_stream(self, in_stream, nested): 768 # type: (create_InputStream, bool) -> float 769 return in_stream.read_bigendian_int16() 770 771 def estimate_size(self, unused_value, nested=False): 772 # type: (Any, bool) -> int 773 # A short is encoded as 2 bytes, regardless of nesting. 774 return 2 775 776 777 class SinglePrecisionFloatCoderImpl(StreamCoderImpl): 778 """For internal use only; no backwards-compatibility guarantees.""" 779 def encode_to_stream(self, value, out, nested): 780 # type: (float, create_OutputStream, bool) -> None 781 out.write_bigendian_float(value) 782 783 def decode_from_stream(self, in_stream, nested): 784 # type: (create_InputStream, bool) -> float 785 return in_stream.read_bigendian_float() 786 787 def estimate_size(self, unused_value, nested=False): 788 # type: (Any, bool) -> int 789 # A float is encoded as 4 bytes, regardless of nesting. 790 return 4 791 792 793 class FloatCoderImpl(StreamCoderImpl): 794 """For internal use only; no backwards-compatibility guarantees.""" 795 def encode_to_stream(self, value, out, nested): 796 # type: (float, create_OutputStream, bool) -> None 797 out.write_bigendian_double(value) 798 799 def decode_from_stream(self, in_stream, nested): 800 # type: (create_InputStream, bool) -> float 801 return in_stream.read_bigendian_double() 802 803 def estimate_size(self, unused_value, nested=False): 804 # type: (Any, bool) -> int 805 # A double is encoded as 8 bytes, regardless of nesting. 806 return 8 807 808 809 if not TYPE_CHECKING: 810 IntervalWindow = None 811 812 813 class IntervalWindowCoderImpl(StreamCoderImpl): 814 """For internal use only; no backwards-compatibility guarantees.""" 815 816 # TODO: Fn Harness only supports millis. Is this important enough to fix? 817 def _to_normal_time(self, value): 818 """Convert "lexicographically ordered unsigned" to signed.""" 819 return value - _TIME_SHIFT 820 821 def _from_normal_time(self, value): 822 """Convert signed to "lexicographically ordered unsigned".""" 823 return value + _TIME_SHIFT 824 825 def encode_to_stream(self, value, out, nested): 826 # type: (IntervalWindow, create_OutputStream, bool) -> None 827 typed_value = value 828 span_millis = ( 829 typed_value._end_micros // 1000 - typed_value._start_micros // 1000) 830 out.write_bigendian_uint64( 831 self._from_normal_time(typed_value._end_micros // 1000)) 832 out.write_var_int64(span_millis) 833 834 def decode_from_stream(self, in_, nested): 835 # type: (create_InputStream, bool) -> IntervalWindow 836 if not TYPE_CHECKING: 837 global IntervalWindow # pylint: disable=global-variable-not-assigned 838 if IntervalWindow is None: 839 from apache_beam.transforms.window import IntervalWindow 840 # instantiating with None is not part of the public interface 841 typed_value = IntervalWindow(None, None) # type: ignore[arg-type] 842 typed_value._end_micros = ( 843 1000 * self._to_normal_time(in_.read_bigendian_uint64())) 844 typed_value._start_micros = ( 845 typed_value._end_micros - 1000 * in_.read_var_int64()) 846 return typed_value 847 848 def estimate_size(self, value, nested=False): 849 # type: (Any, bool) -> int 850 # An IntervalWindow is context-insensitive, with a timestamp (8 bytes) 851 # and a varint timespam. 852 typed_value = value 853 span_millis = ( 854 typed_value._end_micros // 1000 - typed_value._start_micros // 1000) 855 return 8 + get_varint_size(span_millis) 856 857 858 class TimestampCoderImpl(StreamCoderImpl): 859 """For internal use only; no backwards-compatibility guarantees. 860 861 TODO: SDK agnostic encoding 862 For interoperability with Java SDK, encoding needs to match 863 that of the Java SDK InstantCoder. 864 https://github.com/apache/beam/blob/f5029b4f0dfff404310b2ef55e2632bbacc7b04f/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/InstantCoder.java#L79 865 """ 866 def encode_to_stream(self, value, out, nested): 867 # type: (Timestamp, create_OutputStream, bool) -> None 868 millis = value.micros // 1000 869 if millis >= 0: 870 millis = millis - _TIME_SHIFT 871 else: 872 millis = millis + _TIME_SHIFT 873 out.write_bigendian_int64(millis) 874 875 def decode_from_stream(self, in_stream, nested): 876 # type: (create_InputStream, bool) -> Timestamp 877 millis = in_stream.read_bigendian_int64() 878 if millis < 0: 879 millis = millis + _TIME_SHIFT 880 else: 881 millis = millis - _TIME_SHIFT 882 return Timestamp(micros=millis * 1000) 883 884 def estimate_size(self, unused_value, nested=False): 885 # A Timestamp is encoded as a 64-bit integer in 8 bytes, regardless of 886 # nesting. 887 return 8 888 889 890 class TimerCoderImpl(StreamCoderImpl): 891 """For internal use only; no backwards-compatibility guarantees.""" 892 def __init__(self, key_coder_impl, window_coder_impl): 893 self._timestamp_coder_impl = TimestampCoderImpl() 894 self._boolean_coder_impl = BooleanCoderImpl() 895 self._pane_info_coder_impl = PaneInfoCoderImpl() 896 self._key_coder_impl = key_coder_impl 897 self._windows_coder_impl = TupleSequenceCoderImpl(window_coder_impl) 898 from apache_beam.coders.coders import StrUtf8Coder 899 self._tag_coder_impl = StrUtf8Coder().get_impl() 900 901 def encode_to_stream(self, value, out, nested): 902 # type: (userstate.Timer, create_OutputStream, bool) -> None 903 self._key_coder_impl.encode_to_stream(value.user_key, out, True) 904 self._tag_coder_impl.encode_to_stream(value.dynamic_timer_tag, out, True) 905 self._windows_coder_impl.encode_to_stream(value.windows, out, True) 906 self._boolean_coder_impl.encode_to_stream(value.clear_bit, out, True) 907 if not value.clear_bit: 908 self._timestamp_coder_impl.encode_to_stream( 909 value.fire_timestamp, out, True) 910 self._timestamp_coder_impl.encode_to_stream( 911 value.hold_timestamp, out, True) 912 self._pane_info_coder_impl.encode_to_stream(value.paneinfo, out, True) 913 914 def decode_from_stream(self, in_stream, nested): 915 # type: (create_InputStream, bool) -> userstate.Timer 916 from apache_beam.transforms import userstate 917 user_key = self._key_coder_impl.decode_from_stream(in_stream, True) 918 dynamic_timer_tag = self._tag_coder_impl.decode_from_stream(in_stream, True) 919 windows = self._windows_coder_impl.decode_from_stream(in_stream, True) 920 clear_bit = self._boolean_coder_impl.decode_from_stream(in_stream, True) 921 if clear_bit: 922 return userstate.Timer( 923 user_key=user_key, 924 dynamic_timer_tag=dynamic_timer_tag, 925 windows=windows, 926 clear_bit=clear_bit, 927 fire_timestamp=None, 928 hold_timestamp=None, 929 paneinfo=None) 930 931 return userstate.Timer( 932 user_key=user_key, 933 dynamic_timer_tag=dynamic_timer_tag, 934 windows=windows, 935 clear_bit=clear_bit, 936 fire_timestamp=self._timestamp_coder_impl.decode_from_stream( 937 in_stream, True), 938 hold_timestamp=self._timestamp_coder_impl.decode_from_stream( 939 in_stream, True), 940 paneinfo=self._pane_info_coder_impl.decode_from_stream(in_stream, True)) 941 942 943 small_ints = [chr(_).encode('latin-1') for _ in range(128)] 944 945 946 class VarIntCoderImpl(StreamCoderImpl): 947 """For internal use only; no backwards-compatibility guarantees. 948 949 A coder for int objects.""" 950 def encode_to_stream(self, value, out, nested): 951 # type: (int, create_OutputStream, bool) -> None 952 out.write_var_int64(value) 953 954 def decode_from_stream(self, in_stream, nested): 955 # type: (create_InputStream, bool) -> int 956 return in_stream.read_var_int64() 957 958 def encode(self, value): 959 ivalue = value # type cast 960 if 0 <= ivalue < len(small_ints): 961 return small_ints[ivalue] 962 return StreamCoderImpl.encode(self, value) 963 964 def decode(self, encoded): 965 if len(encoded) == 1: 966 i = ord(encoded) 967 if 0 <= i < 128: 968 return i 969 return StreamCoderImpl.decode(self, encoded) 970 971 def estimate_size(self, value, nested=False): 972 # type: (Any, bool) -> int 973 # Note that VarInts are encoded the same way regardless of nesting. 974 return get_varint_size(value) 975 976 977 class SingletonCoderImpl(CoderImpl): 978 """For internal use only; no backwards-compatibility guarantees. 979 980 A coder that always encodes exactly one value.""" 981 def __init__(self, value): 982 self._value = value 983 984 def encode_to_stream(self, value, stream, nested): 985 # type: (Any, create_OutputStream, bool) -> None 986 pass 987 988 def decode_from_stream(self, stream, nested): 989 # type: (create_InputStream, bool) -> Any 990 return self._value 991 992 def encode(self, value): 993 b = b'' # avoid byte vs str vs unicode error 994 return b 995 996 def decode(self, encoded): 997 return self._value 998 999 def estimate_size(self, value, nested=False): 1000 # type: (Any, bool) -> int 1001 return 0 1002 1003 1004 class AbstractComponentCoderImpl(StreamCoderImpl): 1005 """For internal use only; no backwards-compatibility guarantees. 1006 1007 CoderImpl for coders that are comprised of several component coders.""" 1008 def __init__(self, coder_impls): 1009 for c in coder_impls: 1010 assert isinstance(c, CoderImpl), c 1011 self._coder_impls = tuple(coder_impls) 1012 1013 def _extract_components(self, value): 1014 raise NotImplementedError 1015 1016 def _construct_from_components(self, components): 1017 raise NotImplementedError 1018 1019 def encode_to_stream(self, value, out, nested): 1020 # type: (Any, create_OutputStream, bool) -> None 1021 values = self._extract_components(value) 1022 if len(self._coder_impls) != len(values): 1023 raise ValueError('Number of components does not match number of coders.') 1024 for i in range(0, len(self._coder_impls)): 1025 c = self._coder_impls[i] # type cast 1026 c.encode_to_stream( 1027 values[i], out, nested or i + 1 < len(self._coder_impls)) 1028 1029 def decode_from_stream(self, in_stream, nested): 1030 # type: (create_InputStream, bool) -> Any 1031 return self._construct_from_components([ 1032 c.decode_from_stream( 1033 in_stream, nested or i + 1 < len(self._coder_impls)) for i, 1034 c in enumerate(self._coder_impls) 1035 ]) 1036 1037 def estimate_size(self, value, nested=False): 1038 # type: (Any, bool) -> int 1039 1040 """Estimates the encoded size of the given value, in bytes.""" 1041 # TODO(ccy): This ignores sizes of observable components. 1042 estimated_size, _ = (self.get_estimated_size_and_observables(value)) 1043 return estimated_size 1044 1045 def get_estimated_size_and_observables(self, value, nested=False): 1046 # type: (Any, bool) -> Tuple[int, Observables] 1047 1048 """Returns estimated size of value along with any nested observables.""" 1049 values = self._extract_components(value) 1050 estimated_size = 0 1051 observables = [] # type: Observables 1052 for i in range(0, len(self._coder_impls)): 1053 c = self._coder_impls[i] # type cast 1054 child_size, child_observables = ( 1055 c.get_estimated_size_and_observables( 1056 values[i], nested=nested or i + 1 < len(self._coder_impls))) 1057 estimated_size += child_size 1058 observables += child_observables 1059 return estimated_size, observables 1060 1061 1062 class AvroCoderImpl(SimpleCoderImpl): 1063 """For internal use only; no backwards-compatibility guarantees.""" 1064 def __init__(self, schema): 1065 self.parsed_schema = parse_schema(json.loads(schema)) 1066 1067 def encode(self, value): 1068 assert issubclass(type(value), AvroRecord) 1069 with BytesIO() as buf: 1070 schemaless_writer(buf, self.parsed_schema, value.record) 1071 return buf.getvalue() 1072 1073 def decode(self, encoded): 1074 with BytesIO(encoded) as buf: 1075 return AvroRecord(schemaless_reader(buf, self.parsed_schema)) 1076 1077 1078 class TupleCoderImpl(AbstractComponentCoderImpl): 1079 """A coder for tuple objects.""" 1080 def _extract_components(self, value): 1081 return tuple(value) 1082 1083 def _construct_from_components(self, components): 1084 return tuple(components) 1085 1086 1087 class _ConcatSequence(object): 1088 def __init__(self, head, tail): 1089 # type: (Iterable[Any], Iterable[Any]) -> None 1090 self._head = head 1091 self._tail = tail 1092 1093 def __iter__(self): 1094 # type: () -> Iterator[Any] 1095 for elem in self._head: 1096 yield elem 1097 for elem in self._tail: 1098 yield elem 1099 1100 def __eq__(self, other): 1101 return list(self) == list(other) 1102 1103 def __hash__(self): 1104 raise NotImplementedError 1105 1106 def __reduce__(self): 1107 return list, (list(self), ) 1108 1109 1110 FastPrimitivesCoderImpl.register_iterable_like_type(_ConcatSequence) 1111 1112 1113 class SequenceCoderImpl(StreamCoderImpl): 1114 """For internal use only; no backwards-compatibility guarantees. 1115 1116 A coder for sequences. 1117 1118 If the length of the sequence in known we encode the length as a 32 bit 1119 ``int`` followed by the encoded bytes. 1120 1121 If the length of the sequence is unknown, we encode the length as ``-1`` 1122 followed by the encoding of elements buffered up to 64K bytes before prefixing 1123 the count of number of elements. A ``0`` is encoded at the end to indicate the 1124 end of stream. 1125 1126 The resulting encoding would look like this:: 1127 1128 -1 1129 countA element(0) element(1) ... element(countA - 1) 1130 countB element(0) element(1) ... element(countB - 1) 1131 ... 1132 countX element(0) element(1) ... element(countX - 1) 1133 0 1134 1135 If writing to state is enabled, the final terminating 0 will instead be 1136 repaced with:: 1137 1138 varInt64(-1) 1139 len(state_token) 1140 state_token 1141 1142 where state_token is a bytes object used to retrieve the remainder of the 1143 iterable via the state API. 1144 """ 1145 1146 # Default buffer size of 64kB of handling iterables of unknown length. 1147 _DEFAULT_BUFFER_SIZE = 64 * 1024 1148 1149 def __init__( 1150 self, 1151 elem_coder, # type: CoderImpl 1152 read_state=None, # type: Optional[IterableStateReader] 1153 write_state=None, # type: Optional[IterableStateWriter] 1154 write_state_threshold=0 # type: int 1155 ): 1156 self._elem_coder = elem_coder 1157 self._read_state = read_state 1158 self._write_state = write_state 1159 self._write_state_threshold = write_state_threshold 1160 1161 def _construct_from_sequence(self, values): 1162 raise NotImplementedError 1163 1164 def encode_to_stream(self, value, out, nested): 1165 # type: (Sequence, create_OutputStream, bool) -> None 1166 # Compatible with Java's IterableLikeCoder. 1167 if hasattr(value, '__len__') and self._write_state is None: 1168 out.write_bigendian_int32(len(value)) 1169 for elem in value: 1170 self._elem_coder.encode_to_stream(elem, out, True) 1171 else: 1172 # We don't know the size without traversing it so use a fixed size buffer 1173 # and encode as many elements as possible into it before outputting 1174 # the size followed by the elements. 1175 1176 # -1 to indicate that the length is not known. 1177 out.write_bigendian_int32(-1) 1178 buffer = create_OutputStream() 1179 if self._write_state is None: 1180 target_buffer_size = self._DEFAULT_BUFFER_SIZE 1181 else: 1182 target_buffer_size = min( 1183 self._DEFAULT_BUFFER_SIZE, self._write_state_threshold) 1184 prev_index = index = -1 1185 # Don't want to miss out on fast list iteration optimization. 1186 value_iter = value if isinstance(value, (list, tuple)) else iter(value) 1187 start_size = out.size() 1188 for elem in value_iter: 1189 index += 1 1190 self._elem_coder.encode_to_stream(elem, buffer, True) 1191 if buffer.size() > target_buffer_size: 1192 out.write_var_int64(index - prev_index) 1193 out.write(buffer.get()) 1194 prev_index = index 1195 buffer = create_OutputStream() 1196 if (self._write_state is not None and 1197 out.size() - start_size > self._write_state_threshold): 1198 tail = ( 1199 value_iter[index + 1200 1:] if isinstance(value_iter, 1201 (list, tuple)) else value_iter) 1202 state_token = self._write_state(tail, self._elem_coder) 1203 out.write_var_int64(-1) 1204 out.write(state_token, True) 1205 break 1206 else: 1207 if index > prev_index: 1208 out.write_var_int64(index - prev_index) 1209 out.write(buffer.get()) 1210 out.write_var_int64(0) 1211 1212 def decode_from_stream(self, in_stream, nested): 1213 # type: (create_InputStream, bool) -> Sequence 1214 size = in_stream.read_bigendian_int32() 1215 1216 if size >= 0: 1217 elements = [ 1218 self._elem_coder.decode_from_stream(in_stream, True) 1219 for _ in range(size) 1220 ] # type: Iterable[Any] 1221 else: 1222 elements = [] 1223 count = in_stream.read_var_int64() 1224 while count > 0: 1225 for _ in range(count): 1226 elements.append(self._elem_coder.decode_from_stream(in_stream, True)) 1227 count = in_stream.read_var_int64() 1228 1229 if count == -1: 1230 if self._read_state is None: 1231 raise ValueError( 1232 'Cannot read state-written iterable without state reader.') 1233 1234 state_token = in_stream.read_all(True) 1235 elements = _ConcatSequence( 1236 elements, self._read_state(state_token, self._elem_coder)) 1237 1238 return self._construct_from_sequence(elements) 1239 1240 def estimate_size(self, value, nested=False): 1241 # type: (Any, bool) -> int 1242 1243 """Estimates the encoded size of the given value, in bytes.""" 1244 # TODO(ccy): This ignores element sizes. 1245 estimated_size, _ = (self.get_estimated_size_and_observables(value)) 1246 return estimated_size 1247 1248 def get_estimated_size_and_observables(self, value, nested=False): 1249 # type: (Any, bool) -> Tuple[int, Observables] 1250 1251 """Returns estimated size of value along with any nested observables.""" 1252 estimated_size = 0 1253 # Size of 32-bit integer storing number of elements. 1254 estimated_size += 4 1255 if isinstance(value, observable.ObservableMixin): 1256 return estimated_size, [(value, self._elem_coder)] 1257 1258 observables = [] # type: Observables 1259 for elem in value: 1260 child_size, child_observables = ( 1261 self._elem_coder.get_estimated_size_and_observables( 1262 elem, nested=True)) 1263 estimated_size += child_size 1264 observables += child_observables 1265 # TODO: (https://github.com/apache/beam/issues/18169) Update to use an 1266 # accurate count depending on size and count, currently we are 1267 # underestimating the size by up to 10 bytes per block of data since we are 1268 # not including the count prefix which occurs at most once per 64k of data 1269 # and is upto 10 bytes long. The upper bound of the underestimate is 1270 # 10 / 65536 ~= 0.0153% of the actual size. 1271 # TODO: More efficient size estimation in the case of state-backed 1272 # iterables. 1273 return estimated_size, observables 1274 1275 1276 class TupleSequenceCoderImpl(SequenceCoderImpl): 1277 """For internal use only; no backwards-compatibility guarantees. 1278 1279 A coder for homogeneous tuple objects.""" 1280 def _construct_from_sequence(self, components): 1281 return tuple(components) 1282 1283 1284 class _AbstractIterable(object): 1285 """Wraps an iterable hiding methods that might not always be available.""" 1286 def __init__(self, contents): 1287 self._contents = contents 1288 1289 def __iter__(self): 1290 return iter(self._contents) 1291 1292 def __repr__(self): 1293 head = [repr(e) for e in itertools.islice(self, 4)] 1294 if len(head) == 4: 1295 head[-1] = '...' 1296 return '_AbstractIterable([%s])' % ', '.join(head) 1297 1298 # Mostly useful for tests. 1299 def __eq__(left, right): 1300 end = object() 1301 for a, b in itertools.zip_longest(left, right, fillvalue=end): 1302 if a != b: 1303 return False 1304 return True 1305 1306 1307 FastPrimitivesCoderImpl.register_iterable_like_type(_AbstractIterable) 1308 1309 # TODO(https://github.com/apache/beam/issues/21167): Enable using abstract 1310 # iterables permanently 1311 _iterable_coder_uses_abstract_iterable_by_default = False 1312 1313 1314 class IterableCoderImpl(SequenceCoderImpl): 1315 """For internal use only; no backwards-compatibility guarantees. 1316 1317 A coder for homogeneous iterable objects.""" 1318 def __init__(self, *args, use_abstract_iterable=None, **kwargs): 1319 super().__init__(*args, **kwargs) 1320 if use_abstract_iterable is None: 1321 use_abstract_iterable = _iterable_coder_uses_abstract_iterable_by_default 1322 self._use_abstract_iterable = use_abstract_iterable 1323 1324 def _construct_from_sequence(self, components): 1325 if self._use_abstract_iterable: 1326 return _AbstractIterable(components) 1327 else: 1328 return components 1329 1330 1331 class ListCoderImpl(SequenceCoderImpl): 1332 """For internal use only; no backwards-compatibility guarantees. 1333 1334 A coder for homogeneous list objects.""" 1335 def _construct_from_sequence(self, components): 1336 return components if isinstance(components, list) else list(components) 1337 1338 1339 class PaneInfoEncoding(object): 1340 """For internal use only; no backwards-compatibility guarantees. 1341 1342 Encoding used to describe a PaneInfo descriptor. A PaneInfo descriptor 1343 can be encoded in three different ways: with a single byte (FIRST), with a 1344 single byte followed by a varint describing a single index (ONE_INDEX) or 1345 with a single byte followed by two varints describing two separate indices: 1346 the index and nonspeculative index. 1347 """ 1348 1349 FIRST = 0 1350 ONE_INDEX = 1 1351 TWO_INDICES = 2 1352 1353 1354 # These are cdef'd to ints to optimized the common case. 1355 PaneInfoTiming_UNKNOWN = windowed_value.PaneInfoTiming.UNKNOWN 1356 PaneInfoEncoding_FIRST = PaneInfoEncoding.FIRST 1357 1358 1359 class PaneInfoCoderImpl(StreamCoderImpl): 1360 """For internal use only; no backwards-compatibility guarantees. 1361 1362 Coder for a PaneInfo descriptor.""" 1363 def _choose_encoding(self, value): 1364 if ((value._index == 0 and value._nonspeculative_index == 0) or 1365 value._timing == PaneInfoTiming_UNKNOWN): 1366 return PaneInfoEncoding_FIRST 1367 elif (value._index == value._nonspeculative_index or 1368 value._timing == windowed_value.PaneInfoTiming.EARLY): 1369 return PaneInfoEncoding.ONE_INDEX 1370 else: 1371 return PaneInfoEncoding.TWO_INDICES 1372 1373 def encode_to_stream(self, value, out, nested): 1374 # type: (windowed_value.PaneInfo, create_OutputStream, bool) -> None 1375 pane_info = value # cast 1376 encoding_type = self._choose_encoding(pane_info) 1377 out.write_byte(pane_info._encoded_byte | (encoding_type << 4)) 1378 if encoding_type == PaneInfoEncoding_FIRST: 1379 return 1380 elif encoding_type == PaneInfoEncoding.ONE_INDEX: 1381 out.write_var_int64(value.index) 1382 elif encoding_type == PaneInfoEncoding.TWO_INDICES: 1383 out.write_var_int64(value.index) 1384 out.write_var_int64(value.nonspeculative_index) 1385 else: 1386 raise NotImplementedError('Invalid PaneInfoEncoding: %s' % encoding_type) 1387 1388 def decode_from_stream(self, in_stream, nested): 1389 # type: (create_InputStream, bool) -> windowed_value.PaneInfo 1390 encoded_first_byte = in_stream.read_byte() 1391 base = windowed_value._BYTE_TO_PANE_INFO[encoded_first_byte & 0xF] 1392 assert base is not None 1393 encoding_type = encoded_first_byte >> 4 1394 if encoding_type == PaneInfoEncoding_FIRST: 1395 return base 1396 elif encoding_type == PaneInfoEncoding.ONE_INDEX: 1397 index = in_stream.read_var_int64() 1398 if base.timing == windowed_value.PaneInfoTiming.EARLY: 1399 nonspeculative_index = -1 1400 else: 1401 nonspeculative_index = index 1402 elif encoding_type == PaneInfoEncoding.TWO_INDICES: 1403 index = in_stream.read_var_int64() 1404 nonspeculative_index = in_stream.read_var_int64() 1405 else: 1406 raise NotImplementedError('Invalid PaneInfoEncoding: %s' % encoding_type) 1407 return windowed_value.PaneInfo( 1408 base.is_first, base.is_last, base.timing, index, nonspeculative_index) 1409 1410 def estimate_size(self, value, nested=False): 1411 # type: (Any, bool) -> int 1412 1413 """Estimates the encoded size of the given value, in bytes.""" 1414 size = 1 1415 encoding_type = self._choose_encoding(value) 1416 if encoding_type == PaneInfoEncoding.ONE_INDEX: 1417 size += get_varint_size(value.index) 1418 elif encoding_type == PaneInfoEncoding.TWO_INDICES: 1419 size += get_varint_size(value.index) 1420 size += get_varint_size(value.nonspeculative_index) 1421 return size 1422 1423 1424 class WindowedValueCoderImpl(StreamCoderImpl): 1425 """For internal use only; no backwards-compatibility guarantees. 1426 1427 A coder for windowed values.""" 1428 1429 # Ensure that lexicographic ordering of the bytes corresponds to 1430 # chronological order of timestamps. 1431 # TODO(https://github.com/apache/beam/issues/18190): Clean this up once we 1432 # have a BEAM wide consensus on byte representation of timestamps. 1433 def _to_normal_time(self, value): 1434 """Convert "lexicographically ordered unsigned" to signed.""" 1435 return value - _TIME_SHIFT 1436 1437 def _from_normal_time(self, value): 1438 """Convert signed to "lexicographically ordered unsigned".""" 1439 return value + _TIME_SHIFT 1440 1441 def __init__(self, value_coder, timestamp_coder, window_coder): 1442 # TODO(lcwik): Remove the timestamp coder field 1443 self._value_coder = value_coder 1444 self._timestamp_coder = timestamp_coder 1445 self._windows_coder = TupleSequenceCoderImpl(window_coder) 1446 self._pane_info_coder = PaneInfoCoderImpl() 1447 1448 def encode_to_stream(self, value, out, nested): 1449 # type: (windowed_value.WindowedValue, create_OutputStream, bool) -> None 1450 wv = value # type cast 1451 # Avoid creation of Timestamp object. 1452 restore_sign = -1 if wv.timestamp_micros < 0 else 1 1453 out.write_bigendian_uint64( 1454 # Convert to postive number and divide, since python rounds off to the 1455 # lower negative number. For ex: -3 / 2 = -2, but we expect it to be -1, 1456 # to be consistent across SDKs. 1457 # TODO(https://github.com/apache/beam/issues/18190): Clean this up once 1458 # we have a BEAM wide consensus on precision of timestamps. 1459 self._from_normal_time( 1460 restore_sign * ( 1461 abs( 1462 MIN_TIMESTAMP_micros if wv.timestamp_micros < 1463 MIN_TIMESTAMP_micros else wv.timestamp_micros) // 1000))) 1464 self._windows_coder.encode_to_stream(wv.windows, out, True) 1465 # Default PaneInfo encoded byte representing NO_FIRING. 1466 self._pane_info_coder.encode_to_stream(wv.pane_info, out, True) 1467 self._value_coder.encode_to_stream(wv.value, out, nested) 1468 1469 def decode_from_stream(self, in_stream, nested): 1470 # type: (create_InputStream, bool) -> windowed_value.WindowedValue 1471 timestamp = self._to_normal_time(in_stream.read_bigendian_uint64()) 1472 # Restore MIN/MAX timestamps to their actual values as encoding incurs loss 1473 # of precision while converting to millis. 1474 # Note: This is only a best effort here as there is no way to know if these 1475 # were indeed MIN/MAX timestamps. 1476 # TODO(https://github.com/apache/beam/issues/18190): Clean this up once we 1477 # have a BEAM wide consensus on precision of timestamps. 1478 if timestamp <= -(abs(MIN_TIMESTAMP_micros) // 1000): 1479 timestamp = MIN_TIMESTAMP_micros 1480 elif timestamp >= MAX_TIMESTAMP_micros // 1000: 1481 timestamp = MAX_TIMESTAMP_micros 1482 else: 1483 timestamp *= 1000 1484 1485 windows = self._windows_coder.decode_from_stream(in_stream, True) 1486 # Read PaneInfo encoded byte. 1487 pane_info = self._pane_info_coder.decode_from_stream(in_stream, True) 1488 value = self._value_coder.decode_from_stream(in_stream, nested) 1489 return windowed_value.create( 1490 value, 1491 timestamp, # Avoid creation of Timestamp object. 1492 windows, 1493 pane_info) 1494 1495 def get_estimated_size_and_observables(self, value, nested=False): 1496 # type: (Any, bool) -> Tuple[int, Observables] 1497 1498 """Returns estimated size of value along with any nested observables.""" 1499 if isinstance(value, observable.ObservableMixin): 1500 # Should never be here. 1501 # TODO(robertwb): Remove when coders are set correctly. 1502 return 0, [(value, self._value_coder)] 1503 estimated_size = 0 1504 observables = [] # type: Observables 1505 value_estimated_size, value_observables = ( 1506 self._value_coder.get_estimated_size_and_observables( 1507 value.value, nested=nested)) 1508 estimated_size += value_estimated_size 1509 observables += value_observables 1510 estimated_size += ( 1511 self._timestamp_coder.estimate_size(value.timestamp, nested=True)) 1512 estimated_size += ( 1513 self._windows_coder.estimate_size(value.windows, nested=True)) 1514 estimated_size += ( 1515 self._pane_info_coder.estimate_size(value.pane_info, nested=True)) 1516 return estimated_size, observables 1517 1518 1519 class ParamWindowedValueCoderImpl(WindowedValueCoderImpl): 1520 """For internal use only; no backwards-compatibility guarantees. 1521 1522 A coder for windowed values with constant timestamp, windows and 1523 pane info. The coder drops timestamp, windows and pane info during 1524 encoding, and uses the supplied parameterized timestamp, windows 1525 and pane info values during decoding when reconstructing the windowed 1526 value.""" 1527 def __init__(self, value_coder, window_coder, payload): 1528 super().__init__(value_coder, TimestampCoderImpl(), window_coder) 1529 self._timestamp, self._windows, self._pane_info = self._from_proto( 1530 payload, window_coder) 1531 1532 def _from_proto(self, payload, window_coder): 1533 windowed_value_coder = WindowedValueCoderImpl( 1534 BytesCoderImpl(), TimestampCoderImpl(), window_coder) 1535 wv = windowed_value_coder.decode(payload) 1536 return wv.timestamp_micros, wv.windows, wv.pane_info 1537 1538 def encode_to_stream(self, value, out, nested): 1539 wv = value # type cast 1540 self._value_coder.encode_to_stream(wv.value, out, nested) 1541 1542 def decode_from_stream(self, in_stream, nested): 1543 value = self._value_coder.decode_from_stream(in_stream, nested) 1544 return windowed_value.create( 1545 value, self._timestamp, self._windows, self._pane_info) 1546 1547 def get_estimated_size_and_observables(self, value, nested=False): 1548 """Returns estimated size of value along with any nested observables.""" 1549 if isinstance(value, observable.ObservableMixin): 1550 # Should never be here. 1551 # TODO(robertwb): Remove when coders are set correctly. 1552 return 0, [(value, self._value_coder)] 1553 estimated_size = 0 1554 observables = [] 1555 value_estimated_size, value_observables = ( 1556 self._value_coder.get_estimated_size_and_observables( 1557 value.value, nested=nested)) 1558 estimated_size += value_estimated_size 1559 observables += value_observables 1560 return estimated_size, observables 1561 1562 1563 class LengthPrefixCoderImpl(StreamCoderImpl): 1564 """For internal use only; no backwards-compatibility guarantees. 1565 1566 Coder which prefixes the length of the encoded object in the stream.""" 1567 def __init__(self, value_coder): 1568 # type: (CoderImpl) -> None 1569 self._value_coder = value_coder 1570 1571 def encode_to_stream(self, value, out, nested): 1572 # type: (Any, create_OutputStream, bool) -> None 1573 encoded_value = self._value_coder.encode(value) 1574 out.write_var_int64(len(encoded_value)) 1575 out.write(encoded_value) 1576 1577 def decode_from_stream(self, in_stream, nested): 1578 # type: (create_InputStream, bool) -> Any 1579 value_length = in_stream.read_var_int64() 1580 return self._value_coder.decode(in_stream.read(value_length)) 1581 1582 def estimate_size(self, value, nested=False): 1583 # type: (Any, bool) -> int 1584 value_size = self._value_coder.estimate_size(value) 1585 return get_varint_size(value_size) + value_size 1586 1587 1588 class ShardedKeyCoderImpl(StreamCoderImpl): 1589 """For internal use only; no backwards-compatibility guarantees. 1590 1591 A coder for sharded user keys. 1592 1593 The encoding and decoding should follow the order: 1594 shard id byte string 1595 encoded user key 1596 """ 1597 def __init__(self, key_coder_impl): 1598 self._shard_id_coder_impl = BytesCoderImpl() 1599 self._key_coder_impl = key_coder_impl 1600 1601 def encode_to_stream(self, value, out, nested): 1602 # type: (ShardedKey, create_OutputStream, bool) -> None 1603 self._shard_id_coder_impl.encode_to_stream(value._shard_id, out, True) 1604 self._key_coder_impl.encode_to_stream(value.key, out, True) 1605 1606 def decode_from_stream(self, in_stream, nested): 1607 # type: (create_InputStream, bool) -> ShardedKey 1608 shard_id = self._shard_id_coder_impl.decode_from_stream(in_stream, True) 1609 key = self._key_coder_impl.decode_from_stream(in_stream, True) 1610 return ShardedKey(key=key, shard_id=shard_id) 1611 1612 def estimate_size(self, value, nested=False): 1613 # type: (Any, bool) -> int 1614 estimated_size = 0 1615 estimated_size += ( 1616 self._shard_id_coder_impl.estimate_size(value._shard_id, nested=True)) 1617 estimated_size += ( 1618 self._key_coder_impl.estimate_size(value.key, nested=True)) 1619 return estimated_size 1620 1621 1622 class TimestampPrefixingWindowCoderImpl(StreamCoderImpl): 1623 """For internal use only; no backwards-compatibility guarantees. 1624 1625 A coder for custom window types, which prefix required max_timestamp to 1626 encoded original window. 1627 1628 The coder encodes and decodes custom window types with following format: 1629 window's max_timestamp() 1630 encoded window using it's own coder. 1631 """ 1632 def __init__(self, window_coder_impl: CoderImpl) -> None: 1633 self._window_coder_impl = window_coder_impl 1634 1635 def encode_to_stream(self, value, stream, nested): 1636 TimestampCoderImpl().encode_to_stream(value.max_timestamp(), stream, nested) 1637 self._window_coder_impl.encode_to_stream(value, stream, nested) 1638 1639 def decode_from_stream(self, stream, nested): 1640 TimestampCoderImpl().decode_from_stream(stream, nested) 1641 return self._window_coder_impl.decode_from_stream(stream, nested) 1642 1643 def estimate_size(self, value: Any, nested: bool = False) -> int: 1644 estimated_size = 0 1645 estimated_size += TimestampCoderImpl().estimate_size(value) 1646 estimated_size += self._window_coder_impl.estimate_size(value, nested) 1647 return estimated_size 1648 1649 1650 row_coders_registered = False 1651 1652 1653 class RowColumnEncoder: 1654 ROW_ENCODERS = {1: 12345} 1655 1656 @classmethod 1657 def register(cls, field_type, coder_impl): 1658 cls.ROW_ENCODERS[field_type, coder_impl] = cls 1659 1660 @classmethod 1661 def create(cls, field_type, coder_impl, column): 1662 global row_coders_registered 1663 if not row_coders_registered: 1664 try: 1665 # pylint: disable=unused-import 1666 from apache_beam.coders import coder_impl_row_encoders 1667 except ImportError: 1668 pass 1669 row_coders_registered = True 1670 return cls.ROW_ENCODERS.get((field_type, column.dtype), 1671 GenericRowColumnEncoder)(coder_impl, column) 1672 1673 def null_flags(self): 1674 raise NotImplementedError(type(self)) 1675 1676 def encode_to_stream(self, index, out): 1677 raise NotImplementedError(type(self)) 1678 1679 def decode_from_stream(self, index, in_stream): 1680 raise NotImplementedError(type(self)) 1681 1682 def finalize_write(self): 1683 raise NotImplementedError(type(self)) 1684 1685 1686 class GenericRowColumnEncoder(RowColumnEncoder): 1687 def __init__(self, coder_impl, column): 1688 self.coder_impl = coder_impl 1689 self.column = column 1690 1691 def null_flags(self): 1692 # pylint: disable=singleton-comparison 1693 return self.column == None # This is an array. 1694 1695 def encode_to_stream(self, index, out): 1696 self.coder_impl.encode_to_stream(self.column[index], out, True) 1697 1698 def decode_from_stream(self, index, in_stream): 1699 self.column[index] = self.coder_impl.decode_from_stream(in_stream, True) 1700 1701 def finalize_write(self): 1702 pass 1703 1704 1705 class RowCoderImpl(StreamCoderImpl): 1706 """For internal use only; no backwards-compatibility guarantees.""" 1707 def __init__(self, schema, components): 1708 self.schema = schema 1709 self.num_fields = len(self.schema.fields) 1710 self.field_names = [f.name for f in self.schema.fields] 1711 self.field_nullable = [field.type.nullable for field in self.schema.fields] 1712 self.constructor = named_tuple_from_schema(schema) 1713 self.encoding_positions = list(range(len(self.schema.fields))) 1714 if self.schema.encoding_positions_set: 1715 # should never be duplicate encoding positions. 1716 enc_posx = list( 1717 set(field.encoding_position for field in self.schema.fields)) 1718 if len(enc_posx) != len(self.schema.fields): 1719 raise ValueError( 1720 f'''Schema with id {schema.id} has encoding_positions_set=True, 1721 but not all fields have encoding_position set''') 1722 self.encoding_positions = list( 1723 field.encoding_position for field in self.schema.fields) 1724 self.encoding_positions_argsort = list(np.argsort(self.encoding_positions)) 1725 self.encoding_positions_are_trivial = self.encoding_positions == list( 1726 range(len(self.encoding_positions))) 1727 self.components = list( 1728 components[self.encoding_positions.index(i)].get_impl() 1729 for i in self.encoding_positions) 1730 self.has_nullable_fields = any( 1731 field.type.nullable for field in self.schema.fields) 1732 1733 def encode_to_stream(self, value, out, nested): 1734 out.write_var_int64(self.num_fields) 1735 attrs = [getattr(value, name) for name in self.field_names] 1736 1737 if self.has_nullable_fields: 1738 any_nulls = False 1739 for attr in attrs: 1740 if attr is None: 1741 any_nulls = True 1742 break 1743 if any_nulls: 1744 out.write_var_int64((self.num_fields + 7) // 8) 1745 # Pack the bits, little-endian, in consecutive bytes. 1746 running = 0 1747 for i, attr in enumerate(attrs): 1748 if i and i % 8 == 0: 1749 out.write_byte(running) 1750 running = 0 1751 running |= (attr is None) << (i % 8) 1752 out.write_byte(running) 1753 else: 1754 out.write_byte(0) 1755 else: 1756 out.write_byte(0) 1757 1758 for i in range(self.num_fields): 1759 if not self.encoding_positions_are_trivial: 1760 i = self.encoding_positions_argsort[i] 1761 attr = attrs[i] 1762 if attr is None: 1763 if not self.field_nullable[i]: 1764 raise ValueError( 1765 "Attempted to encode null for non-nullable field \"{}\".".format( 1766 self.schema.fields[i].name)) 1767 continue 1768 component_coder = self.components[i] # for typing 1769 component_coder.encode_to_stream(attr, out, True) 1770 1771 def _row_column_encoders(self, columns): 1772 return [ 1773 RowColumnEncoder.create( 1774 self.schema.fields[i].type.atomic_type, 1775 self.components[i], 1776 columns[name]) for i, 1777 name in enumerate(self.field_names) 1778 ] 1779 1780 def encode_batch_to_stream(self, columns: Dict[str, np.ndarray], out): 1781 attrs = self._row_column_encoders(columns) 1782 n = len(next(iter(columns.values()))) 1783 if self.has_nullable_fields: 1784 null_flags_py = np.zeros((n, self.num_fields), dtype=np.uint8) 1785 null_bits_len = (self.num_fields + 7) // 8 1786 null_bits_py = np.zeros((n, null_bits_len), dtype=np.uint8) 1787 for i, attr in enumerate(attrs): 1788 attr_null_flags = attr.null_flags() 1789 if attr_null_flags is not None and attr_null_flags.any(): 1790 null_flags_py[:, i] = attr_null_flags 1791 null_bits_py[:, i // 8] |= attr_null_flags << np.uint8(i % 8) 1792 has_null_bits = (null_bits_py.sum(axis=1) != 0).astype(np.uint8) 1793 null_bits = null_bits_py 1794 null_flags = null_flags_py 1795 else: 1796 has_null_bits = np.zeros((n, ), dtype=np.uint8) 1797 1798 for k in range(n): 1799 out.write_var_int64(self.num_fields) 1800 if has_null_bits[k]: 1801 out.write_byte(null_bits_len) 1802 for i in range(null_bits_len): 1803 out.write_byte(null_bits[k, i]) 1804 else: 1805 out.write_byte(0) 1806 for i in range(self.num_fields): 1807 if not self.encoding_positions_are_trivial: 1808 i = self.encoding_positions_argsort[i] 1809 if has_null_bits[k] and null_flags[k, i]: 1810 if not self.field_nullable[i]: 1811 raise ValueError( 1812 "Attempted to encode null for non-nullable field \"{}\".". 1813 format(self.schema.fields[i].name)) 1814 else: 1815 cython.cast(RowColumnEncoder, attrs[i]).encode_to_stream(k, out) 1816 1817 def decode_from_stream(self, in_stream, nested): 1818 nvals = in_stream.read_var_int64() 1819 null_mask_len = in_stream.read_var_int64() 1820 if null_mask_len: 1821 # pylint: disable=unused-variable 1822 null_mask_c = null_mask_py = in_stream.read(null_mask_len) 1823 1824 # Note that if this coder's schema has *fewer* attributes than the encoded 1825 # value, we just need to ignore the additional values, which will occur 1826 # here because we only decode as many values as we have coders for. 1827 1828 sorted_components = [] 1829 for i in range(min(self.num_fields, nvals)): 1830 if not self.encoding_positions_are_trivial: 1831 i = self.encoding_positions_argsort[i] 1832 if (null_mask_len and i >> 3 < null_mask_len and 1833 null_mask_c[i >> 3] & (0x01 << (i & 0x07))): 1834 item = None 1835 else: 1836 component_coder = self.components[i] # for typing 1837 item = component_coder.decode_from_stream(in_stream, True) 1838 sorted_components.append(item) 1839 1840 # If this coder's schema has more attributes than the encoded value, then 1841 # the schema must have changed. Populate the unencoded fields with nulls. 1842 while len(sorted_components) < self.num_fields: 1843 sorted_components.append(None) 1844 1845 return self.constructor( 1846 *( 1847 sorted_components if self.encoding_positions_are_trivial else 1848 [sorted_components[i] for i in self.encoding_positions])) 1849 1850 def decode_batch_from_stream(self, dest: Dict[str, np.ndarray], in_stream): 1851 attrs = self._row_column_encoders(dest) 1852 n = len(next(iter(dest.values()))) 1853 for k in range(n): 1854 if in_stream.size() == 0: 1855 break 1856 nvals = in_stream.read_var_int64() 1857 null_mask_len = in_stream.read_var_int64() 1858 if null_mask_len: 1859 # pylint: disable=unused-variable 1860 null_mask_c = null_mask = in_stream.read(null_mask_len) 1861 1862 for i in range(min(self.num_fields, nvals)): 1863 if not self.encoding_positions_are_trivial: 1864 i = self.encoding_positions_argsort[i] 1865 if (null_mask_len and i >> 3 < null_mask_len and 1866 null_mask_c[i >> 3] & (0x01 << (i & 0x07))): 1867 continue 1868 else: 1869 cython.cast(RowColumnEncoder, 1870 attrs[i]).decode_from_stream(k, in_stream) 1871 else: 1872 # Loop variable will be n-1 on normal exit. 1873 k = n 1874 1875 for attr in attrs: 1876 attr.finalize_write() 1877 return k 1878 1879 1880 class LogicalTypeCoderImpl(StreamCoderImpl): 1881 def __init__(self, logical_type, representation_coder): 1882 self.logical_type = logical_type 1883 self.representation_coder = representation_coder.get_impl() 1884 1885 def encode_to_stream(self, value, out, nested): 1886 return self.representation_coder.encode_to_stream( 1887 self.logical_type.to_representation_type(value), out, nested) 1888 1889 def decode_from_stream(self, in_stream, nested): 1890 return self.logical_type.to_language_type( 1891 self.representation_coder.decode_from_stream(in_stream, nested)) 1892 1893 1894 class BigIntegerCoderImpl(StreamCoderImpl): 1895 """For internal use only; no backwards-compatibility guarantees. 1896 1897 For interoperability with Java SDK, encoding needs to match that of the Java 1898 SDK BigIntegerCoder.""" 1899 def encode_to_stream(self, value, out, nested): 1900 # type: (int, create_OutputStream, bool) -> None 1901 if value < 0: 1902 byte_length = ((value + 1).bit_length() + 8) // 8 1903 else: 1904 byte_length = (value.bit_length() + 8) // 8 1905 encoded_value = value.to_bytes( 1906 length=byte_length, byteorder='big', signed=True) 1907 out.write(encoded_value, nested) 1908 1909 def decode_from_stream(self, in_stream, nested): 1910 # type: (create_InputStream, bool) -> int 1911 encoded_value = in_stream.read_all(nested) 1912 return int.from_bytes(encoded_value, byteorder='big', signed=True) 1913 1914 1915 class DecimalCoderImpl(StreamCoderImpl): 1916 """For internal use only; no backwards-compatibility guarantees. 1917 1918 For interoperability with Java SDK, encoding needs to match that of the Java 1919 SDK BigDecimalCoder.""" 1920 1921 BIG_INT_CODER_IMPL = BigIntegerCoderImpl() 1922 1923 def encode_to_stream(self, value, out, nested): 1924 # type: (decimal.Decimal, create_OutputStream, bool) -> None 1925 scale = -value.as_tuple().exponent 1926 int_value = int(value.scaleb(scale)) 1927 out.write_var_int64(scale) 1928 self.BIG_INT_CODER_IMPL.encode_to_stream(int_value, out, nested) 1929 1930 def decode_from_stream(self, in_stream, nested): 1931 # type: (create_InputStream, bool) -> decimal.Decimal 1932 scale = in_stream.read_var_int64() 1933 int_value = self.BIG_INT_CODER_IMPL.decode_from_stream(in_stream, nested) 1934 value = decimal.Decimal(int_value).scaleb(-scale) 1935 return value