github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/worker/data_plane_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Tests for apache_beam.runners.worker.data_plane."""
    19  
    20  # pytype: skip-file
    21  
    22  import itertools
    23  import logging
    24  import time
    25  import unittest
    26  
    27  import grpc
    28  
    29  from apache_beam.portability.api import beam_fn_api_pb2
    30  from apache_beam.portability.api import beam_fn_api_pb2_grpc
    31  from apache_beam.runners.worker import data_plane
    32  from apache_beam.runners.worker.worker_id_interceptor import WorkerIdInterceptor
    33  from apache_beam.utils import thread_pool_executor
    34  
    35  
    36  class DataChannelTest(unittest.TestCase):
    37    def test_grpc_data_channel(self):
    38      self._grpc_data_channel_test()
    39  
    40    def test_time_based_flush_grpc_data_channel(self):
    41      self._grpc_data_channel_test(True)
    42  
    43    def _grpc_data_channel_test(self, time_based_flush=False):
    44      if time_based_flush:
    45        data_servicer = data_plane.BeamFnDataServicer(
    46            data_buffer_time_limit_ms=100)
    47      else:
    48        data_servicer = data_plane.BeamFnDataServicer()
    49      worker_id = 'worker_0'
    50      data_channel_service = \
    51        data_servicer.get_conn_by_worker_id(worker_id)
    52  
    53      server = grpc.server(thread_pool_executor.shared_unbounded_instance())
    54      beam_fn_api_pb2_grpc.add_BeamFnDataServicer_to_server(data_servicer, server)
    55      test_port = server.add_insecure_port('[::]:0')
    56      server.start()
    57  
    58      grpc_channel = grpc.insecure_channel('localhost:%s' % test_port)
    59      # Add workerId to the grpc channel
    60      grpc_channel = grpc.intercept_channel(
    61          grpc_channel, WorkerIdInterceptor(worker_id))
    62      data_channel_stub = beam_fn_api_pb2_grpc.BeamFnDataStub(grpc_channel)
    63      if time_based_flush:
    64        data_channel_client = data_plane.GrpcClientDataChannel(
    65            data_channel_stub, data_buffer_time_limit_ms=100)
    66      else:
    67        data_channel_client = data_plane.GrpcClientDataChannel(data_channel_stub)
    68  
    69      try:
    70        self._data_channel_test(
    71            data_channel_service, data_channel_client, time_based_flush)
    72      finally:
    73        data_channel_client.close()
    74        data_channel_service.close()
    75        data_channel_client.wait()
    76        data_channel_service.wait()
    77  
    78    def test_in_memory_data_channel(self):
    79      channel = data_plane.InMemoryDataChannel()
    80      self._data_channel_test(channel, channel.inverse())
    81  
    82    def _data_channel_test(self, server, client, time_based_flush=False):
    83      self._data_channel_test_one_direction(server, client, time_based_flush)
    84      self._data_channel_test_one_direction(client, server, time_based_flush)
    85  
    86    def _data_channel_test_one_direction(
    87        self, from_channel, to_channel, time_based_flush):
    88      transform_1 = '1'
    89      transform_2 = '2'
    90  
    91      # Single write.
    92      stream01 = from_channel.output_stream('0', transform_1)
    93      stream01.write(b'abc')
    94      if not time_based_flush:
    95        stream01.close()
    96      self.assertEqual(
    97          list(
    98              itertools.islice(to_channel.input_elements('0', [transform_1]), 1)),
    99          [
   100              beam_fn_api_pb2.Elements.Data(
   101                  instruction_id='0', transform_id=transform_1, data=b'abc')
   102          ])
   103  
   104      # Multiple interleaved writes to multiple instructions.
   105      stream11 = from_channel.output_stream('1', transform_1)
   106      stream11.write(b'abc')
   107      stream21 = from_channel.output_stream('2', transform_1)
   108      stream21.write(b'def')
   109      if not time_based_flush:
   110        stream11.close()
   111      self.assertEqual(
   112          list(
   113              itertools.islice(to_channel.input_elements('1', [transform_1]), 1)),
   114          [
   115              beam_fn_api_pb2.Elements.Data(
   116                  instruction_id='1', transform_id=transform_1, data=b'abc')
   117          ])
   118      if time_based_flush:
   119        # Wait to ensure stream21 is flushed before stream22.
   120        # Because the flush callback is invoked periodically starting from when a
   121        # stream is constructed, there is no guarantee that one stream's callback
   122        # is called before the other.
   123        time.sleep(0.1)
   124      else:
   125        stream21.close()
   126      stream22 = from_channel.output_stream('2', transform_2)
   127      stream22.write(b'ghi')
   128      if not time_based_flush:
   129        stream22.close()
   130      self.assertEqual(
   131          list(
   132              itertools.islice(
   133                  to_channel.input_elements('2', [transform_1, transform_2]), 2)),
   134          [
   135              beam_fn_api_pb2.Elements.Data(
   136                  instruction_id='2', transform_id=transform_1, data=b'def'),
   137              beam_fn_api_pb2.Elements.Data(
   138                  instruction_id='2', transform_id=transform_2, data=b'ghi')
   139          ])
   140  
   141  
   142  if __name__ == '__main__':
   143    logging.getLogger().setLevel(logging.INFO)
   144    unittest.main()