github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/tfrecordio_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  # pytype: skip-file
    19  
    20  import binascii
    21  import glob
    22  import gzip
    23  import io
    24  import logging
    25  import os
    26  import pickle
    27  import random
    28  import re
    29  import unittest
    30  import zlib
    31  
    32  import crcmod
    33  
    34  import apache_beam as beam
    35  from apache_beam import Create
    36  from apache_beam import coders
    37  from apache_beam.io.filesystem import CompressionTypes
    38  from apache_beam.io.tfrecordio import ReadAllFromTFRecord
    39  from apache_beam.io.tfrecordio import ReadFromTFRecord
    40  from apache_beam.io.tfrecordio import WriteToTFRecord
    41  from apache_beam.io.tfrecordio import _TFRecordSink
    42  from apache_beam.io.tfrecordio import _TFRecordUtil
    43  from apache_beam.testing.test_pipeline import TestPipeline
    44  from apache_beam.testing.test_utils import TempDir
    45  from apache_beam.testing.util import assert_that
    46  from apache_beam.testing.util import equal_to
    47  
    48  try:
    49    import tensorflow.compat.v1 as tf  # pylint: disable=import-error
    50  except ImportError:
    51    try:
    52      import tensorflow as tf  # pylint: disable=import-error
    53    except ImportError:
    54      tf = None  # pylint: disable=invalid-name
    55      logging.warning('Tensorflow is not installed, so skipping some tests.')
    56  
    57  # Created by running following code in python:
    58  # >>> import tensorflow as tf
    59  # >>> import base64
    60  # >>> writer = tf.python_io.TFRecordWriter('/tmp/python_foo.tfrecord')
    61  # >>> writer.write(b'foo')
    62  # >>> writer.close()
    63  # >>> with open('/tmp/python_foo.tfrecord', 'rb') as f:
    64  # ...   data =  base64.b64encode(f.read())
    65  # ...   print(data)
    66  FOO_RECORD_BASE64 = b'AwAAAAAAAACwmUkOZm9vYYq+/g=='
    67  
    68  # Same as above but containing two records [b'foo', b'bar']
    69  FOO_BAR_RECORD_BASE64 = b'AwAAAAAAAACwmUkOZm9vYYq+/gMAAAAAAAAAsJlJDmJhckYA5cg='
    70  
    71  
    72  def _write_file(path, base64_records):
    73    record = binascii.a2b_base64(base64_records)
    74    with open(path, 'wb') as f:
    75      f.write(record)
    76  
    77  
    78  def _write_file_deflate(path, base64_records):
    79    record = binascii.a2b_base64(base64_records)
    80    with open(path, 'wb') as f:
    81      f.write(zlib.compress(record))
    82  
    83  
    84  def _write_file_gzip(path, base64_records):
    85    record = binascii.a2b_base64(base64_records)
    86    with gzip.GzipFile(path, 'wb') as f:
    87      f.write(record)
    88  
    89  
    90  class TestTFRecordUtil(unittest.TestCase):
    91    def setUp(self):
    92      self.record = binascii.a2b_base64(FOO_RECORD_BASE64)
    93  
    94    def _as_file_handle(self, contents):
    95      result = io.BytesIO()
    96      result.write(contents)
    97      result.seek(0)
    98      return result
    99  
   100    def _increment_value_at_index(self, value, index):
   101      l = list(value)
   102      l[index] = l[index] + 1
   103      return bytes(l)
   104  
   105    def _test_error(self, record, error_text):
   106      with self.assertRaisesRegex(ValueError, re.escape(error_text)):
   107        _TFRecordUtil.read_record(self._as_file_handle(record))
   108  
   109    def test_masked_crc32c(self):
   110      self.assertEqual(0xfd7fffa, _TFRecordUtil._masked_crc32c(b'\x00' * 32))
   111      self.assertEqual(0xf909b029, _TFRecordUtil._masked_crc32c(b'\xff' * 32))
   112      self.assertEqual(0xfebe8a61, _TFRecordUtil._masked_crc32c(b'foo'))
   113      self.assertEqual(
   114          0xe4999b0,
   115          _TFRecordUtil._masked_crc32c(b'\x03\x00\x00\x00\x00\x00\x00\x00'))
   116  
   117    def test_masked_crc32c_crcmod(self):
   118      crc32c_fn = crcmod.predefined.mkPredefinedCrcFun('crc-32c')
   119      self.assertEqual(
   120          0xfd7fffa,
   121          _TFRecordUtil._masked_crc32c(b'\x00' * 32, crc32c_fn=crc32c_fn))
   122      self.assertEqual(
   123          0xf909b029,
   124          _TFRecordUtil._masked_crc32c(b'\xff' * 32, crc32c_fn=crc32c_fn))
   125      self.assertEqual(
   126          0xfebe8a61, _TFRecordUtil._masked_crc32c(b'foo', crc32c_fn=crc32c_fn))
   127      self.assertEqual(
   128          0xe4999b0,
   129          _TFRecordUtil._masked_crc32c(
   130              b'\x03\x00\x00\x00\x00\x00\x00\x00', crc32c_fn=crc32c_fn))
   131  
   132    def test_write_record(self):
   133      file_handle = io.BytesIO()
   134      _TFRecordUtil.write_record(file_handle, b'foo')
   135      self.assertEqual(self.record, file_handle.getvalue())
   136  
   137    def test_read_record(self):
   138      actual = _TFRecordUtil.read_record(self._as_file_handle(self.record))
   139      self.assertEqual(b'foo', actual)
   140  
   141    def test_read_record_invalid_record(self):
   142      self._test_error(b'bar', 'Not a valid TFRecord. Fewer than 12 bytes')
   143  
   144    def test_read_record_invalid_length_mask(self):
   145      record = self._increment_value_at_index(self.record, 9)
   146      self._test_error(record, 'Mismatch of length mask')
   147  
   148    def test_read_record_invalid_data_mask(self):
   149      record = self._increment_value_at_index(self.record, 16)
   150      self._test_error(record, 'Mismatch of data mask')
   151  
   152    def test_compatibility_read_write(self):
   153      for record in [b'', b'blah', b'another blah']:
   154        file_handle = io.BytesIO()
   155        _TFRecordUtil.write_record(file_handle, record)
   156        file_handle.seek(0)
   157        actual = _TFRecordUtil.read_record(file_handle)
   158        self.assertEqual(record, actual)
   159  
   160  
   161  class TestTFRecordSink(unittest.TestCase):
   162    def _write_lines(self, sink, path, lines):
   163      f = sink.open(path)
   164      for l in lines:
   165        sink.write_record(f, l)
   166      sink.close(f)
   167  
   168    def test_write_record_single(self):
   169      with TempDir() as temp_dir:
   170        path = temp_dir.create_temp_file('result')
   171        record = binascii.a2b_base64(FOO_RECORD_BASE64)
   172        sink = _TFRecordSink(
   173            path,
   174            coder=coders.BytesCoder(),
   175            file_name_suffix='',
   176            num_shards=0,
   177            shard_name_template=None,
   178            compression_type=CompressionTypes.UNCOMPRESSED)
   179        self._write_lines(sink, path, [b'foo'])
   180  
   181        with open(path, 'rb') as f:
   182          self.assertEqual(f.read(), record)
   183  
   184    def test_write_record_multiple(self):
   185      with TempDir() as temp_dir:
   186        path = temp_dir.create_temp_file('result')
   187        record = binascii.a2b_base64(FOO_BAR_RECORD_BASE64)
   188        sink = _TFRecordSink(
   189            path,
   190            coder=coders.BytesCoder(),
   191            file_name_suffix='',
   192            num_shards=0,
   193            shard_name_template=None,
   194            compression_type=CompressionTypes.UNCOMPRESSED)
   195        self._write_lines(sink, path, [b'foo', b'bar'])
   196  
   197        with open(path, 'rb') as f:
   198          self.assertEqual(f.read(), record)
   199  
   200  
   201  @unittest.skipIf(tf is None, 'tensorflow not installed.')
   202  class TestWriteToTFRecord(TestTFRecordSink):
   203    def test_write_record_gzip(self):
   204      with TempDir() as temp_dir:
   205        file_path_prefix = temp_dir.create_temp_file('result')
   206        with TestPipeline() as p:
   207          input_data = [b'foo', b'bar']
   208          _ = p | beam.Create(input_data) | WriteToTFRecord(
   209              file_path_prefix, compression_type=CompressionTypes.GZIP)
   210  
   211        actual = []
   212        file_name = glob.glob(file_path_prefix + '-*')[0]
   213        for r in tf.python_io.tf_record_iterator(
   214            file_name,
   215            options=tf.python_io.TFRecordOptions(
   216                tf.python_io.TFRecordCompressionType.GZIP)):
   217          actual.append(r)
   218        self.assertEqual(sorted(actual), sorted(input_data))
   219  
   220    def test_write_record_auto(self):
   221      with TempDir() as temp_dir:
   222        file_path_prefix = temp_dir.create_temp_file('result')
   223        with TestPipeline() as p:
   224          input_data = [b'foo', b'bar']
   225          _ = p | beam.Create(input_data) | WriteToTFRecord(
   226              file_path_prefix, file_name_suffix='.gz')
   227  
   228        actual = []
   229        file_name = glob.glob(file_path_prefix + '-*.gz')[0]
   230        for r in tf.python_io.tf_record_iterator(
   231            file_name,
   232            options=tf.python_io.TFRecordOptions(
   233                tf.python_io.TFRecordCompressionType.GZIP)):
   234          actual.append(r)
   235        self.assertEqual(sorted(actual), sorted(input_data))
   236  
   237  
   238  class TestReadFromTFRecord(unittest.TestCase):
   239    def test_process_single(self):
   240      with TempDir() as temp_dir:
   241        path = temp_dir.create_temp_file('result')
   242        _write_file(path, FOO_RECORD_BASE64)
   243        with TestPipeline() as p:
   244          result = (
   245              p
   246              | ReadFromTFRecord(
   247                  path,
   248                  coder=coders.BytesCoder(),
   249                  compression_type=CompressionTypes.AUTO,
   250                  validate=True))
   251          assert_that(result, equal_to([b'foo']))
   252  
   253    def test_process_multiple(self):
   254      with TempDir() as temp_dir:
   255        path = temp_dir.create_temp_file('result')
   256        _write_file(path, FOO_BAR_RECORD_BASE64)
   257        with TestPipeline() as p:
   258          result = (
   259              p
   260              | ReadFromTFRecord(
   261                  path,
   262                  coder=coders.BytesCoder(),
   263                  compression_type=CompressionTypes.AUTO,
   264                  validate=True))
   265          assert_that(result, equal_to([b'foo', b'bar']))
   266  
   267    def test_process_deflate(self):
   268      with TempDir() as temp_dir:
   269        path = temp_dir.create_temp_file('result')
   270        _write_file_deflate(path, FOO_BAR_RECORD_BASE64)
   271        with TestPipeline() as p:
   272          result = (
   273              p
   274              | ReadFromTFRecord(
   275                  path,
   276                  coder=coders.BytesCoder(),
   277                  compression_type=CompressionTypes.DEFLATE,
   278                  validate=True))
   279          assert_that(result, equal_to([b'foo', b'bar']))
   280  
   281    def test_process_gzip_with_coder(self):
   282      with TempDir() as temp_dir:
   283        path = temp_dir.create_temp_file('result')
   284        _write_file_gzip(path, FOO_BAR_RECORD_BASE64)
   285        with TestPipeline() as p:
   286          result = (
   287              p
   288              | ReadFromTFRecord(
   289                  path,
   290                  coder=coders.BytesCoder(),
   291                  compression_type=CompressionTypes.GZIP,
   292                  validate=True))
   293          assert_that(result, equal_to([b'foo', b'bar']))
   294  
   295    def test_process_gzip_without_coder(self):
   296      with TempDir() as temp_dir:
   297        path = temp_dir.create_temp_file('result')
   298        _write_file_gzip(path, FOO_BAR_RECORD_BASE64)
   299        with TestPipeline() as p:
   300          result = (
   301              p
   302              | ReadFromTFRecord(path, compression_type=CompressionTypes.GZIP))
   303          assert_that(result, equal_to([b'foo', b'bar']))
   304  
   305    def test_process_auto(self):
   306      with TempDir() as temp_dir:
   307        path = temp_dir.create_temp_file('result.gz')
   308        _write_file_gzip(path, FOO_BAR_RECORD_BASE64)
   309        with TestPipeline() as p:
   310          result = (
   311              p
   312              | ReadFromTFRecord(
   313                  path,
   314                  coder=coders.BytesCoder(),
   315                  compression_type=CompressionTypes.AUTO,
   316                  validate=True))
   317          assert_that(result, equal_to([b'foo', b'bar']))
   318  
   319    def test_process_gzip_auto(self):
   320      with TempDir() as temp_dir:
   321        path = temp_dir.create_temp_file('result.gz')
   322        _write_file_gzip(path, FOO_BAR_RECORD_BASE64)
   323        with TestPipeline() as p:
   324          result = (
   325              p
   326              | ReadFromTFRecord(path, compression_type=CompressionTypes.AUTO))
   327          assert_that(result, equal_to([b'foo', b'bar']))
   328  
   329  
   330  class TestReadAllFromTFRecord(unittest.TestCase):
   331    def _write_glob(self, temp_dir, suffix, include_empty=False):
   332      for _ in range(3):
   333        path = temp_dir.create_temp_file(suffix)
   334        _write_file(path, FOO_BAR_RECORD_BASE64)
   335      if include_empty:
   336        path = temp_dir.create_temp_file(suffix)
   337        _write_file(path, '')
   338  
   339    def test_process_single(self):
   340      with TempDir() as temp_dir:
   341        path = temp_dir.create_temp_file('result')
   342        _write_file(path, FOO_RECORD_BASE64)
   343        with TestPipeline() as p:
   344          result = (
   345              p
   346              | Create([path])
   347              | ReadAllFromTFRecord(
   348                  coder=coders.BytesCoder(),
   349                  compression_type=CompressionTypes.AUTO))
   350          assert_that(result, equal_to([b'foo']))
   351  
   352    def test_process_multiple(self):
   353      with TempDir() as temp_dir:
   354        path = temp_dir.create_temp_file('result')
   355        _write_file(path, FOO_BAR_RECORD_BASE64)
   356        with TestPipeline() as p:
   357          result = (
   358              p
   359              | Create([path])
   360              | ReadAllFromTFRecord(
   361                  coder=coders.BytesCoder(),
   362                  compression_type=CompressionTypes.AUTO))
   363          assert_that(result, equal_to([b'foo', b'bar']))
   364  
   365    def test_process_with_filename(self):
   366      with TempDir() as temp_dir:
   367        num_files = 3
   368        files = []
   369        for i in range(num_files):
   370          path = temp_dir.create_temp_file('result%s' % i)
   371          _write_file(path, FOO_BAR_RECORD_BASE64)
   372          files.append(path)
   373        content = [b'foo', b'bar']
   374        expected = [(file, line) for file in files for line in content]
   375        pattern = temp_dir.get_path() + os.path.sep + '*'
   376  
   377        with TestPipeline() as p:
   378          result = (
   379              p
   380              | Create([pattern])
   381              | ReadAllFromTFRecord(
   382                  coder=coders.BytesCoder(),
   383                  compression_type=CompressionTypes.AUTO,
   384                  with_filename=True))
   385          assert_that(result, equal_to(expected))
   386  
   387    def test_process_glob(self):
   388      with TempDir() as temp_dir:
   389        self._write_glob(temp_dir, 'result')
   390        glob = temp_dir.get_path() + os.path.sep + '*result'
   391        with TestPipeline() as p:
   392          result = (
   393              p
   394              | Create([glob])
   395              | ReadAllFromTFRecord(
   396                  coder=coders.BytesCoder(),
   397                  compression_type=CompressionTypes.AUTO))
   398          assert_that(result, equal_to([b'foo', b'bar'] * 3))
   399  
   400    def test_process_glob_with_empty_file(self):
   401      with TempDir() as temp_dir:
   402        self._write_glob(temp_dir, 'result', include_empty=True)
   403        glob = temp_dir.get_path() + os.path.sep + '*result'
   404        with TestPipeline() as p:
   405          result = (
   406              p
   407              | Create([glob])
   408              | ReadAllFromTFRecord(
   409                  coder=coders.BytesCoder(),
   410                  compression_type=CompressionTypes.AUTO))
   411          assert_that(result, equal_to([b'foo', b'bar'] * 3))
   412  
   413    def test_process_multiple_globs(self):
   414      with TempDir() as temp_dir:
   415        globs = []
   416        for i in range(3):
   417          suffix = 'result' + str(i)
   418          self._write_glob(temp_dir, suffix)
   419          globs.append(temp_dir.get_path() + os.path.sep + '*' + suffix)
   420  
   421        with TestPipeline() as p:
   422          result = (
   423              p
   424              | Create(globs)
   425              | ReadAllFromTFRecord(
   426                  coder=coders.BytesCoder(),
   427                  compression_type=CompressionTypes.AUTO))
   428          assert_that(result, equal_to([b'foo', b'bar'] * 9))
   429  
   430    def test_process_deflate(self):
   431      with TempDir() as temp_dir:
   432        path = temp_dir.create_temp_file('result')
   433        _write_file_deflate(path, FOO_BAR_RECORD_BASE64)
   434        with TestPipeline() as p:
   435          result = (
   436              p
   437              | Create([path])
   438              | ReadAllFromTFRecord(
   439                  coder=coders.BytesCoder(),
   440                  compression_type=CompressionTypes.DEFLATE))
   441          assert_that(result, equal_to([b'foo', b'bar']))
   442  
   443    def test_process_gzip(self):
   444      with TempDir() as temp_dir:
   445        path = temp_dir.create_temp_file('result')
   446        _write_file_gzip(path, FOO_BAR_RECORD_BASE64)
   447        with TestPipeline() as p:
   448          result = (
   449              p
   450              | Create([path])
   451              | ReadAllFromTFRecord(
   452                  coder=coders.BytesCoder(),
   453                  compression_type=CompressionTypes.GZIP))
   454          assert_that(result, equal_to([b'foo', b'bar']))
   455  
   456    def test_process_auto(self):
   457      with TempDir() as temp_dir:
   458        path = temp_dir.create_temp_file('result.gz')
   459        _write_file_gzip(path, FOO_BAR_RECORD_BASE64)
   460        with TestPipeline() as p:
   461          result = (
   462              p
   463              | Create([path])
   464              | ReadAllFromTFRecord(
   465                  coder=coders.BytesCoder(),
   466                  compression_type=CompressionTypes.AUTO))
   467          assert_that(result, equal_to([b'foo', b'bar']))
   468  
   469  
   470  class TestEnd2EndWriteAndRead(unittest.TestCase):
   471    def create_inputs(self):
   472      input_array = [[random.random() - 0.5 for _ in range(15)]
   473                     for _ in range(12)]
   474      memfile = io.BytesIO()
   475      pickle.dump(input_array, memfile)
   476      return memfile.getvalue()
   477  
   478    def test_end2end(self):
   479      with TempDir() as temp_dir:
   480        file_path_prefix = temp_dir.create_temp_file('result')
   481  
   482        # Generate a TFRecord file.
   483        with TestPipeline() as p:
   484          expected_data = [self.create_inputs() for _ in range(0, 10)]
   485          _ = p | beam.Create(expected_data) | WriteToTFRecord(file_path_prefix)
   486  
   487        # Read the file back and compare.
   488        with TestPipeline() as p:
   489          actual_data = p | ReadFromTFRecord(file_path_prefix + '-*')
   490          assert_that(actual_data, equal_to(expected_data))
   491  
   492    def test_end2end_auto_compression(self):
   493      with TempDir() as temp_dir:
   494        file_path_prefix = temp_dir.create_temp_file('result')
   495  
   496        # Generate a TFRecord file.
   497        with TestPipeline() as p:
   498          expected_data = [self.create_inputs() for _ in range(0, 10)]
   499          _ = p | beam.Create(expected_data) | WriteToTFRecord(
   500              file_path_prefix, file_name_suffix='.gz')
   501  
   502        # Read the file back and compare.
   503        with TestPipeline() as p:
   504          actual_data = p | ReadFromTFRecord(file_path_prefix + '-*')
   505          assert_that(actual_data, equal_to(expected_data))
   506  
   507    def test_end2end_auto_compression_unsharded(self):
   508      with TempDir() as temp_dir:
   509        file_path_prefix = temp_dir.create_temp_file('result')
   510  
   511        # Generate a TFRecord file.
   512        with TestPipeline() as p:
   513          expected_data = [self.create_inputs() for _ in range(0, 10)]
   514          _ = p | beam.Create(expected_data) | WriteToTFRecord(
   515              file_path_prefix + '.gz', shard_name_template='')
   516  
   517        # Read the file back and compare.
   518        with TestPipeline() as p:
   519          actual_data = p | ReadFromTFRecord(file_path_prefix + '.gz')
   520          assert_that(actual_data, equal_to(expected_data))
   521  
   522    @unittest.skipIf(tf is None, 'tensorflow not installed.')
   523    def test_end2end_example_proto(self):
   524      with TempDir() as temp_dir:
   525        file_path_prefix = temp_dir.create_temp_file('result')
   526  
   527        example = tf.train.Example()
   528        example.features.feature['int'].int64_list.value.extend(list(range(3)))
   529        example.features.feature['bytes'].bytes_list.value.extend(
   530            [b'foo', b'bar'])
   531  
   532        with TestPipeline() as p:
   533          _ = p | beam.Create([example]) | WriteToTFRecord(
   534              file_path_prefix, coder=beam.coders.ProtoCoder(example.__class__))
   535  
   536        # Read the file back and compare.
   537        with TestPipeline() as p:
   538          actual_data = (
   539              p | ReadFromTFRecord(
   540                  file_path_prefix + '-*',
   541                  coder=beam.coders.ProtoCoder(example.__class__)))
   542          assert_that(actual_data, equal_to([example]))
   543  
   544    def test_end2end_read_write_read(self):
   545      with TempDir() as temp_dir:
   546        path = temp_dir.create_temp_file('result')
   547        with TestPipeline() as p:
   548          # Initial read to validate the pipeline doesn't fail before the file is
   549          # created.
   550          _ = p | ReadFromTFRecord(path + '-*', validate=False)
   551          expected_data = [self.create_inputs() for _ in range(0, 10)]
   552          _ = p | beam.Create(expected_data) | WriteToTFRecord(
   553              path, file_name_suffix='.gz')
   554  
   555        # Read the file back and compare.
   556        with TestPipeline() as p:
   557          actual_data = p | ReadFromTFRecord(path + '-*', validate=True)
   558          assert_that(actual_data, equal_to(expected_data))
   559  
   560  
   561  if __name__ == '__main__':
   562    logging.getLogger().setLevel(logging.INFO)
   563    unittest.main()