github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/tfrecordio.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """TFRecord sources and sinks.""" 19 20 # pytype: skip-file 21 22 import codecs 23 import logging 24 import struct 25 from functools import partial 26 27 import crcmod 28 29 from apache_beam import coders 30 from apache_beam.io import filebasedsink 31 from apache_beam.io.filebasedsource import FileBasedSource 32 from apache_beam.io.filebasedsource import ReadAllFiles 33 from apache_beam.io.filesystem import CompressionTypes 34 from apache_beam.io.iobase import Read 35 from apache_beam.io.iobase import Write 36 from apache_beam.transforms import PTransform 37 38 __all__ = ['ReadFromTFRecord', 'ReadAllFromTFRecord', 'WriteToTFRecord'] 39 40 _LOGGER = logging.getLogger(__name__) 41 42 43 def _default_crc32c_fn(value): 44 """Calculates crc32c of a bytes object using either snappy or crcmod.""" 45 46 if not _default_crc32c_fn.fn: 47 try: 48 import snappy # pylint: disable=import-error 49 # Support multiple versions of python-snappy: 50 # https://github.com/andrix/python-snappy/pull/53 51 if getattr(snappy, '_crc32c', None): 52 _default_crc32c_fn.fn = snappy._crc32c # pylint: disable=protected-access 53 elif getattr(snappy, '_snappy', None): 54 _default_crc32c_fn.fn = snappy._snappy._crc32c # pylint: disable=protected-access 55 except ImportError: 56 pass 57 58 if not _default_crc32c_fn.fn: 59 _LOGGER.warning( 60 'Couldn\'t find python-snappy so the implementation of ' 61 '_TFRecordUtil._masked_crc32c is not as fast as it could ' 62 'be.') 63 _default_crc32c_fn.fn = crcmod.predefined.mkPredefinedCrcFun('crc-32c') 64 return _default_crc32c_fn.fn(value) 65 66 67 _default_crc32c_fn.fn = None # type: ignore 68 69 70 class _TFRecordUtil(object): 71 """Provides basic TFRecord encoding/decoding with consistency checks. 72 73 For detailed TFRecord format description see: 74 https://www.tensorflow.org/versions/r1.11/api_guides/python/python_io#TFRecords_Format_Details 75 76 Note that masks and length are represented in LittleEndian order. 77 """ 78 @classmethod 79 def _masked_crc32c(cls, value, crc32c_fn=_default_crc32c_fn): 80 """Compute a masked crc32c checksum for a value. 81 82 Args: 83 value: A bytes object for which we compute the crc. 84 crc32c_fn: A function that can compute a crc32c. 85 This is a performance hook that also helps with testing. Callers are 86 not expected to make use of it directly. 87 Returns: 88 Masked crc32c checksum. 89 """ 90 91 crc = crc32c_fn(value) 92 return (((crc >> 15) | (crc << 17)) + 0xa282ead8) & 0xffffffff 93 94 @staticmethod 95 def encoded_num_bytes(record): 96 """Return the number of bytes consumed by a record in its encoded form.""" 97 # 16 = 8 (Length) + 4 (crc of length) + 4 (crc of data) 98 return len(record) + 16 99 100 @classmethod 101 def write_record(cls, file_handle, value): 102 """Encode a value as a TFRecord. 103 104 Args: 105 file_handle: The file to write to. 106 value: A bytes object representing content of the record. 107 """ 108 encoded_length = struct.pack(b'<Q', len(value)) 109 file_handle.write( 110 b''.join([ 111 encoded_length, 112 struct.pack(b'<I', cls._masked_crc32c(encoded_length)), 113 value, 114 struct.pack(b'<I', cls._masked_crc32c(value)) 115 ])) 116 117 @classmethod 118 def read_record(cls, file_handle): 119 """Read a record from a TFRecords file. 120 121 Args: 122 file_handle: The file to read from. 123 Returns: 124 None if EOF is reached; the paylod of the record otherwise. 125 Raises: 126 ValueError: If file appears to not be a valid TFRecords file. 127 """ 128 buf_length_expected = 12 129 buf = file_handle.read(buf_length_expected) 130 if not buf: 131 return None # EOF Reached. 132 133 # Validate all length related payloads. 134 if len(buf) != buf_length_expected: 135 raise ValueError( 136 'Not a valid TFRecord. Fewer than %d bytes: %s' % 137 (buf_length_expected, codecs.encode(buf, 'hex'))) 138 length, length_mask_expected = struct.unpack('<QI', buf) 139 length_mask_actual = cls._masked_crc32c(buf[:8]) 140 if length_mask_actual != length_mask_expected: 141 raise ValueError( 142 'Not a valid TFRecord. Mismatch of length mask: %s' % 143 codecs.encode(buf, 'hex')) 144 145 # Validate all data related payloads. 146 buf_length_expected = length + 4 147 buf = file_handle.read(buf_length_expected) 148 if len(buf) != buf_length_expected: 149 raise ValueError( 150 'Not a valid TFRecord. Fewer than %d bytes: %s' % 151 (buf_length_expected, codecs.encode(buf, 'hex'))) 152 data, data_mask_expected = struct.unpack('<%dsI' % length, buf) 153 data_mask_actual = cls._masked_crc32c(data) 154 if data_mask_actual != data_mask_expected: 155 raise ValueError( 156 'Not a valid TFRecord. Mismatch of data mask: %s' % 157 codecs.encode(buf, 'hex')) 158 159 # All validation checks passed. 160 return data 161 162 163 class _TFRecordSource(FileBasedSource): 164 """A File source for reading files of TFRecords. 165 166 For detailed TFRecords format description see: 167 https://www.tensorflow.org/versions/r1.11/api_guides/python/python_io#TFRecords_Format_Details 168 """ 169 def __init__(self, file_pattern, coder, compression_type, validate): 170 """Initialize a TFRecordSource. See ReadFromTFRecord for details.""" 171 super().__init__( 172 file_pattern=file_pattern, 173 compression_type=compression_type, 174 splittable=False, 175 validate=validate) 176 self._coder = coder 177 178 def read_records(self, file_name, offset_range_tracker): 179 if offset_range_tracker.start_position(): 180 raise ValueError( 181 'Start position not 0:%s' % offset_range_tracker.start_position()) 182 183 current_offset = offset_range_tracker.start_position() 184 with self.open_file(file_name) as file_handle: 185 while True: 186 if not offset_range_tracker.try_claim(current_offset): 187 raise RuntimeError('Unable to claim position: %s' % current_offset) 188 record = _TFRecordUtil.read_record(file_handle) 189 if record is None: 190 return # Reached EOF 191 else: 192 current_offset += _TFRecordUtil.encoded_num_bytes(record) 193 yield self._coder.decode(record) 194 195 196 def _create_tfrecordio_source( 197 file_pattern=None, coder=None, compression_type=None): 198 # We intentionally disable validation for ReadAll pattern so that reading does 199 # not fail for globs (elements) that are empty. 200 return _TFRecordSource(file_pattern, coder, compression_type, validate=False) 201 202 203 class ReadAllFromTFRecord(PTransform): 204 """A ``PTransform`` for reading a ``PCollection`` of TFRecord files.""" 205 def __init__( 206 self, 207 coder=coders.BytesCoder(), 208 compression_type=CompressionTypes.AUTO, 209 with_filename=False): 210 """Initialize the ``ReadAllFromTFRecord`` transform. 211 212 Args: 213 coder: Coder used to decode each record. 214 compression_type: Used to handle compressed input files. Default value 215 is CompressionTypes.AUTO, in which case the file_path's extension will 216 be used to detect the compression. 217 with_filename: If True, returns a Key Value with the key being the file 218 name and the value being the actual data. If False, it only returns 219 the data. 220 """ 221 super().__init__() 222 source_from_file = partial( 223 _create_tfrecordio_source, 224 compression_type=compression_type, 225 coder=coder) 226 # Desired and min bundle sizes do not matter since TFRecord files are 227 # unsplittable. 228 self._read_all_files = ReadAllFiles( 229 splittable=False, 230 compression_type=compression_type, 231 desired_bundle_size=0, 232 min_bundle_size=0, 233 source_from_file=source_from_file, 234 with_filename=with_filename) 235 236 def expand(self, pvalue): 237 return pvalue | 'ReadAllFiles' >> self._read_all_files 238 239 240 class ReadFromTFRecord(PTransform): 241 """Transform for reading TFRecord sources.""" 242 def __init__( 243 self, 244 file_pattern, 245 coder=coders.BytesCoder(), 246 compression_type=CompressionTypes.AUTO, 247 validate=True): 248 """Initialize a ReadFromTFRecord transform. 249 250 Args: 251 file_pattern: A file glob pattern to read TFRecords from. 252 coder: Coder used to decode each record. 253 compression_type: Used to handle compressed input files. Default value 254 is CompressionTypes.AUTO, in which case the file_path's extension will 255 be used to detect the compression. 256 validate: Boolean flag to verify that the files exist during the pipeline 257 creation time. 258 259 Returns: 260 A ReadFromTFRecord transform object. 261 """ 262 super().__init__() 263 self._source = _TFRecordSource( 264 file_pattern, coder, compression_type, validate) 265 266 def expand(self, pvalue): 267 return pvalue.pipeline | Read(self._source) 268 269 270 class _TFRecordSink(filebasedsink.FileBasedSink): 271 """Sink for writing TFRecords files. 272 273 For detailed TFRecord format description see: 274 https://www.tensorflow.org/versions/r1.11/api_guides/python/python_io#TFRecords_Format_Details 275 """ 276 def __init__( 277 self, 278 file_path_prefix, 279 coder, 280 file_name_suffix, 281 num_shards, 282 shard_name_template, 283 compression_type): 284 """Initialize a TFRecordSink. See WriteToTFRecord for details.""" 285 286 super().__init__( 287 file_path_prefix=file_path_prefix, 288 coder=coder, 289 file_name_suffix=file_name_suffix, 290 num_shards=num_shards, 291 shard_name_template=shard_name_template, 292 mime_type='application/octet-stream', 293 compression_type=compression_type) 294 295 def write_encoded_record(self, file_handle, value): 296 _TFRecordUtil.write_record(file_handle, value) 297 298 299 class WriteToTFRecord(PTransform): 300 """Transform for writing to TFRecord sinks.""" 301 def __init__( 302 self, 303 file_path_prefix, 304 coder=coders.BytesCoder(), 305 file_name_suffix='', 306 num_shards=0, 307 shard_name_template=None, 308 compression_type=CompressionTypes.AUTO): 309 """Initialize WriteToTFRecord transform. 310 311 Args: 312 file_path_prefix: The file path to write to. The files written will begin 313 with this prefix, followed by a shard identifier (see num_shards), and 314 end in a common extension, if given by file_name_suffix. 315 coder: Coder used to encode each record. 316 file_name_suffix: Suffix for the files written. 317 num_shards: The number of files (shards) used for output. If not set, the 318 default value will be used. 319 shard_name_template: A template string containing placeholders for 320 the shard number and shard count. When constructing a filename for a 321 particular shard number, the upper-case letters 'S' and 'N' are 322 replaced with the 0-padded shard number and shard count respectively. 323 This argument can be '' in which case it behaves as if num_shards was 324 set to 1 and only one file will be generated. The default pattern used 325 is '-SSSSS-of-NNNNN' if None is passed as the shard_name_template. 326 compression_type: Used to handle compressed output files. Typical value 327 is CompressionTypes.AUTO, in which case the file_path's extension will 328 be used to detect the compression. 329 330 Returns: 331 A WriteToTFRecord transform object. 332 """ 333 super().__init__() 334 self._sink = _TFRecordSink( 335 file_path_prefix, 336 coder, 337 file_name_suffix, 338 num_shards, 339 shard_name_template, 340 compression_type) 341 342 def expand(self, pcoll): 343 return pcoll | Write(self._sink)