github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/cache_manager.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 # pytype: skip-file 19 20 import collections 21 import os 22 import tempfile 23 from urllib.parse import quote 24 from urllib.parse import unquote_to_bytes 25 26 import apache_beam as beam 27 from apache_beam import coders 28 from apache_beam.io import filesystems 29 from apache_beam.io import textio 30 from apache_beam.io import tfrecordio 31 from apache_beam.transforms import combiners 32 33 34 class CacheManager(object): 35 """Abstract class for caching PCollections. 36 37 A PCollection cache is identified by labels, which consist of a prefix (either 38 'full' or 'sample') and a cache_label which is a hash of the PCollection 39 derivation. 40 """ 41 def exists(self, *labels): 42 # type (*str) -> bool 43 44 """Returns if the PCollection cache exists.""" 45 raise NotImplementedError 46 47 def is_latest_version(self, version, *labels): 48 # type (str, *str) -> bool 49 50 """Returns if the given version number is the latest.""" 51 return version == self._latest_version(*labels) 52 53 def _latest_version(self, *labels): 54 # type (*str) -> str 55 56 """Returns the latest version number of the PCollection cache.""" 57 raise NotImplementedError 58 59 def read(self, *labels, **args): 60 # type (*str, Dict[str, Any]) -> Tuple[str, Generator[Any]] 61 62 """Return the PCollection as a list as well as the version number. 63 64 Args: 65 *labels: List of labels for PCollection instance. 66 **args: Dict of additional arguments. Currently only 'tail' as a boolean. 67 When tail is True, will wait and read new elements until the cache is 68 complete. 69 70 Returns: 71 A tuple containing an iterator for the items in the PCollection and the 72 version number. 73 74 It is possible that the version numbers from read() and_latest_version() 75 are different. This usually means that the cache's been evicted (thus 76 unavailable => read() returns version = -1), but it had reached version n 77 before eviction. 78 """ 79 raise NotImplementedError 80 81 def write(self, value, *labels): 82 # type (Any, *str) -> None 83 84 """Writes the value to the given cache. 85 86 Args: 87 value: An encodable (with corresponding PCoder) value 88 *labels: List of labels for PCollection instance 89 """ 90 raise NotImplementedError 91 92 def clear(self, *labels): 93 # type (*str) -> Boolean 94 95 """Clears the cache entry of the given labels and returns True on success. 96 97 Args: 98 value: An encodable (with corresponding PCoder) value 99 *labels: List of labels for PCollection instance 100 """ 101 raise NotImplementedError 102 103 def source(self, *labels): 104 # type (*str) -> ptransform.PTransform 105 106 """Returns a PTransform that reads the PCollection cache.""" 107 raise NotImplementedError 108 109 def sink(self, labels, is_capture=False): 110 # type (*str, bool) -> ptransform.PTransform 111 112 """Returns a PTransform that writes the PCollection cache. 113 114 TODO(BEAM-10514): Make sure labels will not be converted into an 115 arbitrarily long file path: e.g., windows has a 260 path limit. 116 """ 117 raise NotImplementedError 118 119 def save_pcoder(self, pcoder, *labels): 120 # type (coders.Coder, *str) -> None 121 122 """Saves pcoder for given PCollection. 123 124 Correct reading of PCollection from Cache requires PCoder to be known. 125 This method saves desired PCoder for PCollection that will subsequently 126 be used by sink(...), source(...), and, most importantly, read(...) method. 127 The latter must be able to read a PCollection written by Beam using 128 non-Beam IO. 129 130 Args: 131 pcoder: A PCoder to be used for reading and writing a PCollection. 132 *labels: List of labels for PCollection instance. 133 """ 134 raise NotImplementedError 135 136 def load_pcoder(self, *labels): 137 # type (*str) -> coders.Coder 138 139 """Returns previously saved PCoder for reading and writing PCollection.""" 140 raise NotImplementedError 141 142 def cleanup(self): 143 # type () -> None 144 145 """Cleans up all the PCollection caches.""" 146 raise NotImplementedError 147 148 def size(self, *labels): 149 # type: (*str) -> int 150 151 """Returns the size of the PCollection on disk in bytes.""" 152 raise NotImplementedError 153 154 155 class FileBasedCacheManager(CacheManager): 156 """Maps PCollections to local temp files for materialization.""" 157 158 _available_formats = { 159 'text': (textio.ReadFromText, textio.WriteToText), 160 'tfrecord': (tfrecordio.ReadFromTFRecord, tfrecordio.WriteToTFRecord) 161 } 162 163 def __init__(self, cache_dir=None, cache_format='text'): 164 if cache_dir: 165 self._cache_dir = cache_dir 166 else: 167 self._cache_dir = tempfile.mkdtemp( 168 prefix='ib-', dir=os.environ.get('TEST_TMPDIR', None)) 169 self._versions = collections.defaultdict(lambda: self._CacheVersion()) 170 self.cache_format = cache_format 171 172 if cache_format not in self._available_formats: 173 raise ValueError("Unsupported cache format: '%s'." % cache_format) 174 self._reader_class, self._writer_class = self._available_formats[ 175 cache_format] 176 self._default_pcoder = ( 177 SafeFastPrimitivesCoder() if cache_format == 'text' else None) 178 179 # List of saved pcoders keyed by PCollection path. It is OK to keep this 180 # list in memory because once FileBasedCacheManager object is 181 # destroyed/re-created it loses the access to previously written cache 182 # objects anyways even if cache_dir already exists. In other words, 183 # it is not possible to resume execution of Beam pipeline from the 184 # saved cache if FileBasedCacheManager has been reset. 185 # 186 # However, if we are to implement better cache persistence, one needs 187 # to take care of keeping consistency between the cached PCollection 188 # and its PCoder type. 189 self._saved_pcoders = {} 190 191 def size(self, *labels): 192 if self.exists(*labels): 193 matched_path = self._match(*labels) 194 # if any matched path has a gs:// prefix, it must be cached on GCS 195 if 'gs://' in matched_path[0]: 196 from apache_beam.io.gcp import gcsio 197 return sum( 198 sum(gcsio.GcsIO().list_prefix(path).values()) 199 for path in matched_path) 200 return sum(os.path.getsize(path) for path in matched_path) 201 return 0 202 203 def exists(self, *labels): 204 if labels and any(labels[1:]): 205 return bool(self._match(*labels)) 206 return False 207 208 def _latest_version(self, *labels): 209 timestamp = 0 210 for path in self._match(*labels): 211 timestamp = max(timestamp, filesystems.FileSystems.last_updated(path)) 212 result = self._versions["-".join(labels)].get_version(timestamp) 213 return result 214 215 def save_pcoder(self, pcoder, *labels): 216 self._saved_pcoders[self._path(*labels)] = pcoder 217 218 def load_pcoder(self, *labels): 219 saved_pcoder = self._saved_pcoders.get(self._path(*labels), None) 220 if saved_pcoder is None or isinstance(saved_pcoder, 221 coders.FastPrimitivesCoder): 222 return self._default_pcoder 223 return saved_pcoder 224 225 def read(self, *labels, **args): 226 # Return an iterator to an empty list if it doesn't exist. 227 if not self.exists(*labels): 228 return iter([]), -1 229 230 # Otherwise, return a generator to the cached PCollection. 231 source = self.source(*labels)._source 232 range_tracker = source.get_range_tracker(None, None) 233 reader = source.read(range_tracker) 234 version = self._latest_version(*labels) 235 236 return reader, version 237 238 def write(self, values, *labels): 239 """Imitates how a WriteCache transform works without running a pipeline. 240 241 For testing and cache manager development, not for production usage because 242 the write is not sharded and does not use Beam execution model. 243 """ 244 pcoder = coders.registry.get_coder(type(values[0])) 245 # Save the pcoder for the actual labels. 246 self.save_pcoder(pcoder, *labels) 247 single_shard_labels = [*labels[:-1], '-00000-of-00001'] 248 # Save the pcoder for the labels that imitates the sharded cache file name 249 # suffix. 250 self.save_pcoder(pcoder, *single_shard_labels) 251 # Put a '-%05d-of-%05d' suffix to the cache file. 252 sink = self.sink(single_shard_labels)._sink 253 path = self._path(*labels[:-1]) 254 writer = sink.open_writer(path, labels[-1]) 255 for v in values: 256 writer.write(v) 257 writer.close() 258 259 def clear(self, *labels): 260 if self.exists(*labels): 261 filesystems.FileSystems.delete(self._match(*labels)) 262 return True 263 return False 264 265 def source(self, *labels): 266 return self._reader_class( 267 self._glob_path(*labels), coder=self.load_pcoder(*labels)) 268 269 def sink(self, labels, is_capture=False): 270 return self._writer_class( 271 self._path(*labels), coder=self.load_pcoder(*labels)) 272 273 def cleanup(self): 274 if self._cache_dir.startswith('gs://'): 275 from apache_beam.io.gcp import gcsfilesystem 276 from apache_beam.options.pipeline_options import PipelineOptions 277 fs = gcsfilesystem.GCSFileSystem(PipelineOptions()) 278 fs.delete([self._cache_dir + '/full/']) 279 elif filesystems.FileSystems.exists(self._cache_dir): 280 filesystems.FileSystems.delete([self._cache_dir]) 281 self._saved_pcoders = {} 282 283 def _glob_path(self, *labels): 284 return self._path(*labels) + '*-*-of-*' 285 286 def _path(self, *labels): 287 return filesystems.FileSystems.join(self._cache_dir, *labels) 288 289 def _match(self, *labels): 290 match = filesystems.FileSystems.match([self._glob_path(*labels)]) 291 assert len(match) == 1 292 return [metadata.path for metadata in match[0].metadata_list] 293 294 class _CacheVersion(object): 295 """This class keeps track of the timestamp and the corresponding version.""" 296 def __init__(self): 297 self.current_version = -1 298 self.current_timestamp = 0 299 300 def get_version(self, timestamp): 301 """Updates version if necessary and returns the version number. 302 303 Args: 304 timestamp: (int) unix timestamp when the cache is updated. This value is 305 zero if the cache has been evicted or doesn't exist. 306 """ 307 # Do not update timestamp if the cache's been evicted. 308 if timestamp != 0 and timestamp != self.current_timestamp: 309 assert timestamp > self.current_timestamp 310 self.current_version = self.current_version + 1 311 self.current_timestamp = timestamp 312 return self.current_version 313 314 315 class ReadCache(beam.PTransform): 316 """A PTransform that reads the PCollections from the cache.""" 317 def __init__(self, cache_manager, label): 318 self._cache_manager = cache_manager 319 self._label = label 320 321 def expand(self, pbegin): 322 # pylint: disable=expression-not-assigned 323 return pbegin | 'Read' >> self._cache_manager.source('full', self._label) 324 325 326 class WriteCache(beam.PTransform): 327 """A PTransform that writes the PCollections to the cache.""" 328 def __init__( 329 self, 330 cache_manager, 331 label, 332 sample=False, 333 sample_size=0, 334 is_capture=False): 335 self._cache_manager = cache_manager 336 self._label = label 337 self._sample = sample 338 self._sample_size = sample_size 339 self._is_capture = is_capture 340 341 def expand(self, pcoll): 342 prefix = 'sample' if self._sample else 'full' 343 344 # We save pcoder that is necessary for proper reading of 345 # cached PCollection. _cache_manager.sink(...) call below 346 # should be using this saved pcoder. 347 self._cache_manager.save_pcoder( 348 coders.registry.get_coder(pcoll.element_type), prefix, self._label) 349 350 if self._sample: 351 pcoll |= 'Sample' >> ( 352 combiners.Sample.FixedSizeGlobally(self._sample_size) 353 | beam.FlatMap(lambda sample: sample)) 354 # pylint: disable=expression-not-assigned 355 return pcoll | 'Write' >> self._cache_manager.sink( 356 (prefix, self._label), is_capture=self._is_capture) 357 358 359 class SafeFastPrimitivesCoder(coders.Coder): 360 """This class add an quote/unquote step to escape special characters.""" 361 362 # pylint: disable=bad-option-value 363 364 def encode(self, value): 365 return quote( 366 coders.coders.FastPrimitivesCoder().encode(value)).encode('utf-8') 367 368 def decode(self, value): 369 return coders.coders.FastPrimitivesCoder().decode(unquote_to_bytes(value))