github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/external.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Defines Transform whose expansion is implemented elsewhere."""
    19  # pytype: skip-file
    20  
    21  import contextlib
    22  import copy
    23  import functools
    24  import glob
    25  import logging
    26  import threading
    27  from collections import OrderedDict
    28  from collections import namedtuple
    29  from typing import Dict
    30  
    31  import grpc
    32  
    33  from apache_beam import pvalue
    34  from apache_beam.coders import RowCoder
    35  from apache_beam.portability import common_urns
    36  from apache_beam.portability.api import beam_artifact_api_pb2_grpc
    37  from apache_beam.portability.api import beam_expansion_api_pb2
    38  from apache_beam.portability.api import beam_expansion_api_pb2_grpc
    39  from apache_beam.portability.api import beam_runner_api_pb2
    40  from apache_beam.portability.api import external_transforms_pb2
    41  from apache_beam.runners import pipeline_context
    42  from apache_beam.runners.portability import artifact_service
    43  from apache_beam.transforms import ptransform
    44  from apache_beam.typehints import WithTypeHints
    45  from apache_beam.typehints import native_type_compatibility
    46  from apache_beam.typehints import row_type
    47  from apache_beam.typehints.schemas import named_fields_to_schema
    48  from apache_beam.typehints.schemas import named_tuple_from_schema
    49  from apache_beam.typehints.schemas import named_tuple_to_schema
    50  from apache_beam.typehints.trivial_inference import instance_to_type
    51  from apache_beam.typehints.typehints import Union
    52  from apache_beam.typehints.typehints import UnionConstraint
    53  from apache_beam.utils import subprocess_server
    54  
    55  DEFAULT_EXPANSION_SERVICE = 'localhost:8097'
    56  
    57  
    58  def convert_to_typing_type(type_):
    59    if isinstance(type_, row_type.RowTypeConstraint):
    60      return named_tuple_from_schema(named_fields_to_schema(type_._fields))
    61    else:
    62      return native_type_compatibility.convert_to_typing_type(type_)
    63  
    64  
    65  def _is_optional_or_none(typehint):
    66    return (
    67        type(None) in typehint.union_types if isinstance(
    68            typehint, UnionConstraint) else typehint is type(None))
    69  
    70  
    71  def _strip_optional(typehint):
    72    if not _is_optional_or_none(typehint):
    73      return typehint
    74    new_types = typehint.union_types.difference({type(None)})
    75    if len(new_types) == 1:
    76      return list(new_types)[0]
    77    return Union[new_types]
    78  
    79  
    80  def iter_urns(coder, context=None):
    81    yield coder.to_runner_api_parameter(context)[0]
    82    for child in coder._get_component_coders():
    83      for urn in iter_urns(child, context):
    84        yield urn
    85  
    86  
    87  class PayloadBuilder(object):
    88    """
    89    Abstract base class for building payloads to pass to ExternalTransform.
    90    """
    91    def build(self):
    92      """
    93      :return: ExternalConfigurationPayload
    94      """
    95      raise NotImplementedError
    96  
    97    def payload(self):
    98      """
    99      The serialized ExternalConfigurationPayload
   100  
   101      :return: bytes
   102      """
   103      return self.build().SerializeToString()
   104  
   105    def _get_schema_proto_and_payload(self, **kwargs):
   106      named_fields = []
   107      fields_to_values = OrderedDict()
   108  
   109      for key, value in kwargs.items():
   110        if not key:
   111          raise ValueError('Parameter name cannot be empty')
   112        if value is None:
   113          raise ValueError(
   114              'Received value None for key %s. None values are currently not '
   115              'supported' % key)
   116        named_fields.append(
   117            (key, convert_to_typing_type(instance_to_type(value))))
   118        fields_to_values[key] = value
   119  
   120      schema_proto = named_fields_to_schema(named_fields)
   121      row = named_tuple_from_schema(schema_proto)(**fields_to_values)
   122      schema = named_tuple_to_schema(type(row))
   123  
   124      payload = RowCoder(schema).encode(row)
   125      return (schema_proto, payload)
   126  
   127  
   128  class SchemaBasedPayloadBuilder(PayloadBuilder):
   129    """
   130    Base class for building payloads based on a schema that provides
   131    type information for each configuration value to encode.
   132    """
   133    def _get_named_tuple_instance(self):
   134      raise NotImplementedError()
   135  
   136    def build(self):
   137      row = self._get_named_tuple_instance()
   138      schema = named_tuple_to_schema(type(row))
   139      return external_transforms_pb2.ExternalConfigurationPayload(
   140          schema=schema, payload=RowCoder(schema).encode(row))
   141  
   142  
   143  class ImplicitSchemaPayloadBuilder(SchemaBasedPayloadBuilder):
   144    """
   145    Build a payload that generates a schema from the provided values.
   146    """
   147    def __init__(self, values):
   148      self._values = values
   149  
   150    def _get_named_tuple_instance(self):
   151      # omit fields with value=None since we can't infer their type
   152      values = {
   153          key: value
   154          for key, value in self._values.items() if value is not None
   155      }
   156  
   157      schema = named_fields_to_schema([
   158          (key, convert_to_typing_type(instance_to_type(value))) for key,
   159          value in values.items()
   160      ])
   161      return named_tuple_from_schema(schema)(**values)
   162  
   163  
   164  class NamedTupleBasedPayloadBuilder(SchemaBasedPayloadBuilder):
   165    """
   166    Build a payload based on a NamedTuple schema.
   167    """
   168    def __init__(self, tuple_instance):
   169      """
   170      :param tuple_instance: an instance of a typing.NamedTuple
   171      """
   172      super().__init__()
   173      self._tuple_instance = tuple_instance
   174  
   175    def _get_named_tuple_instance(self):
   176      return self._tuple_instance
   177  
   178  
   179  class SchemaTransformPayloadBuilder(PayloadBuilder):
   180    def __init__(self, identifier, **kwargs):
   181      self._identifier = identifier
   182      self._kwargs = kwargs
   183  
   184    def build(self):
   185      schema_proto, payload = self._get_schema_proto_and_payload(**self._kwargs)
   186      payload = external_transforms_pb2.SchemaTransformPayload(
   187          identifier=self._identifier,
   188          configuration_schema=schema_proto,
   189          configuration_row=payload)
   190      return payload
   191  
   192  
   193  class JavaClassLookupPayloadBuilder(PayloadBuilder):
   194    """
   195    Builds a payload for directly instantiating a Java transform using a
   196    constructor and builder methods.
   197    """
   198  
   199    IGNORED_ARG_FORMAT = 'ignore%d'
   200  
   201    def __init__(self, class_name):
   202      """
   203      :param class_name: fully qualified name of the transform class.
   204      """
   205      if not class_name:
   206        raise ValueError('Class name must not be empty')
   207  
   208      self._class_name = class_name
   209      self._constructor_method = None
   210      self._constructor_param_args = None
   211      self._constructor_param_kwargs = None
   212      self._builder_methods_and_params = OrderedDict()
   213  
   214    def _args_to_named_fields(self, args):
   215      next_field_id = 0
   216      named_fields = OrderedDict()
   217      for value in args:
   218        if value is None:
   219          raise ValueError(
   220              'Received value None. None values are currently not supported')
   221        named_fields[(
   222            JavaClassLookupPayloadBuilder.IGNORED_ARG_FORMAT %
   223            next_field_id)] = value
   224        next_field_id += 1
   225      return named_fields
   226  
   227    def build(self):
   228      all_constructor_param_kwargs = self._args_to_named_fields(
   229          self._constructor_param_args)
   230      if self._constructor_param_kwargs:
   231        all_constructor_param_kwargs.update(self._constructor_param_kwargs)
   232      constructor_schema, constructor_payload = (
   233        self._get_schema_proto_and_payload(**all_constructor_param_kwargs))
   234      payload = external_transforms_pb2.JavaClassLookupPayload(
   235          class_name=self._class_name,
   236          constructor_schema=constructor_schema,
   237          constructor_payload=constructor_payload)
   238      if self._constructor_method:
   239        payload.constructor_method = self._constructor_method
   240  
   241      for builder_method_name, params in self._builder_methods_and_params.items():
   242        builder_method_args, builder_method_kwargs = params
   243        all_builder_method_kwargs = self._args_to_named_fields(
   244            builder_method_args)
   245        if builder_method_kwargs:
   246          all_builder_method_kwargs.update(builder_method_kwargs)
   247        builder_method_schema, builder_method_payload = (
   248          self._get_schema_proto_and_payload(**all_builder_method_kwargs))
   249        builder_method = external_transforms_pb2.BuilderMethod(
   250            name=builder_method_name,
   251            schema=builder_method_schema,
   252            payload=builder_method_payload)
   253        builder_method.name = builder_method_name
   254        payload.builder_methods.append(builder_method)
   255      return payload
   256  
   257    def with_constructor(self, *args, **kwargs):
   258      """
   259      Specifies the Java constructor to use.
   260      Arguments provided using args and kwargs will be applied to the Java
   261      transform constructor in the specified order.
   262  
   263      :param args: parameter values of the constructor.
   264      :param kwargs: parameter names and values of the constructor.
   265      """
   266      if self._has_constructor():
   267        raise ValueError(
   268            'Constructor or constructor method can only be specified once')
   269  
   270      self._constructor_param_args = args
   271      self._constructor_param_kwargs = kwargs
   272  
   273    def with_constructor_method(self, method_name, *args, **kwargs):
   274      """
   275      Specifies the Java constructor method to use.
   276      Arguments provided using args and kwargs will be applied to the Java
   277      transform constructor method in the specified order.
   278  
   279      :param method_name: name of the constructor method.
   280      :param args: parameter values of the constructor method.
   281      :param kwargs: parameter names and values of the constructor method.
   282      """
   283      if self._has_constructor():
   284        raise ValueError(
   285            'Constructor or constructor method can only be specified once')
   286  
   287      self._constructor_method = method_name
   288      self._constructor_param_args = args
   289      self._constructor_param_kwargs = kwargs
   290  
   291    def add_builder_method(self, method_name, *args, **kwargs):
   292      """
   293      Specifies a Java builder method to be invoked after instantiating the Java
   294      transform class. Specified builder method will be applied in order.
   295      Arguments provided using args and kwargs will be applied to the Java
   296      transform builder method in the specified order.
   297  
   298      :param method_name: name of the builder method.
   299      :param args: parameter values of the builder method.
   300      :param kwargs:  parameter names and values of the builder method.
   301      """
   302      self._builder_methods_and_params[method_name] = (args, kwargs)
   303  
   304    def _has_constructor(self):
   305      return (
   306          self._constructor_method or self._constructor_param_args or
   307          self._constructor_param_kwargs)
   308  
   309  
   310  # Information regarding a SchemaTransform available in an external SDK.
   311  SchemaTransformsConfig = namedtuple(
   312      'SchemaTransformsConfig',
   313      ['identifier', 'configuration_schema', 'inputs', 'outputs'])
   314  
   315  
   316  class SchemaAwareExternalTransform(ptransform.PTransform):
   317    """A proxy transform for SchemaTransforms implemented in external SDKs.
   318  
   319    This allows Python pipelines to directly use existing SchemaTransforms
   320    available to the expansion service without adding additional code in external
   321    SDKs.
   322  
   323    :param identifier: unique identifier of the SchemaTransform.
   324    :param expansion_service: an expansion service to use. This should already be
   325        available and the Schema-aware transforms to be used must already be
   326        deployed.
   327    :param rearrange_based_on_discovery: if this flag is set, the input kwargs
   328        will be rearranged to match the order of fields in the external
   329        SchemaTransform configuration. A discovery call will be made to fetch
   330        the configuration.
   331    :param classpath: (Optional) A list paths to additional jars to place on the
   332        expansion service classpath.
   333    :kwargs: field name to value mapping for configuring the schema transform.
   334        keys map to the field names of the schema of the SchemaTransform
   335        (in-order).
   336    """
   337    def __init__(
   338        self,
   339        identifier,
   340        expansion_service,
   341        rearrange_based_on_discovery=False,
   342        classpath=None,
   343        **kwargs):
   344      self._expansion_service = expansion_service
   345      self._kwargs = kwargs
   346      self._classpath = classpath
   347  
   348      _kwargs = kwargs
   349      if rearrange_based_on_discovery:
   350        _kwargs = self._rearrange_kwargs(identifier)
   351  
   352      self._payload_builder = SchemaTransformPayloadBuilder(identifier, **_kwargs)
   353  
   354    def _rearrange_kwargs(self, identifier):
   355      # discover and fetch the external SchemaTransform configuration then
   356      # use it to build an appropriate payload
   357      schematransform_config = SchemaAwareExternalTransform.discover_config(
   358          self._expansion_service, identifier)
   359  
   360      external_config_fields = schematransform_config.configuration_schema._fields
   361      ordered_kwargs = OrderedDict()
   362      missing_fields = []
   363  
   364      for field in external_config_fields:
   365        if field not in self._kwargs:
   366          missing_fields.append(field)
   367        else:
   368          ordered_kwargs[field] = self._kwargs[field]
   369  
   370      extra_fields = list(set(self._kwargs.keys()) - set(external_config_fields))
   371      if missing_fields:
   372        raise ValueError(
   373            'Input parameters are missing the following SchemaTransform config '
   374            'fields: %s' % missing_fields)
   375      elif extra_fields:
   376        raise ValueError(
   377            'Input parameters include the following extra fields that are not '
   378            'found in the SchemaTransform config schema: %s' % extra_fields)
   379  
   380      return ordered_kwargs
   381  
   382    def expand(self, pcolls):
   383      # Expand the transform using the expansion service.
   384      return pcolls | ExternalTransform(
   385          common_urns.schematransform_based_expand.urn,
   386          self._payload_builder,
   387          self._expansion_service)
   388  
   389    @staticmethod
   390    def discover(expansion_service):
   391      """Discover all SchemaTransforms available to the given expansion service.
   392  
   393      :return: a list of SchemaTransformsConfigs that represent the discovered
   394          SchemaTransforms.
   395      """
   396  
   397      with ExternalTransform.service(expansion_service) as service:
   398        discover_response = service.DiscoverSchemaTransform(
   399            beam_expansion_api_pb2.DiscoverSchemaTransformRequest())
   400  
   401        for identifier in discover_response.schema_transform_configs:
   402          proto_config = discover_response.schema_transform_configs[identifier]
   403          schema = named_tuple_from_schema(proto_config.config_schema)
   404  
   405          yield SchemaTransformsConfig(
   406              identifier=identifier,
   407              configuration_schema=schema,
   408              inputs=proto_config.input_pcollection_names,
   409              outputs=proto_config.output_pcollection_names)
   410  
   411    @staticmethod
   412    def discover_config(expansion_service, name):
   413      """Discover one SchemaTransform by name in the given expansion service.
   414  
   415      :return: one SchemaTransformsConfig that represents the discovered
   416          SchemaTransform
   417  
   418      :raises:
   419        ValueError: if more than one SchemaTransform is discovered, or if none
   420        are discovered
   421      """
   422  
   423      schematransforms = SchemaAwareExternalTransform.discover(expansion_service)
   424      matched = []
   425  
   426      for st in schematransforms:
   427        if name in st.identifier:
   428          matched.append(st)
   429  
   430      if not matched:
   431        raise ValueError(
   432            "Did not discover any SchemaTransforms resembling the name '%s'" %
   433            name)
   434      elif len(matched) > 1:
   435        raise ValueError(
   436            "Found multiple SchemaTransforms with the name '%s':\n%s\n" %
   437            (name, [st.identifier for st in matched]))
   438  
   439      return matched[0]
   440  
   441  
   442  class JavaExternalTransform(ptransform.PTransform):
   443    """A proxy for Java-implemented external transforms.
   444  
   445    One builds these transforms just as one would in Java, e.g.::
   446  
   447        transform = JavaExternalTransform('fully.qualified.ClassName'
   448            )(contructorArg, ... ).builderMethod(...)
   449  
   450    or::
   451  
   452        JavaExternalTransform('fully.qualified.ClassName').staticConstructor(
   453            ...).builderMethod1(...).builderMethod2(...)
   454  
   455    :param class_name: fully qualified name of the java class
   456    :param expansion_service: (Optional) an expansion service to use.  If none is
   457        provided, a default expansion service will be started.
   458    :param classpath: (Optional) A list paths to additional jars to place on the
   459        expansion service classpath.
   460    """
   461    def __init__(self, class_name, expansion_service=None, classpath=None):
   462      if expansion_service and classpath:
   463        raise ValueError(
   464            f'Only one of expansion_service ({expansion_service}) '
   465            f'or classpath ({classpath}) may be provided.')
   466      self._payload_builder = JavaClassLookupPayloadBuilder(class_name)
   467      self._classpath = classpath
   468      self._expansion_service = expansion_service
   469      # Beam explicitly looks for following attributes. Hence adding
   470      # 'None' values here to prevent '__getattr__' from being called.
   471      self.inputs = None
   472      self._fn_api_payload = None
   473  
   474    def __call__(self, *args, **kwargs):
   475      self._payload_builder.with_constructor(*args, **kwargs)
   476      return self
   477  
   478    def __getattr__(self, name):
   479      # Don't try to emulate special methods.
   480      if name.startswith('__') and name.endswith('__'):
   481        return super().__getattr__(name)
   482      else:
   483        return self[name]
   484  
   485    def __getitem__(self, name):
   486      # Use directly for keywords or attribute conflicts.
   487      def construct(*args, **kwargs):
   488        if self._payload_builder._has_constructor():
   489          builder_method = self._payload_builder.add_builder_method
   490        else:
   491          builder_method = self._payload_builder.with_constructor_method
   492        builder_method(name, *args, **kwargs)
   493        return self
   494  
   495      return construct
   496  
   497    def expand(self, pcolls):
   498      if self._expansion_service is None:
   499        self._expansion_service = BeamJarExpansionService(
   500            ':sdks:java:expansion-service:app:shadowJar',
   501            extra_args=['{{PORT}}', '--javaClassLookupAllowlistFile=*'],
   502            classpath=self._classpath)
   503      return pcolls | ExternalTransform(
   504          common_urns.java_class_lookup.urn,
   505          self._payload_builder,
   506          self._expansion_service)
   507  
   508  
   509  class AnnotationBasedPayloadBuilder(SchemaBasedPayloadBuilder):
   510    """
   511    Build a payload based on an external transform's type annotations.
   512    """
   513    def __init__(self, transform, **values):
   514      """
   515      :param transform: a PTransform instance or class. type annotations will
   516                        be gathered from its __init__ method
   517      :param values: values to encode
   518      """
   519      self._transform = transform
   520      self._values = values
   521  
   522    def _get_named_tuple_instance(self):
   523      schema = named_fields_to_schema([
   524          (k, convert_to_typing_type(v)) for k,
   525          v in self._transform.__init__.__annotations__.items()
   526          if k in self._values
   527      ])
   528      return named_tuple_from_schema(schema)(**self._values)
   529  
   530  
   531  class DataclassBasedPayloadBuilder(SchemaBasedPayloadBuilder):
   532    """
   533    Build a payload based on an external transform that uses dataclasses.
   534    """
   535    def __init__(self, transform):
   536      """
   537      :param transform: a dataclass-decorated PTransform instance from which to
   538                        gather type annotations and values
   539      """
   540      self._transform = transform
   541  
   542    def _get_named_tuple_instance(self):
   543      import dataclasses
   544      schema = named_fields_to_schema([
   545          (field.name, convert_to_typing_type(field.type))
   546          for field in dataclasses.fields(self._transform)
   547      ])
   548      return named_tuple_from_schema(schema)(
   549          **dataclasses.asdict(self._transform))
   550  
   551  
   552  class ExternalTransform(ptransform.PTransform):
   553    """
   554      External provides a cross-language transform via expansion services in
   555      foreign SDKs.
   556    """
   557    _namespace_counter = 0
   558  
   559    # Variable name _namespace conflicts with DisplayData._namespace so we use
   560    # name _external_namespace here.
   561    _external_namespace = threading.local()
   562  
   563    _IMPULSE_PREFIX = 'impulse'
   564  
   565    def __init__(self, urn, payload, expansion_service=None):
   566      """Wrapper for an external transform with the given urn and payload.
   567  
   568      :param urn: the unique beam identifier for this transform
   569      :param payload: the payload, either as a byte string or a PayloadBuilder
   570      :param expansion_service: an expansion service implementing the beam
   571          ExpansionService protocol, either as an object with an Expand method
   572          or an address (as a str) to a grpc server that provides this method.
   573      """
   574      expansion_service = expansion_service or DEFAULT_EXPANSION_SERVICE
   575      if not urn and isinstance(payload, JavaClassLookupPayloadBuilder):
   576        urn = common_urns.java_class_lookup.urn
   577      self._urn = urn
   578      self._payload = (
   579          payload.payload() if isinstance(payload, PayloadBuilder) else payload)
   580      self._expansion_service = expansion_service
   581      self._external_namespace = self._fresh_namespace()
   582      self._inputs = {}  # type: Dict[str, pvalue.PCollection]
   583      self._outputs = {}  # type: Dict[str, pvalue.PCollection]
   584  
   585    def with_output_types(self, *args, **kwargs):
   586      return WithTypeHints.with_output_types(self, *args, **kwargs)
   587  
   588    def replace_named_inputs(self, named_inputs):
   589      self._inputs = named_inputs
   590  
   591    def replace_named_outputs(self, named_outputs):
   592      self._outputs = named_outputs
   593  
   594    def __post_init__(self, expansion_service):
   595      """
   596      This will only be invoked if ExternalTransform is used as a base class
   597      for a class decorated with dataclasses.dataclass
   598      """
   599      ExternalTransform.__init__(
   600          self, self.URN, DataclassBasedPayloadBuilder(self), expansion_service)
   601  
   602    def default_label(self):
   603      return '%s(%s)' % (self.__class__.__name__, self._urn)
   604  
   605    @classmethod
   606    def get_local_namespace(cls):
   607      return getattr(cls._external_namespace, 'value', 'external')
   608  
   609    @classmethod
   610    @contextlib.contextmanager
   611    def outer_namespace(cls, namespace):
   612      prev = cls.get_local_namespace()
   613      cls._external_namespace.value = namespace
   614      yield
   615      cls._external_namespace.value = prev
   616  
   617    @classmethod
   618    def _fresh_namespace(cls):
   619      # type: () -> str
   620      ExternalTransform._namespace_counter += 1
   621      return '%s_%d' % (cls.get_local_namespace(), cls._namespace_counter)
   622  
   623    def expand(self, pvalueish):
   624      # type: (pvalue.PCollection) -> pvalue.PCollection
   625      if isinstance(pvalueish, pvalue.PBegin):
   626        self._inputs = {}
   627      elif isinstance(pvalueish, (list, tuple)):
   628        self._inputs = {str(ix): pvalue for ix, pvalue in enumerate(pvalueish)}
   629      elif isinstance(pvalueish, dict):
   630        self._inputs = pvalueish
   631      else:
   632        self._inputs = {'input': pvalueish}
   633      pipeline = (
   634          next(iter(self._inputs.values())).pipeline
   635          if self._inputs else pvalueish.pipeline)
   636      context = pipeline_context.PipelineContext(
   637          component_id_map=pipeline.component_id_map)
   638      transform_proto = beam_runner_api_pb2.PTransform(
   639          unique_name=pipeline._current_transform().full_label,
   640          spec=beam_runner_api_pb2.FunctionSpec(
   641              urn=self._urn, payload=self._payload))
   642      for tag, pcoll in self._inputs.items():
   643        transform_proto.inputs[tag] = context.pcollections.get_id(pcoll)
   644        # Conversion to/from proto assumes producers.
   645        # TODO: Possibly loosen this.
   646        context.transforms.put_proto(
   647            '%s_%s' % (self._IMPULSE_PREFIX, tag),
   648            beam_runner_api_pb2.PTransform(
   649                unique_name='%s_%s' % (self._IMPULSE_PREFIX, tag),
   650                spec=beam_runner_api_pb2.FunctionSpec(
   651                    urn=common_urns.primitives.IMPULSE.urn),
   652                outputs={'out': transform_proto.inputs[tag]}))
   653      output_coders = None
   654      if self._type_hints.output_types:
   655        if self._type_hints.output_types[0]:
   656          output_coders = dict(
   657              (str(k), context.coder_id_from_element_type(v))
   658              for (k, v) in enumerate(self._type_hints.output_types[0]))
   659        elif self._type_hints.output_types[1]:
   660          output_coders = {
   661              k: context.coder_id_from_element_type(v)
   662              for (k, v) in self._type_hints.output_types[1].items()
   663          }
   664      components = context.to_runner_api()
   665      request = beam_expansion_api_pb2.ExpansionRequest(
   666          components=components,
   667          namespace=self._external_namespace,
   668          transform=transform_proto,
   669          output_coder_requests=output_coders)
   670  
   671      with ExternalTransform.service(self._expansion_service) as service:
   672        response = service.Expand(request)
   673        if response.error:
   674          raise RuntimeError(response.error)
   675        self._expanded_components = response.components
   676        if any(env.dependencies
   677               for env in self._expanded_components.environments.values()):
   678          self._expanded_components = self._resolve_artifacts(
   679              self._expanded_components,
   680              service.artifact_service(),
   681              pipeline.local_tempdir)
   682  
   683      self._expanded_transform = response.transform
   684      self._expanded_requirements = response.requirements
   685      result_context = pipeline_context.PipelineContext(response.components)
   686  
   687      def fix_output(pcoll, tag):
   688        pcoll.pipeline = pipeline
   689        pcoll.tag = tag
   690        return pcoll
   691  
   692      self._outputs = {
   693          tag: fix_output(result_context.pcollections.get_by_id(pcoll_id), tag)
   694          for tag,
   695          pcoll_id in self._expanded_transform.outputs.items()
   696      }
   697  
   698      return self._output_to_pvalueish(self._outputs)
   699  
   700    @staticmethod
   701    @contextlib.contextmanager
   702    def service(expansion_service):
   703      if isinstance(expansion_service, str):
   704        channel_options = [("grpc.max_receive_message_length", -1),
   705                           ("grpc.max_send_message_length", -1)]
   706        if hasattr(grpc, 'local_channel_credentials'):
   707          # Some environments may not support insecure channels. Hence use a
   708          # secure channel with local credentials here.
   709          # TODO: update this to support secure non-local channels.
   710          channel_factory_fn = functools.partial(
   711              grpc.secure_channel,
   712              expansion_service,
   713              grpc.local_channel_credentials(),
   714              options=channel_options)
   715        else:
   716          # local_channel_credentials is an experimental API which is unsupported
   717          # by older versions of grpc which may be pulled in due to other project
   718          # dependencies.
   719          channel_factory_fn = functools.partial(
   720              grpc.insecure_channel, expansion_service, options=channel_options)
   721        with channel_factory_fn() as channel:
   722          yield ExpansionAndArtifactRetrievalStub(channel)
   723      elif hasattr(expansion_service, 'Expand'):
   724        yield expansion_service
   725      else:
   726        with expansion_service as stub:
   727          yield stub
   728  
   729    def _resolve_artifacts(self, components, service, dest):
   730      for env in components.environments.values():
   731        if env.dependencies:
   732          resolved = list(
   733              artifact_service.resolve_artifacts(env.dependencies, service, dest))
   734          del env.dependencies[:]
   735          env.dependencies.extend(resolved)
   736      return components
   737  
   738    def _output_to_pvalueish(self, output_dict):
   739      if len(output_dict) == 1:
   740        return next(iter(output_dict.values()))
   741      else:
   742        return output_dict
   743  
   744    def to_runner_api_transform(self, context, full_label):
   745      pcoll_renames = {}
   746      renamed_tag_seen = False
   747      for tag, pcoll in self._inputs.items():
   748        if tag not in self._expanded_transform.inputs:
   749          if renamed_tag_seen:
   750            raise RuntimeError(
   751                'Ambiguity due to non-preserved tags: %s vs %s' % (
   752                    sorted(self._expanded_transform.inputs.keys()),
   753                    sorted(self._inputs.keys())))
   754          else:
   755            renamed_tag_seen = True
   756            tag, = self._expanded_transform.inputs.keys()
   757        pcoll_renames[self._expanded_transform.inputs[tag]] = (
   758            context.pcollections.get_id(pcoll))
   759      for tag, pcoll in self._outputs.items():
   760        pcoll_renames[self._expanded_transform.outputs[tag]] = (
   761            context.pcollections.get_id(pcoll))
   762  
   763      def _equivalent(coder1, coder2):
   764        return coder1 == coder2 or _normalize(coder1) == _normalize(coder2)
   765  
   766      def _normalize(coder_proto):
   767        normalized = copy.copy(coder_proto)
   768        normalized.spec.environment_id = ''
   769        # TODO(robertwb): Normalize components as well.
   770        return normalized
   771  
   772      for id, proto in self._expanded_components.coders.items():
   773        if id.startswith(self._external_namespace):
   774          context.coders.put_proto(id, proto)
   775        elif id in context.coders:
   776          if not _equivalent(context.coders._id_to_proto[id], proto):
   777            raise RuntimeError(
   778                'Re-used coder id: %s\n%s\n%s' %
   779                (id, context.coders._id_to_proto[id], proto))
   780        else:
   781          context.coders.put_proto(id, proto)
   782      for id, proto in self._expanded_components.windowing_strategies.items():
   783        if id.startswith(self._external_namespace):
   784          context.windowing_strategies.put_proto(id, proto)
   785      for id, proto in self._expanded_components.environments.items():
   786        if id.startswith(self._external_namespace):
   787          context.environments.put_proto(id, proto)
   788      for id, proto in self._expanded_components.pcollections.items():
   789        id = pcoll_renames.get(id, id)
   790        if id not in context.pcollections._id_to_obj.keys():
   791          context.pcollections.put_proto(id, proto)
   792  
   793      for id, proto in self._expanded_components.transforms.items():
   794        if id.startswith(self._IMPULSE_PREFIX):
   795          # Our fake inputs.
   796          continue
   797        assert id.startswith(
   798            self._external_namespace), (id, self._external_namespace)
   799        new_proto = beam_runner_api_pb2.PTransform(
   800            unique_name=proto.unique_name,
   801            # If URN is not set this is an empty spec.
   802            spec=proto.spec if proto.spec.urn else None,
   803            subtransforms=proto.subtransforms,
   804            inputs={
   805                tag: pcoll_renames.get(pcoll, pcoll)
   806                for tag,
   807                pcoll in proto.inputs.items()
   808            },
   809            outputs={
   810                tag: pcoll_renames.get(pcoll, pcoll)
   811                for tag,
   812                pcoll in proto.outputs.items()
   813            },
   814            display_data=proto.display_data,
   815            environment_id=proto.environment_id)
   816        context.transforms.put_proto(id, new_proto)
   817  
   818      for requirement in self._expanded_requirements:
   819        context.add_requirement(requirement)
   820  
   821      return beam_runner_api_pb2.PTransform(
   822          unique_name=full_label,
   823          spec=self._expanded_transform.spec,
   824          subtransforms=self._expanded_transform.subtransforms,
   825          inputs={
   826              tag: pcoll_renames.get(pcoll, pcoll)
   827              for tag,
   828              pcoll in self._expanded_transform.inputs.items()
   829          },
   830          outputs={
   831              tag: pcoll_renames.get(pcoll, pcoll)
   832              for tag,
   833              pcoll in self._expanded_transform.outputs.items()
   834          },
   835          environment_id=self._expanded_transform.environment_id)
   836  
   837  
   838  class ExpansionAndArtifactRetrievalStub(
   839      beam_expansion_api_pb2_grpc.ExpansionServiceStub):
   840    def __init__(self, channel, **kwargs):
   841      self._channel = channel
   842      self._kwargs = kwargs
   843      super().__init__(channel, **kwargs)
   844  
   845    def artifact_service(self):
   846      return beam_artifact_api_pb2_grpc.ArtifactRetrievalServiceStub(
   847          self._channel, **self._kwargs)
   848  
   849    def ready(self, timeout_sec):
   850      grpc.channel_ready_future(self._channel).result(timeout=timeout_sec)
   851  
   852  
   853  class JavaJarExpansionService(object):
   854    """An expansion service based on an Java Jar file.
   855  
   856    This can be passed into an ExternalTransform as the expansion_service
   857    argument which will spawn a subprocess using this jar to expand the
   858    transform.
   859  
   860    Args:
   861      path_to_jar: the path to a locally available executable jar file to be used
   862        to start up the expansion service.
   863      extra_args: arguments to be provided when starting up the
   864        expansion service using the jar file. These arguments will replace the
   865        default arguments.
   866      classpath: Additional dependencies to be added to the classpath.
   867      append_args: arguments to be provided when starting up the
   868        expansion service using the jar file. These arguments will be appended to
   869        the default arguments.
   870    """
   871    def __init__(
   872        self, path_to_jar, extra_args=None, classpath=None, append_args=None):
   873      if extra_args and append_args:
   874        raise ValueError('Only one of extra_args or append_args may be provided')
   875      self._path_to_jar = path_to_jar
   876      self._extra_args = extra_args
   877      self._classpath = classpath or []
   878      self._service_count = 0
   879      self._append_args = append_args or []
   880  
   881    @staticmethod
   882    def _expand_jars(jar):
   883      if glob.glob(jar):
   884        return glob.glob(jar)
   885      elif isinstance(jar, str) and (jar.startswith('http://') or
   886                                     jar.startswith('https://')):
   887        return [subprocess_server.JavaJarServer.local_jar(jar)]
   888      else:
   889        # If the input JAR is not a local glob, nor an http/https URL, then
   890        # we assume that it's a gradle-style Java artifact in Maven Central,
   891        # in the form group:artifact:version, so we attempt to parse that way.
   892        try:
   893          group_id, artifact_id, version = jar.split(':')
   894        except ValueError:
   895          # If we are not able to find a JAR, nor a JAR artifact, nor a URL for
   896          # a JAR path, we still choose to include it in the path.
   897          logging.warning('Unable to parse %s into group:artifact:version.', jar)
   898          return [jar]
   899        path = subprocess_server.JavaJarServer.local_jar(
   900            subprocess_server.JavaJarServer.path_to_maven_jar(
   901                artifact_id, group_id, version))
   902        return [path]
   903  
   904    def _default_args(self):
   905      """Default arguments to be used by `JavaJarExpansionService`."""
   906  
   907      to_stage = ','.join([self._path_to_jar] + sum((
   908          JavaJarExpansionService._expand_jars(jar)
   909          for jar in self._classpath or []), []))
   910      return ['{{PORT}}', f'--filesToStage={to_stage}']
   911  
   912    def __enter__(self):
   913      if self._service_count == 0:
   914        self._path_to_jar = subprocess_server.JavaJarServer.local_jar(
   915            self._path_to_jar)
   916        if self._extra_args is None:
   917          self._extra_args = self._default_args() + self._append_args
   918        # Consider memoizing these servers (with some timeout).
   919        logging.info(
   920            'Starting a JAR-based expansion service from JAR %s ' + (
   921                'and with classpath: %s' %
   922                self._classpath if self._classpath else ''),
   923            self._path_to_jar)
   924        classpath_urls = [
   925            subprocess_server.JavaJarServer.local_jar(path)
   926            for jar in self._classpath
   927            for path in JavaJarExpansionService._expand_jars(jar)
   928        ]
   929        self._service_provider = subprocess_server.JavaJarServer(
   930            ExpansionAndArtifactRetrievalStub,
   931            self._path_to_jar,
   932            self._extra_args,
   933            classpath=classpath_urls)
   934        self._service = self._service_provider.__enter__()
   935      self._service_count += 1
   936      return self._service
   937  
   938    def __exit__(self, *args):
   939      self._service_count -= 1
   940      if self._service_count == 0:
   941        self._service_provider.__exit__(*args)
   942  
   943  
   944  class BeamJarExpansionService(JavaJarExpansionService):
   945    """An expansion service based on an Beam Java Jar file.
   946  
   947    Attempts to use a locally-built copy of the jar based on the gradle target,
   948    if it exists, otherwise attempts to download and cache the released artifact
   949    corresponding to this version of Beam from the apache maven repository.
   950  
   951    Args:
   952      gradle_target: Beam Gradle target for building an executable jar which will
   953        be used to start the expansion service.
   954      extra_args: arguments to be provided when starting up the
   955        expansion service using the jar file. These arguments will replace the
   956        default arguments.
   957      gradle_appendix: Gradle appendix of the artifact.
   958      classpath: Additional dependencies to be added to the classpath.
   959      append_args: arguments to be provided when starting up the
   960        expansion service using the jar file. These arguments will be appended to
   961        the default arguments.
   962    """
   963    def __init__(
   964        self,
   965        gradle_target,
   966        extra_args=None,
   967        gradle_appendix=None,
   968        classpath=None,
   969        append_args=None):
   970      path_to_jar = subprocess_server.JavaJarServer.path_to_beam_jar(
   971          gradle_target, gradle_appendix)
   972      super().__init__(
   973          path_to_jar, extra_args, classpath=classpath, append_args=append_args)
   974  
   975  
   976  def memoize(func):
   977    cache = {}
   978  
   979    def wrapper(*args):
   980      if args not in cache:
   981        cache[args] = func(*args)
   982      return cache[args]
   983  
   984    return wrapper