github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/sources_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Unit tests for the sources framework."""
    19  # pytype: skip-file
    20  
    21  import logging
    22  import os
    23  import tempfile
    24  import unittest
    25  
    26  import apache_beam as beam
    27  from apache_beam import coders
    28  from apache_beam.io import iobase
    29  from apache_beam.io import range_trackers
    30  from apache_beam.testing.test_pipeline import TestPipeline
    31  from apache_beam.testing.util import assert_that
    32  from apache_beam.testing.util import equal_to
    33  
    34  
    35  class LineSource(iobase.BoundedSource):
    36    """A simple source that reads lines from a given file."""
    37  
    38    TEST_BUNDLE_SIZE = 10
    39  
    40    def __init__(self, file_name):
    41      self._file_name = file_name
    42  
    43    def read(self, range_tracker):
    44      with open(self._file_name, 'rb') as f:
    45        start = range_tracker.start_position()
    46        f.seek(start)
    47        if start > 0:
    48          f.seek(-1, os.SEEK_CUR)
    49          start -= 1
    50          start += len(f.readline())
    51        current = start
    52        line = f.readline()
    53        while range_tracker.try_claim(current):
    54          if not line:
    55            return
    56          yield line.rstrip(b'\n')
    57          current += len(line)
    58          line = f.readline()
    59  
    60    def split(self, desired_bundle_size, start_position=None, stop_position=None):
    61      assert start_position is None
    62      assert stop_position is None
    63      size = self.estimate_size()
    64  
    65      bundle_start = 0
    66      while bundle_start < size:
    67        bundle_stop = min(bundle_start + LineSource.TEST_BUNDLE_SIZE, size)
    68        yield iobase.SourceBundle(
    69            bundle_stop - bundle_start, self, bundle_start, bundle_stop)
    70        bundle_start = bundle_stop
    71  
    72    def get_range_tracker(self, start_position, stop_position):
    73      if start_position is None:
    74        start_position = 0
    75      if stop_position is None:
    76        stop_position = self._get_file_size()
    77      return range_trackers.OffsetRangeTracker(start_position, stop_position)
    78  
    79    def default_output_coder(self):
    80      return coders.BytesCoder()
    81  
    82    def estimate_size(self):
    83      return self._get_file_size()
    84  
    85    def _get_file_size(self):
    86      with open(self._file_name, 'rb') as f:
    87        f.seek(0, os.SEEK_END)
    88        return f.tell()
    89  
    90  
    91  class SourcesTest(unittest.TestCase):
    92    def _create_temp_file(self, contents):
    93      with tempfile.NamedTemporaryFile(delete=False) as f:
    94        f.write(contents)
    95        return f.name
    96  
    97    def test_read_from_source(self):
    98      file_name = self._create_temp_file(b'aaaa\nbbbb\ncccc\ndddd')
    99  
   100      source = LineSource(file_name)
   101      range_tracker = source.get_range_tracker(None, None)
   102      result = [line for line in source.read(range_tracker)]
   103  
   104      self.assertCountEqual([b'aaaa', b'bbbb', b'cccc', b'dddd'], result)
   105      self.assertTrue(
   106          range_tracker.last_attempted_record_start >=
   107          range_tracker.stop_position())
   108  
   109    def test_source_estimated_size(self):
   110      file_name = self._create_temp_file(b'aaaa\n')
   111  
   112      source = LineSource(file_name)
   113      self.assertEqual(5, source.estimate_size())
   114  
   115    def test_run_direct(self):
   116      file_name = self._create_temp_file(b'aaaa\nbbbb\ncccc\ndddd')
   117      with TestPipeline() as pipeline:
   118        pcoll = pipeline | beam.io.Read(LineSource(file_name))
   119        assert_that(pcoll, equal_to([b'aaaa', b'bbbb', b'cccc', b'dddd']))
   120  
   121  
   122  if __name__ == '__main__':
   123    logging.getLogger().setLevel(logging.INFO)
   124    unittest.main()