github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/dataflow/ptransform_overrides.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Ptransform overrides for DataflowRunner."""
    19  
    20  # pytype: skip-file
    21  
    22  from apache_beam.options.pipeline_options import StandardOptions
    23  from apache_beam.pipeline import PTransformOverride
    24  
    25  
    26  class CreatePTransformOverride(PTransformOverride):
    27    """A ``PTransformOverride`` for ``Create`` in streaming mode."""
    28    def matches(self, applied_ptransform):
    29      # Imported here to avoid circular dependencies.
    30      # pylint: disable=wrong-import-order, wrong-import-position
    31      from apache_beam import Create
    32      return isinstance(applied_ptransform.transform, Create)
    33  
    34    def get_replacement_transform_for_applied_ptransform(
    35        self, applied_ptransform):
    36      # Imported here to avoid circular dependencies.
    37      # pylint: disable=wrong-import-order, wrong-import-position
    38      from apache_beam import PTransform
    39  
    40      ptransform = applied_ptransform.transform
    41  
    42      # Return a wrapper rather than ptransform.as_read() directly to
    43      # ensure backwards compatibility of the pipeline structure.
    44      class LegacyCreate(PTransform):
    45        def expand(self, pbegin):
    46          return pbegin | ptransform.as_read()
    47  
    48      return LegacyCreate().with_output_types(ptransform.get_output_type())
    49  
    50  
    51  class ReadPTransformOverride(PTransformOverride):
    52    """A ``PTransformOverride`` for ``Read(BoundedSource)``"""
    53    def matches(self, applied_ptransform):
    54      from apache_beam.io import Read
    55      from apache_beam.io.iobase import BoundedSource
    56      # Only overrides Read(BoundedSource) transform
    57      if (isinstance(applied_ptransform.transform, Read) and
    58          not getattr(applied_ptransform.transform, 'override', False)):
    59        if isinstance(applied_ptransform.transform.source, BoundedSource):
    60          return True
    61      return False
    62  
    63    def get_replacement_transform_for_applied_ptransform(
    64        self, applied_ptransform):
    65  
    66      from apache_beam import pvalue
    67      from apache_beam.io import iobase
    68  
    69      transform = applied_ptransform.transform
    70  
    71      class Read(iobase.Read):
    72        override = True
    73  
    74        def expand(self, pbegin):
    75          return pvalue.PCollection(
    76              self.pipeline, is_bounded=self.source.is_bounded())
    77  
    78      return Read(transform.source).with_output_types(
    79          transform.get_type_hints().simple_output_type('Read'))
    80  
    81  
    82  class CombineValuesPTransformOverride(PTransformOverride):
    83    """A ``PTransformOverride`` for ``CombineValues``.
    84  
    85    The DataflowRunner expects that the CombineValues PTransform acts as a
    86    primitive. So this override replaces the CombineValues with a primitive.
    87    """
    88    def matches(self, applied_ptransform):
    89      # Imported here to avoid circular dependencies.
    90      # pylint: disable=wrong-import-order, wrong-import-position
    91      from apache_beam import CombineValues
    92  
    93      if isinstance(applied_ptransform.transform, CombineValues):
    94        self.transform = applied_ptransform.transform
    95        return True
    96      return False
    97  
    98    def get_replacement_transform(self, ptransform):
    99      # Imported here to avoid circular dependencies.
   100      # pylint: disable=wrong-import-order, wrong-import-position
   101      from apache_beam import PTransform
   102      from apache_beam.pvalue import PCollection
   103  
   104      # The DataflowRunner still needs access to the CombineValues members to
   105      # generate a V1B3 proto representation, so we remember the transform from
   106      # the matches method and forward it here.
   107      class CombineValuesReplacement(PTransform):
   108        def __init__(self, transform):
   109          self.transform = transform
   110  
   111        def expand(self, pcoll):
   112          return PCollection.from_(pcoll)
   113  
   114      return CombineValuesReplacement(self.transform)
   115  
   116  
   117  class NativeReadPTransformOverride(PTransformOverride):
   118    """A ``PTransformOverride`` for ``Read`` using native sources.
   119  
   120    The DataflowRunner expects that the Read PTransform using native sources act
   121    as a primitive. So this override replaces the Read with a primitive.
   122    """
   123    def matches(self, applied_ptransform):
   124      # Imported here to avoid circular dependencies.
   125      # pylint: disable=wrong-import-order, wrong-import-position
   126      from apache_beam.io import Read
   127  
   128      # Consider the native Read to be a primitive for Dataflow by replacing.
   129      return (
   130          isinstance(applied_ptransform.transform, Read) and
   131          not getattr(applied_ptransform.transform, 'override', False) and
   132          hasattr(applied_ptransform.transform.source, 'format'))
   133  
   134    def get_replacement_transform(self, ptransform):
   135      # Imported here to avoid circular dependencies.
   136      # pylint: disable=wrong-import-order, wrong-import-position
   137      from apache_beam import pvalue
   138      from apache_beam.io import iobase
   139  
   140      # This is purposely subclassed from the Read transform to take advantage of
   141      # the existing windowing, typing, and display data.
   142      class Read(iobase.Read):
   143        override = True
   144  
   145        def expand(self, pbegin):
   146          return pvalue.PCollection.from_(pbegin)
   147  
   148      # Use the source's coder type hint as this replacement's output. Otherwise,
   149      # the typing information is not properly forwarded to the DataflowRunner and
   150      # will choose the incorrect coder for this transform.
   151      return Read(ptransform.source).with_output_types(
   152          ptransform.source.coder.to_type_hint())
   153  
   154  
   155  class GroupIntoBatchesWithShardedKeyPTransformOverride(PTransformOverride):
   156    """A ``PTransformOverride`` for ``GroupIntoBatches.WithShardedKey``.
   157  
   158    This override simply returns the original transform but additionally records
   159    the output PCollection in order to append required step properties during
   160    graph translation.
   161    """
   162    def __init__(self, dataflow_runner, options):
   163      self.dataflow_runner = dataflow_runner
   164      self.options = options
   165  
   166    def matches(self, applied_ptransform):
   167      # Imported here to avoid circular dependencies.
   168      # pylint: disable=wrong-import-order, wrong-import-position
   169      from apache_beam import util
   170  
   171      transform = applied_ptransform.transform
   172  
   173      if not isinstance(transform, util.GroupIntoBatches.WithShardedKey):
   174        return False
   175  
   176      # The replacement is only valid for portable Streaming Engine jobs with
   177      # runner v2.
   178      standard_options = self.options.view_as(StandardOptions)
   179      if not standard_options.streaming:
   180        return False
   181  
   182      self.dataflow_runner.add_pcoll_with_auto_sharding(applied_ptransform)
   183      return True
   184  
   185    def get_replacement_transform_for_applied_ptransform(self, ptransform):
   186      return ptransform.transform