github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/worker/data_plane_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Tests for apache_beam.runners.worker.data_plane.""" 19 20 # pytype: skip-file 21 22 import itertools 23 import logging 24 import time 25 import unittest 26 27 import grpc 28 29 from apache_beam.portability.api import beam_fn_api_pb2 30 from apache_beam.portability.api import beam_fn_api_pb2_grpc 31 from apache_beam.runners.worker import data_plane 32 from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor 33 from apache_beam.utils import thread_pool_executor 34 35 36 class DataChannelTest(unittest.TestCase): 37 def test_grpc_data_channel(self): 38 self._grpc_data_channel_test() 39 40 def test_time_based_flush_grpc_data_channel(self): 41 self._grpc_data_channel_test(True) 42 43 def _grpc_data_channel_test(self, time_based_flush=False): 44 if time_based_flush: 45 data_servicer = data_plane.BeamFnDataServicer( 46 data_buffer_time_limit_ms=100) 47 else: 48 data_servicer = data_plane.BeamFnDataServicer() 49 worker_id = 'worker_0' 50 data_channel_service = \ 51 data_servicer.get_conn_by_worker_id(worker_id) 52 53 server = grpc.server(thread_pool_executor.shared_unbounded_instance()) 54 beam_fn_api_pb2_grpc.add_BeamFnDataServicer_to_server(data_servicer, server) 55 test_port = server.add_insecure_port('[::]:0') 56 server.start() 57 58 grpc_channel = grpc.insecure_channel('localhost:%s' % test_port) 59 # Add workerId to the grpc channel 60 grpc_channel = grpc.intercept_channel( 61 grpc_channel, WorkerIdInterceptor(worker_id)) 62 data_channel_stub = beam_fn_api_pb2_grpc.BeamFnDataStub(grpc_channel) 63 if time_based_flush: 64 data_channel_client = data_plane.GrpcClientDataChannel( 65 data_channel_stub, data_buffer_time_limit_ms=100) 66 else: 67 data_channel_client = data_plane.GrpcClientDataChannel(data_channel_stub) 68 69 try: 70 self._data_channel_test( 71 data_channel_service, data_channel_client, time_based_flush) 72 finally: 73 data_channel_client.close() 74 data_channel_service.close() 75 data_channel_client.wait() 76 data_channel_service.wait() 77 78 def test_in_memory_data_channel(self): 79 channel = data_plane.InMemoryDataChannel() 80 self._data_channel_test(channel, channel.inverse()) 81 82 def _data_channel_test(self, server, client, time_based_flush=False): 83 self._data_channel_test_one_direction(server, client, time_based_flush) 84 self._data_channel_test_one_direction(client, server, time_based_flush) 85 86 def _data_channel_test_one_direction( 87 self, from_channel, to_channel, time_based_flush): 88 transform_1 = '1' 89 transform_2 = '2' 90 91 # Single write. 92 stream01 = from_channel.output_stream('0', transform_1) 93 stream01.write(b'abc') 94 if not time_based_flush: 95 stream01.close() 96 self.assertEqual( 97 list( 98 itertools.islice(to_channel.input_elements('0', [transform_1]), 1)), 99 [ 100 beam_fn_api_pb2.Elements.Data( 101 instruction_id='0', transform_id=transform_1, data=b'abc') 102 ]) 103 104 # Multiple interleaved writes to multiple instructions. 105 stream11 = from_channel.output_stream('1', transform_1) 106 stream11.write(b'abc') 107 stream21 = from_channel.output_stream('2', transform_1) 108 stream21.write(b'def') 109 if not time_based_flush: 110 stream11.close() 111 self.assertEqual( 112 list( 113 itertools.islice(to_channel.input_elements('1', [transform_1]), 1)), 114 [ 115 beam_fn_api_pb2.Elements.Data( 116 instruction_id='1', transform_id=transform_1, data=b'abc') 117 ]) 118 if time_based_flush: 119 # Wait to ensure stream21 is flushed before stream22. 120 # Because the flush callback is invoked periodically starting from when a 121 # stream is constructed, there is no guarantee that one stream's callback 122 # is called before the other. 123 time.sleep(0.1) 124 else: 125 stream21.close() 126 stream22 = from_channel.output_stream('2', transform_2) 127 stream22.write(b'ghi') 128 if not time_based_flush: 129 stream22.close() 130 self.assertEqual( 131 list( 132 itertools.islice( 133 to_channel.input_elements('2', [transform_1, transform_2]), 2)), 134 [ 135 beam_fn_api_pb2.Elements.Data( 136 instruction_id='2', transform_id=transform_1, data=b'def'), 137 beam_fn_api_pb2.Elements.Data( 138 instruction_id='2', transform_id=transform_2, data=b'ghi') 139 ]) 140 141 142 if __name__ == '__main__': 143 logging.getLogger().setLevel(logging.INFO) 144 unittest.main()