github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/worker/statecache_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Tests for state caching."""
    19  # pytype: skip-file
    20  
    21  import logging
    22  import re
    23  import sys
    24  import threading
    25  import time
    26  import unittest
    27  import weakref
    28  
    29  import objsize
    30  from hamcrest import assert_that
    31  from hamcrest import contains_string
    32  
    33  from apache_beam.runners.worker.statecache import CacheAware
    34  from apache_beam.runners.worker.statecache import StateCache
    35  from apache_beam.runners.worker.statecache import WeightedValue
    36  from apache_beam.runners.worker.statecache import _LoadingValue
    37  from apache_beam.runners.worker.statecache import get_deep_size
    38  
    39  
    40  class StateCacheTest(unittest.TestCase):
    41    def test_weakref(self):
    42      test_value = WeightedValue('test', 10 << 20)
    43  
    44      class WeightedValueRef():
    45        def __init__(self):
    46          self.ref = weakref.ref(test_value)
    47  
    48      cache = StateCache(5 << 20)
    49      wait_event = threading.Event()
    50      o = WeightedValueRef()
    51      cache.put('deep ref', o)
    52      # Ensure that the contents of the internal weak ref isn't sized
    53      self.assertIsNotNone(cache.peek('deep ref'))
    54      self.assertEqual(
    55          cache.describe_stats(),
    56          'used/max 0/5 MB, hit 100.00%, lookups 1, avg load time 0 ns, loads 0, '
    57          'evictions 0')
    58      cache.invalidate_all()
    59  
    60      # Ensure that putting in a weakref doesn't fail regardless of whether
    61      # it is alive or not
    62      o_ref = weakref.ref(o, lambda value: wait_event.set())
    63      cache.put('not deleted ref', o_ref)
    64      del o
    65      wait_event.wait()
    66      cache.put('deleted', o_ref)
    67  
    68    def test_weakref_proxy(self):
    69      test_value = WeightedValue('test', 10 << 20)
    70  
    71      class WeightedValueRef():
    72        def __init__(self):
    73          self.ref = weakref.ref(test_value)
    74  
    75      cache = StateCache(5 << 20)
    76      wait_event = threading.Event()
    77      o = WeightedValueRef()
    78      cache.put('deep ref', o)
    79      # Ensure that the contents of the internal weak ref isn't sized
    80      self.assertIsNotNone(cache.peek('deep ref'))
    81      self.assertEqual(
    82          cache.describe_stats(),
    83          'used/max 0/5 MB, hit 100.00%, lookups 1, avg load time 0 ns, loads 0, '
    84          'evictions 0')
    85      cache.invalidate_all()
    86  
    87      # Ensure that putting in a weakref doesn't fail regardless of whether
    88      # it is alive or not
    89      o_ref = weakref.proxy(o, lambda value: wait_event.set())
    90      cache.put('not deleted', o_ref)
    91      del o
    92      wait_event.wait()
    93      cache.put('deleted', o_ref)
    94  
    95    def test_size_of_fails(self):
    96      class BadSizeOf(object):
    97        def __sizeof__(self):
    98          raise RuntimeError("TestRuntimeError")
    99  
   100      cache = StateCache(5 << 20)
   101      with self.assertLogs('apache_beam.runners.worker.statecache',
   102                           level='WARNING') as context:
   103        cache.put('key', BadSizeOf())
   104        self.assertEqual(1, len(context.output))
   105        self.assertTrue('Failed to size' in context.output[0])
   106        # Test that we don't spam the logs
   107        cache.put('key', BadSizeOf())
   108        self.assertEqual(1, len(context.output))
   109  
   110    def test_empty_cache_peek(self):
   111      cache = StateCache(5 << 20)
   112      self.assertEqual(cache.peek("key"), None)
   113      self.assertEqual(
   114          cache.describe_stats(),
   115          (
   116              'used/max 0/5 MB, hit 0.00%, lookups 1, '
   117              'avg load time 0 ns, loads 0, evictions 0'))
   118  
   119    def test_put_peek(self):
   120      cache = StateCache(5 << 20)
   121      cache.put("key", WeightedValue("value", 1 << 20))
   122      self.assertEqual(cache.size(), 1)
   123      self.assertEqual(cache.peek("key"), "value")
   124      self.assertEqual(cache.peek("key2"), None)
   125      self.assertEqual(
   126          cache.describe_stats(),
   127          (
   128              'used/max 1/5 MB, hit 50.00%, lookups 2, '
   129              'avg load time 0 ns, loads 0, evictions 0'))
   130  
   131    def test_default_sized_put(self):
   132      cache = StateCache(5 << 20)
   133      cache.put("key", bytearray(1 << 20))
   134      cache.put("key2", bytearray(1 << 20))
   135      cache.put("key3", bytearray(1 << 20))
   136      self.assertEqual(cache.peek("key3"), bytearray(1 << 20))
   137      cache.put("key4", bytearray(1 << 20))
   138      cache.put("key5", bytearray(1 << 20))
   139      # note that each byte array instance takes slightly over 1 MB which is why
   140      # these 5 byte arrays can't all be stored in the cache causing a single
   141      # eviction
   142      self.assertEqual(
   143          cache.describe_stats(),
   144          (
   145              'used/max 4/5 MB, hit 100.00%, lookups 1, '
   146              'avg load time 0 ns, loads 0, evictions 1'))
   147  
   148    def test_max_size(self):
   149      cache = StateCache(2 << 20)
   150      cache.put("key", WeightedValue("value", 1 << 20))
   151      cache.put("key2", WeightedValue("value2", 1 << 20))
   152      self.assertEqual(cache.size(), 2)
   153      cache.put("key3", WeightedValue("value3", 1 << 20))
   154      self.assertEqual(cache.size(), 2)
   155      self.assertEqual(
   156          cache.describe_stats(),
   157          (
   158              'used/max 2/2 MB, hit 100.00%, lookups 0, '
   159              'avg load time 0 ns, loads 0, evictions 1'))
   160  
   161    def test_invalidate_all(self):
   162      cache = StateCache(5 << 20)
   163      cache.put("key", WeightedValue("value", 1 << 20))
   164      cache.put("key2", WeightedValue("value2", 1 << 20))
   165      self.assertEqual(cache.size(), 2)
   166      cache.invalidate_all()
   167      self.assertEqual(cache.size(), 0)
   168      self.assertEqual(cache.peek("key"), None)
   169      self.assertEqual(cache.peek("key2"), None)
   170      self.assertEqual(
   171          cache.describe_stats(),
   172          (
   173              'used/max 0/5 MB, hit 0.00%, lookups 2, '
   174              'avg load time 0 ns, loads 0, evictions 0'))
   175  
   176    def test_lru(self):
   177      cache = StateCache(5 << 20)
   178      cache.put("key", WeightedValue("value", 1 << 20))
   179      cache.put("key2", WeightedValue("value2", 1 << 20))
   180      cache.put("key3", WeightedValue("value0", 1 << 20))
   181      cache.put("key3", WeightedValue("value3", 1 << 20))
   182      cache.put("key4", WeightedValue("value4", 1 << 20))
   183      cache.put("key5", WeightedValue("value0", 1 << 20))
   184      cache.put("key5", WeightedValue(["value5"], 1 << 20))
   185      self.assertEqual(cache.size(), 5)
   186      self.assertEqual(cache.peek("key"), "value")
   187      self.assertEqual(cache.peek("key2"), "value2")
   188      self.assertEqual(cache.peek("key3"), "value3")
   189      self.assertEqual(cache.peek("key4"), "value4")
   190      self.assertEqual(cache.peek("key5"), ["value5"])
   191      # insert another key to trigger cache eviction
   192      cache.put("key6", WeightedValue("value6", 1 << 20))
   193      self.assertEqual(cache.size(), 5)
   194      # least recently used key should be gone ("key")
   195      self.assertEqual(cache.peek("key"), None)
   196      # trigger a read on "key2"
   197      cache.peek("key2")
   198      # insert another key to trigger cache eviction
   199      cache.put("key7", WeightedValue("value7", 1 << 20))
   200      self.assertEqual(cache.size(), 5)
   201      # least recently used key should be gone ("key3")
   202      self.assertEqual(cache.peek("key3"), None)
   203      # insert another key to trigger cache eviction
   204      cache.put("key8", WeightedValue("put", 1 << 20))
   205      self.assertEqual(cache.size(), 5)
   206      # insert another key to trigger cache eviction
   207      cache.put("key9", WeightedValue("value8", 1 << 20))
   208      self.assertEqual(cache.size(), 5)
   209      # least recently used key should be gone ("key4")
   210      self.assertEqual(cache.peek("key4"), None)
   211      # make "key5" used by writing to it
   212      cache.put("key5", WeightedValue("val", 1 << 20))
   213      # least recently used key should be gone ("key6")
   214      self.assertEqual(cache.peek("key6"), None)
   215      self.assertEqual(
   216          cache.describe_stats(),
   217          (
   218              'used/max 5/5 MB, hit 60.00%, lookups 10, '
   219              'avg load time 0 ns, loads 0, evictions 5'))
   220  
   221    def test_get(self):
   222      def check_key(key):
   223        self.assertEqual(key, "key")
   224        time.sleep(0.5)
   225        return "value"
   226  
   227      def raise_exception(key):
   228        time.sleep(0.5)
   229        raise Exception("TestException")
   230  
   231      cache = StateCache(5 << 20)
   232      self.assertEqual("value", cache.get("key", check_key))
   233      with cache._lock:
   234        self.assertFalse(isinstance(cache._cache["key"], _LoadingValue))
   235      self.assertEqual("value", cache.peek("key"))
   236      cache.invalidate_all()
   237  
   238      with self.assertRaisesRegex(Exception, "TestException"):
   239        cache.get("key", raise_exception)
   240      # The cache should not have the value after the failing load causing
   241      # check_key to load the value.
   242      self.assertEqual("value", cache.get("key", check_key))
   243      with cache._lock:
   244        self.assertFalse(isinstance(cache._cache["key"], _LoadingValue))
   245      self.assertEqual("value", cache.peek("key"))
   246  
   247      assert_that(cache.describe_stats(), contains_string(", loads 3,"))
   248      load_time_ns = re.search(
   249          ", avg load time (.+) ns,", cache.describe_stats()).group(1)
   250      # Load time should be larger then the sleep time and less than 2x sleep time
   251      self.assertGreater(int(load_time_ns), 0.5 * 1_000_000_000)
   252      self.assertLess(int(load_time_ns), 1_000_000_000)
   253  
   254    def test_concurrent_get_waits(self):
   255      event = threading.Semaphore(0)
   256      threads_running = threading.Barrier(3)
   257  
   258      def wait_for_event(key):
   259        with cache._lock:
   260          self.assertTrue(isinstance(cache._cache["key"], _LoadingValue))
   261        event.release()
   262        return "value"
   263  
   264      cache = StateCache(5 << 20)
   265  
   266      def load_key(output):
   267        threads_running.wait()
   268        output["value"] = cache.get("key", wait_for_event)
   269        output["time"] = time.time_ns()
   270  
   271      t1_output = {}
   272      t1 = threading.Thread(target=load_key, args=(t1_output, ))
   273      t1.start()
   274  
   275      t2_output = {}
   276      t2 = threading.Thread(target=load_key, args=(t2_output, ))
   277      t2.start()
   278  
   279      # Wait for both threads to start
   280      threads_running.wait()
   281      # Record the time and wait for the load to start
   282      current_time_ns = time.time_ns()
   283      event.acquire()
   284      t1.join()
   285      t2.join()
   286  
   287      # Ensure that only one thread did the loading and not both by checking that
   288      # the semaphore was only released once
   289      self.assertFalse(event.acquire(blocking=False))
   290  
   291      # Ensure that the load time is greater than the set time ensuring that
   292      # both loads had to wait for the event
   293      self.assertLessEqual(current_time_ns, t1_output["time"])
   294      self.assertLessEqual(current_time_ns, t2_output["time"])
   295      self.assertEqual("value", t1_output["value"])
   296      self.assertEqual("value", t2_output["value"])
   297      self.assertEqual("value", cache.peek("key"))
   298  
   299    def test_concurrent_get_superseded_by_put(self):
   300      load_happening = threading.Event()
   301      finish_loading = threading.Event()
   302  
   303      def wait_for_event(key):
   304        load_happening.set()
   305        finish_loading.wait()
   306        return "value"
   307  
   308      cache = StateCache(5 << 20)
   309  
   310      def load_key(output):
   311        output["value"] = cache.get("key", wait_for_event)
   312  
   313      t1_output = {}
   314      t1 = threading.Thread(target=load_key, args=(t1_output, ))
   315      t1.start()
   316  
   317      # Wait for the load to start, update the key, and then let the load finish
   318      load_happening.wait()
   319      cache.put("key", "value2")
   320      finish_loading.set()
   321      t1.join()
   322  
   323      # Ensure that the original value is loaded and returned and not the
   324      # updated value
   325      self.assertEqual("value", t1_output["value"])
   326      # Ensure that the updated value supersedes the loaded value.
   327      self.assertEqual("value2", cache.peek("key"))
   328  
   329    def test_is_cached_enabled(self):
   330      cache = StateCache(1 << 20)
   331      self.assertEqual(cache.is_cache_enabled(), True)
   332      self.assertEqual(
   333          cache.describe_stats(),
   334          (
   335              'used/max 0/1 MB, hit 100.00%, lookups 0, '
   336              'avg load time 0 ns, loads 0, evictions 0'))
   337      cache = StateCache(0)
   338      self.assertEqual(cache.is_cache_enabled(), False)
   339      self.assertEqual(
   340          cache.describe_stats(),
   341          (
   342              'used/max 0/0 MB, hit 100.00%, lookups 0, '
   343              'avg load time 0 ns, loads 0, evictions 0'))
   344  
   345    def test_get_referents_for_cache(self):
   346      class GetReferentsForCache(CacheAware):
   347        def __init__(self):
   348          self.measure_me = bytearray(1 << 20)
   349          self.ignore_me = bytearray(2 << 20)
   350  
   351        def get_referents_for_cache(self):
   352          return [self.measure_me]
   353  
   354      cache = StateCache(5 << 20)
   355      cache.put("key", GetReferentsForCache())
   356      self.assertEqual(
   357          cache.describe_stats(),
   358          (
   359              'used/max 1/5 MB, hit 100.00%, lookups 0, '
   360              'avg load time 0 ns, loads 0, evictions 0'))
   361  
   362    def test_get_deep_size_builtin_objects(self):
   363      """
   364      `statecache.get_deep_copy` should work same with objsize unless the `objs`
   365      has `CacheAware` or a filtered object. They should return the same size for
   366      built-in objects.
   367      """
   368      primitive_test_objects = [
   369          1,                    # int
   370          2.0,                  # float
   371          1+1j,                 # complex
   372          True,                 # bool
   373          'hello,world',        # str
   374          b'\00\01\02',         # bytes
   375      ]
   376  
   377      collection_test_objects = [
   378          [3, 4, 5],            # list
   379          (6, 7),               # tuple
   380          {'a', 'b', 'c'},      # set
   381          {'k': 8, 'l': 9},     # dict
   382      ]
   383  
   384      for obj in primitive_test_objects:
   385        self.assertEqual(
   386            get_deep_size(obj),
   387            objsize.get_deep_size(obj),
   388            f'different size for obj: `{obj}`, type: {type(obj)}')
   389        self.assertEqual(
   390            get_deep_size(obj),
   391            sys.getsizeof(obj),
   392            f'different size for obj: `{obj}`, type: {type(obj)}')
   393  
   394      for obj in collection_test_objects:
   395        self.assertEqual(
   396            get_deep_size(obj),
   397            objsize.get_deep_size(obj),
   398            f'different size for obj: `{obj}`, type: {type(obj)}')
   399  
   400    def test_current_weight_between_get_and_put(self):
   401      value = 1234567
   402      get_cache = StateCache(100)
   403      get_cache.get("key", lambda k: value)
   404  
   405      put_cache = StateCache(100)
   406      put_cache.put("key", value)
   407  
   408      self.assertEqual(get_cache._current_weight, put_cache._current_weight)
   409  
   410  
   411  if __name__ == '__main__':
   412    logging.getLogger().setLevel(logging.INFO)
   413    unittest.main()