github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/filebasedsource_test.py (about)

     1  # Licensed to the Apache Software Foundation (ASF) under one or more
     2  # contributor license agreements.  See the NOTICE file distributed with
     3  # this work for additional information regarding copyright ownership.
     4  # The ASF licenses this file to You under the Apache License, Version 2.0
     5  # (the "License"); you may not use this file except in compliance with
     6  # the License.  You may obtain a copy of the License at
     7  #
     8  #    http://www.apache.org/licenses/LICENSE-2.0
     9  #
    10  # Unless required by applicable law or agreed to in writing, software
    11  # distributed under the License is distributed on an "AS IS" BASIS,
    12  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  # See the License for the specific language governing permissions and
    14  # limitations under the License.
    15  #
    16  
    17  # pytype: skip-file
    18  
    19  import bz2
    20  import gzip
    21  import io
    22  import logging
    23  import math
    24  import os
    25  import random
    26  import tempfile
    27  import unittest
    28  
    29  import hamcrest as hc
    30  
    31  import apache_beam as beam
    32  from apache_beam.io import filebasedsource
    33  from apache_beam.io import iobase
    34  from apache_beam.io import range_trackers
    35  # importing following private classes for testing
    36  from apache_beam.io.concat_source import ConcatSource
    37  from apache_beam.io.filebasedsource import _SingleFileSource as SingleFileSource
    38  from apache_beam.io.filebasedsource import FileBasedSource
    39  from apache_beam.io.filesystem import CompressionTypes
    40  from apache_beam.options.value_provider import RuntimeValueProvider
    41  from apache_beam.options.value_provider import StaticValueProvider
    42  from apache_beam.testing.test_pipeline import TestPipeline
    43  from apache_beam.testing.util import assert_that
    44  from apache_beam.testing.util import equal_to
    45  from apache_beam.transforms.display import DisplayData
    46  from apache_beam.transforms.display_test import DisplayDataItemMatcher
    47  
    48  
    49  class LineSource(FileBasedSource):
    50    def read_records(self, file_name, range_tracker):
    51      f = self.open_file(file_name)
    52      try:
    53        start = range_tracker.start_position()
    54        if start > 0:
    55          # Any line that starts after 'start' does not belong to the current
    56          # bundle. Seeking to (start - 1) and skipping a line moves the current
    57          # position to the starting position of the first line that belongs to
    58          # the current bundle.
    59          start -= 1
    60          f.seek(start)
    61          line = f.readline()
    62          start += len(line)
    63        current = start
    64        line = f.readline()
    65        while range_tracker.try_claim(current):
    66          # When the source is unsplittable, try_claim is not enough to determine
    67          # whether the file has reached to the end.
    68          if not line:
    69            return
    70          yield line.rstrip(b'\n')
    71          current += len(line)
    72          line = f.readline()
    73      finally:
    74        f.close()
    75  
    76  
    77  class EOL(object):
    78    LF = 1
    79    CRLF = 2
    80    MIXED = 3
    81    LF_WITH_NOTHING_AT_LAST_LINE = 4
    82  
    83  
    84  def write_data(
    85      num_lines,
    86      no_data=False,
    87      directory=None,
    88      prefix=tempfile.template,
    89      eol=EOL.LF):
    90    """Writes test data to a temporary file.
    91  
    92    Args:
    93      num_lines (int): The number of lines to write.
    94      no_data (bool): If :data:`True`, empty lines will be written, otherwise
    95        each line will contain a concatenation of b'line' and the line number.
    96      directory (str): The name of the directory to create the temporary file in.
    97      prefix (str): The prefix to use for the temporary file.
    98      eol (int): The line ending to use when writing.
    99        :class:`~apache_beam.io.filebasedsource_test.EOL` exposes attributes that
   100        can be used here to define the eol.
   101  
   102    Returns:
   103      Tuple[str, List[bytes]]: A tuple of the filename and a list of the written
   104        data.
   105    """
   106    all_data = []
   107    with tempfile.NamedTemporaryFile(delete=False, dir=directory,
   108                                     prefix=prefix) as f:
   109      sep_values = [b'\n', b'\r\n']
   110      for i in range(num_lines):
   111        data = b'' if no_data else b'line' + str(i).encode()
   112        all_data.append(data)
   113  
   114        if eol == EOL.LF:
   115          sep = sep_values[0]
   116        elif eol == EOL.CRLF:
   117          sep = sep_values[1]
   118        elif eol == EOL.MIXED:
   119          sep = sep_values[i % len(sep_values)]
   120        elif eol == EOL.LF_WITH_NOTHING_AT_LAST_LINE:
   121          sep = b'' if i == (num_lines - 1) else sep_values[0]
   122        else:
   123          raise ValueError('Received unknown value %s for eol.' % eol)
   124  
   125        f.write(data + sep)
   126  
   127      return f.name, all_data
   128  
   129  
   130  def _write_prepared_data(
   131      data, directory=None, prefix=tempfile.template, suffix=''):
   132    with tempfile.NamedTemporaryFile(delete=False,
   133                                     dir=directory,
   134                                     prefix=prefix,
   135                                     suffix=suffix) as f:
   136      f.write(data)
   137      return f.name
   138  
   139  
   140  def write_prepared_pattern(data, suffixes=None):
   141    assert data, 'Data (%s) seems to be empty' % data
   142    if suffixes is None:
   143      suffixes = [''] * len(data)
   144    temp_dir = tempfile.mkdtemp()
   145    for i, d in enumerate(data):
   146      file_name = _write_prepared_data(
   147          d, temp_dir, prefix='mytemp', suffix=suffixes[i])
   148    return file_name[:file_name.rfind(os.path.sep)] + os.path.sep + 'mytemp*'
   149  
   150  
   151  def write_pattern(lines_per_file, no_data=False):
   152    """Writes a pattern of temporary files.
   153  
   154    Args:
   155      lines_per_file (List[int]): The number of lines to write per file.
   156      no_data (bool): If :data:`True`, empty lines will be written, otherwise
   157        each line will contain a concatenation of b'line' and the line number.
   158  
   159    Returns:
   160      Tuple[str, List[bytes]]: A tuple of the filename pattern and a list of the
   161        written data.
   162    """
   163    temp_dir = tempfile.mkdtemp()
   164  
   165    all_data = []
   166    file_name = None
   167    start_index = 0
   168    for i in range(len(lines_per_file)):
   169      file_name, data = write_data(lines_per_file[i], no_data=no_data,
   170                                   directory=temp_dir, prefix='mytemp')
   171      all_data.extend(data)
   172      start_index += lines_per_file[i]
   173  
   174    assert file_name
   175    return (
   176        file_name[:file_name.rfind(os.path.sep)] + os.path.sep + 'mytemp*',
   177        all_data)
   178  
   179  
   180  class TestConcatSource(unittest.TestCase):
   181    class DummySource(iobase.BoundedSource):
   182      def __init__(self, values):
   183        self._values = values
   184  
   185      def split(
   186          self, desired_bundle_size, start_position=None, stop_position=None):
   187        # simply devides values into two bundles
   188        middle = len(self._values) // 2
   189        yield iobase.SourceBundle(
   190            0.5, TestConcatSource.DummySource(self._values[:middle]), None, None)
   191        yield iobase.SourceBundle(
   192            0.5, TestConcatSource.DummySource(self._values[middle:]), None, None)
   193  
   194      def get_range_tracker(self, start_position, stop_position):
   195        if start_position is None:
   196          start_position = 0
   197        if stop_position is None:
   198          stop_position = len(self._values)
   199  
   200        return range_trackers.OffsetRangeTracker(start_position, stop_position)
   201  
   202      def read(self, range_tracker):
   203        for index, value in enumerate(self._values):
   204          if not range_tracker.try_claim(index):
   205            return
   206  
   207          yield value
   208  
   209      def estimate_size(self):
   210        return len(self._values)  # Assuming each value to be 1 byte.
   211  
   212    def setUp(self):
   213      # Reducing the size of thread pools. Without this test execution may fail in
   214      # environments with limited amount of resources.
   215      filebasedsource.MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 2
   216  
   217    def test_read(self):
   218      sources = [
   219          TestConcatSource.DummySource(range(start, start + 10))
   220          for start in [0, 10, 20]
   221      ]
   222      concat = ConcatSource(sources)
   223      range_tracker = concat.get_range_tracker(None, None)
   224      read_data = [value for value in concat.read(range_tracker)]
   225      self.assertCountEqual(list(range(30)), read_data)
   226  
   227    def test_split(self):
   228      sources = [
   229          TestConcatSource.DummySource(list(range(start, start + 10)))
   230          for start in [0, 10, 20]
   231      ]
   232      concat = ConcatSource(sources)
   233      splits = [split for split in concat.split()]
   234      self.assertEqual(6, len(splits))
   235  
   236      # Reading all splits
   237      read_data = []
   238      for split in splits:
   239        range_tracker_for_split = split.source.get_range_tracker(
   240            split.start_position, split.stop_position)
   241        read_data.extend(
   242            [value for value in split.source.read(range_tracker_for_split)])
   243      self.assertCountEqual(list(range(30)), read_data)
   244  
   245    def test_estimate_size(self):
   246      sources = [
   247          TestConcatSource.DummySource(range(start, start + 10))
   248          for start in [0, 10, 20]
   249      ]
   250      concat = ConcatSource(sources)
   251      self.assertEqual(30, concat.estimate_size())
   252  
   253  
   254  class TestFileBasedSource(unittest.TestCase):
   255    def setUp(self):
   256      # Reducing the size of thread pools. Without this test execution may fail in
   257      # environments with limited amount of resources.
   258      filebasedsource.MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 2
   259  
   260    def test_string_or_value_provider_only(self):
   261      str_file_pattern = tempfile.NamedTemporaryFile(delete=False).name
   262      self.assertEqual(
   263          str_file_pattern, FileBasedSource(str_file_pattern)._pattern.value)
   264  
   265      static_vp_file_pattern = StaticValueProvider(
   266          value_type=str, value=str_file_pattern)
   267      self.assertEqual(
   268          static_vp_file_pattern,
   269          FileBasedSource(static_vp_file_pattern)._pattern)
   270  
   271      runtime_vp_file_pattern = RuntimeValueProvider(
   272          option_name='arg', value_type=str, default_value=str_file_pattern)
   273      self.assertEqual(
   274          runtime_vp_file_pattern,
   275          FileBasedSource(runtime_vp_file_pattern)._pattern)
   276      # Reset runtime options to avoid side-effects in other tests.
   277      RuntimeValueProvider.set_runtime_options(None)
   278  
   279      invalid_file_pattern = 123
   280      with self.assertRaises(TypeError):
   281        FileBasedSource(invalid_file_pattern)
   282  
   283    def test_validation_file_exists(self):
   284      file_name, _ = write_data(10)
   285      LineSource(file_name)
   286  
   287    def test_validation_directory_non_empty(self):
   288      temp_dir = tempfile.mkdtemp()
   289      file_name, _ = write_data(10, directory=temp_dir)
   290      LineSource(file_name)
   291  
   292    def test_validation_failing(self):
   293      no_files_found_error = 'No files found based on the file pattern*'
   294      with self.assertRaisesRegex(IOError, no_files_found_error):
   295        LineSource('dummy_pattern')
   296      with self.assertRaisesRegex(IOError, no_files_found_error):
   297        temp_dir = tempfile.mkdtemp()
   298        LineSource(os.path.join(temp_dir, '*'))
   299  
   300    def test_validation_file_missing_verification_disabled(self):
   301      LineSource('dummy_pattern', validate=False)
   302  
   303    def test_fully_read_single_file(self):
   304      file_name, expected_data = write_data(10)
   305      assert len(expected_data) == 10
   306      fbs = LineSource(file_name)
   307      range_tracker = fbs.get_range_tracker(None, None)
   308      read_data = [record for record in fbs.read(range_tracker)]
   309      self.assertCountEqual(expected_data, read_data)
   310  
   311    def test_single_file_display_data(self):
   312      file_name, _ = write_data(10)
   313      fbs = LineSource(file_name)
   314      dd = DisplayData.create_from(fbs)
   315      expected_items = [
   316          DisplayDataItemMatcher('file_pattern', file_name),
   317          DisplayDataItemMatcher('compression', 'auto')
   318      ]
   319      hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
   320  
   321    def test_fully_read_file_pattern(self):
   322      pattern, expected_data = write_pattern([5, 3, 12, 8, 8, 4])
   323      assert len(expected_data) == 40
   324      fbs = LineSource(pattern)
   325      range_tracker = fbs.get_range_tracker(None, None)
   326      read_data = [record for record in fbs.read(range_tracker)]
   327      self.assertCountEqual(expected_data, read_data)
   328  
   329    def test_fully_read_file_pattern_with_empty_files(self):
   330      pattern, expected_data = write_pattern([5, 0, 12, 0, 8, 0])
   331      assert len(expected_data) == 25
   332      fbs = LineSource(pattern)
   333      range_tracker = fbs.get_range_tracker(None, None)
   334      read_data = [record for record in fbs.read(range_tracker)]
   335      self.assertCountEqual(expected_data, read_data)
   336  
   337    def test_estimate_size_of_file(self):
   338      file_name, expected_data = write_data(10)
   339      assert len(expected_data) == 10
   340      fbs = LineSource(file_name)
   341      self.assertEqual(10 * 6, fbs.estimate_size())
   342  
   343    def test_estimate_size_of_pattern(self):
   344      pattern, expected_data = write_pattern([5, 3, 10, 8, 8, 4])
   345      assert len(expected_data) == 38
   346      fbs = LineSource(pattern)
   347      self.assertEqual(38 * 6, fbs.estimate_size())
   348  
   349      pattern, expected_data = write_pattern([5, 3, 9])
   350      assert len(expected_data) == 17
   351      fbs = LineSource(pattern)
   352      self.assertEqual(17 * 6, fbs.estimate_size())
   353  
   354    def test_estimate_size_with_sampling_same_size(self):
   355      num_files = 2 * FileBasedSource.MIN_NUMBER_OF_FILES_TO_STAT
   356      pattern, _ = write_pattern([10] * num_files)
   357      # Each line will be of length 6 since write_pattern() uses
   358      # ('line' + line number + '\n') as data.
   359      self.assertEqual(
   360          6 * 10 * num_files, FileBasedSource(pattern).estimate_size())
   361  
   362    def test_estimate_size_with_sampling_different_sizes(self):
   363      num_files = 2 * FileBasedSource.MIN_NUMBER_OF_FILES_TO_STAT
   364  
   365      # Each line will be of length 8 since write_pattern() uses
   366      # ('line' + line number + '\n') as data.
   367      base_size = 500
   368      variance = 5
   369  
   370      sizes = []
   371      for _ in range(num_files):
   372        sizes.append(
   373            int(random.uniform(base_size - variance, base_size + variance)))
   374      pattern, _ = write_pattern(sizes)
   375      tolerance = 0.05
   376      self.assertAlmostEqual(
   377          base_size * 8 * num_files,
   378          FileBasedSource(pattern).estimate_size(),
   379          delta=base_size * 8 * num_files * tolerance)
   380  
   381    def test_splits_into_subranges(self):
   382      pattern, expected_data = write_pattern([5, 9, 6])
   383      assert len(expected_data) == 20
   384      fbs = LineSource(pattern)
   385      splits = [split for split in fbs.split(desired_bundle_size=15)]
   386      expected_num_splits = (
   387          math.ceil(float(6 * 5) / 15) + math.ceil(float(6 * 9) / 15) +
   388          math.ceil(float(6 * 6) / 15))
   389      assert len(splits) == expected_num_splits
   390  
   391    def test_read_splits_single_file(self):
   392      file_name, expected_data = write_data(100)
   393      assert len(expected_data) == 100
   394      fbs = LineSource(file_name)
   395      splits = [split for split in fbs.split(desired_bundle_size=33)]
   396  
   397      # Reading all splits
   398      read_data = []
   399      for split in splits:
   400        source = split.source
   401        range_tracker = source.get_range_tracker(
   402            split.start_position, split.stop_position)
   403        data_from_split = [data for data in source.read(range_tracker)]
   404        read_data.extend(data_from_split)
   405  
   406      self.assertCountEqual(expected_data, read_data)
   407  
   408    def test_read_splits_file_pattern(self):
   409      pattern, expected_data = write_pattern([34, 66, 40, 24, 24, 12])
   410      assert len(expected_data) == 200
   411      fbs = LineSource(pattern)
   412      splits = [split for split in fbs.split(desired_bundle_size=50)]
   413  
   414      # Reading all splits
   415      read_data = []
   416      for split in splits:
   417        source = split.source
   418        range_tracker = source.get_range_tracker(
   419            split.start_position, split.stop_position)
   420        data_from_split = [data for data in source.read(range_tracker)]
   421        read_data.extend(data_from_split)
   422  
   423      self.assertCountEqual(expected_data, read_data)
   424  
   425    def _run_source_test(self, pattern, expected_data, splittable=True):
   426      with TestPipeline() as pipeline:
   427        pcoll = pipeline | 'Read' >> beam.io.Read(
   428            LineSource(pattern, splittable=splittable))
   429        assert_that(pcoll, equal_to(expected_data))
   430  
   431    def test_source_file(self):
   432      file_name, expected_data = write_data(100)
   433      assert len(expected_data) == 100
   434      self._run_source_test(file_name, expected_data)
   435  
   436    def test_source_pattern(self):
   437      pattern, expected_data = write_pattern([34, 66, 40, 24, 24, 12])
   438      assert len(expected_data) == 200
   439      self._run_source_test(pattern, expected_data)
   440  
   441    def test_unsplittable_does_not_split(self):
   442      pattern, expected_data = write_pattern([5, 9, 6])
   443      assert len(expected_data) == 20
   444      fbs = LineSource(pattern, splittable=False)
   445      splits = [split for split in fbs.split(desired_bundle_size=15)]
   446      self.assertEqual(3, len(splits))
   447  
   448    def test_source_file_unsplittable(self):
   449      file_name, expected_data = write_data(100)
   450      assert len(expected_data) == 100
   451      self._run_source_test(file_name, expected_data, False)
   452  
   453    def test_source_pattern_unsplittable(self):
   454      pattern, expected_data = write_pattern([34, 66, 40, 24, 24, 12])
   455      assert len(expected_data) == 200
   456      self._run_source_test(pattern, expected_data, False)
   457  
   458    def test_read_file_bzip2(self):
   459      _, lines = write_data(10)
   460      filename = tempfile.NamedTemporaryFile(
   461          delete=False, prefix=tempfile.template).name
   462      with bz2.BZ2File(filename, 'wb') as f:
   463        f.write(b'\n'.join(lines))
   464  
   465      with TestPipeline() as pipeline:
   466        pcoll = pipeline | 'Read' >> beam.io.Read(
   467            LineSource(
   468                filename,
   469                splittable=False,
   470                compression_type=CompressionTypes.BZIP2))
   471        assert_that(pcoll, equal_to(lines))
   472  
   473    def test_read_file_gzip(self):
   474      _, lines = write_data(10)
   475      filename = tempfile.NamedTemporaryFile(
   476          delete=False, prefix=tempfile.template).name
   477      with gzip.GzipFile(filename, 'wb') as f:
   478        f.write(b'\n'.join(lines))
   479  
   480      with TestPipeline() as pipeline:
   481        pcoll = pipeline | 'Read' >> beam.io.Read(
   482            LineSource(
   483                filename,
   484                splittable=False,
   485                compression_type=CompressionTypes.GZIP))
   486        assert_that(pcoll, equal_to(lines))
   487  
   488    def test_read_pattern_bzip2(self):
   489      _, lines = write_data(200)
   490      splits = [0, 34, 100, 140, 164, 188, 200]
   491      chunks = [lines[splits[i - 1]:splits[i]] for i in range(1, len(splits))]
   492      compressed_chunks = []
   493      for c in chunks:
   494        compressobj = bz2.BZ2Compressor()
   495        compressed_chunks.append(
   496            compressobj.compress(b'\n'.join(c)) + compressobj.flush())
   497      file_pattern = write_prepared_pattern(compressed_chunks)
   498      with TestPipeline() as pipeline:
   499        pcoll = pipeline | 'Read' >> beam.io.Read(
   500            LineSource(
   501                file_pattern,
   502                splittable=False,
   503                compression_type=CompressionTypes.BZIP2))
   504        assert_that(pcoll, equal_to(lines))
   505  
   506    def test_read_pattern_gzip(self):
   507      _, lines = write_data(200)
   508      splits = [0, 34, 100, 140, 164, 188, 200]
   509      chunks = [lines[splits[i - 1]:splits[i]] for i in range(1, len(splits))]
   510      compressed_chunks = []
   511      for c in chunks:
   512        out = io.BytesIO()
   513        with gzip.GzipFile(fileobj=out, mode="wb") as f:
   514          f.write(b'\n'.join(c))
   515        compressed_chunks.append(out.getvalue())
   516      file_pattern = write_prepared_pattern(compressed_chunks)
   517      with TestPipeline() as pipeline:
   518        pcoll = pipeline | 'Read' >> beam.io.Read(
   519            LineSource(
   520                file_pattern,
   521                splittable=False,
   522                compression_type=CompressionTypes.GZIP))
   523        assert_that(pcoll, equal_to(lines))
   524  
   525    def test_read_auto_single_file_bzip2(self):
   526      _, lines = write_data(10)
   527      filename = tempfile.NamedTemporaryFile(
   528          delete=False, prefix=tempfile.template, suffix='.bz2').name
   529      with bz2.BZ2File(filename, 'wb') as f:
   530        f.write(b'\n'.join(lines))
   531  
   532      with TestPipeline() as pipeline:
   533        pcoll = pipeline | 'Read' >> beam.io.Read(
   534            LineSource(filename, compression_type=CompressionTypes.AUTO))
   535        assert_that(pcoll, equal_to(lines))
   536  
   537    def test_read_auto_single_file_gzip(self):
   538      _, lines = write_data(10)
   539      filename = tempfile.NamedTemporaryFile(
   540          delete=False, prefix=tempfile.template, suffix='.gz').name
   541      with gzip.GzipFile(filename, 'wb') as f:
   542        f.write(b'\n'.join(lines))
   543  
   544      with TestPipeline() as pipeline:
   545        pcoll = pipeline | 'Read' >> beam.io.Read(
   546            LineSource(filename, compression_type=CompressionTypes.AUTO))
   547        assert_that(pcoll, equal_to(lines))
   548  
   549    def test_read_auto_pattern(self):
   550      _, lines = write_data(200)
   551      splits = [0, 34, 100, 140, 164, 188, 200]
   552      chunks = [lines[splits[i - 1]:splits[i]] for i in range(1, len(splits))]
   553      compressed_chunks = []
   554      for c in chunks:
   555        out = io.BytesIO()
   556        with gzip.GzipFile(fileobj=out, mode="wb") as f:
   557          f.write(b'\n'.join(c))
   558        compressed_chunks.append(out.getvalue())
   559      file_pattern = write_prepared_pattern(
   560          compressed_chunks, suffixes=['.gz'] * len(chunks))
   561      with TestPipeline() as pipeline:
   562        pcoll = pipeline | 'Read' >> beam.io.Read(
   563            LineSource(file_pattern, compression_type=CompressionTypes.AUTO))
   564        assert_that(pcoll, equal_to(lines))
   565  
   566    def test_read_auto_pattern_compressed_and_uncompressed(self):
   567      _, lines = write_data(200)
   568      splits = [0, 34, 100, 140, 164, 188, 200]
   569      chunks = [lines[splits[i - 1]:splits[i]] for i in range(1, len(splits))]
   570      chunks_to_write = []
   571      for i, c in enumerate(chunks):
   572        if i % 2 == 0:
   573          out = io.BytesIO()
   574          with gzip.GzipFile(fileobj=out, mode="wb") as f:
   575            f.write(b'\n'.join(c))
   576          chunks_to_write.append(out.getvalue())
   577        else:
   578          chunks_to_write.append(b'\n'.join(c))
   579      file_pattern = write_prepared_pattern(
   580          chunks_to_write, suffixes=(['.gz', ''] * 3))
   581      with TestPipeline() as pipeline:
   582        pcoll = pipeline | 'Read' >> beam.io.Read(
   583            LineSource(file_pattern, compression_type=CompressionTypes.AUTO))
   584        assert_that(pcoll, equal_to(lines))
   585  
   586    def test_splits_get_coder_from_fbs(self):
   587      class DummyCoder(object):
   588        val = 12345
   589  
   590      class FileBasedSourceWithCoder(LineSource):
   591        def default_output_coder(self):
   592          return DummyCoder()
   593  
   594      pattern, expected_data = write_pattern([34, 66, 40, 24, 24, 12])
   595      self.assertEqual(200, len(expected_data))
   596      fbs = FileBasedSourceWithCoder(pattern)
   597      splits = [split for split in fbs.split(desired_bundle_size=50)]
   598      self.assertTrue(len(splits))
   599      for split in splits:
   600        self.assertEqual(DummyCoder.val, split.source.default_output_coder().val)
   601  
   602  
   603  class TestSingleFileSource(unittest.TestCase):
   604    def setUp(self):
   605      # Reducing the size of thread pools. Without this test execution may fail in
   606      # environments with limited amount of resources.
   607      filebasedsource.MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 2
   608  
   609    def test_source_creation_fails_for_non_number_offsets(self):
   610      start_not_a_number_error = 'start_offset must be a number*'
   611      stop_not_a_number_error = 'stop_offset must be a number*'
   612      file_name = 'dummy_pattern'
   613      fbs = LineSource(file_name, validate=False)
   614  
   615      with self.assertRaisesRegex(TypeError, start_not_a_number_error):
   616        SingleFileSource(
   617            fbs, file_name='dummy_file', start_offset='aaa', stop_offset='bbb')
   618      with self.assertRaisesRegex(TypeError, start_not_a_number_error):
   619        SingleFileSource(
   620            fbs, file_name='dummy_file', start_offset='aaa', stop_offset=100)
   621      with self.assertRaisesRegex(TypeError, stop_not_a_number_error):
   622        SingleFileSource(
   623            fbs, file_name='dummy_file', start_offset=100, stop_offset='bbb')
   624      with self.assertRaisesRegex(TypeError, stop_not_a_number_error):
   625        SingleFileSource(
   626            fbs, file_name='dummy_file', start_offset=100, stop_offset=None)
   627      with self.assertRaisesRegex(TypeError, start_not_a_number_error):
   628        SingleFileSource(
   629            fbs, file_name='dummy_file', start_offset=None, stop_offset=100)
   630  
   631    def test_source_creation_display_data(self):
   632      file_name = 'dummy_pattern'
   633      fbs = LineSource(file_name, validate=False)
   634      dd = DisplayData.create_from(fbs)
   635      expected_items = [
   636          DisplayDataItemMatcher('compression', 'auto'),
   637          DisplayDataItemMatcher('file_pattern', file_name)
   638      ]
   639      hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
   640  
   641    def test_source_creation_fails_if_start_lg_stop(self):
   642      start_larger_than_stop_error = (
   643          'start_offset must be smaller than stop_offset*')
   644      fbs = LineSource('dummy_pattern', validate=False)
   645      SingleFileSource(
   646          fbs, file_name='dummy_file', start_offset=99, stop_offset=100)
   647      with self.assertRaisesRegex(ValueError, start_larger_than_stop_error):
   648        SingleFileSource(
   649            fbs, file_name='dummy_file', start_offset=100, stop_offset=99)
   650      with self.assertRaisesRegex(ValueError, start_larger_than_stop_error):
   651        SingleFileSource(
   652            fbs, file_name='dummy_file', start_offset=100, stop_offset=100)
   653  
   654    def test_estimates_size(self):
   655      fbs = LineSource('dummy_pattern', validate=False)
   656  
   657      # Should simply return stop_offset - start_offset
   658      source = SingleFileSource(
   659          fbs, file_name='dummy_file', start_offset=0, stop_offset=100)
   660      self.assertEqual(100, source.estimate_size())
   661  
   662      source = SingleFileSource(
   663          fbs, file_name='dummy_file', start_offset=10, stop_offset=100)
   664      self.assertEqual(90, source.estimate_size())
   665  
   666    def test_read_range_at_beginning(self):
   667      fbs = LineSource('dummy_pattern', validate=False)
   668  
   669      file_name, expected_data = write_data(10)
   670      assert len(expected_data) == 10
   671  
   672      source = SingleFileSource(fbs, file_name, 0, 10 * 6)
   673      range_tracker = source.get_range_tracker(0, 20)
   674      read_data = [value for value in source.read(range_tracker)]
   675      self.assertCountEqual(expected_data[:4], read_data)
   676  
   677    def test_read_range_at_end(self):
   678      fbs = LineSource('dummy_pattern', validate=False)
   679  
   680      file_name, expected_data = write_data(10)
   681      assert len(expected_data) == 10
   682  
   683      source = SingleFileSource(fbs, file_name, 0, 10 * 6)
   684      range_tracker = source.get_range_tracker(40, 60)
   685      read_data = [value for value in source.read(range_tracker)]
   686      self.assertCountEqual(expected_data[-3:], read_data)
   687  
   688    def test_read_range_at_middle(self):
   689      fbs = LineSource('dummy_pattern', validate=False)
   690  
   691      file_name, expected_data = write_data(10)
   692      assert len(expected_data) == 10
   693  
   694      source = SingleFileSource(fbs, file_name, 0, 10 * 6)
   695      range_tracker = source.get_range_tracker(20, 40)
   696      read_data = [value for value in source.read(range_tracker)]
   697      self.assertCountEqual(expected_data[4:7], read_data)
   698  
   699    def test_produces_splits_desiredsize_large_than_size(self):
   700      fbs = LineSource('dummy_pattern', validate=False)
   701  
   702      file_name, expected_data = write_data(10)
   703      assert len(expected_data) == 10
   704      source = SingleFileSource(fbs, file_name, 0, 10 * 6)
   705      splits = [split for split in source.split(desired_bundle_size=100)]
   706      self.assertEqual(1, len(splits))
   707      self.assertEqual(60, splits[0].weight)
   708      self.assertEqual(0, splits[0].start_position)
   709      self.assertEqual(60, splits[0].stop_position)
   710  
   711      range_tracker = splits[0].source.get_range_tracker(None, None)
   712      read_data = [value for value in splits[0].source.read(range_tracker)]
   713      self.assertCountEqual(expected_data, read_data)
   714  
   715    def test_produces_splits_desiredsize_smaller_than_size(self):
   716      fbs = LineSource('dummy_pattern', validate=False)
   717  
   718      file_name, expected_data = write_data(10)
   719      assert len(expected_data) == 10
   720      source = SingleFileSource(fbs, file_name, 0, 10 * 6)
   721      splits = [split for split in source.split(desired_bundle_size=25)]
   722      self.assertEqual(3, len(splits))
   723  
   724      read_data = []
   725      for split in splits:
   726        source = split.source
   727        range_tracker = source.get_range_tracker(
   728            split.start_position, split.stop_position)
   729        data_from_split = [data for data in source.read(range_tracker)]
   730        read_data.extend(data_from_split)
   731      self.assertCountEqual(expected_data, read_data)
   732  
   733    def test_produce_split_with_start_and_end_positions(self):
   734      fbs = LineSource('dummy_pattern', validate=False)
   735  
   736      file_name, expected_data = write_data(10)
   737      assert len(expected_data) == 10
   738      source = SingleFileSource(fbs, file_name, 0, 10 * 6)
   739      splits = [
   740          split for split in source.split(
   741              desired_bundle_size=15, start_offset=10, stop_offset=50)
   742      ]
   743      self.assertEqual(3, len(splits))
   744  
   745      read_data = []
   746      for split in splits:
   747        source = split.source
   748        range_tracker = source.get_range_tracker(
   749            split.start_position, split.stop_position)
   750        data_from_split = [data for data in source.read(range_tracker)]
   751        read_data.extend(data_from_split)
   752      self.assertCountEqual(expected_data[2:9], read_data)
   753  
   754  
   755  if __name__ == '__main__':
   756    logging.getLogger().setLevel(logging.INFO)
   757    unittest.main()