github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/portability/expansion_service_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  # pytype: skip-file
    18  
    19  import argparse
    20  import logging
    21  import pickle
    22  import signal
    23  import sys
    24  import typing
    25  
    26  import grpc
    27  
    28  import apache_beam as beam
    29  import apache_beam.transforms.combiners as combine
    30  from apache_beam.coders import RowCoder
    31  from apache_beam.pipeline import PipelineOptions
    32  from apache_beam.portability.api import beam_artifact_api_pb2_grpc
    33  from apache_beam.portability.api import beam_expansion_api_pb2_grpc
    34  from apache_beam.portability.api import external_transforms_pb2
    35  from apache_beam.runners.portability import artifact_service
    36  from apache_beam.runners.portability import expansion_service
    37  from apache_beam.runners.portability.stager import Stager
    38  from apache_beam.transforms import fully_qualified_named_transform
    39  from apache_beam.transforms import ptransform
    40  from apache_beam.transforms.environments import PyPIArtifactRegistry
    41  from apache_beam.transforms.external import ImplicitSchemaPayloadBuilder
    42  from apache_beam.utils import thread_pool_executor
    43  
    44  # This script provides an expansion service and example ptransforms for running
    45  # external transform test cases. See external_test.py for details.
    46  
    47  _LOGGER = logging.getLogger(__name__)
    48  
    49  TEST_PREFIX_URN = "beam:transforms:xlang:test:prefix"
    50  TEST_MULTI_URN = "beam:transforms:xlang:test:multi"
    51  TEST_GBK_URN = "beam:transforms:xlang:test:gbk"
    52  TEST_CGBK_URN = "beam:transforms:xlang:test:cgbk"
    53  TEST_COMGL_URN = "beam:transforms:xlang:test:comgl"
    54  TEST_COMPK_URN = "beam:transforms:xlang:test:compk"
    55  TEST_FLATTEN_URN = "beam:transforms:xlang:test:flatten"
    56  TEST_PARTITION_URN = "beam:transforms:xlang:test:partition"
    57  TEST_PYTHON_BS4_URN = "beam:transforms:xlang:test:python_bs4"
    58  
    59  # A transform that does not produce an output.
    60  TEST_NO_OUTPUT_URN = "beam:transforms:xlang:test:nooutput"
    61  
    62  
    63  @ptransform.PTransform.register_urn('beam:transforms:xlang:count', None)
    64  class CountPerElementTransform(ptransform.PTransform):
    65    def expand(self, pcoll):
    66      return pcoll | combine.Count.PerElement()
    67  
    68    def to_runner_api_parameter(self, unused_context):
    69      return 'beam:transforms:xlang:count', None
    70  
    71    @staticmethod
    72    def from_runner_api_parameter(
    73        unused_ptransform, unused_parameter, unused_context):
    74      return CountPerElementTransform()
    75  
    76  
    77  @ptransform.PTransform.register_urn(
    78      'beam:transforms:xlang:filter_less_than_eq', bytes)
    79  class FilterLessThanTransform(ptransform.PTransform):
    80    def __init__(self, payload):
    81      self._payload = payload
    82  
    83    def expand(self, pcoll):
    84      return (
    85          pcoll | beam.Filter(
    86              lambda elem, target: elem <= target, int(ord(self._payload[0]))))
    87  
    88    def to_runner_api_parameter(self, unused_context):
    89      return (
    90          'beam:transforms:xlang:filter_less_than', self._payload.encode('utf8'))
    91  
    92    @staticmethod
    93    def from_runner_api_parameter(unused_ptransform, payload, unused_context):
    94      return FilterLessThanTransform(payload.decode('utf8'))
    95  
    96  
    97  @ptransform.PTransform.register_urn(TEST_PREFIX_URN, None)
    98  @beam.typehints.with_output_types(str)
    99  class PrefixTransform(ptransform.PTransform):
   100    def __init__(self, payload):
   101      self._payload = payload
   102  
   103    def expand(self, pcoll):
   104      return pcoll | 'TestLabel' >> beam.Map(
   105          lambda x: '{}{}'.format(self._payload, x))
   106  
   107    def to_runner_api_parameter(self, unused_context):
   108      return TEST_PREFIX_URN, ImplicitSchemaPayloadBuilder(
   109          {'data': self._payload}).payload()
   110  
   111    @staticmethod
   112    def from_runner_api_parameter(unused_ptransform, payload, unused_context):
   113      return PrefixTransform(parse_string_payload(payload)['data'])
   114  
   115  
   116  @ptransform.PTransform.register_urn(TEST_MULTI_URN, None)
   117  class MutltiTransform(ptransform.PTransform):
   118    def expand(self, pcolls):
   119      return {
   120          'main': (pcolls['main1'], pcolls['main2'])
   121          | beam.Flatten()
   122          | beam.Map(lambda x, s: x + s, beam.pvalue.AsSingleton(
   123              pcolls['side'])).with_output_types(str),
   124          'side': pcolls['side']
   125          | beam.Map(lambda x: x + x).with_output_types(str),
   126      }
   127  
   128    def to_runner_api_parameter(self, unused_context):
   129      return TEST_MULTI_URN, None
   130  
   131    @staticmethod
   132    def from_runner_api_parameter(
   133        unused_ptransform, unused_parameter, unused_context):
   134      return MutltiTransform()
   135  
   136  
   137  @ptransform.PTransform.register_urn(TEST_GBK_URN, None)
   138  class GBKTransform(ptransform.PTransform):
   139    def expand(self, pcoll):
   140      return pcoll | 'TestLabel' >> beam.GroupByKey()
   141  
   142    def to_runner_api_parameter(self, unused_context):
   143      return TEST_GBK_URN, None
   144  
   145    @staticmethod
   146    def from_runner_api_parameter(
   147        unused_ptransform, unused_parameter, unused_context):
   148      return GBKTransform()
   149  
   150  
   151  @ptransform.PTransform.register_urn(TEST_CGBK_URN, None)
   152  class CoGBKTransform(ptransform.PTransform):
   153    class ConcatFn(beam.DoFn):
   154      def process(self, element):
   155        (k, v) = element
   156        return [(k, v['col1'] + v['col2'])]
   157  
   158    def expand(self, pcoll):
   159      return pcoll \
   160             | beam.CoGroupByKey() \
   161             | beam.ParDo(self.ConcatFn()).with_output_types(
   162                 typing.Tuple[int, typing.Iterable[str]])
   163  
   164    def to_runner_api_parameter(self, unused_context):
   165      return TEST_CGBK_URN, None
   166  
   167    @staticmethod
   168    def from_runner_api_parameter(
   169        unused_ptransform, unused_parameter, unused_context):
   170      return CoGBKTransform()
   171  
   172  
   173  @ptransform.PTransform.register_urn(TEST_COMGL_URN, None)
   174  class CombineGloballyTransform(ptransform.PTransform):
   175    def expand(self, pcoll):
   176      return pcoll \
   177             | beam.CombineGlobally(sum).with_output_types(int)
   178  
   179    def to_runner_api_parameter(self, unused_context):
   180      return TEST_COMGL_URN, None
   181  
   182    @staticmethod
   183    def from_runner_api_parameter(
   184        unused_ptransform, unused_parameter, unused_context):
   185      return CombineGloballyTransform()
   186  
   187  
   188  @ptransform.PTransform.register_urn(TEST_COMPK_URN, None)
   189  class CombinePerKeyTransform(ptransform.PTransform):
   190    def expand(self, pcoll):
   191      output = pcoll \
   192             | beam.CombinePerKey(sum)
   193      # TODO: Use `with_output_types` instead of explicitly
   194      #  assigning to `.element_type` after fixing BEAM-12872
   195      output.element_type = beam.typehints.Tuple[str, int]
   196      return output
   197  
   198    def to_runner_api_parameter(self, unused_context):
   199      return TEST_COMPK_URN, None
   200  
   201    @staticmethod
   202    def from_runner_api_parameter(
   203        unused_ptransform, unused_parameter, unused_context):
   204      return CombinePerKeyTransform()
   205  
   206  
   207  @ptransform.PTransform.register_urn(TEST_FLATTEN_URN, None)
   208  class FlattenTransform(ptransform.PTransform):
   209    def expand(self, pcoll):
   210      return pcoll.values() | beam.Flatten().with_output_types(int)
   211  
   212    def to_runner_api_parameter(self, unused_context):
   213      return TEST_FLATTEN_URN, None
   214  
   215    @staticmethod
   216    def from_runner_api_parameter(
   217        unused_ptransform, unused_parameter, unused_context):
   218      return FlattenTransform()
   219  
   220  
   221  @ptransform.PTransform.register_urn(TEST_PARTITION_URN, None)
   222  class PartitionTransform(ptransform.PTransform):
   223    def expand(self, pcoll):
   224      col1, col2 = pcoll | beam.Partition(
   225          lambda elem, n: 0 if elem % 2 == 0 else 1, 2)
   226      typed_col1 = col1 | beam.Map(lambda x: x).with_output_types(int)
   227      typed_col2 = col2 | beam.Map(lambda x: x).with_output_types(int)
   228      return {'0': typed_col1, '1': typed_col2}
   229  
   230    def to_runner_api_parameter(self, unused_context):
   231      return TEST_PARTITION_URN, None
   232  
   233    @staticmethod
   234    def from_runner_api_parameter(
   235        unused_ptransform, unused_parameter, unused_context):
   236      return PartitionTransform()
   237  
   238  
   239  class ExtractHtmlTitleDoFn(beam.DoFn):
   240    def process(self, element):
   241      from bs4 import BeautifulSoup
   242      soup = BeautifulSoup(element, 'html.parser')
   243      return [soup.title.string]
   244  
   245  
   246  @ptransform.PTransform.register_urn(TEST_PYTHON_BS4_URN, None)
   247  class ExtractHtmlTitleTransform(ptransform.PTransform):
   248    def expand(self, pcoll):
   249      return pcoll | beam.ParDo(ExtractHtmlTitleDoFn()).with_output_types(str)
   250  
   251    def to_runner_api_parameter(self, unused_context):
   252      return TEST_PYTHON_BS4_URN, None
   253  
   254    @staticmethod
   255    def from_runner_api_parameter(
   256        unused_ptransform, unused_parameter, unused_context):
   257      return ExtractHtmlTitleTransform()
   258  
   259  
   260  @ptransform.PTransform.register_urn('payload', bytes)
   261  class PayloadTransform(ptransform.PTransform):
   262    def __init__(self, payload):
   263      self._payload = payload
   264  
   265    def expand(self, pcoll):
   266      return pcoll | beam.Map(lambda x, s: x + s, self._payload)
   267  
   268    def to_runner_api_parameter(self, unused_context):
   269      return b'payload', self._payload.encode('ascii')
   270  
   271    @staticmethod
   272    def from_runner_api_parameter(unused_ptransform, payload, unused_context):
   273      return PayloadTransform(payload.decode('ascii'))
   274  
   275  
   276  @ptransform.PTransform.register_urn('map_to_union_types', None)
   277  class MapToUnionTypesTransform(ptransform.PTransform):
   278    class CustomDoFn(beam.DoFn):
   279      def process(self, element):
   280        if element == 1:
   281          return ['1']
   282        elif element == 2:
   283          return [2]
   284        else:
   285          return [3.0]
   286  
   287    def expand(self, pcoll):
   288      return pcoll | beam.ParDo(self.CustomDoFn())
   289  
   290    def to_runner_api_parameter(self, unused_context):
   291      return b'map_to_union_types', None
   292  
   293    @staticmethod
   294    def from_runner_api_parameter(
   295        unused_ptransform, unused_payload, unused_context):
   296      return MapToUnionTypesTransform()
   297  
   298  
   299  @ptransform.PTransform.register_urn('fib', bytes)
   300  class FibTransform(ptransform.PTransform):
   301    def __init__(self, level):
   302      self._level = level
   303  
   304    def expand(self, p):
   305      if self._level <= 2:
   306        return p | beam.Create([1])
   307      else:
   308        a = p | 'A' >> beam.ExternalTransform(
   309            'fib',
   310            str(self._level - 1).encode('ascii'),
   311            expansion_service.ExpansionServiceServicer())
   312        b = p | 'B' >> beam.ExternalTransform(
   313            'fib',
   314            str(self._level - 2).encode('ascii'),
   315            expansion_service.ExpansionServiceServicer())
   316        return ((a, b)
   317                | beam.Flatten()
   318                | beam.CombineGlobally(sum).without_defaults())
   319  
   320    def to_runner_api_parameter(self, unused_context):
   321      return 'fib', str(self._level).encode('ascii')
   322  
   323    @staticmethod
   324    def from_runner_api_parameter(unused_ptransform, level, unused_context):
   325      return FibTransform(int(level.decode('ascii')))
   326  
   327  
   328  @ptransform.PTransform.register_urn(TEST_NO_OUTPUT_URN, None)
   329  class NoOutputTransform(ptransform.PTransform):
   330    def expand(self, pcoll):
   331      def log_val(val):
   332        logging.debug('Got value: %r', val)
   333  
   334      # Logging without returning anything
   335      _ = (pcoll | 'TestLabel' >> beam.ParDo(log_val))
   336  
   337    def to_runner_api_parameter(self, unused_context):
   338      return TEST_NO_OUTPUT_URN, None
   339  
   340    @staticmethod
   341    def from_runner_api_parameter(unused_ptransform, payload, unused_context):
   342      return NoOutputTransform(parse_string_payload(payload)['data'])
   343  
   344  
   345  def parse_string_payload(input_byte):
   346    payload = external_transforms_pb2.ExternalConfigurationPayload()
   347    payload.ParseFromString(input_byte)
   348  
   349    return RowCoder(payload.schema).decode(payload.payload)._asdict()
   350  
   351  
   352  def create_test_sklearn_model(file_name):
   353    from sklearn import svm
   354    x = [[0, 0], [1, 1]]
   355    y = [0, 1]
   356    model = svm.SVC()
   357    model.fit(x, y)
   358    with open(file_name, 'wb') as file:
   359      pickle.dump(model, file)
   360  
   361  
   362  def update_sklearn_model_dependency(env):
   363    model_file = "/tmp/sklearn_test_model"
   364    staged_name = "sklearn_model"
   365    create_test_sklearn_model(model_file)
   366    env._artifacts.append(
   367        Stager._create_file_stage_to_artifact(model_file, staged_name))
   368  
   369  
   370  server = None
   371  
   372  
   373  def cleanup(unused_signum, unused_frame):
   374    _LOGGER.info('Shutting down expansion service.')
   375    server.stop(None)
   376  
   377  
   378  def main(unused_argv):
   379    # TODO: use the regular expansion service (expansion_service_main) instead of
   380    # this custom service for testing.
   381    PyPIArtifactRegistry.register_artifact('beautifulsoup4', '>=4.9,<5.0')
   382    parser = argparse.ArgumentParser()
   383    parser.add_argument(
   384        '-p', '--port', type=int, help='port on which to serve the job api')
   385    parser.add_argument('--fully_qualified_name_glob', default=None)
   386    options = parser.parse_args()
   387  
   388    global server
   389    with fully_qualified_named_transform.FullyQualifiedNamedTransform.with_filter(
   390        options.fully_qualified_name_glob):
   391      server = grpc.server(thread_pool_executor.shared_unbounded_instance())
   392      expansion_servicer = expansion_service.ExpansionServiceServicer(
   393          PipelineOptions([
   394              "--experiments",
   395              "beam_fn_api",
   396              "--sdk_location",
   397              "container",
   398              "--pickle_library",
   399              "cloudpickle"
   400          ]))
   401      update_sklearn_model_dependency(expansion_servicer._default_environment)
   402      beam_expansion_api_pb2_grpc.add_ExpansionServiceServicer_to_server(
   403          expansion_servicer, server)
   404      beam_artifact_api_pb2_grpc.add_ArtifactRetrievalServiceServicer_to_server(
   405          artifact_service.ArtifactRetrievalService(
   406              artifact_service.BeamFilesystemHandler(None).file_reader),
   407          server)
   408      server.add_insecure_port('localhost:{}'.format(options.port))
   409      server.start()
   410      _LOGGER.info('Listening for expansion requests at %d', options.port)
   411  
   412      signal.signal(signal.SIGTERM, cleanup)
   413      signal.signal(signal.SIGINT, cleanup)
   414      # blocking main thread forever.
   415      signal.pause()
   416  
   417  
   418  if __name__ == '__main__':
   419    logging.getLogger().setLevel(logging.INFO)
   420    main(sys.argv)