github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/caching/streaming_cache.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 # pytype: skip-file 19 20 import logging 21 import os 22 import shutil 23 import tempfile 24 import time 25 import traceback 26 from collections import OrderedDict 27 # We don't have an explicit pathlib dependency because this code only works with 28 # the interactive target installed which has an indirect dependency on pathlib 29 # through ipython>=5.9.0. 30 from pathlib import Path 31 32 from google.protobuf.message import DecodeError 33 34 import apache_beam as beam 35 from apache_beam import coders 36 from apache_beam.portability.api import beam_interactive_api_pb2 37 from apache_beam.portability.api import beam_runner_api_pb2 38 from apache_beam.runners.interactive.cache_manager import CacheManager 39 from apache_beam.runners.interactive.cache_manager import SafeFastPrimitivesCoder 40 from apache_beam.runners.interactive.caching.cacheable import CacheKey 41 from apache_beam.testing.test_stream import OutputFormat 42 from apache_beam.testing.test_stream import ReverseTestStream 43 from apache_beam.utils import timestamp 44 45 _LOGGER = logging.getLogger(__name__) 46 47 48 class StreamingCacheSink(beam.PTransform): 49 """A PTransform that writes TestStreamFile(Header|Records)s to file. 50 51 This transform takes in an arbitrary element stream and writes the list of 52 TestStream events (as TestStreamFileRecords) to file. When replayed, this 53 will produce the best-effort replay of the original job (e.g. some elements 54 may be produced slightly out of order from the original stream). 55 56 Note that this PTransform is assumed to be only run on a single machine where 57 the following assumptions are correct: elements come in ordered, no two 58 transforms are writing to the same file. This PTransform is assumed to only 59 run correctly with the DirectRunner. 60 61 TODO(https://github.com/apache/beam/issues/20002): Generalize this to more 62 source/sink types aside from file based. Also, generalize to cases where 63 there might be multiple workers writing to the same sink. 64 """ 65 def __init__( 66 self, 67 cache_dir, 68 filename, 69 sample_resolution_sec, 70 coder=SafeFastPrimitivesCoder()): 71 self._cache_dir = cache_dir 72 self._filename = filename 73 self._sample_resolution_sec = sample_resolution_sec 74 self._coder = coder 75 self._path = os.path.join(self._cache_dir, self._filename) 76 77 @property 78 def path(self): 79 """Returns the path the sink leads to.""" 80 return self._path 81 82 @property 83 def size_in_bytes(self): 84 """Returns the space usage in bytes of the sink.""" 85 try: 86 return os.stat(self._path).st_size 87 except OSError: 88 _LOGGER.debug( 89 'Failed to calculate cache size for file %s, the file might have not ' 90 'been created yet. Return 0. %s', 91 self._path, 92 traceback.format_exc()) 93 return 0 94 95 def expand(self, pcoll): 96 class StreamingWriteToText(beam.DoFn): 97 """DoFn that performs the writing. 98 99 Note that the other file writing methods cannot be used in streaming 100 contexts. 101 """ 102 def __init__(self, full_path, coder=SafeFastPrimitivesCoder()): 103 self._full_path = full_path 104 self._coder = coder 105 106 # Try and make the given path. 107 Path(os.path.dirname(full_path)).mkdir(parents=True, exist_ok=True) 108 109 def start_bundle(self): 110 # Open the file for 'append-mode' and writing 'bytes'. 111 self._fh = open(self._full_path, 'ab') 112 113 def finish_bundle(self): 114 self._fh.close() 115 116 def process(self, e): 117 """Appends the given element to the file. 118 """ 119 self._fh.write(self._coder.encode(e) + b'\n') 120 121 return ( 122 pcoll 123 | ReverseTestStream( 124 output_tag=self._filename, 125 sample_resolution_sec=self._sample_resolution_sec, 126 output_format=OutputFormat.SERIALIZED_TEST_STREAM_FILE_RECORDS, 127 coder=self._coder) 128 | beam.ParDo( 129 StreamingWriteToText(full_path=self._path, coder=self._coder))) 130 131 132 class StreamingCacheSource: 133 """A class that reads and parses TestStreamFile(Header|Reader)s. 134 135 This source operates in the following way: 136 137 1. Wait for up to `timeout_secs` for the file to be available. 138 2. Read, parse, and emit the entire contents of the file 139 3. Wait for more events to come or until `is_cache_complete` returns True 140 4. If there are more events, then go to 2 141 5. Otherwise, stop emitting. 142 143 This class is used to read from file and send its to the TestStream via the 144 StreamingCacheManager.Reader. 145 """ 146 def __init__(self, cache_dir, labels, is_cache_complete=None, coder=None): 147 if not coder: 148 coder = SafeFastPrimitivesCoder() 149 150 if not is_cache_complete: 151 is_cache_complete = lambda _: True 152 153 self._cache_dir = cache_dir 154 self._coder = coder 155 self._labels = labels 156 self._path = os.path.join(self._cache_dir, *self._labels) 157 self._is_cache_complete = is_cache_complete 158 self._pipeline_id = CacheKey.from_str(labels[-1]).pipeline_id 159 160 def _wait_until_file_exists(self, timeout_secs=30): 161 """Blocks until the file exists for a maximum of timeout_secs. 162 """ 163 # Wait for up to `timeout_secs` for the file to be available. 164 start = time.time() 165 while not os.path.exists(self._path): 166 time.sleep(1) 167 if time.time() - start > timeout_secs: 168 pcollection_var = CacheKey.from_str(self._labels[-1]).var 169 raise RuntimeError( 170 'Timed out waiting for cache file for PCollection `{}` to be ' 171 'available with path {}.'.format(pcollection_var, self._path)) 172 return open(self._path, mode='rb') 173 174 def _emit_from_file(self, fh, tail): 175 """Emits the TestStreamFile(Header|Record)s from file. 176 177 This returns a generator to be able to read all lines from the given file. 178 If `tail` is True, then it will wait until the cache is complete to exit. 179 Otherwise, it will read the file only once. 180 """ 181 # Always read at least once to read the whole file. 182 while True: 183 pos = fh.tell() 184 line = fh.readline() 185 186 # Check if we are at EOF or if we have an incomplete line. 187 if not line or (line and line[-1] != b'\n'[0]): 188 # Read at least the first line to get the header. 189 if not tail and pos != 0: 190 break 191 192 # Complete reading only when the cache is complete. 193 if self._is_cache_complete(self._pipeline_id): 194 break 195 196 # Otherwise wait for new data in the file to be written. 197 time.sleep(0.5) 198 fh.seek(pos) 199 else: 200 # The first line at pos = 0 is always the header. Read the line without 201 # the new line. 202 to_decode = line[:-1] 203 if pos == 0: 204 proto_cls = beam_interactive_api_pb2.TestStreamFileHeader 205 else: 206 proto_cls = beam_interactive_api_pb2.TestStreamFileRecord 207 msg = self._try_parse_as(proto_cls, to_decode) 208 if msg: 209 yield msg 210 else: 211 break 212 213 def _try_parse_as(self, proto_cls, to_decode): 214 try: 215 msg = proto_cls() 216 msg.ParseFromString(self._coder.decode(to_decode)) 217 except DecodeError: 218 _LOGGER.error( 219 'Could not parse as %s. This can indicate that the cache is ' 220 'corruputed. Please restart the kernel. ' 221 '\nfile: %s \nmessage: %s', 222 proto_cls, 223 self._path, 224 to_decode) 225 msg = None 226 return msg 227 228 def read(self, tail): 229 """Reads all TestStreamFile(Header|TestStreamFileRecord)s from file. 230 231 This returns a generator to be able to read all lines from the given file. 232 If `tail` is True, then it will wait until the cache is complete to exit. 233 Otherwise, it will read the file only once. 234 """ 235 with self._wait_until_file_exists() as f: 236 for e in self._emit_from_file(f, tail): 237 yield e 238 239 240 # TODO(victorhc): Add support for cache_dir locations that are on GCS 241 class StreamingCache(CacheManager): 242 """Abstraction that holds the logic for reading and writing to cache. 243 """ 244 def __init__( 245 self, 246 cache_dir, 247 is_cache_complete=None, 248 sample_resolution_sec=0.1, 249 saved_pcoders=None): 250 self._sample_resolution_sec = sample_resolution_sec 251 self._is_cache_complete = is_cache_complete 252 253 if cache_dir: 254 self._cache_dir = cache_dir 255 else: 256 self._cache_dir = tempfile.mkdtemp( 257 prefix='ib-', dir=os.environ.get('TEST_TMPDIR', None)) 258 259 # List of saved pcoders keyed by PCollection path. It is OK to keep this 260 # list in memory because once FileBasedCacheManager object is 261 # destroyed/re-created it loses the access to previously written cache 262 # objects anyways even if cache_dir already exists. In other words, 263 # it is not possible to resume execution of Beam pipeline from the 264 # saved cache if FileBasedCacheManager has been reset. 265 # 266 # However, if we are to implement better cache persistence, one needs 267 # to take care of keeping consistency between the cached PCollection 268 # and its PCoder type. 269 self._saved_pcoders = saved_pcoders or {} 270 self._default_pcoder = SafeFastPrimitivesCoder() 271 272 # The sinks to capture data from capturable sources. 273 # Dict([str, StreamingCacheSink]) 274 self._capture_sinks = {} 275 self._capture_keys = set() 276 277 def size(self, *labels): 278 if self.exists(*labels): 279 return os.path.getsize(os.path.join(self._cache_dir, *labels)) 280 return 0 281 282 @property 283 def capture_size(self): 284 return sum([sink.size_in_bytes for _, sink in self._capture_sinks.items()]) 285 286 @property 287 def capture_paths(self): 288 return list(self._capture_sinks.keys()) 289 290 @property 291 def capture_keys(self): 292 return self._capture_keys 293 294 def exists(self, *labels): 295 if labels and any(labels): 296 path = os.path.join(self._cache_dir, *labels) 297 return os.path.exists(path) 298 return False 299 300 # TODO(srohde): Modify this to return the correct version. 301 def read(self, *labels, **args): 302 """Returns a generator to read all records from file.""" 303 tail = args.pop('tail', False) 304 305 # Only immediately return when the file doesn't exist when the user wants a 306 # snapshot of the cache (when tail is false). 307 if not self.exists(*labels) and not tail: 308 return iter([]), -1 309 310 reader = StreamingCacheSource( 311 self._cache_dir, 312 labels, 313 self._is_cache_complete, 314 self.load_pcoder(*labels)).read(tail=tail) 315 316 # Return an empty iterator if there is nothing in the file yet. This can 317 # only happen when tail is False. 318 try: 319 header = next(reader) 320 except StopIteration: 321 return iter([]), -1 322 return StreamingCache.Reader([header], [reader]).read(), 1 323 324 def read_multiple(self, labels, tail=True): 325 """Returns a generator to read all records from file. 326 327 Does tail until the cache is complete. This is because it is used in the 328 TestStreamServiceController to read from file which is only used during 329 pipeline runtime which needs to block. 330 """ 331 readers = [ 332 StreamingCacheSource( 333 self._cache_dir, l, self._is_cache_complete, 334 self.load_pcoder(*l)).read(tail=tail) for l in labels 335 ] 336 headers = [next(r) for r in readers] 337 return StreamingCache.Reader(headers, readers).read() 338 339 def write(self, values, *labels): 340 """Writes the given values to cache. 341 """ 342 directory = os.path.join(self._cache_dir, *labels[:-1]) 343 filepath = os.path.join(directory, labels[-1]) 344 if not os.path.exists(directory): 345 os.makedirs(directory) 346 with open(filepath, 'ab') as f: 347 for v in values: 348 if isinstance(v, 349 (beam_interactive_api_pb2.TestStreamFileHeader, 350 beam_interactive_api_pb2.TestStreamFileRecord)): 351 val = v.SerializeToString() 352 else: 353 raise TypeError( 354 'Values given to streaming cache should be either ' 355 'TestStreamFileHeader or TestStreamFileRecord.') 356 f.write(self.load_pcoder(*labels).encode(val) + b'\n') 357 358 def clear(self, *labels): 359 directory = os.path.join(self._cache_dir, *labels[:-1]) 360 filepath = os.path.join(directory, labels[-1]) 361 self._capture_keys.discard(labels[-1]) 362 if os.path.exists(filepath): 363 os.remove(filepath) 364 return True 365 return False 366 367 def source(self, *labels): 368 """Returns the StreamingCacheManager source. 369 370 This is beam.Impulse() because unbounded sources will be marked with this 371 and then the PipelineInstrument will replace these with a TestStream. 372 """ 373 return beam.Impulse() 374 375 def sink(self, labels, is_capture=False): 376 """Returns a StreamingCacheSink to write elements to file. 377 378 Note that this is assumed to only work in the DirectRunner as the underlying 379 StreamingCacheSink assumes a single machine to have correct element 380 ordering. 381 """ 382 filename = labels[-1] 383 cache_dir = os.path.join(self._cache_dir, *labels[:-1]) 384 sink = StreamingCacheSink( 385 cache_dir, 386 filename, 387 self._sample_resolution_sec, 388 self.load_pcoder(*labels)) 389 if is_capture: 390 self._capture_sinks[sink.path] = sink 391 self._capture_keys.add(filename) 392 return sink 393 394 def save_pcoder(self, pcoder, *labels): 395 self._saved_pcoders[os.path.join(self._cache_dir, *labels)] = pcoder 396 397 def load_pcoder(self, *labels): 398 saved_pcoder = self._saved_pcoders.get( 399 os.path.join(self._cache_dir, *labels), None) 400 if saved_pcoder is None or isinstance(saved_pcoder, 401 coders.FastPrimitivesCoder): 402 return self._default_pcoder 403 return saved_pcoder 404 405 def cleanup(self): 406 407 if os.path.exists(self._cache_dir): 408 409 def on_fail_to_cleanup(function, path, excinfo): 410 _LOGGER.warning( 411 'Failed to clean up temporary files: %s. You may' 412 'manually delete them if necessary. Error was: %s', 413 path, 414 excinfo) 415 416 shutil.rmtree(self._cache_dir, onerror=on_fail_to_cleanup) 417 self._saved_pcoders = {} 418 self._capture_sinks = {} 419 self._capture_keys = set() 420 421 class Reader(object): 422 """Abstraction that reads from PCollection readers. 423 424 This class is an Abstraction layer over multiple PCollection readers to be 425 used for supplying a TestStream service with events. 426 427 This class is also responsible for holding the state of the clock, injecting 428 clock advancement events, and watermark advancement events. 429 """ 430 def __init__(self, headers, readers): 431 # This timestamp is used as the monotonic clock to order events in the 432 # replay. 433 self._monotonic_clock = timestamp.Timestamp.of(0) 434 435 # The PCollection cache readers. 436 self._readers = {} 437 438 # The file headers that are metadata for that particular PCollection. 439 # The header allows for metadata about an entire stream, so that the data 440 # isn't copied per record. 441 self._headers = {header.tag: header for header in headers} 442 self._readers = OrderedDict( 443 ((h.tag, r) for (h, r) in zip(headers, readers))) 444 445 # The most recently read timestamp per tag. 446 self._stream_times = { 447 tag: timestamp.Timestamp(seconds=0) 448 for tag in self._headers 449 } 450 451 def _test_stream_events_before_target(self, target_timestamp): 452 """Reads the next iteration of elements from each stream. 453 454 Retrieves an element from each stream iff the most recently read timestamp 455 from that stream is less than the target_timestamp. Since the amount of 456 events may not fit into memory, this StreamingCache reads at most one 457 element from each stream at a time. 458 """ 459 records = [] 460 for tag, r in self._readers.items(): 461 # The target_timestamp is the maximum timestamp that was read from the 462 # stream. Some readers may have elements that are less than this. Thus, 463 # we skip all readers that already have elements that are at this 464 # timestamp so that we don't read everything into memory. 465 if self._stream_times[tag] >= target_timestamp: 466 continue 467 try: 468 record = next(r).recorded_event 469 if record.HasField('processing_time_event'): 470 self._stream_times[tag] += timestamp.Duration( 471 micros=record.processing_time_event.advance_duration) 472 records.append((tag, record, self._stream_times[tag])) 473 except StopIteration: 474 pass 475 return records 476 477 def _merge_sort(self, previous_events, new_events): 478 return sorted( 479 previous_events + new_events, key=lambda x: x[2], reverse=True) 480 481 def _min_timestamp_of(self, events): 482 return events[-1][2] if events else timestamp.MAX_TIMESTAMP 483 484 def _event_stream_caught_up_to_target(self, events, target_timestamp): 485 empty_events = not events 486 stream_is_past_target = self._min_timestamp_of(events) > target_timestamp 487 return empty_events or stream_is_past_target 488 489 def read(self): 490 """Reads records from PCollection readers. 491 """ 492 493 # The largest timestamp read from the different streams. 494 target_timestamp = timestamp.MAX_TIMESTAMP 495 496 # The events from last iteration that are past the target timestamp. 497 unsent_events = [] 498 499 # Emit events until all events have been read. 500 while True: 501 # Read the next set of events. The read events will most likely be 502 # out of order if there are multiple readers. Here we sort them into 503 # a more manageable state. 504 new_events = self._test_stream_events_before_target(target_timestamp) 505 events_to_send = self._merge_sort(unsent_events, new_events) 506 if not events_to_send: 507 break 508 509 # Get the next largest timestamp in the stream. This is used as the 510 # timestamp for readers to "catch-up" to. This will only read from 511 # readers with a timestamp less than this. 512 target_timestamp = self._min_timestamp_of(events_to_send) 513 514 # Loop through the elements with the correct timestamp. 515 while not self._event_stream_caught_up_to_target(events_to_send, 516 target_timestamp): 517 518 # First advance the clock to match the time of the stream. This has 519 # a side-effect of also advancing this cache's clock. 520 tag, r, curr_timestamp = events_to_send.pop() 521 if curr_timestamp > self._monotonic_clock: 522 yield self._advance_processing_time(curr_timestamp) 523 524 # Then, send either a new element or watermark. 525 if r.HasField('element_event'): 526 r.element_event.tag = tag 527 yield r 528 elif r.HasField('watermark_event'): 529 r.watermark_event.tag = tag 530 yield r 531 unsent_events = events_to_send 532 target_timestamp = self._min_timestamp_of(unsent_events) 533 534 def _advance_processing_time(self, new_timestamp): 535 """Advances the internal clock and returns an AdvanceProcessingTime event. 536 """ 537 advancy_by = new_timestamp.micros - self._monotonic_clock.micros 538 e = beam_runner_api_pb2.TestStreamPayload.Event( 539 processing_time_event=beam_runner_api_pb2.TestStreamPayload.Event. 540 AdvanceProcessingTime(advance_duration=advancy_by)) 541 self._monotonic_clock = new_timestamp 542 return e