github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/worker/statecache.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """A module for caching state reads/writes in Beam applications.""" 19 # pytype: skip-file 20 # mypy: disallow-untyped-defs 21 22 import collections 23 import gc 24 import logging 25 import sys 26 import threading 27 import time 28 import types 29 import weakref 30 from typing import Any 31 from typing import Callable 32 from typing import List 33 from typing import Tuple 34 from typing import Union 35 36 import objsize 37 38 _LOGGER = logging.getLogger(__name__) 39 _DEFAULT_WEIGHT = 8 40 _TYPES_TO_NOT_MEASURE = ( 41 # Do not measure shared types 42 type, 43 types.ModuleType, 44 types.FrameType, 45 types.BuiltinFunctionType, 46 # Do not measure lambdas as they typically share lots of state 47 types.FunctionType, 48 types.LambdaType, 49 # Do not measure weak references as they will be deleted and not counted 50 *weakref.ProxyTypes, 51 weakref.ReferenceType) 52 53 54 class WeightedValue(object): 55 """Value type that stores corresponding weight. 56 57 :arg value The value to be stored. 58 :arg weight The associated weight of the value. If unspecified, the objects 59 size will be used. 60 """ 61 def __init__(self, value, weight): 62 # type: (Any, int) -> None 63 self._value = value 64 if weight <= 0: 65 raise ValueError( 66 'Expected weight to be > 0 for %s but received %d' % (value, weight)) 67 self._weight = weight 68 69 def weight(self): 70 # type: () -> int 71 return self._weight 72 73 def value(self): 74 # type: () -> Any 75 return self._value 76 77 78 class CacheAware(object): 79 """Allows cache users to override what objects are measured.""" 80 def __init__(self): 81 # type: () -> None 82 pass 83 84 def get_referents_for_cache(self): 85 # type: () -> List[Any] 86 87 """Returns the list of objects accounted during cache measurement.""" 88 raise NotImplementedError() 89 90 91 def _safe_isinstance(obj, type): 92 # type: (Any, Union[type, Tuple[type, ...]]) -> bool 93 94 """ 95 Return whether an object is an instance of a class or of a subclass thereof. 96 See `isinstance()` for more information. 97 98 Returns false on `isinstance()` failure. For example applying `isinstance()` 99 on `weakref.proxy` objects attempts to dereference the proxy objects, which 100 may yield an exception. See https://github.com/apache/beam/issues/23389 for 101 additional details. 102 """ 103 try: 104 return isinstance(obj, type) 105 except Exception: 106 return False 107 108 109 def _size_func(obj): 110 # type: (Any) -> int 111 112 """ 113 Returns the size of the object or a default size if an error occurred during 114 sizing. 115 """ 116 try: 117 return sys.getsizeof(obj) 118 except Exception as e: 119 current_time = time.time() 120 # Limit outputting this log so we don't spam the logs on these 121 # occurrences. 122 if _size_func.last_log_time + 300 < current_time: # type: ignore 123 _LOGGER.warning( 124 'Failed to size %s of type %s. Note that this may ' 125 'impact cache sizing such that the cache is over ' 126 'utilized which may lead to out of memory errors.', 127 obj, 128 type(obj), 129 exc_info=e) 130 _size_func.last_log_time = current_time # type: ignore 131 # Use an arbitrary default size that would account for some of the object 132 # overhead. 133 return _DEFAULT_WEIGHT 134 135 136 _size_func.last_log_time = 0 # type: ignore 137 138 139 def _get_referents_func(*objs): 140 # type: (List[Any]) -> List[Any] 141 142 """Returns the list of objects accounted during cache measurement. 143 144 Users can inherit CacheAware to override which referents should be 145 used when measuring the deep size of the object. The default is to 146 use gc.get_referents(*objs). 147 """ 148 rval = [] 149 for obj in objs: 150 if _safe_isinstance(obj, CacheAware): 151 rval.extend(obj.get_referents_for_cache()) # type: ignore 152 else: 153 rval.extend(gc.get_referents(obj)) 154 return rval 155 156 157 def _filter_func(o): 158 # type: (Any) -> bool 159 160 """ 161 Filter out specific types from being measured. 162 163 Note that we do want to measure the cost of weak references as they will only 164 stay in scope as long as other code references them and will effectively be 165 garbage collected as soon as there isn't a strong reference anymore. 166 167 Note that we cannot use the default filter function due to isinstance raising 168 an error on weakref.proxy types. See 169 https://github.com/liran-funaro/objsize/issues/6 for additional details. 170 """ 171 return not _safe_isinstance(o, _TYPES_TO_NOT_MEASURE) 172 173 174 def get_deep_size(*objs): 175 # type: (Any) -> int 176 177 """Calculates the deep size of all the arguments in bytes.""" 178 return objsize.get_deep_size( 179 *objs, 180 get_size_func=_size_func, 181 get_referents_func=_get_referents_func, 182 filter_func=_filter_func) 183 184 185 class _LoadingValue(WeightedValue): 186 """Allows concurrent users of the cache to wait for a value to be loaded.""" 187 def __init__(self): 188 # type: () -> None 189 super().__init__(None, 1) 190 self._wait_event = threading.Event() 191 192 def load(self, key, loading_fn): 193 # type: (Any, Callable[[Any], Any]) -> None 194 try: 195 self._value = loading_fn(key) 196 except Exception as err: 197 self._error = err 198 finally: 199 self._wait_event.set() 200 201 def value(self): 202 # type: () -> Any 203 self._wait_event.wait() 204 err = getattr(self, "_error", None) 205 if err: 206 raise err 207 return self._value 208 209 210 class StateCache(object): 211 """LRU cache for Beam state access, scoped by state key and cache_token. 212 Assumes a bag state implementation. 213 214 For a given key, caches a value and allows to 215 a) peek at the cache (peek), 216 returns the value for the provided key or None if it doesn't exist. 217 Will never block. 218 b) read from the cache (get), 219 returns the value for the provided key or loads it using the 220 supplied function. Multiple calls for the same key will block 221 until the value is loaded. 222 c) write to the cache (put), 223 store the provided value overwriting any previous result 224 d) invalidate a cached element (invalidate) 225 removes the value from the cache for the provided key 226 e) invalidate all cached elements (invalidate_all) 227 228 The operations on the cache are thread-safe for use by multiple workers. 229 230 :arg max_weight The maximum weight of entries to store in the cache in bytes. 231 """ 232 def __init__(self, max_weight): 233 # type: (int) -> None 234 _LOGGER.info('Creating state cache with size %s', max_weight) 235 self._max_weight = max_weight 236 self._current_weight = 0 237 self._cache = collections.OrderedDict( 238 ) # type: collections.OrderedDict[Any, WeightedValue] 239 self._hit_count = 0 240 self._miss_count = 0 241 self._evict_count = 0 242 self._load_time_ns = 0 243 self._load_count = 0 244 self._lock = threading.RLock() 245 246 def peek(self, key): 247 # type: (Any) -> Any 248 assert self.is_cache_enabled() 249 with self._lock: 250 value = self._cache.get(key, None) 251 if value is None or _safe_isinstance(value, _LoadingValue): 252 self._miss_count += 1 253 return None 254 255 self._cache.move_to_end(key) 256 self._hit_count += 1 257 return value.value() 258 259 def get(self, key, loading_fn): 260 # type: (Any, Callable[[Any], Any]) -> Any 261 assert self.is_cache_enabled() and callable(loading_fn) 262 263 self._lock.acquire() 264 value = self._cache.get(key, None) 265 266 # Return the already cached value 267 if value is not None: 268 self._cache.move_to_end(key) 269 self._hit_count += 1 270 self._lock.release() 271 return value.value() 272 273 # Load the value since it isn't in the cache. 274 self._miss_count += 1 275 loading_value = _LoadingValue() 276 self._cache[key] = loading_value 277 self._current_weight += loading_value.weight() 278 279 # Ensure that we unlock the lock while loading to allow for parallel gets 280 self._lock.release() 281 282 start_time_ns = time.time_ns() 283 loading_value.load(key, loading_fn) 284 elapsed_time_ns = time.time_ns() - start_time_ns 285 286 try: 287 value = loading_value.value() 288 except Exception as err: 289 # If loading failed then delete the value from the cache allowing for 290 # the next lookup to possibly succeed. 291 with self._lock: 292 self._load_count += 1 293 self._load_time_ns += elapsed_time_ns 294 # Don't remove values that have already been replaced with a different 295 # value by a put/invalidate that occurred concurrently with the load. 296 # The put/invalidate will have been responsible for updating the 297 # cache weight appropriately already. 298 old_value = self._cache.get(key, None) 299 if old_value is not loading_value: 300 raise err 301 self._current_weight -= loading_value.weight() 302 del self._cache[key] 303 raise err 304 305 # Replace the value in the cache with a weighted value now that the 306 # loading has completed successfully. 307 weight = get_deep_size(value) 308 if weight <= 0: 309 _LOGGER.warning( 310 'Expected object size to be >= 0 for %s but received %d.', 311 value, 312 weight) 313 weight = 8 314 value = WeightedValue(value, weight) 315 with self._lock: 316 self._load_count += 1 317 self._load_time_ns += elapsed_time_ns 318 # Don't replace values that have already been replaced with a different 319 # value by a put/invalidate that occurred concurrently with the load. 320 # The put/invalidate will have been responsible for updating the 321 # cache weight appropriately already. 322 old_value = self._cache.get(key, None) 323 if old_value is not loading_value: 324 return value.value() 325 326 self._current_weight -= loading_value.weight() 327 self._cache[key] = value 328 self._current_weight += value.weight() 329 while self._current_weight > self._max_weight: 330 (_, weighted_value) = self._cache.popitem(last=False) 331 self._current_weight -= weighted_value.weight() 332 self._evict_count += 1 333 334 return value.value() 335 336 def put(self, key, value): 337 # type: (Any, Any) -> None 338 assert self.is_cache_enabled() 339 if not _safe_isinstance(value, WeightedValue): 340 weight = get_deep_size(value) 341 if weight <= 0: 342 _LOGGER.warning( 343 'Expected object size to be >= 0 for %s but received %d.', 344 value, 345 weight) 346 weight = _DEFAULT_WEIGHT 347 value = WeightedValue(value, weight) 348 with self._lock: 349 old_value = self._cache.pop(key, None) 350 if old_value is not None: 351 self._current_weight -= old_value.weight() 352 self._cache[key] = value 353 self._current_weight += value.weight() 354 while self._current_weight > self._max_weight: 355 (_, weighted_value) = self._cache.popitem(last=False) 356 self._current_weight -= weighted_value.weight() 357 self._evict_count += 1 358 359 def invalidate(self, key): 360 # type: (Any) -> None 361 assert self.is_cache_enabled() 362 with self._lock: 363 weighted_value = self._cache.pop(key, None) 364 if weighted_value is not None: 365 self._current_weight -= weighted_value.weight() 366 367 def invalidate_all(self): 368 # type: () -> None 369 with self._lock: 370 self._cache.clear() 371 self._current_weight = 0 372 373 def describe_stats(self): 374 # type: () -> str 375 with self._lock: 376 request_count = self._hit_count + self._miss_count 377 if request_count > 0: 378 hit_ratio = 100.0 * self._hit_count / request_count 379 else: 380 hit_ratio = 100.0 381 return ( 382 'used/max %d/%d MB, hit %.2f%%, lookups %d, ' 383 'avg load time %.0f ns, loads %d, evictions %d') % ( 384 self._current_weight >> 20, 385 self._max_weight >> 20, 386 hit_ratio, 387 request_count, 388 self._load_time_ns / 389 self._load_count if self._load_count > 0 else 0, 390 self._load_count, 391 self._evict_count) 392 393 def is_cache_enabled(self): 394 # type: () -> bool 395 return self._max_weight > 0 396 397 def size(self): 398 # type: () -> int 399 with self._lock: 400 return len(self._cache)