github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/testing/synthetic_pipeline.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """A set of utilities to write pipelines for performance tests.
    19  
    20  This module offers a way to create pipelines using synthetic sources and steps.
    21  Exact shape of the pipeline and the behaviour of sources and steps can be
    22  controlled through arguments. Please see function 'parse_args()' for more
    23  details about the arguments.
    24  
    25  Shape of the pipeline is primarily controlled through two arguments. Argument
    26  'steps' can be used to define a list of steps as a JSON string. Argument
    27  'barrier' describes how these steps are separated from each other. Argument
    28  'barrier' can be use to build a pipeline as a series of steps or a tree of
    29  steps with a fanin or a fanout of size 2.
    30  
    31  Other arguments describe what gets generated by synthetic sources that produce
    32  data for the pipeline.
    33  """
    34  
    35  # pytype: skip-file
    36  
    37  import argparse
    38  import json
    39  import logging
    40  import math
    41  import os
    42  import sys
    43  import time
    44  from random import Random
    45  from typing import Tuple
    46  
    47  import apache_beam as beam
    48  from apache_beam import pvalue
    49  from apache_beam import typehints
    50  from apache_beam.io import WriteToText
    51  from apache_beam.io import iobase
    52  from apache_beam.io import range_trackers
    53  from apache_beam.io import restriction_trackers
    54  from apache_beam.io.restriction_trackers import OffsetRange
    55  from apache_beam.io.restriction_trackers import OffsetRestrictionTracker
    56  from apache_beam.options.pipeline_options import PipelineOptions
    57  from apache_beam.options.pipeline_options import SetupOptions
    58  from apache_beam.testing.test_pipeline import TestPipeline
    59  from apache_beam.transforms import userstate
    60  from apache_beam.transforms.core import RestrictionProvider
    61  
    62  try:
    63    import numpy as np
    64  except ImportError:
    65    np = None
    66  
    67  
    68  class _Random(Random):
    69    """A subclass of `random.Random` from the Python Standard Library that
    70    provides a method returning random bytes of arbitrary length.
    71    """
    72  
    73    # `numpy.random.RandomState` does not provide `random()` method, we keep this
    74    # for compatibility reasons.
    75    random_sample = Random.random
    76  
    77    def bytes(self, length):
    78      """Returns random bytes.
    79  
    80      Args:
    81        length (int): Number of random bytes.
    82      """
    83      return self.getrandbits(length * 8).to_bytes(length, sys.byteorder)
    84  
    85  
    86  Generator = _Random
    87  
    88  
    89  def parse_byte_size(s):
    90    suffixes = 'BKMGTP'
    91    if s[-1] in suffixes:
    92      return int(float(s[:-1]) * 1024**suffixes.index(s[-1]))
    93  
    94    return int(s)
    95  
    96  
    97  def div_round_up(a, b):
    98    """Return ceil(a/b)."""
    99    return int(math.ceil(float(a) / b))
   100  
   101  
   102  def rotate_key(element):
   103    """Returns a new key-value pair of the same size but with a different key."""
   104    (key, value) = element
   105    return key[-1:] + key[:-1], value
   106  
   107  
   108  def initial_splitting_zipf(
   109      start_position,
   110      stop_position,
   111      desired_num_bundles,
   112      distribution_parameter,
   113      num_total_records=None):
   114    """Split the given range (defined by start_position, stop_position) into
   115       desired_num_bundles using zipf with the given distribution_parameter.
   116    """
   117    if not num_total_records:
   118      num_total_records = stop_position - start_position
   119    samples = np.random.zipf(distribution_parameter, desired_num_bundles)
   120    total = sum(samples)
   121    relative_bundle_sizes = [(float(sample) / total) for sample in samples]
   122    bundle_ranges = []
   123    start = start_position
   124    index = 0
   125    while start < stop_position:
   126      if index == desired_num_bundles - 1:
   127        bundle_ranges.append((start, stop_position))
   128        break
   129      stop = start + int(num_total_records * relative_bundle_sizes[index])
   130      bundle_ranges.append((start, stop))
   131      start = stop
   132      index += 1
   133    return bundle_ranges
   134  
   135  
   136  class SyntheticStep(beam.DoFn):
   137    """A DoFn of which behavior can be controlled through prespecified parameters.
   138    """
   139    def __init__(
   140        self,
   141        per_element_delay_sec=0,
   142        per_bundle_delay_sec=0,
   143        output_records_per_input_record=1,
   144        output_filter_ratio=0):
   145      if per_element_delay_sec and per_element_delay_sec < 1e-3:
   146        raise ValueError(
   147            'Per element sleep time must be at least 1e-3. '
   148            'Received: %r',
   149            per_element_delay_sec)
   150      self._per_element_delay_sec = per_element_delay_sec
   151      self._per_bundle_delay_sec = per_bundle_delay_sec
   152      self._output_records_per_input_record = output_records_per_input_record
   153      self._output_filter_ratio = output_filter_ratio
   154  
   155    def start_bundle(self):
   156      self._start_time = time.time()
   157  
   158    def finish_bundle(self):
   159      # The target is for the enclosing stage to take as close to as possible
   160      # the given number of seconds, so we only sleep enough to make up for
   161      # overheads not incurred elsewhere.
   162      to_sleep = self._per_bundle_delay_sec - (time.time() - self._start_time)
   163  
   164      # Ignoring sub-millisecond sleep times.
   165      if to_sleep >= 1e-3:
   166        time.sleep(to_sleep)
   167  
   168    def process(self, element):
   169      if self._per_element_delay_sec >= 1e-3:
   170        time.sleep(self._per_element_delay_sec)
   171      filter_element = False
   172      if self._output_filter_ratio > 0:
   173        if np.random.random() < self._output_filter_ratio:
   174          filter_element = True
   175  
   176      if not filter_element:
   177        for _ in range(self._output_records_per_input_record):
   178          yield element
   179  
   180  
   181  class NonLiquidShardingOffsetRangeTracker(OffsetRestrictionTracker):
   182    """An OffsetRangeTracker that doesn't allow splitting. """
   183    def try_split(self, split_offset):
   184      pass  # Don't split.
   185  
   186    def checkpoint(self):
   187      pass  # Don't split.
   188  
   189  
   190  class SyntheticSDFStepRestrictionProvider(RestrictionProvider):
   191    """A `RestrictionProvider` for SyntheticSDFStep.
   192  
   193    An initial_restriction and split that operate on num_records and ignores
   194    source description (element). Splits into initial_splitting_num_bundles.
   195    Returns size_estimate_override as restriction size, if set. Otherwise uses
   196    element size.
   197  
   198    If initial_splitting_uneven_chunks, produces uneven chunks.
   199  
   200    """
   201    def __init__(
   202        self,
   203        num_records,
   204        initial_splitting_num_bundles,
   205        initial_splitting_uneven_chunks,
   206        disable_liquid_sharding,
   207        size_estimate_override):
   208      self._num_records = num_records
   209      self._initial_splitting_num_bundles = initial_splitting_num_bundles
   210      self._initial_splitting_uneven_chunks = initial_splitting_uneven_chunks
   211      self._disable_liquid_sharding = disable_liquid_sharding
   212      self._size_estimate_override = size_estimate_override
   213  
   214    def initial_restriction(self, element):
   215      return OffsetRange(0, self._num_records)
   216  
   217    def create_tracker(self, restriction):
   218      if self._disable_liquid_sharding:
   219        return NonLiquidShardingOffsetRangeTracker(restriction)
   220      else:
   221        return OffsetRestrictionTracker(restriction)
   222  
   223    def split(self, element, restriction):
   224      elems = restriction.size()
   225      if (self._initial_splitting_uneven_chunks and
   226          self._initial_splitting_num_bundles > 1 and elems > 1):
   227        bundle_ranges = initial_splitting_zipf(
   228            restriction.start,
   229            restriction.stop,
   230            self._initial_splitting_num_bundles,
   231            3.0)
   232        for start, stop in bundle_ranges:
   233          yield OffsetRange(start, stop)
   234  
   235      else:
   236        offsets_per_split = max(1, (elems // self._initial_splitting_num_bundles))
   237        for split in restriction.split(offsets_per_split, offsets_per_split // 2):
   238          yield split
   239  
   240    def restriction_size(self, element, restriction):
   241      if self._size_estimate_override is not None:
   242        return self._size_estimate_override
   243      element_size = len(element) if isinstance(element, str) else 1
   244      return restriction.size() * element_size
   245  
   246  
   247  def get_synthetic_sdf_step(
   248      per_element_delay_sec=0,
   249      per_bundle_delay_sec=0,
   250      output_records_per_input_record=1,
   251      output_filter_ratio=0,
   252      initial_splitting_num_bundles=8,
   253      initial_splitting_uneven_chunks=False,
   254      disable_liquid_sharding=False,
   255      size_estimate_override=None,
   256  ):
   257    """A function which returns a SyntheticSDFStep with given parameters. """
   258    class SyntheticSDFStep(beam.DoFn):
   259      """A SplittableDoFn of which behavior can be controlled through prespecified
   260         parameters.
   261      """
   262      def __init__(
   263          self,
   264          per_element_delay_sec_arg,
   265          per_bundle_delay_sec_arg,
   266          output_filter_ratio_arg,
   267          output_records_per_input_record_arg):
   268        if per_element_delay_sec_arg:
   269          per_element_delay_sec_arg = (
   270              per_element_delay_sec_arg // output_records_per_input_record_arg)
   271          if per_element_delay_sec_arg < 1e-3:
   272            raise ValueError(
   273                'Per element sleep time must be at least 1e-3 after being '
   274                'divided among output elements.')
   275        self._per_element_delay_sec = per_element_delay_sec_arg
   276        self._per_bundle_delay_sec = per_bundle_delay_sec_arg
   277        self._output_filter_ratio = output_filter_ratio_arg
   278  
   279      def start_bundle(self):
   280        self._start_time = time.time()
   281  
   282      def finish_bundle(self):
   283        # The target is for the enclosing stage to take as close to as possible
   284        # the given number of seconds, so we only sleep enough to make up for
   285        # overheads not incurred elsewhere.
   286        to_sleep = self._per_bundle_delay_sec - (time.time() - self._start_time)
   287  
   288        # Ignoring sub-millisecond sleep times.
   289        if to_sleep >= 1e-3:
   290          time.sleep(to_sleep)
   291  
   292      def process(
   293          self,
   294          element,
   295          restriction_tracker=beam.DoFn.RestrictionParam(
   296              SyntheticSDFStepRestrictionProvider(
   297                  output_records_per_input_record,
   298                  initial_splitting_num_bundles,
   299                  initial_splitting_uneven_chunks,
   300                  disable_liquid_sharding,
   301                  size_estimate_override))):
   302        filter_element = False
   303        if self._output_filter_ratio > 0:
   304          if np.random.random() < self._output_filter_ratio:
   305            filter_element = True
   306  
   307        current_restriction = restriction_tracker.current_restriction()
   308        for cur in range(current_restriction.start, current_restriction.stop):
   309          if not restriction_tracker.try_claim(cur):
   310            return
   311  
   312          if self._per_element_delay_sec:
   313            time.sleep(self._per_element_delay_sec)
   314  
   315          if not filter_element:
   316            yield element
   317          cur += 1
   318  
   319    return SyntheticSDFStep(
   320        per_element_delay_sec,
   321        per_bundle_delay_sec,
   322        output_filter_ratio,
   323        output_records_per_input_record)
   324  
   325  
   326  class SyntheticSource(iobase.BoundedSource):
   327    """A custom source of a specified size.
   328    """
   329    def __init__(self, input_spec):
   330      """Initiates a synthetic source.
   331  
   332      Args:
   333        input_spec: Input specification of the source. See corresponding option in
   334                    function 'parse_args()' below for more details.
   335      Raises:
   336        ValueError: if input parameters are invalid.
   337      """
   338      def maybe_parse_byte_size(s):
   339        return parse_byte_size(s) if isinstance(s, str) else int(s)
   340  
   341      self._num_records = input_spec['numRecords']
   342      self._key_size = maybe_parse_byte_size(input_spec.get('keySizeBytes', 1))
   343      self._hot_key_fraction = input_spec.get('hotKeyFraction', 0)
   344      self._num_hot_keys = input_spec.get('numHotKeys', 0)
   345  
   346      self._value_size = maybe_parse_byte_size(
   347          input_spec.get('valueSizeBytes', 1))
   348      self._total_size = self.element_size * self._num_records
   349      self._initial_splitting = (
   350          input_spec['bundleSizeDistribution']['type']
   351          if 'bundleSizeDistribution' in input_spec else 'const')
   352      if self._initial_splitting != 'const' and self._initial_splitting != 'zipf':
   353        raise ValueError(
   354            'Only const and zipf distributions are supported for determining '
   355            'sizes of bundles produced by initial splitting. Received: %s',
   356            self._initial_splitting)
   357      self._initial_splitting_num_bundles = (
   358          input_spec['forceNumInitialBundles']
   359          if 'forceNumInitialBundles' in input_spec else 0)
   360      if self._initial_splitting == 'zipf':
   361        self._initial_splitting_distribution_parameter = (
   362            input_spec['bundleSizeDistribution']['param'])
   363        if self._initial_splitting_distribution_parameter < 1:
   364          raise ValueError(
   365              'Parameter for a Zipf distribution must be larger than 1. '
   366              'Received %r.',
   367              self._initial_splitting_distribution_parameter)
   368      else:
   369        self._initial_splitting_distribution_parameter = 0
   370      self._dynamic_splitting = (
   371          'none' if (
   372              'splitPointFrequencyRecords' in input_spec and
   373              input_spec['splitPointFrequencyRecords'] == 0) else 'perfect')
   374      if 'delayDistribution' in input_spec:
   375        if input_spec['delayDistribution']['type'] != 'const':
   376          raise ValueError(
   377              'SyntheticSource currently only supports delay '
   378              'distributions of type \'const\'. Received %s.',
   379              input_spec['delayDistribution']['type'])
   380        self._sleep_per_input_record_sec = (
   381            float(input_spec['delayDistribution']['const']) / 1000)
   382        if (self._sleep_per_input_record_sec and
   383            self._sleep_per_input_record_sec < 1e-3):
   384          raise ValueError(
   385              'Sleep time per input record must be at least 1e-3.'
   386              ' Received: %r',
   387              self._sleep_per_input_record_sec)
   388      else:
   389        self._sleep_per_input_record_sec = 0
   390  
   391    @property
   392    def element_size(self):
   393      return self._key_size + self._value_size
   394  
   395    def estimate_size(self):
   396      return self._total_size
   397  
   398    def split(self, desired_bundle_size, start_position=0, stop_position=None):
   399      # Performs initial splitting of SyntheticSource.
   400      #
   401      # Exact sizes and distribution of initial splits generated here depends on
   402      # the input specification of the SyntheticSource.
   403  
   404      if stop_position is None:
   405        stop_position = self._num_records
   406      if self._initial_splitting == 'zipf':
   407        desired_num_bundles = self._initial_splitting_num_bundles or math.ceil(
   408            float(self.estimate_size()) / desired_bundle_size)
   409        bundle_ranges = initial_splitting_zipf(
   410            start_position,
   411            stop_position,
   412            desired_num_bundles,
   413            self._initial_splitting_distribution_parameter,
   414            self._num_records)
   415      else:
   416        if self._initial_splitting_num_bundles:
   417          bundle_size_in_elements = max(
   418              1, int(self._num_records / self._initial_splitting_num_bundles))
   419        else:
   420          bundle_size_in_elements = (
   421              max(
   422                  div_round_up(desired_bundle_size, self.element_size),
   423                  int(math.floor(math.sqrt(self._num_records)))))
   424        bundle_ranges = []
   425        for start in range(start_position, stop_position,
   426                           bundle_size_in_elements):
   427          stop = min(start + bundle_size_in_elements, stop_position)
   428          bundle_ranges.append((start, stop))
   429  
   430      for start, stop in bundle_ranges:
   431        yield iobase.SourceBundle(stop - start, self, start, stop)
   432  
   433    def get_range_tracker(self, start_position, stop_position):
   434      if start_position is None:
   435        start_position = 0
   436      if stop_position is None:
   437        stop_position = self._num_records
   438      tracker = range_trackers.OffsetRangeTracker(start_position, stop_position)
   439      if self._dynamic_splitting == 'none':
   440        tracker = range_trackers.UnsplittableRangeTracker(tracker)
   441      return tracker
   442  
   443    def _gen_kv_pair(self, generator, index):
   444      generator.seed(index)
   445      rand = generator.random_sample()
   446  
   447      # Determines whether to generate hot key or not.
   448      if rand < self._hot_key_fraction:
   449        # Generate hot key.
   450        # An integer is randomly selected from the range [0, numHotKeys-1]
   451        # with equal probability.
   452        generator_hot = Generator(index % self._num_hot_keys)
   453        bytes_ = generator_hot.bytes(self._key_size), generator.bytes(
   454          self._value_size)
   455      else:
   456        bytes_ = generator.bytes(self.element_size)
   457        bytes_ = bytes_[:self._key_size], bytes_[self._key_size:]
   458      return bytes_
   459  
   460    def read(self, range_tracker):
   461      index = range_tracker.start_position()
   462      generator = Generator()
   463      while range_tracker.try_claim(index):
   464        time.sleep(self._sleep_per_input_record_sec)
   465        yield self._gen_kv_pair(generator, index)
   466        index += 1
   467  
   468    def default_output_coder(self):
   469      return beam.coders.TupleCoder(
   470          [beam.coders.BytesCoder(), beam.coders.BytesCoder()])
   471  
   472  
   473  class SyntheticSDFSourceRestrictionProvider(RestrictionProvider):
   474    """A `RestrictionProvider` for SyntheticSDFAsSource.
   475  
   476    In initial_restriction(element) and split(element), element means source
   477    description.
   478    A typical element is like:
   479  
   480      {
   481        'key_size': 1,
   482        'value_size': 1,
   483        'initial_splitting_num_bundles': 8,
   484        'initial_splitting_desired_bundle_size': 2,
   485        'sleep_per_input_record_sec': 0,
   486        'initial_splitting' : 'const'
   487  
   488      }
   489  
   490    """
   491    def initial_restriction(self, element):
   492      return OffsetRange(0, element['num_records'])
   493  
   494    def create_tracker(self, restriction):
   495      return restriction_trackers.OffsetRestrictionTracker(restriction)
   496  
   497    def split(self, element, restriction):
   498      bundle_ranges = []
   499      start_position = restriction.start
   500      stop_position = restriction.stop
   501      element_size = element['key_size'] + element['value_size']
   502      estimate_size = element_size * element['num_records']
   503      if element['initial_splitting'] == 'zipf':
   504        desired_num_bundles = (
   505            element['initial_splitting_num_bundles'] or div_round_up(
   506                estimate_size, element['initial_splitting_desired_bundle_size']))
   507        samples = np.random.zipf(
   508            element['initial_splitting_distribution_parameter'],
   509            desired_num_bundles)
   510        total = sum(samples)
   511        relative_bundle_sizes = [(float(sample) / total) for sample in samples]
   512        start = start_position
   513        index = 0
   514        while start < stop_position:
   515          if index == desired_num_bundles - 1:
   516            bundle_ranges.append(OffsetRange(start, stop_position))
   517            break
   518          stop = start + int(
   519              element['num_records'] * relative_bundle_sizes[index])
   520          bundle_ranges.append(OffsetRange(start, stop))
   521          start = stop
   522          index += 1
   523      else:
   524        if element['initial_splitting_num_bundles']:
   525          bundle_size_in_elements = max(
   526              1,
   527              int(
   528                  element['num_records'] /
   529                  element['initial_splitting_num_bundles']))
   530        else:
   531          bundle_size_in_elements = (
   532              max(
   533                  div_round_up(
   534                      element['initial_splitting_desired_bundle_size'],
   535                      element_size),
   536                  int(math.floor(math.sqrt(element['num_records'])))))
   537        for start in range(start_position, stop_position,
   538                           bundle_size_in_elements):
   539          stop = min(start + bundle_size_in_elements, stop_position)
   540          bundle_ranges.append(OffsetRange(start, stop))
   541      return bundle_ranges
   542  
   543    def restriction_size(self, element, restriction):
   544      return (element['key_size'] + element['value_size']) * restriction.size()
   545  
   546  
   547  class SyntheticSDFAsSource(beam.DoFn):
   548    """A SDF that generates records like a source.
   549  
   550    This SDF accepts a PCollection of record-based source description.
   551    A typical description is like:
   552  
   553      {
   554        'key_size': 1,
   555        'value_size': 1,
   556        'initial_splitting_num_bundles': 8,
   557        'initial_splitting_desired_bundle_size': 2,
   558        'sleep_per_input_record_sec': 0,
   559        'initial_splitting' : 'const'
   560  
   561      }
   562  
   563    A simple pipeline taking this SDF as a source is like:
   564      p
   565      | beam.Create([description1, description2,...])
   566      | beam.ParDo(SyntheticSDFAsSource())
   567  
   568    NOTE:
   569      The SDF.process() will have different param content between defining a DoFn
   570      and runtime.
   571      When defining an SDF.process, the restriction_tracker should be a
   572      `RestrictionProvider`.
   573      During runtime, the DoFnRunner.process_with_sized_restriction() will feed
   574      a 'RestrictionTracker' based on a restriction to SDF.process().
   575    """
   576    def process(
   577        self,
   578        element,
   579        restriction_tracker=beam.DoFn.RestrictionParam(
   580            SyntheticSDFSourceRestrictionProvider())):
   581      cur = restriction_tracker.current_restriction().start
   582      while restriction_tracker.try_claim(cur):
   583        r = Generator()
   584        r.seed(cur)
   585        time.sleep(element['sleep_per_input_record_sec'])
   586        yield r.bytes(element['key_size']), r.bytes(element['value_size'])
   587        cur += 1
   588  
   589  
   590  class ShuffleBarrier(beam.PTransform):
   591    def expand(self, pc):
   592      return (
   593          pc
   594          | beam.Map(rotate_key)
   595          | beam.GroupByKey()
   596          | 'Ungroup' >> beam.FlatMap(lambda elm: [(elm[0], v) for v in elm[1]]))
   597  
   598  
   599  class SideInputBarrier(beam.PTransform):
   600    def expand(self, pc):
   601      return (
   602          pc
   603          | beam.Map(rotate_key)
   604          | beam.Map(
   605              lambda elem,
   606              ignored: elem,
   607              beam.pvalue.AsIter(pc | beam.FlatMap(lambda elem: None))))
   608  
   609  
   610  def merge_using_gbk(name, pc1, pc2):
   611    """Merges two given PCollections using a CoGroupByKey."""
   612  
   613    pc1_with_key = pc1 | (name + 'AttachKey1') >> beam.Map(lambda x: (x, x))
   614    pc2_with_key = pc2 | (name + 'AttachKey2') >> beam.Map(lambda x: (x, x))
   615  
   616    grouped = ({
   617        'pc1': pc1_with_key, 'pc2': pc2_with_key
   618    } | (name + 'Group') >> beam.CoGroupByKey())
   619    return (
   620        grouped | (name + 'DeDup') >> beam.Map(lambda elm: elm[0])
   621    )  # Ignoring values
   622  
   623  
   624  def merge_using_side_input(name, pc1, pc2):
   625    """Merges two given PCollections using side inputs."""
   626    def join_fn(val, _):  # Ignoring side input
   627      return val
   628  
   629    return pc1 | name >> beam.core.Map(join_fn, beam.pvalue.AsIter(pc2))
   630  
   631  
   632  def expand_using_gbk(name, pc):
   633    """Expands a given PCollection into two copies using GroupByKey."""
   634  
   635    ret = []
   636    ret.append((pc | ('%s.a' % name) >> ShuffleBarrier()))
   637    ret.append((pc | ('%s.b' % name) >> ShuffleBarrier()))
   638    return ret
   639  
   640  
   641  def expand_using_second_output(name, pc):
   642    """Expands a given PCollection into two copies using side outputs."""
   643    class ExpandFn(beam.DoFn):
   644      def process(self, element):
   645        yield beam.pvalue.TaggedOutput('second_out', element)
   646        yield element
   647  
   648    pc1, pc2 = (pc | name >> beam.ParDo(
   649        ExpandFn()).with_outputs('second_out', main='main_out'))
   650    return [pc1, pc2]
   651  
   652  
   653  def _parse_steps(json_str):
   654    """Converts the JSON step description into Python objects.
   655  
   656    See property 'steps' for more details about the JSON step description.
   657  
   658    Args:
   659      json_str: a JSON string that describes the steps.
   660  
   661    Returns:
   662      Information about steps as a list of dictionaries. Each dictionary may have
   663      following properties.
   664      (1) per_element_delay - amount of delay for each element in seconds.
   665      (2) per_bundle_delay - minimum amount of delay for a given step in seconds.
   666      (3) output_records_per_input_record - number of output elements generated
   667          for each input element to a step.
   668      (4) output_filter_ratio - the probability at which a step may filter out a
   669          given element by not producing any output for that element.
   670      (5) splittable - if the step should be splittable.
   671      (6) initial_splitting_num_bundles - number of bundles initial split if step
   672          is splittable.
   673      (7) initial_splitting_uneven_chunks - if the bundles should be
   674          unevenly-sized
   675      (8) disable_liquid_sharding - if liquid sharding should be disabled
   676      (9) size_estimate_override - the size estimate or None to use default
   677    """
   678    all_steps = []
   679    json_data = json.loads(json_str)
   680    for val in json_data:
   681      steps = {}
   682      steps['per_element_delay'] = ((float(val['per_element_delay_msec']) / 1000)
   683                                    if 'per_element_delay_msec' in val else 0)
   684      steps['per_bundle_delay'] = (
   685          float(val['per_bundle_delay_sec'])
   686          if 'per_bundle_delay_sec' in val else 0)
   687      steps['output_records_per_input_record'] = (
   688          int(val['output_records_per_input_record'])
   689          if 'output_records_per_input_record' in val else 1)
   690      steps['output_filter_ratio'] = (
   691          float(val['output_filter_ratio'])
   692          if 'output_filter_ratio' in val else 0)
   693      steps['splittable'] = (
   694          bool(val['splittable']) if 'splittable' in val else False)
   695      steps['initial_splitting_num_bundles'] = (
   696          int(val['initial_splitting_num_bundles'])
   697          if 'initial_splitting_num_bundles' in val else 8)
   698      steps['initial_splitting_uneven_chunks'] = (
   699          bool(val['initial_splitting_uneven_chunks'])
   700          if 'initial_splitting_uneven_chunks' in val else False)
   701      steps['disable_liquid_sharding'] = (
   702          bool(val['disable_liquid_sharding'])
   703          if 'disable_liquid_sharding' in val else False)
   704      steps['size_estimate_override'] = (
   705          int(val['size_estimate_override'])
   706          if 'size_estimate_override' in val else None)
   707      all_steps.append(steps)
   708  
   709    return all_steps
   710  
   711  
   712  def parse_args(args):
   713    """Parses a given set of arguments.
   714  
   715    Args:
   716      args: set of arguments to be passed.
   717  
   718    Returns:
   719      a tuple where first item gives the set of arguments defined and parsed
   720      within this method and second item gives the set of unknown arguments.
   721    """
   722  
   723    parser = argparse.ArgumentParser()
   724    parser.add_argument(
   725        '--steps',
   726        dest='steps',
   727        type=_parse_steps,
   728        help='A JSON string that gives a list where each entry of the list is '
   729        'configuration information for a step. Configuration for each step '
   730        'consists of '
   731        '(1) A float "per_bundle_delay_sec" (in seconds). Defaults to 0.'
   732        '(2) A float "per_element_delay_msec" (in milli seconds). '
   733        '    Defaults to 0.'
   734        '(3) An integer "output_records_per_input_record". Defaults to 1.'
   735        '(4) A float "output_filter_ratio" in the range [0, 1] . '
   736        '    Defaults to 0.'
   737        '(5) A bool "splittable" that defaults to false.'
   738        '(6) An integer "initial_splitting_num_bundles". Defaults to 8.')
   739  
   740    parser.add_argument(
   741        '--input',
   742        dest='input',
   743        type=json.loads,
   744        help='A JSON string that describes the properties of the SyntheticSource '
   745        'used by the pipeline. Configuration is similar to Java '
   746        'SyntheticBoundedInput.'
   747        'Currently supports following properties. '
   748        '(1) An integer "numRecords". '
   749        '(2) An integer "keySize". '
   750        '(3) An integer "valueSize". '
   751        '(4) A tuple "bundleSizeDistribution" with following values. '
   752        '    A string "type". Allowed values are "const" and "zipf". '
   753        '    An float "param". Only used if "type"=="zipf". Must be '
   754        '    larger than 1. '
   755        '(5) An integer "forceNumInitialBundles". '
   756        '(6) An integer "splitPointFrequencyRecords". '
   757        '(7) A tuple "delayDistribution" with following values. '
   758        '    A string "type". Only allowed value is "const". '
   759        '    An integer "const". ')
   760  
   761    parser.add_argument(
   762        '--barrier',
   763        dest='barrier',
   764        default='shuffle',
   765        choices=[
   766            'shuffle',
   767            'side-input',
   768            'expand-gbk',
   769            'expand-second-output',
   770            'merge-gbk',
   771            'merge-side-input'
   772        ],
   773        help='Whether to use shuffle as the barrier '
   774        '(as opposed to side inputs).')
   775    parser.add_argument(
   776        '--output',
   777        dest='output',
   778        default='',
   779        help='Destination to write output.')
   780  
   781    return parser.parse_known_args(args)
   782  
   783  
   784  def run(argv=None, save_main_session=True):
   785    """Runs the workflow."""
   786    known_args, pipeline_args = parse_args(argv)
   787  
   788    pipeline_options = PipelineOptions(pipeline_args)
   789    pipeline_options.view_as(SetupOptions).save_main_session = save_main_session
   790  
   791    input_info = known_args.input
   792  
   793    with TestPipeline(options=pipeline_options) as p:
   794      source = SyntheticSource(input_info)
   795  
   796      # pylint: disable=expression-not-assigned
   797      barrier = known_args.barrier
   798  
   799      pc_list = []
   800      num_roots = 2**(len(known_args.steps) - 1) if (
   801          barrier == 'merge-gbk' or barrier == 'merge-side-input') else 1
   802      for read_no in range(num_roots):
   803        pc_list.append((p | ('Read %d' % read_no) >> beam.io.Read(source)))
   804  
   805      for step_no, steps in enumerate(known_args.steps):
   806        if step_no != 0:
   807          new_pc_list = []
   808          for pc_no, pc in enumerate(pc_list):
   809            if barrier == 'shuffle':
   810              new_pc_list.append(
   811                  (pc | ('shuffle %d.%d' % (step_no, pc_no)) >> ShuffleBarrier()))
   812            elif barrier == 'side-input':
   813              new_pc_list.append((
   814                  pc | ('side-input %d.%d' %
   815                        (step_no, pc_no)) >> SideInputBarrier()))
   816            elif barrier == 'expand-gbk':
   817              new_pc_list.extend(
   818                  expand_using_gbk(('expand-gbk %d.%d' % (step_no, pc_no)), pc))
   819            elif barrier == 'expand-second-output':
   820              new_pc_list.extend(
   821                  expand_using_second_output(
   822                      ('expand-second-output %d.%d' % (step_no, pc_no)), pc))
   823            elif barrier == 'merge-gbk':
   824              if pc_no % 2 == 0:
   825                new_pc_list.append(
   826                    merge_using_gbk(('merge-gbk %d.%d' % (step_no, pc_no)),
   827                                    pc,
   828                                    pc_list[pc_no + 1]))
   829              else:
   830                continue
   831            elif barrier == 'merge-side-input':
   832              if pc_no % 2 == 0:
   833                new_pc_list.append(
   834                    merge_using_side_input(
   835                        ('merge-side-input %d.%d' % (step_no, pc_no)),
   836                        pc,
   837                        pc_list[pc_no + 1]))
   838              else:
   839                continue
   840  
   841          pc_list = new_pc_list
   842  
   843        new_pc_list = []
   844        for pc_no, pc in enumerate(pc_list):
   845          if steps['splittable']:
   846            step = get_synthetic_sdf_step(
   847                per_element_delay_sec=steps['per_element_delay'],
   848                per_bundle_delay_sec=steps['per_bundle_delay'],
   849                output_records_per_input_record=steps[
   850                    'output_records_per_input_record'],
   851                output_filter_ratio=steps['output_filter_ratio'],
   852                initial_splitting_num_bundles=steps[
   853                    'initial_splitting_num_bundles'],
   854                initial_splitting_uneven_chunks=steps[
   855                    'initial_splitting_uneven_chunks'],
   856                disable_liquid_sharding=steps['disable_liquid_sharding'],
   857                size_estimate_override=steps['size_estimate_override'])
   858          else:
   859            step = SyntheticStep(
   860                per_element_delay_sec=steps['per_element_delay'],
   861                per_bundle_delay_sec=steps['per_bundle_delay'],
   862                output_records_per_input_record=steps[
   863                    'output_records_per_input_record'],
   864                output_filter_ratio=steps['output_filter_ratio'])
   865          new_pc = pc | 'SyntheticStep %d.%d' % (step_no,
   866                                                 pc_no) >> beam.ParDo(step)
   867          new_pc_list.append(new_pc)
   868        pc_list = new_pc_list
   869  
   870      if known_args.output:
   871        # If an output location is provided we format and write output.
   872        if len(pc_list) == 1:
   873          (
   874              pc_list[0]
   875              | 'FormatOutput' >> beam.Map(lambda elm: (elm[0] + elm[1]))
   876              | 'WriteOutput' >> WriteToText(known_args.output))
   877  
   878    logging.info('Pipeline run completed.')
   879  
   880  
   881  if __name__ == '__main__':
   882    logging.getLogger().setLevel(logging.INFO)
   883    run()
   884  
   885  
   886  class StatefulLoadGenerator(beam.PTransform):
   887    """A PTransform for generating random data using Timers API."""
   888    def __init__(self, input_options, num_keys=100):
   889      self.num_records = input_options['num_records']
   890      self.key_size = input_options['key_size']
   891      self.value_size = input_options['value_size']
   892      self.num_keys = num_keys
   893  
   894    @typehints.with_output_types(Tuple[bytes, bytes])
   895    class GenerateKeys(beam.DoFn):
   896      def __init__(self, num_keys, key_size):
   897        self.num_keys = num_keys
   898        self.key_size = key_size
   899  
   900      def process(self, impulse):
   901        for _ in range(self.num_keys):
   902          key = os.urandom(self.key_size)
   903          yield key, b''
   904  
   905    class GenerateLoad(beam.DoFn):
   906      state_spec = userstate.CombiningValueStateSpec(
   907          'bundles_remaining', combine_fn=sum)
   908      timer_spec = userstate.TimerSpec('timer', userstate.TimeDomain.WATERMARK)
   909  
   910      def __init__(self, num_records_per_key, value_size, bundle_size=1000):
   911        self.num_records_per_key = num_records_per_key
   912        self.payload = os.urandom(value_size)
   913        self.bundle_size = bundle_size
   914        self.time_fn = time.time
   915  
   916      def process(
   917          self,
   918          _element,
   919          records_remaining=beam.DoFn.StateParam(state_spec),
   920          timer=beam.DoFn.TimerParam(timer_spec)):
   921        records_remaining.add(self.num_records_per_key)
   922        timer.set(0)
   923  
   924      @userstate.on_timer(timer_spec)
   925      def process_timer(
   926          self,
   927          key=beam.DoFn.KeyParam,
   928          records_remaining=beam.DoFn.StateParam(state_spec),
   929          timer=beam.DoFn.TimerParam(timer_spec)):
   930        cur_bundle_size = min(self.bundle_size, records_remaining.read())
   931        for _ in range(cur_bundle_size):
   932          records_remaining.add(-1)
   933          yield key, self.payload
   934        if records_remaining.read() > 0:
   935          timer.set(0)
   936  
   937    def expand(self, pbegin):
   938      assert isinstance(pbegin, pvalue.PBegin), (
   939          'Input to transform must be a PBegin but found %s' % pbegin)
   940      return (
   941          pbegin
   942          | 'Impulse' >> beam.Impulse()
   943          | 'GenerateKeys' >> beam.ParDo(
   944              StatefulLoadGenerator.GenerateKeys(self.num_keys, self.key_size))
   945          | 'GenerateLoad' >> beam.ParDo(
   946              StatefulLoadGenerator.GenerateLoad(
   947                  self.num_records // self.num_keys, self.value_size)))