github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/external_java.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Tests for the Java external transforms."""
    19  
    20  import argparse
    21  import logging
    22  import subprocess
    23  import sys
    24  
    25  import grpc
    26  from mock import patch
    27  
    28  import apache_beam as beam
    29  from apache_beam.options.pipeline_options import PipelineOptions
    30  from apache_beam.testing.test_pipeline import TestPipeline
    31  from apache_beam.testing.util import assert_that
    32  from apache_beam.testing.util import equal_to
    33  from apache_beam.transforms.external import ImplicitSchemaPayloadBuilder
    34  
    35  # Protect against environments where apitools library is not available.
    36  # pylint: disable=wrong-import-order, wrong-import-position
    37  try:
    38    from apache_beam.runners.dataflow.internal import apiclient as _apiclient
    39  except ImportError:
    40    apiclient = None
    41  else:
    42    apiclient = _apiclient
    43  # pylint: enable=wrong-import-order, wrong-import-position
    44  
    45  
    46  class JavaExternalTransformTest(object):
    47  
    48    # This will be overwritten if set via a flag.
    49    expansion_service_jar = None  # type: str
    50    expansion_service_port = None  # type: int
    51  
    52    class _RunWithExpansion(object):
    53      def __init__(self):
    54        self._server = None
    55  
    56      def __enter__(self):
    57        if not (JavaExternalTransformTest.expansion_service_jar or
    58                JavaExternalTransformTest.expansion_service_port):
    59          raise RuntimeError('No expansion service jar or port provided.')
    60  
    61        JavaExternalTransformTest.expansion_service_port = (
    62            JavaExternalTransformTest.expansion_service_port or 8091)
    63  
    64        jar = JavaExternalTransformTest.expansion_service_jar
    65        port = JavaExternalTransformTest.expansion_service_port
    66  
    67        # Start the java server and wait for it to be ready.
    68        if jar:
    69          self._server = subprocess.Popen(['java', '-jar', jar, str(port)])
    70  
    71        address = 'localhost:%s' % str(port)
    72  
    73        with grpc.insecure_channel(address) as channel:
    74          grpc.channel_ready_future(channel).result()
    75  
    76      def __exit__(self, type, value, traceback):
    77        if self._server:
    78          self._server.kill()
    79          self._server = None
    80  
    81    @staticmethod
    82    def test_java_expansion_dataflow():
    83      if apiclient is None:
    84        return
    85  
    86      # This test does not actually running the pipeline in Dataflow. It just
    87      # tests the translation to a Dataflow job request.
    88  
    89      with patch.object(apiclient.DataflowApplicationClient,
    90                        'create_job') as mock_create_job:
    91        with JavaExternalTransformTest._RunWithExpansion():
    92          pipeline_options = PipelineOptions([
    93              '--runner=DataflowRunner',
    94              '--project=dummyproject',
    95              '--region=some-region1',
    96              '--experiments=beam_fn_api',
    97              '--temp_location=gs://dummybucket/'
    98          ])
    99  
   100          # Run a simple count-filtered-letters pipeline.
   101          JavaExternalTransformTest.run_pipeline(
   102              pipeline_options,
   103              JavaExternalTransformTest.expansion_service_port,
   104              False)
   105  
   106          mock_args = mock_create_job.call_args_list
   107          assert mock_args
   108          args, kwargs = mock_args[0]
   109          job = args[0]
   110          job_str = '%s' % job
   111          assert 'beam:transforms:xlang:filter_less_than_eq' in job_str
   112  
   113    @staticmethod
   114    def run_pipeline_with_expansion_service(pipeline_options):
   115      with JavaExternalTransformTest._RunWithExpansion():
   116        # Run a simple count-filtered-letters pipeline.
   117        JavaExternalTransformTest.run_pipeline(
   118            pipeline_options,
   119            JavaExternalTransformTest.expansion_service_port,
   120            True)
   121  
   122    @staticmethod
   123    def run_pipeline(pipeline_options, expansion_service, wait_until_finish=True):
   124      # The actual definitions of these transforms is in
   125      # org.apache.beam.runners.core.construction.TestExpansionService.
   126      TEST_COUNT_URN = "beam:transforms:xlang:count"
   127      TEST_FILTER_URN = "beam:transforms:xlang:filter_less_than_eq"
   128  
   129      # Run a simple count-filtered-letters pipeline.
   130      p = TestPipeline(options=pipeline_options)
   131  
   132      if isinstance(expansion_service, int):
   133        # Only the port was specified.
   134        expansion_service = 'localhost:%s' % str(expansion_service)
   135  
   136      res = (
   137          p
   138          | beam.Create(list('aaabccxyyzzz'))
   139          | beam.Map(str)
   140          | beam.ExternalTransform(
   141              TEST_FILTER_URN,
   142              ImplicitSchemaPayloadBuilder({'data': u'middle'}),
   143              expansion_service)
   144          | beam.ExternalTransform(TEST_COUNT_URN, None, expansion_service)
   145          | beam.Map(lambda kv: '%s: %s' % kv))
   146  
   147      assert_that(res, equal_to(['a: 3', 'b: 1', 'c: 2']))
   148  
   149      result = p.run()
   150      if wait_until_finish:
   151        result.wait_until_finish()
   152  
   153  
   154  if __name__ == '__main__':
   155    logging.getLogger().setLevel(logging.INFO)
   156    parser = argparse.ArgumentParser()
   157    parser.add_argument('--expansion_service_jar')
   158    parser.add_argument('--expansion_service_port')
   159    parser.add_argument('--expansion_service_target')
   160    parser.add_argument('--expansion_service_target_appendix')
   161    known_args, pipeline_args = parser.parse_known_args(sys.argv)
   162  
   163    if known_args.expansion_service_jar:
   164      JavaExternalTransformTest.expansion_service_jar = (
   165          known_args.expansion_service_jar)
   166      JavaExternalTransformTest.expansion_service_port = int(
   167          known_args.expansion_service_port)
   168      pipeline_options = PipelineOptions(pipeline_args)
   169      JavaExternalTransformTest.run_pipeline_with_expansion_service(
   170          pipeline_options)
   171    elif known_args.expansion_service_target:
   172      pipeline_options = PipelineOptions(pipeline_args)
   173      JavaExternalTransformTest.run_pipeline(
   174          pipeline_options,
   175          beam.transforms.external.BeamJarExpansionService(
   176              known_args.expansion_service_target,
   177              gradle_appendix=known_args.expansion_service_target_appendix))
   178    else:
   179      raise RuntimeError(
   180          "--expansion_service_jar or --expansion_service_target "
   181          "should be provided.")