github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/yaml/yaml_transform.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  import collections
    19  import json
    20  import logging
    21  import pprint
    22  import re
    23  import uuid
    24  from typing import Iterable
    25  from typing import Mapping
    26  
    27  import yaml
    28  from yaml.loader import SafeLoader
    29  
    30  import apache_beam as beam
    31  from apache_beam.transforms.fully_qualified_named_transform import FullyQualifiedNamedTransform
    32  from apache_beam.yaml import yaml_provider
    33  
    34  __all__ = ["YamlTransform"]
    35  
    36  _LOGGER = logging.getLogger(__name__)
    37  yaml_provider.fix_pycallable()
    38  
    39  
    40  def memoize_method(func):
    41    def wrapper(self, *args):
    42      if not hasattr(self, '_cache'):
    43        self._cache = {}
    44      key = func.__name__, args
    45      if key not in self._cache:
    46        self._cache[key] = func(self, *args)
    47      return self._cache[key]
    48  
    49    return wrapper
    50  
    51  
    52  def only_element(xs):
    53    x, = xs
    54    return x
    55  
    56  
    57  class SafeLineLoader(SafeLoader):
    58    """A yaml loader that attaches line information to mappings and strings."""
    59    class TaggedString(str):
    60      """A string class to which we can attach metadata.
    61  
    62      This is primarily used to trace a string's origin back to its place in a
    63      yaml file.
    64      """
    65      def __reduce__(self):
    66        # Pickle as an ordinary string.
    67        return str, (str(self), )
    68  
    69    def construct_scalar(self, node):
    70      value = super().construct_scalar(node)
    71      if isinstance(value, str):
    72        value = SafeLineLoader.TaggedString(value)
    73        value._line_ = node.start_mark.line + 1
    74      return value
    75  
    76    def construct_mapping(self, node, deep=False):
    77      mapping = super().construct_mapping(node, deep=deep)
    78      mapping['__line__'] = node.start_mark.line + 1
    79      mapping['__uuid__'] = self.create_uuid()
    80      return mapping
    81  
    82    @classmethod
    83    def create_uuid(cls):
    84      return str(uuid.uuid4())
    85  
    86    @classmethod
    87    def strip_metadata(cls, spec, tagged_str=True):
    88      if isinstance(spec, Mapping):
    89        return {
    90            key: cls.strip_metadata(value, tagged_str)
    91            for key,
    92            value in spec.items() if key not in ('__line__', '__uuid__')
    93        }
    94      elif isinstance(spec, Iterable) and not isinstance(spec, (str, bytes)):
    95        return [cls.strip_metadata(value, tagged_str) for value in spec]
    96      elif isinstance(spec, SafeLineLoader.TaggedString) and tagged_str:
    97        return str(spec)
    98      else:
    99        return spec
   100  
   101    @staticmethod
   102    def get_line(obj):
   103      if isinstance(obj, dict):
   104        return obj.get('__line__', 'unknown')
   105      else:
   106        return getattr(obj, '_line_', 'unknown')
   107  
   108  
   109  class LightweightScope(object):
   110    def __init__(self, transforms):
   111      self._transforms = transforms
   112      self._transforms_by_uuid = {t['__uuid__']: t for t in self._transforms}
   113      self._uuid_by_name = collections.defaultdict(list)
   114      for spec in self._transforms:
   115        if 'name' in spec:
   116          self._uuid_by_name[spec['name']].append(spec['__uuid__'])
   117        if 'type' in spec:
   118          self._uuid_by_name[spec['type']].append(spec['__uuid__'])
   119  
   120    def get_transform_id_and_output_name(self, name):
   121      if '.' in name:
   122        transform_name, output = name.rsplit('.', 1)
   123      else:
   124        transform_name, output = name, None
   125      return self.get_transform_id(transform_name), output
   126  
   127    def get_transform_id(self, transform_name):
   128      if transform_name in self._transforms_by_uuid:
   129        return transform_name
   130      else:
   131        candidates = self._uuid_by_name[transform_name]
   132        if not candidates:
   133          raise ValueError(
   134              f'Unknown transform at line '
   135              f'{SafeLineLoader.get_line(transform_name)}: {transform_name}')
   136        elif len(candidates) > 1:
   137          raise ValueError(
   138              f'Ambiguous transform at line '
   139              f'{SafeLineLoader.get_line(transform_name)}: {transform_name}')
   140        else:
   141          return only_element(candidates)
   142  
   143  
   144  class Scope(LightweightScope):
   145    """To look up PCollections (typically outputs of prior transforms) by name."""
   146    def __init__(self, root, inputs, transforms, providers):
   147      super().__init__(transforms)
   148      self.root = root
   149      self._inputs = inputs
   150      self.providers = providers
   151      self._seen_names = set()
   152  
   153    def compute_all(self):
   154      for transform_id in self._transforms_by_uuid.keys():
   155        self.compute_outputs(transform_id)
   156  
   157    def get_pcollection(self, name):
   158      if name in self._inputs:
   159        return self._inputs[name]
   160      elif '.' in name:
   161        transform, output = name.rsplit('.', 1)
   162        outputs = self.get_outputs(transform)
   163        if output in outputs:
   164          return outputs[output]
   165        else:
   166          raise ValueError(
   167              f'Unknown output {repr(output)} '
   168              f'at line {SafeLineLoader.get_line(name)}: '
   169              f'{transform} only has outputs {list(outputs.keys())}')
   170      else:
   171        outputs = self.get_outputs(name)
   172        if len(outputs) == 1:
   173          return only_element(outputs.values())
   174        else:
   175          raise ValueError(
   176              f'Ambiguous output at line {SafeLineLoader.get_line(name)}: '
   177              f'{name} has outputs {list(outputs.keys())}')
   178  
   179    def get_outputs(self, transform_name):
   180      return self.compute_outputs(self.get_transform_id(transform_name))
   181  
   182    @memoize_method
   183    def compute_outputs(self, transform_id):
   184      return expand_transform(self._transforms_by_uuid[transform_id], self)
   185  
   186    # A method on scope as providers may be scoped...
   187    def create_ptransform(self, spec):
   188      if 'type' not in spec:
   189        raise ValueError(f'Missing transform type: {identify_object(spec)}')
   190  
   191      if spec['type'] not in self.providers:
   192        raise ValueError(
   193            'Unknown transform type %r at %s' %
   194            (spec['type'], identify_object(spec)))
   195  
   196      for provider in self.providers.get(spec['type']):
   197        if provider.available():
   198          break
   199      else:
   200        raise ValueError(
   201            'No available provider for type %r at %s' %
   202            (spec['type'], identify_object(spec)))
   203  
   204      if 'args' in spec:
   205        args = spec['args']
   206        if not isinstance(args, dict):
   207          raise ValueError(
   208              'Arguments for transform at %s must be a mapping.' %
   209              identify_object(spec))
   210      else:
   211        args = {
   212            key: value
   213            for (key, value) in spec.items()
   214            if key not in ('type', 'name', 'input', 'output')
   215        }
   216      real_args = SafeLineLoader.strip_metadata(args)
   217      try:
   218        # pylint: disable=undefined-loop-variable
   219        ptransform = provider.create_transform(spec['type'], real_args)
   220        # TODO(robertwb): Should we have a better API for adding annotations
   221        # than this?
   222        annotations = dict(
   223            yaml_type=spec['type'],
   224            yaml_args=json.dumps(real_args),
   225            yaml_provider=json.dumps(provider.to_json()),
   226            **ptransform.annotations())
   227        ptransform.annotations = lambda: annotations
   228        return ptransform
   229      except Exception as exn:
   230        if isinstance(exn, TypeError):
   231          # Create a slightly more generic error message for argument errors.
   232          msg = str(exn).replace('positional', '').replace('keyword', '')
   233          msg = re.sub(r'\S+lambda\S+', '', msg)
   234          msg = re.sub('  +', ' ', msg).strip()
   235        else:
   236          msg = str(exn)
   237        raise ValueError(
   238            f'Invalid transform specification at {identify_object(spec)}: {msg}'
   239        ) from exn
   240  
   241    def unique_name(self, spec, ptransform, strictness=0):
   242      if 'name' in spec:
   243        name = spec['name']
   244        strictness += 1
   245      else:
   246        name = ptransform.label
   247      if name in self._seen_names:
   248        if strictness >= 2:
   249          raise ValueError(f'Duplicate name at {identify_object(spec)}: {name}')
   250        else:
   251          name = f'{name}@{SafeLineLoader.get_line(spec)}'
   252      self._seen_names.add(name)
   253      return name
   254  
   255  
   256  def expand_transform(spec, scope):
   257    if 'type' not in spec:
   258      raise TypeError(
   259          f'Missing type parameter for transform at {identify_object(spec)}')
   260    type = spec['type']
   261    if type == 'composite':
   262      return expand_composite_transform(spec, scope)
   263    else:
   264      return expand_leaf_transform(spec, scope)
   265  
   266  
   267  def expand_leaf_transform(spec, scope):
   268    spec = normalize_inputs_outputs(spec)
   269    inputs_dict = {
   270        key: scope.get_pcollection(value)
   271        for (key, value) in spec['input'].items()
   272    }
   273    input_type = spec.get('input_type', 'default')
   274    if input_type == 'list':
   275      inputs = tuple(inputs_dict.values())
   276    elif input_type == 'map':
   277      inputs = inputs_dict
   278    else:
   279      if len(inputs_dict) == 0:
   280        inputs = scope.root
   281      elif len(inputs_dict) == 1:
   282        inputs = next(iter(inputs_dict.values()))
   283      else:
   284        inputs = inputs_dict
   285    _LOGGER.info("Expanding %s ", identify_object(spec))
   286    ptransform = scope.create_ptransform(spec)
   287    try:
   288      # TODO: Move validation to construction?
   289      with FullyQualifiedNamedTransform.with_filter('*'):
   290        outputs = inputs | scope.unique_name(spec, ptransform) >> ptransform
   291    except Exception as exn:
   292      raise ValueError(
   293          f"Error apply transform {identify_object(spec)}: {exn}") from exn
   294    if isinstance(outputs, dict):
   295      # TODO: Handle (or at least reject) nested case.
   296      return outputs
   297    elif isinstance(outputs, (tuple, list)):
   298      return {'out{ix}': pcoll for (ix, pcoll) in enumerate(outputs)}
   299    elif isinstance(outputs, beam.PCollection):
   300      return {'out': outputs}
   301    else:
   302      raise ValueError(
   303          f'Transform {identify_object(spec)} returned an unexpected type '
   304          f'{type(outputs)}')
   305  
   306  
   307  def expand_composite_transform(spec, scope):
   308    spec = normalize_inputs_outputs(normalize_source_sink(spec))
   309  
   310    inner_scope = Scope(
   311        scope.root, {
   312            key: scope.get_pcollection(value)
   313            for key,
   314            value in spec['input'].items()
   315        },
   316        spec['transforms'],
   317        yaml_provider.merge_providers(
   318            yaml_provider.parse_providers(spec.get('providers', [])),
   319            scope.providers))
   320  
   321    class CompositePTransform(beam.PTransform):
   322      @staticmethod
   323      def expand(inputs):
   324        inner_scope.compute_all()
   325        return {
   326            key: inner_scope.get_pcollection(value)
   327            for (key, value) in spec['output'].items()
   328        }
   329  
   330    if 'name' not in spec:
   331      spec['name'] = 'Composite'
   332    if spec['name'] is None:  # top-level pipeline, don't nest
   333      return CompositePTransform.expand(None)
   334    else:
   335      _LOGGER.info("Expanding %s ", identify_object(spec))
   336      return ({
   337          key: scope.get_pcollection(value)
   338          for key,
   339          value in spec['input'].items()
   340      } or scope.root) | scope.unique_name(spec, None) >> CompositePTransform()
   341  
   342  
   343  def expand_chain_transform(spec, scope):
   344    return expand_composite_transform(chain_as_composite(spec), scope)
   345  
   346  
   347  def chain_as_composite(spec):
   348    # A chain is simply a composite transform where all inputs and outputs
   349    # are implicit.
   350    spec = normalize_source_sink(spec)
   351    if 'transforms' not in spec:
   352      raise TypeError(
   353          f"Chain at {identify_object(spec)} missing transforms property.")
   354    has_explicit_outputs = 'output' in spec
   355    composite_spec = normalize_inputs_outputs(spec)
   356    new_transforms = []
   357    for ix, transform in enumerate(composite_spec['transforms']):
   358      if any(io in transform for io in ('input', 'output', 'input', 'output')):
   359        raise ValueError(
   360            f'Transform {identify_object(transform)} is part of a chain, '
   361            'must have implicit inputs and outputs.')
   362      if ix == 0:
   363        transform['input'] = {key: key for key in composite_spec['input'].keys()}
   364      else:
   365        transform['input'] = new_transforms[-1]['__uuid__']
   366      new_transforms.append(transform)
   367    composite_spec['transforms'] = new_transforms
   368  
   369    last_transform = new_transforms[-1]['__uuid__']
   370    if has_explicit_outputs:
   371      composite_spec['output'] = {
   372          key: f'{last_transform}.{value}'
   373          for (key, value) in composite_spec['output'].items()
   374      }
   375    else:
   376      composite_spec['output'] = last_transform
   377    if 'name' not in composite_spec:
   378      composite_spec['name'] = 'Chain'
   379    composite_spec['type'] = 'composite'
   380    return composite_spec
   381  
   382  
   383  def preprocess_chain(spec):
   384    if spec['type'] == 'chain':
   385      return chain_as_composite(spec)
   386    else:
   387      return spec
   388  
   389  
   390  def pipeline_as_composite(spec):
   391    if isinstance(spec, list):
   392      return {
   393          'type': 'composite',
   394          'name': None,
   395          'transforms': spec,
   396          '__line__': spec[0]['__line__'],
   397          '__uuid__': SafeLineLoader.create_uuid(),
   398      }
   399    else:
   400      return dict(spec, name=None, type=spec.get('type', 'composite'))
   401  
   402  
   403  def normalize_source_sink(spec):
   404    if 'source' not in spec and 'sink' not in spec:
   405      return spec
   406    spec = dict(spec)
   407    spec['transforms'] = list(spec.get('transforms', []))
   408    if 'source' in spec:
   409      spec['transforms'].insert(0, spec.pop('source'))
   410    if 'sink' in spec:
   411      spec['transforms'].append(spec.pop('sink'))
   412    return spec
   413  
   414  
   415  def preprocess_source_sink(spec):
   416    if spec['type'] in ('chain', 'composite'):
   417      return normalize_source_sink(spec)
   418    else:
   419      return spec
   420  
   421  
   422  def normalize_inputs_outputs(spec):
   423    spec = dict(spec)
   424  
   425    def normalize_io(tag):
   426      io = spec.get(tag, {})
   427      if isinstance(io, (str, list)):
   428        return {tag: io}
   429      else:
   430        return SafeLineLoader.strip_metadata(io, tagged_str=False)
   431  
   432    return dict(spec, input=normalize_io('input'), output=normalize_io('output'))
   433  
   434  
   435  def identify_object(spec):
   436    line = SafeLineLoader.get_line(spec)
   437    name = extract_name(spec)
   438    if name:
   439      return f'"{name}" at line {line}'
   440    else:
   441      return f'at line {line}'
   442  
   443  
   444  def extract_name(spec):
   445    if 'name' in spec:
   446      return spec['name']
   447    elif 'id' in spec:
   448      return spec['id']
   449    elif 'type' in spec:
   450      return spec['type']
   451    elif len(spec) == 1:
   452      return extract_name(next(iter(spec.values())))
   453    else:
   454      return ''
   455  
   456  
   457  def push_windowing_to_roots(spec):
   458    scope = LightweightScope(spec['transforms'])
   459    consumed_outputs_by_transform = collections.defaultdict(set)
   460    for transform in spec['transforms']:
   461      for _, input_ref in transform['input'].items():
   462        try:
   463          transform_id, output = scope.get_transform_id_and_output_name(input_ref)
   464          consumed_outputs_by_transform[transform_id].add(output)
   465        except ValueError:
   466          # Could be an input or an ambiguity we'll raise later.
   467          pass
   468  
   469    for transform in spec['transforms']:
   470      if not transform['input'] and 'windowing' not in transform:
   471        transform['windowing'] = spec['windowing']
   472        transform['__consumed_outputs'] = consumed_outputs_by_transform[
   473            transform['__uuid__']]
   474  
   475    return spec
   476  
   477  
   478  def preprocess_windowing(spec):
   479    if spec['type'] == 'WindowInto':
   480      # This is the transform where it is actually applied.
   481      return spec
   482    elif 'windowing' not in spec:
   483      # Nothing to do.
   484      return spec
   485  
   486    if spec['type'] == 'composite':
   487      # Apply the windowing to any reads, creates, etc. in this transform
   488      # TODO(robertwb): Better handle the case where a read is followed by a
   489      # setting of the timestamps. We should be careful of sliding windows
   490      # in particular.
   491      spec = push_windowing_to_roots(spec)
   492  
   493    windowing = spec.pop('windowing')
   494    if spec['input']:
   495      # Apply the windowing to all inputs by wrapping it in a trasnform that
   496      # first applies windowing and then applies the original transform.
   497      original_inputs = spec['input']
   498      windowing_transforms = [{
   499          'type': 'WindowInto',
   500          'name': f'WindowInto[{key}]',
   501          'windowing': windowing,
   502          'input': key,
   503          '__line__': spec['__line__'],
   504          '__uuid__': SafeLineLoader.create_uuid(),
   505      } for key in original_inputs.keys()]
   506      windowed_inputs = {
   507          key: t['__uuid__']
   508          for (key, t) in zip(original_inputs.keys(), windowing_transforms)
   509      }
   510      modified_spec = dict(
   511          spec, input=windowed_inputs, __uuid__=SafeLineLoader.create_uuid())
   512      return {
   513          'type': 'composite',
   514          'name': spec.get('name', None) or spec['type'],
   515          'transforms': [modified_spec] + windowing_transforms,
   516          'input': spec['input'],
   517          'output': modified_spec['__uuid__'],
   518          '__line__': spec['__line__'],
   519          '__uuid__': spec['__uuid__'],
   520      }
   521  
   522    elif spec['type'] == 'composite':
   523      # Pushing the windowing down was sufficient.
   524      return spec
   525  
   526    else:
   527      # No inputs, apply the windowing to all outputs.
   528      consumed_outputs = list(spec.pop('__consumed_outputs', {None}))
   529      modified_spec = dict(spec, __uuid__=SafeLineLoader.create_uuid())
   530      windowing_transforms = [{
   531          'type': 'WindowInto',
   532          'name': f'WindowInto[{out}]',
   533          'windowing': windowing,
   534          'input': modified_spec['__uuid__'] + ('.' + out if out else ''),
   535          '__line__': spec['__line__'],
   536          '__uuid__': SafeLineLoader.create_uuid(),
   537      } for out in consumed_outputs]
   538      if consumed_outputs == [None]:
   539        windowed_outputs = only_element(windowing_transforms)['__uuid__']
   540      else:
   541        windowed_outputs = {
   542            out: t['__uuid__']
   543            for (out, t) in zip(consumed_outputs, windowing_transforms)
   544        }
   545      return {
   546          'type': 'composite',
   547          'name': spec.get('name', None) or spec['type'],
   548          'transforms': [modified_spec] + windowing_transforms,
   549          'output': windowed_outputs,
   550          '__line__': spec['__line__'],
   551          '__uuid__': spec['__uuid__'],
   552      }
   553  
   554  
   555  def preprocess_flattened_inputs(spec):
   556    if spec['type'] != 'composite':
   557      return spec
   558  
   559    # Prefer to add the flattens as sibling operations rather than nesting
   560    # to keep graph shape consistent when the number of inputs goes from
   561    # one to multiple.
   562    new_transforms = []
   563    for t in spec['transforms']:
   564      if t['type'] == 'Flatten':
   565        # Don't flatten before explicit flatten.
   566        # But we do have to expand list inputs into singleton inputs.
   567        def all_inputs(t):
   568          for key, values in t.get('input', {}).items():
   569            if isinstance(values, list):
   570              for ix, values in enumerate(values):
   571                yield f'{key}{ix}', values
   572            else:
   573              yield key, values
   574  
   575        inputs_dict = {}
   576        for key, value in all_inputs(t):
   577          while key in inputs_dict:
   578            key += '_'
   579          inputs_dict[key] = value
   580        t = dict(t, input=inputs_dict)
   581      else:
   582        replaced_inputs = {}
   583        for key, values in t.get('input', {}).items():
   584          if isinstance(values, list):
   585            flatten_id = SafeLineLoader.create_uuid()
   586            new_transforms.append({
   587                'type': 'Flatten',
   588                'name': '%s-Flatten[%s]' % (t.get('name', t['type']), key),
   589                'input': {
   590                    f'input{ix}': value
   591                    for (ix, value) in enumerate(values)
   592                },
   593                '__line__': spec['__line__'],
   594                '__uuid__': flatten_id,
   595            })
   596            replaced_inputs[key] = flatten_id
   597        if replaced_inputs:
   598          t = dict(t, input={**t['input'], **replaced_inputs})
   599      new_transforms.append(t)
   600    return dict(spec, transforms=new_transforms)
   601  
   602  
   603  def preprocess(spec, verbose=False):
   604    if verbose:
   605      pprint.pprint(spec)
   606  
   607    def apply(phase, spec):
   608      spec = phase(spec)
   609      if spec['type'] in {'composite', 'chain'}:
   610        spec = dict(
   611            spec, transforms=[apply(phase, t) for t in spec['transforms']])
   612      return spec
   613  
   614    for phase in [preprocess_source_sink,
   615                  preprocess_chain,
   616                  normalize_inputs_outputs,
   617                  preprocess_flattened_inputs,
   618                  preprocess_windowing]:
   619      spec = apply(phase, spec)
   620      if verbose:
   621        print('=' * 20, phase, '=' * 20)
   622        pprint.pprint(spec)
   623    return spec
   624  
   625  
   626  class YamlTransform(beam.PTransform):
   627    def __init__(self, spec, providers={}):  # pylint: disable=dangerous-default-value
   628      if isinstance(spec, str):
   629        spec = yaml.load(spec, Loader=SafeLineLoader)
   630      self._spec = preprocess(spec)
   631      self._providers = yaml_provider.merge_providers(
   632          {
   633              key: yaml_provider.as_provider_list(key, value)
   634              for (key, value) in providers.items()
   635          },
   636          yaml_provider.standard_providers())
   637  
   638    def expand(self, pcolls):
   639      if isinstance(pcolls, beam.pvalue.PBegin):
   640        root = pcolls
   641        pcolls = {}
   642      elif isinstance(pcolls, beam.PCollection):
   643        root = pcolls.pipeline
   644        pcolls = {'input': pcolls}
   645      else:
   646        root = next(iter(pcolls.values())).pipeline
   647      result = expand_transform(
   648          self._spec,
   649          Scope(root, pcolls, transforms=[], providers=self._providers))
   650      if len(result) == 1:
   651        return only_element(result.values())
   652      else:
   653        return result
   654  
   655  
   656  def expand_pipeline(pipeline, pipeline_spec, providers=None):
   657    if isinstance(pipeline_spec, str):
   658      pipeline_spec = yaml.load(pipeline_spec, Loader=SafeLineLoader)
   659    # Calling expand directly to avoid outer layer of nesting.
   660    return YamlTransform(
   661        pipeline_as_composite(pipeline_spec['pipeline']),
   662        {
   663            **yaml_provider.parse_providers(pipeline_spec.get('providers', [])),
   664            **{
   665                key: yaml_provider.as_provider_list(key, value)
   666                for (key, value) in (providers or {}).items()
   667            }
   668        }).expand(beam.pvalue.PBegin(pipeline))