github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/worker/log_handler_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 # pytype: skip-file 19 20 import logging 21 import re 22 import unittest 23 24 import grpc 25 26 from apache_beam.portability.api import beam_fn_api_pb2 27 from apache_beam.portability.api import beam_fn_api_pb2_grpc 28 from apache_beam.portability.api import endpoints_pb2 29 from apache_beam.runners.common import NameContext 30 from apache_beam.runners.worker import log_handler 31 from apache_beam.runners.worker import statesampler 32 from apache_beam.utils import thread_pool_executor 33 34 _LOGGER = logging.getLogger(__name__) 35 36 37 class BeamFnLoggingServicer(beam_fn_api_pb2_grpc.BeamFnLoggingServicer): 38 def __init__(self): 39 self.log_records_received = [] 40 41 def Logging(self, request_iterator, context): 42 43 for log_record in request_iterator: 44 self.log_records_received.append(log_record) 45 46 yield beam_fn_api_pb2.LogControl() 47 48 49 class FnApiLogRecordHandlerTest(unittest.TestCase): 50 def setUp(self): 51 self.test_logging_service = BeamFnLoggingServicer() 52 self.server = grpc.server(thread_pool_executor.shared_unbounded_instance()) 53 beam_fn_api_pb2_grpc.add_BeamFnLoggingServicer_to_server( 54 self.test_logging_service, self.server) 55 self.test_port = self.server.add_insecure_port('[::]:0') 56 self.server.start() 57 58 self.logging_service_descriptor = endpoints_pb2.ApiServiceDescriptor() 59 self.logging_service_descriptor.url = 'localhost:%s' % self.test_port 60 self.fn_log_handler = log_handler.FnApiLogRecordHandler( 61 self.logging_service_descriptor) 62 logging.getLogger().setLevel(logging.INFO) 63 logging.getLogger().addHandler(self.fn_log_handler) 64 65 def tearDown(self): 66 # wait upto 5 seconds. 67 self.server.stop(5) 68 69 def _verify_fn_log_handler(self, num_log_entries): 70 msg = 'Testing fn logging' 71 _LOGGER.debug('Debug Message 1') 72 for idx in range(num_log_entries): 73 _LOGGER.info('%s: %s', msg, idx) 74 _LOGGER.debug('Debug Message 2') 75 76 # Wait for logs to be sent to server. 77 self.fn_log_handler.close() 78 79 num_received_log_entries = 0 80 for outer in self.test_logging_service.log_records_received: 81 for log_entry in outer.log_entries: 82 self.assertEqual( 83 beam_fn_api_pb2.LogEntry.Severity.INFO, log_entry.severity) 84 self.assertEqual( 85 '%s: %s' % (msg, num_received_log_entries), log_entry.message) 86 self.assertTrue( 87 re.match(r'.*log_handler_test.py:\d+', log_entry.log_location), 88 log_entry.log_location) 89 self.assertGreater(log_entry.timestamp.seconds, 0) 90 self.assertGreaterEqual(log_entry.timestamp.nanos, 0) 91 num_received_log_entries += 1 92 93 self.assertEqual(num_received_log_entries, num_log_entries) 94 95 def assertContains(self, haystack, needle): 96 self.assertTrue( 97 needle in haystack, 'Expected %r to contain %r.' % (haystack, needle)) 98 99 def test_exc_info(self): 100 try: 101 raise ValueError('some message') 102 except ValueError: 103 _LOGGER.error('some error', exc_info=True) 104 105 self.fn_log_handler.close() 106 107 log_entry = self.test_logging_service.log_records_received[0].log_entries[0] 108 self.assertContains(log_entry.message, 'some error') 109 self.assertContains(log_entry.trace, 'some message') 110 self.assertContains(log_entry.trace, 'log_handler_test.py') 111 112 def test_format_bad_message(self): 113 # We specifically emit to the handler directly since we don't want to emit 114 # to all handlers in general since we know that this record will raise an 115 # exception during formatting. 116 self.fn_log_handler.emit( 117 logging.LogRecord( 118 'name', 119 logging.ERROR, 120 'pathname', 121 777, 122 'TestLog %d', (None, ), 123 exc_info=None)) 124 self.fn_log_handler.close() 125 log_entry = self.test_logging_service.log_records_received[0].log_entries[0] 126 self.assertContains( 127 log_entry.message, 128 "Failed to format 'TestLog %d' with args '(None,)' during logging.") 129 130 def test_context(self): 131 try: 132 with statesampler.instruction_id('A'): 133 tracker = statesampler.for_test() 134 with tracker.scoped_state(NameContext('name', 'tid'), 'stage'): 135 _LOGGER.info('message a') 136 with statesampler.instruction_id('B'): 137 _LOGGER.info('message b') 138 _LOGGER.info('message c') 139 140 self.fn_log_handler.close() 141 a, b, c = sum( 142 [list(logs.log_entries) 143 for logs in self.test_logging_service.log_records_received], []) 144 145 self.assertEqual(a.instruction_id, 'A') 146 self.assertEqual(b.instruction_id, 'B') 147 self.assertEqual(c.instruction_id, '') 148 149 self.assertEqual(a.transform_id, 'tid') 150 self.assertEqual(b.transform_id, '') 151 self.assertEqual(c.transform_id, '') 152 153 finally: 154 statesampler.set_current_tracker(None) 155 156 157 # Test cases. 158 data = { 159 'one_batch': log_handler.FnApiLogRecordHandler._MAX_BATCH_SIZE - 47, 160 'exact_multiple': log_handler.FnApiLogRecordHandler._MAX_BATCH_SIZE, 161 'multi_batch': log_handler.FnApiLogRecordHandler._MAX_BATCH_SIZE * 3 + 47 162 } 163 164 165 def _create_test(name, num_logs): 166 setattr( 167 FnApiLogRecordHandlerTest, 168 'test_%s' % name, 169 lambda self: self._verify_fn_log_handler(num_logs)) 170 171 172 for test_name, num_logs_entries in data.items(): 173 _create_test(test_name, num_logs_entries) 174 175 if __name__ == '__main__': 176 unittest.main()