github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/utils/shared_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Test for Shared class."""
    19  
    20  import gc
    21  import threading
    22  import time
    23  import unittest
    24  
    25  from apache_beam.utils import shared
    26  
    27  
    28  class Count(object):
    29    def __init__(self):
    30      self._lock = threading.Lock()
    31      self._total = 0
    32      self._active = 0
    33  
    34    def add_ref(self):
    35      with self._lock:
    36        self._total += 1
    37        self._active += 1
    38  
    39    def release_ref(self):
    40      with self._lock:
    41        self._active -= 1
    42  
    43    def get_active(self):
    44      with self._lock:
    45        return self._active
    46  
    47    def get_total(self):
    48      with self._lock:
    49        return self._total
    50  
    51  
    52  class Marker(object):
    53    def __init__(self, count):
    54      self._count = count
    55      self._count.add_ref()
    56  
    57    def __del__(self):
    58      self._count.release_ref()
    59  
    60  
    61  class NamedObject(object):
    62    def __init__(self, name):
    63      self._name = name
    64  
    65    def get_name(self):
    66      return self._name
    67  
    68  
    69  class Sequence(object):
    70    def __init__(self):
    71      self._sequence = 0
    72  
    73    def make_acquire_fn(self):
    74      # Every time acquire_fn is called, increases the sequence number and returns
    75      # a NamedObject with that sequenece number.
    76      def acquire_fn():
    77        self._sequence += 1
    78        return NamedObject('sequence%d' % self._sequence)
    79  
    80      return acquire_fn
    81  
    82  
    83  class SharedTest(unittest.TestCase):
    84    def testKeepalive(self):
    85      count = Count()
    86      shared_handle = shared.Shared()
    87      other_shared_handle = shared.Shared()
    88  
    89      def dummy_acquire_fn():
    90        return None
    91  
    92      def acquire_fn():
    93        return Marker(count)
    94  
    95      p1 = shared_handle.acquire(acquire_fn)
    96      self.assertEqual(1, count.get_total())
    97      self.assertEqual(1, count.get_active())
    98      del p1
    99      gc.collect()
   100      # Won't be garbage collected, because of the keep-alive
   101      self.assertEqual(1, count.get_active())
   102  
   103      # Reacquire.
   104      p2 = shared_handle.acquire(acquire_fn)
   105      self.assertEqual(1, count.get_total())  # No reinitialisation.
   106      self.assertEqual(1, count.get_active())
   107  
   108      # Get rid of the keepalive
   109      other_shared_handle.acquire(dummy_acquire_fn)
   110      del p2
   111      gc.collect()
   112      self.assertEqual(0, count.get_active())
   113  
   114    def testMultiple(self):
   115      count = Count()
   116      shared_handle = shared.Shared()
   117      other_shared_handle = shared.Shared()
   118  
   119      def dummy_acquire_fn():
   120        return None
   121  
   122      def acquire_fn():
   123        return Marker(count)
   124  
   125      p = shared_handle.acquire(acquire_fn)
   126      other_shared_handle.acquire(dummy_acquire_fn)  # Get rid of the keepalive
   127      self.assertEqual(1, count.get_total())
   128      self.assertEqual(1, count.get_active())
   129      del p
   130      gc.collect()
   131      self.assertEqual(0, count.get_active())
   132      # Shared value should be garbage collected.
   133  
   134      # Acquiring multiple times only results in one initialisation
   135      p1 = shared_handle.acquire(acquire_fn)
   136      # Since shared value was released, expect a reinitialisation.
   137      self.assertEqual(2, count.get_total())
   138      self.assertEqual(1, count.get_active())
   139      p2 = shared_handle.acquire(acquire_fn)
   140      self.assertEqual(2, count.get_total())
   141      self.assertEqual(1, count.get_active())
   142  
   143      other_shared_handle.acquire(dummy_acquire_fn)  # Get rid of the keepalive
   144  
   145      # Check that shared object isn't destroyed if there's still a reference to
   146      # it.
   147      del p2
   148      gc.collect()
   149      self.assertEqual(1, count.get_active())
   150  
   151      del p1
   152      gc.collect()
   153      self.assertEqual(0, count.get_active())
   154  
   155    def testConcurrentCallsDeduped(self):
   156      # Test that only one among many calls to acquire will actually run the
   157      # initialisation function.
   158  
   159      count = Count()
   160      shared_handle = shared.Shared()
   161      other_shared_handle = shared.Shared()
   162  
   163      refs = []
   164      ref_lock = threading.Lock()
   165  
   166      def dummy_acquire_fn():
   167        return None
   168  
   169      def acquire_fn():
   170        time.sleep(1)
   171        return Marker(count)
   172  
   173      def thread_fn():
   174        p = shared_handle.acquire(acquire_fn)
   175        with ref_lock:
   176          refs.append(p)
   177  
   178      threads = []
   179      for _ in range(100):
   180        t = threading.Thread(target=thread_fn)
   181        threads.append(t)
   182        t.start()
   183  
   184      for t in threads:
   185        t.join()
   186  
   187      self.assertEqual(1, count.get_total())
   188      self.assertEqual(1, count.get_active())
   189  
   190      other_shared_handle.acquire(dummy_acquire_fn)  # Get rid of the keepalive
   191  
   192      with ref_lock:
   193        del refs[:]
   194      gc.collect()
   195  
   196      self.assertEqual(0, count.get_active())
   197  
   198    def testDifferentObjects(self):
   199      sequence = Sequence()
   200  
   201      def dummy_acquire_fn():
   202        return None
   203  
   204      first_handle = shared.Shared()
   205      second_handle = shared.Shared()
   206      dummy_handle = shared.Shared()
   207  
   208      f1 = first_handle.acquire(sequence.make_acquire_fn())
   209      s1 = second_handle.acquire(sequence.make_acquire_fn())
   210  
   211      self.assertEqual('sequence1', f1.get_name())
   212      self.assertEqual('sequence2', s1.get_name())
   213  
   214      f2 = first_handle.acquire(sequence.make_acquire_fn())
   215      s2 = second_handle.acquire(sequence.make_acquire_fn())
   216  
   217      # Check that the repeated acquisitions return the earlier objects
   218      self.assertEqual('sequence1', f2.get_name())
   219      self.assertEqual('sequence2', s2.get_name())
   220  
   221      # Release all references and force garbage-collection
   222      del f1
   223      del f2
   224      del s1
   225      del s2
   226      dummy_handle.acquire(dummy_acquire_fn)  # Get rid of the keepalive
   227      gc.collect()
   228  
   229      # Check that acquiring again after they're released gives new objects
   230      f3 = first_handle.acquire(sequence.make_acquire_fn())
   231      s3 = second_handle.acquire(sequence.make_acquire_fn())
   232      self.assertEqual('sequence3', f3.get_name())
   233      self.assertEqual('sequence4', s3.get_name())
   234  
   235    def testTagCacheEviction(self):
   236      shared1 = shared.Shared()
   237      shared2 = shared.Shared()
   238  
   239      def acquire_fn_1():
   240        return NamedObject('obj_1')
   241  
   242      def acquire_fn_2():
   243        return NamedObject('obj_2')
   244  
   245      # with no tag, shared handle does not know when to evict objects
   246      p1 = shared1.acquire(acquire_fn_1)
   247      assert p1.get_name() == 'obj_1'
   248      p2 = shared1.acquire(acquire_fn_2)
   249      assert p2.get_name() == 'obj_1'
   250  
   251      # cache eviction can be forced by specifying different tags
   252      p1 = shared2.acquire(acquire_fn_1, tag='1')
   253      assert p1.get_name() == 'obj_1'
   254      p2 = shared2.acquire(acquire_fn_2, tag='2')
   255      assert p2.get_name() == 'obj_2'
   256  
   257    def testTagReturnsCached(self):
   258      sequence = Sequence()
   259      handle = shared.Shared()
   260  
   261      f1 = handle.acquire(sequence.make_acquire_fn(), tag='1')
   262      self.assertEqual('sequence1', f1.get_name())
   263  
   264      # should return cached
   265      f1 = handle.acquire(sequence.make_acquire_fn(), tag='1')
   266      self.assertEqual('sequence1', f1.get_name())
   267  
   268  
   269  if __name__ == '__main__':
   270    unittest.main()