github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/worker/logger_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Tests for worker logging utilities."""
    19  
    20  # pytype: skip-file
    21  
    22  import json
    23  import logging
    24  import sys
    25  import threading
    26  import unittest
    27  
    28  from apache_beam.runners.worker import logger
    29  from apache_beam.runners.worker import statesampler
    30  from apache_beam.utils.counters import CounterFactory
    31  
    32  
    33  class PerThreadLoggingContextTest(unittest.TestCase):
    34    def thread_check_attribute(self, name):
    35      self.assertFalse(name in logger.per_thread_worker_data.get_data())
    36      with logger.PerThreadLoggingContext(**{name: 'thread-value'}):
    37        self.assertEqual(
    38            logger.per_thread_worker_data.get_data()[name], 'thread-value')
    39      self.assertFalse(name in logger.per_thread_worker_data.get_data())
    40  
    41    def test_per_thread_attribute(self):
    42      self.assertFalse('xyz' in logger.per_thread_worker_data.get_data())
    43      with logger.PerThreadLoggingContext(xyz='value'):
    44        self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value')
    45        thread = threading.Thread(
    46            target=self.thread_check_attribute, args=('xyz', ))
    47        thread.start()
    48        thread.join()
    49        self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value')
    50      self.assertFalse('xyz' in logger.per_thread_worker_data.get_data())
    51  
    52    def test_set_when_undefined(self):
    53      self.assertFalse('xyz' in logger.per_thread_worker_data.get_data())
    54      with logger.PerThreadLoggingContext(xyz='value'):
    55        self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value')
    56      self.assertFalse('xyz' in logger.per_thread_worker_data.get_data())
    57  
    58    def test_set_when_already_defined(self):
    59      self.assertFalse('xyz' in logger.per_thread_worker_data.get_data())
    60      with logger.PerThreadLoggingContext(xyz='value'):
    61        self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value')
    62        with logger.PerThreadLoggingContext(xyz='value2'):
    63          self.assertEqual(
    64              logger.per_thread_worker_data.get_data()['xyz'], 'value2')
    65        self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value')
    66      self.assertFalse('xyz' in logger.per_thread_worker_data.get_data())
    67  
    68  
    69  class JsonLogFormatterTest(unittest.TestCase):
    70  
    71    SAMPLE_RECORD = {
    72        'created': 123456.789,
    73        'msecs': 789.654321,
    74        'msg': '%s:%d:%.2f',
    75        'args': ('xyz', 4, 3.14),
    76        'levelname': 'WARNING',
    77        'process': 'pid',
    78        'thread': 'tid',
    79        'name': 'name',
    80        'filename': 'file',
    81        'funcName': 'func',
    82        'exc_info': None
    83    }
    84  
    85    SAMPLE_OUTPUT = {
    86        'timestamp': {
    87            'seconds': 123456, 'nanos': 789654321
    88        },
    89        'severity': 'WARN',
    90        'message': 'xyz:4:3.14',
    91        'thread': 'pid:tid',
    92        'job': 'jobid',
    93        'worker': 'workerid',
    94        'logger': 'name:file:func'
    95    }
    96  
    97    def create_log_record(self, **kwargs):
    98      class Record(object):
    99        def __init__(self, **kwargs):
   100          for k, v in kwargs.items():
   101            setattr(self, k, v)
   102  
   103      return Record(**kwargs)
   104  
   105    def test_basic_record(self):
   106      formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid')
   107      record = self.create_log_record(**self.SAMPLE_RECORD)
   108      self.assertEqual(json.loads(formatter.format(record)), self.SAMPLE_OUTPUT)
   109  
   110    def execute_multiple_cases(self, test_cases):
   111      record = self.SAMPLE_RECORD
   112      output = self.SAMPLE_OUTPUT
   113      formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid')
   114  
   115      for case in test_cases:
   116        record['msg'] = case['msg']
   117        record['args'] = case['args']
   118        output['message'] = case['expected']
   119  
   120        self.assertEqual(
   121            json.loads(formatter.format(self.create_log_record(**record))),
   122            output)
   123  
   124    def test_record_with_format_character(self):
   125      test_cases = [
   126          {
   127              'msg': '%A', 'args': (), 'expected': '%A'
   128          },
   129          {
   130              'msg': '%s', 'args': (), 'expected': '%s'
   131          },
   132          {
   133              'msg': '%A%s', 'args': ('xy'), 'expected': '%A%s with args (xy)'
   134          },
   135          {
   136              'msg': '%s%s', 'args': (1), 'expected': '%s%s with args (1)'
   137          },
   138      ]
   139  
   140      self.execute_multiple_cases(test_cases)
   141  
   142    def test_record_with_arbitrary_messages(self):
   143      test_cases = [
   144          {
   145              'msg': ImportError('abc'), 'args': (), 'expected': 'abc'
   146          },
   147          {
   148              'msg': TypeError('abc %s'), 'args': ('def'), 'expected': 'abc def'
   149          },
   150      ]
   151  
   152      self.execute_multiple_cases(test_cases)
   153  
   154    def test_record_with_per_thread_info(self):
   155      self.maxDiff = None
   156      tracker = statesampler.StateSampler('stage', CounterFactory())
   157      statesampler.set_current_tracker(tracker)
   158      formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid')
   159      with logger.PerThreadLoggingContext(work_item_id='workitem'):
   160        with tracker.scoped_state('step', 'process'):
   161          record = self.create_log_record(**self.SAMPLE_RECORD)
   162          log_output = json.loads(formatter.format(record))
   163      expected_output = dict(self.SAMPLE_OUTPUT)
   164      expected_output.update({
   165          'work': 'workitem', 'stage': 'stage', 'step': 'step'
   166      })
   167      self.assertEqual(log_output, expected_output)
   168      statesampler.set_current_tracker(None)
   169  
   170    def test_nested_with_per_thread_info(self):
   171      self.maxDiff = None
   172      tracker = statesampler.StateSampler('stage', CounterFactory())
   173      statesampler.set_current_tracker(tracker)
   174      formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid')
   175      with logger.PerThreadLoggingContext(work_item_id='workitem'):
   176        with tracker.scoped_state('step1', 'process'):
   177          record = self.create_log_record(**self.SAMPLE_RECORD)
   178          log_output1 = json.loads(formatter.format(record))
   179  
   180          with tracker.scoped_state('step2', 'process'):
   181            record = self.create_log_record(**self.SAMPLE_RECORD)
   182            log_output2 = json.loads(formatter.format(record))
   183  
   184          record = self.create_log_record(**self.SAMPLE_RECORD)
   185          log_output3 = json.loads(formatter.format(record))
   186  
   187      statesampler.set_current_tracker(None)
   188      record = self.create_log_record(**self.SAMPLE_RECORD)
   189      log_output4 = json.loads(formatter.format(record))
   190  
   191      self.assertEqual(
   192          log_output1,
   193          dict(self.SAMPLE_OUTPUT, work='workitem', stage='stage', step='step1'))
   194      self.assertEqual(
   195          log_output2,
   196          dict(self.SAMPLE_OUTPUT, work='workitem', stage='stage', step='step2'))
   197      self.assertEqual(
   198          log_output3,
   199          dict(self.SAMPLE_OUTPUT, work='workitem', stage='stage', step='step1'))
   200      self.assertEqual(log_output4, self.SAMPLE_OUTPUT)
   201  
   202    def test_exception_record(self):
   203      formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid')
   204      try:
   205        raise ValueError('Something')
   206      except ValueError:
   207        attribs = dict(self.SAMPLE_RECORD)
   208        attribs.update({'exc_info': sys.exc_info()})
   209        record = self.create_log_record(**attribs)
   210      log_output = json.loads(formatter.format(record))
   211      # Check if exception type, its message, and stack trace information are in.
   212      exn_output = log_output.pop('exception')
   213      self.assertNotEqual(exn_output.find('ValueError: Something'), -1)
   214      self.assertNotEqual(exn_output.find('logger_test.py'), -1)
   215      self.assertEqual(log_output, self.SAMPLE_OUTPUT)
   216  
   217  
   218  if __name__ == '__main__':
   219    logging.getLogger().setLevel(logging.INFO)
   220    unittest.main()