github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/coders/row_coder_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  # pytype: skip-file
    18  
    19  import logging
    20  import typing
    21  import unittest
    22  from itertools import chain
    23  
    24  import numpy as np
    25  from google.protobuf import json_format
    26  from numpy.testing import assert_array_equal
    27  
    28  import apache_beam as beam
    29  from apache_beam.coders import RowCoder
    30  from apache_beam.coders import coder_impl
    31  from apache_beam.coders.typecoders import registry as coders_registry
    32  from apache_beam.internal import pickler
    33  from apache_beam.portability.api import schema_pb2
    34  from apache_beam.testing.test_pipeline import TestPipeline
    35  from apache_beam.testing.util import assert_that
    36  from apache_beam.testing.util import equal_to
    37  from apache_beam.typehints.schemas import named_tuple_from_schema
    38  from apache_beam.typehints.schemas import typing_to_runner_api
    39  from apache_beam.utils.timestamp import Timestamp
    40  
    41  Person = typing.NamedTuple(
    42      "Person",
    43      [
    44          ("name", str),
    45          ("age", np.int32),
    46          ("address", typing.Optional[str]),
    47          ("aliases", typing.List[str]),
    48          ("knows_javascript", bool),
    49          ("payload", typing.Optional[bytes]),
    50          ("custom_metadata", typing.Mapping[str, int]),
    51          ("favorite_time", Timestamp),
    52      ])
    53  
    54  NullablePerson = typing.NamedTuple(
    55      "NullablePerson",
    56      [("name", typing.Optional[str]), ("age", np.int32),
    57       ("address", typing.Optional[str]), ("aliases", typing.List[str]),
    58       ("knows_javascript", bool), ("payload", typing.Optional[bytes]),
    59       ("custom_metadata", typing.Mapping[str, int]),
    60       ("favorite_time", typing.Optional[Timestamp]),
    61       ("one_more_field", typing.Optional[str])])
    62  
    63  
    64  class People(typing.NamedTuple):
    65    primary: Person
    66    partner: typing.Optional[Person]
    67  
    68  
    69  coders_registry.register_coder(Person, RowCoder)
    70  coders_registry.register_coder(People, RowCoder)
    71  
    72  
    73  class RowCoderTest(unittest.TestCase):
    74    JON_SNOW = Person(
    75        name="Jon Snow",
    76        age=np.int32(23),
    77        address=None,
    78        aliases=["crow", "wildling"],
    79        knows_javascript=False,
    80        payload=None,
    81        custom_metadata={},
    82        favorite_time=Timestamp.from_rfc3339('2016-03-18T23:22:59.123456Z'),
    83    )
    84    PEOPLE = [
    85        JON_SNOW,
    86        Person(
    87            "Daenerys Targaryen",
    88            np.int32(25),
    89            "Westeros",
    90            ["Mother of Dragons"],
    91            False,
    92            None,
    93            {"dragons": 3},
    94            Timestamp.from_rfc3339('1970-04-26T17:46:40Z'),
    95        ),
    96        Person(
    97            "Michael Bluth",
    98            np.int32(30),
    99            None, [],
   100            True,
   101            b"I've made a huge mistake", {},
   102            Timestamp.from_rfc3339('2020-08-12T15:51:00.032Z'))
   103    ]
   104  
   105    def test_row_accepts_trailing_zeros_truncated(self):
   106      expected_coder = RowCoder(
   107          typing_to_runner_api(NullablePerson).row_type.schema)
   108      person = NullablePerson(
   109          None,
   110          np.int32(25),
   111          "Westeros", ["Mother of Dragons"],
   112          False,
   113          None, {"dragons": 3},
   114          None,
   115          "NotNull")
   116      out = expected_coder.encode(person)
   117      # 9 fields, 1 null byte, field 0, 5, 7 are null
   118      new_payload = bytes([9, 1, 1 | 1 << 5 | 1 << 7]) + out[4:]
   119      new_value = expected_coder.decode(new_payload)
   120      self.assertEqual(person, new_value)
   121  
   122    def test_create_row_coder_from_named_tuple(self):
   123      expected_coder = RowCoder(typing_to_runner_api(Person).row_type.schema)
   124      real_coder = coders_registry.get_coder(Person)
   125  
   126      for test_case in self.PEOPLE:
   127        self.assertEqual(
   128            expected_coder.encode(test_case), real_coder.encode(test_case))
   129  
   130        self.assertEqual(
   131            test_case, real_coder.decode(real_coder.encode(test_case)))
   132  
   133    def test_create_row_coder_from_nested_named_tuple(self):
   134      expected_coder = RowCoder(typing_to_runner_api(People).row_type.schema)
   135      real_coder = coders_registry.get_coder(People)
   136  
   137      for primary in self.PEOPLE:
   138        for other in self.PEOPLE + [None]:
   139          test_case = People(primary=primary, partner=other)
   140          self.assertEqual(
   141              expected_coder.encode(test_case), real_coder.encode(test_case))
   142  
   143          self.assertEqual(
   144              test_case, real_coder.decode(real_coder.encode(test_case)))
   145  
   146    def test_create_row_coder_from_schema(self):
   147      schema = schema_pb2.Schema(
   148          id="person",
   149          fields=[
   150              schema_pb2.Field(
   151                  name="name",
   152                  type=schema_pb2.FieldType(atomic_type=schema_pb2.STRING)),
   153              schema_pb2.Field(
   154                  name="age",
   155                  type=schema_pb2.FieldType(atomic_type=schema_pb2.INT32)),
   156              schema_pb2.Field(
   157                  name="address",
   158                  type=schema_pb2.FieldType(
   159                      atomic_type=schema_pb2.STRING, nullable=True)),
   160              schema_pb2.Field(
   161                  name="aliases",
   162                  type=schema_pb2.FieldType(
   163                      array_type=schema_pb2.ArrayType(
   164                          element_type=schema_pb2.FieldType(
   165                              atomic_type=schema_pb2.STRING)))),
   166              schema_pb2.Field(
   167                  name="knows_javascript",
   168                  type=schema_pb2.FieldType(atomic_type=schema_pb2.BOOLEAN)),
   169              schema_pb2.Field(
   170                  name="payload",
   171                  type=schema_pb2.FieldType(
   172                      atomic_type=schema_pb2.BYTES, nullable=True)),
   173              schema_pb2.Field(
   174                  name="custom_metadata",
   175                  type=schema_pb2.FieldType(
   176                      map_type=schema_pb2.MapType(
   177                          key_type=schema_pb2.FieldType(
   178                              atomic_type=schema_pb2.STRING),
   179                          value_type=schema_pb2.FieldType(
   180                              atomic_type=schema_pb2.INT64),
   181                      ))),
   182              schema_pb2.Field(
   183                  name="favorite_time",
   184                  type=schema_pb2.FieldType(
   185                      logical_type=schema_pb2.LogicalType(
   186                          urn="beam:logical_type:micros_instant:v1",
   187                          representation=schema_pb2.FieldType(
   188                              row_type=schema_pb2.RowType(
   189                                  schema=schema_pb2.Schema(
   190                                      id="micros_instant",
   191                                      fields=[
   192                                          schema_pb2.Field(
   193                                              name="seconds",
   194                                              type=schema_pb2.FieldType(
   195                                                  atomic_type=schema_pb2.INT64)),
   196                                          schema_pb2.Field(
   197                                              name="micros",
   198                                              type=schema_pb2.FieldType(
   199                                                  atomic_type=schema_pb2.INT64)),
   200                                      ])))))),
   201          ])
   202      coder = RowCoder(schema)
   203  
   204      for test_case in self.PEOPLE:
   205        self.assertEqual(test_case, coder.decode(coder.encode(test_case)))
   206  
   207    @unittest.skip(
   208        "https://github.com/apache/beam/issues/19696 - Overflow behavior in "
   209        "VarIntCoder is currently inconsistent")
   210    def test_overflows(self):
   211      IntTester = typing.NamedTuple(
   212          'IntTester',
   213          [
   214              # TODO(https://github.com/apache/beam/issues/19815): Test int8 and
   215              # int16 here as well when those types are supported
   216              # ('i8', typing.Optional[np.int8]),
   217              # ('i16', typing.Optional[np.int16]),
   218              ('i32', typing.Optional[np.int32]),
   219              ('i64', typing.Optional[np.int64]),
   220          ])
   221  
   222      c = RowCoder.from_type_hint(IntTester, None)
   223  
   224      no_overflow = chain(
   225          (IntTester(i32=i, i64=None) for i in (-2**31, 2**31 - 1)),
   226          (IntTester(i32=None, i64=i) for i in (-2**63, 2**63 - 1)),
   227      )
   228  
   229      # Encode max/min ints to make sure they don't throw any error
   230      for case in no_overflow:
   231        c.encode(case)
   232  
   233      overflow = chain(
   234          (IntTester(i32=i, i64=None) for i in (-2**31 - 1, 2**31)),
   235          (IntTester(i32=None, i64=i) for i in (-2**63 - 1, 2**63)),
   236      )
   237  
   238      # Encode max+1/min-1 ints to make sure they DO throw an error
   239      # pylint: disable=cell-var-from-loop
   240      for case in overflow:
   241        self.assertRaises(OverflowError, lambda: c.encode(case))
   242  
   243    def test_none_in_non_nullable_field_throws(self):
   244      Test = typing.NamedTuple('Test', [('foo', str)])
   245  
   246      c = RowCoder.from_type_hint(Test, None)
   247      self.assertRaises(ValueError, lambda: c.encode(Test(foo=None)))
   248  
   249    def test_schema_remove_column(self):
   250      fields = [("field1", str), ("field2", str)]
   251      # new schema is missing one field that was in the old schema
   252      Old = typing.NamedTuple('Old', fields)
   253      New = typing.NamedTuple('New', fields[:-1])
   254  
   255      old_coder = RowCoder.from_type_hint(Old, None)
   256      new_coder = RowCoder.from_type_hint(New, None)
   257  
   258      self.assertEqual(
   259          New("foo"), new_coder.decode(old_coder.encode(Old("foo", "bar"))))
   260  
   261    def test_schema_add_column(self):
   262      fields = [("field1", str), ("field2", typing.Optional[str])]
   263      # new schema has one (optional) field that didn't exist in the old schema
   264      Old = typing.NamedTuple('Old', fields[:-1])
   265      New = typing.NamedTuple('New', fields)
   266  
   267      old_coder = RowCoder.from_type_hint(Old, None)
   268      new_coder = RowCoder.from_type_hint(New, None)
   269  
   270      self.assertEqual(
   271          New("bar", None), new_coder.decode(old_coder.encode(Old("bar"))))
   272  
   273    def test_schema_add_column_with_null_value(self):
   274      fields = [("field1", typing.Optional[str]), ("field2", str),
   275                ("field3", typing.Optional[str])]
   276      # new schema has one (optional) field that didn't exist in the old schema
   277      Old = typing.NamedTuple('Old', fields[:-1])
   278      New = typing.NamedTuple('New', fields)
   279  
   280      old_coder = RowCoder.from_type_hint(Old, None)
   281      new_coder = RowCoder.from_type_hint(New, None)
   282  
   283      self.assertEqual(
   284          New(None, "baz", None),
   285          new_coder.decode(old_coder.encode(Old(None, "baz"))))
   286  
   287    def test_row_coder_picklable(self):
   288      # occasionally coders can get pickled, RowCoder should be able to handle it
   289      coder = coders_registry.get_coder(Person)
   290      roundtripped = pickler.loads(pickler.dumps(coder))
   291  
   292      self.assertEqual(roundtripped, coder)
   293  
   294    def test_row_coder_in_pipeine(self):
   295      with TestPipeline() as p:
   296        res = (
   297            p
   298            | beam.Create(self.PEOPLE)
   299            | beam.Filter(lambda person: person.name == "Jon Snow"))
   300        assert_that(res, equal_to([self.JON_SNOW]))
   301  
   302    def test_row_coder_nested_struct(self):
   303      Pair = typing.NamedTuple('Pair', [('left', Person), ('right', Person)])
   304  
   305      value = Pair(self.PEOPLE[0], self.PEOPLE[1])
   306      coder = RowCoder(typing_to_runner_api(Pair).row_type.schema)
   307  
   308      self.assertEqual(value, coder.decode(coder.encode(value)))
   309  
   310    def test_encoding_position_reorder_fields(self):
   311      schema1 = schema_pb2.Schema(
   312          id="reorder_test_schema1",
   313          fields=[
   314              schema_pb2.Field(
   315                  name="f_int32",
   316                  type=schema_pb2.FieldType(atomic_type=schema_pb2.INT32),
   317              ),
   318              schema_pb2.Field(
   319                  name="f_str",
   320                  type=schema_pb2.FieldType(atomic_type=schema_pb2.STRING),
   321              ),
   322          ])
   323      schema2 = schema_pb2.Schema(
   324          id="reorder_test_schema2",
   325          encoding_positions_set=True,
   326          fields=[
   327              schema_pb2.Field(
   328                  name="f_str",
   329                  type=schema_pb2.FieldType(atomic_type=schema_pb2.STRING),
   330                  encoding_position=1,
   331              ),
   332              schema_pb2.Field(
   333                  name="f_int32",
   334                  type=schema_pb2.FieldType(atomic_type=schema_pb2.INT32),
   335                  encoding_position=0,
   336              ),
   337          ])
   338  
   339      RowSchema1 = named_tuple_from_schema(schema1)
   340      RowSchema2 = named_tuple_from_schema(schema2)
   341      roundtripped = RowCoder(schema2).decode(
   342          RowCoder(schema1).encode(RowSchema1(42, "Hello World!")))
   343  
   344      self.assertEqual(RowSchema2(f_int32=42, f_str="Hello World!"), roundtripped)
   345  
   346    def test_encoding_position_add_fields_and_reorder(self):
   347      old_schema = schema_pb2.Schema(
   348          id="add_test_old",
   349          fields=[
   350              schema_pb2.Field(
   351                  name="f_int32",
   352                  type=schema_pb2.FieldType(atomic_type=schema_pb2.INT32),
   353              ),
   354              schema_pb2.Field(
   355                  name="f_str",
   356                  type=schema_pb2.FieldType(atomic_type=schema_pb2.STRING),
   357              ),
   358          ])
   359      new_schema = schema_pb2.Schema(
   360          encoding_positions_set=True,
   361          id="add_test_new",
   362          fields=[
   363              schema_pb2.Field(
   364                  name="f_new_str",
   365                  type=schema_pb2.FieldType(
   366                      atomic_type=schema_pb2.STRING, nullable=True),
   367                  encoding_position=2,
   368              ),
   369              schema_pb2.Field(
   370                  name="f_int32",
   371                  type=schema_pb2.FieldType(atomic_type=schema_pb2.INT32),
   372                  encoding_position=0,
   373              ),
   374              schema_pb2.Field(
   375                  name="f_str",
   376                  type=schema_pb2.FieldType(atomic_type=schema_pb2.STRING),
   377                  encoding_position=1,
   378              ),
   379          ])
   380  
   381      Old = named_tuple_from_schema(old_schema)
   382      New = named_tuple_from_schema(new_schema)
   383      roundtripped = RowCoder(new_schema).decode(
   384          RowCoder(old_schema).encode(Old(42, "Hello World!")))
   385  
   386      self.assertEqual(
   387          New(f_new_str=None, f_int32=42, f_str="Hello World!"), roundtripped)
   388  
   389    def test_row_coder_fail_early_bad_schema(self):
   390      schema_proto = schema_pb2.Schema(
   391          fields=[
   392              schema_pb2.Field(
   393                  name="type_with_no_typeinfo", type=schema_pb2.FieldType())
   394          ],
   395          id='bad-schema')
   396  
   397      # Should raise an exception referencing the problem field
   398      self.assertRaisesRegex(
   399          ValueError, "type_with_no_typeinfo", lambda: RowCoder(schema_proto))
   400  
   401    def test_row_coder_cloud_object_schema(self):
   402      schema_proto = schema_pb2.Schema(id='some-cloud-object-schema')
   403      schema_proto_json = json_format.MessageToJson(schema_proto).encode('utf-8')
   404  
   405      coder = RowCoder(schema_proto)
   406  
   407      cloud_object = coder.as_cloud_object()
   408  
   409      self.assertEqual(schema_proto_json, cloud_object['schema'])
   410  
   411    def test_batch_encode_decode(self):
   412      coder = RowCoder(typing_to_runner_api(Person).row_type.schema).get_impl()
   413      seq_out = coder_impl.create_OutputStream()
   414      for person in self.PEOPLE:
   415        coder.encode_to_stream(person, seq_out, False)
   416  
   417      batch_out = coder_impl.create_OutputStream()
   418      columnar = {
   419          field: np.array([getattr(person, field) for person in self.PEOPLE],
   420                          ndmin=1,
   421                          dtype=object)
   422          for field in Person._fields
   423      }
   424      coder.encode_batch_to_stream(columnar, batch_out)
   425      if seq_out.get() != batch_out.get():
   426        a, b = seq_out.get(), batch_out.get()
   427        N = 25
   428        for k in range(0, max(len(a), len(b)), N):
   429          print(k, a[k:k + N] == b[k:k + N])
   430          print(a[k:k + N])
   431          print(b[k:k + N])
   432      self.assertEqual(seq_out.get(), batch_out.get())
   433  
   434      for size in [len(self.PEOPLE) - 1, len(self.PEOPLE), len(self.PEOPLE) + 1]:
   435        dest = {
   436            field: np.ndarray((size, ), dtype=a.dtype)
   437            for field,
   438            a in columnar.items()
   439        }
   440        n = min(size, len(self.PEOPLE))
   441        self.assertEqual(
   442            n,
   443            coder.decode_batch_from_stream(
   444                dest, coder_impl.create_InputStream(seq_out.get())))
   445        for field, a in columnar.items():
   446          assert_array_equal(a[:n], dest[field][:n])
   447  
   448  
   449  if __name__ == "__main__":
   450    logging.getLogger().setLevel(logging.INFO)
   451    unittest.main()