github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/yaml/yaml_provider.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """This module defines Providers usable from yaml, which is a specification
    19  for where to find and how to invoke services that vend implementations of
    20  various PTransforms."""
    21  
    22  import collections
    23  import hashlib
    24  import json
    25  import os
    26  import subprocess
    27  import sys
    28  import uuid
    29  from typing import Any
    30  from typing import Iterable
    31  from typing import Mapping
    32  
    33  import yaml
    34  from yaml.loader import SafeLoader
    35  
    36  import apache_beam as beam
    37  import apache_beam.dataframe.io
    38  import apache_beam.io
    39  import apache_beam.transforms.util
    40  from apache_beam.portability.api import schema_pb2
    41  from apache_beam.transforms import external
    42  from apache_beam.transforms import window
    43  from apache_beam.transforms.fully_qualified_named_transform import FullyQualifiedNamedTransform
    44  from apache_beam.typehints import schemas
    45  from apache_beam.typehints import trivial_inference
    46  from apache_beam.utils import python_callable
    47  from apache_beam.utils import subprocess_server
    48  from apache_beam.version import __version__ as beam_version
    49  
    50  
    51  class Provider:
    52    """Maps transform types names and args to concrete PTransform instances."""
    53    def available(self) -> bool:
    54      """Returns whether this provider is available to use in this environment."""
    55      raise NotImplementedError(type(self))
    56  
    57    def provided_transforms(self) -> Iterable[str]:
    58      """Returns a list of transform type names this provider can handle."""
    59      raise NotImplementedError(type(self))
    60  
    61    def create_transform(
    62        self, typ: str, args: Mapping[str, Any]) -> beam.PTransform:
    63      """Creates a PTransform instance for the given transform type and arguments.
    64      """
    65      raise NotImplementedError(type(self))
    66  
    67  
    68  def as_provider(name, provider_or_constructor):
    69    if isinstance(provider_or_constructor, Provider):
    70      return provider_or_constructor
    71    else:
    72      return InlineProvider({name: provider_or_constructor})
    73  
    74  
    75  def as_provider_list(name, lst):
    76    if not isinstance(lst, list):
    77      return as_provider_list(name, [lst])
    78    return [as_provider(name, x) for x in lst]
    79  
    80  
    81  class ExternalProvider(Provider):
    82    """A Provider implemented via the cross language transform service."""
    83    def __init__(self, urns, service):
    84      self._urns = urns
    85      self._service = service
    86      self._schema_transforms = None
    87  
    88    def provided_transforms(self):
    89      return self._urns.keys()
    90  
    91    def create_transform(self, type, args):
    92      if callable(self._service):
    93        self._service = self._service()
    94      if self._schema_transforms is None:
    95        try:
    96          self._schema_transforms = [
    97              config.identifier
    98              for config in external.SchemaAwareExternalTransform.discover(
    99                  self._service)
   100          ]
   101        except Exception:
   102          self._schema_transforms = []
   103      urn = self._urns[type]
   104      if urn in self._schema_transforms:
   105        return external.SchemaAwareExternalTransform(urn, self._service, **args)
   106      else:
   107        return type >> self.create_external_transform(urn, args)
   108  
   109    def create_external_transform(self, urn, args):
   110      return external.ExternalTransform(
   111          urn,
   112          external.ImplicitSchemaPayloadBuilder(args).payload(),
   113          self._service)
   114  
   115    @staticmethod
   116    def provider_from_spec(spec):
   117      urns = spec['transforms']
   118      type = spec['type']
   119      if spec.get('version', None) == 'BEAM_VERSION':
   120        spec['version'] = beam_version
   121      if type == 'javaJar':
   122        return ExternalJavaProvider(urns, lambda: spec['jar'])
   123      elif type == 'mavenJar':
   124        return ExternalJavaProvider(
   125            urns,
   126            lambda: subprocess_server.JavaJarServer.path_to_maven_jar(
   127                **{
   128                    key: value
   129                    for (key, value) in spec.items() if key in [
   130                        'artifact_id',
   131                        'group_id',
   132                        'version',
   133                        'repository',
   134                        'classifier',
   135                        'appendix'
   136                    ]
   137                }))
   138      elif type == 'beamJar':
   139        return ExternalJavaProvider(
   140            urns,
   141            lambda: subprocess_server.JavaJarServer.path_to_beam_jar(
   142                **{
   143                    key: value
   144                    for (key, value) in spec.items() if key in
   145                    ['gradle_target', 'version', 'appendix', 'artifact_id']
   146                }))
   147      elif type == 'pythonPackage':
   148        return ExternalPythonProvider(urns, spec['packages'])
   149      elif type == 'remote':
   150        return RemoteProvider(spec['address'])
   151      elif type == 'docker':
   152        raise NotImplementedError()
   153      else:
   154        raise NotImplementedError(f'Unknown provider type: {type}')
   155  
   156  
   157  class RemoteProvider(ExternalProvider):
   158    _is_available = None
   159  
   160    def available(self):
   161      if self._is_available is None:
   162        try:
   163          with external.ExternalTransform.service(self._service) as service:
   164            service.ready(1)
   165            self._is_available = True
   166        except Exception:
   167          self._is_available = False
   168      return self._is_available
   169  
   170  
   171  class ExternalJavaProvider(ExternalProvider):
   172    def __init__(self, urns, jar_provider):
   173      super().__init__(
   174          urns, lambda: external.JavaJarExpansionService(jar_provider()))
   175  
   176    def available(self):
   177      # pylint: disable=subprocess-run-check
   178      return subprocess.run(['which', 'java'],
   179                            capture_output=True).returncode == 0
   180  
   181  
   182  class ExternalPythonProvider(ExternalProvider):
   183    def __init__(self, urns, packages):
   184      super().__init__(urns, PypiExpansionService(packages))
   185  
   186    def available(self):
   187      return True  # If we're running this script, we have Python installed.
   188  
   189    def create_external_transform(self, urn, args):
   190      # Python transforms are "registered" by fully qualified name.
   191      return external.ExternalTransform(
   192          "beam:transforms:python:fully_qualified_named",
   193          external.ImplicitSchemaPayloadBuilder({
   194              'constructor': urn,
   195              'kwargs': args,
   196          }).payload(),
   197          self._service)
   198  
   199  
   200  # This is needed because type inference can't handle *args, **kwargs forwarding.
   201  # TODO(BEAM-24755): Add support for type inference of through kwargs calls.
   202  def fix_pycallable():
   203    from apache_beam.transforms.ptransform import label_from_callable
   204  
   205    def default_label(self):
   206      src = self._source.strip()
   207      last_line = src.split('\n')[-1]
   208      if last_line[0] != ' ' and len(last_line) < 72:
   209        return last_line
   210      return label_from_callable(self._callable)
   211  
   212    def _argspec_fn(self):
   213      return self._callable
   214  
   215    python_callable.PythonCallableWithSource.default_label = default_label
   216    python_callable.PythonCallableWithSource._argspec_fn = property(_argspec_fn)
   217  
   218    original_infer_return_type = trivial_inference.infer_return_type
   219  
   220    def infer_return_type(fn, *args, **kwargs):
   221      if isinstance(fn, python_callable.PythonCallableWithSource):
   222        fn = fn._callable
   223      return original_infer_return_type(fn, *args, **kwargs)
   224  
   225    trivial_inference.infer_return_type = infer_return_type
   226  
   227    original_fn_takes_side_inputs = (
   228        apache_beam.transforms.util.fn_takes_side_inputs)
   229  
   230    def fn_takes_side_inputs(fn):
   231      if isinstance(fn, python_callable.PythonCallableWithSource):
   232        fn = fn._callable
   233      return original_fn_takes_side_inputs(fn)
   234  
   235    apache_beam.transforms.util.fn_takes_side_inputs = fn_takes_side_inputs
   236  
   237  
   238  class InlineProvider(Provider):
   239    def __init__(self, transform_factories):
   240      self._transform_factories = transform_factories
   241  
   242    def available(self):
   243      return True
   244  
   245    def provided_transforms(self):
   246      return self._transform_factories.keys()
   247  
   248    def create_transform(self, type, args):
   249      return self._transform_factories[type](**args)
   250  
   251    def to_json(self):
   252      return {'type': "InlineProvider"}
   253  
   254  
   255  PRIMITIVE_NAMES_TO_ATOMIC_TYPE = {
   256      py_type.__name__: schema_type
   257      for (py_type, schema_type) in schemas.PRIMITIVE_TO_ATOMIC_TYPE.items()
   258      if py_type.__module__ != 'typing'
   259  }
   260  
   261  
   262  def create_builtin_provider():
   263    def with_schema(**args):
   264      # TODO: This is preliminary.
   265      def parse_type(spec):
   266        if spec in PRIMITIVE_NAMES_TO_ATOMIC_TYPE:
   267          return schema_pb2.FieldType(
   268              atomic_type=PRIMITIVE_NAMES_TO_ATOMIC_TYPE[spec])
   269        elif isinstance(spec, list):
   270          if len(spec) != 1:
   271            raise ValueError("Use single-element lists to denote list types.")
   272          else:
   273            return schema_pb2.FieldType(
   274                iterable_type=schema_pb2.IterableType(
   275                    element_type=parse_type(spec[0])))
   276        elif isinstance(spec, dict):
   277          return schema_pb2.FieldType(
   278              iterable_type=schema_pb2.RowType(schema=parse_schema(spec[0])))
   279        else:
   280          raise ValueError("Unknown schema type: {spec}")
   281  
   282      def parse_schema(spec):
   283        return schema_pb2.Schema(
   284            fields=[
   285                schema_pb2.Field(name=key, type=parse_type(value), id=ix)
   286                for (ix, (key, value)) in enumerate(spec.items())
   287            ],
   288            id=str(uuid.uuid4()))
   289  
   290      named_tuple = schemas.named_tuple_from_schema(parse_schema(args))
   291      names = list(args.keys())
   292  
   293      def extract_field(x, name):
   294        if isinstance(x, dict):
   295          return x[name]
   296        else:
   297          return getattr(x, name)
   298  
   299      return 'WithSchema(%s)' % ', '.join(names) >> beam.Map(
   300          lambda x: named_tuple(*[extract_field(x, name) for name in names])
   301      ).with_output_types(named_tuple)
   302  
   303    # Or should this be posargs, args?
   304    # pylint: disable=dangerous-default-value
   305    def fully_qualified_named_transform(constructor, args=(), kwargs={}):
   306      with FullyQualifiedNamedTransform.with_filter('*'):
   307        return constructor >> FullyQualifiedNamedTransform(
   308            constructor, args, kwargs)
   309  
   310    # This intermediate is needed because there is no way to specify a tuple of
   311    # exactly zero or one PCollection in yaml (as they would be interpreted as
   312    # PBegin and the PCollection itself respectively).
   313    class Flatten(beam.PTransform):
   314      def expand(self, pcolls):
   315        if isinstance(pcolls, beam.PCollection):
   316          pipeline_arg = {}
   317          pcolls = (pcolls, )
   318        elif isinstance(pcolls, dict):
   319          pipeline_arg = {}
   320          pcolls = tuple(pcolls.values())
   321        else:
   322          pipeline_arg = {'pipeline': pcolls.pipeline}
   323          pcolls = ()
   324        return pcolls | beam.Flatten(**pipeline_arg)
   325  
   326    class WindowInto(beam.PTransform):
   327      def __init__(self, windowing):
   328        self._window_transform = self._parse_window_spec(windowing)
   329  
   330      def expand(self, pcoll):
   331        return pcoll | self._window_transform
   332  
   333      @staticmethod
   334      def _parse_window_spec(spec):
   335        spec = dict(spec)
   336        window_type = spec.pop('type')
   337        # TODO: These are in seconds, perhaps parse duration strings meaningfully?
   338        if window_type == 'global':
   339          window_fn = window.GlobalWindows()
   340        elif window_type == 'fixed':
   341          window_fn = window.FixedWindows(spec.pop('size'), spec.pop('offset', 0))
   342        elif window_type == 'sliding':
   343          window_fn = window.SlidingWindows(
   344              spec.pop('size'), spec.pop('period'), spec.pop('offset', 0))
   345        elif window_type == 'sessions':
   346          window_fn = window.FixedWindows(spec.pop('gap'))
   347        if spec:
   348          raise ValueError(f'Unknown parameters {spec.keys()}')
   349        # TODO: Triggering, etc.
   350        return beam.WindowInto(window_fn)
   351  
   352    ios = {
   353        key: getattr(apache_beam.io, key)
   354        for key in dir(apache_beam.io)
   355        if key.startswith('ReadFrom') or key.startswith('WriteTo')
   356    }
   357  
   358    return InlineProvider(
   359        dict({
   360            'Create': lambda elements,
   361            reshuffle=True: beam.Create(elements, reshuffle),
   362            'PyMap': lambda fn: beam.Map(
   363                python_callable.PythonCallableWithSource(fn)),
   364            'PyMapTuple': lambda fn: beam.MapTuple(
   365                python_callable.PythonCallableWithSource(fn)),
   366            'PyFlatMap': lambda fn: beam.FlatMap(
   367                python_callable.PythonCallableWithSource(fn)),
   368            'PyFlatMapTuple': lambda fn: beam.FlatMapTuple(
   369                python_callable.PythonCallableWithSource(fn)),
   370            'PyFilter': lambda keep: beam.Filter(
   371                python_callable.PythonCallableWithSource(keep)),
   372            'PyTransform': fully_qualified_named_transform,
   373            'PyToRow': lambda fields: beam.Select(
   374                **{
   375                    name: python_callable.PythonCallableWithSource(fn)
   376                    for (name, fn) in fields.items()
   377                }),
   378            'WithSchema': with_schema,
   379            'Flatten': Flatten,
   380            'WindowInto': WindowInto,
   381            'GroupByKey': beam.GroupByKey,
   382        },
   383             **ios))
   384  
   385  
   386  class PypiExpansionService:
   387    """Expands transforms by fully qualified name in a virtual environment
   388    with the given dependencies.
   389    """
   390    VENV_CACHE = os.path.expanduser("~/.apache_beam/cache/venvs")
   391  
   392    def __init__(self, packages, base_python=sys.executable):
   393      self._packages = packages
   394      self._base_python = base_python
   395  
   396    def _key(self):
   397      return json.dumps({'binary': self._base_python, 'packages': self._packages})
   398  
   399    def _venv(self):
   400      venv = os.path.join(
   401          self.VENV_CACHE,
   402          hashlib.sha256(self._key().encode('utf-8')).hexdigest())
   403      if not os.path.exists(venv):
   404        python_binary = os.path.join(venv, 'bin', 'python')
   405        subprocess.run([self._base_python, '-m', 'venv', venv], check=True)
   406        subprocess.run([python_binary, '-m', 'ensurepip'], check=True)
   407        subprocess.run([python_binary, '-m', 'pip', 'install'] + self._packages,
   408                       check=True)
   409        with open(venv + '-requirements.txt', 'w') as fout:
   410          fout.write('\n'.join(self._packages))
   411      return venv
   412  
   413    def __enter__(self):
   414      venv = self._venv()
   415      self._service_provider = subprocess_server.SubprocessServer(
   416          external.ExpansionAndArtifactRetrievalStub,
   417          [
   418              os.path.join(venv, 'bin', 'python'),
   419              '-m',
   420              'apache_beam.runners.portability.expansion_service_main',
   421              '--port',
   422              '{{PORT}}',
   423              '--fully_qualified_name_glob=*',
   424              '--pickle_library=cloudpickle',
   425              '--requirements_file=' + os.path.join(venv + '-requirements.txt')
   426          ])
   427      self._service = self._service_provider.__enter__()
   428      return self._service
   429  
   430    def __exit__(self, *args):
   431      self._service_provider.__exit__(*args)
   432      self._service = None
   433  
   434  
   435  def parse_providers(provider_specs):
   436    providers = collections.defaultdict(list)
   437    for provider_spec in provider_specs:
   438      provider = ExternalProvider.provider_from_spec(provider_spec)
   439      for transform_type in provider.provided_transforms():
   440        providers[transform_type].append(provider)
   441        # TODO: Do this better.
   442        provider.to_json = lambda result=provider_spec: result
   443    return providers
   444  
   445  
   446  def merge_providers(*provider_sets):
   447    result = collections.defaultdict(list)
   448    for provider_set in provider_sets:
   449      for transform_type, providers in provider_set.items():
   450        result[transform_type].extend(providers)
   451    return result
   452  
   453  
   454  def standard_providers():
   455    builtin_providers = collections.defaultdict(list)
   456    builtin_provider = create_builtin_provider()
   457    for transform_type in builtin_provider.provided_transforms():
   458      builtin_providers[transform_type].append(builtin_provider)
   459    with open(os.path.join(os.path.dirname(__file__),
   460                           'standard_providers.yaml')) as fin:
   461      standard_providers = yaml.load(fin, Loader=SafeLoader)
   462    return merge_providers(builtin_providers, parse_providers(standard_providers))