github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/worker/statecache.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """A module for caching state reads/writes in Beam applications."""
    19  # pytype: skip-file
    20  # mypy: disallow-untyped-defs
    21  
    22  import collections
    23  import gc
    24  import logging
    25  import sys
    26  import threading
    27  import time
    28  import types
    29  import weakref
    30  from typing import Any
    31  from typing import Callable
    32  from typing import List
    33  from typing import Tuple
    34  from typing import Union
    35  
    36  import objsize
    37  
    38  _LOGGER = logging.getLogger(__name__)
    39  _DEFAULT_WEIGHT = 8
    40  _TYPES_TO_NOT_MEASURE = (
    41      # Do not measure shared types
    42      type,
    43      types.ModuleType,
    44      types.FrameType,
    45      types.BuiltinFunctionType,
    46      # Do not measure lambdas as they typically share lots of state
    47      types.FunctionType,
    48      types.LambdaType,
    49      # Do not measure weak references as they will be deleted and not counted
    50      *weakref.ProxyTypes,
    51      weakref.ReferenceType)
    52  
    53  
    54  class WeightedValue(object):
    55    """Value type that stores corresponding weight.
    56  
    57    :arg value The value to be stored.
    58    :arg weight The associated weight of the value. If unspecified, the objects
    59    size will be used.
    60    """
    61    def __init__(self, value, weight):
    62      # type: (Any, int) -> None
    63      self._value = value
    64      if weight <= 0:
    65        raise ValueError(
    66            'Expected weight to be > 0 for %s but received %d' % (value, weight))
    67      self._weight = weight
    68  
    69    def weight(self):
    70      # type: () -> int
    71      return self._weight
    72  
    73    def value(self):
    74      # type: () -> Any
    75      return self._value
    76  
    77  
    78  class CacheAware(object):
    79    """Allows cache users to override what objects are measured."""
    80    def __init__(self):
    81      # type: () -> None
    82      pass
    83  
    84    def get_referents_for_cache(self):
    85      # type: () -> List[Any]
    86  
    87      """Returns the list of objects accounted during cache measurement."""
    88      raise NotImplementedError()
    89  
    90  
    91  def _safe_isinstance(obj, type):
    92    # type: (Any, Union[type, Tuple[type, ...]]) -> bool
    93  
    94    """
    95    Return whether an object is an instance of a class or of a subclass thereof.
    96    See `isinstance()` for more information.
    97  
    98    Returns false on `isinstance()` failure. For example applying `isinstance()`
    99    on `weakref.proxy` objects attempts to dereference the proxy objects, which
   100    may yield an exception. See https://github.com/apache/beam/issues/23389 for
   101    additional details.
   102    """
   103    try:
   104      return isinstance(obj, type)
   105    except Exception:
   106      return False
   107  
   108  
   109  def _size_func(obj):
   110    # type: (Any) -> int
   111  
   112    """
   113    Returns the size of the object or a default size if an error occurred during
   114    sizing.
   115    """
   116    try:
   117      return sys.getsizeof(obj)
   118    except Exception as e:
   119      current_time = time.time()
   120      # Limit outputting this log so we don't spam the logs on these
   121      # occurrences.
   122      if _size_func.last_log_time + 300 < current_time:  # type: ignore
   123        _LOGGER.warning(
   124            'Failed to size %s of type %s. Note that this may '
   125            'impact cache sizing such that the cache is over '
   126            'utilized which may lead to out of memory errors.',
   127            obj,
   128            type(obj),
   129            exc_info=e)
   130        _size_func.last_log_time = current_time  # type: ignore
   131      # Use an arbitrary default size that would account for some of the object
   132      # overhead.
   133      return _DEFAULT_WEIGHT
   134  
   135  
   136  _size_func.last_log_time = 0  # type: ignore
   137  
   138  
   139  def _get_referents_func(*objs):
   140    # type: (List[Any]) -> List[Any]
   141  
   142    """Returns the list of objects accounted during cache measurement.
   143  
   144    Users can inherit CacheAware to override which referents should be
   145    used when measuring the deep size of the object. The default is to
   146    use gc.get_referents(*objs).
   147    """
   148    rval = []
   149    for obj in objs:
   150      if _safe_isinstance(obj, CacheAware):
   151        rval.extend(obj.get_referents_for_cache())  # type: ignore
   152      else:
   153        rval.extend(gc.get_referents(obj))
   154    return rval
   155  
   156  
   157  def _filter_func(o):
   158    # type: (Any) -> bool
   159  
   160    """
   161    Filter out specific types from being measured.
   162  
   163    Note that we do want to measure the cost of weak references as they will only
   164    stay in scope as long as other code references them and will effectively be
   165    garbage collected as soon as there isn't a strong reference anymore.
   166  
   167    Note that we cannot use the default filter function due to isinstance raising
   168    an error on weakref.proxy types. See
   169    https://github.com/liran-funaro/objsize/issues/6 for additional details.
   170    """
   171    return not _safe_isinstance(o, _TYPES_TO_NOT_MEASURE)
   172  
   173  
   174  def get_deep_size(*objs):
   175    # type: (Any) -> int
   176  
   177    """Calculates the deep size of all the arguments in bytes."""
   178    return objsize.get_deep_size(
   179        *objs,
   180        get_size_func=_size_func,
   181        get_referents_func=_get_referents_func,
   182        filter_func=_filter_func)
   183  
   184  
   185  class _LoadingValue(WeightedValue):
   186    """Allows concurrent users of the cache to wait for a value to be loaded."""
   187    def __init__(self):
   188      # type: () -> None
   189      super().__init__(None, 1)
   190      self._wait_event = threading.Event()
   191  
   192    def load(self, key, loading_fn):
   193      # type: (Any, Callable[[Any], Any]) -> None
   194      try:
   195        self._value = loading_fn(key)
   196      except Exception as err:
   197        self._error = err
   198      finally:
   199        self._wait_event.set()
   200  
   201    def value(self):
   202      # type: () -> Any
   203      self._wait_event.wait()
   204      err = getattr(self, "_error", None)
   205      if err:
   206        raise err
   207      return self._value
   208  
   209  
   210  class StateCache(object):
   211    """LRU cache for Beam state access, scoped by state key and cache_token.
   212       Assumes a bag state implementation.
   213  
   214    For a given key, caches a value and allows to
   215      a) peek at the cache (peek),
   216             returns the value for the provided key or None if it doesn't exist.
   217             Will never block.
   218      b) read from the cache (get),
   219             returns the value for the provided key or loads it using the
   220             supplied function. Multiple calls for the same key will block
   221             until the value is loaded.
   222      c) write to the cache (put),
   223             store the provided value overwriting any previous result
   224      d) invalidate a cached element (invalidate)
   225             removes the value from the cache for the provided key
   226      e) invalidate all cached elements (invalidate_all)
   227  
   228    The operations on the cache are thread-safe for use by multiple workers.
   229  
   230    :arg max_weight The maximum weight of entries to store in the cache in bytes.
   231    """
   232    def __init__(self, max_weight):
   233      # type: (int) -> None
   234      _LOGGER.info('Creating state cache with size %s', max_weight)
   235      self._max_weight = max_weight
   236      self._current_weight = 0
   237      self._cache = collections.OrderedDict(
   238      )  # type: collections.OrderedDict[Any, WeightedValue]
   239      self._hit_count = 0
   240      self._miss_count = 0
   241      self._evict_count = 0
   242      self._load_time_ns = 0
   243      self._load_count = 0
   244      self._lock = threading.RLock()
   245  
   246    def peek(self, key):
   247      # type: (Any) -> Any
   248      assert self.is_cache_enabled()
   249      with self._lock:
   250        value = self._cache.get(key, None)
   251        if value is None or _safe_isinstance(value, _LoadingValue):
   252          self._miss_count += 1
   253          return None
   254  
   255        self._cache.move_to_end(key)
   256        self._hit_count += 1
   257      return value.value()
   258  
   259    def get(self, key, loading_fn):
   260      # type: (Any, Callable[[Any], Any]) -> Any
   261      assert self.is_cache_enabled() and callable(loading_fn)
   262  
   263      self._lock.acquire()
   264      value = self._cache.get(key, None)
   265  
   266      # Return the already cached value
   267      if value is not None:
   268        self._cache.move_to_end(key)
   269        self._hit_count += 1
   270        self._lock.release()
   271        return value.value()
   272  
   273      # Load the value since it isn't in the cache.
   274      self._miss_count += 1
   275      loading_value = _LoadingValue()
   276      self._cache[key] = loading_value
   277      self._current_weight += loading_value.weight()
   278  
   279      # Ensure that we unlock the lock while loading to allow for parallel gets
   280      self._lock.release()
   281  
   282      start_time_ns = time.time_ns()
   283      loading_value.load(key, loading_fn)
   284      elapsed_time_ns = time.time_ns() - start_time_ns
   285  
   286      try:
   287        value = loading_value.value()
   288      except Exception as err:
   289        # If loading failed then delete the value from the cache allowing for
   290        # the next lookup to possibly succeed.
   291        with self._lock:
   292          self._load_count += 1
   293          self._load_time_ns += elapsed_time_ns
   294          # Don't remove values that have already been replaced with a different
   295          # value by a put/invalidate that occurred concurrently with the load.
   296          # The put/invalidate will have been responsible for updating the
   297          # cache weight appropriately already.
   298          old_value = self._cache.get(key, None)
   299          if old_value is not loading_value:
   300            raise err
   301          self._current_weight -= loading_value.weight()
   302          del self._cache[key]
   303        raise err
   304  
   305      # Replace the value in the cache with a weighted value now that the
   306      # loading has completed successfully.
   307      weight = get_deep_size(value)
   308      if weight <= 0:
   309        _LOGGER.warning(
   310            'Expected object size to be >= 0 for %s but received %d.',
   311            value,
   312            weight)
   313        weight = 8
   314      value = WeightedValue(value, weight)
   315      with self._lock:
   316        self._load_count += 1
   317        self._load_time_ns += elapsed_time_ns
   318        # Don't replace values that have already been replaced with a different
   319        # value by a put/invalidate that occurred concurrently with the load.
   320        # The put/invalidate will have been responsible for updating the
   321        # cache weight appropriately already.
   322        old_value = self._cache.get(key, None)
   323        if old_value is not loading_value:
   324          return value.value()
   325  
   326        self._current_weight -= loading_value.weight()
   327        self._cache[key] = value
   328        self._current_weight += value.weight()
   329        while self._current_weight > self._max_weight:
   330          (_, weighted_value) = self._cache.popitem(last=False)
   331          self._current_weight -= weighted_value.weight()
   332          self._evict_count += 1
   333  
   334      return value.value()
   335  
   336    def put(self, key, value):
   337      # type: (Any, Any) -> None
   338      assert self.is_cache_enabled()
   339      if not _safe_isinstance(value, WeightedValue):
   340        weight = get_deep_size(value)
   341        if weight <= 0:
   342          _LOGGER.warning(
   343              'Expected object size to be >= 0 for %s but received %d.',
   344              value,
   345              weight)
   346          weight = _DEFAULT_WEIGHT
   347        value = WeightedValue(value, weight)
   348      with self._lock:
   349        old_value = self._cache.pop(key, None)
   350        if old_value is not None:
   351          self._current_weight -= old_value.weight()
   352        self._cache[key] = value
   353        self._current_weight += value.weight()
   354        while self._current_weight > self._max_weight:
   355          (_, weighted_value) = self._cache.popitem(last=False)
   356          self._current_weight -= weighted_value.weight()
   357          self._evict_count += 1
   358  
   359    def invalidate(self, key):
   360      # type: (Any) -> None
   361      assert self.is_cache_enabled()
   362      with self._lock:
   363        weighted_value = self._cache.pop(key, None)
   364        if weighted_value is not None:
   365          self._current_weight -= weighted_value.weight()
   366  
   367    def invalidate_all(self):
   368      # type: () -> None
   369      with self._lock:
   370        self._cache.clear()
   371        self._current_weight = 0
   372  
   373    def describe_stats(self):
   374      # type: () -> str
   375      with self._lock:
   376        request_count = self._hit_count + self._miss_count
   377        if request_count > 0:
   378          hit_ratio = 100.0 * self._hit_count / request_count
   379        else:
   380          hit_ratio = 100.0
   381        return (
   382            'used/max %d/%d MB, hit %.2f%%, lookups %d, '
   383            'avg load time %.0f ns, loads %d, evictions %d') % (
   384                self._current_weight >> 20,
   385                self._max_weight >> 20,
   386                hit_ratio,
   387                request_count,
   388                self._load_time_ns /
   389                self._load_count if self._load_count > 0 else 0,
   390                self._load_count,
   391                self._evict_count)
   392  
   393    def is_cache_enabled(self):
   394      # type: () -> bool
   395      return self._max_weight > 0
   396  
   397    def size(self):
   398      # type: () -> int
   399      with self._lock:
   400        return len(self._cache)