github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/worker/statesampler.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 # pytype: skip-file 19 20 import contextlib 21 import threading 22 from typing import TYPE_CHECKING 23 from typing import Dict 24 from typing import NamedTuple 25 from typing import Optional 26 from typing import Union 27 28 from apache_beam.runners import common 29 from apache_beam.utils.counters import Counter 30 from apache_beam.utils.counters import CounterFactory 31 from apache_beam.utils.counters import CounterName 32 33 try: 34 from apache_beam.runners.worker import statesampler_fast as statesampler_impl # type: ignore 35 FAST_SAMPLER = True 36 except ImportError: 37 from apache_beam.runners.worker import statesampler_slow as statesampler_impl 38 FAST_SAMPLER = False 39 40 if TYPE_CHECKING: 41 from apache_beam.metrics.execution import MetricsContainer 42 43 _STATE_SAMPLERS = threading.local() 44 45 46 def set_current_tracker(tracker): 47 _STATE_SAMPLERS.tracker = tracker 48 49 50 def get_current_tracker(): 51 try: 52 return _STATE_SAMPLERS.tracker 53 except AttributeError: 54 return None 55 56 57 _INSTRUCTION_IDS = threading.local() 58 59 60 def get_current_instruction_id(): 61 try: 62 return _INSTRUCTION_IDS.instruction_id 63 except AttributeError: 64 return None 65 66 67 @contextlib.contextmanager 68 def instruction_id(id): 69 try: 70 _INSTRUCTION_IDS.instruction_id = id 71 yield 72 finally: 73 _INSTRUCTION_IDS.instruction_id = None 74 75 76 def for_test(): 77 set_current_tracker(StateSampler('test', CounterFactory())) 78 return get_current_tracker() 79 80 81 StateSamplerInfo = NamedTuple( 82 'StateSamplerInfo', 83 [('state_name', CounterName), ('transition_count', int), 84 ('time_since_transition', int), 85 ('tracked_thread', Optional[threading.Thread])]) 86 87 # Default period for sampling current state of pipeline execution. 88 DEFAULT_SAMPLING_PERIOD_MS = 200 89 90 91 class StateSampler(statesampler_impl.StateSampler): 92 93 def __init__(self, 94 prefix, # type: str 95 counter_factory, 96 sampling_period_ms=DEFAULT_SAMPLING_PERIOD_MS): 97 self._prefix = prefix 98 self._counter_factory = counter_factory 99 self._states_by_name = { 100 } # type: Dict[CounterName, statesampler_impl.ScopedState] 101 self.sampling_period_ms = sampling_period_ms 102 self.tracked_thread = None # type: Optional[threading.Thread] 103 self.finished = False 104 self.started = False 105 super().__init__(sampling_period_ms) 106 107 @property 108 def stage_name(self): 109 # type: () -> str 110 return self._prefix 111 112 def stop(self): 113 # type: () -> None 114 set_current_tracker(None) 115 super().stop() 116 117 def stop_if_still_running(self): 118 # type: () -> None 119 if self.started and not self.finished: 120 self.stop() 121 122 def start(self): 123 # type: () -> None 124 self.tracked_thread = threading.current_thread() 125 set_current_tracker(self) 126 super().start() 127 self.started = True 128 129 def get_info(self): 130 # type: () -> StateSamplerInfo 131 132 """Returns StateSamplerInfo with transition statistics.""" 133 return StateSamplerInfo( 134 self.current_state().name, 135 self.state_transition_count, 136 self.time_since_transition, 137 self.tracked_thread) 138 139 def scoped_state(self, 140 name_context, # type: Union[str, common.NameContext] 141 state_name, # type: str 142 io_target=None, 143 metrics_container=None # type: Optional[MetricsContainer] 144 ): 145 # type: (...) -> statesampler_impl.ScopedState 146 147 """Returns a ScopedState object associated to a Step and a State. 148 149 Args: 150 name_context: common.NameContext. It is the step name information. 151 state_name: str. It is the state name (e.g. process / start / finish). 152 io_target: 153 metrics_container: MetricsContainer. The step's metrics container. 154 155 Returns: 156 A ScopedState that keeps the execution context and is able to switch it 157 for the execution thread. 158 """ 159 if not isinstance(name_context, common.NameContext): 160 name_context = common.NameContext(name_context) 161 162 counter_name = CounterName( 163 state_name + '-msecs', 164 stage_name=self._prefix, 165 step_name=name_context.metrics_name(), 166 io_target=io_target) 167 if counter_name in self._states_by_name: 168 return self._states_by_name[counter_name] 169 else: 170 output_counter = self._counter_factory.get_counter( 171 counter_name, Counter.SUM) 172 self._states_by_name[counter_name] = super()._scoped_state( 173 counter_name, name_context, output_counter, metrics_container) 174 return self._states_by_name[counter_name] 175 176 def commit_counters(self): 177 # type: () -> None 178 179 """Updates output counters with latest state statistics.""" 180 for state in self._states_by_name.values(): 181 state_msecs = int(1e-6 * state.nsecs) 182 state.counter.update(state_msecs - state.counter.value())