github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/tfrecordio_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 # pytype: skip-file 19 20 import binascii 21 import glob 22 import gzip 23 import io 24 import logging 25 import os 26 import pickle 27 import random 28 import re 29 import unittest 30 import zlib 31 32 import crcmod 33 34 import apache_beam as beam 35 from apache_beam import Create 36 from apache_beam import coders 37 from apache_beam.io.filesystem import CompressionTypes 38 from apache_beam.io.tfrecordio import ReadAllFromTFRecord 39 from apache_beam.io.tfrecordio import ReadFromTFRecord 40 from apache_beam.io.tfrecordio import WriteToTFRecord 41 from apache_beam.io.tfrecordio import _TFRecordSink 42 from apache_beam.io.tfrecordio import _TFRecordUtil 43 from apache_beam.testing.test_pipeline import TestPipeline 44 from apache_beam.testing.test_utils import TempDir 45 from apache_beam.testing.util import assert_that 46 from apache_beam.testing.util import equal_to 47 48 try: 49 import tensorflow.compat.v1 as tf # pylint: disable=import-error 50 except ImportError: 51 try: 52 import tensorflow as tf # pylint: disable=import-error 53 except ImportError: 54 tf = None # pylint: disable=invalid-name 55 logging.warning('Tensorflow is not installed, so skipping some tests.') 56 57 # Created by running following code in python: 58 # >>> import tensorflow as tf 59 # >>> import base64 60 # >>> writer = tf.python_io.TFRecordWriter('/tmp/python_foo.tfrecord') 61 # >>> writer.write(b'foo') 62 # >>> writer.close() 63 # >>> with open('/tmp/python_foo.tfrecord', 'rb') as f: 64 # ... data = base64.b64encode(f.read()) 65 # ... print(data) 66 FOO_RECORD_BASE64 = b'AwAAAAAAAACwmUkOZm9vYYq+/g==' 67 68 # Same as above but containing two records [b'foo', b'bar'] 69 FOO_BAR_RECORD_BASE64 = b'AwAAAAAAAACwmUkOZm9vYYq+/gMAAAAAAAAAsJlJDmJhckYA5cg=' 70 71 72 def _write_file(path, base64_records): 73 record = binascii.a2b_base64(base64_records) 74 with open(path, 'wb') as f: 75 f.write(record) 76 77 78 def _write_file_deflate(path, base64_records): 79 record = binascii.a2b_base64(base64_records) 80 with open(path, 'wb') as f: 81 f.write(zlib.compress(record)) 82 83 84 def _write_file_gzip(path, base64_records): 85 record = binascii.a2b_base64(base64_records) 86 with gzip.GzipFile(path, 'wb') as f: 87 f.write(record) 88 89 90 class TestTFRecordUtil(unittest.TestCase): 91 def setUp(self): 92 self.record = binascii.a2b_base64(FOO_RECORD_BASE64) 93 94 def _as_file_handle(self, contents): 95 result = io.BytesIO() 96 result.write(contents) 97 result.seek(0) 98 return result 99 100 def _increment_value_at_index(self, value, index): 101 l = list(value) 102 l[index] = l[index] + 1 103 return bytes(l) 104 105 def _test_error(self, record, error_text): 106 with self.assertRaisesRegex(ValueError, re.escape(error_text)): 107 _TFRecordUtil.read_record(self._as_file_handle(record)) 108 109 def test_masked_crc32c(self): 110 self.assertEqual(0xfd7fffa, _TFRecordUtil._masked_crc32c(b'\x00' * 32)) 111 self.assertEqual(0xf909b029, _TFRecordUtil._masked_crc32c(b'\xff' * 32)) 112 self.assertEqual(0xfebe8a61, _TFRecordUtil._masked_crc32c(b'foo')) 113 self.assertEqual( 114 0xe4999b0, 115 _TFRecordUtil._masked_crc32c(b'\x03\x00\x00\x00\x00\x00\x00\x00')) 116 117 def test_masked_crc32c_crcmod(self): 118 crc32c_fn = crcmod.predefined.mkPredefinedCrcFun('crc-32c') 119 self.assertEqual( 120 0xfd7fffa, 121 _TFRecordUtil._masked_crc32c(b'\x00' * 32, crc32c_fn=crc32c_fn)) 122 self.assertEqual( 123 0xf909b029, 124 _TFRecordUtil._masked_crc32c(b'\xff' * 32, crc32c_fn=crc32c_fn)) 125 self.assertEqual( 126 0xfebe8a61, _TFRecordUtil._masked_crc32c(b'foo', crc32c_fn=crc32c_fn)) 127 self.assertEqual( 128 0xe4999b0, 129 _TFRecordUtil._masked_crc32c( 130 b'\x03\x00\x00\x00\x00\x00\x00\x00', crc32c_fn=crc32c_fn)) 131 132 def test_write_record(self): 133 file_handle = io.BytesIO() 134 _TFRecordUtil.write_record(file_handle, b'foo') 135 self.assertEqual(self.record, file_handle.getvalue()) 136 137 def test_read_record(self): 138 actual = _TFRecordUtil.read_record(self._as_file_handle(self.record)) 139 self.assertEqual(b'foo', actual) 140 141 def test_read_record_invalid_record(self): 142 self._test_error(b'bar', 'Not a valid TFRecord. Fewer than 12 bytes') 143 144 def test_read_record_invalid_length_mask(self): 145 record = self._increment_value_at_index(self.record, 9) 146 self._test_error(record, 'Mismatch of length mask') 147 148 def test_read_record_invalid_data_mask(self): 149 record = self._increment_value_at_index(self.record, 16) 150 self._test_error(record, 'Mismatch of data mask') 151 152 def test_compatibility_read_write(self): 153 for record in [b'', b'blah', b'another blah']: 154 file_handle = io.BytesIO() 155 _TFRecordUtil.write_record(file_handle, record) 156 file_handle.seek(0) 157 actual = _TFRecordUtil.read_record(file_handle) 158 self.assertEqual(record, actual) 159 160 161 class TestTFRecordSink(unittest.TestCase): 162 def _write_lines(self, sink, path, lines): 163 f = sink.open(path) 164 for l in lines: 165 sink.write_record(f, l) 166 sink.close(f) 167 168 def test_write_record_single(self): 169 with TempDir() as temp_dir: 170 path = temp_dir.create_temp_file('result') 171 record = binascii.a2b_base64(FOO_RECORD_BASE64) 172 sink = _TFRecordSink( 173 path, 174 coder=coders.BytesCoder(), 175 file_name_suffix='', 176 num_shards=0, 177 shard_name_template=None, 178 compression_type=CompressionTypes.UNCOMPRESSED) 179 self._write_lines(sink, path, [b'foo']) 180 181 with open(path, 'rb') as f: 182 self.assertEqual(f.read(), record) 183 184 def test_write_record_multiple(self): 185 with TempDir() as temp_dir: 186 path = temp_dir.create_temp_file('result') 187 record = binascii.a2b_base64(FOO_BAR_RECORD_BASE64) 188 sink = _TFRecordSink( 189 path, 190 coder=coders.BytesCoder(), 191 file_name_suffix='', 192 num_shards=0, 193 shard_name_template=None, 194 compression_type=CompressionTypes.UNCOMPRESSED) 195 self._write_lines(sink, path, [b'foo', b'bar']) 196 197 with open(path, 'rb') as f: 198 self.assertEqual(f.read(), record) 199 200 201 @unittest.skipIf(tf is None, 'tensorflow not installed.') 202 class TestWriteToTFRecord(TestTFRecordSink): 203 def test_write_record_gzip(self): 204 with TempDir() as temp_dir: 205 file_path_prefix = temp_dir.create_temp_file('result') 206 with TestPipeline() as p: 207 input_data = [b'foo', b'bar'] 208 _ = p | beam.Create(input_data) | WriteToTFRecord( 209 file_path_prefix, compression_type=CompressionTypes.GZIP) 210 211 actual = [] 212 file_name = glob.glob(file_path_prefix + '-*')[0] 213 for r in tf.python_io.tf_record_iterator( 214 file_name, 215 options=tf.python_io.TFRecordOptions( 216 tf.python_io.TFRecordCompressionType.GZIP)): 217 actual.append(r) 218 self.assertEqual(sorted(actual), sorted(input_data)) 219 220 def test_write_record_auto(self): 221 with TempDir() as temp_dir: 222 file_path_prefix = temp_dir.create_temp_file('result') 223 with TestPipeline() as p: 224 input_data = [b'foo', b'bar'] 225 _ = p | beam.Create(input_data) | WriteToTFRecord( 226 file_path_prefix, file_name_suffix='.gz') 227 228 actual = [] 229 file_name = glob.glob(file_path_prefix + '-*.gz')[0] 230 for r in tf.python_io.tf_record_iterator( 231 file_name, 232 options=tf.python_io.TFRecordOptions( 233 tf.python_io.TFRecordCompressionType.GZIP)): 234 actual.append(r) 235 self.assertEqual(sorted(actual), sorted(input_data)) 236 237 238 class TestReadFromTFRecord(unittest.TestCase): 239 def test_process_single(self): 240 with TempDir() as temp_dir: 241 path = temp_dir.create_temp_file('result') 242 _write_file(path, FOO_RECORD_BASE64) 243 with TestPipeline() as p: 244 result = ( 245 p 246 | ReadFromTFRecord( 247 path, 248 coder=coders.BytesCoder(), 249 compression_type=CompressionTypes.AUTO, 250 validate=True)) 251 assert_that(result, equal_to([b'foo'])) 252 253 def test_process_multiple(self): 254 with TempDir() as temp_dir: 255 path = temp_dir.create_temp_file('result') 256 _write_file(path, FOO_BAR_RECORD_BASE64) 257 with TestPipeline() as p: 258 result = ( 259 p 260 | ReadFromTFRecord( 261 path, 262 coder=coders.BytesCoder(), 263 compression_type=CompressionTypes.AUTO, 264 validate=True)) 265 assert_that(result, equal_to([b'foo', b'bar'])) 266 267 def test_process_deflate(self): 268 with TempDir() as temp_dir: 269 path = temp_dir.create_temp_file('result') 270 _write_file_deflate(path, FOO_BAR_RECORD_BASE64) 271 with TestPipeline() as p: 272 result = ( 273 p 274 | ReadFromTFRecord( 275 path, 276 coder=coders.BytesCoder(), 277 compression_type=CompressionTypes.DEFLATE, 278 validate=True)) 279 assert_that(result, equal_to([b'foo', b'bar'])) 280 281 def test_process_gzip_with_coder(self): 282 with TempDir() as temp_dir: 283 path = temp_dir.create_temp_file('result') 284 _write_file_gzip(path, FOO_BAR_RECORD_BASE64) 285 with TestPipeline() as p: 286 result = ( 287 p 288 | ReadFromTFRecord( 289 path, 290 coder=coders.BytesCoder(), 291 compression_type=CompressionTypes.GZIP, 292 validate=True)) 293 assert_that(result, equal_to([b'foo', b'bar'])) 294 295 def test_process_gzip_without_coder(self): 296 with TempDir() as temp_dir: 297 path = temp_dir.create_temp_file('result') 298 _write_file_gzip(path, FOO_BAR_RECORD_BASE64) 299 with TestPipeline() as p: 300 result = ( 301 p 302 | ReadFromTFRecord(path, compression_type=CompressionTypes.GZIP)) 303 assert_that(result, equal_to([b'foo', b'bar'])) 304 305 def test_process_auto(self): 306 with TempDir() as temp_dir: 307 path = temp_dir.create_temp_file('result.gz') 308 _write_file_gzip(path, FOO_BAR_RECORD_BASE64) 309 with TestPipeline() as p: 310 result = ( 311 p 312 | ReadFromTFRecord( 313 path, 314 coder=coders.BytesCoder(), 315 compression_type=CompressionTypes.AUTO, 316 validate=True)) 317 assert_that(result, equal_to([b'foo', b'bar'])) 318 319 def test_process_gzip_auto(self): 320 with TempDir() as temp_dir: 321 path = temp_dir.create_temp_file('result.gz') 322 _write_file_gzip(path, FOO_BAR_RECORD_BASE64) 323 with TestPipeline() as p: 324 result = ( 325 p 326 | ReadFromTFRecord(path, compression_type=CompressionTypes.AUTO)) 327 assert_that(result, equal_to([b'foo', b'bar'])) 328 329 330 class TestReadAllFromTFRecord(unittest.TestCase): 331 def _write_glob(self, temp_dir, suffix, include_empty=False): 332 for _ in range(3): 333 path = temp_dir.create_temp_file(suffix) 334 _write_file(path, FOO_BAR_RECORD_BASE64) 335 if include_empty: 336 path = temp_dir.create_temp_file(suffix) 337 _write_file(path, '') 338 339 def test_process_single(self): 340 with TempDir() as temp_dir: 341 path = temp_dir.create_temp_file('result') 342 _write_file(path, FOO_RECORD_BASE64) 343 with TestPipeline() as p: 344 result = ( 345 p 346 | Create([path]) 347 | ReadAllFromTFRecord( 348 coder=coders.BytesCoder(), 349 compression_type=CompressionTypes.AUTO)) 350 assert_that(result, equal_to([b'foo'])) 351 352 def test_process_multiple(self): 353 with TempDir() as temp_dir: 354 path = temp_dir.create_temp_file('result') 355 _write_file(path, FOO_BAR_RECORD_BASE64) 356 with TestPipeline() as p: 357 result = ( 358 p 359 | Create([path]) 360 | ReadAllFromTFRecord( 361 coder=coders.BytesCoder(), 362 compression_type=CompressionTypes.AUTO)) 363 assert_that(result, equal_to([b'foo', b'bar'])) 364 365 def test_process_with_filename(self): 366 with TempDir() as temp_dir: 367 num_files = 3 368 files = [] 369 for i in range(num_files): 370 path = temp_dir.create_temp_file('result%s' % i) 371 _write_file(path, FOO_BAR_RECORD_BASE64) 372 files.append(path) 373 content = [b'foo', b'bar'] 374 expected = [(file, line) for file in files for line in content] 375 pattern = temp_dir.get_path() + os.path.sep + '*' 376 377 with TestPipeline() as p: 378 result = ( 379 p 380 | Create([pattern]) 381 | ReadAllFromTFRecord( 382 coder=coders.BytesCoder(), 383 compression_type=CompressionTypes.AUTO, 384 with_filename=True)) 385 assert_that(result, equal_to(expected)) 386 387 def test_process_glob(self): 388 with TempDir() as temp_dir: 389 self._write_glob(temp_dir, 'result') 390 glob = temp_dir.get_path() + os.path.sep + '*result' 391 with TestPipeline() as p: 392 result = ( 393 p 394 | Create([glob]) 395 | ReadAllFromTFRecord( 396 coder=coders.BytesCoder(), 397 compression_type=CompressionTypes.AUTO)) 398 assert_that(result, equal_to([b'foo', b'bar'] * 3)) 399 400 def test_process_glob_with_empty_file(self): 401 with TempDir() as temp_dir: 402 self._write_glob(temp_dir, 'result', include_empty=True) 403 glob = temp_dir.get_path() + os.path.sep + '*result' 404 with TestPipeline() as p: 405 result = ( 406 p 407 | Create([glob]) 408 | ReadAllFromTFRecord( 409 coder=coders.BytesCoder(), 410 compression_type=CompressionTypes.AUTO)) 411 assert_that(result, equal_to([b'foo', b'bar'] * 3)) 412 413 def test_process_multiple_globs(self): 414 with TempDir() as temp_dir: 415 globs = [] 416 for i in range(3): 417 suffix = 'result' + str(i) 418 self._write_glob(temp_dir, suffix) 419 globs.append(temp_dir.get_path() + os.path.sep + '*' + suffix) 420 421 with TestPipeline() as p: 422 result = ( 423 p 424 | Create(globs) 425 | ReadAllFromTFRecord( 426 coder=coders.BytesCoder(), 427 compression_type=CompressionTypes.AUTO)) 428 assert_that(result, equal_to([b'foo', b'bar'] * 9)) 429 430 def test_process_deflate(self): 431 with TempDir() as temp_dir: 432 path = temp_dir.create_temp_file('result') 433 _write_file_deflate(path, FOO_BAR_RECORD_BASE64) 434 with TestPipeline() as p: 435 result = ( 436 p 437 | Create([path]) 438 | ReadAllFromTFRecord( 439 coder=coders.BytesCoder(), 440 compression_type=CompressionTypes.DEFLATE)) 441 assert_that(result, equal_to([b'foo', b'bar'])) 442 443 def test_process_gzip(self): 444 with TempDir() as temp_dir: 445 path = temp_dir.create_temp_file('result') 446 _write_file_gzip(path, FOO_BAR_RECORD_BASE64) 447 with TestPipeline() as p: 448 result = ( 449 p 450 | Create([path]) 451 | ReadAllFromTFRecord( 452 coder=coders.BytesCoder(), 453 compression_type=CompressionTypes.GZIP)) 454 assert_that(result, equal_to([b'foo', b'bar'])) 455 456 def test_process_auto(self): 457 with TempDir() as temp_dir: 458 path = temp_dir.create_temp_file('result.gz') 459 _write_file_gzip(path, FOO_BAR_RECORD_BASE64) 460 with TestPipeline() as p: 461 result = ( 462 p 463 | Create([path]) 464 | ReadAllFromTFRecord( 465 coder=coders.BytesCoder(), 466 compression_type=CompressionTypes.AUTO)) 467 assert_that(result, equal_to([b'foo', b'bar'])) 468 469 470 class TestEnd2EndWriteAndRead(unittest.TestCase): 471 def create_inputs(self): 472 input_array = [[random.random() - 0.5 for _ in range(15)] 473 for _ in range(12)] 474 memfile = io.BytesIO() 475 pickle.dump(input_array, memfile) 476 return memfile.getvalue() 477 478 def test_end2end(self): 479 with TempDir() as temp_dir: 480 file_path_prefix = temp_dir.create_temp_file('result') 481 482 # Generate a TFRecord file. 483 with TestPipeline() as p: 484 expected_data = [self.create_inputs() for _ in range(0, 10)] 485 _ = p | beam.Create(expected_data) | WriteToTFRecord(file_path_prefix) 486 487 # Read the file back and compare. 488 with TestPipeline() as p: 489 actual_data = p | ReadFromTFRecord(file_path_prefix + '-*') 490 assert_that(actual_data, equal_to(expected_data)) 491 492 def test_end2end_auto_compression(self): 493 with TempDir() as temp_dir: 494 file_path_prefix = temp_dir.create_temp_file('result') 495 496 # Generate a TFRecord file. 497 with TestPipeline() as p: 498 expected_data = [self.create_inputs() for _ in range(0, 10)] 499 _ = p | beam.Create(expected_data) | WriteToTFRecord( 500 file_path_prefix, file_name_suffix='.gz') 501 502 # Read the file back and compare. 503 with TestPipeline() as p: 504 actual_data = p | ReadFromTFRecord(file_path_prefix + '-*') 505 assert_that(actual_data, equal_to(expected_data)) 506 507 def test_end2end_auto_compression_unsharded(self): 508 with TempDir() as temp_dir: 509 file_path_prefix = temp_dir.create_temp_file('result') 510 511 # Generate a TFRecord file. 512 with TestPipeline() as p: 513 expected_data = [self.create_inputs() for _ in range(0, 10)] 514 _ = p | beam.Create(expected_data) | WriteToTFRecord( 515 file_path_prefix + '.gz', shard_name_template='') 516 517 # Read the file back and compare. 518 with TestPipeline() as p: 519 actual_data = p | ReadFromTFRecord(file_path_prefix + '.gz') 520 assert_that(actual_data, equal_to(expected_data)) 521 522 @unittest.skipIf(tf is None, 'tensorflow not installed.') 523 def test_end2end_example_proto(self): 524 with TempDir() as temp_dir: 525 file_path_prefix = temp_dir.create_temp_file('result') 526 527 example = tf.train.Example() 528 example.features.feature['int'].int64_list.value.extend(list(range(3))) 529 example.features.feature['bytes'].bytes_list.value.extend( 530 [b'foo', b'bar']) 531 532 with TestPipeline() as p: 533 _ = p | beam.Create([example]) | WriteToTFRecord( 534 file_path_prefix, coder=beam.coders.ProtoCoder(example.__class__)) 535 536 # Read the file back and compare. 537 with TestPipeline() as p: 538 actual_data = ( 539 p | ReadFromTFRecord( 540 file_path_prefix + '-*', 541 coder=beam.coders.ProtoCoder(example.__class__))) 542 assert_that(actual_data, equal_to([example])) 543 544 def test_end2end_read_write_read(self): 545 with TempDir() as temp_dir: 546 path = temp_dir.create_temp_file('result') 547 with TestPipeline() as p: 548 # Initial read to validate the pipeline doesn't fail before the file is 549 # created. 550 _ = p | ReadFromTFRecord(path + '-*', validate=False) 551 expected_data = [self.create_inputs() for _ in range(0, 10)] 552 _ = p | beam.Create(expected_data) | WriteToTFRecord( 553 path, file_name_suffix='.gz') 554 555 # Read the file back and compare. 556 with TestPipeline() as p: 557 actual_data = p | ReadFromTFRecord(path + '-*', validate=True) 558 assert_that(actual_data, equal_to(expected_data)) 559 560 561 if __name__ == '__main__': 562 logging.getLogger().setLevel(logging.INFO) 563 unittest.main()