github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/coders/stream_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Tests for the stream implementations.""" 19 # pytype: skip-file 20 21 import logging 22 import math 23 import unittest 24 25 import numpy as np 26 27 from apache_beam.coders import slow_stream 28 29 30 class StreamTest(unittest.TestCase): 31 # pylint: disable=invalid-name 32 InputStream = slow_stream.InputStream 33 OutputStream = slow_stream.OutputStream 34 ByteCountingOutputStream = slow_stream.ByteCountingOutputStream 35 36 # pylint: enable=invalid-name 37 38 def test_read_write(self): 39 out_s = self.OutputStream() 40 out_s.write(b'abc') 41 out_s.write(b'\0\t\n') 42 out_s.write(b'xyz', True) 43 out_s.write(b'', True) 44 in_s = self.InputStream(out_s.get()) 45 self.assertEqual(b'abc\0\t\n', in_s.read(6)) 46 self.assertEqual(b'xyz', in_s.read_all(True)) 47 self.assertEqual(b'', in_s.read_all(True)) 48 49 def test_read_all(self): 50 out_s = self.OutputStream() 51 out_s.write(b'abc') 52 in_s = self.InputStream(out_s.get()) 53 self.assertEqual(b'abc', in_s.read_all(False)) 54 55 def test_read_write_byte(self): 56 out_s = self.OutputStream() 57 out_s.write_byte(1) 58 out_s.write_byte(0) 59 out_s.write_byte(0xFF) 60 in_s = self.InputStream(out_s.get()) 61 self.assertEqual(1, in_s.read_byte()) 62 self.assertEqual(0, in_s.read_byte()) 63 self.assertEqual(0xFF, in_s.read_byte()) 64 65 def test_read_write_large(self): 66 values = range(4 * 1024) 67 out_s = self.OutputStream() 68 for v in values: 69 out_s.write_bigendian_int64(v) 70 in_s = self.InputStream(out_s.get()) 71 for v in values: 72 self.assertEqual(v, in_s.read_bigendian_int64()) 73 74 def run_read_write_var_int64(self, values): 75 out_s = self.OutputStream() 76 for v in values: 77 out_s.write_var_int64(v) 78 in_s = self.InputStream(out_s.get()) 79 for v in values: 80 self.assertEqual(v, in_s.read_var_int64()) 81 82 def test_small_var_int64(self): 83 self.run_read_write_var_int64(range(-10, 30)) 84 85 def test_medium_var_int64(self): 86 base = -1.7 87 self.run_read_write_var_int64([ 88 int(base**pow) 89 for pow in range(1, int(63 * math.log(2) / math.log(-base))) 90 ]) 91 92 def test_large_var_int64(self): 93 self.run_read_write_var_int64([0, 2**63 - 1, -2**63, 2**63 - 3]) 94 95 def test_read_write_double(self): 96 values = 0, 1, -1, 1e100, 1.0 / 3, math.pi, float('inf') 97 out_s = self.OutputStream() 98 for v in values: 99 out_s.write_bigendian_double(v) 100 in_s = self.InputStream(out_s.get()) 101 for v in values: 102 self.assertEqual(v, in_s.read_bigendian_double()) 103 104 def test_read_write_float(self): 105 values = 0, 1, -1, 1e20, 1.0 / 3, math.pi, float('inf') 106 # Restrict to single precision before coder roundtrip 107 values = tuple(float(np.float32(v)) for v in values) 108 out_s = self.OutputStream() 109 for v in values: 110 out_s.write_bigendian_float(v) 111 in_s = self.InputStream(out_s.get()) 112 for v in values: 113 self.assertEqual(v, in_s.read_bigendian_float()) 114 115 def test_read_write_bigendian_int64(self): 116 values = 0, 1, -1, 2**63 - 1, -2**63, int(2**61 * math.pi) 117 out_s = self.OutputStream() 118 for v in values: 119 out_s.write_bigendian_int64(v) 120 in_s = self.InputStream(out_s.get()) 121 for v in values: 122 self.assertEqual(v, in_s.read_bigendian_int64()) 123 124 def test_read_write_bigendian_uint64(self): 125 values = 0, 1, 2**64 - 1, int(2**61 * math.pi) 126 out_s = self.OutputStream() 127 for v in values: 128 out_s.write_bigendian_uint64(v) 129 in_s = self.InputStream(out_s.get()) 130 for v in values: 131 self.assertEqual(v, in_s.read_bigendian_uint64()) 132 133 def test_read_write_bigendian_int32(self): 134 values = 0, 1, -1, 2**31 - 1, -2**31, int(2**29 * math.pi) 135 out_s = self.OutputStream() 136 for v in values: 137 out_s.write_bigendian_int32(v) 138 in_s = self.InputStream(out_s.get()) 139 for v in values: 140 self.assertEqual(v, in_s.read_bigendian_int32()) 141 142 def test_read_write_bigendian_int16(self): 143 values = 0, 1, -1, 2**15 - 1, -2**15, int(2**13 * math.pi) 144 out_s = self.OutputStream() 145 for v in values: 146 out_s.write_bigendian_int16(v) 147 in_s = self.InputStream(out_s.get()) 148 for v in values: 149 self.assertEqual(v, in_s.read_bigendian_int16()) 150 151 def test_byte_counting(self): 152 bc_s = self.ByteCountingOutputStream() 153 self.assertEqual(0, bc_s.get_count()) 154 bc_s.write(b'def') 155 self.assertEqual(3, bc_s.get_count()) 156 bc_s.write(b'') 157 self.assertEqual(3, bc_s.get_count()) 158 bc_s.write_byte(10) 159 self.assertEqual(4, bc_s.get_count()) 160 # "nested" also writes the length of the string, which should 161 # cause 1 extra byte to be counted. 162 bc_s.write(b'2345', nested=True) 163 self.assertEqual(9, bc_s.get_count()) 164 bc_s.write_var_int64(63) 165 self.assertEqual(10, bc_s.get_count()) 166 bc_s.write_bigendian_int64(42) 167 self.assertEqual(18, bc_s.get_count()) 168 bc_s.write_bigendian_int32(36) 169 self.assertEqual(22, bc_s.get_count()) 170 bc_s.write_bigendian_double(6.25) 171 self.assertEqual(30, bc_s.get_count()) 172 bc_s.write_bigendian_uint64(47) 173 self.assertEqual(38, bc_s.get_count()) 174 175 176 try: 177 # pylint: disable=wrong-import-position 178 from apache_beam.coders import stream 179 180 class FastStreamTest(StreamTest): 181 """Runs the test with the compiled stream classes.""" 182 InputStream = stream.InputStream 183 OutputStream = stream.OutputStream 184 ByteCountingOutputStream = stream.ByteCountingOutputStream 185 186 class SlowFastStreamTest(StreamTest): 187 """Runs the test with compiled and uncompiled stream classes.""" 188 InputStream = stream.InputStream 189 OutputStream = slow_stream.OutputStream 190 ByteCountingOutputStream = slow_stream.ByteCountingOutputStream 191 192 class FastSlowStreamTest(StreamTest): 193 """Runs the test with uncompiled and compiled stream classes.""" 194 InputStream = slow_stream.InputStream 195 OutputStream = stream.OutputStream 196 ByteCountingOutputStream = stream.ByteCountingOutputStream 197 198 except ImportError: 199 pass 200 201 if __name__ == '__main__': 202 logging.getLogger().setLevel(logging.INFO) 203 unittest.main()