github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/util.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Simple utility PTransforms.
    19  """
    20  
    21  # pytype: skip-file
    22  
    23  import collections
    24  import contextlib
    25  import logging
    26  import random
    27  import re
    28  import threading
    29  import time
    30  import uuid
    31  from typing import TYPE_CHECKING
    32  from typing import Any
    33  from typing import Iterable
    34  from typing import List
    35  from typing import Tuple
    36  from typing import TypeVar
    37  from typing import Union
    38  
    39  from apache_beam import coders
    40  from apache_beam import typehints
    41  from apache_beam.metrics import Metrics
    42  from apache_beam.portability import common_urns
    43  from apache_beam.portability.api import beam_runner_api_pb2
    44  from apache_beam.pvalue import AsSideInput
    45  from apache_beam.transforms import window
    46  from apache_beam.transforms.combiners import CountCombineFn
    47  from apache_beam.transforms.core import CombinePerKey
    48  from apache_beam.transforms.core import Create
    49  from apache_beam.transforms.core import DoFn
    50  from apache_beam.transforms.core import FlatMap
    51  from apache_beam.transforms.core import Flatten
    52  from apache_beam.transforms.core import GroupByKey
    53  from apache_beam.transforms.core import Map
    54  from apache_beam.transforms.core import MapTuple
    55  from apache_beam.transforms.core import ParDo
    56  from apache_beam.transforms.core import Windowing
    57  from apache_beam.transforms.ptransform import PTransform
    58  from apache_beam.transforms.ptransform import ptransform_fn
    59  from apache_beam.transforms.timeutil import TimeDomain
    60  from apache_beam.transforms.trigger import AccumulationMode
    61  from apache_beam.transforms.trigger import Always
    62  from apache_beam.transforms.userstate import BagStateSpec
    63  from apache_beam.transforms.userstate import CombiningValueStateSpec
    64  from apache_beam.transforms.userstate import TimerSpec
    65  from apache_beam.transforms.userstate import on_timer
    66  from apache_beam.transforms.window import NonMergingWindowFn
    67  from apache_beam.transforms.window import TimestampCombiner
    68  from apache_beam.transforms.window import TimestampedValue
    69  from apache_beam.typehints import trivial_inference
    70  from apache_beam.typehints.decorators import get_signature
    71  from apache_beam.typehints.sharded_key_type import ShardedKeyType
    72  from apache_beam.utils import windowed_value
    73  from apache_beam.utils.annotations import deprecated
    74  from apache_beam.utils.sharded_key import ShardedKey
    75  
    76  if TYPE_CHECKING:
    77    from apache_beam import pvalue
    78    from apache_beam.runners.pipeline_context import PipelineContext
    79  
    80  __all__ = [
    81      'BatchElements',
    82      'CoGroupByKey',
    83      'Distinct',
    84      'Keys',
    85      'KvSwap',
    86      'LogElements',
    87      'Regex',
    88      'Reify',
    89      'RemoveDuplicates',
    90      'Reshuffle',
    91      'ToString',
    92      'Values',
    93      'WithKeys',
    94      'GroupIntoBatches'
    95  ]
    96  
    97  K = TypeVar('K')
    98  V = TypeVar('V')
    99  T = TypeVar('T')
   100  
   101  
   102  class CoGroupByKey(PTransform):
   103    """Groups results across several PCollections by key.
   104  
   105    Given an input dict of serializable keys (called "tags") to 0 or more
   106    PCollections of (key, value) tuples, it creates a single output PCollection
   107    of (key, value) tuples whose keys are the unique input keys from all inputs,
   108    and whose values are dicts mapping each tag to an iterable of whatever values
   109    were under the key in the corresponding PCollection, in this manner::
   110  
   111        ('some key', {'tag1': ['value 1 under "some key" in pcoll1',
   112                               'value 2 under "some key" in pcoll1',
   113                               ...],
   114                      'tag2': ... ,
   115                      ... })
   116  
   117    where `[]` refers to an iterable, not a list.
   118  
   119    For example, given::
   120  
   121        {'tag1': pc1, 'tag2': pc2, 333: pc3}
   122  
   123    where::
   124  
   125        pc1 = beam.Create([(k1, v1)]))
   126        pc2 = beam.Create([])
   127        pc3 = beam.Create([(k1, v31), (k1, v32), (k2, v33)])
   128  
   129    The output PCollection would consist of items::
   130  
   131        [(k1, {'tag1': [v1], 'tag2': [], 333: [v31, v32]}),
   132         (k2, {'tag1': [], 'tag2': [], 333: [v33]})]
   133  
   134    where `[]` refers to an iterable, not a list.
   135  
   136    CoGroupByKey also works for tuples, lists, or other flat iterables of
   137    PCollections, in which case the values of the resulting PCollections
   138    will be tuples whose nth value is the iterable of values from the nth
   139    PCollection---conceptually, the "tags" are the indices into the input.
   140    Thus, for this input::
   141  
   142       (pc1, pc2, pc3)
   143  
   144    the output would be::
   145  
   146        [(k1, ([v1], [], [v31, v32]),
   147         (k2, ([], [], [v33]))]
   148  
   149    where, again, `[]` refers to an iterable, not a list.
   150  
   151    Attributes:
   152      **kwargs: Accepts a single named argument "pipeline", which specifies the
   153        pipeline that "owns" this PTransform. Ordinarily CoGroupByKey can obtain
   154        this information from one of the input PCollections, but if there are none
   155        (or if there's a chance there may be none), this argument is the only way
   156        to provide pipeline information, and should be considered mandatory.
   157    """
   158    def __init__(self, *, pipeline=None):
   159      self.pipeline = pipeline
   160  
   161    def _extract_input_pvalues(self, pvalueish):
   162      try:
   163        # If this works, it's a dict.
   164        return pvalueish, tuple(pvalueish.values())
   165      except AttributeError:
   166        # Cast iterables a tuple so we can do re-iteration.
   167        pcolls = tuple(pvalueish)
   168        return pcolls, pcolls
   169  
   170    def expand(self, pcolls):
   171      if not pcolls:
   172        pcolls = (self.pipeline | Create([]), )
   173      if isinstance(pcolls, dict):
   174        tags = list(pcolls.keys())
   175        if all(isinstance(tag, str) and len(tag) < 10 for tag in tags):
   176          # Small, string tags. Pass them as data.
   177          pcolls_dict = pcolls
   178          restore_tags = None
   179        else:
   180          # Pass the tags in the restore_tags closure.
   181          tags = list(pcolls.keys())
   182          pcolls_dict = {str(ix): pcolls[tag] for (ix, tag) in enumerate(tags)}
   183          restore_tags = lambda vs: {
   184              tag: vs[str(ix)]
   185              for (ix, tag) in enumerate(tags)
   186          }
   187      else:
   188        # Tags are tuple indices.
   189        tags = [str(ix) for ix in range(len(pcolls))]
   190        pcolls_dict = dict(zip(tags, pcolls))
   191        restore_tags = lambda vs: tuple(vs[tag] for tag in tags)
   192  
   193      input_key_types = []
   194      input_value_types = []
   195      for pcoll in pcolls_dict.values():
   196        key_type, value_type = typehints.trivial_inference.key_value_types(
   197            pcoll.element_type)
   198        input_key_types.append(key_type)
   199        input_value_types.append(value_type)
   200      output_key_type = typehints.Union[tuple(input_key_types)]
   201      iterable_input_value_types = tuple(
   202          typehints.Iterable[t] for t in input_value_types)
   203  
   204      output_value_type = typehints.Dict[
   205          str, typehints.Union[iterable_input_value_types or [typehints.Any]]]
   206      result = (
   207          pcolls_dict
   208          | 'CoGroupByKeyImpl' >>
   209          _CoGBKImpl(pipeline=self.pipeline).with_output_types(
   210              typehints.Tuple[output_key_type, output_value_type]))
   211  
   212      if restore_tags:
   213        if isinstance(pcolls, dict):
   214          dict_key_type = typehints.Union[tuple(
   215              trivial_inference.instance_to_type(tag) for tag in tags)]
   216          output_value_type = typehints.Dict[
   217              dict_key_type, typehints.Union[iterable_input_value_types]]
   218        else:
   219          output_value_type = typehints.Tuple[iterable_input_value_types]
   220        result |= 'RestoreTags' >> MapTuple(
   221            lambda k, vs: (k, restore_tags(vs))).with_output_types(
   222                typehints.Tuple[output_key_type, output_value_type])
   223  
   224      return result
   225  
   226  
   227  class _CoGBKImpl(PTransform):
   228    def __init__(self, *, pipeline=None):
   229      self.pipeline = pipeline
   230  
   231    def expand(self, pcolls):
   232      # Check input PCollections for PCollection-ness, and that they all belong
   233      # to the same pipeline.
   234      for pcoll in pcolls.values():
   235        self._check_pcollection(pcoll)
   236        if self.pipeline:
   237          assert pcoll.pipeline == self.pipeline, (
   238              'All input PCollections must belong to the same pipeline.')
   239  
   240      tags = list(pcolls.keys())
   241  
   242      def add_tag(tag):
   243        return lambda k, v: (k, (tag, v))
   244  
   245      def collect_values(key, tagged_values):
   246        grouped_values = {tag: [] for tag in tags}
   247        for tag, value in tagged_values:
   248          grouped_values[tag].append(value)
   249        return key, grouped_values
   250  
   251      return ([
   252          pcoll
   253          | 'Tag[%s]' % tag >> MapTuple(add_tag(tag))
   254          for (tag, pcoll) in pcolls.items()
   255      ]
   256              | Flatten(pipeline=self.pipeline)
   257              | GroupByKey()
   258              | MapTuple(collect_values))
   259  
   260  
   261  @ptransform_fn
   262  @typehints.with_input_types(Tuple[K, V])
   263  @typehints.with_output_types(K)
   264  def Keys(pcoll, label='Keys'):  # pylint: disable=invalid-name
   265    """Produces a PCollection of first elements of 2-tuples in a PCollection."""
   266    return pcoll | label >> MapTuple(lambda k, _: k)
   267  
   268  
   269  @ptransform_fn
   270  @typehints.with_input_types(Tuple[K, V])
   271  @typehints.with_output_types(V)
   272  def Values(pcoll, label='Values'):  # pylint: disable=invalid-name
   273    """Produces a PCollection of second elements of 2-tuples in a PCollection."""
   274    return pcoll | label >> MapTuple(lambda _, v: v)
   275  
   276  
   277  @ptransform_fn
   278  @typehints.with_input_types(Tuple[K, V])
   279  @typehints.with_output_types(Tuple[V, K])
   280  def KvSwap(pcoll, label='KvSwap'):  # pylint: disable=invalid-name
   281    """Produces a PCollection reversing 2-tuples in a PCollection."""
   282    return pcoll | label >> MapTuple(lambda k, v: (v, k))
   283  
   284  
   285  @ptransform_fn
   286  @typehints.with_input_types(T)
   287  @typehints.with_output_types(T)
   288  def Distinct(pcoll):  # pylint: disable=invalid-name
   289    """Produces a PCollection containing distinct elements of a PCollection."""
   290    return (
   291        pcoll
   292        | 'ToPairs' >> Map(lambda v: (v, None))
   293        | 'Group' >> CombinePerKey(lambda vs: None)
   294        | 'Distinct' >> Keys())
   295  
   296  
   297  @deprecated(since='2.12', current='Distinct')
   298  @ptransform_fn
   299  @typehints.with_input_types(T)
   300  @typehints.with_output_types(T)
   301  def RemoveDuplicates(pcoll):
   302    """Produces a PCollection containing distinct elements of a PCollection."""
   303    return pcoll | 'RemoveDuplicates' >> Distinct()
   304  
   305  
   306  class _BatchSizeEstimator(object):
   307    """Estimates the best size for batches given historical timing.
   308    """
   309  
   310    _MAX_DATA_POINTS = 100
   311    _MAX_GROWTH_FACTOR = 2
   312  
   313    def __init__(
   314        self,
   315        min_batch_size=1,
   316        max_batch_size=10000,
   317        target_batch_overhead=.05,
   318        target_batch_duration_secs=10,
   319        target_batch_duration_secs_including_fixed_cost=None,
   320        variance=0.25,
   321        clock=time.time,
   322        ignore_first_n_seen_per_batch_size=0,
   323        record_metrics=True):
   324      if min_batch_size > max_batch_size:
   325        raise ValueError(
   326            "Minimum (%s) must not be greater than maximum (%s)" %
   327            (min_batch_size, max_batch_size))
   328      if target_batch_overhead and not 0 < target_batch_overhead <= 1:
   329        raise ValueError(
   330            "target_batch_overhead (%s) must be between 0 and 1" %
   331            (target_batch_overhead))
   332      if target_batch_duration_secs and target_batch_duration_secs <= 0:
   333        raise ValueError(
   334            "target_batch_duration_secs (%s) must be positive" %
   335            (target_batch_duration_secs))
   336      if (target_batch_duration_secs_including_fixed_cost and
   337          target_batch_duration_secs_including_fixed_cost <= 0):
   338        raise ValueError(
   339            "target_batch_duration_secs_including_fixed_cost "
   340            "(%s) must be positive" %
   341            (target_batch_duration_secs_including_fixed_cost))
   342      if not (target_batch_overhead or target_batch_duration_secs or
   343              target_batch_duration_secs_including_fixed_cost):
   344        raise ValueError(
   345            "At least one of target_batch_overhead or "
   346            "target_batch_duration_secs or "
   347            "target_batch_duration_secs_including_fixed_cost must be positive.")
   348      if ignore_first_n_seen_per_batch_size < 0:
   349        raise ValueError(
   350            'ignore_first_n_seen_per_batch_size (%s) must be non '
   351            'negative' % (ignore_first_n_seen_per_batch_size))
   352      self._min_batch_size = min_batch_size
   353      self._max_batch_size = max_batch_size
   354      self._target_batch_overhead = target_batch_overhead
   355      self._target_batch_duration_secs = target_batch_duration_secs
   356      self._target_batch_duration_secs_including_fixed_cost = (
   357          target_batch_duration_secs_including_fixed_cost)
   358      self._variance = variance
   359      self._clock = clock
   360      self._data = []
   361      self._ignore_next_timing = False
   362      self._ignore_first_n_seen_per_batch_size = (
   363          ignore_first_n_seen_per_batch_size)
   364      self._batch_size_num_seen = {}
   365      self._replay_last_batch_size = None
   366      self._record_metrics = record_metrics
   367      self._element_count = 0
   368      self._batch_count = 0
   369  
   370      if record_metrics:
   371        self._size_distribution = Metrics.distribution(
   372            'BatchElements', 'batch_size')
   373        self._time_distribution = Metrics.distribution(
   374            'BatchElements', 'msec_per_batch')
   375      else:
   376        self._size_distribution = self._time_distribution = None
   377      # Beam distributions only accept integer values, so we use this to
   378      # accumulate under-reported values until they add up to whole milliseconds.
   379      # (Milliseconds are chosen because that's conventionally used elsewhere in
   380      # profiling-style counters.)
   381      self._remainder_msecs = 0
   382  
   383    def ignore_next_timing(self):
   384      """Call to indicate the next timing should be ignored.
   385  
   386      For example, the first emit of a ParDo operation is known to be anomalous
   387      due to setup that may occur.
   388      """
   389      self._ignore_next_timing = True
   390  
   391    @contextlib.contextmanager
   392    def record_time(self, batch_size):
   393      start = self._clock()
   394      yield
   395      elapsed = self._clock() - start
   396      elapsed_msec = 1e3 * elapsed + self._remainder_msecs
   397      if self._record_metrics:
   398        self._size_distribution.update(batch_size)
   399        self._time_distribution.update(int(elapsed_msec))
   400      self._element_count += batch_size
   401      self._batch_count += 1
   402      self._remainder_msecs = elapsed_msec - int(elapsed_msec)
   403      # If we ignore the next timing, replay the batch size to get accurate
   404      # timing.
   405      if self._ignore_next_timing:
   406        self._ignore_next_timing = False
   407        self._replay_last_batch_size = min(batch_size, self._max_batch_size)
   408      else:
   409        self._data.append((batch_size, elapsed))
   410        if len(self._data) >= self._MAX_DATA_POINTS:
   411          self._thin_data()
   412  
   413    def _thin_data(self):
   414      # Make sure we don't change the parity of len(self._data)
   415      # As it's used below to alternate jitter.
   416      self._data.pop(random.randrange(len(self._data) // 4))
   417      self._data.pop(random.randrange(len(self._data) // 2))
   418  
   419    @staticmethod
   420    def linear_regression_no_numpy(xs, ys):
   421      # Least squares fit for y = a + bx over all points.
   422      n = float(len(xs))
   423      xbar = sum(xs) / n
   424      ybar = sum(ys) / n
   425      if xbar == 0:
   426        return ybar, 0
   427      if all(xs[0] == x for x in xs):
   428        # Simply use the mean if all values in xs are same.
   429        return 0, ybar / xbar
   430      b = (
   431          sum([(x - xbar) * (y - ybar)
   432               for x, y in zip(xs, ys)]) / sum([(x - xbar)**2 for x in xs]))
   433      a = ybar - b * xbar
   434      return a, b
   435  
   436    @staticmethod
   437    def linear_regression_numpy(xs, ys):
   438      # pylint: disable=wrong-import-order, wrong-import-position
   439      import numpy as np
   440      from numpy import sum
   441      n = len(xs)
   442      if all(xs[0] == x for x in xs):
   443        # If all values of xs are same then fallback to linear_regression_no_numpy
   444        return _BatchSizeEstimator.linear_regression_no_numpy(xs, ys)
   445      xs = np.asarray(xs, dtype=float)
   446      ys = np.asarray(ys, dtype=float)
   447  
   448      # First do a simple least squares fit for y = a + bx over all points.
   449      b, a = np.polyfit(xs, ys, 1)
   450  
   451      if n < 10:
   452        return a, b
   453      else:
   454        # Refine this by throwing out outliers, according to Cook's distance.
   455        # https://en.wikipedia.org/wiki/Cook%27s_distance
   456        sum_x = sum(xs)
   457        sum_x2 = sum(xs**2)
   458        errs = a + b * xs - ys
   459        s2 = sum(errs**2) / (n - 2)
   460        if s2 == 0:
   461          # It's an exact fit!
   462          return a, b
   463        h = (sum_x2 - 2 * sum_x * xs + n * xs**2) / (n * sum_x2 - sum_x**2)
   464        cook_ds = 0.5 / s2 * errs**2 * (h / (1 - h)**2)
   465  
   466        # Re-compute the regression, excluding those points with Cook's distance
   467        # greater than 0.5, and weighting by the inverse of x to give a more
   468        # stable y-intercept (as small batches have relatively more information
   469        # about the fixed overhead).
   470        weight = (cook_ds <= 0.5) / xs
   471        b, a = np.polyfit(xs, ys, 1, w=weight)
   472        return a, b
   473  
   474    try:
   475      # pylint: disable=wrong-import-order, wrong-import-position
   476      import numpy as np
   477      linear_regression = linear_regression_numpy
   478    except ImportError:
   479      linear_regression = linear_regression_no_numpy
   480  
   481    def _calculate_next_batch_size(self):
   482      if self._min_batch_size == self._max_batch_size:
   483        return self._min_batch_size
   484      elif len(self._data) < 1:
   485        return self._min_batch_size
   486      elif len(self._data) < 2:
   487        # Force some variety so we have distinct batch sizes on which to do
   488        # linear regression below.
   489        return int(
   490            max(
   491                min(
   492                    self._max_batch_size,
   493                    self._min_batch_size * self._MAX_GROWTH_FACTOR),
   494                self._min_batch_size + 1))
   495  
   496      # There tends to be a lot of noise in the top quantile, which also
   497      # has outsided influence in the regression.  If we have enough data,
   498      # Simply declare the top 20% to be outliers.
   499      trimmed_data = sorted(self._data)[:max(20, len(self._data) * 4 // 5)]
   500  
   501      # Linear regression for y = a + bx, where x is batch size and y is time.
   502      xs, ys = zip(*trimmed_data)
   503      a, b = self.linear_regression(xs, ys)
   504  
   505      # Avoid nonsensical or division-by-zero errors below due to noise.
   506      a = max(a, 1e-10)
   507      b = max(b, 1e-20)
   508  
   509      last_batch_size = self._data[-1][0]
   510      cap = min(last_batch_size * self._MAX_GROWTH_FACTOR, self._max_batch_size)
   511  
   512      target = self._max_batch_size
   513  
   514      if self._target_batch_duration_secs_including_fixed_cost:
   515        # Solution to
   516        # a + b*x = self._target_batch_duration_secs_including_fixed_cost.
   517        target = min(
   518            target,
   519            (self._target_batch_duration_secs_including_fixed_cost - a) / b)
   520  
   521      if self._target_batch_duration_secs:
   522        # Solution to b*x = self._target_batch_duration_secs.
   523        # We ignore the fixed cost in this computation as it has negligeabel
   524        # impact when it is small and unhelpfully forces the minimum batch size
   525        # when it is large.
   526        target = min(target, self._target_batch_duration_secs / b)
   527  
   528      if self._target_batch_overhead:
   529        # Solution to a / (a + b*x) = self._target_batch_overhead.
   530        target = min(target, (a / b) * (1 / self._target_batch_overhead - 1))
   531  
   532      # Avoid getting stuck at a single batch size (especially the minimal
   533      # batch size) which would not allow us to extrapolate to other batch
   534      # sizes.
   535      # Jitter alternates between 0 and 1.
   536      jitter = len(self._data) % 2
   537      # Smear our samples across a range centered at the target.
   538      if len(self._data) > 10:
   539        target += int(target * self._variance * 2 * (random.random() - .5))
   540  
   541      return int(max(self._min_batch_size + jitter, min(target, cap)))
   542  
   543    def next_batch_size(self):
   544      # Check if we should replay a previous batch size due to it not being
   545      # recorded.
   546      if self._replay_last_batch_size:
   547        result = self._replay_last_batch_size
   548        self._replay_last_batch_size = None
   549      else:
   550        result = self._calculate_next_batch_size()
   551  
   552      seen_count = self._batch_size_num_seen.get(result, 0) + 1
   553      if seen_count <= self._ignore_first_n_seen_per_batch_size:
   554        self.ignore_next_timing()
   555      self._batch_size_num_seen[result] = seen_count
   556      return result
   557  
   558    def stats(self):
   559      return "element_count=%s batch_count=%s next_batch_size=%s timings=%s" % (
   560          self._element_count,
   561          self._batch_count,
   562          self._calculate_next_batch_size(),
   563          self._data)
   564  
   565  
   566  class _GlobalWindowsBatchingDoFn(DoFn):
   567    def __init__(self, batch_size_estimator, element_size_fn):
   568      self._batch_size_estimator = batch_size_estimator
   569      self._element_size_fn = element_size_fn
   570  
   571    def start_bundle(self):
   572      self._batch = []
   573      self._running_batch_size = 0
   574      self._target_batch_size = self._batch_size_estimator.next_batch_size()
   575      # The first emit often involves non-trivial setup.
   576      self._batch_size_estimator.ignore_next_timing()
   577  
   578    def process(self, element):
   579      self._batch.append(element)
   580      self._running_batch_size += self._element_size_fn(element)
   581      if self._running_batch_size >= self._target_batch_size:
   582        with self._batch_size_estimator.record_time(self._running_batch_size):
   583          yield window.GlobalWindows.windowed_value_at_end_of_window(self._batch)
   584        self._batch = []
   585        self._running_batch_size = 0
   586        self._target_batch_size = self._batch_size_estimator.next_batch_size()
   587  
   588    def finish_bundle(self):
   589      if self._batch:
   590        with self._batch_size_estimator.record_time(self._running_batch_size):
   591          yield window.GlobalWindows.windowed_value_at_end_of_window(self._batch)
   592        self._batch = None
   593        self._running_batch_size = 0
   594      self._target_batch_size = self._batch_size_estimator.next_batch_size()
   595      logging.info(
   596          "BatchElements statistics: " + self._batch_size_estimator.stats())
   597  
   598  
   599  class _SizedBatch():
   600    def __init__(self):
   601      self.elements = []
   602      self.size = 0
   603  
   604  
   605  class _WindowAwareBatchingDoFn(DoFn):
   606  
   607    _MAX_LIVE_WINDOWS = 10
   608  
   609    def __init__(self, batch_size_estimator, element_size_fn):
   610      self._batch_size_estimator = batch_size_estimator
   611      self._element_size_fn = element_size_fn
   612  
   613    def start_bundle(self):
   614      self._batches = collections.defaultdict(_SizedBatch)
   615      self._target_batch_size = self._batch_size_estimator.next_batch_size()
   616      # The first emit often involves non-trivial setup.
   617      self._batch_size_estimator.ignore_next_timing()
   618  
   619    def process(self, element, window=DoFn.WindowParam):
   620      batch = self._batches[window]
   621      batch.elements.append(element)
   622      batch.size += self._element_size_fn(element)
   623      if batch.size >= self._target_batch_size:
   624        with self._batch_size_estimator.record_time(batch.size):
   625          yield windowed_value.WindowedValue(
   626              batch.elements, window.max_timestamp(), (window, ))
   627        del self._batches[window]
   628        self._target_batch_size = self._batch_size_estimator.next_batch_size()
   629      elif len(self._batches) > self._MAX_LIVE_WINDOWS:
   630        window, batch = max(
   631            self._batches.items(),
   632            key=lambda window_batch: window_batch[1].size)
   633        with self._batch_size_estimator.record_time(batch.size):
   634          yield windowed_value.WindowedValue(
   635              batch.elements, window.max_timestamp(), (window, ))
   636        del self._batches[window]
   637        self._target_batch_size = self._batch_size_estimator.next_batch_size()
   638  
   639    def finish_bundle(self):
   640      for window, batch in self._batches.items():
   641        if batch:
   642          with self._batch_size_estimator.record_time(batch.size):
   643            yield windowed_value.WindowedValue(
   644                batch.elements, window.max_timestamp(), (window, ))
   645      self._batches = None
   646      self._target_batch_size = self._batch_size_estimator.next_batch_size()
   647  
   648  
   649  @typehints.with_input_types(T)
   650  @typehints.with_output_types(List[T])
   651  class BatchElements(PTransform):
   652    """A Transform that batches elements for amortized processing.
   653  
   654    This transform is designed to precede operations whose processing cost
   655    is of the form
   656  
   657        time = fixed_cost + num_elements * per_element_cost
   658  
   659    where the per element cost is (often significantly) smaller than the fixed
   660    cost and could be amortized over multiple elements.  It consumes a PCollection
   661    of element type T and produces a PCollection of element type List[T].
   662  
   663    This transform attempts to find the best batch size between the minimim
   664    and maximum parameters by profiling the time taken by (fused) downstream
   665    operations. For a fixed batch size, set the min and max to be equal.
   666  
   667    Elements are batched per-window and batches emitted in the window
   668    corresponding to its contents.
   669  
   670    Args:
   671      min_batch_size: (optional) the smallest size of a batch
   672      max_batch_size: (optional) the largest size of a batch
   673      target_batch_overhead: (optional) a target for fixed_cost / time,
   674          as used in the formula above
   675      target_batch_duration_secs: (optional) a target for total time per bundle,
   676          in seconds, excluding fixed cost
   677      target_batch_duration_secs_including_fixed_cost: (optional) a target for
   678          total time per bundle, in seconds, including fixed cost
   679      element_size_fn: (optional) A mapping of an element to its contribution to
   680          batch size, defaulting to every element having size 1.  When provided,
   681          attempts to provide batches of optimal total size which may consist of
   682          a varying number of elements.
   683      variance: (optional) the permitted (relative) amount of deviation from the
   684          (estimated) ideal batch size used to produce a wider base for
   685          linear interpolation
   686      clock: (optional) an alternative to time.time for measuring the cost of
   687          donwstream operations (mostly for testing)
   688      record_metrics: (optional) whether or not to record beam metrics on
   689          distributions of the batch size. Defaults to True.
   690    """
   691    def __init__(
   692        self,
   693        min_batch_size=1,
   694        max_batch_size=10000,
   695        target_batch_overhead=.05,
   696        target_batch_duration_secs=10,
   697        target_batch_duration_secs_including_fixed_cost=None,
   698        *,
   699        element_size_fn=lambda x: 1,
   700        variance=0.25,
   701        clock=time.time,
   702        record_metrics=True):
   703      self._batch_size_estimator = _BatchSizeEstimator(
   704          min_batch_size=min_batch_size,
   705          max_batch_size=max_batch_size,
   706          target_batch_overhead=target_batch_overhead,
   707          target_batch_duration_secs=target_batch_duration_secs,
   708          target_batch_duration_secs_including_fixed_cost=(
   709              target_batch_duration_secs_including_fixed_cost),
   710          variance=variance,
   711          clock=clock,
   712          record_metrics=record_metrics)
   713      self._element_size_fn = element_size_fn
   714  
   715    def expand(self, pcoll):
   716      if getattr(pcoll.pipeline.runner, 'is_streaming', False):
   717        raise NotImplementedError("Requires stateful processing (BEAM-2687)")
   718      elif pcoll.windowing.is_default():
   719        # This is the same logic as _GlobalWindowsBatchingDoFn, but optimized
   720        # for that simpler case.
   721        return pcoll | ParDo(
   722            _GlobalWindowsBatchingDoFn(
   723                self._batch_size_estimator, self._element_size_fn))
   724      else:
   725        return pcoll | ParDo(
   726            _WindowAwareBatchingDoFn(
   727                self._batch_size_estimator, self._element_size_fn))
   728  
   729  
   730  class _IdentityWindowFn(NonMergingWindowFn):
   731    """Windowing function that preserves existing windows.
   732  
   733    To be used internally with the Reshuffle transform.
   734    Will raise an exception when used after DoFns that return TimestampedValue
   735    elements.
   736    """
   737    def __init__(self, window_coder):
   738      """Create a new WindowFn with compatible coder.
   739      To be applied to PCollections with windows that are compatible with the
   740      given coder.
   741  
   742      Arguments:
   743        window_coder: coders.Coder object to be used on windows.
   744      """
   745      super().__init__()
   746      if window_coder is None:
   747        raise ValueError('window_coder should not be None')
   748      self._window_coder = window_coder
   749  
   750    def assign(self, assign_context):
   751      if assign_context.window is None:
   752        raise ValueError(
   753            'assign_context.window should not be None. '
   754            'This might be due to a DoFn returning a TimestampedValue.')
   755      return [assign_context.window]
   756  
   757    def get_window_coder(self):
   758      return self._window_coder
   759  
   760  
   761  @typehints.with_input_types(Tuple[K, V])
   762  @typehints.with_output_types(Tuple[K, V])
   763  class ReshufflePerKey(PTransform):
   764    """PTransform that returns a PCollection equivalent to its input,
   765    but operationally provides some of the side effects of a GroupByKey,
   766    in particular checkpointing, and preventing fusion of the surrounding
   767    transforms.
   768    """
   769    def expand(self, pcoll):
   770      windowing_saved = pcoll.windowing
   771      if windowing_saved.is_default():
   772        # In this (common) case we can use a trivial trigger driver
   773        # and avoid the (expensive) window param.
   774        globally_windowed = window.GlobalWindows.windowed_value(None)
   775        MIN_TIMESTAMP = window.MIN_TIMESTAMP
   776  
   777        def reify_timestamps(element, timestamp=DoFn.TimestampParam):
   778          key, value = element
   779          if timestamp == MIN_TIMESTAMP:
   780            timestamp = None
   781          return key, (value, timestamp)
   782  
   783        def restore_timestamps(element):
   784          key, values = element
   785          return [
   786              globally_windowed.with_value((key, value)) if timestamp is None else
   787              window.GlobalWindows.windowed_value((key, value), timestamp)
   788              for (value, timestamp) in values
   789          ]
   790      else:
   791  
   792        # typing: All conditional function variants must have identical signatures
   793        def reify_timestamps(  # type: ignore[misc]
   794            element, timestamp=DoFn.TimestampParam, window=DoFn.WindowParam):
   795          key, value = element
   796          # Transport the window as part of the value and restore it later.
   797          return key, windowed_value.WindowedValue(value, timestamp, [window])
   798  
   799        def restore_timestamps(element):
   800          key, windowed_values = element
   801          return [wv.with_value((key, wv.value)) for wv in windowed_values]
   802  
   803      ungrouped = pcoll | Map(reify_timestamps).with_output_types(Any)
   804  
   805      # TODO(https://github.com/apache/beam/issues/19785) Using global window as
   806      # one of the standard window. This is to mitigate the Dataflow Java Runner
   807      # Harness limitation to accept only standard coders.
   808      ungrouped._windowing = Windowing(
   809          window.GlobalWindows(),
   810          triggerfn=Always(),
   811          accumulation_mode=AccumulationMode.DISCARDING,
   812          timestamp_combiner=TimestampCombiner.OUTPUT_AT_EARLIEST)
   813      result = (
   814          ungrouped
   815          | GroupByKey()
   816          | FlatMap(restore_timestamps).with_output_types(Any))
   817      result._windowing = windowing_saved
   818      return result
   819  
   820  
   821  @typehints.with_input_types(T)
   822  @typehints.with_output_types(T)
   823  class Reshuffle(PTransform):
   824    """PTransform that returns a PCollection equivalent to its input,
   825    but operationally provides some of the side effects of a GroupByKey,
   826    in particular checkpointing, and preventing fusion of the surrounding
   827    transforms.
   828  
   829    Reshuffle adds a temporary random key to each element, performs a
   830    ReshufflePerKey, and finally removes the temporary key.
   831    """
   832  
   833    # We use 32-bit integer as the default number of buckets.
   834    _DEFAULT_NUM_BUCKETS = 1 << 32
   835  
   836    def __init__(self, num_buckets=None):
   837      """
   838      :param num_buckets: If set, specifies the maximum random keys that would be
   839        generated.
   840      """
   841      self.num_buckets = num_buckets if num_buckets else self._DEFAULT_NUM_BUCKETS
   842  
   843      valid_buckets = isinstance(num_buckets, int) and num_buckets > 0
   844      if not (num_buckets is None or valid_buckets):
   845        raise ValueError(
   846            'If `num_buckets` is set, it has to be an '
   847            'integer greater than 0, got %s' % num_buckets)
   848  
   849    def expand(self, pcoll):
   850      # type: (pvalue.PValue) -> pvalue.PCollection
   851      return (
   852          pcoll | 'AddRandomKeys' >>
   853          Map(lambda t: (random.randrange(0, self.num_buckets), t)
   854              ).with_input_types(T).with_output_types(Tuple[int, T])
   855          | ReshufflePerKey()
   856          | 'RemoveRandomKeys' >> Map(lambda t: t[1]).with_input_types(
   857              Tuple[int, T]).with_output_types(T))
   858  
   859    def to_runner_api_parameter(self, unused_context):
   860      # type: (PipelineContext) -> Tuple[str, None]
   861      return common_urns.composites.RESHUFFLE.urn, None
   862  
   863    @staticmethod
   864    @PTransform.register_urn(common_urns.composites.RESHUFFLE.urn, None)
   865    def from_runner_api_parameter(
   866        unused_ptransform, unused_parameter, unused_context):
   867      return Reshuffle()
   868  
   869  
   870  def fn_takes_side_inputs(fn):
   871    fn = getattr(fn, '_argspec_fn', fn)
   872    try:
   873      signature = get_signature(fn)
   874    except TypeError:
   875      # We can't tell; maybe it does.
   876      return True
   877  
   878    return (
   879        len(signature.parameters) > 1 or any(
   880            p.kind == p.VAR_POSITIONAL or p.kind == p.VAR_KEYWORD
   881            for p in signature.parameters.values()))
   882  
   883  
   884  @ptransform_fn
   885  def WithKeys(pcoll, k, *args, **kwargs):
   886    """PTransform that takes a PCollection, and either a constant key or a
   887    callable, and returns a PCollection of (K, V), where each of the values in
   888    the input PCollection has been paired with either the constant key or a key
   889    computed from the value.  The callable may optionally accept positional or
   890    keyword arguments, which should be passed to WithKeys directly.  These may
   891    be either SideInputs or static (non-PCollection) values, such as ints.
   892    """
   893    if callable(k):
   894      if fn_takes_side_inputs(k):
   895        if all(isinstance(arg, AsSideInput)
   896               for arg in args) and all(isinstance(kwarg, AsSideInput)
   897                                        for kwarg in kwargs.values()):
   898          return pcoll | Map(
   899              lambda v,
   900              *args,
   901              **kwargs: (k(v, *args, **kwargs), v),
   902              *args,
   903              **kwargs)
   904        return pcoll | Map(lambda v: (k(v, *args, **kwargs), v))
   905      return pcoll | Map(lambda v: (k(v), v))
   906    return pcoll | Map(lambda v: (k, v))
   907  
   908  
   909  @typehints.with_input_types(Tuple[K, V])
   910  @typehints.with_output_types(Tuple[K, Iterable[V]])
   911  class GroupIntoBatches(PTransform):
   912    """PTransform that batches the input into desired batch size. Elements are
   913    buffered until they are equal to batch size provided in the argument at which
   914    point they are output to the output Pcollection.
   915  
   916    Windows are preserved (batches will contain elements from the same window)
   917    """
   918    def __init__(
   919        self, batch_size, max_buffering_duration_secs=None, clock=time.time):
   920      """Create a new GroupIntoBatches.
   921  
   922      Arguments:
   923        batch_size: (required) How many elements should be in a batch
   924        max_buffering_duration_secs: (optional) How long in seconds at most an
   925          incomplete batch of elements is allowed to be buffered in the states.
   926          The duration must be a positive second duration and should be given as
   927          an int or float. Setting this parameter to zero effectively means no
   928          buffering limit.
   929        clock: (optional) an alternative to time.time (mostly for testing)
   930      """
   931      self.params = _GroupIntoBatchesParams(
   932          batch_size, max_buffering_duration_secs)
   933      self.clock = clock
   934  
   935    def expand(self, pcoll):
   936      input_coder = coders.registry.get_coder(pcoll)
   937      return pcoll | ParDo(
   938          _pardo_group_into_batches(
   939              input_coder,
   940              self.params.batch_size,
   941              self.params.max_buffering_duration_secs,
   942              self.clock))
   943  
   944    def to_runner_api_parameter(
   945        self,
   946        unused_context  # type: PipelineContext
   947    ):  # type: (...) -> Tuple[str, beam_runner_api_pb2.GroupIntoBatchesPayload]
   948      return (
   949          common_urns.group_into_batches_components.GROUP_INTO_BATCHES.urn,
   950          self.params.get_payload())
   951  
   952    @staticmethod
   953    @PTransform.register_urn(
   954        common_urns.group_into_batches_components.GROUP_INTO_BATCHES.urn,
   955        beam_runner_api_pb2.GroupIntoBatchesPayload)
   956    def from_runner_api_parameter(unused_ptransform, proto, unused_context):
   957      return GroupIntoBatches(*_GroupIntoBatchesParams.parse_payload(proto))
   958  
   959    @typehints.with_input_types(Tuple[K, V])
   960    @typehints.with_output_types(
   961        typehints.Tuple[
   962            ShardedKeyType[typehints.TypeVariable(K)],  # type: ignore[misc]
   963            typehints.Iterable[typehints.TypeVariable(V)]])
   964    class WithShardedKey(PTransform):
   965      """A GroupIntoBatches transform that outputs batched elements associated
   966      with sharded input keys.
   967  
   968      By default, keys are sharded to such that the input elements with the same
   969      key are spread to all available threads executing the transform. Runners may
   970      override the default sharding to do a better load balancing during the
   971      execution time.
   972      """
   973      def __init__(
   974          self, batch_size, max_buffering_duration_secs=None, clock=time.time):
   975        """Create a new GroupIntoBatches with sharded output.
   976        See ``GroupIntoBatches`` transform for a description of input parameters.
   977        """
   978        self.params = _GroupIntoBatchesParams(
   979            batch_size, max_buffering_duration_secs)
   980        self.clock = clock
   981  
   982      _shard_id_prefix = uuid.uuid4().bytes
   983  
   984      def expand(self, pcoll):
   985        key_type, value_type = pcoll.element_type.tuple_types
   986        sharded_pcoll = pcoll | Map(
   987            lambda key_value: (
   988                ShardedKey(
   989                    key_value[0],
   990                    # Use [uuid, thread id] as the shard id.
   991                    GroupIntoBatches.WithShardedKey._shard_id_prefix + bytes(
   992                        threading.get_ident().to_bytes(8, 'big'))),
   993                key_value[1])).with_output_types(
   994                    typehints.Tuple[
   995                        ShardedKeyType[key_type],  # type: ignore[misc]
   996                        value_type])
   997        return (
   998            sharded_pcoll
   999            | GroupIntoBatches(
  1000                self.params.batch_size,
  1001                self.params.max_buffering_duration_secs,
  1002                self.clock))
  1003  
  1004      def to_runner_api_parameter(
  1005          self,
  1006          unused_context  # type: PipelineContext
  1007      ):  # type: (...) -> Tuple[str, beam_runner_api_pb2.GroupIntoBatchesPayload]
  1008        return (
  1009            common_urns.composites.GROUP_INTO_BATCHES_WITH_SHARDED_KEY.urn,
  1010            self.params.get_payload())
  1011  
  1012      @staticmethod
  1013      @PTransform.register_urn(
  1014          common_urns.composites.GROUP_INTO_BATCHES_WITH_SHARDED_KEY.urn,
  1015          beam_runner_api_pb2.GroupIntoBatchesPayload)
  1016      def from_runner_api_parameter(unused_ptransform, proto, unused_context):
  1017        return GroupIntoBatches.WithShardedKey(
  1018            *_GroupIntoBatchesParams.parse_payload(proto))
  1019  
  1020  
  1021  class _GroupIntoBatchesParams:
  1022    """This class represents the parameters for
  1023    :class:`apache_beam.utils.GroupIntoBatches` transform, used to define how
  1024    elements should be batched.
  1025    """
  1026    def __init__(self, batch_size, max_buffering_duration_secs):
  1027      self.batch_size = batch_size
  1028      self.max_buffering_duration_secs = (
  1029          0
  1030          if max_buffering_duration_secs is None else max_buffering_duration_secs)
  1031      self._validate()
  1032  
  1033    def __eq__(self, other):
  1034      if other is None or not isinstance(other, _GroupIntoBatchesParams):
  1035        return False
  1036      return (
  1037          self.batch_size == other.batch_size and
  1038          self.max_buffering_duration_secs == other.max_buffering_duration_secs)
  1039  
  1040    def _validate(self):
  1041      assert self.batch_size is not None and self.batch_size > 0, (
  1042          'batch_size must be a positive value')
  1043      assert (
  1044          self.max_buffering_duration_secs is not None and
  1045          self.max_buffering_duration_secs >= 0), (
  1046              'max_buffering_duration must be a non-negative value')
  1047  
  1048    def get_payload(self):
  1049      return beam_runner_api_pb2.GroupIntoBatchesPayload(
  1050          batch_size=self.batch_size,
  1051          max_buffering_duration_millis=int(
  1052              self.max_buffering_duration_secs * 1000))
  1053  
  1054    @staticmethod
  1055    def parse_payload(
  1056        proto  # type: beam_runner_api_pb2.GroupIntoBatchesPayload
  1057    ):
  1058      return proto.batch_size, proto.max_buffering_duration_millis / 1000
  1059  
  1060  
  1061  def _pardo_group_into_batches(
  1062      input_coder, batch_size, max_buffering_duration_secs, clock=time.time):
  1063    ELEMENT_STATE = BagStateSpec('values', input_coder)
  1064    COUNT_STATE = CombiningValueStateSpec('count', input_coder, CountCombineFn())
  1065    WINDOW_TIMER = TimerSpec('window_end', TimeDomain.WATERMARK)
  1066    BUFFERING_TIMER = TimerSpec('buffering_end', TimeDomain.REAL_TIME)
  1067  
  1068    class _GroupIntoBatchesDoFn(DoFn):
  1069      def process(
  1070          self,
  1071          element,
  1072          window=DoFn.WindowParam,
  1073          element_state=DoFn.StateParam(ELEMENT_STATE),
  1074          count_state=DoFn.StateParam(COUNT_STATE),
  1075          window_timer=DoFn.TimerParam(WINDOW_TIMER),
  1076          buffering_timer=DoFn.TimerParam(BUFFERING_TIMER)):
  1077        # Allowed lateness not supported in Python SDK
  1078        # https://beam.apache.org/documentation/programming-guide/#watermarks-and-late-data
  1079        window_timer.set(window.end)
  1080        element_state.add(element)
  1081        count_state.add(1)
  1082        count = count_state.read()
  1083        if count == 1 and max_buffering_duration_secs > 0:
  1084          # This is the first element in batch. Start counting buffering time if a
  1085          # limit was set.
  1086          # pylint: disable=deprecated-method
  1087          buffering_timer.set(clock() + max_buffering_duration_secs)
  1088        if count >= batch_size:
  1089          return self.flush_batch(element_state, count_state, buffering_timer)
  1090  
  1091      @on_timer(WINDOW_TIMER)
  1092      def on_window_timer(
  1093          self,
  1094          element_state=DoFn.StateParam(ELEMENT_STATE),
  1095          count_state=DoFn.StateParam(COUNT_STATE),
  1096          buffering_timer=DoFn.TimerParam(BUFFERING_TIMER)):
  1097        return self.flush_batch(element_state, count_state, buffering_timer)
  1098  
  1099      @on_timer(BUFFERING_TIMER)
  1100      def on_buffering_timer(
  1101          self,
  1102          element_state=DoFn.StateParam(ELEMENT_STATE),
  1103          count_state=DoFn.StateParam(COUNT_STATE),
  1104          buffering_timer=DoFn.TimerParam(BUFFERING_TIMER)):
  1105        return self.flush_batch(element_state, count_state, buffering_timer)
  1106  
  1107      def flush_batch(self, element_state, count_state, buffering_timer):
  1108        batch = [element for element in element_state.read()]
  1109        if not batch:
  1110          return
  1111        key, _ = batch[0]
  1112        batch_values = [v for (k, v) in batch]
  1113        element_state.clear()
  1114        count_state.clear()
  1115        buffering_timer.clear()
  1116        yield key, batch_values
  1117  
  1118    return _GroupIntoBatchesDoFn()
  1119  
  1120  
  1121  class ToString(object):
  1122    """
  1123    PTransform for converting a PCollection element, KV or PCollection Iterable
  1124    to string.
  1125    """
  1126  
  1127    # pylint: disable=invalid-name
  1128    @staticmethod
  1129    def Element():
  1130      """
  1131      Transforms each element of the PCollection to a string.
  1132      """
  1133      return 'ElementToString' >> Map(str)
  1134  
  1135    @staticmethod
  1136    def Iterables(delimiter=None):
  1137      """
  1138      Transforms each item in the iterable of the input of PCollection to a
  1139      string. There is no trailing delimiter.
  1140      """
  1141      if delimiter is None:
  1142        delimiter = ','
  1143      return (
  1144          'IterablesToString' >>
  1145          Map(lambda xs: delimiter.join(str(x) for x in xs)).with_input_types(
  1146              Iterable[Any]).with_output_types(str))
  1147  
  1148    # An alias for Iterables.
  1149    Kvs = Iterables
  1150  
  1151  
  1152  @typehints.with_input_types(T)
  1153  @typehints.with_output_types(T)
  1154  class LogElements(PTransform):
  1155    """
  1156    PTransform for printing the elements of a PCollection.
  1157    """
  1158    class _LoggingFn(DoFn):
  1159      def __init__(self, prefix='', with_timestamp=False, with_window=False):
  1160        super().__init__()
  1161        self.prefix = prefix
  1162        self.with_timestamp = with_timestamp
  1163        self.with_window = with_window
  1164  
  1165      def process(
  1166          self,
  1167          element,
  1168          timestamp=DoFn.TimestampParam,
  1169          window=DoFn.WindowParam,
  1170          **kwargs):
  1171        log_line = self.prefix + str(element)
  1172  
  1173        if self.with_timestamp:
  1174          log_line += ', timestamp=' + repr(timestamp.to_rfc3339())
  1175  
  1176        if self.with_window:
  1177          log_line += ', window(start=' + window.start.to_rfc3339()
  1178          log_line += ', end=' + window.end.to_rfc3339() + ')'
  1179  
  1180        print(log_line)
  1181        yield element
  1182  
  1183    def __init__(
  1184        self, label=None, prefix='', with_timestamp=False, with_window=False):
  1185      super().__init__(label)
  1186      self.prefix = prefix
  1187      self.with_timestamp = with_timestamp
  1188      self.with_window = with_window
  1189  
  1190    def expand(self, input):
  1191      return input | ParDo(
  1192          self._LoggingFn(self.prefix, self.with_timestamp, self.with_window))
  1193  
  1194  
  1195  class Reify(object):
  1196    """PTransforms for converting between explicit and implicit form of various
  1197    Beam values."""
  1198    @typehints.with_input_types(T)
  1199    @typehints.with_output_types(T)
  1200    class Timestamp(PTransform):
  1201      """PTransform to wrap a value in a TimestampedValue with it's
  1202      associated timestamp."""
  1203      @staticmethod
  1204      def add_timestamp_info(element, timestamp=DoFn.TimestampParam):
  1205        yield TimestampedValue(element, timestamp)
  1206  
  1207      def expand(self, pcoll):
  1208        return pcoll | ParDo(self.add_timestamp_info)
  1209  
  1210    @typehints.with_input_types(T)
  1211    @typehints.with_output_types(T)
  1212    class Window(PTransform):
  1213      """PTransform to convert an element in a PCollection into a tuple of
  1214      (element, timestamp, window), wrapped in a TimestampedValue with it's
  1215      associated timestamp."""
  1216      @staticmethod
  1217      def add_window_info(
  1218          element, timestamp=DoFn.TimestampParam, window=DoFn.WindowParam):
  1219        yield TimestampedValue((element, timestamp, window), timestamp)
  1220  
  1221      def expand(self, pcoll):
  1222        return pcoll | ParDo(self.add_window_info)
  1223  
  1224    @typehints.with_input_types(Tuple[K, V])
  1225    @typehints.with_output_types(Tuple[K, V])
  1226    class TimestampInValue(PTransform):
  1227      """PTransform to wrap the Value in a KV pair in a TimestampedValue with
  1228      the element's associated timestamp."""
  1229      @staticmethod
  1230      def add_timestamp_info(element, timestamp=DoFn.TimestampParam):
  1231        key, value = element
  1232        yield (key, TimestampedValue(value, timestamp))
  1233  
  1234      def expand(self, pcoll):
  1235        return pcoll | ParDo(self.add_timestamp_info)
  1236  
  1237    @typehints.with_input_types(Tuple[K, V])
  1238    @typehints.with_output_types(Tuple[K, V])
  1239    class WindowInValue(PTransform):
  1240      """PTransform to convert the Value in a KV pair into a tuple of
  1241      (value, timestamp, window), with the whole element being wrapped inside a
  1242      TimestampedValue."""
  1243      @staticmethod
  1244      def add_window_info(
  1245          element, timestamp=DoFn.TimestampParam, window=DoFn.WindowParam):
  1246        key, value = element
  1247        yield TimestampedValue((key, (value, timestamp, window)), timestamp)
  1248  
  1249      def expand(self, pcoll):
  1250        return pcoll | ParDo(self.add_window_info)
  1251  
  1252  
  1253  class Regex(object):
  1254    """
  1255    PTransform  to use Regular Expression to process the elements in a
  1256    PCollection.
  1257    """
  1258  
  1259    ALL = "__regex_all_groups"
  1260  
  1261    @staticmethod
  1262    def _regex_compile(regex):
  1263      """Return re.compile if the regex has a string value"""
  1264      if isinstance(regex, str):
  1265        regex = re.compile(regex)
  1266      return regex
  1267  
  1268    @staticmethod
  1269    @typehints.with_input_types(str)
  1270    @typehints.with_output_types(str)
  1271    @ptransform_fn
  1272    def matches(pcoll, regex, group=0):
  1273      """
  1274      Returns the matches (group 0 by default) if zero or more characters at the
  1275      beginning of string match the regular expression. To match the entire
  1276      string, add "$" sign at the end of regex expression.
  1277  
  1278      Group can be integer value or a string value.
  1279  
  1280      Args:
  1281        regex: the regular expression string or (re.compile) pattern.
  1282        group: (optional) name/number of the group, it can be integer or a string
  1283          value. Defaults to 0, meaning the entire matched string will be
  1284          returned.
  1285      """
  1286      regex = Regex._regex_compile(regex)
  1287  
  1288      def _process(element):
  1289        m = regex.match(element)
  1290        if m:
  1291          yield m.group(group)
  1292  
  1293      return pcoll | FlatMap(_process)
  1294  
  1295    @staticmethod
  1296    @typehints.with_input_types(str)
  1297    @typehints.with_output_types(List[str])
  1298    @ptransform_fn
  1299    def all_matches(pcoll, regex):
  1300      """
  1301      Returns all matches (groups) if zero or more characters at the beginning
  1302      of string match the regular expression.
  1303  
  1304      Args:
  1305        regex: the regular expression string or (re.compile) pattern.
  1306      """
  1307      regex = Regex._regex_compile(regex)
  1308  
  1309      def _process(element):
  1310        m = regex.match(element)
  1311        if m:
  1312          yield [m.group(ix) for ix in range(m.lastindex + 1)]
  1313  
  1314      return pcoll | FlatMap(_process)
  1315  
  1316    @staticmethod
  1317    @typehints.with_input_types(str)
  1318    @typehints.with_output_types(Tuple[str, str])
  1319    @ptransform_fn
  1320    def matches_kv(pcoll, regex, keyGroup, valueGroup=0):
  1321      """
  1322      Returns the KV pairs if the string matches the regular expression, deriving
  1323      the key & value from the specified group of the regular expression.
  1324  
  1325      Args:
  1326        regex: the regular expression string or (re.compile) pattern.
  1327        keyGroup: The Regex group to use as the key. Can be int or str.
  1328        valueGroup: (optional) Regex group to use the value. Can be int or str.
  1329          The default value "0" returns entire matched string.
  1330      """
  1331      regex = Regex._regex_compile(regex)
  1332  
  1333      def _process(element):
  1334        match = regex.match(element)
  1335        if match:
  1336          yield (match.group(keyGroup), match.group(valueGroup))
  1337  
  1338      return pcoll | FlatMap(_process)
  1339  
  1340    @staticmethod
  1341    @typehints.with_input_types(str)
  1342    @typehints.with_output_types(str)
  1343    @ptransform_fn
  1344    def find(pcoll, regex, group=0):
  1345      """
  1346      Returns the matches if a portion of the line matches the Regex. Returns
  1347      the entire group (group 0 by default). Group can be integer value or a
  1348      string value.
  1349  
  1350      Args:
  1351        regex: the regular expression string or (re.compile) pattern.
  1352        group: (optional) name of the group, it can be integer or a string value.
  1353      """
  1354      regex = Regex._regex_compile(regex)
  1355  
  1356      def _process(element):
  1357        r = regex.search(element)
  1358        if r:
  1359          yield r.group(group)
  1360  
  1361      return pcoll | FlatMap(_process)
  1362  
  1363    @staticmethod
  1364    @typehints.with_input_types(str)
  1365    @typehints.with_output_types(Union[List[str], List[Tuple[str, str]]])
  1366    @ptransform_fn
  1367    def find_all(pcoll, regex, group=0, outputEmpty=True):
  1368      """
  1369      Returns the matches if a portion of the line matches the Regex. By default,
  1370      list of group 0 will return with empty items. To get all groups, pass the
  1371      `Regex.ALL` flag in the `group` parameter which returns all the groups in
  1372      the tuple format.
  1373  
  1374      Args:
  1375        regex: the regular expression string or (re.compile) pattern.
  1376        group: (optional) name of the group, it can be integer or a string value.
  1377        outputEmpty: (optional) Should empty be output. True to output empties
  1378          and false if not.
  1379      """
  1380      regex = Regex._regex_compile(regex)
  1381  
  1382      def _process(element):
  1383        matches = regex.finditer(element)
  1384        if group == Regex.ALL:
  1385          yield [(m.group(), m.groups()[0]) for m in matches
  1386                 if outputEmpty or m.groups()[0]]
  1387        else:
  1388          yield [m.group(group) for m in matches if outputEmpty or m.group(group)]
  1389  
  1390      return pcoll | FlatMap(_process)
  1391  
  1392    @staticmethod
  1393    @typehints.with_input_types(str)
  1394    @typehints.with_output_types(Tuple[str, str])
  1395    @ptransform_fn
  1396    def find_kv(pcoll, regex, keyGroup, valueGroup=0):
  1397      """
  1398      Returns the matches if a portion of the line matches the Regex. Returns the
  1399      specified groups as the key and value pair.
  1400  
  1401      Args:
  1402        regex: the regular expression string or (re.compile) pattern.
  1403        keyGroup: The Regex group to use as the key. Can be int or str.
  1404        valueGroup: (optional) Regex group to use the value. Can be int or str.
  1405          The default value "0" returns entire matched string.
  1406      """
  1407      regex = Regex._regex_compile(regex)
  1408  
  1409      def _process(element):
  1410        matches = regex.finditer(element)
  1411        if matches:
  1412          for match in matches:
  1413            yield (match.group(keyGroup), match.group(valueGroup))
  1414  
  1415      return pcoll | FlatMap(_process)
  1416  
  1417    @staticmethod
  1418    @typehints.with_input_types(str)
  1419    @typehints.with_output_types(str)
  1420    @ptransform_fn
  1421    def replace_all(pcoll, regex, replacement):
  1422      """
  1423      Returns the matches if a portion of the line  matches the regex and
  1424      replaces all matches with the replacement string.
  1425  
  1426      Args:
  1427        regex: the regular expression string or (re.compile) pattern.
  1428        replacement: the string to be substituted for each match.
  1429      """
  1430      regex = Regex._regex_compile(regex)
  1431      return pcoll | Map(lambda elem: regex.sub(replacement, elem))
  1432  
  1433    @staticmethod
  1434    @typehints.with_input_types(str)
  1435    @typehints.with_output_types(str)
  1436    @ptransform_fn
  1437    def replace_first(pcoll, regex, replacement):
  1438      """
  1439      Returns the matches if a portion of the line matches the regex and replaces
  1440      the first match with the replacement string.
  1441  
  1442      Args:
  1443        regex: the regular expression string or (re.compile) pattern.
  1444        replacement: the string to be substituted for each match.
  1445      """
  1446      regex = Regex._regex_compile(regex)
  1447      return pcoll | Map(lambda elem: regex.sub(replacement, elem, 1))
  1448  
  1449    @staticmethod
  1450    @typehints.with_input_types(str)
  1451    @typehints.with_output_types(List[str])
  1452    @ptransform_fn
  1453    def split(pcoll, regex, outputEmpty=False):
  1454      """
  1455      Returns the list string which was splitted on the basis of regular
  1456      expression. It will not output empty items (by defaults).
  1457  
  1458      Args:
  1459        regex: the regular expression string or (re.compile) pattern.
  1460        outputEmpty: (optional) Should empty be output. True to output empties
  1461            and false if not.
  1462      """
  1463      regex = Regex._regex_compile(regex)
  1464      outputEmpty = bool(outputEmpty)
  1465  
  1466      def _process(element):
  1467        r = regex.split(element)
  1468        if r and not outputEmpty:
  1469          r = list(filter(None, r))
  1470        yield r
  1471  
  1472      return pcoll | FlatMap(_process)