github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/worker/statecache_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Tests for state caching.""" 19 # pytype: skip-file 20 21 import logging 22 import re 23 import sys 24 import threading 25 import time 26 import unittest 27 import weakref 28 29 import objsize 30 from hamcrest import assert_that 31 from hamcrest import contains_string 32 33 from apache_beam.runners.worker.statecache import CacheAware 34 from apache_beam.runners.worker.statecache import StateCache 35 from apache_beam.runners.worker.statecache import WeightedValue 36 from apache_beam.runners.worker.statecache import _LoadingValue 37 from apache_beam.runners.worker.statecache import get_deep_size 38 39 40 class StateCacheTest(unittest.TestCase): 41 def test_weakref(self): 42 test_value = WeightedValue('test', 10 << 20) 43 44 class WeightedValueRef(): 45 def __init__(self): 46 self.ref = weakref.ref(test_value) 47 48 cache = StateCache(5 << 20) 49 wait_event = threading.Event() 50 o = WeightedValueRef() 51 cache.put('deep ref', o) 52 # Ensure that the contents of the internal weak ref isn't sized 53 self.assertIsNotNone(cache.peek('deep ref')) 54 self.assertEqual( 55 cache.describe_stats(), 56 'used/max 0/5 MB, hit 100.00%, lookups 1, avg load time 0 ns, loads 0, ' 57 'evictions 0') 58 cache.invalidate_all() 59 60 # Ensure that putting in a weakref doesn't fail regardless of whether 61 # it is alive or not 62 o_ref = weakref.ref(o, lambda value: wait_event.set()) 63 cache.put('not deleted ref', o_ref) 64 del o 65 wait_event.wait() 66 cache.put('deleted', o_ref) 67 68 def test_weakref_proxy(self): 69 test_value = WeightedValue('test', 10 << 20) 70 71 class WeightedValueRef(): 72 def __init__(self): 73 self.ref = weakref.ref(test_value) 74 75 cache = StateCache(5 << 20) 76 wait_event = threading.Event() 77 o = WeightedValueRef() 78 cache.put('deep ref', o) 79 # Ensure that the contents of the internal weak ref isn't sized 80 self.assertIsNotNone(cache.peek('deep ref')) 81 self.assertEqual( 82 cache.describe_stats(), 83 'used/max 0/5 MB, hit 100.00%, lookups 1, avg load time 0 ns, loads 0, ' 84 'evictions 0') 85 cache.invalidate_all() 86 87 # Ensure that putting in a weakref doesn't fail regardless of whether 88 # it is alive or not 89 o_ref = weakref.proxy(o, lambda value: wait_event.set()) 90 cache.put('not deleted', o_ref) 91 del o 92 wait_event.wait() 93 cache.put('deleted', o_ref) 94 95 def test_size_of_fails(self): 96 class BadSizeOf(object): 97 def __sizeof__(self): 98 raise RuntimeError("TestRuntimeError") 99 100 cache = StateCache(5 << 20) 101 with self.assertLogs('apache_beam.runners.worker.statecache', 102 level='WARNING') as context: 103 cache.put('key', BadSizeOf()) 104 self.assertEqual(1, len(context.output)) 105 self.assertTrue('Failed to size' in context.output[0]) 106 # Test that we don't spam the logs 107 cache.put('key', BadSizeOf()) 108 self.assertEqual(1, len(context.output)) 109 110 def test_empty_cache_peek(self): 111 cache = StateCache(5 << 20) 112 self.assertEqual(cache.peek("key"), None) 113 self.assertEqual( 114 cache.describe_stats(), 115 ( 116 'used/max 0/5 MB, hit 0.00%, lookups 1, ' 117 'avg load time 0 ns, loads 0, evictions 0')) 118 119 def test_put_peek(self): 120 cache = StateCache(5 << 20) 121 cache.put("key", WeightedValue("value", 1 << 20)) 122 self.assertEqual(cache.size(), 1) 123 self.assertEqual(cache.peek("key"), "value") 124 self.assertEqual(cache.peek("key2"), None) 125 self.assertEqual( 126 cache.describe_stats(), 127 ( 128 'used/max 1/5 MB, hit 50.00%, lookups 2, ' 129 'avg load time 0 ns, loads 0, evictions 0')) 130 131 def test_default_sized_put(self): 132 cache = StateCache(5 << 20) 133 cache.put("key", bytearray(1 << 20)) 134 cache.put("key2", bytearray(1 << 20)) 135 cache.put("key3", bytearray(1 << 20)) 136 self.assertEqual(cache.peek("key3"), bytearray(1 << 20)) 137 cache.put("key4", bytearray(1 << 20)) 138 cache.put("key5", bytearray(1 << 20)) 139 # note that each byte array instance takes slightly over 1 MB which is why 140 # these 5 byte arrays can't all be stored in the cache causing a single 141 # eviction 142 self.assertEqual( 143 cache.describe_stats(), 144 ( 145 'used/max 4/5 MB, hit 100.00%, lookups 1, ' 146 'avg load time 0 ns, loads 0, evictions 1')) 147 148 def test_max_size(self): 149 cache = StateCache(2 << 20) 150 cache.put("key", WeightedValue("value", 1 << 20)) 151 cache.put("key2", WeightedValue("value2", 1 << 20)) 152 self.assertEqual(cache.size(), 2) 153 cache.put("key3", WeightedValue("value3", 1 << 20)) 154 self.assertEqual(cache.size(), 2) 155 self.assertEqual( 156 cache.describe_stats(), 157 ( 158 'used/max 2/2 MB, hit 100.00%, lookups 0, ' 159 'avg load time 0 ns, loads 0, evictions 1')) 160 161 def test_invalidate_all(self): 162 cache = StateCache(5 << 20) 163 cache.put("key", WeightedValue("value", 1 << 20)) 164 cache.put("key2", WeightedValue("value2", 1 << 20)) 165 self.assertEqual(cache.size(), 2) 166 cache.invalidate_all() 167 self.assertEqual(cache.size(), 0) 168 self.assertEqual(cache.peek("key"), None) 169 self.assertEqual(cache.peek("key2"), None) 170 self.assertEqual( 171 cache.describe_stats(), 172 ( 173 'used/max 0/5 MB, hit 0.00%, lookups 2, ' 174 'avg load time 0 ns, loads 0, evictions 0')) 175 176 def test_lru(self): 177 cache = StateCache(5 << 20) 178 cache.put("key", WeightedValue("value", 1 << 20)) 179 cache.put("key2", WeightedValue("value2", 1 << 20)) 180 cache.put("key3", WeightedValue("value0", 1 << 20)) 181 cache.put("key3", WeightedValue("value3", 1 << 20)) 182 cache.put("key4", WeightedValue("value4", 1 << 20)) 183 cache.put("key5", WeightedValue("value0", 1 << 20)) 184 cache.put("key5", WeightedValue(["value5"], 1 << 20)) 185 self.assertEqual(cache.size(), 5) 186 self.assertEqual(cache.peek("key"), "value") 187 self.assertEqual(cache.peek("key2"), "value2") 188 self.assertEqual(cache.peek("key3"), "value3") 189 self.assertEqual(cache.peek("key4"), "value4") 190 self.assertEqual(cache.peek("key5"), ["value5"]) 191 # insert another key to trigger cache eviction 192 cache.put("key6", WeightedValue("value6", 1 << 20)) 193 self.assertEqual(cache.size(), 5) 194 # least recently used key should be gone ("key") 195 self.assertEqual(cache.peek("key"), None) 196 # trigger a read on "key2" 197 cache.peek("key2") 198 # insert another key to trigger cache eviction 199 cache.put("key7", WeightedValue("value7", 1 << 20)) 200 self.assertEqual(cache.size(), 5) 201 # least recently used key should be gone ("key3") 202 self.assertEqual(cache.peek("key3"), None) 203 # insert another key to trigger cache eviction 204 cache.put("key8", WeightedValue("put", 1 << 20)) 205 self.assertEqual(cache.size(), 5) 206 # insert another key to trigger cache eviction 207 cache.put("key9", WeightedValue("value8", 1 << 20)) 208 self.assertEqual(cache.size(), 5) 209 # least recently used key should be gone ("key4") 210 self.assertEqual(cache.peek("key4"), None) 211 # make "key5" used by writing to it 212 cache.put("key5", WeightedValue("val", 1 << 20)) 213 # least recently used key should be gone ("key6") 214 self.assertEqual(cache.peek("key6"), None) 215 self.assertEqual( 216 cache.describe_stats(), 217 ( 218 'used/max 5/5 MB, hit 60.00%, lookups 10, ' 219 'avg load time 0 ns, loads 0, evictions 5')) 220 221 def test_get(self): 222 def check_key(key): 223 self.assertEqual(key, "key") 224 time.sleep(0.5) 225 return "value" 226 227 def raise_exception(key): 228 time.sleep(0.5) 229 raise Exception("TestException") 230 231 cache = StateCache(5 << 20) 232 self.assertEqual("value", cache.get("key", check_key)) 233 with cache._lock: 234 self.assertFalse(isinstance(cache._cache["key"], _LoadingValue)) 235 self.assertEqual("value", cache.peek("key")) 236 cache.invalidate_all() 237 238 with self.assertRaisesRegex(Exception, "TestException"): 239 cache.get("key", raise_exception) 240 # The cache should not have the value after the failing load causing 241 # check_key to load the value. 242 self.assertEqual("value", cache.get("key", check_key)) 243 with cache._lock: 244 self.assertFalse(isinstance(cache._cache["key"], _LoadingValue)) 245 self.assertEqual("value", cache.peek("key")) 246 247 assert_that(cache.describe_stats(), contains_string(", loads 3,")) 248 load_time_ns = re.search( 249 ", avg load time (.+) ns,", cache.describe_stats()).group(1) 250 # Load time should be larger then the sleep time and less than 2x sleep time 251 self.assertGreater(int(load_time_ns), 0.5 * 1_000_000_000) 252 self.assertLess(int(load_time_ns), 1_000_000_000) 253 254 def test_concurrent_get_waits(self): 255 event = threading.Semaphore(0) 256 threads_running = threading.Barrier(3) 257 258 def wait_for_event(key): 259 with cache._lock: 260 self.assertTrue(isinstance(cache._cache["key"], _LoadingValue)) 261 event.release() 262 return "value" 263 264 cache = StateCache(5 << 20) 265 266 def load_key(output): 267 threads_running.wait() 268 output["value"] = cache.get("key", wait_for_event) 269 output["time"] = time.time_ns() 270 271 t1_output = {} 272 t1 = threading.Thread(target=load_key, args=(t1_output, )) 273 t1.start() 274 275 t2_output = {} 276 t2 = threading.Thread(target=load_key, args=(t2_output, )) 277 t2.start() 278 279 # Wait for both threads to start 280 threads_running.wait() 281 # Record the time and wait for the load to start 282 current_time_ns = time.time_ns() 283 event.acquire() 284 t1.join() 285 t2.join() 286 287 # Ensure that only one thread did the loading and not both by checking that 288 # the semaphore was only released once 289 self.assertFalse(event.acquire(blocking=False)) 290 291 # Ensure that the load time is greater than the set time ensuring that 292 # both loads had to wait for the event 293 self.assertLessEqual(current_time_ns, t1_output["time"]) 294 self.assertLessEqual(current_time_ns, t2_output["time"]) 295 self.assertEqual("value", t1_output["value"]) 296 self.assertEqual("value", t2_output["value"]) 297 self.assertEqual("value", cache.peek("key")) 298 299 def test_concurrent_get_superseded_by_put(self): 300 load_happening = threading.Event() 301 finish_loading = threading.Event() 302 303 def wait_for_event(key): 304 load_happening.set() 305 finish_loading.wait() 306 return "value" 307 308 cache = StateCache(5 << 20) 309 310 def load_key(output): 311 output["value"] = cache.get("key", wait_for_event) 312 313 t1_output = {} 314 t1 = threading.Thread(target=load_key, args=(t1_output, )) 315 t1.start() 316 317 # Wait for the load to start, update the key, and then let the load finish 318 load_happening.wait() 319 cache.put("key", "value2") 320 finish_loading.set() 321 t1.join() 322 323 # Ensure that the original value is loaded and returned and not the 324 # updated value 325 self.assertEqual("value", t1_output["value"]) 326 # Ensure that the updated value supersedes the loaded value. 327 self.assertEqual("value2", cache.peek("key")) 328 329 def test_is_cached_enabled(self): 330 cache = StateCache(1 << 20) 331 self.assertEqual(cache.is_cache_enabled(), True) 332 self.assertEqual( 333 cache.describe_stats(), 334 ( 335 'used/max 0/1 MB, hit 100.00%, lookups 0, ' 336 'avg load time 0 ns, loads 0, evictions 0')) 337 cache = StateCache(0) 338 self.assertEqual(cache.is_cache_enabled(), False) 339 self.assertEqual( 340 cache.describe_stats(), 341 ( 342 'used/max 0/0 MB, hit 100.00%, lookups 0, ' 343 'avg load time 0 ns, loads 0, evictions 0')) 344 345 def test_get_referents_for_cache(self): 346 class GetReferentsForCache(CacheAware): 347 def __init__(self): 348 self.measure_me = bytearray(1 << 20) 349 self.ignore_me = bytearray(2 << 20) 350 351 def get_referents_for_cache(self): 352 return [self.measure_me] 353 354 cache = StateCache(5 << 20) 355 cache.put("key", GetReferentsForCache()) 356 self.assertEqual( 357 cache.describe_stats(), 358 ( 359 'used/max 1/5 MB, hit 100.00%, lookups 0, ' 360 'avg load time 0 ns, loads 0, evictions 0')) 361 362 def test_get_deep_size_builtin_objects(self): 363 """ 364 `statecache.get_deep_copy` should work same with objsize unless the `objs` 365 has `CacheAware` or a filtered object. They should return the same size for 366 built-in objects. 367 """ 368 primitive_test_objects = [ 369 1, # int 370 2.0, # float 371 1+1j, # complex 372 True, # bool 373 'hello,world', # str 374 b'\00\01\02', # bytes 375 ] 376 377 collection_test_objects = [ 378 [3, 4, 5], # list 379 (6, 7), # tuple 380 {'a', 'b', 'c'}, # set 381 {'k': 8, 'l': 9}, # dict 382 ] 383 384 for obj in primitive_test_objects: 385 self.assertEqual( 386 get_deep_size(obj), 387 objsize.get_deep_size(obj), 388 f'different size for obj: `{obj}`, type: {type(obj)}') 389 self.assertEqual( 390 get_deep_size(obj), 391 sys.getsizeof(obj), 392 f'different size for obj: `{obj}`, type: {type(obj)}') 393 394 for obj in collection_test_objects: 395 self.assertEqual( 396 get_deep_size(obj), 397 objsize.get_deep_size(obj), 398 f'different size for obj: `{obj}`, type: {type(obj)}') 399 400 def test_current_weight_between_get_and_put(self): 401 value = 1234567 402 get_cache = StateCache(100) 403 get_cache.get("key", lambda k: value) 404 405 put_cache = StateCache(100) 406 put_cache.put("key", value) 407 408 self.assertEqual(get_cache._current_weight, put_cache._current_weight) 409 410 411 if __name__ == '__main__': 412 logging.getLogger().setLevel(logging.INFO) 413 unittest.main()