github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/coders/stream_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Tests for the stream implementations."""
    19  # pytype: skip-file
    20  
    21  import logging
    22  import math
    23  import unittest
    24  
    25  import numpy as np
    26  
    27  from apache_beam.coders import slow_stream
    28  
    29  
    30  class StreamTest(unittest.TestCase):
    31    # pylint: disable=invalid-name
    32    InputStream = slow_stream.InputStream
    33    OutputStream = slow_stream.OutputStream
    34    ByteCountingOutputStream = slow_stream.ByteCountingOutputStream
    35  
    36    # pylint: enable=invalid-name
    37  
    38    def test_read_write(self):
    39      out_s = self.OutputStream()
    40      out_s.write(b'abc')
    41      out_s.write(b'\0\t\n')
    42      out_s.write(b'xyz', True)
    43      out_s.write(b'', True)
    44      in_s = self.InputStream(out_s.get())
    45      self.assertEqual(b'abc\0\t\n', in_s.read(6))
    46      self.assertEqual(b'xyz', in_s.read_all(True))
    47      self.assertEqual(b'', in_s.read_all(True))
    48  
    49    def test_read_all(self):
    50      out_s = self.OutputStream()
    51      out_s.write(b'abc')
    52      in_s = self.InputStream(out_s.get())
    53      self.assertEqual(b'abc', in_s.read_all(False))
    54  
    55    def test_read_write_byte(self):
    56      out_s = self.OutputStream()
    57      out_s.write_byte(1)
    58      out_s.write_byte(0)
    59      out_s.write_byte(0xFF)
    60      in_s = self.InputStream(out_s.get())
    61      self.assertEqual(1, in_s.read_byte())
    62      self.assertEqual(0, in_s.read_byte())
    63      self.assertEqual(0xFF, in_s.read_byte())
    64  
    65    def test_read_write_large(self):
    66      values = range(4 * 1024)
    67      out_s = self.OutputStream()
    68      for v in values:
    69        out_s.write_bigendian_int64(v)
    70      in_s = self.InputStream(out_s.get())
    71      for v in values:
    72        self.assertEqual(v, in_s.read_bigendian_int64())
    73  
    74    def run_read_write_var_int64(self, values):
    75      out_s = self.OutputStream()
    76      for v in values:
    77        out_s.write_var_int64(v)
    78      in_s = self.InputStream(out_s.get())
    79      for v in values:
    80        self.assertEqual(v, in_s.read_var_int64())
    81  
    82    def test_small_var_int64(self):
    83      self.run_read_write_var_int64(range(-10, 30))
    84  
    85    def test_medium_var_int64(self):
    86      base = -1.7
    87      self.run_read_write_var_int64([
    88          int(base**pow)
    89          for pow in range(1, int(63 * math.log(2) / math.log(-base)))
    90      ])
    91  
    92    def test_large_var_int64(self):
    93      self.run_read_write_var_int64([0, 2**63 - 1, -2**63, 2**63 - 3])
    94  
    95    def test_read_write_double(self):
    96      values = 0, 1, -1, 1e100, 1.0 / 3, math.pi, float('inf')
    97      out_s = self.OutputStream()
    98      for v in values:
    99        out_s.write_bigendian_double(v)
   100      in_s = self.InputStream(out_s.get())
   101      for v in values:
   102        self.assertEqual(v, in_s.read_bigendian_double())
   103  
   104    def test_read_write_float(self):
   105      values = 0, 1, -1, 1e20, 1.0 / 3, math.pi, float('inf')
   106      # Restrict to single precision before coder roundtrip
   107      values = tuple(float(np.float32(v)) for v in values)
   108      out_s = self.OutputStream()
   109      for v in values:
   110        out_s.write_bigendian_float(v)
   111      in_s = self.InputStream(out_s.get())
   112      for v in values:
   113        self.assertEqual(v, in_s.read_bigendian_float())
   114  
   115    def test_read_write_bigendian_int64(self):
   116      values = 0, 1, -1, 2**63 - 1, -2**63, int(2**61 * math.pi)
   117      out_s = self.OutputStream()
   118      for v in values:
   119        out_s.write_bigendian_int64(v)
   120      in_s = self.InputStream(out_s.get())
   121      for v in values:
   122        self.assertEqual(v, in_s.read_bigendian_int64())
   123  
   124    def test_read_write_bigendian_uint64(self):
   125      values = 0, 1, 2**64 - 1, int(2**61 * math.pi)
   126      out_s = self.OutputStream()
   127      for v in values:
   128        out_s.write_bigendian_uint64(v)
   129      in_s = self.InputStream(out_s.get())
   130      for v in values:
   131        self.assertEqual(v, in_s.read_bigendian_uint64())
   132  
   133    def test_read_write_bigendian_int32(self):
   134      values = 0, 1, -1, 2**31 - 1, -2**31, int(2**29 * math.pi)
   135      out_s = self.OutputStream()
   136      for v in values:
   137        out_s.write_bigendian_int32(v)
   138      in_s = self.InputStream(out_s.get())
   139      for v in values:
   140        self.assertEqual(v, in_s.read_bigendian_int32())
   141  
   142    def test_read_write_bigendian_int16(self):
   143      values = 0, 1, -1, 2**15 - 1, -2**15, int(2**13 * math.pi)
   144      out_s = self.OutputStream()
   145      for v in values:
   146        out_s.write_bigendian_int16(v)
   147      in_s = self.InputStream(out_s.get())
   148      for v in values:
   149        self.assertEqual(v, in_s.read_bigendian_int16())
   150  
   151    def test_byte_counting(self):
   152      bc_s = self.ByteCountingOutputStream()
   153      self.assertEqual(0, bc_s.get_count())
   154      bc_s.write(b'def')
   155      self.assertEqual(3, bc_s.get_count())
   156      bc_s.write(b'')
   157      self.assertEqual(3, bc_s.get_count())
   158      bc_s.write_byte(10)
   159      self.assertEqual(4, bc_s.get_count())
   160      # "nested" also writes the length of the string, which should
   161      # cause 1 extra byte to be counted.
   162      bc_s.write(b'2345', nested=True)
   163      self.assertEqual(9, bc_s.get_count())
   164      bc_s.write_var_int64(63)
   165      self.assertEqual(10, bc_s.get_count())
   166      bc_s.write_bigendian_int64(42)
   167      self.assertEqual(18, bc_s.get_count())
   168      bc_s.write_bigendian_int32(36)
   169      self.assertEqual(22, bc_s.get_count())
   170      bc_s.write_bigendian_double(6.25)
   171      self.assertEqual(30, bc_s.get_count())
   172      bc_s.write_bigendian_uint64(47)
   173      self.assertEqual(38, bc_s.get_count())
   174  
   175  
   176  try:
   177    # pylint: disable=wrong-import-position
   178    from apache_beam.coders import stream
   179  
   180    class FastStreamTest(StreamTest):
   181      """Runs the test with the compiled stream classes."""
   182      InputStream = stream.InputStream
   183      OutputStream = stream.OutputStream
   184      ByteCountingOutputStream = stream.ByteCountingOutputStream
   185  
   186    class SlowFastStreamTest(StreamTest):
   187      """Runs the test with compiled and uncompiled stream classes."""
   188      InputStream = stream.InputStream
   189      OutputStream = slow_stream.OutputStream
   190      ByteCountingOutputStream = slow_stream.ByteCountingOutputStream
   191  
   192    class FastSlowStreamTest(StreamTest):
   193      """Runs the test with uncompiled and compiled stream classes."""
   194      InputStream = slow_stream.InputStream
   195      OutputStream = stream.OutputStream
   196      ByteCountingOutputStream = stream.ByteCountingOutputStream
   197  
   198  except ImportError:
   199    pass
   200  
   201  if __name__ == '__main__':
   202    logging.getLogger().setLevel(logging.INFO)
   203    unittest.main()