github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/coders/coder_impl.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  # cython: language_level=3
    19  
    20  """Coder implementations.
    21  
    22  The actual encode/decode implementations are split off from coders to
    23  allow conditional (compiled/pure) implementations, which can be used to
    24  encode many elements with minimal overhead.
    25  
    26  This module may be optionally compiled with Cython, using the corresponding
    27  coder_impl.pxd file for type hints.  In particular, because CoderImpls are
    28  never pickled and sent across the wire (unlike Coders themselves) the workers
    29  can use compiled Impls even if the main program does not (or vice versa).
    30  
    31  For internal use only; no backwards-compatibility guarantees.
    32  """
    33  # pytype: skip-file
    34  
    35  import decimal
    36  import enum
    37  import itertools
    38  import json
    39  import logging
    40  import pickle
    41  from io import BytesIO
    42  from typing import TYPE_CHECKING
    43  from typing import Any
    44  from typing import Callable
    45  from typing import Dict
    46  from typing import Iterable
    47  from typing import Iterator
    48  from typing import List
    49  from typing import Optional
    50  from typing import Sequence
    51  from typing import Set
    52  from typing import Tuple
    53  from typing import Type
    54  
    55  import dill
    56  import numpy as np
    57  from fastavro import parse_schema
    58  from fastavro import schemaless_reader
    59  from fastavro import schemaless_writer
    60  
    61  from apache_beam.coders import observable
    62  from apache_beam.coders.avro_record import AvroRecord
    63  from apache_beam.typehints.schemas import named_tuple_from_schema
    64  from apache_beam.utils import proto_utils
    65  from apache_beam.utils import windowed_value
    66  from apache_beam.utils.sharded_key import ShardedKey
    67  from apache_beam.utils.timestamp import MAX_TIMESTAMP
    68  from apache_beam.utils.timestamp import MIN_TIMESTAMP
    69  from apache_beam.utils.timestamp import Timestamp
    70  
    71  try:
    72    import dataclasses
    73  except ImportError:
    74    dataclasses = None  # type: ignore
    75  
    76  if TYPE_CHECKING:
    77    import proto
    78    from apache_beam.transforms import userstate
    79    from apache_beam.transforms.window import IntervalWindow
    80  
    81  try:
    82    from . import stream  # pylint: disable=unused-import
    83  except ImportError:
    84    SLOW_STREAM = True
    85  else:
    86    SLOW_STREAM = False
    87  
    88  is_compiled = False
    89  fits_in_64_bits = lambda x: -(1 << 63) <= x <= (1 << 63) - 1
    90  
    91  if TYPE_CHECKING or SLOW_STREAM:
    92    from .slow_stream import InputStream as create_InputStream
    93    from .slow_stream import OutputStream as create_OutputStream
    94    from .slow_stream import ByteCountingOutputStream
    95    from .slow_stream import get_varint_size
    96  
    97    try:
    98      import cython
    99      is_compiled = cython.compiled
   100    except ImportError:
   101      globals()['cython'] = type('fake_cython', (), {'cast': lambda typ, x: x})
   102  
   103  else:
   104    # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
   105    from .stream import InputStream as create_InputStream
   106    from .stream import OutputStream as create_OutputStream
   107    from .stream import ByteCountingOutputStream
   108    from .stream import get_varint_size
   109    # Make it possible to import create_InputStream and other cdef-classes
   110    # from apache_beam.coders.coder_impl when Cython codepath is used.
   111    globals()['create_InputStream'] = create_InputStream
   112    globals()['create_OutputStream'] = create_OutputStream
   113    globals()['ByteCountingOutputStream'] = ByteCountingOutputStream
   114    # pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports
   115    is_compiled = True
   116  
   117  _LOGGER = logging.getLogger(__name__)
   118  
   119  _TIME_SHIFT = 1 << 63
   120  MIN_TIMESTAMP_micros = MIN_TIMESTAMP.micros
   121  MAX_TIMESTAMP_micros = MAX_TIMESTAMP.micros
   122  
   123  IterableStateReader = Callable[[bytes, 'CoderImpl'], Iterable]
   124  IterableStateWriter = Callable[[Iterable, 'CoderImpl'], bytes]
   125  Observables = List[Tuple[observable.ObservableMixin, 'CoderImpl']]
   126  
   127  
   128  class CoderImpl(object):
   129    """For internal use only; no backwards-compatibility guarantees."""
   130    def encode_to_stream(self, value, stream, nested):
   131      # type: (Any, create_OutputStream, bool) -> None
   132  
   133      """Reads object from potentially-nested encoding in stream."""
   134      raise NotImplementedError
   135  
   136    def decode_from_stream(self, stream, nested):
   137      # type: (create_InputStream, bool) -> Any
   138  
   139      """Reads object from potentially-nested encoding in stream."""
   140      raise NotImplementedError
   141  
   142    def encode(self, value):
   143      # type: (Any) -> bytes
   144  
   145      """Encodes an object to an unnested string."""
   146      raise NotImplementedError
   147  
   148    def decode(self, encoded):
   149      # type: (bytes) -> Any
   150  
   151      """Decodes an object to an unnested string."""
   152      raise NotImplementedError
   153  
   154    def encode_all(self, values):
   155      # type: (Iterable[Any]) -> bytes
   156      out = create_OutputStream()
   157      for value in values:
   158        self.encode_to_stream(value, out, True)
   159      return out.get()
   160  
   161    def decode_all(self, encoded):
   162      # type: (bytes) -> Iterator[Any]
   163      input_stream = create_InputStream(encoded)
   164      while input_stream.size() > 0:
   165        yield self.decode_from_stream(input_stream, True)
   166  
   167    def encode_nested(self, value):
   168      # type: (Any) -> bytes
   169      out = create_OutputStream()
   170      self.encode_to_stream(value, out, True)
   171      return out.get()
   172  
   173    def decode_nested(self, encoded):
   174      # type: (bytes) -> Any
   175      return self.decode_from_stream(create_InputStream(encoded), True)
   176  
   177    def estimate_size(self, value, nested=False):
   178      # type: (Any, bool) -> int
   179  
   180      """Estimates the encoded size of the given value, in bytes."""
   181      out = ByteCountingOutputStream()
   182      self.encode_to_stream(value, out, nested)
   183      return out.get_count()
   184  
   185    def _get_nested_size(self, inner_size, nested):
   186      if not nested:
   187        return inner_size
   188      varint_size = get_varint_size(inner_size)
   189      return varint_size + inner_size
   190  
   191    def get_estimated_size_and_observables(self, value, nested=False):
   192      # type: (Any, bool) -> Tuple[int, Observables]
   193  
   194      """Returns estimated size of value along with any nested observables.
   195  
   196      The list of nested observables is returned as a list of 2-tuples of
   197      (obj, coder_impl), where obj is an instance of observable.ObservableMixin,
   198      and coder_impl is the CoderImpl that can be used to encode elements sent by
   199      obj to its observers.
   200  
   201      Arguments:
   202        value: the value whose encoded size is to be estimated.
   203        nested: whether the value is nested.
   204  
   205      Returns:
   206        The estimated encoded size of the given value and a list of observables
   207        whose elements are 2-tuples of (obj, coder_impl) as described above.
   208      """
   209      return self.estimate_size(value, nested), []
   210  
   211  
   212  class SimpleCoderImpl(CoderImpl):
   213    """For internal use only; no backwards-compatibility guarantees.
   214  
   215    Subclass of CoderImpl implementing stream methods using encode/decode."""
   216    def encode_to_stream(self, value, stream, nested):
   217      # type: (Any, create_OutputStream, bool) -> None
   218  
   219      """Reads object from potentially-nested encoding in stream."""
   220      stream.write(self.encode(value), nested)
   221  
   222    def decode_from_stream(self, stream, nested):
   223      # type: (create_InputStream, bool) -> Any
   224  
   225      """Reads object from potentially-nested encoding in stream."""
   226      return self.decode(stream.read_all(nested))
   227  
   228  
   229  class StreamCoderImpl(CoderImpl):
   230    """For internal use only; no backwards-compatibility guarantees.
   231  
   232    Subclass of CoderImpl implementing encode/decode using stream methods."""
   233    def encode(self, value):
   234      # type: (Any) -> bytes
   235      out = create_OutputStream()
   236      self.encode_to_stream(value, out, False)
   237      return out.get()
   238  
   239    def decode(self, encoded):
   240      # type: (bytes) -> Any
   241      return self.decode_from_stream(create_InputStream(encoded), False)
   242  
   243    def estimate_size(self, value, nested=False):
   244      # type: (Any, bool) -> int
   245  
   246      """Estimates the encoded size of the given value, in bytes."""
   247      out = ByteCountingOutputStream()
   248      self.encode_to_stream(value, out, nested)
   249      return out.get_count()
   250  
   251  
   252  class CallbackCoderImpl(CoderImpl):
   253    """For internal use only; no backwards-compatibility guarantees.
   254  
   255    A CoderImpl that calls back to the _impl methods on the Coder itself.
   256  
   257    This is the default implementation used if Coder._get_impl()
   258    is not overwritten.
   259    """
   260    def __init__(self, encoder, decoder, size_estimator=None):
   261      self._encoder = encoder
   262      self._decoder = decoder
   263      self._size_estimator = size_estimator or self._default_size_estimator
   264  
   265    def _default_size_estimator(self, value):
   266      return len(self.encode(value))
   267  
   268    def encode_to_stream(self, value, stream, nested):
   269      # type: (Any, create_OutputStream, bool) -> None
   270      return stream.write(self._encoder(value), nested)
   271  
   272    def decode_from_stream(self, stream, nested):
   273      # type: (create_InputStream, bool) -> Any
   274      return self._decoder(stream.read_all(nested))
   275  
   276    def encode(self, value):
   277      return self._encoder(value)
   278  
   279    def decode(self, encoded):
   280      return self._decoder(encoded)
   281  
   282    def estimate_size(self, value, nested=False):
   283      # type: (Any, bool) -> int
   284      return self._get_nested_size(self._size_estimator(value), nested)
   285  
   286    def get_estimated_size_and_observables(self, value, nested=False):
   287      # type: (Any, bool) -> Tuple[int, Observables]
   288      # TODO(robertwb): Remove this once all coders are correct.
   289      if isinstance(value, observable.ObservableMixin):
   290        # CallbackCoderImpl can presumably encode the elements too.
   291        return 1, [(value, self)]
   292  
   293      return self.estimate_size(value, nested), []
   294  
   295    def __repr__(self):
   296      return 'CallbackCoderImpl[encoder=%s, decoder=%s]' % (
   297          self._encoder, self._decoder)
   298  
   299  
   300  class ProtoCoderImpl(SimpleCoderImpl):
   301    """For internal use only; no backwards-compatibility guarantees."""
   302    def __init__(self, proto_message_type):
   303      self.proto_message_type = proto_message_type
   304  
   305    def encode(self, value):
   306      return value.SerializePartialToString()
   307  
   308    def decode(self, encoded):
   309      proto_message = self.proto_message_type()
   310      proto_message.ParseFromString(encoded)  # This is in effect "ParsePartial".
   311      return proto_message
   312  
   313  
   314  class DeterministicProtoCoderImpl(ProtoCoderImpl):
   315    """For internal use only; no backwards-compatibility guarantees."""
   316    def encode(self, value):
   317      return value.SerializePartialToString(deterministic=True)
   318  
   319  
   320  class ProtoPlusCoderImpl(SimpleCoderImpl):
   321    """For internal use only; no backwards-compatibility guarantees."""
   322    def __init__(self, proto_plus_type):
   323      # type: (Type[proto.Message]) -> None
   324      self.proto_plus_type = proto_plus_type
   325  
   326    def encode(self, value):
   327      return value._pb.SerializePartialToString(deterministic=True)
   328  
   329    def decode(self, value):
   330      return self.proto_plus_type.deserialize(value)
   331  
   332  
   333  UNKNOWN_TYPE = 0xFF
   334  NONE_TYPE = 0
   335  INT_TYPE = 1
   336  FLOAT_TYPE = 2
   337  BYTES_TYPE = 3
   338  UNICODE_TYPE = 4
   339  BOOL_TYPE = 9
   340  LIST_TYPE = 5
   341  TUPLE_TYPE = 6
   342  DICT_TYPE = 7
   343  SET_TYPE = 8
   344  ITERABLE_LIKE_TYPE = 10
   345  
   346  PROTO_TYPE = 100
   347  DATACLASS_TYPE = 101
   348  NAMED_TUPLE_TYPE = 102
   349  ENUM_TYPE = 103
   350  NESTED_STATE_TYPE = 104
   351  
   352  # Types that can be encoded as iterables, but are not literally
   353  # lists, etc. due to being lazy.  The actual type is not preserved
   354  # through encoding, only the elements. This is particularly useful
   355  # for the value list types created in GroupByKey.
   356  _ITERABLE_LIKE_TYPES = set()  # type: Set[Type]
   357  
   358  
   359  class FastPrimitivesCoderImpl(StreamCoderImpl):
   360    """For internal use only; no backwards-compatibility guarantees."""
   361    def __init__(
   362        self, fallback_coder_impl, requires_deterministic_step_label=None):
   363      self.fallback_coder_impl = fallback_coder_impl
   364      self.iterable_coder_impl = IterableCoderImpl(self)
   365      self.requires_deterministic_step_label = requires_deterministic_step_label
   366      self.warn_deterministic_fallback = True
   367  
   368    @staticmethod
   369    def register_iterable_like_type(t):
   370      _ITERABLE_LIKE_TYPES.add(t)
   371  
   372    def get_estimated_size_and_observables(self, value, nested=False):
   373      # type: (Any, bool) -> Tuple[int, Observables]
   374      if isinstance(value, observable.ObservableMixin):
   375        # FastPrimitivesCoderImpl can presumably encode the elements too.
   376        return 1, [(value, self)]
   377  
   378      out = ByteCountingOutputStream()
   379      self.encode_to_stream(value, out, nested)
   380      return out.get_count(), []
   381  
   382    def encode_to_stream(self, value, stream, nested):
   383      # type: (Any, create_OutputStream, bool) -> None
   384      t = type(value)
   385      if value is None:
   386        stream.write_byte(NONE_TYPE)
   387      elif t is int:
   388        # In Python 3, an int may be larger than 64 bits.
   389        # We need to check whether value fits into a 64 bit integer before
   390        # writing the marker byte.
   391        try:
   392          # In Cython-compiled code this will throw an overflow error
   393          # when value does not fit into int64.
   394          int_value = value
   395          # If Cython is not used, we must do a (slower) check ourselves.
   396          if not TYPE_CHECKING and not is_compiled:
   397            if not fits_in_64_bits(value):
   398              raise OverflowError()
   399          stream.write_byte(INT_TYPE)
   400          stream.write_var_int64(int_value)
   401        except OverflowError:
   402          stream.write_byte(UNKNOWN_TYPE)
   403          self.fallback_coder_impl.encode_to_stream(value, stream, nested)
   404      elif t is float:
   405        stream.write_byte(FLOAT_TYPE)
   406        stream.write_bigendian_double(value)
   407      elif t is bytes:
   408        stream.write_byte(BYTES_TYPE)
   409        stream.write(value, nested)
   410      elif t is str:
   411        unicode_value = value  # for typing
   412        stream.write_byte(UNICODE_TYPE)
   413        stream.write(unicode_value.encode('utf-8'), nested)
   414      elif t is list or t is tuple:
   415        stream.write_byte(LIST_TYPE if t is list else TUPLE_TYPE)
   416        stream.write_var_int64(len(value))
   417        for e in value:
   418          self.encode_to_stream(e, stream, True)
   419      elif t is bool:
   420        stream.write_byte(BOOL_TYPE)
   421        stream.write_byte(value)
   422      elif t in _ITERABLE_LIKE_TYPES:
   423        stream.write_byte(ITERABLE_LIKE_TYPE)
   424        self.iterable_coder_impl.encode_to_stream(value, stream, nested)
   425      elif t is dict:
   426        dict_value = value  # for typing
   427        stream.write_byte(DICT_TYPE)
   428        stream.write_var_int64(len(dict_value))
   429        if self.requires_deterministic_step_label is not None:
   430          try:
   431            ordered_kvs = sorted(dict_value.items())
   432          except Exception as exn:
   433            raise TypeError(
   434                "Unable to deterministically order keys of dict for '%s'" %
   435                self.requires_deterministic_step_label) from exn
   436          for k, v in ordered_kvs:
   437            self.encode_to_stream(k, stream, True)
   438            self.encode_to_stream(v, stream, True)
   439        else:
   440          # Loop over dict.items() is optimized by Cython.
   441          for k, v in dict_value.items():
   442            self.encode_to_stream(k, stream, True)
   443            self.encode_to_stream(v, stream, True)
   444      elif t is set:
   445        stream.write_byte(SET_TYPE)
   446        stream.write_var_int64(len(value))
   447        if self.requires_deterministic_step_label is not None:
   448          try:
   449            value = sorted(value)
   450          except Exception as exn:
   451            raise TypeError(
   452                "Unable to deterministically order element of set for '%s'" %
   453                self.requires_deterministic_step_label) from exn
   454        for e in value:
   455          self.encode_to_stream(e, stream, True)
   456      # All possibly deterministic encodings should be above this clause,
   457      # all non-deterministic ones below.
   458      elif self.requires_deterministic_step_label is not None:
   459        self.encode_special_deterministic(value, stream)
   460      else:
   461        stream.write_byte(UNKNOWN_TYPE)
   462        self.fallback_coder_impl.encode_to_stream(value, stream, nested)
   463  
   464    def encode_special_deterministic(self, value, stream):
   465      if self.warn_deterministic_fallback:
   466        _LOGGER.warning(
   467            "Using fallback deterministic coder for type '%s' in '%s'. ",
   468            type(value),
   469            self.requires_deterministic_step_label)
   470        self.warn_deterministic_fallback = False
   471      if isinstance(value, proto_utils.message_types):
   472        stream.write_byte(PROTO_TYPE)
   473        self.encode_type(type(value), stream)
   474        stream.write(value.SerializePartialToString(deterministic=True), True)
   475      elif dataclasses and dataclasses.is_dataclass(value):
   476        stream.write_byte(DATACLASS_TYPE)
   477        if not type(value).__dataclass_params__.frozen:
   478          raise TypeError(
   479              "Unable to deterministically encode non-frozen '%s' of type '%s' "
   480              "for the input of '%s'" %
   481              (value, type(value), self.requires_deterministic_step_label))
   482        self.encode_type(type(value), stream)
   483        values = [
   484            getattr(value, field.name) for field in dataclasses.fields(value)
   485        ]
   486        try:
   487          self.iterable_coder_impl.encode_to_stream(values, stream, True)
   488        except Exception as e:
   489          raise TypeError(self._deterministic_encoding_error_msg(value)) from e
   490      elif isinstance(value, tuple) and hasattr(type(value), '_fields'):
   491        stream.write_byte(NAMED_TUPLE_TYPE)
   492        self.encode_type(type(value), stream)
   493        try:
   494          self.iterable_coder_impl.encode_to_stream(value, stream, True)
   495        except Exception as e:
   496          raise TypeError(self._deterministic_encoding_error_msg(value)) from e
   497      elif isinstance(value, enum.Enum):
   498        stream.write_byte(ENUM_TYPE)
   499        self.encode_type(type(value), stream)
   500        # Enum values can be of any type.
   501        try:
   502          self.encode_to_stream(value.value, stream, True)
   503        except Exception as e:
   504          raise TypeError(self._deterministic_encoding_error_msg(value)) from e
   505      elif hasattr(value, "__getstate__"):
   506        if not hasattr(value, "__setstate__"):
   507          raise TypeError(
   508              "Unable to deterministically encode '%s' of type '%s', "
   509              "for the input of '%s'. The object defines __getstate__ but not "
   510              "__setstate__." %
   511              (value, type(value), self.requires_deterministic_step_label))
   512        stream.write_byte(NESTED_STATE_TYPE)
   513        self.encode_type(type(value), stream)
   514        state_value = value.__getstate__()
   515        try:
   516          self.encode_to_stream(state_value, stream, True)
   517        except Exception as e:
   518          raise TypeError(self._deterministic_encoding_error_msg(value)) from e
   519      else:
   520        raise TypeError(self._deterministic_encoding_error_msg(value))
   521  
   522    def _deterministic_encoding_error_msg(self, value):
   523      return (
   524          "Unable to deterministically encode '%s' of type '%s', "
   525          "please provide a type hint for the input of '%s'" %
   526          (value, type(value), self.requires_deterministic_step_label))
   527  
   528    def encode_type(self, t, stream):
   529      stream.write(dill.dumps(t), True)
   530  
   531    def decode_type(self, stream):
   532      return _unpickle_type(stream.read_all(True))
   533  
   534    def decode_from_stream(self, stream, nested):
   535      # type: (create_InputStream, bool) -> Any
   536      t = stream.read_byte()
   537      if t == NONE_TYPE:
   538        return None
   539      elif t == INT_TYPE:
   540        return stream.read_var_int64()
   541      elif t == FLOAT_TYPE:
   542        return stream.read_bigendian_double()
   543      elif t == BYTES_TYPE:
   544        return stream.read_all(nested)
   545      elif t == UNICODE_TYPE:
   546        return stream.read_all(nested).decode('utf-8')
   547      elif t == LIST_TYPE or t == TUPLE_TYPE or t == SET_TYPE:
   548        vlen = stream.read_var_int64()
   549        vlist = [self.decode_from_stream(stream, True) for _ in range(vlen)]
   550        if t == LIST_TYPE:
   551          return vlist
   552        elif t == TUPLE_TYPE:
   553          return tuple(vlist)
   554        return set(vlist)
   555      elif t == DICT_TYPE:
   556        vlen = stream.read_var_int64()
   557        v = {}
   558        for _ in range(vlen):
   559          k = self.decode_from_stream(stream, True)
   560          v[k] = self.decode_from_stream(stream, True)
   561        return v
   562      elif t == BOOL_TYPE:
   563        return not not stream.read_byte()
   564      elif t == ITERABLE_LIKE_TYPE:
   565        return self.iterable_coder_impl.decode_from_stream(stream, nested)
   566      elif t == PROTO_TYPE:
   567        cls = self.decode_type(stream)
   568        msg = cls()
   569        msg.ParseFromString(stream.read_all(True))
   570        return msg
   571      elif t == DATACLASS_TYPE or t == NAMED_TUPLE_TYPE:
   572        cls = self.decode_type(stream)
   573        return cls(*self.iterable_coder_impl.decode_from_stream(stream, True))
   574      elif t == ENUM_TYPE:
   575        cls = self.decode_type(stream)
   576        return cls(self.decode_from_stream(stream, True))
   577      elif t == NESTED_STATE_TYPE:
   578        cls = self.decode_type(stream)
   579        state = self.decode_from_stream(stream, True)
   580        value = cls.__new__(cls)
   581        value.__setstate__(state)
   582        return value
   583      elif t == UNKNOWN_TYPE:
   584        return self.fallback_coder_impl.decode_from_stream(stream, nested)
   585      else:
   586        raise ValueError('Unknown type tag %x' % t)
   587  
   588  
   589  _unpickled_types = {}  # type: Dict[bytes, type]
   590  
   591  
   592  def _unpickle_type(bs):
   593    t = _unpickled_types.get(bs, None)
   594    if t is None:
   595      t = _unpickled_types[bs] = dill.loads(bs)
   596      # Fix unpicklable anonymous named tuples for Python 3.6.
   597      if t.__base__ is tuple and hasattr(t, '_fields'):
   598        try:
   599          pickle.loads(pickle.dumps(t))
   600        except pickle.PicklingError:
   601          t.__reduce__ = lambda self: (_unpickle_named_tuple, (bs, tuple(self)))
   602    return t
   603  
   604  
   605  def _unpickle_named_tuple(bs, items):
   606    return _unpickle_type(bs)(*items)
   607  
   608  
   609  class BytesCoderImpl(CoderImpl):
   610    """For internal use only; no backwards-compatibility guarantees.
   611  
   612    A coder for bytes/str objects."""
   613    def encode_to_stream(self, value, out, nested):
   614      # type: (bytes, create_OutputStream, bool) -> None
   615  
   616      # value might be of type np.bytes if passed from encode_batch, and cython
   617      # does not recognize it as bytes.
   618      if is_compiled and isinstance(value, np.bytes_):
   619        value = bytes(value)
   620  
   621      out.write(value, nested)
   622  
   623    def decode_from_stream(self, in_stream, nested):
   624      # type: (create_InputStream, bool) -> bytes
   625      return in_stream.read_all(nested)
   626  
   627    def encode(self, value):
   628      assert isinstance(value, bytes), (value, type(value))
   629      return value
   630  
   631    def decode(self, encoded):
   632      return encoded
   633  
   634  
   635  class BooleanCoderImpl(CoderImpl):
   636    """For internal use only; no backwards-compatibility guarantees.
   637  
   638    A coder for bool objects."""
   639    def encode_to_stream(self, value, out, nested):
   640      out.write_byte(1 if value else 0)
   641  
   642    def decode_from_stream(self, in_stream, nested):
   643      value = in_stream.read_byte()
   644      if value == 0:
   645        return False
   646      elif value == 1:
   647        return True
   648      raise ValueError("Expected 0 or 1, got %s" % value)
   649  
   650    def encode(self, value):
   651      return b'\x01' if value else b'\x00'
   652  
   653    def decode(self, encoded):
   654      value = ord(encoded)
   655      if value == 0:
   656        return False
   657      elif value == 1:
   658        return True
   659      raise ValueError("Expected 0 or 1, got %s" % value)
   660  
   661    def estimate_size(self, unused_value, nested=False):
   662      # Note that booleans are encoded the same way regardless of nesting.
   663      return 1
   664  
   665  
   666  class MapCoderImpl(StreamCoderImpl):
   667    """For internal use only; no backwards-compatibility guarantees.
   668  
   669    Note this implementation always uses nested context when encoding keys
   670    and values. This differs from Java's MapCoder, which uses
   671    nested=False if possible for the last value encoded.
   672  
   673    This difference is acceptable because MapCoder is not standard. It is only
   674    used in a standard context by RowCoder which always uses nested context for
   675    attribute values.
   676  
   677    A coder for typing.Mapping objects."""
   678    def __init__(
   679        self,
   680        key_coder,  # type: CoderImpl
   681        value_coder,  # type: CoderImpl
   682        is_deterministic = False
   683    ):
   684      self._key_coder = key_coder
   685      self._value_coder = value_coder
   686      self._is_deterministic = is_deterministic
   687  
   688    def encode_to_stream(self, dict_value, out, nested):
   689      out.write_bigendian_int32(len(dict_value))
   690      # Note this implementation always uses nested context when encoding keys
   691      # and values which differs from Java. See note in docstring.
   692      if self._is_deterministic:
   693        for key, value in sorted(dict_value.items()):
   694          self._key_coder.encode_to_stream(key, out, True)
   695          self._value_coder.encode_to_stream(value, out, True)
   696      else:
   697        # This loop is separate from the above so the loop over dict.items()
   698        # will be optimized by Cython.
   699        for key, value in dict_value.items():
   700          self._key_coder.encode_to_stream(key, out, True)
   701          self._value_coder.encode_to_stream(value, out, True)
   702  
   703    def decode_from_stream(self, in_stream, nested):
   704      size = in_stream.read_bigendian_int32()
   705      result = {}
   706      for _ in range(size):
   707        # Note this implementation always uses nested context when encoding keys
   708        # and values which differs from Java. See note in docstring.
   709        key = self._key_coder.decode_from_stream(in_stream, True)
   710        value = self._value_coder.decode_from_stream(in_stream, True)
   711        result[key] = value
   712  
   713      return result
   714  
   715    def estimate_size(self, unused_value, nested=False):
   716      estimate = 4  # 4 bytes for int32 size prefix
   717      for key, value in unused_value.items():
   718        estimate += self._key_coder.estimate_size(key, True)
   719        estimate += self._value_coder.estimate_size(value, True)
   720      return estimate
   721  
   722  
   723  class NullableCoderImpl(StreamCoderImpl):
   724    """For internal use only; no backwards-compatibility guarantees.
   725  
   726    A coder for typing.Optional objects."""
   727  
   728    ENCODE_NULL = 0
   729    ENCODE_PRESENT = 1
   730  
   731    def __init__(
   732        self,
   733        value_coder  # type: CoderImpl
   734    ):
   735      self._value_coder = value_coder
   736  
   737    def encode_to_stream(self, value, out, nested):
   738      if value is None:
   739        out.write_byte(self.ENCODE_NULL)
   740      else:
   741        out.write_byte(self.ENCODE_PRESENT)
   742        self._value_coder.encode_to_stream(value, out, nested)
   743  
   744    def decode_from_stream(self, in_stream, nested):
   745      null_indicator = in_stream.read_byte()
   746      if null_indicator == self.ENCODE_NULL:
   747        return None
   748      elif null_indicator == self.ENCODE_PRESENT:
   749        return self._value_coder.decode_from_stream(in_stream, nested)
   750      else:
   751        raise ValueError(
   752            "Encountered unexpected value for null indicator: '%s'" %
   753            null_indicator)
   754  
   755    def estimate_size(self, unused_value, nested=False):
   756      return 1 + (
   757          self._value_coder.estimate_size(unused_value)
   758          if unused_value is not None else 0)
   759  
   760  
   761  class BigEndianShortCoderImpl(StreamCoderImpl):
   762    """For internal use only; no backwards-compatibility guarantees."""
   763    def encode_to_stream(self, value, out, nested):
   764      # type: (int, create_OutputStream, bool) -> None
   765      out.write_bigendian_int16(value)
   766  
   767    def decode_from_stream(self, in_stream, nested):
   768      # type: (create_InputStream, bool) -> float
   769      return in_stream.read_bigendian_int16()
   770  
   771    def estimate_size(self, unused_value, nested=False):
   772      # type: (Any, bool) -> int
   773      # A short is encoded as 2 bytes, regardless of nesting.
   774      return 2
   775  
   776  
   777  class SinglePrecisionFloatCoderImpl(StreamCoderImpl):
   778    """For internal use only; no backwards-compatibility guarantees."""
   779    def encode_to_stream(self, value, out, nested):
   780      # type: (float, create_OutputStream, bool) -> None
   781      out.write_bigendian_float(value)
   782  
   783    def decode_from_stream(self, in_stream, nested):
   784      # type: (create_InputStream, bool) -> float
   785      return in_stream.read_bigendian_float()
   786  
   787    def estimate_size(self, unused_value, nested=False):
   788      # type: (Any, bool) -> int
   789      # A float is encoded as 4 bytes, regardless of nesting.
   790      return 4
   791  
   792  
   793  class FloatCoderImpl(StreamCoderImpl):
   794    """For internal use only; no backwards-compatibility guarantees."""
   795    def encode_to_stream(self, value, out, nested):
   796      # type: (float, create_OutputStream, bool) -> None
   797      out.write_bigendian_double(value)
   798  
   799    def decode_from_stream(self, in_stream, nested):
   800      # type: (create_InputStream, bool) -> float
   801      return in_stream.read_bigendian_double()
   802  
   803    def estimate_size(self, unused_value, nested=False):
   804      # type: (Any, bool) -> int
   805      # A double is encoded as 8 bytes, regardless of nesting.
   806      return 8
   807  
   808  
   809  if not TYPE_CHECKING:
   810    IntervalWindow = None
   811  
   812  
   813  class IntervalWindowCoderImpl(StreamCoderImpl):
   814    """For internal use only; no backwards-compatibility guarantees."""
   815  
   816    # TODO: Fn Harness only supports millis. Is this important enough to fix?
   817    def _to_normal_time(self, value):
   818      """Convert "lexicographically ordered unsigned" to signed."""
   819      return value - _TIME_SHIFT
   820  
   821    def _from_normal_time(self, value):
   822      """Convert signed to "lexicographically ordered unsigned"."""
   823      return value + _TIME_SHIFT
   824  
   825    def encode_to_stream(self, value, out, nested):
   826      # type: (IntervalWindow, create_OutputStream, bool) -> None
   827      typed_value = value
   828      span_millis = (
   829          typed_value._end_micros // 1000 - typed_value._start_micros // 1000)
   830      out.write_bigendian_uint64(
   831          self._from_normal_time(typed_value._end_micros // 1000))
   832      out.write_var_int64(span_millis)
   833  
   834    def decode_from_stream(self, in_, nested):
   835      # type: (create_InputStream, bool) -> IntervalWindow
   836      if not TYPE_CHECKING:
   837        global IntervalWindow  # pylint: disable=global-variable-not-assigned
   838        if IntervalWindow is None:
   839          from apache_beam.transforms.window import IntervalWindow
   840      # instantiating with None is not part of the public interface
   841      typed_value = IntervalWindow(None, None)  # type: ignore[arg-type]
   842      typed_value._end_micros = (
   843          1000 * self._to_normal_time(in_.read_bigendian_uint64()))
   844      typed_value._start_micros = (
   845          typed_value._end_micros - 1000 * in_.read_var_int64())
   846      return typed_value
   847  
   848    def estimate_size(self, value, nested=False):
   849      # type: (Any, bool) -> int
   850      # An IntervalWindow is context-insensitive, with a timestamp (8 bytes)
   851      # and a varint timespam.
   852      typed_value = value
   853      span_millis = (
   854          typed_value._end_micros // 1000 - typed_value._start_micros // 1000)
   855      return 8 + get_varint_size(span_millis)
   856  
   857  
   858  class TimestampCoderImpl(StreamCoderImpl):
   859    """For internal use only; no backwards-compatibility guarantees.
   860  
   861    TODO: SDK agnostic encoding
   862    For interoperability with Java SDK, encoding needs to match
   863    that of the Java SDK InstantCoder.
   864    https://github.com/apache/beam/blob/f5029b4f0dfff404310b2ef55e2632bbacc7b04f/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/InstantCoder.java#L79
   865    """
   866    def encode_to_stream(self, value, out, nested):
   867      # type: (Timestamp, create_OutputStream, bool) -> None
   868      millis = value.micros // 1000
   869      if millis >= 0:
   870        millis = millis - _TIME_SHIFT
   871      else:
   872        millis = millis + _TIME_SHIFT
   873      out.write_bigendian_int64(millis)
   874  
   875    def decode_from_stream(self, in_stream, nested):
   876      # type: (create_InputStream, bool) -> Timestamp
   877      millis = in_stream.read_bigendian_int64()
   878      if millis < 0:
   879        millis = millis + _TIME_SHIFT
   880      else:
   881        millis = millis - _TIME_SHIFT
   882      return Timestamp(micros=millis * 1000)
   883  
   884    def estimate_size(self, unused_value, nested=False):
   885      # A Timestamp is encoded as a 64-bit integer in 8 bytes, regardless of
   886      # nesting.
   887      return 8
   888  
   889  
   890  class TimerCoderImpl(StreamCoderImpl):
   891    """For internal use only; no backwards-compatibility guarantees."""
   892    def __init__(self, key_coder_impl, window_coder_impl):
   893      self._timestamp_coder_impl = TimestampCoderImpl()
   894      self._boolean_coder_impl = BooleanCoderImpl()
   895      self._pane_info_coder_impl = PaneInfoCoderImpl()
   896      self._key_coder_impl = key_coder_impl
   897      self._windows_coder_impl = TupleSequenceCoderImpl(window_coder_impl)
   898      from apache_beam.coders.coders import StrUtf8Coder
   899      self._tag_coder_impl = StrUtf8Coder().get_impl()
   900  
   901    def encode_to_stream(self, value, out, nested):
   902      # type: (userstate.Timer, create_OutputStream, bool) -> None
   903      self._key_coder_impl.encode_to_stream(value.user_key, out, True)
   904      self._tag_coder_impl.encode_to_stream(value.dynamic_timer_tag, out, True)
   905      self._windows_coder_impl.encode_to_stream(value.windows, out, True)
   906      self._boolean_coder_impl.encode_to_stream(value.clear_bit, out, True)
   907      if not value.clear_bit:
   908        self._timestamp_coder_impl.encode_to_stream(
   909            value.fire_timestamp, out, True)
   910        self._timestamp_coder_impl.encode_to_stream(
   911            value.hold_timestamp, out, True)
   912        self._pane_info_coder_impl.encode_to_stream(value.paneinfo, out, True)
   913  
   914    def decode_from_stream(self, in_stream, nested):
   915      # type: (create_InputStream, bool) -> userstate.Timer
   916      from apache_beam.transforms import userstate
   917      user_key = self._key_coder_impl.decode_from_stream(in_stream, True)
   918      dynamic_timer_tag = self._tag_coder_impl.decode_from_stream(in_stream, True)
   919      windows = self._windows_coder_impl.decode_from_stream(in_stream, True)
   920      clear_bit = self._boolean_coder_impl.decode_from_stream(in_stream, True)
   921      if clear_bit:
   922        return userstate.Timer(
   923            user_key=user_key,
   924            dynamic_timer_tag=dynamic_timer_tag,
   925            windows=windows,
   926            clear_bit=clear_bit,
   927            fire_timestamp=None,
   928            hold_timestamp=None,
   929            paneinfo=None)
   930  
   931      return userstate.Timer(
   932          user_key=user_key,
   933          dynamic_timer_tag=dynamic_timer_tag,
   934          windows=windows,
   935          clear_bit=clear_bit,
   936          fire_timestamp=self._timestamp_coder_impl.decode_from_stream(
   937              in_stream, True),
   938          hold_timestamp=self._timestamp_coder_impl.decode_from_stream(
   939              in_stream, True),
   940          paneinfo=self._pane_info_coder_impl.decode_from_stream(in_stream, True))
   941  
   942  
   943  small_ints = [chr(_).encode('latin-1') for _ in range(128)]
   944  
   945  
   946  class VarIntCoderImpl(StreamCoderImpl):
   947    """For internal use only; no backwards-compatibility guarantees.
   948  
   949    A coder for int objects."""
   950    def encode_to_stream(self, value, out, nested):
   951      # type: (int, create_OutputStream, bool) -> None
   952      out.write_var_int64(value)
   953  
   954    def decode_from_stream(self, in_stream, nested):
   955      # type: (create_InputStream, bool) -> int
   956      return in_stream.read_var_int64()
   957  
   958    def encode(self, value):
   959      ivalue = value  # type cast
   960      if 0 <= ivalue < len(small_ints):
   961        return small_ints[ivalue]
   962      return StreamCoderImpl.encode(self, value)
   963  
   964    def decode(self, encoded):
   965      if len(encoded) == 1:
   966        i = ord(encoded)
   967        if 0 <= i < 128:
   968          return i
   969      return StreamCoderImpl.decode(self, encoded)
   970  
   971    def estimate_size(self, value, nested=False):
   972      # type: (Any, bool) -> int
   973      # Note that VarInts are encoded the same way regardless of nesting.
   974      return get_varint_size(value)
   975  
   976  
   977  class SingletonCoderImpl(CoderImpl):
   978    """For internal use only; no backwards-compatibility guarantees.
   979  
   980    A coder that always encodes exactly one value."""
   981    def __init__(self, value):
   982      self._value = value
   983  
   984    def encode_to_stream(self, value, stream, nested):
   985      # type: (Any, create_OutputStream, bool) -> None
   986      pass
   987  
   988    def decode_from_stream(self, stream, nested):
   989      # type: (create_InputStream, bool) -> Any
   990      return self._value
   991  
   992    def encode(self, value):
   993      b = b''  # avoid byte vs str vs unicode error
   994      return b
   995  
   996    def decode(self, encoded):
   997      return self._value
   998  
   999    def estimate_size(self, value, nested=False):
  1000      # type: (Any, bool) -> int
  1001      return 0
  1002  
  1003  
  1004  class AbstractComponentCoderImpl(StreamCoderImpl):
  1005    """For internal use only; no backwards-compatibility guarantees.
  1006  
  1007    CoderImpl for coders that are comprised of several component coders."""
  1008    def __init__(self, coder_impls):
  1009      for c in coder_impls:
  1010        assert isinstance(c, CoderImpl), c
  1011      self._coder_impls = tuple(coder_impls)
  1012  
  1013    def _extract_components(self, value):
  1014      raise NotImplementedError
  1015  
  1016    def _construct_from_components(self, components):
  1017      raise NotImplementedError
  1018  
  1019    def encode_to_stream(self, value, out, nested):
  1020      # type: (Any, create_OutputStream, bool) -> None
  1021      values = self._extract_components(value)
  1022      if len(self._coder_impls) != len(values):
  1023        raise ValueError('Number of components does not match number of coders.')
  1024      for i in range(0, len(self._coder_impls)):
  1025        c = self._coder_impls[i]  # type cast
  1026        c.encode_to_stream(
  1027            values[i], out, nested or i + 1 < len(self._coder_impls))
  1028  
  1029    def decode_from_stream(self, in_stream, nested):
  1030      # type: (create_InputStream, bool) -> Any
  1031      return self._construct_from_components([
  1032          c.decode_from_stream(
  1033              in_stream, nested or i + 1 < len(self._coder_impls)) for i,
  1034          c in enumerate(self._coder_impls)
  1035      ])
  1036  
  1037    def estimate_size(self, value, nested=False):
  1038      # type: (Any, bool) -> int
  1039  
  1040      """Estimates the encoded size of the given value, in bytes."""
  1041      # TODO(ccy): This ignores sizes of observable components.
  1042      estimated_size, _ = (self.get_estimated_size_and_observables(value))
  1043      return estimated_size
  1044  
  1045    def get_estimated_size_and_observables(self, value, nested=False):
  1046      # type: (Any, bool) -> Tuple[int, Observables]
  1047  
  1048      """Returns estimated size of value along with any nested observables."""
  1049      values = self._extract_components(value)
  1050      estimated_size = 0
  1051      observables = []  # type: Observables
  1052      for i in range(0, len(self._coder_impls)):
  1053        c = self._coder_impls[i]  # type cast
  1054        child_size, child_observables = (
  1055            c.get_estimated_size_and_observables(
  1056                values[i], nested=nested or i + 1 < len(self._coder_impls)))
  1057        estimated_size += child_size
  1058        observables += child_observables
  1059      return estimated_size, observables
  1060  
  1061  
  1062  class AvroCoderImpl(SimpleCoderImpl):
  1063    """For internal use only; no backwards-compatibility guarantees."""
  1064    def __init__(self, schema):
  1065      self.parsed_schema = parse_schema(json.loads(schema))
  1066  
  1067    def encode(self, value):
  1068      assert issubclass(type(value), AvroRecord)
  1069      with BytesIO() as buf:
  1070        schemaless_writer(buf, self.parsed_schema, value.record)
  1071        return buf.getvalue()
  1072  
  1073    def decode(self, encoded):
  1074      with BytesIO(encoded) as buf:
  1075        return AvroRecord(schemaless_reader(buf, self.parsed_schema))
  1076  
  1077  
  1078  class TupleCoderImpl(AbstractComponentCoderImpl):
  1079    """A coder for tuple objects."""
  1080    def _extract_components(self, value):
  1081      return tuple(value)
  1082  
  1083    def _construct_from_components(self, components):
  1084      return tuple(components)
  1085  
  1086  
  1087  class _ConcatSequence(object):
  1088    def __init__(self, head, tail):
  1089      # type: (Iterable[Any], Iterable[Any]) -> None
  1090      self._head = head
  1091      self._tail = tail
  1092  
  1093    def __iter__(self):
  1094      # type: () -> Iterator[Any]
  1095      for elem in self._head:
  1096        yield elem
  1097      for elem in self._tail:
  1098        yield elem
  1099  
  1100    def __eq__(self, other):
  1101      return list(self) == list(other)
  1102  
  1103    def __hash__(self):
  1104      raise NotImplementedError
  1105  
  1106    def __reduce__(self):
  1107      return list, (list(self), )
  1108  
  1109  
  1110  FastPrimitivesCoderImpl.register_iterable_like_type(_ConcatSequence)
  1111  
  1112  
  1113  class SequenceCoderImpl(StreamCoderImpl):
  1114    """For internal use only; no backwards-compatibility guarantees.
  1115  
  1116    A coder for sequences.
  1117  
  1118    If the length of the sequence in known we encode the length as a 32 bit
  1119    ``int`` followed by the encoded bytes.
  1120  
  1121    If the length of the sequence is unknown, we encode the length as ``-1``
  1122    followed by the encoding of elements buffered up to 64K bytes before prefixing
  1123    the count of number of elements. A ``0`` is encoded at the end to indicate the
  1124    end of stream.
  1125  
  1126    The resulting encoding would look like this::
  1127  
  1128      -1
  1129      countA element(0) element(1) ... element(countA - 1)
  1130      countB element(0) element(1) ... element(countB - 1)
  1131      ...
  1132      countX element(0) element(1) ... element(countX - 1)
  1133      0
  1134  
  1135    If writing to state is enabled, the final terminating 0 will instead be
  1136    repaced with::
  1137  
  1138      varInt64(-1)
  1139      len(state_token)
  1140      state_token
  1141  
  1142    where state_token is a bytes object used to retrieve the remainder of the
  1143    iterable via the state API.
  1144    """
  1145  
  1146    # Default buffer size of 64kB of handling iterables of unknown length.
  1147    _DEFAULT_BUFFER_SIZE = 64 * 1024
  1148  
  1149    def __init__(
  1150        self,
  1151        elem_coder,  # type: CoderImpl
  1152        read_state=None,  # type: Optional[IterableStateReader]
  1153        write_state=None,  # type: Optional[IterableStateWriter]
  1154        write_state_threshold=0  # type: int
  1155    ):
  1156      self._elem_coder = elem_coder
  1157      self._read_state = read_state
  1158      self._write_state = write_state
  1159      self._write_state_threshold = write_state_threshold
  1160  
  1161    def _construct_from_sequence(self, values):
  1162      raise NotImplementedError
  1163  
  1164    def encode_to_stream(self, value, out, nested):
  1165      # type: (Sequence, create_OutputStream, bool) -> None
  1166      # Compatible with Java's IterableLikeCoder.
  1167      if hasattr(value, '__len__') and self._write_state is None:
  1168        out.write_bigendian_int32(len(value))
  1169        for elem in value:
  1170          self._elem_coder.encode_to_stream(elem, out, True)
  1171      else:
  1172        # We don't know the size without traversing it so use a fixed size buffer
  1173        # and encode as many elements as possible into it before outputting
  1174        # the size followed by the elements.
  1175  
  1176        # -1 to indicate that the length is not known.
  1177        out.write_bigendian_int32(-1)
  1178        buffer = create_OutputStream()
  1179        if self._write_state is None:
  1180          target_buffer_size = self._DEFAULT_BUFFER_SIZE
  1181        else:
  1182          target_buffer_size = min(
  1183              self._DEFAULT_BUFFER_SIZE, self._write_state_threshold)
  1184        prev_index = index = -1
  1185        # Don't want to miss out on fast list iteration optimization.
  1186        value_iter = value if isinstance(value, (list, tuple)) else iter(value)
  1187        start_size = out.size()
  1188        for elem in value_iter:
  1189          index += 1
  1190          self._elem_coder.encode_to_stream(elem, buffer, True)
  1191          if buffer.size() > target_buffer_size:
  1192            out.write_var_int64(index - prev_index)
  1193            out.write(buffer.get())
  1194            prev_index = index
  1195            buffer = create_OutputStream()
  1196            if (self._write_state is not None and
  1197                out.size() - start_size > self._write_state_threshold):
  1198              tail = (
  1199                  value_iter[index +
  1200                             1:] if isinstance(value_iter,
  1201                                               (list, tuple)) else value_iter)
  1202              state_token = self._write_state(tail, self._elem_coder)
  1203              out.write_var_int64(-1)
  1204              out.write(state_token, True)
  1205              break
  1206        else:
  1207          if index > prev_index:
  1208            out.write_var_int64(index - prev_index)
  1209            out.write(buffer.get())
  1210          out.write_var_int64(0)
  1211  
  1212    def decode_from_stream(self, in_stream, nested):
  1213      # type: (create_InputStream, bool) -> Sequence
  1214      size = in_stream.read_bigendian_int32()
  1215  
  1216      if size >= 0:
  1217        elements = [
  1218            self._elem_coder.decode_from_stream(in_stream, True)
  1219            for _ in range(size)
  1220        ]  # type: Iterable[Any]
  1221      else:
  1222        elements = []
  1223        count = in_stream.read_var_int64()
  1224        while count > 0:
  1225          for _ in range(count):
  1226            elements.append(self._elem_coder.decode_from_stream(in_stream, True))
  1227          count = in_stream.read_var_int64()
  1228  
  1229        if count == -1:
  1230          if self._read_state is None:
  1231            raise ValueError(
  1232                'Cannot read state-written iterable without state reader.')
  1233  
  1234          state_token = in_stream.read_all(True)
  1235          elements = _ConcatSequence(
  1236              elements, self._read_state(state_token, self._elem_coder))
  1237  
  1238      return self._construct_from_sequence(elements)
  1239  
  1240    def estimate_size(self, value, nested=False):
  1241      # type: (Any, bool) -> int
  1242  
  1243      """Estimates the encoded size of the given value, in bytes."""
  1244      # TODO(ccy): This ignores element sizes.
  1245      estimated_size, _ = (self.get_estimated_size_and_observables(value))
  1246      return estimated_size
  1247  
  1248    def get_estimated_size_and_observables(self, value, nested=False):
  1249      # type: (Any, bool) -> Tuple[int, Observables]
  1250  
  1251      """Returns estimated size of value along with any nested observables."""
  1252      estimated_size = 0
  1253      # Size of 32-bit integer storing number of elements.
  1254      estimated_size += 4
  1255      if isinstance(value, observable.ObservableMixin):
  1256        return estimated_size, [(value, self._elem_coder)]
  1257  
  1258      observables = []  # type: Observables
  1259      for elem in value:
  1260        child_size, child_observables = (
  1261            self._elem_coder.get_estimated_size_and_observables(
  1262                elem, nested=True))
  1263        estimated_size += child_size
  1264        observables += child_observables
  1265      # TODO: (https://github.com/apache/beam/issues/18169) Update to use an
  1266      # accurate count depending on size and count, currently we are
  1267      # underestimating the size by up to 10 bytes per block of data since we are
  1268      # not including the count prefix which occurs at most once per 64k of data
  1269      # and is upto 10 bytes long. The upper bound of the underestimate is
  1270      # 10 / 65536 ~= 0.0153% of the actual size.
  1271      # TODO: More efficient size estimation in the case of state-backed
  1272      # iterables.
  1273      return estimated_size, observables
  1274  
  1275  
  1276  class TupleSequenceCoderImpl(SequenceCoderImpl):
  1277    """For internal use only; no backwards-compatibility guarantees.
  1278  
  1279    A coder for homogeneous tuple objects."""
  1280    def _construct_from_sequence(self, components):
  1281      return tuple(components)
  1282  
  1283  
  1284  class _AbstractIterable(object):
  1285    """Wraps an iterable hiding methods that might not always be available."""
  1286    def __init__(self, contents):
  1287      self._contents = contents
  1288  
  1289    def __iter__(self):
  1290      return iter(self._contents)
  1291  
  1292    def __repr__(self):
  1293      head = [repr(e) for e in itertools.islice(self, 4)]
  1294      if len(head) == 4:
  1295        head[-1] = '...'
  1296      return '_AbstractIterable([%s])' % ', '.join(head)
  1297  
  1298    # Mostly useful for tests.
  1299    def __eq__(left, right):
  1300      end = object()
  1301      for a, b in itertools.zip_longest(left, right, fillvalue=end):
  1302        if a != b:
  1303          return False
  1304      return True
  1305  
  1306  
  1307  FastPrimitivesCoderImpl.register_iterable_like_type(_AbstractIterable)
  1308  
  1309  # TODO(https://github.com/apache/beam/issues/21167): Enable using abstract
  1310  # iterables permanently
  1311  _iterable_coder_uses_abstract_iterable_by_default = False
  1312  
  1313  
  1314  class IterableCoderImpl(SequenceCoderImpl):
  1315    """For internal use only; no backwards-compatibility guarantees.
  1316  
  1317    A coder for homogeneous iterable objects."""
  1318    def __init__(self, *args, use_abstract_iterable=None, **kwargs):
  1319      super().__init__(*args, **kwargs)
  1320      if use_abstract_iterable is None:
  1321        use_abstract_iterable = _iterable_coder_uses_abstract_iterable_by_default
  1322      self._use_abstract_iterable = use_abstract_iterable
  1323  
  1324    def _construct_from_sequence(self, components):
  1325      if self._use_abstract_iterable:
  1326        return _AbstractIterable(components)
  1327      else:
  1328        return components
  1329  
  1330  
  1331  class ListCoderImpl(SequenceCoderImpl):
  1332    """For internal use only; no backwards-compatibility guarantees.
  1333  
  1334    A coder for homogeneous list objects."""
  1335    def _construct_from_sequence(self, components):
  1336      return components if isinstance(components, list) else list(components)
  1337  
  1338  
  1339  class PaneInfoEncoding(object):
  1340    """For internal use only; no backwards-compatibility guarantees.
  1341  
  1342    Encoding used to describe a PaneInfo descriptor.  A PaneInfo descriptor
  1343    can be encoded in three different ways: with a single byte (FIRST), with a
  1344    single byte followed by a varint describing a single index (ONE_INDEX) or
  1345    with a single byte followed by two varints describing two separate indices:
  1346    the index and nonspeculative index.
  1347    """
  1348  
  1349    FIRST = 0
  1350    ONE_INDEX = 1
  1351    TWO_INDICES = 2
  1352  
  1353  
  1354  # These are cdef'd to ints to optimized the common case.
  1355  PaneInfoTiming_UNKNOWN = windowed_value.PaneInfoTiming.UNKNOWN
  1356  PaneInfoEncoding_FIRST = PaneInfoEncoding.FIRST
  1357  
  1358  
  1359  class PaneInfoCoderImpl(StreamCoderImpl):
  1360    """For internal use only; no backwards-compatibility guarantees.
  1361  
  1362    Coder for a PaneInfo descriptor."""
  1363    def _choose_encoding(self, value):
  1364      if ((value._index == 0 and value._nonspeculative_index == 0) or
  1365          value._timing == PaneInfoTiming_UNKNOWN):
  1366        return PaneInfoEncoding_FIRST
  1367      elif (value._index == value._nonspeculative_index or
  1368            value._timing == windowed_value.PaneInfoTiming.EARLY):
  1369        return PaneInfoEncoding.ONE_INDEX
  1370      else:
  1371        return PaneInfoEncoding.TWO_INDICES
  1372  
  1373    def encode_to_stream(self, value, out, nested):
  1374      # type: (windowed_value.PaneInfo, create_OutputStream, bool) -> None
  1375      pane_info = value  # cast
  1376      encoding_type = self._choose_encoding(pane_info)
  1377      out.write_byte(pane_info._encoded_byte | (encoding_type << 4))
  1378      if encoding_type == PaneInfoEncoding_FIRST:
  1379        return
  1380      elif encoding_type == PaneInfoEncoding.ONE_INDEX:
  1381        out.write_var_int64(value.index)
  1382      elif encoding_type == PaneInfoEncoding.TWO_INDICES:
  1383        out.write_var_int64(value.index)
  1384        out.write_var_int64(value.nonspeculative_index)
  1385      else:
  1386        raise NotImplementedError('Invalid PaneInfoEncoding: %s' % encoding_type)
  1387  
  1388    def decode_from_stream(self, in_stream, nested):
  1389      # type: (create_InputStream, bool) -> windowed_value.PaneInfo
  1390      encoded_first_byte = in_stream.read_byte()
  1391      base = windowed_value._BYTE_TO_PANE_INFO[encoded_first_byte & 0xF]
  1392      assert base is not None
  1393      encoding_type = encoded_first_byte >> 4
  1394      if encoding_type == PaneInfoEncoding_FIRST:
  1395        return base
  1396      elif encoding_type == PaneInfoEncoding.ONE_INDEX:
  1397        index = in_stream.read_var_int64()
  1398        if base.timing == windowed_value.PaneInfoTiming.EARLY:
  1399          nonspeculative_index = -1
  1400        else:
  1401          nonspeculative_index = index
  1402      elif encoding_type == PaneInfoEncoding.TWO_INDICES:
  1403        index = in_stream.read_var_int64()
  1404        nonspeculative_index = in_stream.read_var_int64()
  1405      else:
  1406        raise NotImplementedError('Invalid PaneInfoEncoding: %s' % encoding_type)
  1407      return windowed_value.PaneInfo(
  1408          base.is_first, base.is_last, base.timing, index, nonspeculative_index)
  1409  
  1410    def estimate_size(self, value, nested=False):
  1411      # type: (Any, bool) -> int
  1412  
  1413      """Estimates the encoded size of the given value, in bytes."""
  1414      size = 1
  1415      encoding_type = self._choose_encoding(value)
  1416      if encoding_type == PaneInfoEncoding.ONE_INDEX:
  1417        size += get_varint_size(value.index)
  1418      elif encoding_type == PaneInfoEncoding.TWO_INDICES:
  1419        size += get_varint_size(value.index)
  1420        size += get_varint_size(value.nonspeculative_index)
  1421      return size
  1422  
  1423  
  1424  class WindowedValueCoderImpl(StreamCoderImpl):
  1425    """For internal use only; no backwards-compatibility guarantees.
  1426  
  1427    A coder for windowed values."""
  1428  
  1429    # Ensure that lexicographic ordering of the bytes corresponds to
  1430    # chronological order of timestamps.
  1431    # TODO(https://github.com/apache/beam/issues/18190): Clean this up once we
  1432    # have a BEAM wide consensus on byte representation of timestamps.
  1433    def _to_normal_time(self, value):
  1434      """Convert "lexicographically ordered unsigned" to signed."""
  1435      return value - _TIME_SHIFT
  1436  
  1437    def _from_normal_time(self, value):
  1438      """Convert signed to "lexicographically ordered unsigned"."""
  1439      return value + _TIME_SHIFT
  1440  
  1441    def __init__(self, value_coder, timestamp_coder, window_coder):
  1442      # TODO(lcwik): Remove the timestamp coder field
  1443      self._value_coder = value_coder
  1444      self._timestamp_coder = timestamp_coder
  1445      self._windows_coder = TupleSequenceCoderImpl(window_coder)
  1446      self._pane_info_coder = PaneInfoCoderImpl()
  1447  
  1448    def encode_to_stream(self, value, out, nested):
  1449      # type: (windowed_value.WindowedValue, create_OutputStream, bool) -> None
  1450      wv = value  # type cast
  1451      # Avoid creation of Timestamp object.
  1452      restore_sign = -1 if wv.timestamp_micros < 0 else 1
  1453      out.write_bigendian_uint64(
  1454          # Convert to postive number and divide, since python rounds off to the
  1455          # lower negative number. For ex: -3 / 2 = -2, but we expect it to be -1,
  1456          # to be consistent across SDKs.
  1457          # TODO(https://github.com/apache/beam/issues/18190): Clean this up once
  1458          # we have a BEAM wide consensus on precision of timestamps.
  1459          self._from_normal_time(
  1460              restore_sign * (
  1461                  abs(
  1462                      MIN_TIMESTAMP_micros if wv.timestamp_micros <
  1463                      MIN_TIMESTAMP_micros else wv.timestamp_micros) // 1000)))
  1464      self._windows_coder.encode_to_stream(wv.windows, out, True)
  1465      # Default PaneInfo encoded byte representing NO_FIRING.
  1466      self._pane_info_coder.encode_to_stream(wv.pane_info, out, True)
  1467      self._value_coder.encode_to_stream(wv.value, out, nested)
  1468  
  1469    def decode_from_stream(self, in_stream, nested):
  1470      # type: (create_InputStream, bool) -> windowed_value.WindowedValue
  1471      timestamp = self._to_normal_time(in_stream.read_bigendian_uint64())
  1472      # Restore MIN/MAX timestamps to their actual values as encoding incurs loss
  1473      # of precision while converting to millis.
  1474      # Note: This is only a best effort here as there is no way to know if these
  1475      # were indeed MIN/MAX timestamps.
  1476      # TODO(https://github.com/apache/beam/issues/18190): Clean this up once we
  1477      # have a BEAM wide consensus on precision of timestamps.
  1478      if timestamp <= -(abs(MIN_TIMESTAMP_micros) // 1000):
  1479        timestamp = MIN_TIMESTAMP_micros
  1480      elif timestamp >= MAX_TIMESTAMP_micros // 1000:
  1481        timestamp = MAX_TIMESTAMP_micros
  1482      else:
  1483        timestamp *= 1000
  1484  
  1485      windows = self._windows_coder.decode_from_stream(in_stream, True)
  1486      # Read PaneInfo encoded byte.
  1487      pane_info = self._pane_info_coder.decode_from_stream(in_stream, True)
  1488      value = self._value_coder.decode_from_stream(in_stream, nested)
  1489      return windowed_value.create(
  1490          value,
  1491          timestamp,  # Avoid creation of Timestamp object.
  1492          windows,
  1493          pane_info)
  1494  
  1495    def get_estimated_size_and_observables(self, value, nested=False):
  1496      # type: (Any, bool) -> Tuple[int, Observables]
  1497  
  1498      """Returns estimated size of value along with any nested observables."""
  1499      if isinstance(value, observable.ObservableMixin):
  1500        # Should never be here.
  1501        # TODO(robertwb): Remove when coders are set correctly.
  1502        return 0, [(value, self._value_coder)]
  1503      estimated_size = 0
  1504      observables = []  # type: Observables
  1505      value_estimated_size, value_observables = (
  1506          self._value_coder.get_estimated_size_and_observables(
  1507              value.value, nested=nested))
  1508      estimated_size += value_estimated_size
  1509      observables += value_observables
  1510      estimated_size += (
  1511          self._timestamp_coder.estimate_size(value.timestamp, nested=True))
  1512      estimated_size += (
  1513          self._windows_coder.estimate_size(value.windows, nested=True))
  1514      estimated_size += (
  1515          self._pane_info_coder.estimate_size(value.pane_info, nested=True))
  1516      return estimated_size, observables
  1517  
  1518  
  1519  class ParamWindowedValueCoderImpl(WindowedValueCoderImpl):
  1520    """For internal use only; no backwards-compatibility guarantees.
  1521  
  1522    A coder for windowed values with constant timestamp, windows and
  1523    pane info. The coder drops timestamp, windows and pane info during
  1524    encoding, and uses the supplied parameterized timestamp, windows
  1525    and pane info values during decoding when reconstructing the windowed
  1526    value."""
  1527    def __init__(self, value_coder, window_coder, payload):
  1528      super().__init__(value_coder, TimestampCoderImpl(), window_coder)
  1529      self._timestamp, self._windows, self._pane_info = self._from_proto(
  1530          payload, window_coder)
  1531  
  1532    def _from_proto(self, payload, window_coder):
  1533      windowed_value_coder = WindowedValueCoderImpl(
  1534          BytesCoderImpl(), TimestampCoderImpl(), window_coder)
  1535      wv = windowed_value_coder.decode(payload)
  1536      return wv.timestamp_micros, wv.windows, wv.pane_info
  1537  
  1538    def encode_to_stream(self, value, out, nested):
  1539      wv = value  # type cast
  1540      self._value_coder.encode_to_stream(wv.value, out, nested)
  1541  
  1542    def decode_from_stream(self, in_stream, nested):
  1543      value = self._value_coder.decode_from_stream(in_stream, nested)
  1544      return windowed_value.create(
  1545          value, self._timestamp, self._windows, self._pane_info)
  1546  
  1547    def get_estimated_size_and_observables(self, value, nested=False):
  1548      """Returns estimated size of value along with any nested observables."""
  1549      if isinstance(value, observable.ObservableMixin):
  1550        # Should never be here.
  1551        # TODO(robertwb): Remove when coders are set correctly.
  1552        return 0, [(value, self._value_coder)]
  1553      estimated_size = 0
  1554      observables = []
  1555      value_estimated_size, value_observables = (
  1556          self._value_coder.get_estimated_size_and_observables(
  1557              value.value, nested=nested))
  1558      estimated_size += value_estimated_size
  1559      observables += value_observables
  1560      return estimated_size, observables
  1561  
  1562  
  1563  class LengthPrefixCoderImpl(StreamCoderImpl):
  1564    """For internal use only; no backwards-compatibility guarantees.
  1565  
  1566    Coder which prefixes the length of the encoded object in the stream."""
  1567    def __init__(self, value_coder):
  1568      # type: (CoderImpl) -> None
  1569      self._value_coder = value_coder
  1570  
  1571    def encode_to_stream(self, value, out, nested):
  1572      # type: (Any, create_OutputStream, bool) -> None
  1573      encoded_value = self._value_coder.encode(value)
  1574      out.write_var_int64(len(encoded_value))
  1575      out.write(encoded_value)
  1576  
  1577    def decode_from_stream(self, in_stream, nested):
  1578      # type: (create_InputStream, bool) -> Any
  1579      value_length = in_stream.read_var_int64()
  1580      return self._value_coder.decode(in_stream.read(value_length))
  1581  
  1582    def estimate_size(self, value, nested=False):
  1583      # type: (Any, bool) -> int
  1584      value_size = self._value_coder.estimate_size(value)
  1585      return get_varint_size(value_size) + value_size
  1586  
  1587  
  1588  class ShardedKeyCoderImpl(StreamCoderImpl):
  1589    """For internal use only; no backwards-compatibility guarantees.
  1590  
  1591    A coder for sharded user keys.
  1592  
  1593    The encoding and decoding should follow the order:
  1594        shard id byte string
  1595        encoded user key
  1596    """
  1597    def __init__(self, key_coder_impl):
  1598      self._shard_id_coder_impl = BytesCoderImpl()
  1599      self._key_coder_impl = key_coder_impl
  1600  
  1601    def encode_to_stream(self, value, out, nested):
  1602      # type: (ShardedKey, create_OutputStream, bool) -> None
  1603      self._shard_id_coder_impl.encode_to_stream(value._shard_id, out, True)
  1604      self._key_coder_impl.encode_to_stream(value.key, out, True)
  1605  
  1606    def decode_from_stream(self, in_stream, nested):
  1607      # type: (create_InputStream, bool) -> ShardedKey
  1608      shard_id = self._shard_id_coder_impl.decode_from_stream(in_stream, True)
  1609      key = self._key_coder_impl.decode_from_stream(in_stream, True)
  1610      return ShardedKey(key=key, shard_id=shard_id)
  1611  
  1612    def estimate_size(self, value, nested=False):
  1613      # type: (Any, bool) -> int
  1614      estimated_size = 0
  1615      estimated_size += (
  1616          self._shard_id_coder_impl.estimate_size(value._shard_id, nested=True))
  1617      estimated_size += (
  1618          self._key_coder_impl.estimate_size(value.key, nested=True))
  1619      return estimated_size
  1620  
  1621  
  1622  class TimestampPrefixingWindowCoderImpl(StreamCoderImpl):
  1623    """For internal use only; no backwards-compatibility guarantees.
  1624  
  1625    A coder for custom window types, which prefix required max_timestamp to
  1626    encoded original window.
  1627  
  1628    The coder encodes and decodes custom window types with following format:
  1629      window's max_timestamp()
  1630      encoded window using it's own coder.
  1631    """
  1632    def __init__(self, window_coder_impl: CoderImpl) -> None:
  1633      self._window_coder_impl = window_coder_impl
  1634  
  1635    def encode_to_stream(self, value, stream, nested):
  1636      TimestampCoderImpl().encode_to_stream(value.max_timestamp(), stream, nested)
  1637      self._window_coder_impl.encode_to_stream(value, stream, nested)
  1638  
  1639    def decode_from_stream(self, stream, nested):
  1640      TimestampCoderImpl().decode_from_stream(stream, nested)
  1641      return self._window_coder_impl.decode_from_stream(stream, nested)
  1642  
  1643    def estimate_size(self, value: Any, nested: bool = False) -> int:
  1644      estimated_size = 0
  1645      estimated_size += TimestampCoderImpl().estimate_size(value)
  1646      estimated_size += self._window_coder_impl.estimate_size(value, nested)
  1647      return estimated_size
  1648  
  1649  
  1650  row_coders_registered = False
  1651  
  1652  
  1653  class RowColumnEncoder:
  1654    ROW_ENCODERS = {1: 12345}
  1655  
  1656    @classmethod
  1657    def register(cls, field_type, coder_impl):
  1658      cls.ROW_ENCODERS[field_type, coder_impl] = cls
  1659  
  1660    @classmethod
  1661    def create(cls, field_type, coder_impl, column):
  1662      global row_coders_registered
  1663      if not row_coders_registered:
  1664        try:
  1665          # pylint: disable=unused-import
  1666          from apache_beam.coders import coder_impl_row_encoders
  1667        except ImportError:
  1668          pass
  1669        row_coders_registered = True
  1670      return cls.ROW_ENCODERS.get((field_type, column.dtype),
  1671                                  GenericRowColumnEncoder)(coder_impl, column)
  1672  
  1673    def null_flags(self):
  1674      raise NotImplementedError(type(self))
  1675  
  1676    def encode_to_stream(self, index, out):
  1677      raise NotImplementedError(type(self))
  1678  
  1679    def decode_from_stream(self, index, in_stream):
  1680      raise NotImplementedError(type(self))
  1681  
  1682    def finalize_write(self):
  1683      raise NotImplementedError(type(self))
  1684  
  1685  
  1686  class GenericRowColumnEncoder(RowColumnEncoder):
  1687    def __init__(self, coder_impl, column):
  1688      self.coder_impl = coder_impl
  1689      self.column = column
  1690  
  1691    def null_flags(self):
  1692      # pylint: disable=singleton-comparison
  1693      return self.column == None  # This is an array.
  1694  
  1695    def encode_to_stream(self, index, out):
  1696      self.coder_impl.encode_to_stream(self.column[index], out, True)
  1697  
  1698    def decode_from_stream(self, index, in_stream):
  1699      self.column[index] = self.coder_impl.decode_from_stream(in_stream, True)
  1700  
  1701    def finalize_write(self):
  1702      pass
  1703  
  1704  
  1705  class RowCoderImpl(StreamCoderImpl):
  1706    """For internal use only; no backwards-compatibility guarantees."""
  1707    def __init__(self, schema, components):
  1708      self.schema = schema
  1709      self.num_fields = len(self.schema.fields)
  1710      self.field_names = [f.name for f in self.schema.fields]
  1711      self.field_nullable = [field.type.nullable for field in self.schema.fields]
  1712      self.constructor = named_tuple_from_schema(schema)
  1713      self.encoding_positions = list(range(len(self.schema.fields)))
  1714      if self.schema.encoding_positions_set:
  1715        # should never be duplicate encoding positions.
  1716        enc_posx = list(
  1717            set(field.encoding_position for field in self.schema.fields))
  1718        if len(enc_posx) != len(self.schema.fields):
  1719          raise ValueError(
  1720              f'''Schema with id {schema.id} has encoding_positions_set=True,
  1721              but not all fields have encoding_position set''')
  1722        self.encoding_positions = list(
  1723            field.encoding_position for field in self.schema.fields)
  1724      self.encoding_positions_argsort = list(np.argsort(self.encoding_positions))
  1725      self.encoding_positions_are_trivial = self.encoding_positions == list(
  1726          range(len(self.encoding_positions)))
  1727      self.components = list(
  1728          components[self.encoding_positions.index(i)].get_impl()
  1729          for i in self.encoding_positions)
  1730      self.has_nullable_fields = any(
  1731          field.type.nullable for field in self.schema.fields)
  1732  
  1733    def encode_to_stream(self, value, out, nested):
  1734      out.write_var_int64(self.num_fields)
  1735      attrs = [getattr(value, name) for name in self.field_names]
  1736  
  1737      if self.has_nullable_fields:
  1738        any_nulls = False
  1739        for attr in attrs:
  1740          if attr is None:
  1741            any_nulls = True
  1742            break
  1743        if any_nulls:
  1744          out.write_var_int64((self.num_fields + 7) // 8)
  1745          # Pack the bits, little-endian, in consecutive bytes.
  1746          running = 0
  1747          for i, attr in enumerate(attrs):
  1748            if i and i % 8 == 0:
  1749              out.write_byte(running)
  1750              running = 0
  1751            running |= (attr is None) << (i % 8)
  1752          out.write_byte(running)
  1753        else:
  1754          out.write_byte(0)
  1755      else:
  1756        out.write_byte(0)
  1757  
  1758      for i in range(self.num_fields):
  1759        if not self.encoding_positions_are_trivial:
  1760          i = self.encoding_positions_argsort[i]
  1761        attr = attrs[i]
  1762        if attr is None:
  1763          if not self.field_nullable[i]:
  1764            raise ValueError(
  1765                "Attempted to encode null for non-nullable field \"{}\".".format(
  1766                    self.schema.fields[i].name))
  1767          continue
  1768        component_coder = self.components[i]  # for typing
  1769        component_coder.encode_to_stream(attr, out, True)
  1770  
  1771    def _row_column_encoders(self, columns):
  1772      return [
  1773          RowColumnEncoder.create(
  1774              self.schema.fields[i].type.atomic_type,
  1775              self.components[i],
  1776              columns[name]) for i,
  1777          name in enumerate(self.field_names)
  1778      ]
  1779  
  1780    def encode_batch_to_stream(self, columns: Dict[str, np.ndarray], out):
  1781      attrs = self._row_column_encoders(columns)
  1782      n = len(next(iter(columns.values())))
  1783      if self.has_nullable_fields:
  1784        null_flags_py = np.zeros((n, self.num_fields), dtype=np.uint8)
  1785        null_bits_len = (self.num_fields + 7) // 8
  1786        null_bits_py = np.zeros((n, null_bits_len), dtype=np.uint8)
  1787        for i, attr in enumerate(attrs):
  1788          attr_null_flags = attr.null_flags()
  1789          if attr_null_flags is not None and attr_null_flags.any():
  1790            null_flags_py[:, i] = attr_null_flags
  1791            null_bits_py[:, i // 8] |= attr_null_flags << np.uint8(i % 8)
  1792        has_null_bits = (null_bits_py.sum(axis=1) != 0).astype(np.uint8)
  1793        null_bits = null_bits_py
  1794        null_flags = null_flags_py
  1795      else:
  1796        has_null_bits = np.zeros((n, ), dtype=np.uint8)
  1797  
  1798      for k in range(n):
  1799        out.write_var_int64(self.num_fields)
  1800        if has_null_bits[k]:
  1801          out.write_byte(null_bits_len)
  1802          for i in range(null_bits_len):
  1803            out.write_byte(null_bits[k, i])
  1804        else:
  1805          out.write_byte(0)
  1806        for i in range(self.num_fields):
  1807          if not self.encoding_positions_are_trivial:
  1808            i = self.encoding_positions_argsort[i]
  1809          if has_null_bits[k] and null_flags[k, i]:
  1810            if not self.field_nullable[i]:
  1811              raise ValueError(
  1812                  "Attempted to encode null for non-nullable field \"{}\".".
  1813                  format(self.schema.fields[i].name))
  1814          else:
  1815            cython.cast(RowColumnEncoder, attrs[i]).encode_to_stream(k, out)
  1816  
  1817    def decode_from_stream(self, in_stream, nested):
  1818      nvals = in_stream.read_var_int64()
  1819      null_mask_len = in_stream.read_var_int64()
  1820      if null_mask_len:
  1821        # pylint: disable=unused-variable
  1822        null_mask_c = null_mask_py = in_stream.read(null_mask_len)
  1823  
  1824      # Note that if this coder's schema has *fewer* attributes than the encoded
  1825      # value, we just need to ignore the additional values, which will occur
  1826      # here because we only decode as many values as we have coders for.
  1827  
  1828      sorted_components = []
  1829      for i in range(min(self.num_fields, nvals)):
  1830        if not self.encoding_positions_are_trivial:
  1831          i = self.encoding_positions_argsort[i]
  1832        if (null_mask_len and i >> 3 < null_mask_len and
  1833            null_mask_c[i >> 3] & (0x01 << (i & 0x07))):
  1834          item = None
  1835        else:
  1836          component_coder = self.components[i]  # for typing
  1837          item = component_coder.decode_from_stream(in_stream, True)
  1838        sorted_components.append(item)
  1839  
  1840      # If this coder's schema has more attributes than the encoded value, then
  1841      # the schema must have changed. Populate the unencoded fields with nulls.
  1842      while len(sorted_components) < self.num_fields:
  1843        sorted_components.append(None)
  1844  
  1845      return self.constructor(
  1846          *(
  1847              sorted_components if self.encoding_positions_are_trivial else
  1848              [sorted_components[i] for i in self.encoding_positions]))
  1849  
  1850    def decode_batch_from_stream(self, dest: Dict[str, np.ndarray], in_stream):
  1851      attrs = self._row_column_encoders(dest)
  1852      n = len(next(iter(dest.values())))
  1853      for k in range(n):
  1854        if in_stream.size() == 0:
  1855          break
  1856        nvals = in_stream.read_var_int64()
  1857        null_mask_len = in_stream.read_var_int64()
  1858        if null_mask_len:
  1859          # pylint: disable=unused-variable
  1860          null_mask_c = null_mask = in_stream.read(null_mask_len)
  1861  
  1862        for i in range(min(self.num_fields, nvals)):
  1863          if not self.encoding_positions_are_trivial:
  1864            i = self.encoding_positions_argsort[i]
  1865          if (null_mask_len and i >> 3 < null_mask_len and
  1866              null_mask_c[i >> 3] & (0x01 << (i & 0x07))):
  1867            continue
  1868          else:
  1869            cython.cast(RowColumnEncoder,
  1870                        attrs[i]).decode_from_stream(k, in_stream)
  1871      else:
  1872        # Loop variable will be n-1 on normal exit.
  1873        k = n
  1874  
  1875      for attr in attrs:
  1876        attr.finalize_write()
  1877      return k
  1878  
  1879  
  1880  class LogicalTypeCoderImpl(StreamCoderImpl):
  1881    def __init__(self, logical_type, representation_coder):
  1882      self.logical_type = logical_type
  1883      self.representation_coder = representation_coder.get_impl()
  1884  
  1885    def encode_to_stream(self, value, out, nested):
  1886      return self.representation_coder.encode_to_stream(
  1887          self.logical_type.to_representation_type(value), out, nested)
  1888  
  1889    def decode_from_stream(self, in_stream, nested):
  1890      return self.logical_type.to_language_type(
  1891          self.representation_coder.decode_from_stream(in_stream, nested))
  1892  
  1893  
  1894  class BigIntegerCoderImpl(StreamCoderImpl):
  1895    """For internal use only; no backwards-compatibility guarantees.
  1896  
  1897    For interoperability with Java SDK, encoding needs to match that of the Java
  1898    SDK BigIntegerCoder."""
  1899    def encode_to_stream(self, value, out, nested):
  1900      # type: (int, create_OutputStream, bool) -> None
  1901      if value < 0:
  1902        byte_length = ((value + 1).bit_length() + 8) // 8
  1903      else:
  1904        byte_length = (value.bit_length() + 8) // 8
  1905      encoded_value = value.to_bytes(
  1906          length=byte_length, byteorder='big', signed=True)
  1907      out.write(encoded_value, nested)
  1908  
  1909    def decode_from_stream(self, in_stream, nested):
  1910      # type: (create_InputStream, bool) -> int
  1911      encoded_value = in_stream.read_all(nested)
  1912      return int.from_bytes(encoded_value, byteorder='big', signed=True)
  1913  
  1914  
  1915  class DecimalCoderImpl(StreamCoderImpl):
  1916    """For internal use only; no backwards-compatibility guarantees.
  1917  
  1918    For interoperability with Java SDK, encoding needs to match that of the Java
  1919    SDK BigDecimalCoder."""
  1920  
  1921    BIG_INT_CODER_IMPL = BigIntegerCoderImpl()
  1922  
  1923    def encode_to_stream(self, value, out, nested):
  1924      # type: (decimal.Decimal, create_OutputStream, bool) -> None
  1925      scale = -value.as_tuple().exponent
  1926      int_value = int(value.scaleb(scale))
  1927      out.write_var_int64(scale)
  1928      self.BIG_INT_CODER_IMPL.encode_to_stream(int_value, out, nested)
  1929  
  1930    def decode_from_stream(self, in_stream, nested):
  1931      # type: (create_InputStream, bool) -> decimal.Decimal
  1932      scale = in_stream.read_var_int64()
  1933      int_value = self.BIG_INT_CODER_IMPL.decode_from_stream(in_stream, nested)
  1934      value = decimal.Decimal(int_value).scaleb(-scale)
  1935      return value