github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/caching/streaming_cache.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  # pytype: skip-file
    19  
    20  import logging
    21  import os
    22  import shutil
    23  import tempfile
    24  import time
    25  import traceback
    26  from collections import OrderedDict
    27  # We don't have an explicit pathlib dependency because this code only works with
    28  # the interactive target installed which has an indirect dependency on pathlib
    29  # through ipython>=5.9.0.
    30  from pathlib import Path
    31  
    32  from google.protobuf.message import DecodeError
    33  
    34  import apache_beam as beam
    35  from apache_beam import coders
    36  from apache_beam.portability.api import beam_interactive_api_pb2
    37  from apache_beam.portability.api import beam_runner_api_pb2
    38  from apache_beam.runners.interactive.cache_manager import CacheManager
    39  from apache_beam.runners.interactive.cache_manager import SafeFastPrimitivesCoder
    40  from apache_beam.runners.interactive.caching.cacheable import CacheKey
    41  from apache_beam.testing.test_stream import OutputFormat
    42  from apache_beam.testing.test_stream import ReverseTestStream
    43  from apache_beam.utils import timestamp
    44  
    45  _LOGGER = logging.getLogger(__name__)
    46  
    47  
    48  class StreamingCacheSink(beam.PTransform):
    49    """A PTransform that writes TestStreamFile(Header|Records)s to file.
    50  
    51    This transform takes in an arbitrary element stream and writes the list of
    52    TestStream events (as TestStreamFileRecords) to file. When replayed, this
    53    will produce the best-effort replay of the original job (e.g. some elements
    54    may be produced slightly out of order from the original stream).
    55  
    56    Note that this PTransform is assumed to be only run on a single machine where
    57    the following assumptions are correct: elements come in ordered, no two
    58    transforms are writing to the same file. This PTransform is assumed to only
    59    run correctly with the DirectRunner.
    60  
    61    TODO(https://github.com/apache/beam/issues/20002): Generalize this to more
    62    source/sink types aside from file based. Also, generalize to cases where
    63    there might be multiple workers writing to the same sink.
    64    """
    65    def __init__(
    66        self,
    67        cache_dir,
    68        filename,
    69        sample_resolution_sec,
    70        coder=SafeFastPrimitivesCoder()):
    71      self._cache_dir = cache_dir
    72      self._filename = filename
    73      self._sample_resolution_sec = sample_resolution_sec
    74      self._coder = coder
    75      self._path = os.path.join(self._cache_dir, self._filename)
    76  
    77    @property
    78    def path(self):
    79      """Returns the path the sink leads to."""
    80      return self._path
    81  
    82    @property
    83    def size_in_bytes(self):
    84      """Returns the space usage in bytes of the sink."""
    85      try:
    86        return os.stat(self._path).st_size
    87      except OSError:
    88        _LOGGER.debug(
    89            'Failed to calculate cache size for file %s, the file might have not '
    90            'been created yet. Return 0. %s',
    91            self._path,
    92            traceback.format_exc())
    93        return 0
    94  
    95    def expand(self, pcoll):
    96      class StreamingWriteToText(beam.DoFn):
    97        """DoFn that performs the writing.
    98  
    99        Note that the other file writing methods cannot be used in streaming
   100        contexts.
   101        """
   102        def __init__(self, full_path, coder=SafeFastPrimitivesCoder()):
   103          self._full_path = full_path
   104          self._coder = coder
   105  
   106          # Try and make the given path.
   107          Path(os.path.dirname(full_path)).mkdir(parents=True, exist_ok=True)
   108  
   109        def start_bundle(self):
   110          # Open the file for 'append-mode' and writing 'bytes'.
   111          self._fh = open(self._full_path, 'ab')
   112  
   113        def finish_bundle(self):
   114          self._fh.close()
   115  
   116        def process(self, e):
   117          """Appends the given element to the file.
   118          """
   119          self._fh.write(self._coder.encode(e) + b'\n')
   120  
   121      return (
   122          pcoll
   123          | ReverseTestStream(
   124              output_tag=self._filename,
   125              sample_resolution_sec=self._sample_resolution_sec,
   126              output_format=OutputFormat.SERIALIZED_TEST_STREAM_FILE_RECORDS,
   127              coder=self._coder)
   128          | beam.ParDo(
   129              StreamingWriteToText(full_path=self._path, coder=self._coder)))
   130  
   131  
   132  class StreamingCacheSource:
   133    """A class that reads and parses TestStreamFile(Header|Reader)s.
   134  
   135    This source operates in the following way:
   136  
   137      1. Wait for up to `timeout_secs` for the file to be available.
   138      2. Read, parse, and emit the entire contents of the file
   139      3. Wait for more events to come or until `is_cache_complete` returns True
   140      4. If there are more events, then go to 2
   141      5. Otherwise, stop emitting.
   142  
   143    This class is used to read from file and send its to the TestStream via the
   144    StreamingCacheManager.Reader.
   145    """
   146    def __init__(self, cache_dir, labels, is_cache_complete=None, coder=None):
   147      if not coder:
   148        coder = SafeFastPrimitivesCoder()
   149  
   150      if not is_cache_complete:
   151        is_cache_complete = lambda _: True
   152  
   153      self._cache_dir = cache_dir
   154      self._coder = coder
   155      self._labels = labels
   156      self._path = os.path.join(self._cache_dir, *self._labels)
   157      self._is_cache_complete = is_cache_complete
   158      self._pipeline_id = CacheKey.from_str(labels[-1]).pipeline_id
   159  
   160    def _wait_until_file_exists(self, timeout_secs=30):
   161      """Blocks until the file exists for a maximum of timeout_secs.
   162      """
   163      # Wait for up to `timeout_secs` for the file to be available.
   164      start = time.time()
   165      while not os.path.exists(self._path):
   166        time.sleep(1)
   167        if time.time() - start > timeout_secs:
   168          pcollection_var = CacheKey.from_str(self._labels[-1]).var
   169          raise RuntimeError(
   170              'Timed out waiting for cache file for PCollection `{}` to be '
   171              'available with path {}.'.format(pcollection_var, self._path))
   172      return open(self._path, mode='rb')
   173  
   174    def _emit_from_file(self, fh, tail):
   175      """Emits the TestStreamFile(Header|Record)s from file.
   176  
   177      This returns a generator to be able to read all lines from the given file.
   178      If `tail` is True, then it will wait until the cache is complete to exit.
   179      Otherwise, it will read the file only once.
   180      """
   181      # Always read at least once to read the whole file.
   182      while True:
   183        pos = fh.tell()
   184        line = fh.readline()
   185  
   186        # Check if we are at EOF or if we have an incomplete line.
   187        if not line or (line and line[-1] != b'\n'[0]):
   188          # Read at least the first line to get the header.
   189          if not tail and pos != 0:
   190            break
   191  
   192          # Complete reading only when the cache is complete.
   193          if self._is_cache_complete(self._pipeline_id):
   194            break
   195  
   196          # Otherwise wait for new data in the file to be written.
   197          time.sleep(0.5)
   198          fh.seek(pos)
   199        else:
   200          # The first line at pos = 0 is always the header. Read the line without
   201          # the new line.
   202          to_decode = line[:-1]
   203          if pos == 0:
   204            proto_cls = beam_interactive_api_pb2.TestStreamFileHeader
   205          else:
   206            proto_cls = beam_interactive_api_pb2.TestStreamFileRecord
   207          msg = self._try_parse_as(proto_cls, to_decode)
   208          if msg:
   209            yield msg
   210          else:
   211            break
   212  
   213    def _try_parse_as(self, proto_cls, to_decode):
   214      try:
   215        msg = proto_cls()
   216        msg.ParseFromString(self._coder.decode(to_decode))
   217      except DecodeError:
   218        _LOGGER.error(
   219            'Could not parse as %s. This can indicate that the cache is '
   220            'corruputed. Please restart the kernel. '
   221            '\nfile: %s \nmessage: %s',
   222            proto_cls,
   223            self._path,
   224            to_decode)
   225        msg = None
   226      return msg
   227  
   228    def read(self, tail):
   229      """Reads all TestStreamFile(Header|TestStreamFileRecord)s from file.
   230  
   231      This returns a generator to be able to read all lines from the given file.
   232      If `tail` is True, then it will wait until the cache is complete to exit.
   233      Otherwise, it will read the file only once.
   234      """
   235      with self._wait_until_file_exists() as f:
   236        for e in self._emit_from_file(f, tail):
   237          yield e
   238  
   239  
   240  # TODO(victorhc): Add support for cache_dir locations that are on GCS
   241  class StreamingCache(CacheManager):
   242    """Abstraction that holds the logic for reading and writing to cache.
   243    """
   244    def __init__(
   245        self,
   246        cache_dir,
   247        is_cache_complete=None,
   248        sample_resolution_sec=0.1,
   249        saved_pcoders=None):
   250      self._sample_resolution_sec = sample_resolution_sec
   251      self._is_cache_complete = is_cache_complete
   252  
   253      if cache_dir:
   254        self._cache_dir = cache_dir
   255      else:
   256        self._cache_dir = tempfile.mkdtemp(
   257            prefix='ib-', dir=os.environ.get('TEST_TMPDIR', None))
   258  
   259      # List of saved pcoders keyed by PCollection path. It is OK to keep this
   260      # list in memory because once FileBasedCacheManager object is
   261      # destroyed/re-created it loses the access to previously written cache
   262      # objects anyways even if cache_dir already exists. In other words,
   263      # it is not possible to resume execution of Beam pipeline from the
   264      # saved cache if FileBasedCacheManager has been reset.
   265      #
   266      # However, if we are to implement better cache persistence, one needs
   267      # to take care of keeping consistency between the cached PCollection
   268      # and its PCoder type.
   269      self._saved_pcoders = saved_pcoders or {}
   270      self._default_pcoder = SafeFastPrimitivesCoder()
   271  
   272      # The sinks to capture data from capturable sources.
   273      # Dict([str, StreamingCacheSink])
   274      self._capture_sinks = {}
   275      self._capture_keys = set()
   276  
   277    def size(self, *labels):
   278      if self.exists(*labels):
   279        return os.path.getsize(os.path.join(self._cache_dir, *labels))
   280      return 0
   281  
   282    @property
   283    def capture_size(self):
   284      return sum([sink.size_in_bytes for _, sink in self._capture_sinks.items()])
   285  
   286    @property
   287    def capture_paths(self):
   288      return list(self._capture_sinks.keys())
   289  
   290    @property
   291    def capture_keys(self):
   292      return self._capture_keys
   293  
   294    def exists(self, *labels):
   295      if labels and any(labels):
   296        path = os.path.join(self._cache_dir, *labels)
   297        return os.path.exists(path)
   298      return False
   299  
   300    # TODO(srohde): Modify this to return the correct version.
   301    def read(self, *labels, **args):
   302      """Returns a generator to read all records from file."""
   303      tail = args.pop('tail', False)
   304  
   305      # Only immediately return when the file doesn't exist when the user wants a
   306      # snapshot of the cache (when tail is false).
   307      if not self.exists(*labels) and not tail:
   308        return iter([]), -1
   309  
   310      reader = StreamingCacheSource(
   311          self._cache_dir,
   312          labels,
   313          self._is_cache_complete,
   314          self.load_pcoder(*labels)).read(tail=tail)
   315  
   316      # Return an empty iterator if there is nothing in the file yet. This can
   317      # only happen when tail is False.
   318      try:
   319        header = next(reader)
   320      except StopIteration:
   321        return iter([]), -1
   322      return StreamingCache.Reader([header], [reader]).read(), 1
   323  
   324    def read_multiple(self, labels, tail=True):
   325      """Returns a generator to read all records from file.
   326  
   327      Does tail until the cache is complete. This is because it is used in the
   328      TestStreamServiceController to read from file which is only used during
   329      pipeline runtime which needs to block.
   330      """
   331      readers = [
   332          StreamingCacheSource(
   333              self._cache_dir, l, self._is_cache_complete,
   334              self.load_pcoder(*l)).read(tail=tail) for l in labels
   335      ]
   336      headers = [next(r) for r in readers]
   337      return StreamingCache.Reader(headers, readers).read()
   338  
   339    def write(self, values, *labels):
   340      """Writes the given values to cache.
   341      """
   342      directory = os.path.join(self._cache_dir, *labels[:-1])
   343      filepath = os.path.join(directory, labels[-1])
   344      if not os.path.exists(directory):
   345        os.makedirs(directory)
   346      with open(filepath, 'ab') as f:
   347        for v in values:
   348          if isinstance(v,
   349                        (beam_interactive_api_pb2.TestStreamFileHeader,
   350                         beam_interactive_api_pb2.TestStreamFileRecord)):
   351            val = v.SerializeToString()
   352          else:
   353            raise TypeError(
   354                'Values given to streaming cache should be either '
   355                'TestStreamFileHeader or TestStreamFileRecord.')
   356          f.write(self.load_pcoder(*labels).encode(val) + b'\n')
   357  
   358    def clear(self, *labels):
   359      directory = os.path.join(self._cache_dir, *labels[:-1])
   360      filepath = os.path.join(directory, labels[-1])
   361      self._capture_keys.discard(labels[-1])
   362      if os.path.exists(filepath):
   363        os.remove(filepath)
   364        return True
   365      return False
   366  
   367    def source(self, *labels):
   368      """Returns the StreamingCacheManager source.
   369  
   370      This is beam.Impulse() because unbounded sources will be marked with this
   371      and then the PipelineInstrument will replace these with a TestStream.
   372      """
   373      return beam.Impulse()
   374  
   375    def sink(self, labels, is_capture=False):
   376      """Returns a StreamingCacheSink to write elements to file.
   377  
   378      Note that this is assumed to only work in the DirectRunner as the underlying
   379      StreamingCacheSink assumes a single machine to have correct element
   380      ordering.
   381      """
   382      filename = labels[-1]
   383      cache_dir = os.path.join(self._cache_dir, *labels[:-1])
   384      sink = StreamingCacheSink(
   385          cache_dir,
   386          filename,
   387          self._sample_resolution_sec,
   388          self.load_pcoder(*labels))
   389      if is_capture:
   390        self._capture_sinks[sink.path] = sink
   391        self._capture_keys.add(filename)
   392      return sink
   393  
   394    def save_pcoder(self, pcoder, *labels):
   395      self._saved_pcoders[os.path.join(self._cache_dir, *labels)] = pcoder
   396  
   397    def load_pcoder(self, *labels):
   398      saved_pcoder = self._saved_pcoders.get(
   399          os.path.join(self._cache_dir, *labels), None)
   400      if saved_pcoder is None or isinstance(saved_pcoder,
   401                                            coders.FastPrimitivesCoder):
   402        return self._default_pcoder
   403      return saved_pcoder
   404  
   405    def cleanup(self):
   406  
   407      if os.path.exists(self._cache_dir):
   408  
   409        def on_fail_to_cleanup(function, path, excinfo):
   410          _LOGGER.warning(
   411              'Failed to clean up temporary files: %s. You may'
   412              'manually delete them if necessary. Error was: %s',
   413              path,
   414              excinfo)
   415  
   416        shutil.rmtree(self._cache_dir, onerror=on_fail_to_cleanup)
   417      self._saved_pcoders = {}
   418      self._capture_sinks = {}
   419      self._capture_keys = set()
   420  
   421    class Reader(object):
   422      """Abstraction that reads from PCollection readers.
   423  
   424      This class is an Abstraction layer over multiple PCollection readers to be
   425      used for supplying a TestStream service with events.
   426  
   427      This class is also responsible for holding the state of the clock, injecting
   428      clock advancement events, and watermark advancement events.
   429      """
   430      def __init__(self, headers, readers):
   431        # This timestamp is used as the monotonic clock to order events in the
   432        # replay.
   433        self._monotonic_clock = timestamp.Timestamp.of(0)
   434  
   435        # The PCollection cache readers.
   436        self._readers = {}
   437  
   438        # The file headers that are metadata for that particular PCollection.
   439        # The header allows for metadata about an entire stream, so that the data
   440        # isn't copied per record.
   441        self._headers = {header.tag: header for header in headers}
   442        self._readers = OrderedDict(
   443            ((h.tag, r) for (h, r) in zip(headers, readers)))
   444  
   445        # The most recently read timestamp per tag.
   446        self._stream_times = {
   447            tag: timestamp.Timestamp(seconds=0)
   448            for tag in self._headers
   449        }
   450  
   451      def _test_stream_events_before_target(self, target_timestamp):
   452        """Reads the next iteration of elements from each stream.
   453  
   454        Retrieves an element from each stream iff the most recently read timestamp
   455        from that stream is less than the target_timestamp. Since the amount of
   456        events may not fit into memory, this StreamingCache reads at most one
   457        element from each stream at a time.
   458        """
   459        records = []
   460        for tag, r in self._readers.items():
   461          # The target_timestamp is the maximum timestamp that was read from the
   462          # stream. Some readers may have elements that are less than this. Thus,
   463          # we skip all readers that already have elements that are at this
   464          # timestamp so that we don't read everything into memory.
   465          if self._stream_times[tag] >= target_timestamp:
   466            continue
   467          try:
   468            record = next(r).recorded_event
   469            if record.HasField('processing_time_event'):
   470              self._stream_times[tag] += timestamp.Duration(
   471                  micros=record.processing_time_event.advance_duration)
   472            records.append((tag, record, self._stream_times[tag]))
   473          except StopIteration:
   474            pass
   475        return records
   476  
   477      def _merge_sort(self, previous_events, new_events):
   478        return sorted(
   479            previous_events + new_events, key=lambda x: x[2], reverse=True)
   480  
   481      def _min_timestamp_of(self, events):
   482        return events[-1][2] if events else timestamp.MAX_TIMESTAMP
   483  
   484      def _event_stream_caught_up_to_target(self, events, target_timestamp):
   485        empty_events = not events
   486        stream_is_past_target = self._min_timestamp_of(events) > target_timestamp
   487        return empty_events or stream_is_past_target
   488  
   489      def read(self):
   490        """Reads records from PCollection readers.
   491        """
   492  
   493        # The largest timestamp read from the different streams.
   494        target_timestamp = timestamp.MAX_TIMESTAMP
   495  
   496        # The events from last iteration that are past the target timestamp.
   497        unsent_events = []
   498  
   499        # Emit events until all events have been read.
   500        while True:
   501          # Read the next set of events. The read events will most likely be
   502          # out of order if there are multiple readers. Here we sort them into
   503          # a more manageable state.
   504          new_events = self._test_stream_events_before_target(target_timestamp)
   505          events_to_send = self._merge_sort(unsent_events, new_events)
   506          if not events_to_send:
   507            break
   508  
   509          # Get the next largest timestamp in the stream. This is used as the
   510          # timestamp for readers to "catch-up" to. This will only read from
   511          # readers with a timestamp less than this.
   512          target_timestamp = self._min_timestamp_of(events_to_send)
   513  
   514          # Loop through the elements with the correct timestamp.
   515          while not self._event_stream_caught_up_to_target(events_to_send,
   516                                                           target_timestamp):
   517  
   518            # First advance the clock to match the time of the stream. This has
   519            # a side-effect of also advancing this cache's clock.
   520            tag, r, curr_timestamp = events_to_send.pop()
   521            if curr_timestamp > self._monotonic_clock:
   522              yield self._advance_processing_time(curr_timestamp)
   523  
   524            # Then, send either a new element or watermark.
   525            if r.HasField('element_event'):
   526              r.element_event.tag = tag
   527              yield r
   528            elif r.HasField('watermark_event'):
   529              r.watermark_event.tag = tag
   530              yield r
   531          unsent_events = events_to_send
   532          target_timestamp = self._min_timestamp_of(unsent_events)
   533  
   534      def _advance_processing_time(self, new_timestamp):
   535        """Advances the internal clock and returns an AdvanceProcessingTime event.
   536        """
   537        advancy_by = new_timestamp.micros - self._monotonic_clock.micros
   538        e = beam_runner_api_pb2.TestStreamPayload.Event(
   539            processing_time_event=beam_runner_api_pb2.TestStreamPayload.Event.
   540            AdvanceProcessingTime(advance_duration=advancy_by))
   541        self._monotonic_clock = new_timestamp
   542        return e