github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/textio_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Tests for textio module."""
    19  # pytype: skip-file
    20  
    21  import bz2
    22  import glob
    23  import gzip
    24  import logging
    25  import os
    26  import shutil
    27  import tempfile
    28  import unittest
    29  import zlib
    30  
    31  import apache_beam as beam
    32  from apache_beam import coders
    33  from apache_beam.io import iobase
    34  from apache_beam.io import source_test_utils
    35  from apache_beam.io.filesystem import CompressionTypes
    36  from apache_beam.io.filesystems import FileSystems
    37  from apache_beam.io.textio import _TextSink as TextSink
    38  from apache_beam.io.textio import _TextSource as TextSource
    39  # Importing following private classes for testing.
    40  from apache_beam.io.textio import ReadAllFromText
    41  from apache_beam.io.textio import ReadAllFromTextContinuously
    42  from apache_beam.io.textio import ReadFromText
    43  from apache_beam.io.textio import ReadFromTextWithFilename
    44  from apache_beam.io.textio import WriteToText
    45  from apache_beam.testing.test_pipeline import TestPipeline
    46  from apache_beam.testing.test_utils import TempDir
    47  from apache_beam.testing.util import assert_that
    48  from apache_beam.testing.util import equal_to
    49  from apache_beam.transforms.core import Create
    50  from apache_beam.transforms.userstate import CombiningValueStateSpec
    51  from apache_beam.utils.timestamp import Timestamp
    52  
    53  
    54  class DummyCoder(coders.Coder):
    55    def encode(self, x):
    56      raise ValueError
    57  
    58    def decode(self, x):
    59      return (x * 2).decode('utf-8')
    60  
    61    def to_type_hint(self):
    62      return str
    63  
    64  
    65  class EOL(object):
    66    LF = 1
    67    CRLF = 2
    68    MIXED = 3
    69    LF_WITH_NOTHING_AT_LAST_LINE = 4
    70    CUSTOM_DELIMITER = 5
    71  
    72  
    73  def write_data(
    74      num_lines,
    75      no_data=False,
    76      directory=None,
    77      prefix=tempfile.template,
    78      eol=EOL.LF,
    79      custom_delimiter=None,
    80      line_value=b'line'):
    81    """Writes test data to a temporary file.
    82  
    83    Args:
    84      num_lines (int): The number of lines to write.
    85      no_data (bool): If :data:`True`, empty lines will be written, otherwise
    86        each line will contain a concatenation of b'line' and the line number.
    87      directory (str): The name of the directory to create the temporary file in.
    88      prefix (str): The prefix to use for the temporary file.
    89      eol (int): The line ending to use when writing.
    90        :class:`~apache_beam.io.textio_test.EOL` exposes attributes that can be
    91        used here to define the eol.
    92      custom_delimiter (bytes): The custom delimiter.
    93      line_value (bytes): Default value for test data, default b'line'
    94  
    95    Returns:
    96      Tuple[str, List[str]]: A tuple of the filename and a list of the
    97        utf-8 decoded written data.
    98    """
    99    all_data = []
   100    with tempfile.NamedTemporaryFile(delete=False, dir=directory,
   101                                     prefix=prefix) as f:
   102      sep_values = [b'\n', b'\r\n']
   103      for i in range(num_lines):
   104        data = b'' if no_data else line_value + str(i).encode()
   105        all_data.append(data)
   106  
   107        if eol == EOL.LF:
   108          sep = sep_values[0]
   109        elif eol == EOL.CRLF:
   110          sep = sep_values[1]
   111        elif eol == EOL.MIXED:
   112          sep = sep_values[i % len(sep_values)]
   113        elif eol == EOL.LF_WITH_NOTHING_AT_LAST_LINE:
   114          sep = b'' if i == (num_lines - 1) else sep_values[0]
   115        elif eol == EOL.CUSTOM_DELIMITER:
   116          if custom_delimiter is None or len(custom_delimiter) == 0:
   117            raise ValueError('delimiter can not be null or empty')
   118          else:
   119            sep = custom_delimiter
   120        else:
   121          raise ValueError('Received unknown value %s for eol.' % eol)
   122  
   123        f.write(data + sep)
   124  
   125      return f.name, [line.decode('utf-8') for line in all_data]
   126  
   127  
   128  def write_pattern(lines_per_file, no_data=False, return_filenames=False):
   129    """Writes a pattern of temporary files.
   130  
   131    Args:
   132      lines_per_file (List[int]): The number of lines to write per file.
   133      no_data (bool): If :data:`True`, empty lines will be written, otherwise
   134        each line will contain a concatenation of b'line' and the line number.
   135      return_filenames (bool): If True, returned list will contain
   136        (filename, data) pairs.
   137  
   138    Returns:
   139      Tuple[str, List[Union[str, (str, str)]]]: A tuple of the filename pattern
   140        and a list of the utf-8 decoded written data or (filename, data) pairs.
   141    """
   142    temp_dir = tempfile.mkdtemp()
   143  
   144    all_data = []
   145    file_name = None
   146    start_index = 0
   147    for i in range(len(lines_per_file)):
   148      file_name, data = write_data(lines_per_file[i], no_data=no_data,
   149                                   directory=temp_dir, prefix='mytemp')
   150      if return_filenames:
   151        all_data.extend(zip([file_name] * len(data), data))
   152      else:
   153        all_data.extend(data)
   154      start_index += lines_per_file[i]
   155  
   156    assert file_name
   157    return (
   158        file_name[:file_name.rfind(os.path.sep)] + os.path.sep + 'mytemp*',
   159        all_data)
   160  
   161  
   162  class TextSourceTest(unittest.TestCase):
   163  
   164    # Number of records that will be written by most tests.
   165    DEFAULT_NUM_RECORDS = 100
   166  
   167    def _run_read_test(
   168        self,
   169        file_or_pattern,
   170        expected_data,
   171        buffer_size=DEFAULT_NUM_RECORDS,
   172        compression=CompressionTypes.UNCOMPRESSED,
   173        delimiter=None,
   174        escapechar=None):
   175      # Since each record usually takes more than 1 byte, default buffer size is
   176      # smaller than the total size of the file. This is done to
   177      # increase test coverage for cases that hit the buffer boundary.
   178      kwargs = {}
   179      if delimiter:
   180        kwargs['delimiter'] = delimiter
   181      if escapechar:
   182        kwargs['escapechar'] = escapechar
   183      source = TextSource(
   184          file_or_pattern,
   185          0,
   186          compression,
   187          True,
   188          coders.StrUtf8Coder(),
   189          buffer_size,
   190          **kwargs)
   191      range_tracker = source.get_range_tracker(None, None)
   192      read_data = list(source.read(range_tracker))
   193      self.assertCountEqual(expected_data, read_data)
   194  
   195    def test_read_single_file(self):
   196      file_name, expected_data = write_data(TextSourceTest.DEFAULT_NUM_RECORDS)
   197      assert len(expected_data) == TextSourceTest.DEFAULT_NUM_RECORDS
   198      self._run_read_test(file_name, expected_data)
   199  
   200    def test_read_single_file_smaller_than_default_buffer(self):
   201      file_name, expected_data = write_data(TextSourceTest.DEFAULT_NUM_RECORDS)
   202      self._run_read_test(
   203          file_name,
   204          expected_data,
   205          buffer_size=TextSource.DEFAULT_READ_BUFFER_SIZE)
   206  
   207    def test_read_single_file_larger_than_default_buffer(self):
   208      file_name, expected_data = write_data(TextSource.DEFAULT_READ_BUFFER_SIZE)
   209      self._run_read_test(
   210          file_name,
   211          expected_data,
   212          buffer_size=TextSource.DEFAULT_READ_BUFFER_SIZE)
   213  
   214    def test_read_file_pattern(self):
   215      pattern, expected_data = write_pattern(
   216          [TextSourceTest.DEFAULT_NUM_RECORDS * 5,
   217           TextSourceTest.DEFAULT_NUM_RECORDS * 3,
   218           TextSourceTest.DEFAULT_NUM_RECORDS * 12,
   219           TextSourceTest.DEFAULT_NUM_RECORDS * 8,
   220           TextSourceTest.DEFAULT_NUM_RECORDS * 8,
   221           TextSourceTest.DEFAULT_NUM_RECORDS * 4])
   222      assert len(expected_data) == TextSourceTest.DEFAULT_NUM_RECORDS * 40
   223      self._run_read_test(pattern, expected_data)
   224  
   225    def test_read_single_file_windows_eol(self):
   226      file_name, expected_data = write_data(TextSourceTest.DEFAULT_NUM_RECORDS,
   227                                            eol=EOL.CRLF)
   228      assert len(expected_data) == TextSourceTest.DEFAULT_NUM_RECORDS
   229      self._run_read_test(file_name, expected_data)
   230  
   231    def test_read_single_file_mixed_eol(self):
   232      file_name, expected_data = write_data(TextSourceTest.DEFAULT_NUM_RECORDS,
   233                                            eol=EOL.MIXED)
   234      assert len(expected_data) == TextSourceTest.DEFAULT_NUM_RECORDS
   235      self._run_read_test(file_name, expected_data)
   236  
   237    def test_read_single_file_last_line_no_eol(self):
   238      file_name, expected_data = write_data(
   239          TextSourceTest.DEFAULT_NUM_RECORDS,
   240          eol=EOL.LF_WITH_NOTHING_AT_LAST_LINE)
   241      assert len(expected_data) == TextSourceTest.DEFAULT_NUM_RECORDS
   242      self._run_read_test(file_name, expected_data)
   243  
   244    def test_read_single_file_single_line_no_eol(self):
   245      file_name, expected_data = write_data(
   246          1, eol=EOL.LF_WITH_NOTHING_AT_LAST_LINE)
   247  
   248      assert len(expected_data) == 1
   249      self._run_read_test(file_name, expected_data)
   250  
   251    def test_read_empty_single_file(self):
   252      file_name, written_data = write_data(
   253          1, no_data=True, eol=EOL.LF_WITH_NOTHING_AT_LAST_LINE)
   254  
   255      assert len(written_data) == 1
   256      # written data has a single entry with an empty string. Reading the source
   257      # should not produce anything since we only wrote a single empty string
   258      # without an end of line character.
   259      self._run_read_test(file_name, [])
   260  
   261    def test_read_single_file_last_line_no_eol_gzip(self):
   262      file_name, expected_data = write_data(
   263          TextSourceTest.DEFAULT_NUM_RECORDS,
   264          eol=EOL.LF_WITH_NOTHING_AT_LAST_LINE)
   265  
   266      gzip_file_name = file_name + '.gz'
   267      with open(file_name, 'rb') as src, gzip.open(gzip_file_name, 'wb') as dst:
   268        dst.writelines(src)
   269  
   270      assert len(expected_data) == TextSourceTest.DEFAULT_NUM_RECORDS
   271      self._run_read_test(
   272          gzip_file_name, expected_data, compression=CompressionTypes.GZIP)
   273  
   274    def test_read_single_file_single_line_no_eol_gzip(self):
   275      file_name, expected_data = write_data(
   276          1, eol=EOL.LF_WITH_NOTHING_AT_LAST_LINE)
   277  
   278      gzip_file_name = file_name + '.gz'
   279      with open(file_name, 'rb') as src, gzip.open(gzip_file_name, 'wb') as dst:
   280        dst.writelines(src)
   281  
   282      assert len(expected_data) == 1
   283      self._run_read_test(
   284          gzip_file_name, expected_data, compression=CompressionTypes.GZIP)
   285  
   286    def test_read_empty_single_file_no_eol_gzip(self):
   287      file_name, written_data = write_data(
   288          1, no_data=True, eol=EOL.LF_WITH_NOTHING_AT_LAST_LINE)
   289  
   290      gzip_file_name = file_name + '.gz'
   291      with open(file_name, 'rb') as src, gzip.open(gzip_file_name, 'wb') as dst:
   292        dst.writelines(src)
   293  
   294      assert len(written_data) == 1
   295      # written data has a single entry with an empty string. Reading the source
   296      # should not produce anything since we only wrote a single empty string
   297      # without an end of line character.
   298      self._run_read_test(gzip_file_name, [], compression=CompressionTypes.GZIP)
   299  
   300    def test_read_single_file_with_empty_lines(self):
   301      file_name, expected_data = write_data(
   302          TextSourceTest.DEFAULT_NUM_RECORDS, no_data=True, eol=EOL.LF)
   303  
   304      assert len(expected_data) == TextSourceTest.DEFAULT_NUM_RECORDS
   305      assert not expected_data[0]
   306  
   307      self._run_read_test(file_name, expected_data)
   308  
   309    def test_read_single_file_without_striping_eol_lf(self):
   310      file_name, written_data = write_data(TextSourceTest.DEFAULT_NUM_RECORDS,
   311                                           eol=EOL.LF)
   312      assert len(written_data) == TextSourceTest.DEFAULT_NUM_RECORDS
   313      source = TextSource(
   314          file_name,
   315          0,
   316          CompressionTypes.UNCOMPRESSED,
   317          False,
   318          coders.StrUtf8Coder())
   319  
   320      range_tracker = source.get_range_tracker(None, None)
   321      read_data = list(source.read(range_tracker))
   322      self.assertCountEqual([line + '\n' for line in written_data], read_data)
   323  
   324    def test_read_single_file_without_striping_eol_crlf(self):
   325      file_name, written_data = write_data(TextSourceTest.DEFAULT_NUM_RECORDS,
   326                                           eol=EOL.CRLF)
   327      assert len(written_data) == TextSourceTest.DEFAULT_NUM_RECORDS
   328      source = TextSource(
   329          file_name,
   330          0,
   331          CompressionTypes.UNCOMPRESSED,
   332          False,
   333          coders.StrUtf8Coder())
   334  
   335      range_tracker = source.get_range_tracker(None, None)
   336      read_data = list(source.read(range_tracker))
   337      self.assertCountEqual([line + '\r\n' for line in written_data], read_data)
   338  
   339    def test_read_file_pattern_with_empty_files(self):
   340      pattern, expected_data = write_pattern(
   341          [5 * TextSourceTest.DEFAULT_NUM_RECORDS,
   342           3 * TextSourceTest.DEFAULT_NUM_RECORDS,
   343           12 * TextSourceTest.DEFAULT_NUM_RECORDS,
   344           8 * TextSourceTest.DEFAULT_NUM_RECORDS,
   345           8 * TextSourceTest.DEFAULT_NUM_RECORDS,
   346           4 * TextSourceTest.DEFAULT_NUM_RECORDS],
   347          no_data=True)
   348      assert len(expected_data) == 40 * TextSourceTest.DEFAULT_NUM_RECORDS
   349      assert not expected_data[0]
   350      self._run_read_test(pattern, expected_data)
   351  
   352    def test_read_after_splitting(self):
   353      file_name, expected_data = write_data(10)
   354      assert len(expected_data) == 10
   355      source = TextSource(
   356          file_name,
   357          0,
   358          CompressionTypes.UNCOMPRESSED,
   359          True,
   360          coders.StrUtf8Coder())
   361      splits = list(source.split(desired_bundle_size=33))
   362  
   363      reference_source_info = (source, None, None)
   364      sources_info = ([(split.source, split.start_position, split.stop_position)
   365                       for split in splits])
   366      source_test_utils.assert_sources_equal_reference_source(
   367          reference_source_info, sources_info)
   368  
   369    def test_header_processing(self):
   370      file_name, expected_data = write_data(10)
   371      assert len(expected_data) == 10
   372  
   373      def header_matcher(line):
   374        return line in expected_data[:5]
   375  
   376      header_lines = []
   377  
   378      def store_header(lines):
   379        for line in lines:
   380          header_lines.append(line)
   381  
   382      source = TextSource(
   383          file_name,
   384          0,
   385          CompressionTypes.UNCOMPRESSED,
   386          True,
   387          coders.StrUtf8Coder(),
   388          header_processor_fns=(header_matcher, store_header))
   389      splits = list(source.split(desired_bundle_size=100000))
   390      assert len(splits) == 1
   391      range_tracker = splits[0].source.get_range_tracker(
   392          splits[0].start_position, splits[0].stop_position)
   393      read_data = list(source.read_records(file_name, range_tracker))
   394  
   395      self.assertCountEqual(expected_data[:5], header_lines)
   396      self.assertCountEqual(expected_data[5:], read_data)
   397  
   398    def test_progress(self):
   399      file_name, expected_data = write_data(10)
   400      assert len(expected_data) == 10
   401      source = TextSource(
   402          file_name,
   403          0,
   404          CompressionTypes.UNCOMPRESSED,
   405          True,
   406          coders.StrUtf8Coder())
   407      splits = list(source.split(desired_bundle_size=100000))
   408      assert len(splits) == 1
   409      fraction_consumed_report = []
   410      split_points_report = []
   411      range_tracker = splits[0].source.get_range_tracker(
   412          splits[0].start_position, splits[0].stop_position)
   413      for _ in splits[0].source.read(range_tracker):
   414        fraction_consumed_report.append(range_tracker.fraction_consumed())
   415        split_points_report.append(range_tracker.split_points())
   416  
   417      self.assertEqual([float(i) / 10 for i in range(0, 10)],
   418                       fraction_consumed_report)
   419      expected_split_points_report = [((i - 1),
   420                                       iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)
   421                                      for i in range(1, 10)]
   422  
   423      # At last split point, the remaining split points callback returns 1 since
   424      # the expected position of next record becomes equal to the stop position.
   425      expected_split_points_report.append((9, 1))
   426  
   427      self.assertEqual(expected_split_points_report, split_points_report)
   428  
   429    def test_read_reentrant_without_splitting(self):
   430      file_name, expected_data = write_data(10)
   431      assert len(expected_data) == 10
   432      source = TextSource(
   433          file_name,
   434          0,
   435          CompressionTypes.UNCOMPRESSED,
   436          True,
   437          coders.StrUtf8Coder())
   438      source_test_utils.assert_reentrant_reads_succeed((source, None, None))
   439  
   440    def test_read_reentrant_after_splitting(self):
   441      file_name, expected_data = write_data(10)
   442      assert len(expected_data) == 10
   443      source = TextSource(
   444          file_name,
   445          0,
   446          CompressionTypes.UNCOMPRESSED,
   447          True,
   448          coders.StrUtf8Coder())
   449      splits = list(source.split(desired_bundle_size=100000))
   450      assert len(splits) == 1
   451      source_test_utils.assert_reentrant_reads_succeed(
   452          (splits[0].source, splits[0].start_position, splits[0].stop_position))
   453  
   454    def test_dynamic_work_rebalancing(self):
   455      file_name, expected_data = write_data(5)
   456      assert len(expected_data) == 5
   457      source = TextSource(
   458          file_name,
   459          0,
   460          CompressionTypes.UNCOMPRESSED,
   461          True,
   462          coders.StrUtf8Coder())
   463      splits = list(source.split(desired_bundle_size=100000))
   464      assert len(splits) == 1
   465      source_test_utils.assert_split_at_fraction_exhaustive(
   466          splits[0].source, splits[0].start_position, splits[0].stop_position)
   467  
   468    def test_dynamic_work_rebalancing_windows_eol(self):
   469      file_name, expected_data = write_data(15, eol=EOL.CRLF)
   470      assert len(expected_data) == 15
   471      source = TextSource(
   472          file_name,
   473          0,
   474          CompressionTypes.UNCOMPRESSED,
   475          True,
   476          coders.StrUtf8Coder())
   477      splits = list(source.split(desired_bundle_size=100000))
   478      assert len(splits) == 1
   479      source_test_utils.assert_split_at_fraction_exhaustive(
   480          splits[0].source,
   481          splits[0].start_position,
   482          splits[0].stop_position,
   483          perform_multi_threaded_test=False)
   484  
   485    def test_dynamic_work_rebalancing_mixed_eol(self):
   486      file_name, expected_data = write_data(5, eol=EOL.MIXED)
   487      assert len(expected_data) == 5
   488      source = TextSource(
   489          file_name,
   490          0,
   491          CompressionTypes.UNCOMPRESSED,
   492          True,
   493          coders.StrUtf8Coder())
   494      splits = list(source.split(desired_bundle_size=100000))
   495      assert len(splits) == 1
   496      source_test_utils.assert_split_at_fraction_exhaustive(
   497          splits[0].source,
   498          splits[0].start_position,
   499          splits[0].stop_position,
   500          perform_multi_threaded_test=False)
   501  
   502    def test_read_from_text_single_file(self):
   503      file_name, expected_data = write_data(5)
   504      assert len(expected_data) == 5
   505      with TestPipeline() as pipeline:
   506        pcoll = pipeline | 'Read' >> ReadFromText(file_name)
   507        assert_that(pcoll, equal_to(expected_data))
   508  
   509    def test_read_from_text_with_file_name_single_file(self):
   510      file_name, data = write_data(5)
   511      expected_data = [(file_name, el) for el in data]
   512      assert len(expected_data) == 5
   513      with TestPipeline() as pipeline:
   514        pcoll = pipeline | 'Read' >> ReadFromTextWithFilename(file_name)
   515        assert_that(pcoll, equal_to(expected_data))
   516  
   517    def test_read_all_single_file(self):
   518      file_name, expected_data = write_data(5)
   519      assert len(expected_data) == 5
   520      with TestPipeline() as pipeline:
   521        pcoll = pipeline | 'Create' >> Create(
   522            [file_name]) | 'ReadAll' >> ReadAllFromText()
   523        assert_that(pcoll, equal_to(expected_data))
   524  
   525    def test_read_all_many_single_files(self):
   526      file_name1, expected_data1 = write_data(5)
   527      assert len(expected_data1) == 5
   528      file_name2, expected_data2 = write_data(10)
   529      assert len(expected_data2) == 10
   530      file_name3, expected_data3 = write_data(15)
   531      assert len(expected_data3) == 15
   532      expected_data = []
   533      expected_data.extend(expected_data1)
   534      expected_data.extend(expected_data2)
   535      expected_data.extend(expected_data3)
   536      with TestPipeline() as pipeline:
   537        pcoll = pipeline | 'Create' >> Create([
   538            file_name1, file_name2, file_name3
   539        ]) | 'ReadAll' >> ReadAllFromText()
   540        assert_that(pcoll, equal_to(expected_data))
   541  
   542    def test_read_all_unavailable_files_ignored(self):
   543      file_name1, expected_data1 = write_data(5)
   544      assert len(expected_data1) == 5
   545      file_name2, expected_data2 = write_data(10)
   546      assert len(expected_data2) == 10
   547      file_name3, expected_data3 = write_data(15)
   548      assert len(expected_data3) == 15
   549      file_name4 = "/unavailable_file"
   550      expected_data = []
   551      expected_data.extend(expected_data1)
   552      expected_data.extend(expected_data2)
   553      expected_data.extend(expected_data3)
   554      with TestPipeline() as pipeline:
   555        pcoll = (
   556            pipeline
   557            | 'Create' >> Create([file_name1, file_name2, file_name3, file_name4])
   558            | 'ReadAll' >> ReadAllFromText())
   559        assert_that(pcoll, equal_to(expected_data))
   560  
   561    class _WriteFilesFn(beam.DoFn):
   562      """writes a couple of files with deferral."""
   563      COUNT_STATE = CombiningValueStateSpec('count', combine_fn=sum)
   564  
   565      def __init__(self, temp_path):
   566        self.temp_path = temp_path
   567  
   568      def process(self, element, count_state=beam.DoFn.StateParam(COUNT_STATE)):
   569        counter = count_state.read()
   570        if counter == 0:
   571          count_state.add(1)
   572          with open(FileSystems.join(self.temp_path, 'file1'), 'w') as f:
   573            f.write('second A\nsecond B')
   574          with open(FileSystems.join(self.temp_path, 'file2'), 'w') as f:
   575            f.write('first')
   576        # convert dumb key to basename in output
   577        basename = FileSystems.split(element[1][0])[1]
   578        content = element[1][1]
   579        yield basename, content
   580  
   581    def test_read_all_continuously_new(self):
   582      with TempDir() as tempdir, TestPipeline() as pipeline:
   583        temp_path = tempdir.get_path()
   584        # create a temp file at the beginning
   585        with open(FileSystems.join(temp_path, 'file1'), 'w') as f:
   586          f.write('first')
   587        match_pattern = FileSystems.join(temp_path, '*')
   588        interval = 0.5
   589        last = 2
   590        p_read_once = (
   591            pipeline
   592            | 'Continuously read new files' >> ReadAllFromTextContinuously(
   593                match_pattern,
   594                with_filename=True,
   595                start_timestamp=Timestamp.now(),
   596                interval=interval,
   597                stop_timestamp=Timestamp.now() + last,
   598                match_updated_files=False)
   599            | 'add dumb key' >> beam.Map(lambda x: (0, x))
   600            |
   601            'Write files on-the-fly' >> beam.ParDo(self._WriteFilesFn(temp_path)))
   602        assert_that(
   603            p_read_once,
   604            equal_to([('file1', 'first'), ('file2', 'first')]),
   605            label='assert read new files results')
   606  
   607    def test_read_all_continuously_update(self):
   608      with TempDir() as tempdir, TestPipeline() as pipeline:
   609        temp_path = tempdir.get_path()
   610        # create a temp file at the beginning
   611        with open(FileSystems.join(temp_path, 'file1'), 'w') as f:
   612          f.write('first')
   613        match_pattern = FileSystems.join(temp_path, '*')
   614        interval = 0.5
   615        last = 2
   616        p_read_upd = (
   617            pipeline
   618            | 'Continuously read updated files' >> ReadAllFromTextContinuously(
   619                match_pattern,
   620                with_filename=True,
   621                start_timestamp=Timestamp.now(),
   622                interval=interval,
   623                stop_timestamp=Timestamp.now() + last,
   624                match_updated_files=True)
   625            | 'add dumb key' >> beam.Map(lambda x: (0, x))
   626            |
   627            'Write files on-the-fly' >> beam.ParDo(self._WriteFilesFn(temp_path)))
   628        assert_that(
   629            p_read_upd,
   630            equal_to([('file1', 'first'), ('file1', 'second A'),
   631                      ('file1', 'second B'), ('file2', 'first')]),
   632            label='assert read updated files results')
   633  
   634    def test_read_from_text_single_file_with_coder(self):
   635      file_name, expected_data = write_data(5)
   636      assert len(expected_data) == 5
   637      with TestPipeline() as pipeline:
   638        pcoll = pipeline | 'Read' >> ReadFromText(file_name, coder=DummyCoder())
   639        assert_that(pcoll, equal_to([record * 2 for record in expected_data]))
   640  
   641    def test_read_from_text_file_pattern(self):
   642      pattern, expected_data = write_pattern([5, 3, 12, 8, 8, 4])
   643      assert len(expected_data) == 40
   644      with TestPipeline() as pipeline:
   645        pcoll = pipeline | 'Read' >> ReadFromText(pattern)
   646        assert_that(pcoll, equal_to(expected_data))
   647  
   648    def test_read_from_text_with_file_name_file_pattern(self):
   649      pattern, expected_data = write_pattern(
   650          lines_per_file=[5, 5], return_filenames=True)
   651      assert len(expected_data) == 10
   652      with TestPipeline() as pipeline:
   653        pcoll = pipeline | 'Read' >> ReadFromTextWithFilename(pattern)
   654        assert_that(pcoll, equal_to(expected_data))
   655  
   656    def test_read_all_file_pattern(self):
   657      pattern, expected_data = write_pattern([5, 3, 12, 8, 8, 4])
   658      assert len(expected_data) == 40
   659      with TestPipeline() as pipeline:
   660        pcoll = (
   661            pipeline
   662            | 'Create' >> Create([pattern])
   663            | 'ReadAll' >> ReadAllFromText())
   664        assert_that(pcoll, equal_to(expected_data))
   665  
   666    def test_read_all_many_file_patterns(self):
   667      pattern1, expected_data1 = write_pattern([5, 3, 12, 8, 8, 4])
   668      assert len(expected_data1) == 40
   669      pattern2, expected_data2 = write_pattern([3, 7, 9])
   670      assert len(expected_data2) == 19
   671      pattern3, expected_data3 = write_pattern([11, 20, 5, 5])
   672      assert len(expected_data3) == 41
   673      expected_data = []
   674      expected_data.extend(expected_data1)
   675      expected_data.extend(expected_data2)
   676      expected_data.extend(expected_data3)
   677      with TestPipeline() as pipeline:
   678        pcoll = pipeline | 'Create' >> Create(
   679            [pattern1, pattern2, pattern3]) | 'ReadAll' >> ReadAllFromText()
   680        assert_that(pcoll, equal_to(expected_data))
   681  
   682    def test_read_all_with_filename(self):
   683      pattern, expected_data = write_pattern([5, 3], return_filenames=True)
   684      assert len(expected_data) == 8
   685  
   686      with TestPipeline() as pipeline:
   687        pcoll = (
   688            pipeline
   689            | 'Create' >> Create([pattern])
   690            | 'ReadAll' >> ReadAllFromText(with_filename=True))
   691        assert_that(pcoll, equal_to(expected_data))
   692  
   693    def test_read_auto_bzip2(self):
   694      _, lines = write_data(15)
   695      with TempDir() as tempdir:
   696        file_name = tempdir.create_temp_file(suffix='.bz2')
   697        with bz2.BZ2File(file_name, 'wb') as f:
   698          f.write('\n'.join(lines).encode('utf-8'))
   699  
   700        with TestPipeline() as pipeline:
   701          pcoll = pipeline | 'Read' >> ReadFromText(file_name)
   702          assert_that(pcoll, equal_to(lines))
   703  
   704    def test_read_auto_deflate(self):
   705      _, lines = write_data(15)
   706      with TempDir() as tempdir:
   707        file_name = tempdir.create_temp_file(suffix='.deflate')
   708        with open(file_name, 'wb') as f:
   709          f.write(zlib.compress('\n'.join(lines).encode('utf-8')))
   710  
   711        with TestPipeline() as pipeline:
   712          pcoll = pipeline | 'Read' >> ReadFromText(file_name)
   713          assert_that(pcoll, equal_to(lines))
   714  
   715    def test_read_auto_gzip(self):
   716      _, lines = write_data(15)
   717      with TempDir() as tempdir:
   718        file_name = tempdir.create_temp_file(suffix='.gz')
   719  
   720        with gzip.GzipFile(file_name, 'wb') as f:
   721          f.write('\n'.join(lines).encode('utf-8'))
   722  
   723        with TestPipeline() as pipeline:
   724          pcoll = pipeline | 'Read' >> ReadFromText(file_name)
   725          assert_that(pcoll, equal_to(lines))
   726  
   727    def test_read_bzip2(self):
   728      _, lines = write_data(15)
   729      with TempDir() as tempdir:
   730        file_name = tempdir.create_temp_file()
   731        with bz2.BZ2File(file_name, 'wb') as f:
   732          f.write('\n'.join(lines).encode('utf-8'))
   733  
   734        with TestPipeline() as pipeline:
   735          pcoll = pipeline | 'Read' >> ReadFromText(
   736              file_name, compression_type=CompressionTypes.BZIP2)
   737          assert_that(pcoll, equal_to(lines))
   738  
   739    def test_read_corrupted_bzip2_fails(self):
   740      _, lines = write_data(15)
   741      with TempDir() as tempdir:
   742        file_name = tempdir.create_temp_file()
   743        with bz2.BZ2File(file_name, 'wb') as f:
   744          f.write('\n'.join(lines).encode('utf-8'))
   745  
   746        with open(file_name, 'wb') as f:
   747          f.write(b'corrupt')
   748  
   749        with self.assertRaises(Exception):
   750          with TestPipeline() as pipeline:
   751            pcoll = pipeline | 'Read' >> ReadFromText(
   752                file_name, compression_type=CompressionTypes.BZIP2)
   753            assert_that(pcoll, equal_to(lines))
   754  
   755    def test_read_bzip2_concat(self):
   756      with TempDir() as tempdir:
   757        bzip2_file_name1 = tempdir.create_temp_file()
   758        lines = ['a', 'b', 'c']
   759        with bz2.BZ2File(bzip2_file_name1, 'wb') as dst:
   760          data = '\n'.join(lines) + '\n'
   761          dst.write(data.encode('utf-8'))
   762  
   763        bzip2_file_name2 = tempdir.create_temp_file()
   764        lines = ['p', 'q', 'r']
   765        with bz2.BZ2File(bzip2_file_name2, 'wb') as dst:
   766          data = '\n'.join(lines) + '\n'
   767          dst.write(data.encode('utf-8'))
   768  
   769        bzip2_file_name3 = tempdir.create_temp_file()
   770        lines = ['x', 'y', 'z']
   771        with bz2.BZ2File(bzip2_file_name3, 'wb') as dst:
   772          data = '\n'.join(lines) + '\n'
   773          dst.write(data.encode('utf-8'))
   774  
   775        final_bzip2_file = tempdir.create_temp_file()
   776        with open(bzip2_file_name1, 'rb') as src, open(
   777            final_bzip2_file, 'wb') as dst:
   778          dst.writelines(src.readlines())
   779  
   780        with open(bzip2_file_name2, 'rb') as src, open(
   781            final_bzip2_file, 'ab') as dst:
   782          dst.writelines(src.readlines())
   783  
   784        with open(bzip2_file_name3, 'rb') as src, open(
   785            final_bzip2_file, 'ab') as dst:
   786          dst.writelines(src.readlines())
   787  
   788        with TestPipeline() as pipeline:
   789          lines = pipeline | 'ReadFromText' >> beam.io.ReadFromText(
   790              final_bzip2_file,
   791              compression_type=beam.io.filesystem.CompressionTypes.BZIP2)
   792  
   793          expected = ['a', 'b', 'c', 'p', 'q', 'r', 'x', 'y', 'z']
   794          assert_that(lines, equal_to(expected))
   795  
   796    def test_read_deflate(self):
   797      _, lines = write_data(15)
   798      with TempDir() as tempdir:
   799        file_name = tempdir.create_temp_file()
   800        with open(file_name, 'wb') as f:
   801          f.write(zlib.compress('\n'.join(lines).encode('utf-8')))
   802  
   803        with TestPipeline() as pipeline:
   804          pcoll = pipeline | 'Read' >> ReadFromText(
   805              file_name, 0, CompressionTypes.DEFLATE, True, coders.StrUtf8Coder())
   806          assert_that(pcoll, equal_to(lines))
   807  
   808    def test_read_corrupted_deflate_fails(self):
   809      _, lines = write_data(15)
   810      with TempDir() as tempdir:
   811        file_name = tempdir.create_temp_file()
   812        with open(file_name, 'wb') as f:
   813          f.write(zlib.compress('\n'.join(lines).encode('utf-8')))
   814  
   815        with open(file_name, 'wb') as f:
   816          f.write(b'corrupt')
   817  
   818        with self.assertRaises(Exception):
   819          with TestPipeline() as pipeline:
   820            pcoll = pipeline | 'Read' >> ReadFromText(
   821                file_name,
   822                0,
   823                CompressionTypes.DEFLATE,
   824                True,
   825                coders.StrUtf8Coder())
   826            assert_that(pcoll, equal_to(lines))
   827  
   828    def test_read_deflate_concat(self):
   829      with TempDir() as tempdir:
   830        deflate_file_name1 = tempdir.create_temp_file()
   831        lines = ['a', 'b', 'c']
   832        with open(deflate_file_name1, 'wb') as dst:
   833          data = '\n'.join(lines) + '\n'
   834          dst.write(zlib.compress(data.encode('utf-8')))
   835  
   836        deflate_file_name2 = tempdir.create_temp_file()
   837        lines = ['p', 'q', 'r']
   838        with open(deflate_file_name2, 'wb') as dst:
   839          data = '\n'.join(lines) + '\n'
   840          dst.write(zlib.compress(data.encode('utf-8')))
   841  
   842        deflate_file_name3 = tempdir.create_temp_file()
   843        lines = ['x', 'y', 'z']
   844        with open(deflate_file_name3, 'wb') as dst:
   845          data = '\n'.join(lines) + '\n'
   846          dst.write(zlib.compress(data.encode('utf-8')))
   847  
   848        final_deflate_file = tempdir.create_temp_file()
   849        with open(deflate_file_name1, 'rb') as src, \
   850                open(final_deflate_file, 'wb') as dst:
   851          dst.writelines(src.readlines())
   852  
   853        with open(deflate_file_name2, 'rb') as src, \
   854                open(final_deflate_file, 'ab') as dst:
   855          dst.writelines(src.readlines())
   856  
   857        with open(deflate_file_name3, 'rb') as src, \
   858                open(final_deflate_file, 'ab') as dst:
   859          dst.writelines(src.readlines())
   860  
   861        with TestPipeline() as pipeline:
   862          lines = pipeline | 'ReadFromText' >> beam.io.ReadFromText(
   863              final_deflate_file,
   864              compression_type=beam.io.filesystem.CompressionTypes.DEFLATE)
   865  
   866          expected = ['a', 'b', 'c', 'p', 'q', 'r', 'x', 'y', 'z']
   867          assert_that(lines, equal_to(expected))
   868  
   869    def test_read_gzip(self):
   870      _, lines = write_data(15)
   871      with TempDir() as tempdir:
   872        file_name = tempdir.create_temp_file()
   873        with gzip.GzipFile(file_name, 'wb') as f:
   874          f.write('\n'.join(lines).encode('utf-8'))
   875  
   876        with TestPipeline() as pipeline:
   877          pcoll = pipeline | 'Read' >> ReadFromText(
   878              file_name, 0, CompressionTypes.GZIP, True, coders.StrUtf8Coder())
   879          assert_that(pcoll, equal_to(lines))
   880  
   881    def test_read_corrupted_gzip_fails(self):
   882      _, lines = write_data(15)
   883      with TempDir() as tempdir:
   884        file_name = tempdir.create_temp_file()
   885        with gzip.GzipFile(file_name, 'wb') as f:
   886          f.write('\n'.join(lines).encode('utf-8'))
   887  
   888        with open(file_name, 'wb') as f:
   889          f.write(b'corrupt')
   890  
   891        with self.assertRaises(Exception):
   892          with TestPipeline() as pipeline:
   893            pcoll = pipeline | 'Read' >> ReadFromText(
   894                file_name, 0, CompressionTypes.GZIP, True, coders.StrUtf8Coder())
   895            assert_that(pcoll, equal_to(lines))
   896  
   897    def test_read_gzip_concat(self):
   898      with TempDir() as tempdir:
   899        gzip_file_name1 = tempdir.create_temp_file()
   900        lines = ['a', 'b', 'c']
   901        with gzip.open(gzip_file_name1, 'wb') as dst:
   902          data = '\n'.join(lines) + '\n'
   903          dst.write(data.encode('utf-8'))
   904  
   905        gzip_file_name2 = tempdir.create_temp_file()
   906        lines = ['p', 'q', 'r']
   907        with gzip.open(gzip_file_name2, 'wb') as dst:
   908          data = '\n'.join(lines) + '\n'
   909          dst.write(data.encode('utf-8'))
   910  
   911        gzip_file_name3 = tempdir.create_temp_file()
   912        lines = ['x', 'y', 'z']
   913        with gzip.open(gzip_file_name3, 'wb') as dst:
   914          data = '\n'.join(lines) + '\n'
   915          dst.write(data.encode('utf-8'))
   916  
   917        final_gzip_file = tempdir.create_temp_file()
   918        with open(gzip_file_name1, 'rb') as src, \
   919             open(final_gzip_file, 'wb') as dst:
   920          dst.writelines(src.readlines())
   921  
   922        with open(gzip_file_name2, 'rb') as src, \
   923             open(final_gzip_file, 'ab') as dst:
   924          dst.writelines(src.readlines())
   925  
   926        with open(gzip_file_name3, 'rb') as src, \
   927             open(final_gzip_file, 'ab') as dst:
   928          dst.writelines(src.readlines())
   929  
   930        with TestPipeline() as pipeline:
   931          lines = pipeline | 'ReadFromText' >> beam.io.ReadFromText(
   932              final_gzip_file,
   933              compression_type=beam.io.filesystem.CompressionTypes.GZIP)
   934  
   935          expected = ['a', 'b', 'c', 'p', 'q', 'r', 'x', 'y', 'z']
   936          assert_that(lines, equal_to(expected))
   937  
   938    def test_read_all_gzip(self):
   939      _, lines = write_data(100)
   940      with TempDir() as tempdir:
   941        file_name = tempdir.create_temp_file()
   942        with gzip.GzipFile(file_name, 'wb') as f:
   943          f.write('\n'.join(lines).encode('utf-8'))
   944        with TestPipeline() as pipeline:
   945          pcoll = (
   946              pipeline
   947              | Create([file_name])
   948              | 'ReadAll' >>
   949              ReadAllFromText(compression_type=CompressionTypes.GZIP))
   950          assert_that(pcoll, equal_to(lines))
   951  
   952    def test_read_gzip_large(self):
   953      _, lines = write_data(10000)
   954      with TempDir() as tempdir:
   955        file_name = tempdir.create_temp_file()
   956  
   957        with gzip.GzipFile(file_name, 'wb') as f:
   958          f.write('\n'.join(lines).encode('utf-8'))
   959  
   960        with TestPipeline() as pipeline:
   961          pcoll = pipeline | 'Read' >> ReadFromText(
   962              file_name, 0, CompressionTypes.GZIP, True, coders.StrUtf8Coder())
   963          assert_that(pcoll, equal_to(lines))
   964  
   965    def test_read_gzip_large_after_splitting(self):
   966      _, lines = write_data(10000)
   967      with TempDir() as tempdir:
   968        file_name = tempdir.create_temp_file()
   969        with gzip.GzipFile(file_name, 'wb') as f:
   970          f.write('\n'.join(lines).encode('utf-8'))
   971  
   972        source = TextSource(
   973            file_name, 0, CompressionTypes.GZIP, True, coders.StrUtf8Coder())
   974        splits = list(source.split(desired_bundle_size=1000))
   975  
   976        if len(splits) > 1:
   977          raise ValueError(
   978              'FileBasedSource generated more than one initial '
   979              'split for a compressed file.')
   980  
   981        reference_source_info = (source, None, None)
   982        sources_info = ([
   983            (split.source, split.start_position, split.stop_position)
   984            for split in splits
   985        ])
   986        source_test_utils.assert_sources_equal_reference_source(
   987            reference_source_info, sources_info)
   988  
   989    def test_read_gzip_empty_file(self):
   990      with TempDir() as tempdir:
   991        file_name = tempdir.create_temp_file()
   992        with TestPipeline() as pipeline:
   993          pcoll = pipeline | 'Read' >> ReadFromText(
   994              file_name, 0, CompressionTypes.GZIP, True, coders.StrUtf8Coder())
   995          assert_that(pcoll, equal_to([]))
   996  
   997    def _remove_lines(self, lines, sublist_lengths, num_to_remove):
   998      """Utility function to remove num_to_remove lines from each sublist.
   999  
  1000      Args:
  1001        lines: list of items.
  1002        sublist_lengths: list of integers representing length of sublist
  1003          corresponding to each source file.
  1004        num_to_remove: number of lines to remove from each sublist.
  1005      Returns:
  1006        remaining lines.
  1007      """
  1008      curr = 0
  1009      result = []
  1010      for offset in sublist_lengths:
  1011        end = curr + offset
  1012        start = min(curr + num_to_remove, end)
  1013        result += lines[start:end]
  1014        curr += offset
  1015      return result
  1016  
  1017    def _read_skip_header_lines(self, file_or_pattern, skip_header_lines):
  1018      """Simple wrapper function for instantiating TextSource."""
  1019      source = TextSource(
  1020          file_or_pattern,
  1021          0,
  1022          CompressionTypes.UNCOMPRESSED,
  1023          True,
  1024          coders.StrUtf8Coder(),
  1025          skip_header_lines=skip_header_lines)
  1026  
  1027      range_tracker = source.get_range_tracker(None, None)
  1028      return list(source.read(range_tracker))
  1029  
  1030    def test_read_skip_header_single(self):
  1031      file_name, expected_data = write_data(TextSourceTest.DEFAULT_NUM_RECORDS)
  1032      assert len(expected_data) == TextSourceTest.DEFAULT_NUM_RECORDS
  1033      skip_header_lines = 1
  1034      expected_data = self._remove_lines(
  1035          expected_data, [TextSourceTest.DEFAULT_NUM_RECORDS], skip_header_lines)
  1036      read_data = self._read_skip_header_lines(file_name, skip_header_lines)
  1037      self.assertEqual(len(expected_data), len(read_data))
  1038      self.assertCountEqual(expected_data, read_data)
  1039  
  1040    def test_read_skip_header_pattern(self):
  1041      line_counts = [
  1042          TextSourceTest.DEFAULT_NUM_RECORDS * 5,
  1043          TextSourceTest.DEFAULT_NUM_RECORDS * 3,
  1044          TextSourceTest.DEFAULT_NUM_RECORDS * 12,
  1045          TextSourceTest.DEFAULT_NUM_RECORDS * 8,
  1046          TextSourceTest.DEFAULT_NUM_RECORDS * 8,
  1047          TextSourceTest.DEFAULT_NUM_RECORDS * 4
  1048      ]
  1049      skip_header_lines = 2
  1050      pattern, data = write_pattern(line_counts)
  1051  
  1052      expected_data = self._remove_lines(data, line_counts, skip_header_lines)
  1053      read_data = self._read_skip_header_lines(pattern, skip_header_lines)
  1054      self.assertEqual(len(expected_data), len(read_data))
  1055      self.assertCountEqual(expected_data, read_data)
  1056  
  1057    def test_read_skip_header_pattern_insufficient_lines(self):
  1058      line_counts = [
  1059          5,
  1060          3,  # Fewer lines in file than we want to skip
  1061          12,
  1062          8,
  1063          8,
  1064          4
  1065      ]
  1066      skip_header_lines = 4
  1067      pattern, data = write_pattern(line_counts)
  1068  
  1069      data = self._remove_lines(data, line_counts, skip_header_lines)
  1070      read_data = self._read_skip_header_lines(pattern, skip_header_lines)
  1071      self.assertEqual(len(data), len(read_data))
  1072      self.assertCountEqual(data, read_data)
  1073  
  1074    def test_read_gzip_with_skip_lines(self):
  1075      _, lines = write_data(15)
  1076      with TempDir() as tempdir:
  1077        file_name = tempdir.create_temp_file()
  1078        with gzip.GzipFile(file_name, 'wb') as f:
  1079          f.write('\n'.join(lines).encode('utf-8'))
  1080  
  1081        with TestPipeline() as pipeline:
  1082          pcoll = pipeline | 'Read' >> ReadFromText(
  1083              file_name,
  1084              0,
  1085              CompressionTypes.GZIP,
  1086              True,
  1087              coders.StrUtf8Coder(),
  1088              skip_header_lines=2)
  1089          assert_that(pcoll, equal_to(lines[2:]))
  1090  
  1091    def test_read_after_splitting_skip_header(self):
  1092      file_name, expected_data = write_data(100)
  1093      assert len(expected_data) == 100
  1094      source = TextSource(
  1095          file_name,
  1096          0,
  1097          CompressionTypes.UNCOMPRESSED,
  1098          True,
  1099          coders.StrUtf8Coder(),
  1100          skip_header_lines=2)
  1101      splits = list(source.split(desired_bundle_size=33))
  1102  
  1103      reference_source_info = (source, None, None)
  1104      sources_info = ([(split.source, split.start_position, split.stop_position)
  1105                       for split in splits])
  1106      self.assertGreater(len(sources_info), 1)
  1107      reference_lines = source_test_utils.read_from_source(*reference_source_info)
  1108      split_lines = []
  1109      for source_info in sources_info:
  1110        split_lines.extend(source_test_utils.read_from_source(*source_info))
  1111  
  1112      self.assertEqual(expected_data[2:], reference_lines)
  1113      self.assertEqual(reference_lines, split_lines)
  1114  
  1115    def test_custom_delimiter_read_from_text(self):
  1116      file_name, expected_data = write_data(
  1117        5, eol=EOL.CUSTOM_DELIMITER, custom_delimiter=b'@#')
  1118      assert len(expected_data) == 5
  1119      with TestPipeline() as pipeline:
  1120        pcoll = pipeline | 'Read' >> ReadFromText(file_name, delimiter=b'@#')
  1121        assert_that(pcoll, equal_to(expected_data))
  1122  
  1123    def test_custom_delimiter_read_all_single_file(self):
  1124      file_name, expected_data = write_data(
  1125        5, eol=EOL.CUSTOM_DELIMITER, custom_delimiter=b'@#')
  1126      assert len(expected_data) == 5
  1127      with TestPipeline() as pipeline:
  1128        pcoll = pipeline | 'Create' >> Create(
  1129            [file_name]) | 'ReadAll' >> ReadAllFromText(delimiter=b'@#')
  1130        assert_that(pcoll, equal_to(expected_data))
  1131  
  1132    def test_invalid_delimiters_are_rejected(self):
  1133      file_name, _ = write_data(1)
  1134      for delimiter in (b'', '', '\r\n', 'a', 1):
  1135        with self.assertRaises(
  1136            ValueError, msg='Delimiter must be a non-empty bytes sequence.'):
  1137          _ = TextSource(
  1138              file_pattern=file_name,
  1139              min_bundle_size=0,
  1140              buffer_size=6,
  1141              compression_type=CompressionTypes.UNCOMPRESSED,
  1142              strip_trailing_newlines=True,
  1143              coder=coders.StrUtf8Coder(),
  1144              delimiter=delimiter,
  1145          )
  1146  
  1147    def test_non_self_overlapping_delimiter_is_accepted(self):
  1148      file_name, _ = write_data(1)
  1149      for delimiter in (b'\n', b'\r\n', b'*', b'abc', b'cabdab', b'abcabd'):
  1150        _ = TextSource(
  1151            file_pattern=file_name,
  1152            min_bundle_size=0,
  1153            buffer_size=6,
  1154            compression_type=CompressionTypes.UNCOMPRESSED,
  1155            strip_trailing_newlines=True,
  1156            coder=coders.StrUtf8Coder(),
  1157            delimiter=delimiter,
  1158        )
  1159  
  1160    def test_self_overlapping_delimiter_is_rejected(self):
  1161      file_name, _ = write_data(1)
  1162      for delimiter in (b'||', b'***', b'aba', b'abcab'):
  1163        with self.assertRaises(ValueError,
  1164                               msg='Delimiter must not self-overlap.'):
  1165          _ = TextSource(
  1166              file_pattern=file_name,
  1167              min_bundle_size=0,
  1168              buffer_size=6,
  1169              compression_type=CompressionTypes.UNCOMPRESSED,
  1170              strip_trailing_newlines=True,
  1171              coder=coders.StrUtf8Coder(),
  1172              delimiter=delimiter,
  1173          )
  1174  
  1175    def test_read_with_customer_delimiter(self):
  1176      delimiters = [
  1177          b'\n',
  1178          b'\r\n',
  1179          b'*|',
  1180          b'*',
  1181          b'*=-',
  1182      ]
  1183  
  1184      for delimiter in delimiters:
  1185        file_name, expected_data = write_data(
  1186          10,
  1187          eol=EOL.CUSTOM_DELIMITER,
  1188          custom_delimiter=delimiter)
  1189  
  1190        assert len(expected_data) == 10
  1191        source = TextSource(
  1192            file_pattern=file_name,
  1193            min_bundle_size=0,
  1194            compression_type=CompressionTypes.UNCOMPRESSED,
  1195            strip_trailing_newlines=True,
  1196            coder=coders.StrUtf8Coder(),
  1197            delimiter=delimiter)
  1198        range_tracker = source.get_range_tracker(None, None)
  1199        read_data = list(source.read(range_tracker))
  1200  
  1201        self.assertEqual(read_data, expected_data)
  1202  
  1203    def test_read_with_custom_delimiter_around_split_point(self):
  1204      for delimiter in (b'\n', b'\r\n', b'@#', b'abc'):
  1205        file_name, expected_data = write_data(
  1206          20,
  1207          eol=EOL.CUSTOM_DELIMITER,
  1208          custom_delimiter=delimiter)
  1209        assert len(expected_data) == 20
  1210        for desired_bundle_size in (4, 5, 6, 7):
  1211          source = TextSource(
  1212              file_name,
  1213              0,
  1214              CompressionTypes.UNCOMPRESSED,
  1215              True,
  1216              coders.StrUtf8Coder(),
  1217              delimiter=delimiter)
  1218          splits = list(source.split(desired_bundle_size=desired_bundle_size))
  1219  
  1220          reference_source_info = (source, None, None)
  1221          sources_info = ([
  1222              (split.source, split.start_position, split.stop_position)
  1223              for split in splits
  1224          ])
  1225          source_test_utils.assert_sources_equal_reference_source(
  1226              reference_source_info, sources_info)
  1227  
  1228    def test_read_with_customer_delimiter_truncated(self):
  1229      """
  1230      Corner case: delimiter truncated at the end of the file
  1231      Use delimiter with length = 3, buffer_size = 6
  1232      and line_value with length = 4
  1233      to split the delimiter
  1234      """
  1235      delimiter = b'@$*'
  1236  
  1237      file_name, expected_data = write_data(
  1238        10,
  1239        eol=EOL.CUSTOM_DELIMITER,
  1240        line_value=b'a' * 4,
  1241        custom_delimiter=delimiter)
  1242  
  1243      assert len(expected_data) == 10
  1244      source = TextSource(
  1245          file_pattern=file_name,
  1246          min_bundle_size=0,
  1247          buffer_size=6,
  1248          compression_type=CompressionTypes.UNCOMPRESSED,
  1249          strip_trailing_newlines=True,
  1250          coder=coders.StrUtf8Coder(),
  1251          delimiter=delimiter,
  1252      )
  1253      range_tracker = source.get_range_tracker(None, None)
  1254      read_data = list(source.read(range_tracker))
  1255  
  1256      self.assertEqual(read_data, expected_data)
  1257  
  1258    def test_read_with_customer_delimiter_over_buffer_size(self):
  1259      """
  1260      Corner case: delimiter is on border of size of buffer
  1261      """
  1262      file_name, expected_data = write_data(3, eol=EOL.CRLF, line_value=b'\rline')
  1263      assert len(expected_data) == 3
  1264      self._run_read_test(
  1265          file_name, expected_data, buffer_size=7, delimiter=b'\r\n')
  1266  
  1267    def test_read_with_customer_delimiter_truncated_and_not_equal(self):
  1268      """
  1269      Corner case: delimiter truncated at the end of the file
  1270      and only part of delimiter equal end of buffer
  1271  
  1272      Use delimiter with length = 3, buffer_size = 6
  1273      and line_value with length = 4
  1274      to split the delimiter
  1275      """
  1276  
  1277      write_delimiter = b'@$'
  1278      read_delimiter = b'@$*'
  1279  
  1280      file_name, expected_data = write_data(
  1281        10,
  1282        eol=EOL.CUSTOM_DELIMITER,
  1283        line_value=b'a' * 4,
  1284        custom_delimiter=write_delimiter)
  1285  
  1286      # In this case check, that the line won't be splitted
  1287      write_delimiter_encode = write_delimiter.decode('utf-8')
  1288      expected_data_str = [
  1289          write_delimiter_encode.join(expected_data) + write_delimiter_encode
  1290      ]
  1291  
  1292      source = TextSource(
  1293          file_pattern=file_name,
  1294          min_bundle_size=0,
  1295          buffer_size=6,
  1296          compression_type=CompressionTypes.UNCOMPRESSED,
  1297          strip_trailing_newlines=True,
  1298          coder=coders.StrUtf8Coder(),
  1299          delimiter=read_delimiter,
  1300      )
  1301      range_tracker = source.get_range_tracker(None, None)
  1302  
  1303      read_data = list(source.read(range_tracker))
  1304  
  1305      self.assertEqual(read_data, expected_data_str)
  1306  
  1307    def test_read_crlf_split_by_buffer(self):
  1308      file_name, expected_data = write_data(3, eol=EOL.CRLF)
  1309      assert len(expected_data) == 3
  1310      self._run_read_test(file_name, expected_data, buffer_size=6)
  1311  
  1312    def test_read_escaped_lf(self):
  1313      file_name, expected_data = write_data(
  1314        self.DEFAULT_NUM_RECORDS, eol=EOL.LF, line_value=b'li\\\nne')
  1315      assert len(expected_data) == self.DEFAULT_NUM_RECORDS
  1316      self._run_read_test(file_name, expected_data, escapechar=b'\\')
  1317  
  1318    def test_read_escaped_crlf(self):
  1319      file_name, expected_data = write_data(
  1320        TextSource.DEFAULT_READ_BUFFER_SIZE,
  1321        eol=EOL.CRLF,
  1322        line_value=b'li\\\r\\\nne')
  1323      assert len(expected_data) == TextSource.DEFAULT_READ_BUFFER_SIZE
  1324      self._run_read_test(file_name, expected_data, escapechar=b'\\')
  1325  
  1326    def test_read_escaped_cr_before_not_escaped_lf(self):
  1327      file_name, expected_data_temp = write_data(
  1328        self.DEFAULT_NUM_RECORDS, eol=EOL.CRLF, line_value=b'li\\\r\nne')
  1329      expected_data = []
  1330      for line in expected_data_temp:
  1331        expected_data += line.split("\n")
  1332      assert len(expected_data) == self.DEFAULT_NUM_RECORDS * 2
  1333      self._run_read_test(file_name, expected_data, escapechar=b'\\')
  1334  
  1335    def test_read_escaped_custom_delimiter_crlf(self):
  1336      file_name, expected_data = write_data(
  1337        self.DEFAULT_NUM_RECORDS, eol=EOL.CRLF, line_value=b'li\\\r\nne')
  1338      assert len(expected_data) == self.DEFAULT_NUM_RECORDS
  1339      self._run_read_test(
  1340          file_name, expected_data, delimiter=b'\r\n', escapechar=b'\\')
  1341  
  1342    def test_read_escaped_custom_delimiter(self):
  1343      file_name, expected_data = write_data(
  1344        TextSource.DEFAULT_READ_BUFFER_SIZE,
  1345        eol=EOL.CUSTOM_DELIMITER,
  1346        custom_delimiter=b'*|',
  1347        line_value=b'li\\*|ne')
  1348      assert len(expected_data) == TextSource.DEFAULT_READ_BUFFER_SIZE
  1349      self._run_read_test(
  1350          file_name, expected_data, delimiter=b'*|', escapechar=b'\\')
  1351  
  1352    def test_read_escaped_lf_at_buffer_edge(self):
  1353      file_name, expected_data = write_data(3, eol=EOL.LF, line_value=b'line\\\n')
  1354      assert len(expected_data) == 3
  1355      self._run_read_test(
  1356          file_name, expected_data, buffer_size=5, escapechar=b'\\')
  1357  
  1358    def test_read_escaped_crlf_split_by_buffer(self):
  1359      file_name, expected_data = write_data(
  1360        3, eol=EOL.CRLF, line_value=b'line\\\r\n')
  1361      assert len(expected_data) == 3
  1362      self._run_read_test(
  1363          file_name,
  1364          expected_data,
  1365          buffer_size=6,
  1366          delimiter=b'\r\n',
  1367          escapechar=b'\\')
  1368  
  1369    def test_read_escaped_lf_after_splitting(self):
  1370      file_name, expected_data = write_data(3, line_value=b'line\\\n')
  1371      assert len(expected_data) == 3
  1372      source = TextSource(
  1373          file_name,
  1374          0,
  1375          CompressionTypes.UNCOMPRESSED,
  1376          True,
  1377          coders.StrUtf8Coder(),
  1378          escapechar=b'\\')
  1379      splits = list(source.split(desired_bundle_size=6))
  1380  
  1381      reference_source_info = (source, None, None)
  1382      sources_info = ([(split.source, split.start_position, split.stop_position)
  1383                       for split in splits])
  1384      source_test_utils.assert_sources_equal_reference_source(
  1385          reference_source_info, sources_info)
  1386  
  1387    def test_read_escaped_lf_after_splitting_many(self):
  1388      file_name, expected_data = write_data(
  1389        3, line_value=b'\\\\\\\\\\\n')  # 5 escapes
  1390      assert len(expected_data) == 3
  1391      source = TextSource(
  1392          file_name,
  1393          0,
  1394          CompressionTypes.UNCOMPRESSED,
  1395          True,
  1396          coders.StrUtf8Coder(),
  1397          escapechar=b'\\')
  1398      splits = list(source.split(desired_bundle_size=6))
  1399  
  1400      reference_source_info = (source, None, None)
  1401      sources_info = ([(split.source, split.start_position, split.stop_position)
  1402                       for split in splits])
  1403      source_test_utils.assert_sources_equal_reference_source(
  1404          reference_source_info, sources_info)
  1405  
  1406    def test_read_escaped_escapechar_after_splitting(self):
  1407      file_name, expected_data = write_data(3, line_value=b'line\\\\*|')
  1408      assert len(expected_data) == 3
  1409      source = TextSource(
  1410          file_name,
  1411          0,
  1412          CompressionTypes.UNCOMPRESSED,
  1413          True,
  1414          coders.StrUtf8Coder(),
  1415          delimiter=b'*|',
  1416          escapechar=b'\\')
  1417      splits = list(source.split(desired_bundle_size=8))
  1418  
  1419      reference_source_info = (source, None, None)
  1420      sources_info = ([(split.source, split.start_position, split.stop_position)
  1421                       for split in splits])
  1422      source_test_utils.assert_sources_equal_reference_source(
  1423          reference_source_info, sources_info)
  1424  
  1425    def test_read_escaped_escapechar_after_splitting_many(self):
  1426      file_name, expected_data = write_data(
  1427        3, line_value=b'\\\\\\\\\\\\*|')  # 6 escapes
  1428      assert len(expected_data) == 3
  1429      source = TextSource(
  1430          file_name,
  1431          0,
  1432          CompressionTypes.UNCOMPRESSED,
  1433          True,
  1434          coders.StrUtf8Coder(),
  1435          delimiter=b'*|',
  1436          escapechar=b'\\')
  1437      splits = list(source.split(desired_bundle_size=8))
  1438  
  1439      reference_source_info = (source, None, None)
  1440      sources_info = ([(split.source, split.start_position, split.stop_position)
  1441                       for split in splits])
  1442      source_test_utils.assert_sources_equal_reference_source(
  1443          reference_source_info, sources_info)
  1444  
  1445  
  1446  class TextSinkTest(unittest.TestCase):
  1447    def setUp(self):
  1448      super().setUp()
  1449      self.lines = [b'Line %d' % d for d in range(100)]
  1450      self.tempdir = tempfile.mkdtemp()
  1451      self.path = self._create_temp_file()
  1452  
  1453    def tearDown(self):
  1454      if os.path.exists(self.tempdir):
  1455        shutil.rmtree(self.tempdir)
  1456  
  1457    def _create_temp_file(self, name='', suffix=''):
  1458      if not name:
  1459        name = tempfile.template
  1460      file_name = tempfile.NamedTemporaryFile(
  1461          delete=True, prefix=name, dir=self.tempdir, suffix=suffix).name
  1462      return file_name
  1463  
  1464    def _write_lines(self, sink, lines):
  1465      f = sink.open(self.path)
  1466      for line in lines:
  1467        sink.write_record(f, line)
  1468      sink.close(f)
  1469  
  1470    def test_write_text_file(self):
  1471      sink = TextSink(self.path)
  1472      self._write_lines(sink, self.lines)
  1473  
  1474      with open(self.path, 'rb') as f:
  1475        self.assertEqual(f.read().splitlines(), self.lines)
  1476  
  1477    def test_write_text_file_empty(self):
  1478      sink = TextSink(self.path)
  1479      self._write_lines(sink, [])
  1480  
  1481      with open(self.path, 'rb') as f:
  1482        self.assertEqual(f.read().splitlines(), [])
  1483  
  1484    def test_write_bzip2_file(self):
  1485      sink = TextSink(self.path, compression_type=CompressionTypes.BZIP2)
  1486      self._write_lines(sink, self.lines)
  1487  
  1488      with bz2.BZ2File(self.path, 'rb') as f:
  1489        self.assertEqual(f.read().splitlines(), self.lines)
  1490  
  1491    def test_write_bzip2_file_auto(self):
  1492      self.path = self._create_temp_file(suffix='.bz2')
  1493      sink = TextSink(self.path)
  1494      self._write_lines(sink, self.lines)
  1495  
  1496      with bz2.BZ2File(self.path, 'rb') as f:
  1497        self.assertEqual(f.read().splitlines(), self.lines)
  1498  
  1499    def test_write_gzip_file(self):
  1500      sink = TextSink(self.path, compression_type=CompressionTypes.GZIP)
  1501      self._write_lines(sink, self.lines)
  1502  
  1503      with gzip.GzipFile(self.path, 'rb') as f:
  1504        self.assertEqual(f.read().splitlines(), self.lines)
  1505  
  1506    def test_write_gzip_file_auto(self):
  1507      self.path = self._create_temp_file(suffix='.gz')
  1508      sink = TextSink(self.path)
  1509      self._write_lines(sink, self.lines)
  1510  
  1511      with gzip.GzipFile(self.path, 'rb') as f:
  1512        self.assertEqual(f.read().splitlines(), self.lines)
  1513  
  1514    def test_write_gzip_file_empty(self):
  1515      sink = TextSink(self.path, compression_type=CompressionTypes.GZIP)
  1516      self._write_lines(sink, [])
  1517  
  1518      with gzip.GzipFile(self.path, 'rb') as f:
  1519        self.assertEqual(f.read().splitlines(), [])
  1520  
  1521    def test_write_deflate_file(self):
  1522      sink = TextSink(self.path, compression_type=CompressionTypes.DEFLATE)
  1523      self._write_lines(sink, self.lines)
  1524  
  1525      with open(self.path, 'rb') as f:
  1526        self.assertEqual(zlib.decompress(f.read()).splitlines(), self.lines)
  1527  
  1528    def test_write_deflate_file_auto(self):
  1529      self.path = self._create_temp_file(suffix='.deflate')
  1530      sink = TextSink(self.path)
  1531      self._write_lines(sink, self.lines)
  1532  
  1533      with open(self.path, 'rb') as f:
  1534        self.assertEqual(zlib.decompress(f.read()).splitlines(), self.lines)
  1535  
  1536    def test_write_deflate_file_empty(self):
  1537      sink = TextSink(self.path, compression_type=CompressionTypes.DEFLATE)
  1538      self._write_lines(sink, [])
  1539  
  1540      with open(self.path, 'rb') as f:
  1541        self.assertEqual(zlib.decompress(f.read()).splitlines(), [])
  1542  
  1543    def test_write_text_file_with_header(self):
  1544      header = b'header1\nheader2'
  1545      sink = TextSink(self.path, header=header)
  1546      self._write_lines(sink, self.lines)
  1547  
  1548      with open(self.path, 'rb') as f:
  1549        self.assertEqual(f.read().splitlines(), header.splitlines() + self.lines)
  1550  
  1551    def test_write_text_file_with_footer(self):
  1552      footer = b'footer1\nfooter2'
  1553      sink = TextSink(self.path, footer=footer)
  1554      self._write_lines(sink, self.lines)
  1555  
  1556      with open(self.path, 'rb') as f:
  1557        self.assertEqual(f.read().splitlines(), self.lines + footer.splitlines())
  1558  
  1559    def test_write_text_file_empty_with_header(self):
  1560      header = b'header1\nheader2'
  1561      sink = TextSink(self.path, header=header)
  1562      self._write_lines(sink, [])
  1563  
  1564      with open(self.path, 'rb') as f:
  1565        self.assertEqual(f.read().splitlines(), header.splitlines())
  1566  
  1567    def test_write_pipeline(self):
  1568      with TestPipeline() as pipeline:
  1569        pcoll = pipeline | beam.core.Create(self.lines)
  1570        pcoll | 'Write' >> WriteToText(self.path)  # pylint: disable=expression-not-assigned
  1571  
  1572      read_result = []
  1573      for file_name in glob.glob(self.path + '*'):
  1574        with open(file_name, 'rb') as f:
  1575          read_result.extend(f.read().splitlines())
  1576  
  1577      self.assertEqual(sorted(read_result), sorted(self.lines))
  1578  
  1579    def test_write_pipeline_non_globalwindow_input(self):
  1580      with TestPipeline() as p:
  1581        _ = (
  1582            p
  1583            | beam.core.Create(self.lines)
  1584            | beam.WindowInto(beam.transforms.window.FixedWindows(1))
  1585            | 'Write' >> WriteToText(self.path))
  1586  
  1587      read_result = []
  1588      for file_name in glob.glob(self.path + '*'):
  1589        with open(file_name, 'rb') as f:
  1590          read_result.extend(f.read().splitlines())
  1591  
  1592      self.assertEqual(sorted(read_result), sorted(self.lines))
  1593  
  1594    def test_write_pipeline_auto_compression(self):
  1595      with TestPipeline() as pipeline:
  1596        pcoll = pipeline | beam.core.Create(self.lines)
  1597        pcoll | 'Write' >> WriteToText(self.path, file_name_suffix='.gz')  # pylint: disable=expression-not-assigned
  1598  
  1599      read_result = []
  1600      for file_name in glob.glob(self.path + '*'):
  1601        with gzip.GzipFile(file_name, 'rb') as f:
  1602          read_result.extend(f.read().splitlines())
  1603  
  1604      self.assertEqual(sorted(read_result), sorted(self.lines))
  1605  
  1606    def test_write_pipeline_auto_compression_unsharded(self):
  1607      with TestPipeline() as pipeline:
  1608        pcoll = pipeline | 'Create' >> beam.core.Create(self.lines)
  1609        pcoll | 'Write' >> WriteToText(  # pylint: disable=expression-not-assigned
  1610            self.path + '.gz',
  1611            shard_name_template='')
  1612  
  1613      read_result = []
  1614      for file_name in glob.glob(self.path + '*'):
  1615        with gzip.GzipFile(file_name, 'rb') as f:
  1616          read_result.extend(f.read().splitlines())
  1617  
  1618      self.assertEqual(sorted(read_result), sorted(self.lines))
  1619  
  1620    def test_write_pipeline_header(self):
  1621      with TestPipeline() as pipeline:
  1622        pcoll = pipeline | 'Create' >> beam.core.Create(self.lines)
  1623        header_text = 'foo'
  1624        pcoll | 'Write' >> WriteToText(  # pylint: disable=expression-not-assigned
  1625            self.path + '.gz',
  1626            shard_name_template='',
  1627            header=header_text)
  1628  
  1629      read_result = []
  1630      for file_name in glob.glob(self.path + '*'):
  1631        with gzip.GzipFile(file_name, 'rb') as f:
  1632          read_result.extend(f.read().splitlines())
  1633      # header_text is automatically encoded in WriteToText
  1634      self.assertEqual(read_result[0], header_text.encode('utf-8'))
  1635      self.assertEqual(sorted(read_result[1:]), sorted(self.lines))
  1636  
  1637    def test_write_pipeline_footer(self):
  1638      with TestPipeline() as pipeline:
  1639        footer_text = 'footer'
  1640        pcoll = pipeline | beam.core.Create(self.lines)
  1641        pcoll | 'Write' >> WriteToText(   # pylint: disable=expression-not-assigned
  1642          self.path,
  1643          footer=footer_text)
  1644  
  1645      read_result = []
  1646      for file_name in glob.glob(self.path + '*'):
  1647        with open(file_name, 'rb') as f:
  1648          read_result.extend(f.read().splitlines())
  1649  
  1650      self.assertEqual(sorted(read_result[:-1]), sorted(self.lines))
  1651      self.assertEqual(read_result[-1], footer_text.encode('utf-8'))
  1652  
  1653    def test_write_empty(self):
  1654      with TestPipeline() as p:
  1655        # pylint: disable=expression-not-assigned
  1656        p | beam.core.Create([]) | WriteToText(self.path)
  1657  
  1658      outputs = glob.glob(self.path + '*')
  1659      self.assertEqual(len(outputs), 1)
  1660      with open(outputs[0], 'rb') as f:
  1661        self.assertEqual(list(f.read().splitlines()), [])
  1662  
  1663    def test_write_empty_skipped(self):
  1664      with TestPipeline() as p:
  1665        # pylint: disable=expression-not-assigned
  1666        p | beam.core.Create([]) | WriteToText(self.path, skip_if_empty=True)
  1667  
  1668      outputs = list(glob.glob(self.path + '*'))
  1669      self.assertEqual(outputs, [])
  1670  
  1671    def test_write_max_records_per_shard(self):
  1672      records_per_shard = 13
  1673      lines = [str(i).encode('utf-8') for i in range(100)]
  1674      with TestPipeline() as p:
  1675        # pylint: disable=expression-not-assigned
  1676        p | beam.core.Create(lines) | WriteToText(
  1677            self.path, max_records_per_shard=records_per_shard)
  1678  
  1679      read_result = []
  1680      for file_name in glob.glob(self.path + '*'):
  1681        with open(file_name, 'rb') as f:
  1682          shard_lines = list(f.read().splitlines())
  1683          self.assertLessEqual(len(shard_lines), records_per_shard)
  1684          read_result.extend(shard_lines)
  1685      self.assertEqual(sorted(read_result), sorted(lines))
  1686  
  1687    def test_write_max_bytes_per_shard(self):
  1688      bytes_per_shard = 300
  1689      max_len = 100
  1690      lines = [b'x' * i for i in range(max_len)]
  1691      header = b'a' * 20
  1692      footer = b'b' * 30
  1693      with TestPipeline() as p:
  1694        # pylint: disable=expression-not-assigned
  1695        p | beam.core.Create(lines) | WriteToText(
  1696            self.path,
  1697            header=header,
  1698            footer=footer,
  1699            max_bytes_per_shard=bytes_per_shard)
  1700  
  1701      read_result = []
  1702      for file_name in glob.glob(self.path + '*'):
  1703        with open(file_name, 'rb') as f:
  1704          contents = f.read()
  1705          self.assertLessEqual(
  1706              len(contents), bytes_per_shard + max_len + len(footer) + 2)
  1707          shard_lines = list(contents.splitlines())
  1708          self.assertEqual(shard_lines[0], header)
  1709          self.assertEqual(shard_lines[-1], footer)
  1710          read_result.extend(shard_lines[1:-1])
  1711      self.assertEqual(sorted(read_result), sorted(lines))
  1712  
  1713  
  1714  class CsvTest(unittest.TestCase):
  1715    def test_csv_read_write(self):
  1716      records = [beam.Row(a='str', b=ix) for ix in range(3)]
  1717      with tempfile.TemporaryDirectory() as dest:
  1718        with TestPipeline() as p:
  1719          # pylint: disable=expression-not-assigned
  1720          p | beam.Create(records) | beam.io.WriteToCsv(os.path.join(dest, 'out'))
  1721        with TestPipeline() as p:
  1722          pcoll = (
  1723              p
  1724              | beam.io.ReadFromCsv(os.path.join(dest, 'out*'))
  1725              | beam.Map(lambda t: beam.Row(**dict(zip(type(t)._fields, t)))))
  1726  
  1727          assert_that(pcoll, equal_to(records))
  1728  
  1729  
  1730  class JsonTest(unittest.TestCase):
  1731    def test_json_read_write(self):
  1732      records = [beam.Row(a='str', b=ix) for ix in range(3)]
  1733      with tempfile.TemporaryDirectory() as dest:
  1734        with TestPipeline() as p:
  1735          # pylint: disable=expression-not-assigned
  1736          p | beam.Create(records) | beam.io.WriteToJson(
  1737              os.path.join(dest, 'out'))
  1738        with TestPipeline() as p:
  1739          pcoll = (
  1740              p
  1741              | beam.io.ReadFromJson(os.path.join(dest, 'out*'))
  1742              | beam.Map(lambda t: beam.Row(**dict(zip(type(t)._fields, t)))))
  1743  
  1744          assert_that(pcoll, equal_to(records))
  1745  
  1746  
  1747  if __name__ == '__main__':
  1748    logging.getLogger().setLevel(logging.INFO)
  1749    unittest.main()