github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/utils/multi_process_shared_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 # pytype: skip-file 18 19 import logging 20 import threading 21 import unittest 22 23 from apache_beam.utils import multi_process_shared 24 25 26 class CallableCounter(object): 27 def __init__(self, start=0): 28 self.running = start 29 self.lock = threading.Lock() 30 31 def __call__(self): 32 return self.running 33 34 def increment(self, value=1): 35 with self.lock: 36 self.running += value 37 return self.running 38 39 def error(self, msg): 40 raise RuntimeError(msg) 41 42 43 class Counter(object): 44 def __init__(self, start=0): 45 self.running = start 46 self.lock = threading.Lock() 47 48 def get(self): 49 return self.running 50 51 def increment(self, value=1): 52 with self.lock: 53 self.running += value 54 return self.running 55 56 def error(self, msg): 57 raise RuntimeError(msg) 58 59 60 class MultiProcessSharedTest(unittest.TestCase): 61 @classmethod 62 def setUpClass(cls): 63 cls.shared = multi_process_shared.MultiProcessShared( 64 Counter, tag='basic', always_proxy=True).acquire() 65 cls.sharedCallable = multi_process_shared.MultiProcessShared( 66 CallableCounter, tag='callable', always_proxy=True).acquire() 67 68 def test_call(self): 69 self.assertEqual(self.shared.get(), 0) 70 self.assertEqual(self.shared.increment(), 1) 71 self.assertEqual(self.shared.increment(10), 11) 72 self.assertEqual(self.shared.increment(value=10), 21) 73 self.assertEqual(self.shared.get(), 21) 74 75 def test_call_callable(self): 76 self.assertEqual(self.sharedCallable(), 0) 77 self.assertEqual(self.sharedCallable.increment(), 1) 78 self.assertEqual(self.sharedCallable.increment(10), 11) 79 self.assertEqual(self.sharedCallable.increment(value=10), 21) 80 self.assertEqual(self.sharedCallable(), 21) 81 82 def test_error(self): 83 with self.assertRaisesRegex(Exception, 'something bad'): 84 self.shared.error('something bad') 85 86 def test_no_method(self): 87 with self.assertRaisesRegex(Exception, 'no_such_method'): 88 self.shared.no_such_method() 89 90 def test_connect(self): 91 first = multi_process_shared.MultiProcessShared( 92 Counter, tag='counter').acquire() 93 second = multi_process_shared.MultiProcessShared( 94 Counter, tag='counter').acquire() 95 self.assertEqual(first.get(), 0) 96 self.assertEqual(first.increment(), 1) 97 98 self.assertEqual(second.get(), 1) 99 self.assertEqual(second.increment(), 2) 100 101 self.assertEqual(first.get(), 2) 102 self.assertEqual(first.increment(), 3) 103 104 def test_release(self): 105 shared1 = multi_process_shared.MultiProcessShared( 106 Counter, tag='test_release') 107 shared2 = multi_process_shared.MultiProcessShared( 108 Counter, tag='test_release') 109 110 counter1 = shared1.acquire() 111 counter2 = shared2.acquire() 112 self.assertEqual(counter1.increment(), 1) 113 self.assertEqual(counter2.increment(), 2) 114 115 counter1again = shared1.acquire() 116 self.assertEqual(counter1again.increment(), 3) 117 118 shared1.release(counter1) 119 shared2.release(counter2) 120 121 with self.assertRaisesRegex(Exception, 'released'): 122 counter1.get() 123 with self.assertRaisesRegex(Exception, 'released'): 124 counter2.get() 125 126 self.assertEqual(counter1again.get(), 3) 127 128 shared1.release(counter1again) 129 130 counter1New = shared1.acquire() 131 self.assertEqual(counter1New.get(), 0) 132 133 with self.assertRaisesRegex(Exception, 'released'): 134 counter1.get() 135 136 137 if __name__ == '__main__': 138 logging.getLogger().setLevel(logging.INFO) 139 unittest.main()