github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/testing/synthetic_pipeline_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Tests for apache_beam.testing.synthetic_pipeline."""
    19  
    20  # pytype: skip-file
    21  
    22  import glob
    23  import json
    24  import logging
    25  import tempfile
    26  import time
    27  import unittest
    28  
    29  import apache_beam as beam
    30  from apache_beam.io import source_test_utils
    31  from apache_beam.io.restriction_trackers import OffsetRange
    32  from apache_beam.testing import synthetic_pipeline
    33  from apache_beam.testing.util import assert_that
    34  from apache_beam.testing.util import equal_to
    35  
    36  try:
    37    import numpy  # pylint: disable=unused-import
    38  except ImportError:
    39    NP_INSTALLED = False
    40  else:
    41    NP_INSTALLED = True
    42  
    43  
    44  def input_spec(
    45      num_records,
    46      key_size,
    47      value_size,
    48      bundle_size_distribution_type='const',
    49      bundle_size_distribution_param=0,
    50      force_initial_num_bundles=0):
    51    return {
    52        'numRecords': num_records,
    53        'keySizeBytes': key_size,
    54        'valueSizeBytes': value_size,
    55        'bundleSizeDistribution': {
    56            'type': bundle_size_distribution_type,
    57            'param': bundle_size_distribution_param
    58        },
    59        'forceNumInitialBundles': force_initial_num_bundles,
    60    }
    61  
    62  
    63  @unittest.skipIf(
    64      not NP_INSTALLED, 'Synthetic source dependencies are not installed')
    65  class SyntheticPipelineTest(unittest.TestCase):
    66  
    67    # pylint: disable=expression-not-assigned
    68  
    69    def test_synthetic_step_multiplies_output_elements_count(self):
    70      with beam.Pipeline() as p:
    71        pcoll = p | beam.Create(list(range(10))) | beam.ParDo(
    72            synthetic_pipeline.SyntheticStep(0, 0, 10))
    73        assert_that(pcoll | beam.combiners.Count.Globally(), equal_to([100]))
    74  
    75    def test_minimal_runtime_with_synthetic_step_delay(self):
    76      start = time.time()
    77      with beam.Pipeline() as p:
    78        p | beam.Create(list(range(10))) | beam.ParDo(
    79            synthetic_pipeline.SyntheticStep(0, 0.5, 10))
    80  
    81      elapsed = time.time() - start
    82      self.assertGreaterEqual(elapsed, 0.5, elapsed)
    83  
    84    def test_synthetic_sdf_step_multiplies_output_elements_count(self):
    85      with beam.Pipeline() as p:
    86        pcoll = p | beam.Create(list(range(10))) | beam.ParDo(
    87            synthetic_pipeline.get_synthetic_sdf_step(0, 0, 10))
    88        assert_that(pcoll | beam.combiners.Count.Globally(), equal_to([100]))
    89  
    90    def test_minimal_runtime_with_synthetic_sdf_step_bundle_delay(self):
    91      start = time.time()
    92      with beam.Pipeline() as p:
    93        p | beam.Create(list(range(10))) | beam.ParDo(
    94            synthetic_pipeline.get_synthetic_sdf_step(0, 0.5, 10))
    95  
    96      elapsed = time.time() - start
    97      self.assertGreaterEqual(elapsed, 0.5, elapsed)
    98  
    99    def test_synthetic_step_split_provider(self):
   100      provider = synthetic_pipeline.SyntheticSDFStepRestrictionProvider(
   101          5, 2, False, False, None)
   102  
   103      self.assertEqual(
   104          list(provider.split('ab', OffsetRange(2, 15))),
   105          [OffsetRange(2, 8), OffsetRange(8, 15)])
   106      self.assertEqual(
   107          list(provider.split('ab', OffsetRange(0, 8))),
   108          [OffsetRange(0, 4), OffsetRange(4, 8)])
   109      self.assertEqual(list(provider.split('ab', OffsetRange(0, 0))), [])
   110      self.assertEqual(
   111          list(provider.split('ab', OffsetRange(2, 3))), [OffsetRange(2, 3)])
   112  
   113      provider = synthetic_pipeline.SyntheticSDFStepRestrictionProvider(
   114          10, 1, False, False, None)
   115      self.assertEqual(
   116          list(provider.split('ab', OffsetRange(1, 10))), [OffsetRange(1, 10)])
   117      self.assertEqual(provider.restriction_size('ab', OffsetRange(1, 10)), 9 * 2)
   118  
   119      provider = synthetic_pipeline.SyntheticSDFStepRestrictionProvider(
   120          10, 3, False, False, None)
   121      self.assertEqual(
   122          list(provider.split('ab', OffsetRange(1, 10))),
   123          [OffsetRange(1, 4), OffsetRange(4, 7), OffsetRange(7, 10)])
   124      self.assertEqual(provider.initial_restriction('a'), OffsetRange(0, 10))
   125  
   126      provider = synthetic_pipeline.SyntheticSDFStepRestrictionProvider(
   127          10, 3, False, False, 45)
   128      self.assertEqual(provider.restriction_size('ab', OffsetRange(1, 3)), 45)
   129  
   130      tracker = provider.create_tracker(OffsetRange(1, 6))
   131      tracker.try_claim(1)  # Claim to allow splitting.
   132      self.assertEqual(
   133          tracker.try_split(.5), (OffsetRange(1, 3), OffsetRange(3, 6)))
   134  
   135    def verify_random_splits(self, provider, restriction, bundles):
   136      ranges = list(provider.split('ab', restriction))
   137  
   138      prior_stop = restriction.start
   139      for r in ranges:
   140        self.assertEqual(r.start, prior_stop)
   141        prior_stop = r.stop
   142      self.assertEqual(prior_stop, restriction.stop)
   143      self.assertEqual(len(ranges), bundles)
   144  
   145    def testSyntheticStepSplitProviderUnevenChunks(self):
   146      bundles = 4
   147      provider = synthetic_pipeline.SyntheticSDFStepRestrictionProvider(
   148          5, bundles, True, False, None)
   149      self.verify_random_splits(provider, OffsetRange(4, 10), bundles)
   150      self.verify_random_splits(provider, OffsetRange(4, 4), 0)
   151      self.verify_random_splits(provider, OffsetRange(0, 1), 1)
   152      self.verify_random_splits(provider, OffsetRange(0, bundles - 2), bundles)
   153  
   154    def test_synthetic_step_split_provider_no_liquid_sharding(self):
   155      # Verify Liquid Sharding Works
   156      provider = synthetic_pipeline.SyntheticSDFStepRestrictionProvider(
   157          5, 5, True, False, None)
   158      tracker = provider.create_tracker(OffsetRange(1, 6))
   159      tracker.try_claim(2)
   160      self.assertEqual(
   161          tracker.try_split(.5), (OffsetRange(1, 4), OffsetRange(4, 6)))
   162  
   163      # Verify No Liquid Sharding
   164      provider = synthetic_pipeline.SyntheticSDFStepRestrictionProvider(
   165          5, 5, True, True, None)
   166      tracker = provider.create_tracker(OffsetRange(1, 6))
   167      tracker.try_claim(2)
   168      self.assertEqual(tracker.try_split(3), None)
   169  
   170    def test_synthetic_source(self):
   171      def assert_size(element, expected_size):
   172        assert len(element) == expected_size
   173  
   174      with beam.Pipeline() as p:
   175        pcoll = (
   176            p | beam.io.Read(
   177                synthetic_pipeline.SyntheticSource(input_spec(300, 5, 15))))
   178        (pcoll | beam.Map(lambda elm: elm[0]) | 'key' >> beam.Map(assert_size, 5))
   179        (
   180            pcoll
   181            | beam.Map(lambda elm: elm[1]) | 'value' >> beam.Map(assert_size, 15))
   182        assert_that(pcoll | beam.combiners.Count.Globally(), equal_to([300]))
   183  
   184    def test_synthetic_source_split_even(self):
   185      source = synthetic_pipeline.SyntheticSource(
   186          input_spec(1000, 1, 1, 'const', 0))
   187      splits = source.split(100)
   188      sources_info = [(split.source, split.start_position, split.stop_position)
   189                      for split in splits]
   190      self.assertEqual(20, len(sources_info))
   191      source_test_utils.assert_sources_equal_reference_source(
   192          (source, None, None), sources_info)
   193  
   194    def test_synthetic_source_split_uneven(self):
   195      source = synthetic_pipeline.SyntheticSource(
   196          input_spec(1000, 1, 1, 'zipf', 3, 10))
   197      splits = source.split(100)
   198      sources_info = [(split.source, split.start_position, split.stop_position)
   199                      for split in splits]
   200      self.assertEqual(10, len(sources_info))
   201      source_test_utils.assert_sources_equal_reference_source(
   202          (source, None, None), sources_info)
   203  
   204    def test_split_at_fraction(self):
   205      source = synthetic_pipeline.SyntheticSource(input_spec(10, 1, 1))
   206      source_test_utils.assert_split_at_fraction_exhaustive(source)
   207      source_test_utils.assert_split_at_fraction_fails(source, 5, 0.3)
   208      source_test_utils.assert_split_at_fraction_succeeds_and_consistent(
   209          source, 1, 0.3)
   210  
   211    def run_pipeline(self, barrier, writes_output=True):
   212      steps = [{
   213          'per_element_delay': 1
   214      }, {
   215          'per_element_delay': 1, 'splittable': True
   216      }]
   217      args = [
   218          '--barrier=%s' % barrier,
   219          '--runner=DirectRunner',
   220          '--steps=%s' % json.dumps(steps),
   221          '--input=%s' % json.dumps(input_spec(10, 1, 1))
   222      ]
   223      if writes_output:
   224        output_location = tempfile.NamedTemporaryFile().name
   225        args.append('--output=%s' % output_location)
   226  
   227      synthetic_pipeline.run(args, save_main_session=False)
   228  
   229      # Verify output
   230      if writes_output:
   231        read_output = []
   232        for file_name in glob.glob(output_location + '*'):
   233          with open(file_name, 'rb') as f:
   234            read_output.extend(f.read().splitlines())
   235  
   236        self.assertEqual(10, len(read_output))
   237  
   238    def test_pipeline_shuffle(self):
   239      self.run_pipeline('shuffle')
   240  
   241    def test_pipeline_side_input(self):
   242      self.run_pipeline('side-input')
   243  
   244    def test_pipeline_expand_gbk(self):
   245      self.run_pipeline('expand-gbk', False)
   246  
   247    def test_pipeline_expand_side_output(self):
   248      self.run_pipeline('expand-second-output', False)
   249  
   250    def test_pipeline_merge_gbk(self):
   251      self.run_pipeline('merge-gbk')
   252  
   253    def test_pipeline_merge_side_input(self):
   254      self.run_pipeline('merge-side-input')
   255  
   256  
   257  if __name__ == '__main__':
   258    logging.getLogger().setLevel(logging.INFO)
   259    unittest.main()