github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/worker/statesampler_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Tests for state sampler.""" 19 # pytype: skip-file 20 21 import logging 22 import time 23 import unittest 24 25 from tenacity import retry 26 from tenacity import stop_after_attempt 27 28 from apache_beam.runners.worker import statesampler 29 from apache_beam.utils.counters import CounterFactory 30 from apache_beam.utils.counters import CounterName 31 32 _LOGGER = logging.getLogger(__name__) 33 34 35 class StateSamplerTest(unittest.TestCase): 36 37 # Due to somewhat non-deterministic nature of state sampling and sleep, 38 # this test is flaky when state duration is low. 39 # Since increasing state duration significantly would also slow down 40 # the test suite, we are retrying twice on failure as a mitigation. 41 @retry(reraise=True, stop=stop_after_attempt(3)) 42 def test_basic_sampler(self): 43 # Set up state sampler. 44 counter_factory = CounterFactory() 45 sampler = statesampler.StateSampler( 46 'basic', counter_factory, sampling_period_ms=1) 47 48 # Duration of the fastest state. Total test duration is 6 times longer. 49 state_duration_ms = 1000 50 margin_of_error = 0.25 51 # Run basic workload transitioning between 3 states. 52 sampler.start() 53 with sampler.scoped_state('step1', 'statea'): 54 time.sleep(state_duration_ms / 1000) 55 self.assertEqual( 56 sampler.current_state().name, 57 CounterName('statea-msecs', step_name='step1', stage_name='basic')) 58 with sampler.scoped_state('step1', 'stateb'): 59 time.sleep(state_duration_ms / 1000) 60 self.assertEqual( 61 sampler.current_state().name, 62 CounterName('stateb-msecs', step_name='step1', stage_name='basic')) 63 with sampler.scoped_state('step1', 'statec'): 64 time.sleep(3 * state_duration_ms / 1000) 65 self.assertEqual( 66 sampler.current_state().name, 67 CounterName( 68 'statec-msecs', step_name='step1', stage_name='basic')) 69 time.sleep(state_duration_ms / 1000) 70 71 sampler.stop() 72 sampler.commit_counters() 73 74 if not statesampler.FAST_SAMPLER: 75 # The slow sampler does not implement sampling, so we won't test it. 76 return 77 78 # Test that sampled state timings are close to their expected values. 79 # yapf: disable 80 expected_counter_values = { 81 CounterName('statea-msecs', step_name='step1', stage_name='basic'): 82 state_duration_ms, 83 CounterName('stateb-msecs', step_name='step1', stage_name='basic'): 2 * 84 state_duration_ms, 85 CounterName('statec-msecs', step_name='step1', stage_name='basic'): 3 * 86 state_duration_ms, 87 } 88 # yapf: enable 89 for counter in counter_factory.get_counters(): 90 self.assertIn(counter.name, expected_counter_values) 91 expected_value = expected_counter_values[counter.name] 92 actual_value = counter.value() 93 deviation = float(abs(actual_value - expected_value)) / expected_value 94 _LOGGER.info('Sampling deviation from expectation: %f', deviation) 95 self.assertGreater(actual_value, expected_value * (1.0 - margin_of_error)) 96 self.assertLess(actual_value, expected_value * (1.0 + margin_of_error)) 97 98 # TODO: This test is flaky when it is run under load. A better solution 99 # would be to change the test structure to not depend on specific timings. 100 @retry(reraise=True, stop=stop_after_attempt(3)) 101 def test_sampler_transition_overhead(self): 102 # Set up state sampler. 103 counter_factory = CounterFactory() 104 sampler = statesampler.StateSampler( 105 'overhead-', counter_factory, sampling_period_ms=10) 106 107 # Run basic workload transitioning between 3 states. 108 state_a = sampler.scoped_state('step1', 'statea') 109 state_b = sampler.scoped_state('step1', 'stateb') 110 state_c = sampler.scoped_state('step1', 'statec') 111 start_time = time.time() 112 sampler.start() 113 for _ in range(100000): 114 with state_a: 115 with state_b: 116 for _ in range(10): 117 with state_c: 118 pass 119 sampler.stop() 120 elapsed_time = time.time() - start_time 121 state_transition_count = sampler.get_info().transition_count 122 overhead_us = 1000000.0 * elapsed_time / state_transition_count 123 124 _LOGGER.info('Overhead per transition: %fus', overhead_us) 125 # Conservative upper bound on overhead in microseconds (we expect this to 126 # take 0.17us when compiled in opt mode or 0.48 us when compiled with in 127 # debug mode). 128 self.assertLess(overhead_us, 20.0) 129 130 131 if __name__ == '__main__': 132 logging.getLogger().setLevel(logging.INFO) 133 unittest.main()