github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/utils/shared_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Test for Shared class.""" 19 20 import gc 21 import threading 22 import time 23 import unittest 24 25 from apache_beam.utils import shared 26 27 28 class Count(object): 29 def __init__(self): 30 self._lock = threading.Lock() 31 self._total = 0 32 self._active = 0 33 34 def add_ref(self): 35 with self._lock: 36 self._total += 1 37 self._active += 1 38 39 def release_ref(self): 40 with self._lock: 41 self._active -= 1 42 43 def get_active(self): 44 with self._lock: 45 return self._active 46 47 def get_total(self): 48 with self._lock: 49 return self._total 50 51 52 class Marker(object): 53 def __init__(self, count): 54 self._count = count 55 self._count.add_ref() 56 57 def __del__(self): 58 self._count.release_ref() 59 60 61 class NamedObject(object): 62 def __init__(self, name): 63 self._name = name 64 65 def get_name(self): 66 return self._name 67 68 69 class Sequence(object): 70 def __init__(self): 71 self._sequence = 0 72 73 def make_acquire_fn(self): 74 # Every time acquire_fn is called, increases the sequence number and returns 75 # a NamedObject with that sequenece number. 76 def acquire_fn(): 77 self._sequence += 1 78 return NamedObject('sequence%d' % self._sequence) 79 80 return acquire_fn 81 82 83 class SharedTest(unittest.TestCase): 84 def testKeepalive(self): 85 count = Count() 86 shared_handle = shared.Shared() 87 other_shared_handle = shared.Shared() 88 89 def dummy_acquire_fn(): 90 return None 91 92 def acquire_fn(): 93 return Marker(count) 94 95 p1 = shared_handle.acquire(acquire_fn) 96 self.assertEqual(1, count.get_total()) 97 self.assertEqual(1, count.get_active()) 98 del p1 99 gc.collect() 100 # Won't be garbage collected, because of the keep-alive 101 self.assertEqual(1, count.get_active()) 102 103 # Reacquire. 104 p2 = shared_handle.acquire(acquire_fn) 105 self.assertEqual(1, count.get_total()) # No reinitialisation. 106 self.assertEqual(1, count.get_active()) 107 108 # Get rid of the keepalive 109 other_shared_handle.acquire(dummy_acquire_fn) 110 del p2 111 gc.collect() 112 self.assertEqual(0, count.get_active()) 113 114 def testMultiple(self): 115 count = Count() 116 shared_handle = shared.Shared() 117 other_shared_handle = shared.Shared() 118 119 def dummy_acquire_fn(): 120 return None 121 122 def acquire_fn(): 123 return Marker(count) 124 125 p = shared_handle.acquire(acquire_fn) 126 other_shared_handle.acquire(dummy_acquire_fn) # Get rid of the keepalive 127 self.assertEqual(1, count.get_total()) 128 self.assertEqual(1, count.get_active()) 129 del p 130 gc.collect() 131 self.assertEqual(0, count.get_active()) 132 # Shared value should be garbage collected. 133 134 # Acquiring multiple times only results in one initialisation 135 p1 = shared_handle.acquire(acquire_fn) 136 # Since shared value was released, expect a reinitialisation. 137 self.assertEqual(2, count.get_total()) 138 self.assertEqual(1, count.get_active()) 139 p2 = shared_handle.acquire(acquire_fn) 140 self.assertEqual(2, count.get_total()) 141 self.assertEqual(1, count.get_active()) 142 143 other_shared_handle.acquire(dummy_acquire_fn) # Get rid of the keepalive 144 145 # Check that shared object isn't destroyed if there's still a reference to 146 # it. 147 del p2 148 gc.collect() 149 self.assertEqual(1, count.get_active()) 150 151 del p1 152 gc.collect() 153 self.assertEqual(0, count.get_active()) 154 155 def testConcurrentCallsDeduped(self): 156 # Test that only one among many calls to acquire will actually run the 157 # initialisation function. 158 159 count = Count() 160 shared_handle = shared.Shared() 161 other_shared_handle = shared.Shared() 162 163 refs = [] 164 ref_lock = threading.Lock() 165 166 def dummy_acquire_fn(): 167 return None 168 169 def acquire_fn(): 170 time.sleep(1) 171 return Marker(count) 172 173 def thread_fn(): 174 p = shared_handle.acquire(acquire_fn) 175 with ref_lock: 176 refs.append(p) 177 178 threads = [] 179 for _ in range(100): 180 t = threading.Thread(target=thread_fn) 181 threads.append(t) 182 t.start() 183 184 for t in threads: 185 t.join() 186 187 self.assertEqual(1, count.get_total()) 188 self.assertEqual(1, count.get_active()) 189 190 other_shared_handle.acquire(dummy_acquire_fn) # Get rid of the keepalive 191 192 with ref_lock: 193 del refs[:] 194 gc.collect() 195 196 self.assertEqual(0, count.get_active()) 197 198 def testDifferentObjects(self): 199 sequence = Sequence() 200 201 def dummy_acquire_fn(): 202 return None 203 204 first_handle = shared.Shared() 205 second_handle = shared.Shared() 206 dummy_handle = shared.Shared() 207 208 f1 = first_handle.acquire(sequence.make_acquire_fn()) 209 s1 = second_handle.acquire(sequence.make_acquire_fn()) 210 211 self.assertEqual('sequence1', f1.get_name()) 212 self.assertEqual('sequence2', s1.get_name()) 213 214 f2 = first_handle.acquire(sequence.make_acquire_fn()) 215 s2 = second_handle.acquire(sequence.make_acquire_fn()) 216 217 # Check that the repeated acquisitions return the earlier objects 218 self.assertEqual('sequence1', f2.get_name()) 219 self.assertEqual('sequence2', s2.get_name()) 220 221 # Release all references and force garbage-collection 222 del f1 223 del f2 224 del s1 225 del s2 226 dummy_handle.acquire(dummy_acquire_fn) # Get rid of the keepalive 227 gc.collect() 228 229 # Check that acquiring again after they're released gives new objects 230 f3 = first_handle.acquire(sequence.make_acquire_fn()) 231 s3 = second_handle.acquire(sequence.make_acquire_fn()) 232 self.assertEqual('sequence3', f3.get_name()) 233 self.assertEqual('sequence4', s3.get_name()) 234 235 def testTagCacheEviction(self): 236 shared1 = shared.Shared() 237 shared2 = shared.Shared() 238 239 def acquire_fn_1(): 240 return NamedObject('obj_1') 241 242 def acquire_fn_2(): 243 return NamedObject('obj_2') 244 245 # with no tag, shared handle does not know when to evict objects 246 p1 = shared1.acquire(acquire_fn_1) 247 assert p1.get_name() == 'obj_1' 248 p2 = shared1.acquire(acquire_fn_2) 249 assert p2.get_name() == 'obj_1' 250 251 # cache eviction can be forced by specifying different tags 252 p1 = shared2.acquire(acquire_fn_1, tag='1') 253 assert p1.get_name() == 'obj_1' 254 p2 = shared2.acquire(acquire_fn_2, tag='2') 255 assert p2.get_name() == 'obj_2' 256 257 def testTagReturnsCached(self): 258 sequence = Sequence() 259 handle = shared.Shared() 260 261 f1 = handle.acquire(sequence.make_acquire_fn(), tag='1') 262 self.assertEqual('sequence1', f1.get_name()) 263 264 # should return cached 265 f1 = handle.acquire(sequence.make_acquire_fn(), tag='1') 266 self.assertEqual('sequence1', f1.get_name()) 267 268 269 if __name__ == '__main__': 270 unittest.main()