github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/filebasedsink.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """File-based sink."""
    19  
    20  # pytype: skip-file
    21  
    22  import logging
    23  import os
    24  import re
    25  import time
    26  import uuid
    27  
    28  from apache_beam.internal import util
    29  from apache_beam.io import iobase
    30  from apache_beam.io.filesystem import BeamIOError
    31  from apache_beam.io.filesystem import CompressionTypes
    32  from apache_beam.io.filesystems import FileSystems
    33  from apache_beam.options.value_provider import StaticValueProvider
    34  from apache_beam.options.value_provider import ValueProvider
    35  from apache_beam.options.value_provider import check_accessible
    36  from apache_beam.transforms.display import DisplayDataItem
    37  
    38  DEFAULT_SHARD_NAME_TEMPLATE = '-SSSSS-of-NNNNN'
    39  
    40  __all__ = ['FileBasedSink']
    41  
    42  _LOGGER = logging.getLogger(__name__)
    43  
    44  
    45  class FileBasedSink(iobase.Sink):
    46    """A sink to a GCS or local files.
    47  
    48    To implement a file-based sink, extend this class and override
    49    either :meth:`.write_record()` or :meth:`.write_encoded_record()`.
    50  
    51    If needed, also overwrite :meth:`.open()` and/or :meth:`.close()` to customize
    52    the file handling or write headers and footers.
    53  
    54    The output of this write is a :class:`~apache_beam.pvalue.PCollection` of
    55    all written shards.
    56    """
    57  
    58    # Max number of threads to be used for renaming.
    59    _MAX_RENAME_THREADS = 64
    60    __hash__ = None  # type: ignore[assignment]
    61  
    62    def __init__(
    63        self,
    64        file_path_prefix,
    65        coder,
    66        file_name_suffix='',
    67        num_shards=0,
    68        shard_name_template=None,
    69        mime_type='application/octet-stream',
    70        compression_type=CompressionTypes.AUTO,
    71        *,
    72        max_records_per_shard=None,
    73        max_bytes_per_shard=None,
    74        skip_if_empty=False):
    75      """
    76       Raises:
    77        TypeError: if file path parameters are not a :class:`str` or
    78          :class:`~apache_beam.options.value_provider.ValueProvider`, or if
    79          **compression_type** is not member of
    80          :class:`~apache_beam.io.filesystem.CompressionTypes`.
    81        ValueError: if **shard_name_template** is not of expected
    82          format.
    83      """
    84      if not isinstance(file_path_prefix, (str, ValueProvider)):
    85        raise TypeError(
    86            'file_path_prefix must be a string or ValueProvider;'
    87            'got %r instead' % file_path_prefix)
    88      if not isinstance(file_name_suffix, (str, ValueProvider)):
    89        raise TypeError(
    90            'file_name_suffix must be a string or ValueProvider;'
    91            'got %r instead' % file_name_suffix)
    92  
    93      if not CompressionTypes.is_valid_compression_type(compression_type):
    94        raise TypeError(
    95            'compression_type must be CompressionType object but '
    96            'was %s' % type(compression_type))
    97      if shard_name_template is None:
    98        shard_name_template = DEFAULT_SHARD_NAME_TEMPLATE
    99      elif shard_name_template == '':
   100        num_shards = 1
   101      if isinstance(file_path_prefix, str):
   102        file_path_prefix = StaticValueProvider(str, file_path_prefix)
   103      if isinstance(file_name_suffix, str):
   104        file_name_suffix = StaticValueProvider(str, file_name_suffix)
   105      self.file_path_prefix = file_path_prefix
   106      self.file_name_suffix = file_name_suffix
   107      self.num_shards = num_shards
   108      self.coder = coder
   109      self.shard_name_format = self._template_to_format(shard_name_template)
   110      self.shard_name_glob_format = self._template_to_glob_format(
   111          shard_name_template)
   112      self.compression_type = compression_type
   113      self.mime_type = mime_type
   114      self.max_records_per_shard = max_records_per_shard
   115      self.max_bytes_per_shard = max_bytes_per_shard
   116      self.skip_if_empty = skip_if_empty
   117  
   118    def display_data(self):
   119      return {
   120          'shards': DisplayDataItem(self.num_shards,
   121                                    label='Number of Shards').drop_if_default(0),
   122          'compression': DisplayDataItem(str(self.compression_type)),
   123          'file_pattern': DisplayDataItem(
   124              '{}{}{}'.format(
   125                  self.file_path_prefix,
   126                  self.shard_name_format,
   127                  self.file_name_suffix),
   128              label='File Pattern')
   129      }
   130  
   131    @check_accessible(['file_path_prefix'])
   132    def open(self, temp_path):
   133      """Opens ``temp_path``, returning an opaque file handle object.
   134  
   135      The returned file handle is passed to ``write_[encoded_]record`` and
   136      ``close``.
   137      """
   138      writer = FileSystems.create(
   139          temp_path, self.mime_type, self.compression_type)
   140      if self.max_bytes_per_shard:
   141        self.byte_counter = _ByteCountingWriter(writer)
   142        return self.byte_counter
   143      else:
   144        return writer
   145  
   146    def write_record(self, file_handle, value):
   147      """Writes a single record go the file handle returned by ``open()``.
   148  
   149      By default, calls ``write_encoded_record`` after encoding the record with
   150      this sink's Coder.
   151      """
   152      self.write_encoded_record(file_handle, self.coder.encode(value))
   153  
   154    def write_encoded_record(self, file_handle, encoded_value):
   155      """Writes a single encoded record to the file handle returned by ``open()``.
   156      """
   157      raise NotImplementedError
   158  
   159    def close(self, file_handle):
   160      """Finalize and close the file handle returned from ``open()``.
   161  
   162      Called after all records are written.
   163  
   164      By default, calls ``file_handle.close()`` iff it is not None.
   165      """
   166      if file_handle is not None:
   167        file_handle.close()
   168  
   169    @check_accessible(['file_path_prefix', 'file_name_suffix'])
   170    def initialize_write(self):
   171      file_path_prefix = self.file_path_prefix.get()
   172  
   173      tmp_dir = self._create_temp_dir(file_path_prefix)
   174      FileSystems.mkdirs(tmp_dir)
   175      return tmp_dir
   176  
   177    def _create_temp_dir(self, file_path_prefix):
   178      base_path, last_component = FileSystems.split(file_path_prefix)
   179      if not last_component:
   180        # Trying to re-split the base_path to check if it's a root.
   181        new_base_path, _ = FileSystems.split(base_path)
   182        if base_path == new_base_path:
   183          raise ValueError(
   184              'Cannot create a temporary directory for root path '
   185              'prefix %s. Please specify a file path prefix with '
   186              'at least two components.' % file_path_prefix)
   187      path_components = [
   188          base_path, 'beam-temp-' + last_component + '-' + uuid.uuid1().hex
   189      ]
   190      return FileSystems.join(*path_components)
   191  
   192    @check_accessible(['file_path_prefix', 'file_name_suffix'])
   193    def open_writer(self, init_result, uid):
   194      # A proper suffix is needed for AUTO compression detection.
   195      # We also ensure there will be no collisions with uid and a
   196      # (possibly unsharded) file_path_prefix and a (possibly empty)
   197      # file_name_suffix.
   198      file_path_prefix = self.file_path_prefix.get()
   199      file_name_suffix = self.file_name_suffix.get()
   200      suffix = ('.' + os.path.basename(file_path_prefix) + file_name_suffix)
   201      writer_path = FileSystems.join(init_result, uid) + suffix
   202      return FileBasedSinkWriter(self, writer_path)
   203  
   204    @check_accessible(['file_path_prefix', 'file_name_suffix'])
   205    def _get_final_name(self, shard_num, num_shards):
   206      return ''.join([
   207          self.file_path_prefix.get(),
   208          self.shard_name_format %
   209          dict(shard_num=shard_num, num_shards=num_shards),
   210          self.file_name_suffix.get()
   211      ])
   212  
   213    @check_accessible(['file_path_prefix', 'file_name_suffix'])
   214    def _get_final_name_glob(self, num_shards):
   215      return ''.join([
   216          self.file_path_prefix.get(),
   217          self.shard_name_glob_format % dict(num_shards=num_shards),
   218          self.file_name_suffix.get()
   219      ])
   220  
   221    def pre_finalize(self, init_result, writer_results):
   222      num_shards = len(list(writer_results))
   223      dst_glob = self._get_final_name_glob(num_shards)
   224      dst_glob_files = [
   225          file_metadata.path for mr in FileSystems.match([dst_glob])
   226          for file_metadata in mr.metadata_list
   227      ]
   228  
   229      if dst_glob_files:
   230        _LOGGER.warning(
   231            'Deleting %d existing files in target path matching: %s',
   232            len(dst_glob_files),
   233            self.shard_name_glob_format)
   234        FileSystems.delete(dst_glob_files)
   235  
   236    def _check_state_for_finalize_write(self, writer_results, num_shards):
   237      """Checks writer output files' states.
   238  
   239      Returns:
   240        src_files, dst_files: Lists of files to rename. For each i, finalize_write
   241          should rename(src_files[i], dst_files[i]).
   242        delete_files: Src files to delete. These could be leftovers from an
   243          incomplete (non-atomic) rename operation.
   244        num_skipped: Tally of writer results files already renamed, such as from
   245          a previous run of finalize_write().
   246      """
   247      if not writer_results:
   248        return [], [], [], 0
   249  
   250      src_glob = FileSystems.join(FileSystems.split(writer_results[0])[0], '*')
   251      dst_glob = self._get_final_name_glob(num_shards)
   252      src_glob_files = set(
   253          file_metadata.path for mr in FileSystems.match([src_glob])
   254          for file_metadata in mr.metadata_list)
   255      dst_glob_files = set(
   256          file_metadata.path for mr in FileSystems.match([dst_glob])
   257          for file_metadata in mr.metadata_list)
   258  
   259      src_files = []
   260      dst_files = []
   261      delete_files = []
   262      num_skipped = 0
   263      for shard_num, src in enumerate(writer_results):
   264        final_name = self._get_final_name(shard_num, num_shards)
   265        dst = final_name
   266        src_exists = src in src_glob_files
   267        dst_exists = dst in dst_glob_files
   268        if not src_exists and not dst_exists:
   269          raise BeamIOError(
   270              'src and dst files do not exist. src: %s, dst: %s' % (src, dst))
   271        if not src_exists and dst_exists:
   272          _LOGGER.debug('src: %s -> dst: %s already renamed, skipping', src, dst)
   273          num_skipped += 1
   274          continue
   275        if (src_exists and dst_exists and
   276            FileSystems.checksum(src) == FileSystems.checksum(dst)):
   277          _LOGGER.debug('src: %s == dst: %s, deleting src', src, dst)
   278          delete_files.append(src)
   279          continue
   280  
   281        src_files.append(src)
   282        dst_files.append(dst)
   283      return src_files, dst_files, delete_files, num_skipped
   284  
   285    @check_accessible(['file_path_prefix'])
   286    def finalize_write(
   287        self, init_result, writer_results, unused_pre_finalize_results):
   288      writer_results = sorted(writer_results)
   289      num_shards = len(writer_results)
   290  
   291      src_files, dst_files, delete_files, num_skipped = (
   292          self._check_state_for_finalize_write(writer_results, num_shards))
   293      num_skipped += len(delete_files)
   294      FileSystems.delete(delete_files)
   295      num_shards_to_finalize = len(src_files)
   296      min_threads = min(num_shards_to_finalize, FileBasedSink._MAX_RENAME_THREADS)
   297      num_threads = max(1, min_threads)
   298  
   299      chunk_size = FileSystems.get_chunk_size(self.file_path_prefix.get())
   300      source_file_batch = [
   301          src_files[i:i + chunk_size]
   302          for i in range(0, len(src_files), chunk_size)
   303      ]
   304      destination_file_batch = [
   305          dst_files[i:i + chunk_size]
   306          for i in range(0, len(dst_files), chunk_size)
   307      ]
   308  
   309      if num_shards_to_finalize:
   310        _LOGGER.info(
   311            'Starting finalize_write threads with num_shards: %d (skipped: %d), '
   312            'batches: %d, num_threads: %d',
   313            num_shards_to_finalize,
   314            num_skipped,
   315            len(source_file_batch),
   316            num_threads)
   317        start_time = time.time()
   318  
   319        # Use a thread pool for renaming operations.
   320        def _rename_batch(batch):
   321          """_rename_batch executes batch rename operations."""
   322          source_files, destination_files = batch
   323          exceptions = []
   324          try:
   325            FileSystems.rename(source_files, destination_files)
   326            return exceptions
   327          except BeamIOError as exp:
   328            if exp.exception_details is None:
   329              raise
   330            for (src, dst), exception in exp.exception_details.items():
   331              if exception:
   332                _LOGGER.error(
   333                    ('Exception in _rename_batch. src: %s, '
   334                     'dst: %s, err: %s'),
   335                    src,
   336                    dst,
   337                    exception)
   338                exceptions.append(exception)
   339              else:
   340                _LOGGER.debug('Rename successful: %s -> %s', src, dst)
   341            return exceptions
   342  
   343        exception_batches = util.run_using_threadpool(
   344            _rename_batch,
   345            list(zip(source_file_batch, destination_file_batch)),
   346            num_threads)
   347  
   348        all_exceptions = [
   349            e for exception_batch in exception_batches for e in exception_batch
   350        ]
   351        if all_exceptions:
   352          raise Exception(
   353              'Encountered exceptions in finalize_write: %s' % all_exceptions)
   354  
   355        yield from dst_files
   356  
   357        _LOGGER.info(
   358            'Renamed %d shards in %.2f seconds.',
   359            num_shards_to_finalize,
   360            time.time() - start_time)
   361      else:
   362        _LOGGER.warning(
   363            'No shards found to finalize. num_shards: %d, skipped: %d',
   364            num_shards,
   365            num_skipped)
   366  
   367      try:
   368        FileSystems.delete([init_result])
   369      except IOError:
   370        # This error is not serious, we simply log it.
   371        _LOGGER.info('Unable to delete file: %s', init_result)
   372  
   373    @staticmethod
   374    def _template_replace_num_shards(shard_name_template):
   375      match = re.search('N+', shard_name_template)
   376      if match:
   377        shard_name_template = shard_name_template.replace(
   378            match.group(0), '%%(num_shards)0%dd' % len(match.group(0)))
   379      return shard_name_template
   380  
   381    @staticmethod
   382    def _template_to_format(shard_name_template):
   383      if not shard_name_template:
   384        return ''
   385      match = re.search('S+', shard_name_template)
   386      if match is None:
   387        raise ValueError(
   388            "Shard number pattern S+ not found in shard_name_template: %s" %
   389            shard_name_template)
   390      shard_name_format = shard_name_template.replace(
   391          match.group(0), '%%(shard_num)0%dd' % len(match.group(0)))
   392      return FileBasedSink._template_replace_num_shards(shard_name_format)
   393  
   394    @staticmethod
   395    def _template_to_glob_format(shard_name_template):
   396      if not shard_name_template:
   397        return ''
   398      match = re.search('S+', shard_name_template)
   399      if match is None:
   400        raise ValueError(
   401            "Shard number pattern S+ not found in shard_name_template: %s" %
   402            shard_name_template)
   403      shard_name_format = shard_name_template.replace(match.group(0), '*')
   404      return FileBasedSink._template_replace_num_shards(shard_name_format)
   405  
   406    def __eq__(self, other):
   407      # TODO: Clean up workitem_test which uses this.
   408      # pylint: disable=unidiomatic-typecheck
   409      return type(self) == type(other) and self.__dict__ == other.__dict__
   410  
   411  
   412  class FileBasedSinkWriter(iobase.Writer):
   413    """The writer for FileBasedSink.
   414    """
   415    def __init__(self, sink, temp_shard_path):
   416      self.sink = sink
   417      self.temp_shard_path = temp_shard_path
   418      self.temp_handle = self.sink.open(temp_shard_path)
   419      self.num_records_written = 0
   420  
   421    def write(self, value):
   422      self.num_records_written += 1
   423      self.sink.write_record(self.temp_handle, value)
   424  
   425    def at_capacity(self):
   426      return (
   427          self.sink.max_records_per_shard and
   428          self.num_records_written >= self.sink.max_records_per_shard
   429      ) or (
   430          self.sink.max_bytes_per_shard and
   431          self.sink.byte_counter.bytes_written >= self.sink.max_bytes_per_shard)
   432  
   433    def close(self):
   434      self.sink.close(self.temp_handle)
   435      return self.temp_shard_path
   436  
   437  
   438  class _ByteCountingWriter:
   439    def __init__(self, writer):
   440      self.writer = writer
   441      self.bytes_written = 0
   442  
   443    def write(self, bs):
   444      self.bytes_written += len(bs)
   445      self.writer.write(bs)
   446  
   447    def flush(self):
   448      self.writer.flush()
   449  
   450    def close(self):
   451      self.writer.close()