github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/coders/coders_test_common.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Tests common to all coder implementations."""
    19  # pytype: skip-file
    20  
    21  import base64
    22  import collections
    23  import enum
    24  import logging
    25  import math
    26  import unittest
    27  from decimal import Decimal
    28  from typing import Any
    29  from typing import List
    30  from typing import NamedTuple
    31  
    32  import pytest
    33  
    34  from apache_beam.coders import proto2_coder_test_messages_pb2 as test_message
    35  from apache_beam.coders import coders
    36  from apache_beam.coders import typecoders
    37  from apache_beam.internal import pickler
    38  from apache_beam.runners import pipeline_context
    39  from apache_beam.transforms import userstate
    40  from apache_beam.transforms import window
    41  from apache_beam.transforms.window import GlobalWindow
    42  from apache_beam.typehints import sharded_key_type
    43  from apache_beam.typehints import typehints
    44  from apache_beam.utils import timestamp
    45  from apache_beam.utils import windowed_value
    46  from apache_beam.utils.sharded_key import ShardedKey
    47  from apache_beam.utils.timestamp import MIN_TIMESTAMP
    48  
    49  from . import observable
    50  
    51  try:
    52    import dataclasses
    53  except ImportError:
    54    dataclasses = None  # type: ignore
    55  
    56  MyNamedTuple = collections.namedtuple('A', ['x', 'y'])
    57  MyTypedNamedTuple = NamedTuple('MyTypedNamedTuple', [('f1', int), ('f2', str)])
    58  
    59  
    60  class MyEnum(enum.Enum):
    61    E1 = 5
    62    E2 = enum.auto()
    63    E3 = 'abc'
    64  
    65  
    66  MyIntEnum = enum.IntEnum('MyIntEnum', 'I1 I2 I3')
    67  MyIntFlag = enum.IntFlag('MyIntFlag', 'F1 F2 F3')
    68  MyFlag = enum.Flag('MyFlag', 'F1 F2 F3')  # pylint: disable=too-many-function-args
    69  
    70  
    71  class DefinesGetState:
    72    def __init__(self, value):
    73      self.value = value
    74  
    75    def __getstate__(self):
    76      return self.value
    77  
    78    def __eq__(self, other):
    79      return type(other) is type(self) and other.value == self.value
    80  
    81  
    82  class DefinesGetAndSetState(DefinesGetState):
    83    def __setstate__(self, value):
    84      self.value = value
    85  
    86  
    87  # Defined out of line for picklability.
    88  class CustomCoder(coders.Coder):
    89    def encode(self, x):
    90      return str(x + 1).encode('utf-8')
    91  
    92    def decode(self, encoded):
    93      return int(encoded) - 1
    94  
    95  
    96  if dataclasses is not None:
    97  
    98    @dataclasses.dataclass(frozen=True)
    99    class FrozenDataClass:
   100      a: Any
   101      b: int
   102  
   103    @dataclasses.dataclass
   104    class UnFrozenDataClass:
   105      x: int
   106      y: int
   107  
   108  
   109  # These tests need to all be run in the same process due to the asserts
   110  # in tearDownClass.
   111  @pytest.mark.no_xdist
   112  class CodersTest(unittest.TestCase):
   113  
   114    # These class methods ensure that we test each defined coder in both
   115    # nested and unnested context.
   116  
   117    # Common test values representing Python's built-in types.
   118    test_values_deterministic: List[Any] = [
   119        None,
   120        1,
   121        -1,
   122        1.5,
   123        b'str\0str',
   124        u'unicode\0\u0101',
   125        (),
   126        (1, 2, 3),
   127        [],
   128        [1, 2, 3],
   129        True,
   130        False,
   131    ]
   132    test_values = test_values_deterministic + [
   133        {},
   134        {
   135            'a': 'b'
   136        },
   137        {
   138            0: {}, 1: len
   139        },
   140        set(),
   141        {'a', 'b'},
   142        len,
   143    ]
   144  
   145    @classmethod
   146    def setUpClass(cls):
   147      cls.seen = set()
   148      cls.seen_nested = set()
   149  
   150    @classmethod
   151    def tearDownClass(cls):
   152      standard = set(
   153          c for c in coders.__dict__.values() if isinstance(c, type) and
   154          issubclass(c, coders.Coder) and 'Base' not in c.__name__)
   155      standard -= set([
   156          coders.Coder,
   157          coders.AvroGenericCoder,
   158          coders.DeterministicProtoCoder,
   159          coders.FastCoder,
   160          coders.ListLikeCoder,
   161          coders.ProtoCoder,
   162          coders.ProtoPlusCoder,
   163          coders.BigEndianShortCoder,
   164          coders.SinglePrecisionFloatCoder,
   165          coders.ToBytesCoder,
   166          coders.BigIntegerCoder, # tested in DecimalCoder
   167      ])
   168      cls.seen_nested -= set(
   169          [coders.ProtoCoder, coders.ProtoPlusCoder, CustomCoder])
   170      assert not standard - cls.seen, str(standard - cls.seen)
   171      assert not cls.seen_nested - standard, str(cls.seen_nested - standard)
   172  
   173    @classmethod
   174    def _observe(cls, coder):
   175      cls.seen.add(type(coder))
   176      cls._observe_nested(coder)
   177  
   178    @classmethod
   179    def _observe_nested(cls, coder):
   180      if isinstance(coder, coders.TupleCoder):
   181        for c in coder.coders():
   182          cls.seen_nested.add(type(c))
   183          cls._observe_nested(c)
   184  
   185    def check_coder(self, coder, *values, **kwargs):
   186      context = kwargs.pop('context', pipeline_context.PipelineContext())
   187      test_size_estimation = kwargs.pop('test_size_estimation', True)
   188      assert not kwargs
   189      self._observe(coder)
   190      for v in values:
   191        self.assertEqual(v, coder.decode(coder.encode(v)))
   192        if test_size_estimation:
   193          self.assertEqual(coder.estimate_size(v), len(coder.encode(v)))
   194          self.assertEqual(
   195              coder.estimate_size(v), coder.get_impl().estimate_size(v))
   196          self.assertEqual(
   197              coder.get_impl().get_estimated_size_and_observables(v),
   198              (coder.get_impl().estimate_size(v), []))
   199        copy1 = pickler.loads(pickler.dumps(coder))
   200      copy2 = coders.Coder.from_runner_api(coder.to_runner_api(context), context)
   201      for v in values:
   202        self.assertEqual(v, copy1.decode(copy2.encode(v)))
   203        if coder.is_deterministic():
   204          self.assertEqual(copy1.encode(v), copy2.encode(v))
   205  
   206    def test_custom_coder(self):
   207  
   208      self.check_coder(CustomCoder(), 1, -10, 5)
   209      self.check_coder(
   210          coders.TupleCoder((CustomCoder(), coders.BytesCoder())), (1, b'a'),
   211          (-10, b'b'), (5, b'c'))
   212  
   213    def test_pickle_coder(self):
   214      coder = coders.PickleCoder()
   215      self.check_coder(coder, *self.test_values)
   216  
   217    def test_memoizing_pickle_coder(self):
   218      coder = coders._MemoizingPickleCoder()
   219      self.check_coder(coder, *self.test_values)
   220  
   221    def test_deterministic_coder(self):
   222      coder = coders.FastPrimitivesCoder()
   223      deterministic_coder = coders.DeterministicFastPrimitivesCoder(coder, 'step')
   224      self.check_coder(deterministic_coder, *self.test_values_deterministic)
   225      for v in self.test_values_deterministic:
   226        self.check_coder(coders.TupleCoder((deterministic_coder, )), (v, ))
   227      self.check_coder(
   228          coders.TupleCoder(
   229              (deterministic_coder, ) * len(self.test_values_deterministic)),
   230          tuple(self.test_values_deterministic))
   231  
   232      self.check_coder(deterministic_coder, {})
   233      self.check_coder(deterministic_coder, {2: 'x', 1: 'y'})
   234      with self.assertRaises(TypeError):
   235        self.check_coder(deterministic_coder, {1: 'x', 'y': 2})
   236      self.check_coder(deterministic_coder, [1, {}])
   237      with self.assertRaises(TypeError):
   238        self.check_coder(deterministic_coder, [1, {1: 'x', 'y': 2}])
   239  
   240      self.check_coder(
   241          coders.TupleCoder((deterministic_coder, coder)), (1, {}), ('a', [{}]))
   242  
   243      self.check_coder(deterministic_coder, test_message.MessageA(field1='value'))
   244  
   245      self.check_coder(
   246          deterministic_coder, [MyNamedTuple(1, 2), MyTypedNamedTuple(1, 'a')])
   247  
   248      if dataclasses is not None:
   249        self.check_coder(deterministic_coder, FrozenDataClass(1, 2))
   250  
   251        with self.assertRaises(TypeError):
   252          self.check_coder(deterministic_coder, UnFrozenDataClass(1, 2))
   253        with self.assertRaises(TypeError):
   254          self.check_coder(
   255              deterministic_coder, FrozenDataClass(UnFrozenDataClass(1, 2), 3))
   256        with self.assertRaises(TypeError):
   257          self.check_coder(
   258              deterministic_coder, MyNamedTuple(UnFrozenDataClass(1, 2), 3))
   259  
   260      self.check_coder(deterministic_coder, list(MyEnum))
   261      self.check_coder(deterministic_coder, list(MyIntEnum))
   262      self.check_coder(deterministic_coder, list(MyIntFlag))
   263      self.check_coder(deterministic_coder, list(MyFlag))
   264  
   265      self.check_coder(
   266          deterministic_coder,
   267          [DefinesGetAndSetState(1), DefinesGetAndSetState((1, 2, 3))])
   268  
   269      with self.assertRaises(TypeError):
   270        self.check_coder(deterministic_coder, DefinesGetState(1))
   271      with self.assertRaises(TypeError):
   272        self.check_coder(
   273            deterministic_coder, DefinesGetAndSetState({
   274                1: 'x', 'y': 2
   275            }))
   276  
   277    def test_dill_coder(self):
   278      cell_value = (lambda x: lambda: x)(0).__closure__[0]
   279      self.check_coder(coders.DillCoder(), 'a', 1, cell_value)
   280      self.check_coder(
   281          coders.TupleCoder((coders.VarIntCoder(), coders.DillCoder())),
   282          (1, cell_value))
   283  
   284    def test_fast_primitives_coder(self):
   285      coder = coders.FastPrimitivesCoder(coders.SingletonCoder(len))
   286      self.check_coder(coder, *self.test_values)
   287      for v in self.test_values:
   288        self.check_coder(coders.TupleCoder((coder, )), (v, ))
   289  
   290    def test_fast_primitives_coder_large_int(self):
   291      coder = coders.FastPrimitivesCoder()
   292      self.check_coder(coder, 10**100)
   293  
   294    def test_fake_deterministic_fast_primitives_coder(self):
   295      coder = coders.FakeDeterministicFastPrimitivesCoder(coders.PickleCoder())
   296      self.check_coder(coder, *self.test_values)
   297      for v in self.test_values:
   298        self.check_coder(coders.TupleCoder((coder, )), (v, ))
   299  
   300    def test_bytes_coder(self):
   301      self.check_coder(coders.BytesCoder(), b'a', b'\0', b'z' * 1000)
   302  
   303    def test_bool_coder(self):
   304      self.check_coder(coders.BooleanCoder(), True, False)
   305  
   306    def test_varint_coder(self):
   307      # Small ints.
   308      self.check_coder(coders.VarIntCoder(), *range(-10, 10))
   309      # Multi-byte encoding starts at 128
   310      self.check_coder(coders.VarIntCoder(), *range(120, 140))
   311      # Large values
   312      MAX_64_BIT_INT = 0x7fffffffffffffff
   313      self.check_coder(
   314          coders.VarIntCoder(),
   315          *[
   316              int(math.pow(-1, k) * math.exp(k))
   317              for k in range(0, int(math.log(MAX_64_BIT_INT)))
   318          ])
   319  
   320    def test_float_coder(self):
   321      self.check_coder(
   322          coders.FloatCoder(), *[float(0.1 * x) for x in range(-100, 100)])
   323      self.check_coder(
   324          coders.FloatCoder(), *[float(2**(0.1 * x)) for x in range(-100, 100)])
   325      self.check_coder(coders.FloatCoder(), float('-Inf'), float('Inf'))
   326      self.check_coder(
   327          coders.TupleCoder((coders.FloatCoder(), coders.FloatCoder())), (0, 1),
   328          (-100, 100), (0.5, 0.25))
   329  
   330    def test_singleton_coder(self):
   331      a = 'anything'
   332      b = 'something else'
   333      self.check_coder(coders.SingletonCoder(a), a)
   334      self.check_coder(coders.SingletonCoder(b), b)
   335      self.check_coder(
   336          coders.TupleCoder((coders.SingletonCoder(a), coders.SingletonCoder(b))),
   337          (a, b))
   338  
   339    def test_interval_window_coder(self):
   340      self.check_coder(
   341          coders.IntervalWindowCoder(),
   342          *[
   343              window.IntervalWindow(x, y) for x in [-2**52, 0, 2**52]
   344              for y in range(-100, 100)
   345          ])
   346      self.check_coder(
   347          coders.TupleCoder((coders.IntervalWindowCoder(), )),
   348          (window.IntervalWindow(0, 10), ))
   349  
   350    def test_timestamp_coder(self):
   351      self.check_coder(
   352          coders.TimestampCoder(),
   353          *[timestamp.Timestamp(micros=x) for x in (-1000, 0, 1000)])
   354      self.check_coder(
   355          coders.TimestampCoder(),
   356          timestamp.Timestamp(micros=-1234567000),
   357          timestamp.Timestamp(micros=1234567000))
   358      self.check_coder(
   359          coders.TimestampCoder(),
   360          timestamp.Timestamp(micros=-1234567890123456000),
   361          timestamp.Timestamp(micros=1234567890123456000))
   362      self.check_coder(
   363          coders.TupleCoder((coders.TimestampCoder(), coders.BytesCoder())),
   364          (timestamp.Timestamp.of(27), b'abc'))
   365  
   366    def test_timer_coder(self):
   367      self.check_coder(
   368          coders._TimerCoder(coders.StrUtf8Coder(), coders.GlobalWindowCoder()),
   369          *[
   370              userstate.Timer(
   371                  user_key="key",
   372                  dynamic_timer_tag="tag",
   373                  windows=(GlobalWindow(), ),
   374                  clear_bit=True,
   375                  fire_timestamp=None,
   376                  hold_timestamp=None,
   377                  paneinfo=None),
   378              userstate.Timer(
   379                  user_key="key",
   380                  dynamic_timer_tag="tag",
   381                  windows=(GlobalWindow(), ),
   382                  clear_bit=False,
   383                  fire_timestamp=timestamp.Timestamp.of(123),
   384                  hold_timestamp=timestamp.Timestamp.of(456),
   385                  paneinfo=windowed_value.PANE_INFO_UNKNOWN)
   386          ])
   387  
   388    def test_tuple_coder(self):
   389      kv_coder = coders.TupleCoder((coders.VarIntCoder(), coders.BytesCoder()))
   390      # Verify cloud object representation
   391      self.assertEqual({
   392          '@type': 'kind:pair',
   393          'is_pair_like': True,
   394          'component_encodings': [
   395              coders.VarIntCoder().as_cloud_object(),
   396              coders.BytesCoder().as_cloud_object()
   397          ],
   398      },
   399                       kv_coder.as_cloud_object())
   400      # Test binary representation
   401      self.assertEqual(b'\x04abc', kv_coder.encode((4, b'abc')))
   402      # Test unnested
   403      self.check_coder(kv_coder, (1, b'a'), (-2, b'a' * 100), (300, b'abc\0' * 5))
   404      # Test nested
   405      self.check_coder(
   406          coders.TupleCoder((
   407              coders.TupleCoder((coders.PickleCoder(), coders.VarIntCoder())),
   408              coders.StrUtf8Coder(),
   409              coders.BooleanCoder())), ((1, 2), 'a', True),
   410          ((-2, 5), u'a\u0101' * 100, False), ((300, 1), 'abc\0' * 5, True))
   411  
   412    def test_tuple_sequence_coder(self):
   413      int_tuple_coder = coders.TupleSequenceCoder(coders.VarIntCoder())
   414      self.check_coder(int_tuple_coder, (1, -1, 0), (), tuple(range(1000)))
   415      self.check_coder(
   416          coders.TupleCoder((coders.VarIntCoder(), int_tuple_coder)),
   417          (1, (1, 2, 3)))
   418  
   419    def test_base64_pickle_coder(self):
   420      self.check_coder(coders.Base64PickleCoder(), 'a', 1, 1.5, (1, 2, 3))
   421  
   422    def test_utf8_coder(self):
   423      self.check_coder(coders.StrUtf8Coder(), 'a', u'ab\u00FF', u'\u0101\0')
   424  
   425    def test_iterable_coder(self):
   426      iterable_coder = coders.IterableCoder(coders.VarIntCoder())
   427      # Verify cloud object representation
   428      self.assertEqual({
   429          '@type': 'kind:stream',
   430          'is_stream_like': True,
   431          'component_encodings': [coders.VarIntCoder().as_cloud_object()]
   432      },
   433                       iterable_coder.as_cloud_object())
   434      # Test unnested
   435      self.check_coder(iterable_coder, [1], [-1, 0, 100])
   436      # Test nested
   437      self.check_coder(
   438          coders.TupleCoder(
   439              (coders.VarIntCoder(), coders.IterableCoder(coders.VarIntCoder()))),
   440          (1, [1, 2, 3]))
   441  
   442    def test_iterable_coder_unknown_length(self):
   443      # Empty
   444      self._test_iterable_coder_of_unknown_length(0)
   445      # Single element
   446      self._test_iterable_coder_of_unknown_length(1)
   447      # Multiple elements
   448      self._test_iterable_coder_of_unknown_length(100)
   449      # Multiple elements with underlying stream buffer overflow.
   450      self._test_iterable_coder_of_unknown_length(80000)
   451  
   452    def _test_iterable_coder_of_unknown_length(self, count):
   453      def iter_generator(count):
   454        for i in range(count):
   455          yield i
   456  
   457      iterable_coder = coders.IterableCoder(coders.VarIntCoder())
   458      self.assertCountEqual(
   459          list(iter_generator(count)),
   460          iterable_coder.decode(iterable_coder.encode(iter_generator(count))))
   461  
   462    def test_list_coder(self):
   463      list_coder = coders.ListCoder(coders.VarIntCoder())
   464      # Test unnested
   465      self.check_coder(list_coder, [1], [-1, 0, 100])
   466      # Test nested
   467      self.check_coder(
   468          coders.TupleCoder((coders.VarIntCoder(), list_coder)), (1, [1, 2, 3]))
   469  
   470    def test_windowedvalue_coder_paneinfo(self):
   471      coder = coders.WindowedValueCoder(
   472          coders.VarIntCoder(), coders.GlobalWindowCoder())
   473      test_paneinfo_values = [
   474          windowed_value.PANE_INFO_UNKNOWN,
   475          windowed_value.PaneInfo(
   476              True, True, windowed_value.PaneInfoTiming.EARLY, 0, -1),
   477          windowed_value.PaneInfo(
   478              True, False, windowed_value.PaneInfoTiming.ON_TIME, 0, 0),
   479          windowed_value.PaneInfo(
   480              True, False, windowed_value.PaneInfoTiming.ON_TIME, 10, 0),
   481          windowed_value.PaneInfo(
   482              False, True, windowed_value.PaneInfoTiming.ON_TIME, 0, 23),
   483          windowed_value.PaneInfo(
   484              False, True, windowed_value.PaneInfoTiming.ON_TIME, 12, 23),
   485          windowed_value.PaneInfo(
   486              False, False, windowed_value.PaneInfoTiming.LATE, 0, 123),
   487      ]
   488  
   489      test_values = [
   490          windowed_value.WindowedValue(123, 234, (GlobalWindow(), ), p)
   491          for p in test_paneinfo_values
   492      ]
   493  
   494      # Test unnested.
   495      self.check_coder(
   496          coder,
   497          windowed_value.WindowedValue(
   498              123, 234, (GlobalWindow(), ), windowed_value.PANE_INFO_UNKNOWN))
   499      for value in test_values:
   500        self.check_coder(coder, value)
   501  
   502      # Test nested.
   503      for value1 in test_values:
   504        for value2 in test_values:
   505          self.check_coder(coders.TupleCoder((coder, coder)), (value1, value2))
   506  
   507    def test_windowed_value_coder(self):
   508      coder = coders.WindowedValueCoder(
   509          coders.VarIntCoder(), coders.GlobalWindowCoder())
   510      # Verify cloud object representation
   511      self.assertEqual({
   512          '@type': 'kind:windowed_value',
   513          'is_wrapper': True,
   514          'component_encodings': [
   515              coders.VarIntCoder().as_cloud_object(),
   516              coders.GlobalWindowCoder().as_cloud_object(),
   517          ],
   518      },
   519                       coder.as_cloud_object())
   520      # Test binary representation
   521      self.assertEqual(
   522          b'\x7f\xdf;dZ\x1c\xac\t\x00\x00\x00\x01\x0f\x01',
   523          coder.encode(window.GlobalWindows.windowed_value(1)))
   524  
   525      # Test decoding large timestamp
   526      self.assertEqual(
   527          coder.decode(b'\x7f\xdf;dZ\x1c\xac\x08\x00\x00\x00\x01\x0f\x00'),
   528          windowed_value.create(0, MIN_TIMESTAMP.micros, (GlobalWindow(), )))
   529  
   530      # Test unnested
   531      self.check_coder(
   532          coders.WindowedValueCoder(coders.VarIntCoder()),
   533          windowed_value.WindowedValue(3, -100, ()),
   534          windowed_value.WindowedValue(-1, 100, (1, 2, 3)))
   535  
   536      # Test Global Window
   537      self.check_coder(
   538          coders.WindowedValueCoder(
   539              coders.VarIntCoder(), coders.GlobalWindowCoder()),
   540          window.GlobalWindows.windowed_value(1))
   541  
   542      # Test nested
   543      self.check_coder(
   544          coders.TupleCoder((
   545              coders.WindowedValueCoder(coders.FloatCoder()),
   546              coders.WindowedValueCoder(coders.StrUtf8Coder()))),
   547          (
   548              windowed_value.WindowedValue(1.5, 0, ()),
   549              windowed_value.WindowedValue("abc", 10, ('window', ))))
   550  
   551    def test_param_windowed_value_coder(self):
   552      from apache_beam.transforms.window import IntervalWindow
   553      from apache_beam.utils.windowed_value import PaneInfo
   554      wv = windowed_value.create(
   555          b'',
   556          # Milliseconds to microseconds
   557          1000 * 1000,
   558          (IntervalWindow(11, 21), ),
   559          PaneInfo(True, False, 1, 2, 3))
   560      windowed_value_coder = coders.WindowedValueCoder(
   561          coders.BytesCoder(), coders.IntervalWindowCoder())
   562      payload = windowed_value_coder.encode(wv)
   563      coder = coders.ParamWindowedValueCoder(
   564          payload, [coders.VarIntCoder(), coders.IntervalWindowCoder()])
   565  
   566      # Test binary representation
   567      self.assertEqual(
   568          b'\x01', coder.encode(window.GlobalWindows.windowed_value(1)))
   569  
   570      # Test unnested
   571      self.check_coder(
   572          coders.ParamWindowedValueCoder(
   573              payload, [coders.VarIntCoder(), coders.IntervalWindowCoder()]),
   574          windowed_value.WindowedValue(
   575              3,
   576              1, (window.IntervalWindow(11, 21), ),
   577              PaneInfo(True, False, 1, 2, 3)),
   578          windowed_value.WindowedValue(
   579              1,
   580              1, (window.IntervalWindow(11, 21), ),
   581              PaneInfo(True, False, 1, 2, 3)))
   582  
   583      # Test nested
   584      self.check_coder(
   585          coders.TupleCoder((
   586              coders.ParamWindowedValueCoder(
   587                  payload, [coders.FloatCoder(), coders.IntervalWindowCoder()]),
   588              coders.ParamWindowedValueCoder(
   589                  payload,
   590                  [coders.StrUtf8Coder(), coders.IntervalWindowCoder()]))),
   591          (
   592              windowed_value.WindowedValue(
   593                  1.5,
   594                  1, (window.IntervalWindow(11, 21), ),
   595                  PaneInfo(True, False, 1, 2, 3)),
   596              windowed_value.WindowedValue(
   597                  "abc",
   598                  1, (window.IntervalWindow(11, 21), ),
   599                  PaneInfo(True, False, 1, 2, 3))))
   600  
   601    def test_proto_coder(self):
   602      # For instructions on how these test proto message were generated,
   603      # see coders_test.py
   604      ma = test_message.MessageA()
   605      mab = ma.field2.add()
   606      mab.field1 = True
   607      ma.field1 = u'hello world'
   608  
   609      mb = test_message.MessageA()
   610      mb.field1 = u'beam'
   611  
   612      proto_coder = coders.ProtoCoder(ma.__class__)
   613      self.check_coder(proto_coder, ma)
   614      self.check_coder(
   615          coders.TupleCoder((proto_coder, coders.BytesCoder())), (ma, b'a'),
   616          (mb, b'b'))
   617  
   618    def test_global_window_coder(self):
   619      coder = coders.GlobalWindowCoder()
   620      value = window.GlobalWindow()
   621      # Verify cloud object representation
   622      self.assertEqual({'@type': 'kind:global_window'}, coder.as_cloud_object())
   623      # Test binary representation
   624      self.assertEqual(b'', coder.encode(value))
   625      self.assertEqual(value, coder.decode(b''))
   626      # Test unnested
   627      self.check_coder(coder, value)
   628      # Test nested
   629      self.check_coder(coders.TupleCoder((coder, coder)), (value, value))
   630  
   631    def test_length_prefix_coder(self):
   632      coder = coders.LengthPrefixCoder(coders.BytesCoder())
   633      # Verify cloud object representation
   634      self.assertEqual({
   635          '@type': 'kind:length_prefix',
   636          'component_encodings': [coders.BytesCoder().as_cloud_object()]
   637      },
   638                       coder.as_cloud_object())
   639      # Test binary representation
   640      self.assertEqual(b'\x00', coder.encode(b''))
   641      self.assertEqual(b'\x01a', coder.encode(b'a'))
   642      self.assertEqual(b'\x02bc', coder.encode(b'bc'))
   643      self.assertEqual(b'\xff\x7f' + b'z' * 16383, coder.encode(b'z' * 16383))
   644      # Test unnested
   645      self.check_coder(coder, b'', b'a', b'bc', b'def')
   646      # Test nested
   647      self.check_coder(
   648          coders.TupleCoder((coder, coder)), (b'', b'a'), (b'bc', b'def'))
   649  
   650    def test_nested_observables(self):
   651      class FakeObservableIterator(observable.ObservableMixin):
   652        def __iter__(self):
   653          return iter([1, 2, 3])
   654  
   655      # Coder for elements from the observable iterator.
   656      elem_coder = coders.VarIntCoder()
   657      iter_coder = coders.TupleSequenceCoder(elem_coder)
   658  
   659      # Test nested WindowedValue observable.
   660      coder = coders.WindowedValueCoder(iter_coder)
   661      observ = FakeObservableIterator()
   662      value = windowed_value.WindowedValue(observ, 0, ())
   663      self.assertEqual(
   664          coder.get_impl().get_estimated_size_and_observables(value)[1],
   665          [(observ, elem_coder.get_impl())])
   666  
   667      # Test nested tuple observable.
   668      coder = coders.TupleCoder((coders.StrUtf8Coder(), iter_coder))
   669      value = (u'123', observ)
   670      self.assertEqual(
   671          coder.get_impl().get_estimated_size_and_observables(value)[1],
   672          [(observ, elem_coder.get_impl())])
   673  
   674    def test_state_backed_iterable_coder(self):
   675      # pylint: disable=global-variable-undefined
   676      # required for pickling by reference
   677      global state
   678      state = {}
   679  
   680      def iterable_state_write(values, element_coder_impl):
   681        token = b'state_token_%d' % len(state)
   682        state[token] = [element_coder_impl.encode(e) for e in values]
   683        return token
   684  
   685      def iterable_state_read(token, element_coder_impl):
   686        return [element_coder_impl.decode(s) for s in state[token]]
   687  
   688      coder = coders.StateBackedIterableCoder(
   689          coders.VarIntCoder(),
   690          read_state=iterable_state_read,
   691          write_state=iterable_state_write,
   692          write_state_threshold=1)
   693      # Note: do not use check_coder
   694      # see https://github.com/cloudpipe/cloudpickle/issues/452
   695      self._observe(coder)
   696      self.assertEqual([1, 2, 3], coder.decode(coder.encode([1, 2, 3])))
   697      # Ensure that state was actually used.
   698      self.assertNotEqual(state, {})
   699      tupleCoder = coders.TupleCoder((coder, coder))
   700      self._observe(tupleCoder)
   701      self.assertEqual(([1], [2, 3]),
   702                       tupleCoder.decode(tupleCoder.encode(([1], [2, 3]))))
   703  
   704    def test_nullable_coder(self):
   705      self.check_coder(coders.NullableCoder(coders.VarIntCoder()), None, 2 * 64)
   706  
   707    def test_map_coder(self):
   708      values = [
   709          {1: "one", 300: "three hundred"}, # force yapf to be nice
   710          {},
   711          {i: str(i) for i in range(5000)}
   712      ]
   713      map_coder = coders.MapCoder(coders.VarIntCoder(), coders.StrUtf8Coder())
   714      self.check_coder(map_coder, *values)
   715      self.check_coder(map_coder.as_deterministic_coder("label"), *values)
   716  
   717    def test_sharded_key_coder(self):
   718      key_and_coders = [(b'', b'\x00', coders.BytesCoder()),
   719                        (b'key', b'\x03key', coders.BytesCoder()),
   720                        ('key', b'\03\x6b\x65\x79', coders.StrUtf8Coder()),
   721                        (('k', 1),
   722                         b'\x01\x6b\x01',
   723                         coders.TupleCoder(
   724                             (coders.StrUtf8Coder(), coders.VarIntCoder())))]
   725  
   726      for key, bytes_repr, key_coder in key_and_coders:
   727        coder = coders.ShardedKeyCoder(key_coder)
   728        # Verify cloud object representation
   729        self.assertEqual({
   730            '@type': 'kind:sharded_key',
   731            'component_encodings': [key_coder.as_cloud_object()]
   732        },
   733                         coder.as_cloud_object())
   734  
   735        # Test str repr
   736        self.assertEqual('%s' % coder, 'ShardedKeyCoder[%s]' % key_coder)
   737  
   738        self.assertEqual(b'\x00' + bytes_repr, coder.encode(ShardedKey(key, b'')))
   739        self.assertEqual(
   740            b'\x03123' + bytes_repr, coder.encode(ShardedKey(key, b'123')))
   741  
   742        # Test unnested
   743        self.check_coder(coder, ShardedKey(key, b''))
   744        self.check_coder(coder, ShardedKey(key, b'123'))
   745  
   746        # Test type hints
   747        self.assertTrue(
   748            isinstance(
   749                coder.to_type_hint(), sharded_key_type.ShardedKeyTypeConstraint))
   750        key_type = coder.to_type_hint().key_type
   751        if isinstance(key_type, typehints.TupleConstraint):
   752          self.assertEqual(key_type.tuple_types, (type(key[0]), type(key[1])))
   753        else:
   754          self.assertEqual(key_type, type(key))
   755        self.assertEqual(
   756            coders.ShardedKeyCoder.from_type_hint(
   757                coder.to_type_hint(), typecoders.CoderRegistry()),
   758            coder)
   759  
   760        for other_key, _, other_key_coder in key_and_coders:
   761          other_coder = coders.ShardedKeyCoder(other_key_coder)
   762          # Test nested
   763          self.check_coder(
   764              coders.TupleCoder((coder, other_coder)),
   765              (ShardedKey(key, b''), ShardedKey(other_key, b'')))
   766          self.check_coder(
   767              coders.TupleCoder((coder, other_coder)),
   768              (ShardedKey(key, b'123'), ShardedKey(other_key, b'')))
   769  
   770    def test_timestamp_prefixing_window_coder(self):
   771      self.check_coder(
   772          coders.TimestampPrefixingWindowCoder(coders.IntervalWindowCoder()),
   773          *[
   774              window.IntervalWindow(x, y) for x in [-2**52, 0, 2**52]
   775              for y in range(-100, 100)
   776          ])
   777      self.check_coder(
   778          coders.TupleCoder((
   779              coders.TimestampPrefixingWindowCoder(
   780                  coders.IntervalWindowCoder()), )),
   781          (window.IntervalWindow(0, 10), ))
   782  
   783    def test_decimal_coder(self):
   784      test_coder = coders.DecimalCoder()
   785  
   786      test_values = [
   787          Decimal("-10.5"),
   788          Decimal("-1"),
   789          Decimal(),
   790          Decimal("1"),
   791          Decimal("13.258"),
   792      ]
   793  
   794      test_encodings = ("AZc", "AP8", "AAA", "AAE", "AzPK")
   795  
   796      self.check_coder(test_coder, *test_values)
   797  
   798      for idx, value in enumerate(test_values):
   799        self.assertEqual(
   800            test_encodings[idx],
   801            base64.b64encode(test_coder.encode(value)).decode().rstrip("="))
   802  
   803  
   804  if __name__ == '__main__':
   805    logging.getLogger().setLevel(logging.INFO)
   806    unittest.main()