github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/worker/log_handler.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Beam fn API log handler."""
    19  
    20  # pytype: skip-file
    21  # mypy: disallow-untyped-defs
    22  
    23  import logging
    24  import math
    25  import queue
    26  import sys
    27  import threading
    28  import time
    29  import traceback
    30  from typing import TYPE_CHECKING
    31  from typing import Iterable
    32  from typing import Iterator
    33  from typing import List
    34  from typing import Union
    35  from typing import cast
    36  
    37  import grpc
    38  
    39  from apache_beam.portability.api import beam_fn_api_pb2
    40  from apache_beam.portability.api import beam_fn_api_pb2_grpc
    41  from apache_beam.runners.worker import statesampler
    42  from apache_beam.runners.worker.channel_factory import GRPCChannelFactory
    43  from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor
    44  from apache_beam.utils.sentinel import Sentinel
    45  
    46  if TYPE_CHECKING:
    47    from apache_beam.portability.api import endpoints_pb2
    48  
    49  # Mapping from logging levels to LogEntry levels.
    50  LOG_LEVEL_TO_LOGENTRY_MAP = {
    51      logging.FATAL: beam_fn_api_pb2.LogEntry.Severity.CRITICAL,
    52      logging.ERROR: beam_fn_api_pb2.LogEntry.Severity.ERROR,
    53      logging.WARNING: beam_fn_api_pb2.LogEntry.Severity.WARN,
    54      logging.INFO: beam_fn_api_pb2.LogEntry.Severity.INFO,
    55      logging.DEBUG: beam_fn_api_pb2.LogEntry.Severity.DEBUG,
    56      logging.NOTSET: beam_fn_api_pb2.LogEntry.Severity.UNSPECIFIED,
    57      -float('inf'): beam_fn_api_pb2.LogEntry.Severity.DEBUG,
    58  }
    59  
    60  # Mapping from LogEntry levels to logging levels
    61  LOGENTRY_TO_LOG_LEVEL_MAP = {
    62      beam_fn_api_pb2.LogEntry.Severity.CRITICAL: logging.CRITICAL,
    63      beam_fn_api_pb2.LogEntry.Severity.ERROR: logging.ERROR,
    64      beam_fn_api_pb2.LogEntry.Severity.WARN: logging.WARNING,
    65      beam_fn_api_pb2.LogEntry.Severity.NOTICE: logging.INFO + 1,
    66      beam_fn_api_pb2.LogEntry.Severity.INFO: logging.INFO,
    67      beam_fn_api_pb2.LogEntry.Severity.DEBUG: logging.DEBUG,
    68      beam_fn_api_pb2.LogEntry.Severity.TRACE: logging.DEBUG - 1,
    69      beam_fn_api_pb2.LogEntry.Severity.UNSPECIFIED: logging.NOTSET,
    70  }
    71  
    72  
    73  class FnApiLogRecordHandler(logging.Handler):
    74    """A handler that writes log records to the fn API."""
    75  
    76    # Maximum number of log entries in a single stream request.
    77    _MAX_BATCH_SIZE = 1000
    78    # Used to indicate the end of stream.
    79    _FINISHED = Sentinel.sentinel
    80    # Size of the queue used to buffer messages. Once full, messages will be
    81    # dropped. If the average log size is 1KB this may use up to 10MB of memory.
    82    _QUEUE_SIZE = 10000
    83  
    84    def __init__(self, log_service_descriptor):
    85      # type: (endpoints_pb2.ApiServiceDescriptor) -> None
    86      super().__init__()
    87  
    88      self._alive = True
    89      self._dropped_logs = 0
    90      self._log_entry_queue = queue.Queue(
    91          maxsize=self._QUEUE_SIZE
    92      )  # type: queue.Queue[Union[beam_fn_api_pb2.LogEntry, Sentinel]]
    93  
    94      ch = GRPCChannelFactory.insecure_channel(log_service_descriptor.url)
    95      # Make sure the channel is ready to avoid [BEAM-4649]
    96      grpc.channel_ready_future(ch).result(timeout=60)
    97      self._log_channel = grpc.intercept_channel(ch, WorkerIdInterceptor())
    98      self._reader = threading.Thread(
    99          target=lambda: self._read_log_control_messages(),
   100          name='read_log_control_messages')
   101      self._reader.daemon = True
   102      self._reader.start()
   103  
   104    def connect(self):
   105      # type: () -> Iterable
   106      if hasattr(self, '_logging_stub'):
   107        del self._logging_stub  # type: ignore[has-type]
   108      self._logging_stub = beam_fn_api_pb2_grpc.BeamFnLoggingStub(
   109          self._log_channel)
   110      return self._logging_stub.Logging(self._write_log_entries())
   111  
   112    def map_log_level(self, level):
   113      # type: (int) -> beam_fn_api_pb2.LogEntry.Severity.Enum.ValueType
   114      try:
   115        return LOG_LEVEL_TO_LOGENTRY_MAP[level]
   116      except KeyError:
   117        return max(
   118            beam_level for python_level,
   119            beam_level in LOG_LEVEL_TO_LOGENTRY_MAP.items()
   120            if python_level <= level)
   121  
   122    def emit(self, record):
   123      # type: (logging.LogRecord) -> None
   124      log_entry = beam_fn_api_pb2.LogEntry()
   125      log_entry.severity = self.map_log_level(record.levelno)
   126      try:
   127        log_entry.message = self.format(record)
   128      except Exception:
   129        # record.msg could be an arbitrary object, convert it to a string first.
   130        log_entry.message = (
   131            "Failed to format '%s' with args '%s' during logging." %
   132            (str(record.msg), record.args))
   133      log_entry.thread = record.threadName
   134      log_entry.log_location = '%s:%s' % (
   135          record.pathname or record.module, record.lineno or record.funcName)
   136      (fraction, seconds) = math.modf(record.created)
   137      nanoseconds = 1e9 * fraction
   138      log_entry.timestamp.seconds = int(seconds)
   139      log_entry.timestamp.nanos = int(nanoseconds)
   140      if record.exc_info:
   141        log_entry.trace = ''.join(traceback.format_exception(*record.exc_info))
   142      instruction_id = statesampler.get_current_instruction_id()
   143      if instruction_id:
   144        log_entry.instruction_id = instruction_id
   145      tracker = statesampler.get_current_tracker()
   146      if tracker:
   147        current_state = tracker.current_state()
   148        if (current_state and current_state.name_context and
   149            current_state.name_context.transform_id):
   150          log_entry.transform_id = current_state.name_context.transform_id
   151  
   152      try:
   153        self._log_entry_queue.put(log_entry, block=False)
   154      except queue.Full:
   155        self._dropped_logs += 1
   156  
   157    def close(self):
   158      # type: () -> None
   159  
   160      """Flush out all existing log entries and unregister this handler."""
   161      try:
   162        self._alive = False
   163        # Acquiring the handler lock ensures ``emit`` is not run until the lock is
   164        # released.
   165        self.acquire()
   166        self._log_entry_queue.put(self._FINISHED, timeout=5)
   167        # wait on server to close.
   168        self._reader.join()
   169        self.release()
   170        # Unregister this handler.
   171        super().close()
   172      except Exception:
   173        # Log rather than raising exceptions, to avoid clobbering
   174        # underlying errors that may have caused this to close
   175        # prematurely.
   176        logging.error("Error closing the logging channel.", exc_info=True)
   177  
   178    def _write_log_entries(self):
   179      # type: () -> Iterator[beam_fn_api_pb2.LogEntry.List]
   180      done = False
   181      while not done:
   182        log_entries = [self._log_entry_queue.get()]
   183        try:
   184          for _ in range(self._MAX_BATCH_SIZE):
   185            log_entries.append(self._log_entry_queue.get_nowait())
   186        except queue.Empty:
   187          pass
   188        if log_entries[-1] is self._FINISHED:
   189          done = True
   190          log_entries.pop()
   191        if log_entries:
   192          # typing: log_entries was initialized as List[Union[..., Sentinel]],
   193          # but now that we've popped the sentinel out (above) we can safely cast
   194          yield beam_fn_api_pb2.LogEntry.List(
   195              log_entries=cast(List[beam_fn_api_pb2.LogEntry], log_entries))
   196  
   197    def _read_log_control_messages(self):
   198      # type: () -> None
   199      # Only reconnect when we are alive.
   200      # We can drop some logs in the unlikely event of logging connection
   201      # dropped(not closed) during termination when we still have logs to be sent.
   202      # This case is unlikely and the chance of reconnection and successful
   203      # transmission of logs is also very less as the process is terminating.
   204      # I choose not to handle this case to avoid un-necessary code complexity.
   205  
   206      alive = True  # Force at least one connection attempt.
   207      while alive:
   208        # Loop for reconnection.
   209        log_control_iterator = self.connect()
   210        if self._dropped_logs > 0:
   211          logging.warning(
   212              "Dropped %d logs while logging client disconnected",
   213              self._dropped_logs)
   214          self._dropped_logs = 0
   215        try:
   216          for _ in log_control_iterator:
   217            # Loop for consuming messages from server.
   218            # TODO(vikasrk): Handle control messages.
   219            pass
   220          # iterator is closed
   221          return
   222        except Exception as ex:
   223          print(
   224              "Logging client failed: {}... resetting".format(ex),
   225              file=sys.stderr)
   226          # Wait a bit before trying a reconnect
   227          time.sleep(0.5)  # 0.5 seconds
   228        alive = self._alive