github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/tfrecordio.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """TFRecord sources and sinks."""
    19  
    20  # pytype: skip-file
    21  
    22  import codecs
    23  import logging
    24  import struct
    25  from functools import partial
    26  
    27  import crcmod
    28  
    29  from apache_beam import coders
    30  from apache_beam.io import filebasedsink
    31  from apache_beam.io.filebasedsource import FileBasedSource
    32  from apache_beam.io.filebasedsource import ReadAllFiles
    33  from apache_beam.io.filesystem import CompressionTypes
    34  from apache_beam.io.iobase import Read
    35  from apache_beam.io.iobase import Write
    36  from apache_beam.transforms import PTransform
    37  
    38  __all__ = ['ReadFromTFRecord', 'ReadAllFromTFRecord', 'WriteToTFRecord']
    39  
    40  _LOGGER = logging.getLogger(__name__)
    41  
    42  
    43  def _default_crc32c_fn(value):
    44    """Calculates crc32c of a bytes object using either snappy or crcmod."""
    45  
    46    if not _default_crc32c_fn.fn:
    47      try:
    48        import snappy  # pylint: disable=import-error
    49        # Support multiple versions of python-snappy:
    50        # https://github.com/andrix/python-snappy/pull/53
    51        if getattr(snappy, '_crc32c', None):
    52          _default_crc32c_fn.fn = snappy._crc32c  # pylint: disable=protected-access
    53        elif getattr(snappy, '_snappy', None):
    54          _default_crc32c_fn.fn = snappy._snappy._crc32c  # pylint: disable=protected-access
    55      except ImportError:
    56        pass
    57  
    58      if not _default_crc32c_fn.fn:
    59        _LOGGER.warning(
    60            'Couldn\'t find python-snappy so the implementation of '
    61            '_TFRecordUtil._masked_crc32c is not as fast as it could '
    62            'be.')
    63        _default_crc32c_fn.fn = crcmod.predefined.mkPredefinedCrcFun('crc-32c')
    64    return _default_crc32c_fn.fn(value)
    65  
    66  
    67  _default_crc32c_fn.fn = None  # type: ignore
    68  
    69  
    70  class _TFRecordUtil(object):
    71    """Provides basic TFRecord encoding/decoding with consistency checks.
    72  
    73    For detailed TFRecord format description see:
    74      https://www.tensorflow.org/versions/r1.11/api_guides/python/python_io#TFRecords_Format_Details
    75  
    76    Note that masks and length are represented in LittleEndian order.
    77    """
    78    @classmethod
    79    def _masked_crc32c(cls, value, crc32c_fn=_default_crc32c_fn):
    80      """Compute a masked crc32c checksum for a value.
    81  
    82      Args:
    83        value: A bytes object for which we compute the crc.
    84        crc32c_fn: A function that can compute a crc32c.
    85          This is a performance hook that also helps with testing. Callers are
    86          not expected to make use of it directly.
    87      Returns:
    88        Masked crc32c checksum.
    89      """
    90  
    91      crc = crc32c_fn(value)
    92      return (((crc >> 15) | (crc << 17)) + 0xa282ead8) & 0xffffffff
    93  
    94    @staticmethod
    95    def encoded_num_bytes(record):
    96      """Return the number of bytes consumed by a record in its encoded form."""
    97      # 16 = 8 (Length) + 4 (crc of length) + 4 (crc of data)
    98      return len(record) + 16
    99  
   100    @classmethod
   101    def write_record(cls, file_handle, value):
   102      """Encode a value as a TFRecord.
   103  
   104      Args:
   105        file_handle: The file to write to.
   106        value: A bytes object representing content of the record.
   107      """
   108      encoded_length = struct.pack(b'<Q', len(value))
   109      file_handle.write(
   110          b''.join([
   111              encoded_length,
   112              struct.pack(b'<I', cls._masked_crc32c(encoded_length)),
   113              value,
   114              struct.pack(b'<I', cls._masked_crc32c(value))
   115          ]))
   116  
   117    @classmethod
   118    def read_record(cls, file_handle):
   119      """Read a record from a TFRecords file.
   120  
   121      Args:
   122        file_handle: The file to read from.
   123      Returns:
   124        None if EOF is reached; the paylod of the record otherwise.
   125      Raises:
   126        ValueError: If file appears to not be a valid TFRecords file.
   127      """
   128      buf_length_expected = 12
   129      buf = file_handle.read(buf_length_expected)
   130      if not buf:
   131        return None  # EOF Reached.
   132  
   133      # Validate all length related payloads.
   134      if len(buf) != buf_length_expected:
   135        raise ValueError(
   136            'Not a valid TFRecord. Fewer than %d bytes: %s' %
   137            (buf_length_expected, codecs.encode(buf, 'hex')))
   138      length, length_mask_expected = struct.unpack('<QI', buf)
   139      length_mask_actual = cls._masked_crc32c(buf[:8])
   140      if length_mask_actual != length_mask_expected:
   141        raise ValueError(
   142            'Not a valid TFRecord. Mismatch of length mask: %s' %
   143            codecs.encode(buf, 'hex'))
   144  
   145      # Validate all data related payloads.
   146      buf_length_expected = length + 4
   147      buf = file_handle.read(buf_length_expected)
   148      if len(buf) != buf_length_expected:
   149        raise ValueError(
   150            'Not a valid TFRecord. Fewer than %d bytes: %s' %
   151            (buf_length_expected, codecs.encode(buf, 'hex')))
   152      data, data_mask_expected = struct.unpack('<%dsI' % length, buf)
   153      data_mask_actual = cls._masked_crc32c(data)
   154      if data_mask_actual != data_mask_expected:
   155        raise ValueError(
   156            'Not a valid TFRecord. Mismatch of data mask: %s' %
   157            codecs.encode(buf, 'hex'))
   158  
   159      # All validation checks passed.
   160      return data
   161  
   162  
   163  class _TFRecordSource(FileBasedSource):
   164    """A File source for reading files of TFRecords.
   165  
   166    For detailed TFRecords format description see:
   167      https://www.tensorflow.org/versions/r1.11/api_guides/python/python_io#TFRecords_Format_Details
   168    """
   169    def __init__(self, file_pattern, coder, compression_type, validate):
   170      """Initialize a TFRecordSource.  See ReadFromTFRecord for details."""
   171      super().__init__(
   172          file_pattern=file_pattern,
   173          compression_type=compression_type,
   174          splittable=False,
   175          validate=validate)
   176      self._coder = coder
   177  
   178    def read_records(self, file_name, offset_range_tracker):
   179      if offset_range_tracker.start_position():
   180        raise ValueError(
   181            'Start position not 0:%s' % offset_range_tracker.start_position())
   182  
   183      current_offset = offset_range_tracker.start_position()
   184      with self.open_file(file_name) as file_handle:
   185        while True:
   186          if not offset_range_tracker.try_claim(current_offset):
   187            raise RuntimeError('Unable to claim position: %s' % current_offset)
   188          record = _TFRecordUtil.read_record(file_handle)
   189          if record is None:
   190            return  # Reached EOF
   191          else:
   192            current_offset += _TFRecordUtil.encoded_num_bytes(record)
   193            yield self._coder.decode(record)
   194  
   195  
   196  def _create_tfrecordio_source(
   197      file_pattern=None, coder=None, compression_type=None):
   198    # We intentionally disable validation for ReadAll pattern so that reading does
   199    # not fail for globs (elements) that are empty.
   200    return _TFRecordSource(file_pattern, coder, compression_type, validate=False)
   201  
   202  
   203  class ReadAllFromTFRecord(PTransform):
   204    """A ``PTransform`` for reading a ``PCollection`` of TFRecord files."""
   205    def __init__(
   206        self,
   207        coder=coders.BytesCoder(),
   208        compression_type=CompressionTypes.AUTO,
   209        with_filename=False):
   210      """Initialize the ``ReadAllFromTFRecord`` transform.
   211  
   212      Args:
   213        coder: Coder used to decode each record.
   214        compression_type: Used to handle compressed input files. Default value
   215            is CompressionTypes.AUTO, in which case the file_path's extension will
   216            be used to detect the compression.
   217        with_filename: If True, returns a Key Value with the key being the file
   218          name and the value being the actual data. If False, it only returns
   219          the data.
   220      """
   221      super().__init__()
   222      source_from_file = partial(
   223          _create_tfrecordio_source,
   224          compression_type=compression_type,
   225          coder=coder)
   226      # Desired and min bundle sizes do not matter since TFRecord files are
   227      # unsplittable.
   228      self._read_all_files = ReadAllFiles(
   229          splittable=False,
   230          compression_type=compression_type,
   231          desired_bundle_size=0,
   232          min_bundle_size=0,
   233          source_from_file=source_from_file,
   234          with_filename=with_filename)
   235  
   236    def expand(self, pvalue):
   237      return pvalue | 'ReadAllFiles' >> self._read_all_files
   238  
   239  
   240  class ReadFromTFRecord(PTransform):
   241    """Transform for reading TFRecord sources."""
   242    def __init__(
   243        self,
   244        file_pattern,
   245        coder=coders.BytesCoder(),
   246        compression_type=CompressionTypes.AUTO,
   247        validate=True):
   248      """Initialize a ReadFromTFRecord transform.
   249  
   250      Args:
   251        file_pattern: A file glob pattern to read TFRecords from.
   252        coder: Coder used to decode each record.
   253        compression_type: Used to handle compressed input files. Default value
   254            is CompressionTypes.AUTO, in which case the file_path's extension will
   255            be used to detect the compression.
   256        validate: Boolean flag to verify that the files exist during the pipeline
   257            creation time.
   258  
   259      Returns:
   260        A ReadFromTFRecord transform object.
   261      """
   262      super().__init__()
   263      self._source = _TFRecordSource(
   264          file_pattern, coder, compression_type, validate)
   265  
   266    def expand(self, pvalue):
   267      return pvalue.pipeline | Read(self._source)
   268  
   269  
   270  class _TFRecordSink(filebasedsink.FileBasedSink):
   271    """Sink for writing TFRecords files.
   272  
   273    For detailed TFRecord format description see:
   274      https://www.tensorflow.org/versions/r1.11/api_guides/python/python_io#TFRecords_Format_Details
   275    """
   276    def __init__(
   277        self,
   278        file_path_prefix,
   279        coder,
   280        file_name_suffix,
   281        num_shards,
   282        shard_name_template,
   283        compression_type):
   284      """Initialize a TFRecordSink. See WriteToTFRecord for details."""
   285  
   286      super().__init__(
   287          file_path_prefix=file_path_prefix,
   288          coder=coder,
   289          file_name_suffix=file_name_suffix,
   290          num_shards=num_shards,
   291          shard_name_template=shard_name_template,
   292          mime_type='application/octet-stream',
   293          compression_type=compression_type)
   294  
   295    def write_encoded_record(self, file_handle, value):
   296      _TFRecordUtil.write_record(file_handle, value)
   297  
   298  
   299  class WriteToTFRecord(PTransform):
   300    """Transform for writing to TFRecord sinks."""
   301    def __init__(
   302        self,
   303        file_path_prefix,
   304        coder=coders.BytesCoder(),
   305        file_name_suffix='',
   306        num_shards=0,
   307        shard_name_template=None,
   308        compression_type=CompressionTypes.AUTO):
   309      """Initialize WriteToTFRecord transform.
   310  
   311      Args:
   312        file_path_prefix: The file path to write to. The files written will begin
   313          with this prefix, followed by a shard identifier (see num_shards), and
   314          end in a common extension, if given by file_name_suffix.
   315        coder: Coder used to encode each record.
   316        file_name_suffix: Suffix for the files written.
   317        num_shards: The number of files (shards) used for output. If not set, the
   318          default value will be used.
   319        shard_name_template: A template string containing placeholders for
   320          the shard number and shard count. When constructing a filename for a
   321          particular shard number, the upper-case letters 'S' and 'N' are
   322          replaced with the 0-padded shard number and shard count respectively.
   323          This argument can be '' in which case it behaves as if num_shards was
   324          set to 1 and only one file will be generated. The default pattern used
   325          is '-SSSSS-of-NNNNN' if None is passed as the shard_name_template.
   326        compression_type: Used to handle compressed output files. Typical value
   327            is CompressionTypes.AUTO, in which case the file_path's extension will
   328            be used to detect the compression.
   329  
   330      Returns:
   331        A WriteToTFRecord transform object.
   332      """
   333      super().__init__()
   334      self._sink = _TFRecordSink(
   335          file_path_prefix,
   336          coder,
   337          file_name_suffix,
   338          num_shards,
   339          shard_name_template,
   340          compression_type)
   341  
   342    def expand(self, pcoll):
   343      return pcoll | Write(self._sink)