github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/source_test_utils_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  # pytype: skip-file
    19  
    20  import logging
    21  import tempfile
    22  import unittest
    23  
    24  from apache_beam.io import source_test_utils
    25  from apache_beam.io.filebasedsource_test import LineSource
    26  
    27  
    28  class SourceTestUtilsTest(unittest.TestCase):
    29    def _create_file_with_data(self, lines):
    30      assert isinstance(lines, list)
    31      with tempfile.NamedTemporaryFile(delete=False) as f:
    32        for line in lines:
    33          f.write(line + b'\n')
    34  
    35        return f.name
    36  
    37    def _create_data(self, num_lines):
    38      return [b'line ' + str(i).encode('latin1') for i in range(num_lines)]
    39  
    40    def _create_source(self, data):
    41      source = LineSource(self._create_file_with_data(data))
    42      # By performing initial splitting, we can get a source for a single file.
    43      # This source, that uses OffsetRangeTracker, is better for testing purposes,
    44      # than using the original source for a file-pattern.
    45      for bundle in source.split(float('inf')):
    46        return bundle.source
    47  
    48    def test_read_from_source(self):
    49      data = self._create_data(100)
    50      source = self._create_source(data)
    51      self.assertCountEqual(
    52          data, source_test_utils.read_from_source(source, None, None))
    53  
    54    def test_source_equals_reference_source(self):
    55      data = self._create_data(100)
    56      reference_source = self._create_source(data)
    57      sources_info = [(split.source, split.start_position, split.stop_position)
    58                      for split in reference_source.split(desired_bundle_size=50)]
    59      if len(sources_info) < 2:
    60        raise ValueError(
    61            'Test is too trivial since splitting only generated %d'
    62            'bundles. Please adjust the test so that at least '
    63            'two splits get generated.' % len(sources_info))
    64  
    65      source_test_utils.assert_sources_equal_reference_source(
    66          (reference_source, None, None), sources_info)
    67  
    68    def test_split_at_fraction_successful(self):
    69      data = self._create_data(100)
    70      source = self._create_source(data)
    71      result1 = source_test_utils.assert_split_at_fraction_behavior(
    72          source,
    73          10,
    74          0.5,
    75          source_test_utils.ExpectedSplitOutcome.MUST_SUCCEED_AND_BE_CONSISTENT)
    76      result2 = source_test_utils.assert_split_at_fraction_behavior(
    77          source,
    78          20,
    79          0.5,
    80          source_test_utils.ExpectedSplitOutcome.MUST_SUCCEED_AND_BE_CONSISTENT)
    81      self.assertEqual(result1, result2)
    82      self.assertEqual(100, result1[0] + result1[1])
    83  
    84      result3 = source_test_utils.assert_split_at_fraction_behavior(
    85          source,
    86          30,
    87          0.8,
    88          source_test_utils.ExpectedSplitOutcome.MUST_SUCCEED_AND_BE_CONSISTENT)
    89      result4 = source_test_utils.assert_split_at_fraction_behavior(
    90          source,
    91          50,
    92          0.8,
    93          source_test_utils.ExpectedSplitOutcome.MUST_SUCCEED_AND_BE_CONSISTENT)
    94      self.assertEqual(result3, result4)
    95      self.assertEqual(100, result3[0] + result4[1])
    96  
    97      self.assertTrue(result1[0] < result3[0])
    98      self.assertTrue(result1[1] > result3[1])
    99  
   100    def test_split_at_fraction_fails(self):
   101      data = self._create_data(100)
   102      source = self._create_source(data)
   103  
   104      result = source_test_utils.assert_split_at_fraction_behavior(
   105          source, 90, 0.1, source_test_utils.ExpectedSplitOutcome.MUST_FAIL)
   106      self.assertEqual(result[0], 100)
   107      self.assertEqual(result[1], -1)
   108  
   109      with self.assertRaises(ValueError):
   110        source_test_utils.assert_split_at_fraction_behavior(
   111            source, 10, 0.5, source_test_utils.ExpectedSplitOutcome.MUST_FAIL)
   112  
   113    def test_split_at_fraction_binary(self):
   114      data = self._create_data(100)
   115      source = self._create_source(data)
   116  
   117      stats = source_test_utils.SplitFractionStatistics([], [])
   118      source_test_utils.assert_split_at_fraction_binary(
   119          source, data, 10, 0.5, None, 0.8, None, stats)
   120  
   121      # These lists should not be empty now.
   122      self.assertTrue(stats.successful_fractions)
   123      self.assertTrue(stats.non_trivial_fractions)
   124  
   125    def test_split_at_fraction_exhaustive(self):
   126      data = self._create_data(10)
   127      source = self._create_source(data)
   128      source_test_utils.assert_split_at_fraction_exhaustive(source)
   129  
   130  
   131  if __name__ == '__main__':
   132    logging.getLogger().setLevel(logging.INFO)
   133    unittest.main()