github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/worker/logger_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Tests for worker logging utilities.""" 19 20 # pytype: skip-file 21 22 import json 23 import logging 24 import sys 25 import threading 26 import unittest 27 28 from apache_beam.runners.worker import logger 29 from apache_beam.runners.worker import statesampler 30 from apache_beam.utils.counters import CounterFactory 31 32 33 class PerThreadLoggingContextTest(unittest.TestCase): 34 def thread_check_attribute(self, name): 35 self.assertFalse(name in logger.per_thread_worker_data.get_data()) 36 with logger.PerThreadLoggingContext(**{name: 'thread-value'}): 37 self.assertEqual( 38 logger.per_thread_worker_data.get_data()[name], 'thread-value') 39 self.assertFalse(name in logger.per_thread_worker_data.get_data()) 40 41 def test_per_thread_attribute(self): 42 self.assertFalse('xyz' in logger.per_thread_worker_data.get_data()) 43 with logger.PerThreadLoggingContext(xyz='value'): 44 self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value') 45 thread = threading.Thread( 46 target=self.thread_check_attribute, args=('xyz', )) 47 thread.start() 48 thread.join() 49 self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value') 50 self.assertFalse('xyz' in logger.per_thread_worker_data.get_data()) 51 52 def test_set_when_undefined(self): 53 self.assertFalse('xyz' in logger.per_thread_worker_data.get_data()) 54 with logger.PerThreadLoggingContext(xyz='value'): 55 self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value') 56 self.assertFalse('xyz' in logger.per_thread_worker_data.get_data()) 57 58 def test_set_when_already_defined(self): 59 self.assertFalse('xyz' in logger.per_thread_worker_data.get_data()) 60 with logger.PerThreadLoggingContext(xyz='value'): 61 self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value') 62 with logger.PerThreadLoggingContext(xyz='value2'): 63 self.assertEqual( 64 logger.per_thread_worker_data.get_data()['xyz'], 'value2') 65 self.assertEqual(logger.per_thread_worker_data.get_data()['xyz'], 'value') 66 self.assertFalse('xyz' in logger.per_thread_worker_data.get_data()) 67 68 69 class JsonLogFormatterTest(unittest.TestCase): 70 71 SAMPLE_RECORD = { 72 'created': 123456.789, 73 'msecs': 789.654321, 74 'msg': '%s:%d:%.2f', 75 'args': ('xyz', 4, 3.14), 76 'levelname': 'WARNING', 77 'process': 'pid', 78 'thread': 'tid', 79 'name': 'name', 80 'filename': 'file', 81 'funcName': 'func', 82 'exc_info': None 83 } 84 85 SAMPLE_OUTPUT = { 86 'timestamp': { 87 'seconds': 123456, 'nanos': 789654321 88 }, 89 'severity': 'WARN', 90 'message': 'xyz:4:3.14', 91 'thread': 'pid:tid', 92 'job': 'jobid', 93 'worker': 'workerid', 94 'logger': 'name:file:func' 95 } 96 97 def create_log_record(self, **kwargs): 98 class Record(object): 99 def __init__(self, **kwargs): 100 for k, v in kwargs.items(): 101 setattr(self, k, v) 102 103 return Record(**kwargs) 104 105 def test_basic_record(self): 106 formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid') 107 record = self.create_log_record(**self.SAMPLE_RECORD) 108 self.assertEqual(json.loads(formatter.format(record)), self.SAMPLE_OUTPUT) 109 110 def execute_multiple_cases(self, test_cases): 111 record = self.SAMPLE_RECORD 112 output = self.SAMPLE_OUTPUT 113 formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid') 114 115 for case in test_cases: 116 record['msg'] = case['msg'] 117 record['args'] = case['args'] 118 output['message'] = case['expected'] 119 120 self.assertEqual( 121 json.loads(formatter.format(self.create_log_record(**record))), 122 output) 123 124 def test_record_with_format_character(self): 125 test_cases = [ 126 { 127 'msg': '%A', 'args': (), 'expected': '%A' 128 }, 129 { 130 'msg': '%s', 'args': (), 'expected': '%s' 131 }, 132 { 133 'msg': '%A%s', 'args': ('xy'), 'expected': '%A%s with args (xy)' 134 }, 135 { 136 'msg': '%s%s', 'args': (1), 'expected': '%s%s with args (1)' 137 }, 138 ] 139 140 self.execute_multiple_cases(test_cases) 141 142 def test_record_with_arbitrary_messages(self): 143 test_cases = [ 144 { 145 'msg': ImportError('abc'), 'args': (), 'expected': 'abc' 146 }, 147 { 148 'msg': TypeError('abc %s'), 'args': ('def'), 'expected': 'abc def' 149 }, 150 ] 151 152 self.execute_multiple_cases(test_cases) 153 154 def test_record_with_per_thread_info(self): 155 self.maxDiff = None 156 tracker = statesampler.StateSampler('stage', CounterFactory()) 157 statesampler.set_current_tracker(tracker) 158 formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid') 159 with logger.PerThreadLoggingContext(work_item_id='workitem'): 160 with tracker.scoped_state('step', 'process'): 161 record = self.create_log_record(**self.SAMPLE_RECORD) 162 log_output = json.loads(formatter.format(record)) 163 expected_output = dict(self.SAMPLE_OUTPUT) 164 expected_output.update({ 165 'work': 'workitem', 'stage': 'stage', 'step': 'step' 166 }) 167 self.assertEqual(log_output, expected_output) 168 statesampler.set_current_tracker(None) 169 170 def test_nested_with_per_thread_info(self): 171 self.maxDiff = None 172 tracker = statesampler.StateSampler('stage', CounterFactory()) 173 statesampler.set_current_tracker(tracker) 174 formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid') 175 with logger.PerThreadLoggingContext(work_item_id='workitem'): 176 with tracker.scoped_state('step1', 'process'): 177 record = self.create_log_record(**self.SAMPLE_RECORD) 178 log_output1 = json.loads(formatter.format(record)) 179 180 with tracker.scoped_state('step2', 'process'): 181 record = self.create_log_record(**self.SAMPLE_RECORD) 182 log_output2 = json.loads(formatter.format(record)) 183 184 record = self.create_log_record(**self.SAMPLE_RECORD) 185 log_output3 = json.loads(formatter.format(record)) 186 187 statesampler.set_current_tracker(None) 188 record = self.create_log_record(**self.SAMPLE_RECORD) 189 log_output4 = json.loads(formatter.format(record)) 190 191 self.assertEqual( 192 log_output1, 193 dict(self.SAMPLE_OUTPUT, work='workitem', stage='stage', step='step1')) 194 self.assertEqual( 195 log_output2, 196 dict(self.SAMPLE_OUTPUT, work='workitem', stage='stage', step='step2')) 197 self.assertEqual( 198 log_output3, 199 dict(self.SAMPLE_OUTPUT, work='workitem', stage='stage', step='step1')) 200 self.assertEqual(log_output4, self.SAMPLE_OUTPUT) 201 202 def test_exception_record(self): 203 formatter = logger.JsonLogFormatter(job_id='jobid', worker_id='workerid') 204 try: 205 raise ValueError('Something') 206 except ValueError: 207 attribs = dict(self.SAMPLE_RECORD) 208 attribs.update({'exc_info': sys.exc_info()}) 209 record = self.create_log_record(**attribs) 210 log_output = json.loads(formatter.format(record)) 211 # Check if exception type, its message, and stack trace information are in. 212 exn_output = log_output.pop('exception') 213 self.assertNotEqual(exn_output.find('ValueError: Something'), -1) 214 self.assertNotEqual(exn_output.find('logger_test.py'), -1) 215 self.assertEqual(log_output, self.SAMPLE_OUTPUT) 216 217 218 if __name__ == '__main__': 219 logging.getLogger().setLevel(logging.INFO) 220 unittest.main()