github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/worker/log_handler_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  # pytype: skip-file
    19  
    20  import logging
    21  import re
    22  import unittest
    23  
    24  import grpc
    25  
    26  from apache_beam.portability.api import beam_fn_api_pb2
    27  from apache_beam.portability.api import beam_fn_api_pb2_grpc
    28  from apache_beam.portability.api import endpoints_pb2
    29  from apache_beam.runners.common import NameContext
    30  from apache_beam.runners.worker import log_handler
    31  from apache_beam.runners.worker import statesampler
    32  from apache_beam.utils import thread_pool_executor
    33  
    34  _LOGGER = logging.getLogger(__name__)
    35  
    36  
    37  class BeamFnLoggingServicer(beam_fn_api_pb2_grpc.BeamFnLoggingServicer):
    38    def __init__(self):
    39      self.log_records_received = []
    40  
    41    def Logging(self, request_iterator, context):
    42  
    43      for log_record in request_iterator:
    44        self.log_records_received.append(log_record)
    45  
    46      yield beam_fn_api_pb2.LogControl()
    47  
    48  
    49  class FnApiLogRecordHandlerTest(unittest.TestCase):
    50    def setUp(self):
    51      self.test_logging_service = BeamFnLoggingServicer()
    52      self.server = grpc.server(thread_pool_executor.shared_unbounded_instance())
    53      beam_fn_api_pb2_grpc.add_BeamFnLoggingServicer_to_server(
    54          self.test_logging_service, self.server)
    55      self.test_port = self.server.add_insecure_port('[::]:0')
    56      self.server.start()
    57  
    58      self.logging_service_descriptor = endpoints_pb2.ApiServiceDescriptor()
    59      self.logging_service_descriptor.url = 'localhost:%s' % self.test_port
    60      self.fn_log_handler = log_handler.FnApiLogRecordHandler(
    61          self.logging_service_descriptor)
    62      logging.getLogger().setLevel(logging.INFO)
    63      logging.getLogger().addHandler(self.fn_log_handler)
    64  
    65    def tearDown(self):
    66      # wait upto 5 seconds.
    67      self.server.stop(5)
    68  
    69    def _verify_fn_log_handler(self, num_log_entries):
    70      msg = 'Testing fn logging'
    71      _LOGGER.debug('Debug Message 1')
    72      for idx in range(num_log_entries):
    73        _LOGGER.info('%s: %s', msg, idx)
    74      _LOGGER.debug('Debug Message 2')
    75  
    76      # Wait for logs to be sent to server.
    77      self.fn_log_handler.close()
    78  
    79      num_received_log_entries = 0
    80      for outer in self.test_logging_service.log_records_received:
    81        for log_entry in outer.log_entries:
    82          self.assertEqual(
    83              beam_fn_api_pb2.LogEntry.Severity.INFO, log_entry.severity)
    84          self.assertEqual(
    85              '%s: %s' % (msg, num_received_log_entries), log_entry.message)
    86          self.assertTrue(
    87              re.match(r'.*log_handler_test.py:\d+', log_entry.log_location),
    88              log_entry.log_location)
    89          self.assertGreater(log_entry.timestamp.seconds, 0)
    90          self.assertGreaterEqual(log_entry.timestamp.nanos, 0)
    91          num_received_log_entries += 1
    92  
    93      self.assertEqual(num_received_log_entries, num_log_entries)
    94  
    95    def assertContains(self, haystack, needle):
    96      self.assertTrue(
    97          needle in haystack, 'Expected %r to contain %r.' % (haystack, needle))
    98  
    99    def test_exc_info(self):
   100      try:
   101        raise ValueError('some message')
   102      except ValueError:
   103        _LOGGER.error('some error', exc_info=True)
   104  
   105      self.fn_log_handler.close()
   106  
   107      log_entry = self.test_logging_service.log_records_received[0].log_entries[0]
   108      self.assertContains(log_entry.message, 'some error')
   109      self.assertContains(log_entry.trace, 'some message')
   110      self.assertContains(log_entry.trace, 'log_handler_test.py')
   111  
   112    def test_format_bad_message(self):
   113      # We specifically emit to the handler directly since we don't want to emit
   114      # to all handlers in general since we know that this record will raise an
   115      # exception during formatting.
   116      self.fn_log_handler.emit(
   117          logging.LogRecord(
   118              'name',
   119              logging.ERROR,
   120              'pathname',
   121              777,
   122              'TestLog %d', (None, ),
   123              exc_info=None))
   124      self.fn_log_handler.close()
   125      log_entry = self.test_logging_service.log_records_received[0].log_entries[0]
   126      self.assertContains(
   127          log_entry.message,
   128          "Failed to format 'TestLog %d' with args '(None,)' during logging.")
   129  
   130    def test_context(self):
   131      try:
   132        with statesampler.instruction_id('A'):
   133          tracker = statesampler.for_test()
   134          with tracker.scoped_state(NameContext('name', 'tid'), 'stage'):
   135            _LOGGER.info('message a')
   136        with statesampler.instruction_id('B'):
   137          _LOGGER.info('message b')
   138        _LOGGER.info('message c')
   139  
   140        self.fn_log_handler.close()
   141        a, b, c = sum(
   142            [list(logs.log_entries)
   143             for logs in self.test_logging_service.log_records_received], [])
   144  
   145        self.assertEqual(a.instruction_id, 'A')
   146        self.assertEqual(b.instruction_id, 'B')
   147        self.assertEqual(c.instruction_id, '')
   148  
   149        self.assertEqual(a.transform_id, 'tid')
   150        self.assertEqual(b.transform_id, '')
   151        self.assertEqual(c.transform_id, '')
   152  
   153      finally:
   154        statesampler.set_current_tracker(None)
   155  
   156  
   157  # Test cases.
   158  data = {
   159      'one_batch': log_handler.FnApiLogRecordHandler._MAX_BATCH_SIZE - 47,
   160      'exact_multiple': log_handler.FnApiLogRecordHandler._MAX_BATCH_SIZE,
   161      'multi_batch': log_handler.FnApiLogRecordHandler._MAX_BATCH_SIZE * 3 + 47
   162  }
   163  
   164  
   165  def _create_test(name, num_logs):
   166    setattr(
   167        FnApiLogRecordHandlerTest,
   168        'test_%s' % name,
   169        lambda self: self._verify_fn_log_handler(num_logs))
   170  
   171  
   172  for test_name, num_logs_entries in data.items():
   173    _create_test(test_name, num_logs_entries)
   174  
   175  if __name__ == '__main__':
   176    unittest.main()