github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/coders/row_coder.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  # pytype: skip-file
    19  
    20  from google.protobuf import json_format
    21  
    22  from apache_beam.coders import typecoders
    23  from apache_beam.coders.coder_impl import LogicalTypeCoderImpl
    24  from apache_beam.coders.coder_impl import RowCoderImpl
    25  from apache_beam.coders.coders import BigEndianShortCoder
    26  from apache_beam.coders.coders import BooleanCoder
    27  from apache_beam.coders.coders import BytesCoder
    28  from apache_beam.coders.coders import Coder
    29  from apache_beam.coders.coders import DecimalCoder
    30  from apache_beam.coders.coders import FastCoder
    31  from apache_beam.coders.coders import FloatCoder
    32  from apache_beam.coders.coders import IterableCoder
    33  from apache_beam.coders.coders import MapCoder
    34  from apache_beam.coders.coders import NullableCoder
    35  from apache_beam.coders.coders import SinglePrecisionFloatCoder
    36  from apache_beam.coders.coders import StrUtf8Coder
    37  from apache_beam.coders.coders import TimestampCoder
    38  from apache_beam.coders.coders import VarIntCoder
    39  from apache_beam.portability import common_urns
    40  from apache_beam.portability.api import schema_pb2
    41  from apache_beam.typehints import row_type
    42  from apache_beam.typehints.schemas import PYTHON_ANY_URN
    43  from apache_beam.typehints.schemas import LogicalType
    44  from apache_beam.typehints.schemas import named_tuple_from_schema
    45  from apache_beam.typehints.schemas import schema_from_element_type
    46  from apache_beam.utils import proto_utils
    47  
    48  __all__ = ["RowCoder"]
    49  
    50  
    51  class RowCoder(FastCoder):
    52    """ Coder for `typing.NamedTuple` instances.
    53  
    54    Implements the beam:coder:row:v1 standard coder spec.
    55    """
    56    def __init__(self, schema, force_deterministic=False):
    57      """Initializes a :class:`RowCoder`.
    58  
    59      Args:
    60        schema (apache_beam.portability.api.schema_pb2.Schema): The protobuf
    61          representation of the schema of the data that the RowCoder will be used
    62          to encode/decode.
    63      """
    64      self.schema = schema
    65  
    66      # Eagerly generate type hint to escalate any issues with the Schema proto
    67      self._type_hint = named_tuple_from_schema(self.schema)
    68  
    69      # Use non-null coders because null values are represented separately
    70      self.components = [
    71          _nonnull_coder_from_type(field.type) for field in self.schema.fields
    72      ]
    73      if force_deterministic:
    74        self.components = [
    75            c.as_deterministic_coder(force_deterministic) for c in self.components
    76        ]
    77      self.forced_deterministic = bool(force_deterministic)
    78  
    79    def _create_impl(self):
    80      return RowCoderImpl(self.schema, self.components)
    81  
    82    def is_deterministic(self):
    83      return all(c.is_deterministic() for c in self.components)
    84  
    85    def as_deterministic_coder(self, step_label, error_message=None):
    86      if self.is_deterministic():
    87        return self
    88      else:
    89        return RowCoder(self.schema, error_message or step_label)
    90  
    91    def to_type_hint(self):
    92      return self._type_hint
    93  
    94    def as_cloud_object(self, coders_context=None):
    95      value = super().as_cloud_object(coders_context)
    96  
    97      value['schema'] = json_format.MessageToJson(self.schema).encode('utf-8')
    98  
    99      return value
   100  
   101    def __hash__(self):
   102      return hash(self.schema.SerializeToString())
   103  
   104    def __eq__(self, other):
   105      return (
   106          type(self) == type(other) and self.schema == other.schema and
   107          self.forced_deterministic == other.forced_deterministic)
   108  
   109    def to_runner_api_parameter(self, unused_context):
   110      return (common_urns.coders.ROW.urn, self.schema, [])
   111  
   112    @staticmethod
   113    @Coder.register_urn(common_urns.coders.ROW.urn, schema_pb2.Schema)
   114    def from_runner_api_parameter(schema, components, unused_context):
   115      return RowCoder(schema)
   116  
   117    @classmethod
   118    def from_type_hint(cls, type_hint, registry):
   119      # TODO(https://github.com/apache/beam/issues/21541): Remove once all
   120      # runners are portable.
   121      if isinstance(type_hint, str):
   122        import importlib
   123        main_module = importlib.import_module('__main__')
   124        type_hint = getattr(main_module, type_hint, type_hint)
   125      schema = schema_from_element_type(type_hint)
   126      return cls(schema)
   127  
   128    @staticmethod
   129    def from_payload(payload):
   130      # type: (bytes) -> RowCoder
   131      return RowCoder(proto_utils.parse_Bytes(payload, schema_pb2.Schema))
   132  
   133    def __reduce__(self):
   134      # when pickling, use bytes representation of the schema. schema_pb2.Schema
   135      # objects cannot be pickled.
   136      return (RowCoder.from_payload, (self.schema.SerializeToString(), ))
   137  
   138  
   139  typecoders.registry.register_coder(row_type.RowTypeConstraint, RowCoder)
   140  typecoders.registry.register_coder(
   141      row_type.GeneratedClassRowTypeConstraint, RowCoder)
   142  
   143  
   144  def _coder_from_type(field_type):
   145    coder = _nonnull_coder_from_type(field_type)
   146    if field_type.nullable:
   147      return NullableCoder(coder)
   148    else:
   149      return coder
   150  
   151  
   152  def _nonnull_coder_from_type(field_type):
   153    type_info = field_type.WhichOneof("type_info")
   154    if type_info == "atomic_type":
   155      if field_type.atomic_type in (schema_pb2.INT32, schema_pb2.INT64):
   156        return VarIntCoder()
   157      if field_type.atomic_type == schema_pb2.INT16:
   158        return BigEndianShortCoder()
   159      elif field_type.atomic_type == schema_pb2.FLOAT:
   160        return SinglePrecisionFloatCoder()
   161      elif field_type.atomic_type == schema_pb2.DOUBLE:
   162        return FloatCoder()
   163      elif field_type.atomic_type == schema_pb2.STRING:
   164        return StrUtf8Coder()
   165      elif field_type.atomic_type == schema_pb2.BOOLEAN:
   166        return BooleanCoder()
   167      elif field_type.atomic_type == schema_pb2.BYTES:
   168        return BytesCoder()
   169    elif type_info == "array_type":
   170      return IterableCoder(_coder_from_type(field_type.array_type.element_type))
   171    elif type_info == "map_type":
   172      return MapCoder(
   173          _coder_from_type(field_type.map_type.key_type),
   174          _coder_from_type(field_type.map_type.value_type))
   175    elif type_info == "logical_type":
   176      if field_type.logical_type.urn == PYTHON_ANY_URN:
   177        # Special case for the Any logical type. Just use the default coder for an
   178        # unknown Python object.
   179        return typecoders.registry.get_coder(object)
   180      elif field_type.logical_type.urn == common_urns.millis_instant.urn:
   181        # Special case for millis instant logical type used to handle Java sdk's
   182        # millis Instant. It explicitly uses TimestampCoder which deals with fix
   183        # length 8-bytes big-endian-long instead of VarInt coder.
   184        return TimestampCoder()
   185      elif field_type.logical_type.urn == 'beam:logical_type:decimal:v1':
   186        return DecimalCoder()
   187  
   188      logical_type = LogicalType.from_runner_api(field_type.logical_type)
   189      return LogicalTypeCoder(
   190          logical_type, _coder_from_type(field_type.logical_type.representation))
   191    elif type_info == "row_type":
   192      return RowCoder(field_type.row_type.schema)
   193  
   194    # The Java SDK supports several more types, but the coders are not yet
   195    # standard, and are not implemented in Python.
   196    raise ValueError(
   197        "Encountered a type that is not currently supported by RowCoder: %s" %
   198        field_type)
   199  
   200  
   201  class LogicalTypeCoder(FastCoder):
   202    def __init__(self, logical_type, representation_coder):
   203      self.logical_type = logical_type
   204      self.representation_coder = representation_coder
   205  
   206    def _create_impl(self):
   207      return LogicalTypeCoderImpl(self.logical_type, self.representation_coder)
   208  
   209    def is_deterministic(self):
   210      return self.representation_coder.is_deterministic()
   211  
   212    def to_type_hint(self):
   213      return self.logical_type.language_type()