github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/utils/multi_process_shared_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  # pytype: skip-file
    18  
    19  import logging
    20  import threading
    21  import unittest
    22  
    23  from apache_beam.utils import multi_process_shared
    24  
    25  
    26  class CallableCounter(object):
    27    def __init__(self, start=0):
    28      self.running = start
    29      self.lock = threading.Lock()
    30  
    31    def __call__(self):
    32      return self.running
    33  
    34    def increment(self, value=1):
    35      with self.lock:
    36        self.running += value
    37        return self.running
    38  
    39    def error(self, msg):
    40      raise RuntimeError(msg)
    41  
    42  
    43  class Counter(object):
    44    def __init__(self, start=0):
    45      self.running = start
    46      self.lock = threading.Lock()
    47  
    48    def get(self):
    49      return self.running
    50  
    51    def increment(self, value=1):
    52      with self.lock:
    53        self.running += value
    54        return self.running
    55  
    56    def error(self, msg):
    57      raise RuntimeError(msg)
    58  
    59  
    60  class MultiProcessSharedTest(unittest.TestCase):
    61    @classmethod
    62    def setUpClass(cls):
    63      cls.shared = multi_process_shared.MultiProcessShared(
    64          Counter, tag='basic', always_proxy=True).acquire()
    65      cls.sharedCallable = multi_process_shared.MultiProcessShared(
    66          CallableCounter, tag='callable', always_proxy=True).acquire()
    67  
    68    def test_call(self):
    69      self.assertEqual(self.shared.get(), 0)
    70      self.assertEqual(self.shared.increment(), 1)
    71      self.assertEqual(self.shared.increment(10), 11)
    72      self.assertEqual(self.shared.increment(value=10), 21)
    73      self.assertEqual(self.shared.get(), 21)
    74  
    75    def test_call_callable(self):
    76      self.assertEqual(self.sharedCallable(), 0)
    77      self.assertEqual(self.sharedCallable.increment(), 1)
    78      self.assertEqual(self.sharedCallable.increment(10), 11)
    79      self.assertEqual(self.sharedCallable.increment(value=10), 21)
    80      self.assertEqual(self.sharedCallable(), 21)
    81  
    82    def test_error(self):
    83      with self.assertRaisesRegex(Exception, 'something bad'):
    84        self.shared.error('something bad')
    85  
    86    def test_no_method(self):
    87      with self.assertRaisesRegex(Exception, 'no_such_method'):
    88        self.shared.no_such_method()
    89  
    90    def test_connect(self):
    91      first = multi_process_shared.MultiProcessShared(
    92          Counter, tag='counter').acquire()
    93      second = multi_process_shared.MultiProcessShared(
    94          Counter, tag='counter').acquire()
    95      self.assertEqual(first.get(), 0)
    96      self.assertEqual(first.increment(), 1)
    97  
    98      self.assertEqual(second.get(), 1)
    99      self.assertEqual(second.increment(), 2)
   100  
   101      self.assertEqual(first.get(), 2)
   102      self.assertEqual(first.increment(), 3)
   103  
   104    def test_release(self):
   105      shared1 = multi_process_shared.MultiProcessShared(
   106          Counter, tag='test_release')
   107      shared2 = multi_process_shared.MultiProcessShared(
   108          Counter, tag='test_release')
   109  
   110      counter1 = shared1.acquire()
   111      counter2 = shared2.acquire()
   112      self.assertEqual(counter1.increment(), 1)
   113      self.assertEqual(counter2.increment(), 2)
   114  
   115      counter1again = shared1.acquire()
   116      self.assertEqual(counter1again.increment(), 3)
   117  
   118      shared1.release(counter1)
   119      shared2.release(counter2)
   120  
   121      with self.assertRaisesRegex(Exception, 'released'):
   122        counter1.get()
   123      with self.assertRaisesRegex(Exception, 'released'):
   124        counter2.get()
   125  
   126      self.assertEqual(counter1again.get(), 3)
   127  
   128      shared1.release(counter1again)
   129  
   130      counter1New = shared1.acquire()
   131      self.assertEqual(counter1New.get(), 0)
   132  
   133      with self.assertRaisesRegex(Exception, 'released'):
   134        counter1.get()
   135  
   136  
   137  if __name__ == '__main__':
   138    logging.getLogger().setLevel(logging.INFO)
   139    unittest.main()