github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/portability/expansion_service.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """A PipelineExpansion service.
    19  """
    20  # pytype: skip-file
    21  
    22  import traceback
    23  
    24  from apache_beam import pipeline as beam_pipeline
    25  from apache_beam.portability import python_urns
    26  from apache_beam.portability.api import beam_expansion_api_pb2
    27  from apache_beam.portability.api import beam_expansion_api_pb2_grpc
    28  from apache_beam.runners import pipeline_context
    29  from apache_beam.runners.portability import portable_runner
    30  from apache_beam.transforms import external
    31  from apache_beam.transforms import ptransform
    32  
    33  
    34  class ExpansionServiceServicer(
    35      beam_expansion_api_pb2_grpc.ExpansionServiceServicer):
    36    def __init__(self, options=None):
    37      self._options = options or beam_pipeline.PipelineOptions(
    38          environment_type=python_urns.EMBEDDED_PYTHON, sdk_location='container')
    39      self._default_environment = (
    40          portable_runner.PortableRunner._create_environment(self._options))
    41  
    42    def Expand(self, request, context=None):
    43      try:
    44        pipeline = beam_pipeline.Pipeline(options=self._options)
    45  
    46        def with_pipeline(component, pcoll_id=None):
    47          component.pipeline = pipeline
    48          if pcoll_id:
    49            component.producer, component.tag = producers[pcoll_id]
    50            # We need the lookup to resolve back to this id.
    51            context.pcollections._obj_to_id[component] = pcoll_id
    52          return component
    53  
    54        context = pipeline_context.PipelineContext(
    55            request.components,
    56            default_environment=self._default_environment,
    57            namespace=request.namespace)
    58        producers = {
    59            pcoll_id: (context.transforms.get_by_id(t_id), pcoll_tag)
    60            for t_id,
    61            t_proto in request.components.transforms.items() for pcoll_tag,
    62            pcoll_id in t_proto.outputs.items()
    63        }
    64        transform = with_pipeline(
    65            ptransform.PTransform.from_runner_api(request.transform, context))
    66        if len(request.output_coder_requests) == 1:
    67          output_coder = {
    68              k: context.element_type_from_coder_id(v)
    69              for k,
    70              v in request.output_coder_requests.items()
    71          }
    72          transform = transform.with_output_types(list(output_coder.values())[0])
    73        elif len(request.output_coder_requests) > 1:
    74          raise ValueError(
    75              'type annotation for multiple outputs is not allowed yet: %s' %
    76              request.output_coder_requests)
    77        inputs = transform._pvaluish_from_dict({
    78            tag:
    79            with_pipeline(context.pcollections.get_by_id(pcoll_id), pcoll_id)
    80            for tag,
    81            pcoll_id in request.transform.inputs.items()
    82        })
    83        if not inputs:
    84          inputs = pipeline
    85        with external.ExternalTransform.outer_namespace(request.namespace):
    86          result = pipeline.apply(
    87              transform, inputs, request.transform.unique_name)
    88        expanded_transform = pipeline._root_transform().parts[-1]
    89        # TODO(BEAM-1833): Use named outputs internally.
    90        if isinstance(result, dict):
    91          expanded_transform.outputs = result
    92        pipeline_proto = pipeline.to_runner_api(context=context)
    93        # TODO(BEAM-1833): Use named inputs internally.
    94        expanded_transform_id = context.transforms.get_id(expanded_transform)
    95        expanded_transform_proto = pipeline_proto.components.transforms.pop(
    96            expanded_transform_id)
    97        expanded_transform_proto.inputs.clear()
    98        expanded_transform_proto.inputs.update(request.transform.inputs)
    99        for transform_id in pipeline_proto.root_transform_ids:
   100          del pipeline_proto.components.transforms[transform_id]
   101        return beam_expansion_api_pb2.ExpansionResponse(
   102            components=pipeline_proto.components,
   103            transform=expanded_transform_proto,
   104            requirements=pipeline_proto.requirements)
   105  
   106      except Exception:  # pylint: disable=broad-except
   107        return beam_expansion_api_pb2.ExpansionResponse(
   108            error=traceback.format_exc())