github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/cache_manager.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  # pytype: skip-file
    19  
    20  import collections
    21  import os
    22  import tempfile
    23  from urllib.parse import quote
    24  from urllib.parse import unquote_to_bytes
    25  
    26  import apache_beam as beam
    27  from apache_beam import coders
    28  from apache_beam.io import filesystems
    29  from apache_beam.io import textio
    30  from apache_beam.io import tfrecordio
    31  from apache_beam.transforms import combiners
    32  
    33  
    34  class CacheManager(object):
    35    """Abstract class for caching PCollections.
    36  
    37    A PCollection cache is identified by labels, which consist of a prefix (either
    38    'full' or 'sample') and a cache_label which is a hash of the PCollection
    39    derivation.
    40    """
    41    def exists(self, *labels):
    42      # type (*str) -> bool
    43  
    44      """Returns if the PCollection cache exists."""
    45      raise NotImplementedError
    46  
    47    def is_latest_version(self, version, *labels):
    48      # type (str, *str) -> bool
    49  
    50      """Returns if the given version number is the latest."""
    51      return version == self._latest_version(*labels)
    52  
    53    def _latest_version(self, *labels):
    54      # type (*str) -> str
    55  
    56      """Returns the latest version number of the PCollection cache."""
    57      raise NotImplementedError
    58  
    59    def read(self, *labels, **args):
    60      # type (*str, Dict[str, Any]) -> Tuple[str, Generator[Any]]
    61  
    62      """Return the PCollection as a list as well as the version number.
    63  
    64      Args:
    65        *labels: List of labels for PCollection instance.
    66        **args: Dict of additional arguments. Currently only 'tail' as a boolean.
    67          When tail is True, will wait and read new elements until the cache is
    68          complete.
    69  
    70      Returns:
    71        A tuple containing an iterator for the items in the PCollection and the
    72          version number.
    73  
    74      It is possible that the version numbers from read() and_latest_version()
    75      are different. This usually means that the cache's been evicted (thus
    76      unavailable => read() returns version = -1), but it had reached version n
    77      before eviction.
    78      """
    79      raise NotImplementedError
    80  
    81    def write(self, value, *labels):
    82      # type (Any, *str) -> None
    83  
    84      """Writes the value to the given cache.
    85  
    86      Args:
    87        value: An encodable (with corresponding PCoder) value
    88        *labels: List of labels for PCollection instance
    89      """
    90      raise NotImplementedError
    91  
    92    def clear(self, *labels):
    93      # type (*str) -> Boolean
    94  
    95      """Clears the cache entry of the given labels and returns True on success.
    96  
    97      Args:
    98        value: An encodable (with corresponding PCoder) value
    99        *labels: List of labels for PCollection instance
   100      """
   101      raise NotImplementedError
   102  
   103    def source(self, *labels):
   104      # type (*str) -> ptransform.PTransform
   105  
   106      """Returns a PTransform that reads the PCollection cache."""
   107      raise NotImplementedError
   108  
   109    def sink(self, labels, is_capture=False):
   110      # type (*str, bool) -> ptransform.PTransform
   111  
   112      """Returns a PTransform that writes the PCollection cache.
   113  
   114      TODO(BEAM-10514): Make sure labels will not be converted into an
   115      arbitrarily long file path: e.g., windows has a 260 path limit.
   116      """
   117      raise NotImplementedError
   118  
   119    def save_pcoder(self, pcoder, *labels):
   120      # type (coders.Coder, *str) -> None
   121  
   122      """Saves pcoder for given PCollection.
   123  
   124      Correct reading of PCollection from Cache requires PCoder to be known.
   125      This method saves desired PCoder for PCollection that will subsequently
   126      be used by sink(...), source(...), and, most importantly, read(...) method.
   127      The latter must be able to read a PCollection written by Beam using
   128      non-Beam IO.
   129  
   130      Args:
   131        pcoder: A PCoder to be used for reading and writing a PCollection.
   132        *labels: List of labels for PCollection instance.
   133      """
   134      raise NotImplementedError
   135  
   136    def load_pcoder(self, *labels):
   137      # type (*str) -> coders.Coder
   138  
   139      """Returns previously saved PCoder for reading and writing PCollection."""
   140      raise NotImplementedError
   141  
   142    def cleanup(self):
   143      # type () -> None
   144  
   145      """Cleans up all the PCollection caches."""
   146      raise NotImplementedError
   147  
   148    def size(self, *labels):
   149      # type: (*str) -> int
   150  
   151      """Returns the size of the PCollection on disk in bytes."""
   152      raise NotImplementedError
   153  
   154  
   155  class FileBasedCacheManager(CacheManager):
   156    """Maps PCollections to local temp files for materialization."""
   157  
   158    _available_formats = {
   159        'text': (textio.ReadFromText, textio.WriteToText),
   160        'tfrecord': (tfrecordio.ReadFromTFRecord, tfrecordio.WriteToTFRecord)
   161    }
   162  
   163    def __init__(self, cache_dir=None, cache_format='text'):
   164      if cache_dir:
   165        self._cache_dir = cache_dir
   166      else:
   167        self._cache_dir = tempfile.mkdtemp(
   168            prefix='ib-', dir=os.environ.get('TEST_TMPDIR', None))
   169      self._versions = collections.defaultdict(lambda: self._CacheVersion())
   170      self.cache_format = cache_format
   171  
   172      if cache_format not in self._available_formats:
   173        raise ValueError("Unsupported cache format: '%s'." % cache_format)
   174      self._reader_class, self._writer_class = self._available_formats[
   175          cache_format]
   176      self._default_pcoder = (
   177          SafeFastPrimitivesCoder() if cache_format == 'text' else None)
   178  
   179      # List of saved pcoders keyed by PCollection path. It is OK to keep this
   180      # list in memory because once FileBasedCacheManager object is
   181      # destroyed/re-created it loses the access to previously written cache
   182      # objects anyways even if cache_dir already exists. In other words,
   183      # it is not possible to resume execution of Beam pipeline from the
   184      # saved cache if FileBasedCacheManager has been reset.
   185      #
   186      # However, if we are to implement better cache persistence, one needs
   187      # to take care of keeping consistency between the cached PCollection
   188      # and its PCoder type.
   189      self._saved_pcoders = {}
   190  
   191    def size(self, *labels):
   192      if self.exists(*labels):
   193        matched_path = self._match(*labels)
   194        # if any matched path has a gs:// prefix, it must be cached on GCS
   195        if 'gs://' in matched_path[0]:
   196          from apache_beam.io.gcp import gcsio
   197          return sum(
   198              sum(gcsio.GcsIO().list_prefix(path).values())
   199              for path in matched_path)
   200        return sum(os.path.getsize(path) for path in matched_path)
   201      return 0
   202  
   203    def exists(self, *labels):
   204      if labels and any(labels[1:]):
   205        return bool(self._match(*labels))
   206      return False
   207  
   208    def _latest_version(self, *labels):
   209      timestamp = 0
   210      for path in self._match(*labels):
   211        timestamp = max(timestamp, filesystems.FileSystems.last_updated(path))
   212      result = self._versions["-".join(labels)].get_version(timestamp)
   213      return result
   214  
   215    def save_pcoder(self, pcoder, *labels):
   216      self._saved_pcoders[self._path(*labels)] = pcoder
   217  
   218    def load_pcoder(self, *labels):
   219      saved_pcoder = self._saved_pcoders.get(self._path(*labels), None)
   220      if saved_pcoder is None or isinstance(saved_pcoder,
   221                                            coders.FastPrimitivesCoder):
   222        return self._default_pcoder
   223      return saved_pcoder
   224  
   225    def read(self, *labels, **args):
   226      # Return an iterator to an empty list if it doesn't exist.
   227      if not self.exists(*labels):
   228        return iter([]), -1
   229  
   230      # Otherwise, return a generator to the cached PCollection.
   231      source = self.source(*labels)._source
   232      range_tracker = source.get_range_tracker(None, None)
   233      reader = source.read(range_tracker)
   234      version = self._latest_version(*labels)
   235  
   236      return reader, version
   237  
   238    def write(self, values, *labels):
   239      """Imitates how a WriteCache transform works without running a pipeline.
   240  
   241      For testing and cache manager development, not for production usage because
   242      the write is not sharded and does not use Beam execution model.
   243      """
   244      pcoder = coders.registry.get_coder(type(values[0]))
   245      # Save the pcoder for the actual labels.
   246      self.save_pcoder(pcoder, *labels)
   247      single_shard_labels = [*labels[:-1], '-00000-of-00001']
   248      # Save the pcoder for the labels that imitates the sharded cache file name
   249      # suffix.
   250      self.save_pcoder(pcoder, *single_shard_labels)
   251      # Put a '-%05d-of-%05d' suffix to the cache file.
   252      sink = self.sink(single_shard_labels)._sink
   253      path = self._path(*labels[:-1])
   254      writer = sink.open_writer(path, labels[-1])
   255      for v in values:
   256        writer.write(v)
   257      writer.close()
   258  
   259    def clear(self, *labels):
   260      if self.exists(*labels):
   261        filesystems.FileSystems.delete(self._match(*labels))
   262        return True
   263      return False
   264  
   265    def source(self, *labels):
   266      return self._reader_class(
   267          self._glob_path(*labels), coder=self.load_pcoder(*labels))
   268  
   269    def sink(self, labels, is_capture=False):
   270      return self._writer_class(
   271          self._path(*labels), coder=self.load_pcoder(*labels))
   272  
   273    def cleanup(self):
   274      if self._cache_dir.startswith('gs://'):
   275        from apache_beam.io.gcp import gcsfilesystem
   276        from apache_beam.options.pipeline_options import PipelineOptions
   277        fs = gcsfilesystem.GCSFileSystem(PipelineOptions())
   278        fs.delete([self._cache_dir + '/full/'])
   279      elif filesystems.FileSystems.exists(self._cache_dir):
   280        filesystems.FileSystems.delete([self._cache_dir])
   281      self._saved_pcoders = {}
   282  
   283    def _glob_path(self, *labels):
   284      return self._path(*labels) + '*-*-of-*'
   285  
   286    def _path(self, *labels):
   287      return filesystems.FileSystems.join(self._cache_dir, *labels)
   288  
   289    def _match(self, *labels):
   290      match = filesystems.FileSystems.match([self._glob_path(*labels)])
   291      assert len(match) == 1
   292      return [metadata.path for metadata in match[0].metadata_list]
   293  
   294    class _CacheVersion(object):
   295      """This class keeps track of the timestamp and the corresponding version."""
   296      def __init__(self):
   297        self.current_version = -1
   298        self.current_timestamp = 0
   299  
   300      def get_version(self, timestamp):
   301        """Updates version if necessary and returns the version number.
   302  
   303        Args:
   304          timestamp: (int) unix timestamp when the cache is updated. This value is
   305              zero if the cache has been evicted or doesn't exist.
   306        """
   307        # Do not update timestamp if the cache's been evicted.
   308        if timestamp != 0 and timestamp != self.current_timestamp:
   309          assert timestamp > self.current_timestamp
   310          self.current_version = self.current_version + 1
   311          self.current_timestamp = timestamp
   312        return self.current_version
   313  
   314  
   315  class ReadCache(beam.PTransform):
   316    """A PTransform that reads the PCollections from the cache."""
   317    def __init__(self, cache_manager, label):
   318      self._cache_manager = cache_manager
   319      self._label = label
   320  
   321    def expand(self, pbegin):
   322      # pylint: disable=expression-not-assigned
   323      return pbegin | 'Read' >> self._cache_manager.source('full', self._label)
   324  
   325  
   326  class WriteCache(beam.PTransform):
   327    """A PTransform that writes the PCollections to the cache."""
   328    def __init__(
   329        self,
   330        cache_manager,
   331        label,
   332        sample=False,
   333        sample_size=0,
   334        is_capture=False):
   335      self._cache_manager = cache_manager
   336      self._label = label
   337      self._sample = sample
   338      self._sample_size = sample_size
   339      self._is_capture = is_capture
   340  
   341    def expand(self, pcoll):
   342      prefix = 'sample' if self._sample else 'full'
   343  
   344      # We save pcoder that is necessary for proper reading of
   345      # cached PCollection. _cache_manager.sink(...) call below
   346      # should be using this saved pcoder.
   347      self._cache_manager.save_pcoder(
   348          coders.registry.get_coder(pcoll.element_type), prefix, self._label)
   349  
   350      if self._sample:
   351        pcoll |= 'Sample' >> (
   352            combiners.Sample.FixedSizeGlobally(self._sample_size)
   353            | beam.FlatMap(lambda sample: sample))
   354      # pylint: disable=expression-not-assigned
   355      return pcoll | 'Write' >> self._cache_manager.sink(
   356          (prefix, self._label), is_capture=self._is_capture)
   357  
   358  
   359  class SafeFastPrimitivesCoder(coders.Coder):
   360    """This class add an quote/unquote step to escape special characters."""
   361  
   362    # pylint: disable=bad-option-value
   363  
   364    def encode(self, value):
   365      return quote(
   366          coders.coders.FastPrimitivesCoder().encode(value)).encode('utf-8')
   367  
   368    def decode(self, value):
   369      return coders.coders.FastPrimitivesCoder().decode(unquote_to_bytes(value))