github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/filebasedsink.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """File-based sink.""" 19 20 # pytype: skip-file 21 22 import logging 23 import os 24 import re 25 import time 26 import uuid 27 28 from apache_beam.internal import util 29 from apache_beam.io import iobase 30 from apache_beam.io.filesystem import BeamIOError 31 from apache_beam.io.filesystem import CompressionTypes 32 from apache_beam.io.filesystems import FileSystems 33 from apache_beam.options.value_provider import StaticValueProvider 34 from apache_beam.options.value_provider import ValueProvider 35 from apache_beam.options.value_provider import check_accessible 36 from apache_beam.transforms.display import DisplayDataItem 37 38 DEFAULT_SHARD_NAME_TEMPLATE = '-SSSSS-of-NNNNN' 39 40 __all__ = ['FileBasedSink'] 41 42 _LOGGER = logging.getLogger(__name__) 43 44 45 class FileBasedSink(iobase.Sink): 46 """A sink to a GCS or local files. 47 48 To implement a file-based sink, extend this class and override 49 either :meth:`.write_record()` or :meth:`.write_encoded_record()`. 50 51 If needed, also overwrite :meth:`.open()` and/or :meth:`.close()` to customize 52 the file handling or write headers and footers. 53 54 The output of this write is a :class:`~apache_beam.pvalue.PCollection` of 55 all written shards. 56 """ 57 58 # Max number of threads to be used for renaming. 59 _MAX_RENAME_THREADS = 64 60 __hash__ = None # type: ignore[assignment] 61 62 def __init__( 63 self, 64 file_path_prefix, 65 coder, 66 file_name_suffix='', 67 num_shards=0, 68 shard_name_template=None, 69 mime_type='application/octet-stream', 70 compression_type=CompressionTypes.AUTO, 71 *, 72 max_records_per_shard=None, 73 max_bytes_per_shard=None, 74 skip_if_empty=False): 75 """ 76 Raises: 77 TypeError: if file path parameters are not a :class:`str` or 78 :class:`~apache_beam.options.value_provider.ValueProvider`, or if 79 **compression_type** is not member of 80 :class:`~apache_beam.io.filesystem.CompressionTypes`. 81 ValueError: if **shard_name_template** is not of expected 82 format. 83 """ 84 if not isinstance(file_path_prefix, (str, ValueProvider)): 85 raise TypeError( 86 'file_path_prefix must be a string or ValueProvider;' 87 'got %r instead' % file_path_prefix) 88 if not isinstance(file_name_suffix, (str, ValueProvider)): 89 raise TypeError( 90 'file_name_suffix must be a string or ValueProvider;' 91 'got %r instead' % file_name_suffix) 92 93 if not CompressionTypes.is_valid_compression_type(compression_type): 94 raise TypeError( 95 'compression_type must be CompressionType object but ' 96 'was %s' % type(compression_type)) 97 if shard_name_template is None: 98 shard_name_template = DEFAULT_SHARD_NAME_TEMPLATE 99 elif shard_name_template == '': 100 num_shards = 1 101 if isinstance(file_path_prefix, str): 102 file_path_prefix = StaticValueProvider(str, file_path_prefix) 103 if isinstance(file_name_suffix, str): 104 file_name_suffix = StaticValueProvider(str, file_name_suffix) 105 self.file_path_prefix = file_path_prefix 106 self.file_name_suffix = file_name_suffix 107 self.num_shards = num_shards 108 self.coder = coder 109 self.shard_name_format = self._template_to_format(shard_name_template) 110 self.shard_name_glob_format = self._template_to_glob_format( 111 shard_name_template) 112 self.compression_type = compression_type 113 self.mime_type = mime_type 114 self.max_records_per_shard = max_records_per_shard 115 self.max_bytes_per_shard = max_bytes_per_shard 116 self.skip_if_empty = skip_if_empty 117 118 def display_data(self): 119 return { 120 'shards': DisplayDataItem(self.num_shards, 121 label='Number of Shards').drop_if_default(0), 122 'compression': DisplayDataItem(str(self.compression_type)), 123 'file_pattern': DisplayDataItem( 124 '{}{}{}'.format( 125 self.file_path_prefix, 126 self.shard_name_format, 127 self.file_name_suffix), 128 label='File Pattern') 129 } 130 131 @check_accessible(['file_path_prefix']) 132 def open(self, temp_path): 133 """Opens ``temp_path``, returning an opaque file handle object. 134 135 The returned file handle is passed to ``write_[encoded_]record`` and 136 ``close``. 137 """ 138 writer = FileSystems.create( 139 temp_path, self.mime_type, self.compression_type) 140 if self.max_bytes_per_shard: 141 self.byte_counter = _ByteCountingWriter(writer) 142 return self.byte_counter 143 else: 144 return writer 145 146 def write_record(self, file_handle, value): 147 """Writes a single record go the file handle returned by ``open()``. 148 149 By default, calls ``write_encoded_record`` after encoding the record with 150 this sink's Coder. 151 """ 152 self.write_encoded_record(file_handle, self.coder.encode(value)) 153 154 def write_encoded_record(self, file_handle, encoded_value): 155 """Writes a single encoded record to the file handle returned by ``open()``. 156 """ 157 raise NotImplementedError 158 159 def close(self, file_handle): 160 """Finalize and close the file handle returned from ``open()``. 161 162 Called after all records are written. 163 164 By default, calls ``file_handle.close()`` iff it is not None. 165 """ 166 if file_handle is not None: 167 file_handle.close() 168 169 @check_accessible(['file_path_prefix', 'file_name_suffix']) 170 def initialize_write(self): 171 file_path_prefix = self.file_path_prefix.get() 172 173 tmp_dir = self._create_temp_dir(file_path_prefix) 174 FileSystems.mkdirs(tmp_dir) 175 return tmp_dir 176 177 def _create_temp_dir(self, file_path_prefix): 178 base_path, last_component = FileSystems.split(file_path_prefix) 179 if not last_component: 180 # Trying to re-split the base_path to check if it's a root. 181 new_base_path, _ = FileSystems.split(base_path) 182 if base_path == new_base_path: 183 raise ValueError( 184 'Cannot create a temporary directory for root path ' 185 'prefix %s. Please specify a file path prefix with ' 186 'at least two components.' % file_path_prefix) 187 path_components = [ 188 base_path, 'beam-temp-' + last_component + '-' + uuid.uuid1().hex 189 ] 190 return FileSystems.join(*path_components) 191 192 @check_accessible(['file_path_prefix', 'file_name_suffix']) 193 def open_writer(self, init_result, uid): 194 # A proper suffix is needed for AUTO compression detection. 195 # We also ensure there will be no collisions with uid and a 196 # (possibly unsharded) file_path_prefix and a (possibly empty) 197 # file_name_suffix. 198 file_path_prefix = self.file_path_prefix.get() 199 file_name_suffix = self.file_name_suffix.get() 200 suffix = ('.' + os.path.basename(file_path_prefix) + file_name_suffix) 201 writer_path = FileSystems.join(init_result, uid) + suffix 202 return FileBasedSinkWriter(self, writer_path) 203 204 @check_accessible(['file_path_prefix', 'file_name_suffix']) 205 def _get_final_name(self, shard_num, num_shards): 206 return ''.join([ 207 self.file_path_prefix.get(), 208 self.shard_name_format % 209 dict(shard_num=shard_num, num_shards=num_shards), 210 self.file_name_suffix.get() 211 ]) 212 213 @check_accessible(['file_path_prefix', 'file_name_suffix']) 214 def _get_final_name_glob(self, num_shards): 215 return ''.join([ 216 self.file_path_prefix.get(), 217 self.shard_name_glob_format % dict(num_shards=num_shards), 218 self.file_name_suffix.get() 219 ]) 220 221 def pre_finalize(self, init_result, writer_results): 222 num_shards = len(list(writer_results)) 223 dst_glob = self._get_final_name_glob(num_shards) 224 dst_glob_files = [ 225 file_metadata.path for mr in FileSystems.match([dst_glob]) 226 for file_metadata in mr.metadata_list 227 ] 228 229 if dst_glob_files: 230 _LOGGER.warning( 231 'Deleting %d existing files in target path matching: %s', 232 len(dst_glob_files), 233 self.shard_name_glob_format) 234 FileSystems.delete(dst_glob_files) 235 236 def _check_state_for_finalize_write(self, writer_results, num_shards): 237 """Checks writer output files' states. 238 239 Returns: 240 src_files, dst_files: Lists of files to rename. For each i, finalize_write 241 should rename(src_files[i], dst_files[i]). 242 delete_files: Src files to delete. These could be leftovers from an 243 incomplete (non-atomic) rename operation. 244 num_skipped: Tally of writer results files already renamed, such as from 245 a previous run of finalize_write(). 246 """ 247 if not writer_results: 248 return [], [], [], 0 249 250 src_glob = FileSystems.join(FileSystems.split(writer_results[0])[0], '*') 251 dst_glob = self._get_final_name_glob(num_shards) 252 src_glob_files = set( 253 file_metadata.path for mr in FileSystems.match([src_glob]) 254 for file_metadata in mr.metadata_list) 255 dst_glob_files = set( 256 file_metadata.path for mr in FileSystems.match([dst_glob]) 257 for file_metadata in mr.metadata_list) 258 259 src_files = [] 260 dst_files = [] 261 delete_files = [] 262 num_skipped = 0 263 for shard_num, src in enumerate(writer_results): 264 final_name = self._get_final_name(shard_num, num_shards) 265 dst = final_name 266 src_exists = src in src_glob_files 267 dst_exists = dst in dst_glob_files 268 if not src_exists and not dst_exists: 269 raise BeamIOError( 270 'src and dst files do not exist. src: %s, dst: %s' % (src, dst)) 271 if not src_exists and dst_exists: 272 _LOGGER.debug('src: %s -> dst: %s already renamed, skipping', src, dst) 273 num_skipped += 1 274 continue 275 if (src_exists and dst_exists and 276 FileSystems.checksum(src) == FileSystems.checksum(dst)): 277 _LOGGER.debug('src: %s == dst: %s, deleting src', src, dst) 278 delete_files.append(src) 279 continue 280 281 src_files.append(src) 282 dst_files.append(dst) 283 return src_files, dst_files, delete_files, num_skipped 284 285 @check_accessible(['file_path_prefix']) 286 def finalize_write( 287 self, init_result, writer_results, unused_pre_finalize_results): 288 writer_results = sorted(writer_results) 289 num_shards = len(writer_results) 290 291 src_files, dst_files, delete_files, num_skipped = ( 292 self._check_state_for_finalize_write(writer_results, num_shards)) 293 num_skipped += len(delete_files) 294 FileSystems.delete(delete_files) 295 num_shards_to_finalize = len(src_files) 296 min_threads = min(num_shards_to_finalize, FileBasedSink._MAX_RENAME_THREADS) 297 num_threads = max(1, min_threads) 298 299 chunk_size = FileSystems.get_chunk_size(self.file_path_prefix.get()) 300 source_file_batch = [ 301 src_files[i:i + chunk_size] 302 for i in range(0, len(src_files), chunk_size) 303 ] 304 destination_file_batch = [ 305 dst_files[i:i + chunk_size] 306 for i in range(0, len(dst_files), chunk_size) 307 ] 308 309 if num_shards_to_finalize: 310 _LOGGER.info( 311 'Starting finalize_write threads with num_shards: %d (skipped: %d), ' 312 'batches: %d, num_threads: %d', 313 num_shards_to_finalize, 314 num_skipped, 315 len(source_file_batch), 316 num_threads) 317 start_time = time.time() 318 319 # Use a thread pool for renaming operations. 320 def _rename_batch(batch): 321 """_rename_batch executes batch rename operations.""" 322 source_files, destination_files = batch 323 exceptions = [] 324 try: 325 FileSystems.rename(source_files, destination_files) 326 return exceptions 327 except BeamIOError as exp: 328 if exp.exception_details is None: 329 raise 330 for (src, dst), exception in exp.exception_details.items(): 331 if exception: 332 _LOGGER.error( 333 ('Exception in _rename_batch. src: %s, ' 334 'dst: %s, err: %s'), 335 src, 336 dst, 337 exception) 338 exceptions.append(exception) 339 else: 340 _LOGGER.debug('Rename successful: %s -> %s', src, dst) 341 return exceptions 342 343 exception_batches = util.run_using_threadpool( 344 _rename_batch, 345 list(zip(source_file_batch, destination_file_batch)), 346 num_threads) 347 348 all_exceptions = [ 349 e for exception_batch in exception_batches for e in exception_batch 350 ] 351 if all_exceptions: 352 raise Exception( 353 'Encountered exceptions in finalize_write: %s' % all_exceptions) 354 355 yield from dst_files 356 357 _LOGGER.info( 358 'Renamed %d shards in %.2f seconds.', 359 num_shards_to_finalize, 360 time.time() - start_time) 361 else: 362 _LOGGER.warning( 363 'No shards found to finalize. num_shards: %d, skipped: %d', 364 num_shards, 365 num_skipped) 366 367 try: 368 FileSystems.delete([init_result]) 369 except IOError: 370 # This error is not serious, we simply log it. 371 _LOGGER.info('Unable to delete file: %s', init_result) 372 373 @staticmethod 374 def _template_replace_num_shards(shard_name_template): 375 match = re.search('N+', shard_name_template) 376 if match: 377 shard_name_template = shard_name_template.replace( 378 match.group(0), '%%(num_shards)0%dd' % len(match.group(0))) 379 return shard_name_template 380 381 @staticmethod 382 def _template_to_format(shard_name_template): 383 if not shard_name_template: 384 return '' 385 match = re.search('S+', shard_name_template) 386 if match is None: 387 raise ValueError( 388 "Shard number pattern S+ not found in shard_name_template: %s" % 389 shard_name_template) 390 shard_name_format = shard_name_template.replace( 391 match.group(0), '%%(shard_num)0%dd' % len(match.group(0))) 392 return FileBasedSink._template_replace_num_shards(shard_name_format) 393 394 @staticmethod 395 def _template_to_glob_format(shard_name_template): 396 if not shard_name_template: 397 return '' 398 match = re.search('S+', shard_name_template) 399 if match is None: 400 raise ValueError( 401 "Shard number pattern S+ not found in shard_name_template: %s" % 402 shard_name_template) 403 shard_name_format = shard_name_template.replace(match.group(0), '*') 404 return FileBasedSink._template_replace_num_shards(shard_name_format) 405 406 def __eq__(self, other): 407 # TODO: Clean up workitem_test which uses this. 408 # pylint: disable=unidiomatic-typecheck 409 return type(self) == type(other) and self.__dict__ == other.__dict__ 410 411 412 class FileBasedSinkWriter(iobase.Writer): 413 """The writer for FileBasedSink. 414 """ 415 def __init__(self, sink, temp_shard_path): 416 self.sink = sink 417 self.temp_shard_path = temp_shard_path 418 self.temp_handle = self.sink.open(temp_shard_path) 419 self.num_records_written = 0 420 421 def write(self, value): 422 self.num_records_written += 1 423 self.sink.write_record(self.temp_handle, value) 424 425 def at_capacity(self): 426 return ( 427 self.sink.max_records_per_shard and 428 self.num_records_written >= self.sink.max_records_per_shard 429 ) or ( 430 self.sink.max_bytes_per_shard and 431 self.sink.byte_counter.bytes_written >= self.sink.max_bytes_per_shard) 432 433 def close(self): 434 self.sink.close(self.temp_handle) 435 return self.temp_shard_path 436 437 438 class _ByteCountingWriter: 439 def __init__(self, writer): 440 self.writer = writer 441 self.bytes_written = 0 442 443 def write(self, bs): 444 self.bytes_written += len(bs) 445 self.writer.write(bs) 446 447 def flush(self): 448 self.writer.flush() 449 450 def close(self): 451 self.writer.close()