github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/worker/statesampler.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  # pytype: skip-file
    19  
    20  import contextlib
    21  import threading
    22  from typing import TYPE_CHECKING
    23  from typing import Dict
    24  from typing import NamedTuple
    25  from typing import Optional
    26  from typing import Union
    27  
    28  from apache_beam.runners import common
    29  from apache_beam.utils.counters import Counter
    30  from apache_beam.utils.counters import CounterFactory
    31  from apache_beam.utils.counters import CounterName
    32  
    33  try:
    34    from apache_beam.runners.worker import statesampler_fast as statesampler_impl  # type: ignore
    35    FAST_SAMPLER = True
    36  except ImportError:
    37    from apache_beam.runners.worker import statesampler_slow as statesampler_impl
    38    FAST_SAMPLER = False
    39  
    40  if TYPE_CHECKING:
    41    from apache_beam.metrics.execution import MetricsContainer
    42  
    43  _STATE_SAMPLERS = threading.local()
    44  
    45  
    46  def set_current_tracker(tracker):
    47    _STATE_SAMPLERS.tracker = tracker
    48  
    49  
    50  def get_current_tracker():
    51    try:
    52      return _STATE_SAMPLERS.tracker
    53    except AttributeError:
    54      return None
    55  
    56  
    57  _INSTRUCTION_IDS = threading.local()
    58  
    59  
    60  def get_current_instruction_id():
    61    try:
    62      return _INSTRUCTION_IDS.instruction_id
    63    except AttributeError:
    64      return None
    65  
    66  
    67  @contextlib.contextmanager
    68  def instruction_id(id):
    69    try:
    70      _INSTRUCTION_IDS.instruction_id = id
    71      yield
    72    finally:
    73      _INSTRUCTION_IDS.instruction_id = None
    74  
    75  
    76  def for_test():
    77    set_current_tracker(StateSampler('test', CounterFactory()))
    78    return get_current_tracker()
    79  
    80  
    81  StateSamplerInfo = NamedTuple(
    82      'StateSamplerInfo',
    83      [('state_name', CounterName), ('transition_count', int),
    84       ('time_since_transition', int),
    85       ('tracked_thread', Optional[threading.Thread])])
    86  
    87  # Default period for sampling current state of pipeline execution.
    88  DEFAULT_SAMPLING_PERIOD_MS = 200
    89  
    90  
    91  class StateSampler(statesampler_impl.StateSampler):
    92  
    93    def __init__(self,
    94                 prefix,  # type: str
    95                 counter_factory,
    96                 sampling_period_ms=DEFAULT_SAMPLING_PERIOD_MS):
    97      self._prefix = prefix
    98      self._counter_factory = counter_factory
    99      self._states_by_name = {
   100      }  # type: Dict[CounterName, statesampler_impl.ScopedState]
   101      self.sampling_period_ms = sampling_period_ms
   102      self.tracked_thread = None  # type: Optional[threading.Thread]
   103      self.finished = False
   104      self.started = False
   105      super().__init__(sampling_period_ms)
   106  
   107    @property
   108    def stage_name(self):
   109      # type: () -> str
   110      return self._prefix
   111  
   112    def stop(self):
   113      # type: () -> None
   114      set_current_tracker(None)
   115      super().stop()
   116  
   117    def stop_if_still_running(self):
   118      # type: () -> None
   119      if self.started and not self.finished:
   120        self.stop()
   121  
   122    def start(self):
   123      # type: () -> None
   124      self.tracked_thread = threading.current_thread()
   125      set_current_tracker(self)
   126      super().start()
   127      self.started = True
   128  
   129    def get_info(self):
   130      # type: () -> StateSamplerInfo
   131  
   132      """Returns StateSamplerInfo with transition statistics."""
   133      return StateSamplerInfo(
   134          self.current_state().name,
   135          self.state_transition_count,
   136          self.time_since_transition,
   137          self.tracked_thread)
   138  
   139    def scoped_state(self,
   140                     name_context,  # type: Union[str, common.NameContext]
   141                     state_name,  # type: str
   142                     io_target=None,
   143                     metrics_container=None  # type: Optional[MetricsContainer]
   144                    ):
   145      # type: (...) -> statesampler_impl.ScopedState
   146  
   147      """Returns a ScopedState object associated to a Step and a State.
   148  
   149      Args:
   150        name_context: common.NameContext. It is the step name information.
   151        state_name: str. It is the state name (e.g. process / start / finish).
   152        io_target:
   153        metrics_container: MetricsContainer. The step's metrics container.
   154  
   155      Returns:
   156        A ScopedState that keeps the execution context and is able to switch it
   157        for the execution thread.
   158      """
   159      if not isinstance(name_context, common.NameContext):
   160        name_context = common.NameContext(name_context)
   161  
   162      counter_name = CounterName(
   163          state_name + '-msecs',
   164          stage_name=self._prefix,
   165          step_name=name_context.metrics_name(),
   166          io_target=io_target)
   167      if counter_name in self._states_by_name:
   168        return self._states_by_name[counter_name]
   169      else:
   170        output_counter = self._counter_factory.get_counter(
   171            counter_name, Counter.SUM)
   172        self._states_by_name[counter_name] = super()._scoped_state(
   173            counter_name, name_context, output_counter, metrics_container)
   174        return self._states_by_name[counter_name]
   175  
   176    def commit_counters(self):
   177      # type: () -> None
   178  
   179      """Updates output counters with latest state statistics."""
   180      for state in self._states_by_name.values():
   181        state_msecs = int(1e-6 * state.nsecs)
   182        state.counter.update(state_msecs - state.counter.value())