github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/testing/test_cache_manager.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 import collections 18 import itertools 19 import sys 20 21 import apache_beam as beam 22 from apache_beam import coders 23 from apache_beam.portability.api import beam_interactive_api_pb2 24 from apache_beam.portability.api import beam_runner_api_pb2 25 from apache_beam.runners.interactive.cache_manager import CacheManager 26 from apache_beam.utils.timestamp import Duration 27 from apache_beam.utils.timestamp import Timestamp 28 29 30 class InMemoryCache(CacheManager): 31 """A cache that stores all PCollections in an in-memory map. 32 33 This is only used for checking the pipeline shape. This can't be used for 34 running the pipeline isn't shared between the SDK and the Runner. 35 """ 36 def __init__(self): 37 self._cached = {} 38 self._pcoders = {} 39 40 def exists(self, *labels): 41 return self._key(*labels) in self._cached 42 43 def _latest_version(self, *labels): 44 return True 45 46 def read(self, *labels, **args): 47 if not self.exists(*labels): 48 return itertools.chain([]), -1 49 50 return itertools.chain(self._cached[self._key(*labels)]), None 51 52 def write(self, value, *labels): 53 if not self.exists(*labels): 54 self._cached[self._key(*labels)] = [] 55 self._cached[self._key(*labels)] += value 56 57 def save_pcoder(self, pcoder, *labels): 58 self._pcoders[self._key(*labels)] = pcoder 59 60 def load_pcoder(self, *labels): 61 return self._pcoders[self._key(*labels)] 62 63 def cleanup(self): 64 self._cached = collections.defaultdict(list) 65 self._pcoders = {} 66 67 def clear(self, *label): 68 # Noop because in-memory. 69 pass 70 71 def source(self, *labels): 72 vals = self._cached[self._key(*labels)] 73 return beam.Create(vals) 74 75 def sink(self, labels, is_capture=False): 76 return beam.Map(lambda _: _) 77 78 def size(self, *labels): 79 if self.exists(*labels): 80 return sys.getsizeof(self._cached[self._key(*labels)]) 81 return 0 82 83 def _key(self, *labels): 84 return '/'.join([l for l in labels]) 85 86 87 class NoopSink(beam.PTransform): 88 def expand(self, pcoll): 89 return pcoll | beam.Map(lambda x: x) 90 91 92 class FileRecordsBuilder(object): 93 def __init__(self, tag=None): 94 self._header = beam_interactive_api_pb2.TestStreamFileHeader(tag=tag) 95 self._records = [] 96 self._coder = coders.FastPrimitivesCoder() 97 98 def add_element(self, element, event_time_secs): 99 element_payload = beam_runner_api_pb2.TestStreamPayload.TimestampedElement( 100 encoded_element=self._coder.encode(element), 101 timestamp=Timestamp.of(event_time_secs).micros) 102 record = beam_interactive_api_pb2.TestStreamFileRecord( 103 recorded_event=beam_runner_api_pb2.TestStreamPayload.Event( 104 element_event=beam_runner_api_pb2.TestStreamPayload.Event. 105 AddElements(elements=[element_payload]))) 106 self._records.append(record) 107 return self 108 109 def advance_watermark(self, watermark_secs): 110 record = beam_interactive_api_pb2.TestStreamFileRecord( 111 recorded_event=beam_runner_api_pb2.TestStreamPayload.Event( 112 watermark_event=beam_runner_api_pb2.TestStreamPayload. 113 Event.AdvanceWatermark( 114 new_watermark=Timestamp.of(watermark_secs).micros))) 115 self._records.append(record) 116 return self 117 118 def advance_processing_time(self, delta_secs): 119 record = beam_interactive_api_pb2.TestStreamFileRecord( 120 recorded_event=beam_runner_api_pb2.TestStreamPayload.Event( 121 processing_time_event=beam_runner_api_pb2.TestStreamPayload.Event. 122 AdvanceProcessingTime( 123 advance_duration=Duration.of(delta_secs).micros))) 124 self._records.append(record) 125 return self 126 127 def build(self): 128 return [self._header] + self._records