github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/textio_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Tests for textio module.""" 19 # pytype: skip-file 20 21 import bz2 22 import glob 23 import gzip 24 import logging 25 import os 26 import shutil 27 import tempfile 28 import unittest 29 import zlib 30 31 import apache_beam as beam 32 from apache_beam import coders 33 from apache_beam.io import iobase 34 from apache_beam.io import source_test_utils 35 from apache_beam.io.filesystem import CompressionTypes 36 from apache_beam.io.filesystems import FileSystems 37 from apache_beam.io.textio import _TextSink as TextSink 38 from apache_beam.io.textio import _TextSource as TextSource 39 # Importing following private classes for testing. 40 from apache_beam.io.textio import ReadAllFromText 41 from apache_beam.io.textio import ReadAllFromTextContinuously 42 from apache_beam.io.textio import ReadFromText 43 from apache_beam.io.textio import ReadFromTextWithFilename 44 from apache_beam.io.textio import WriteToText 45 from apache_beam.testing.test_pipeline import TestPipeline 46 from apache_beam.testing.test_utils import TempDir 47 from apache_beam.testing.util import assert_that 48 from apache_beam.testing.util import equal_to 49 from apache_beam.transforms.core import Create 50 from apache_beam.transforms.userstate import CombiningValueStateSpec 51 from apache_beam.utils.timestamp import Timestamp 52 53 54 class DummyCoder(coders.Coder): 55 def encode(self, x): 56 raise ValueError 57 58 def decode(self, x): 59 return (x * 2).decode('utf-8') 60 61 def to_type_hint(self): 62 return str 63 64 65 class EOL(object): 66 LF = 1 67 CRLF = 2 68 MIXED = 3 69 LF_WITH_NOTHING_AT_LAST_LINE = 4 70 CUSTOM_DELIMITER = 5 71 72 73 def write_data( 74 num_lines, 75 no_data=False, 76 directory=None, 77 prefix=tempfile.template, 78 eol=EOL.LF, 79 custom_delimiter=None, 80 line_value=b'line'): 81 """Writes test data to a temporary file. 82 83 Args: 84 num_lines (int): The number of lines to write. 85 no_data (bool): If :data:`True`, empty lines will be written, otherwise 86 each line will contain a concatenation of b'line' and the line number. 87 directory (str): The name of the directory to create the temporary file in. 88 prefix (str): The prefix to use for the temporary file. 89 eol (int): The line ending to use when writing. 90 :class:`~apache_beam.io.textio_test.EOL` exposes attributes that can be 91 used here to define the eol. 92 custom_delimiter (bytes): The custom delimiter. 93 line_value (bytes): Default value for test data, default b'line' 94 95 Returns: 96 Tuple[str, List[str]]: A tuple of the filename and a list of the 97 utf-8 decoded written data. 98 """ 99 all_data = [] 100 with tempfile.NamedTemporaryFile(delete=False, dir=directory, 101 prefix=prefix) as f: 102 sep_values = [b'\n', b'\r\n'] 103 for i in range(num_lines): 104 data = b'' if no_data else line_value + str(i).encode() 105 all_data.append(data) 106 107 if eol == EOL.LF: 108 sep = sep_values[0] 109 elif eol == EOL.CRLF: 110 sep = sep_values[1] 111 elif eol == EOL.MIXED: 112 sep = sep_values[i % len(sep_values)] 113 elif eol == EOL.LF_WITH_NOTHING_AT_LAST_LINE: 114 sep = b'' if i == (num_lines - 1) else sep_values[0] 115 elif eol == EOL.CUSTOM_DELIMITER: 116 if custom_delimiter is None or len(custom_delimiter) == 0: 117 raise ValueError('delimiter can not be null or empty') 118 else: 119 sep = custom_delimiter 120 else: 121 raise ValueError('Received unknown value %s for eol.' % eol) 122 123 f.write(data + sep) 124 125 return f.name, [line.decode('utf-8') for line in all_data] 126 127 128 def write_pattern(lines_per_file, no_data=False, return_filenames=False): 129 """Writes a pattern of temporary files. 130 131 Args: 132 lines_per_file (List[int]): The number of lines to write per file. 133 no_data (bool): If :data:`True`, empty lines will be written, otherwise 134 each line will contain a concatenation of b'line' and the line number. 135 return_filenames (bool): If True, returned list will contain 136 (filename, data) pairs. 137 138 Returns: 139 Tuple[str, List[Union[str, (str, str)]]]: A tuple of the filename pattern 140 and a list of the utf-8 decoded written data or (filename, data) pairs. 141 """ 142 temp_dir = tempfile.mkdtemp() 143 144 all_data = [] 145 file_name = None 146 start_index = 0 147 for i in range(len(lines_per_file)): 148 file_name, data = write_data(lines_per_file[i], no_data=no_data, 149 directory=temp_dir, prefix='mytemp') 150 if return_filenames: 151 all_data.extend(zip([file_name] * len(data), data)) 152 else: 153 all_data.extend(data) 154 start_index += lines_per_file[i] 155 156 assert file_name 157 return ( 158 file_name[:file_name.rfind(os.path.sep)] + os.path.sep + 'mytemp*', 159 all_data) 160 161 162 class TextSourceTest(unittest.TestCase): 163 164 # Number of records that will be written by most tests. 165 DEFAULT_NUM_RECORDS = 100 166 167 def _run_read_test( 168 self, 169 file_or_pattern, 170 expected_data, 171 buffer_size=DEFAULT_NUM_RECORDS, 172 compression=CompressionTypes.UNCOMPRESSED, 173 delimiter=None, 174 escapechar=None): 175 # Since each record usually takes more than 1 byte, default buffer size is 176 # smaller than the total size of the file. This is done to 177 # increase test coverage for cases that hit the buffer boundary. 178 kwargs = {} 179 if delimiter: 180 kwargs['delimiter'] = delimiter 181 if escapechar: 182 kwargs['escapechar'] = escapechar 183 source = TextSource( 184 file_or_pattern, 185 0, 186 compression, 187 True, 188 coders.StrUtf8Coder(), 189 buffer_size, 190 **kwargs) 191 range_tracker = source.get_range_tracker(None, None) 192 read_data = list(source.read(range_tracker)) 193 self.assertCountEqual(expected_data, read_data) 194 195 def test_read_single_file(self): 196 file_name, expected_data = write_data(TextSourceTest.DEFAULT_NUM_RECORDS) 197 assert len(expected_data) == TextSourceTest.DEFAULT_NUM_RECORDS 198 self._run_read_test(file_name, expected_data) 199 200 def test_read_single_file_smaller_than_default_buffer(self): 201 file_name, expected_data = write_data(TextSourceTest.DEFAULT_NUM_RECORDS) 202 self._run_read_test( 203 file_name, 204 expected_data, 205 buffer_size=TextSource.DEFAULT_READ_BUFFER_SIZE) 206 207 def test_read_single_file_larger_than_default_buffer(self): 208 file_name, expected_data = write_data(TextSource.DEFAULT_READ_BUFFER_SIZE) 209 self._run_read_test( 210 file_name, 211 expected_data, 212 buffer_size=TextSource.DEFAULT_READ_BUFFER_SIZE) 213 214 def test_read_file_pattern(self): 215 pattern, expected_data = write_pattern( 216 [TextSourceTest.DEFAULT_NUM_RECORDS * 5, 217 TextSourceTest.DEFAULT_NUM_RECORDS * 3, 218 TextSourceTest.DEFAULT_NUM_RECORDS * 12, 219 TextSourceTest.DEFAULT_NUM_RECORDS * 8, 220 TextSourceTest.DEFAULT_NUM_RECORDS * 8, 221 TextSourceTest.DEFAULT_NUM_RECORDS * 4]) 222 assert len(expected_data) == TextSourceTest.DEFAULT_NUM_RECORDS * 40 223 self._run_read_test(pattern, expected_data) 224 225 def test_read_single_file_windows_eol(self): 226 file_name, expected_data = write_data(TextSourceTest.DEFAULT_NUM_RECORDS, 227 eol=EOL.CRLF) 228 assert len(expected_data) == TextSourceTest.DEFAULT_NUM_RECORDS 229 self._run_read_test(file_name, expected_data) 230 231 def test_read_single_file_mixed_eol(self): 232 file_name, expected_data = write_data(TextSourceTest.DEFAULT_NUM_RECORDS, 233 eol=EOL.MIXED) 234 assert len(expected_data) == TextSourceTest.DEFAULT_NUM_RECORDS 235 self._run_read_test(file_name, expected_data) 236 237 def test_read_single_file_last_line_no_eol(self): 238 file_name, expected_data = write_data( 239 TextSourceTest.DEFAULT_NUM_RECORDS, 240 eol=EOL.LF_WITH_NOTHING_AT_LAST_LINE) 241 assert len(expected_data) == TextSourceTest.DEFAULT_NUM_RECORDS 242 self._run_read_test(file_name, expected_data) 243 244 def test_read_single_file_single_line_no_eol(self): 245 file_name, expected_data = write_data( 246 1, eol=EOL.LF_WITH_NOTHING_AT_LAST_LINE) 247 248 assert len(expected_data) == 1 249 self._run_read_test(file_name, expected_data) 250 251 def test_read_empty_single_file(self): 252 file_name, written_data = write_data( 253 1, no_data=True, eol=EOL.LF_WITH_NOTHING_AT_LAST_LINE) 254 255 assert len(written_data) == 1 256 # written data has a single entry with an empty string. Reading the source 257 # should not produce anything since we only wrote a single empty string 258 # without an end of line character. 259 self._run_read_test(file_name, []) 260 261 def test_read_single_file_last_line_no_eol_gzip(self): 262 file_name, expected_data = write_data( 263 TextSourceTest.DEFAULT_NUM_RECORDS, 264 eol=EOL.LF_WITH_NOTHING_AT_LAST_LINE) 265 266 gzip_file_name = file_name + '.gz' 267 with open(file_name, 'rb') as src, gzip.open(gzip_file_name, 'wb') as dst: 268 dst.writelines(src) 269 270 assert len(expected_data) == TextSourceTest.DEFAULT_NUM_RECORDS 271 self._run_read_test( 272 gzip_file_name, expected_data, compression=CompressionTypes.GZIP) 273 274 def test_read_single_file_single_line_no_eol_gzip(self): 275 file_name, expected_data = write_data( 276 1, eol=EOL.LF_WITH_NOTHING_AT_LAST_LINE) 277 278 gzip_file_name = file_name + '.gz' 279 with open(file_name, 'rb') as src, gzip.open(gzip_file_name, 'wb') as dst: 280 dst.writelines(src) 281 282 assert len(expected_data) == 1 283 self._run_read_test( 284 gzip_file_name, expected_data, compression=CompressionTypes.GZIP) 285 286 def test_read_empty_single_file_no_eol_gzip(self): 287 file_name, written_data = write_data( 288 1, no_data=True, eol=EOL.LF_WITH_NOTHING_AT_LAST_LINE) 289 290 gzip_file_name = file_name + '.gz' 291 with open(file_name, 'rb') as src, gzip.open(gzip_file_name, 'wb') as dst: 292 dst.writelines(src) 293 294 assert len(written_data) == 1 295 # written data has a single entry with an empty string. Reading the source 296 # should not produce anything since we only wrote a single empty string 297 # without an end of line character. 298 self._run_read_test(gzip_file_name, [], compression=CompressionTypes.GZIP) 299 300 def test_read_single_file_with_empty_lines(self): 301 file_name, expected_data = write_data( 302 TextSourceTest.DEFAULT_NUM_RECORDS, no_data=True, eol=EOL.LF) 303 304 assert len(expected_data) == TextSourceTest.DEFAULT_NUM_RECORDS 305 assert not expected_data[0] 306 307 self._run_read_test(file_name, expected_data) 308 309 def test_read_single_file_without_striping_eol_lf(self): 310 file_name, written_data = write_data(TextSourceTest.DEFAULT_NUM_RECORDS, 311 eol=EOL.LF) 312 assert len(written_data) == TextSourceTest.DEFAULT_NUM_RECORDS 313 source = TextSource( 314 file_name, 315 0, 316 CompressionTypes.UNCOMPRESSED, 317 False, 318 coders.StrUtf8Coder()) 319 320 range_tracker = source.get_range_tracker(None, None) 321 read_data = list(source.read(range_tracker)) 322 self.assertCountEqual([line + '\n' for line in written_data], read_data) 323 324 def test_read_single_file_without_striping_eol_crlf(self): 325 file_name, written_data = write_data(TextSourceTest.DEFAULT_NUM_RECORDS, 326 eol=EOL.CRLF) 327 assert len(written_data) == TextSourceTest.DEFAULT_NUM_RECORDS 328 source = TextSource( 329 file_name, 330 0, 331 CompressionTypes.UNCOMPRESSED, 332 False, 333 coders.StrUtf8Coder()) 334 335 range_tracker = source.get_range_tracker(None, None) 336 read_data = list(source.read(range_tracker)) 337 self.assertCountEqual([line + '\r\n' for line in written_data], read_data) 338 339 def test_read_file_pattern_with_empty_files(self): 340 pattern, expected_data = write_pattern( 341 [5 * TextSourceTest.DEFAULT_NUM_RECORDS, 342 3 * TextSourceTest.DEFAULT_NUM_RECORDS, 343 12 * TextSourceTest.DEFAULT_NUM_RECORDS, 344 8 * TextSourceTest.DEFAULT_NUM_RECORDS, 345 8 * TextSourceTest.DEFAULT_NUM_RECORDS, 346 4 * TextSourceTest.DEFAULT_NUM_RECORDS], 347 no_data=True) 348 assert len(expected_data) == 40 * TextSourceTest.DEFAULT_NUM_RECORDS 349 assert not expected_data[0] 350 self._run_read_test(pattern, expected_data) 351 352 def test_read_after_splitting(self): 353 file_name, expected_data = write_data(10) 354 assert len(expected_data) == 10 355 source = TextSource( 356 file_name, 357 0, 358 CompressionTypes.UNCOMPRESSED, 359 True, 360 coders.StrUtf8Coder()) 361 splits = list(source.split(desired_bundle_size=33)) 362 363 reference_source_info = (source, None, None) 364 sources_info = ([(split.source, split.start_position, split.stop_position) 365 for split in splits]) 366 source_test_utils.assert_sources_equal_reference_source( 367 reference_source_info, sources_info) 368 369 def test_header_processing(self): 370 file_name, expected_data = write_data(10) 371 assert len(expected_data) == 10 372 373 def header_matcher(line): 374 return line in expected_data[:5] 375 376 header_lines = [] 377 378 def store_header(lines): 379 for line in lines: 380 header_lines.append(line) 381 382 source = TextSource( 383 file_name, 384 0, 385 CompressionTypes.UNCOMPRESSED, 386 True, 387 coders.StrUtf8Coder(), 388 header_processor_fns=(header_matcher, store_header)) 389 splits = list(source.split(desired_bundle_size=100000)) 390 assert len(splits) == 1 391 range_tracker = splits[0].source.get_range_tracker( 392 splits[0].start_position, splits[0].stop_position) 393 read_data = list(source.read_records(file_name, range_tracker)) 394 395 self.assertCountEqual(expected_data[:5], header_lines) 396 self.assertCountEqual(expected_data[5:], read_data) 397 398 def test_progress(self): 399 file_name, expected_data = write_data(10) 400 assert len(expected_data) == 10 401 source = TextSource( 402 file_name, 403 0, 404 CompressionTypes.UNCOMPRESSED, 405 True, 406 coders.StrUtf8Coder()) 407 splits = list(source.split(desired_bundle_size=100000)) 408 assert len(splits) == 1 409 fraction_consumed_report = [] 410 split_points_report = [] 411 range_tracker = splits[0].source.get_range_tracker( 412 splits[0].start_position, splits[0].stop_position) 413 for _ in splits[0].source.read(range_tracker): 414 fraction_consumed_report.append(range_tracker.fraction_consumed()) 415 split_points_report.append(range_tracker.split_points()) 416 417 self.assertEqual([float(i) / 10 for i in range(0, 10)], 418 fraction_consumed_report) 419 expected_split_points_report = [((i - 1), 420 iobase.RangeTracker.SPLIT_POINTS_UNKNOWN) 421 for i in range(1, 10)] 422 423 # At last split point, the remaining split points callback returns 1 since 424 # the expected position of next record becomes equal to the stop position. 425 expected_split_points_report.append((9, 1)) 426 427 self.assertEqual(expected_split_points_report, split_points_report) 428 429 def test_read_reentrant_without_splitting(self): 430 file_name, expected_data = write_data(10) 431 assert len(expected_data) == 10 432 source = TextSource( 433 file_name, 434 0, 435 CompressionTypes.UNCOMPRESSED, 436 True, 437 coders.StrUtf8Coder()) 438 source_test_utils.assert_reentrant_reads_succeed((source, None, None)) 439 440 def test_read_reentrant_after_splitting(self): 441 file_name, expected_data = write_data(10) 442 assert len(expected_data) == 10 443 source = TextSource( 444 file_name, 445 0, 446 CompressionTypes.UNCOMPRESSED, 447 True, 448 coders.StrUtf8Coder()) 449 splits = list(source.split(desired_bundle_size=100000)) 450 assert len(splits) == 1 451 source_test_utils.assert_reentrant_reads_succeed( 452 (splits[0].source, splits[0].start_position, splits[0].stop_position)) 453 454 def test_dynamic_work_rebalancing(self): 455 file_name, expected_data = write_data(5) 456 assert len(expected_data) == 5 457 source = TextSource( 458 file_name, 459 0, 460 CompressionTypes.UNCOMPRESSED, 461 True, 462 coders.StrUtf8Coder()) 463 splits = list(source.split(desired_bundle_size=100000)) 464 assert len(splits) == 1 465 source_test_utils.assert_split_at_fraction_exhaustive( 466 splits[0].source, splits[0].start_position, splits[0].stop_position) 467 468 def test_dynamic_work_rebalancing_windows_eol(self): 469 file_name, expected_data = write_data(15, eol=EOL.CRLF) 470 assert len(expected_data) == 15 471 source = TextSource( 472 file_name, 473 0, 474 CompressionTypes.UNCOMPRESSED, 475 True, 476 coders.StrUtf8Coder()) 477 splits = list(source.split(desired_bundle_size=100000)) 478 assert len(splits) == 1 479 source_test_utils.assert_split_at_fraction_exhaustive( 480 splits[0].source, 481 splits[0].start_position, 482 splits[0].stop_position, 483 perform_multi_threaded_test=False) 484 485 def test_dynamic_work_rebalancing_mixed_eol(self): 486 file_name, expected_data = write_data(5, eol=EOL.MIXED) 487 assert len(expected_data) == 5 488 source = TextSource( 489 file_name, 490 0, 491 CompressionTypes.UNCOMPRESSED, 492 True, 493 coders.StrUtf8Coder()) 494 splits = list(source.split(desired_bundle_size=100000)) 495 assert len(splits) == 1 496 source_test_utils.assert_split_at_fraction_exhaustive( 497 splits[0].source, 498 splits[0].start_position, 499 splits[0].stop_position, 500 perform_multi_threaded_test=False) 501 502 def test_read_from_text_single_file(self): 503 file_name, expected_data = write_data(5) 504 assert len(expected_data) == 5 505 with TestPipeline() as pipeline: 506 pcoll = pipeline | 'Read' >> ReadFromText(file_name) 507 assert_that(pcoll, equal_to(expected_data)) 508 509 def test_read_from_text_with_file_name_single_file(self): 510 file_name, data = write_data(5) 511 expected_data = [(file_name, el) for el in data] 512 assert len(expected_data) == 5 513 with TestPipeline() as pipeline: 514 pcoll = pipeline | 'Read' >> ReadFromTextWithFilename(file_name) 515 assert_that(pcoll, equal_to(expected_data)) 516 517 def test_read_all_single_file(self): 518 file_name, expected_data = write_data(5) 519 assert len(expected_data) == 5 520 with TestPipeline() as pipeline: 521 pcoll = pipeline | 'Create' >> Create( 522 [file_name]) | 'ReadAll' >> ReadAllFromText() 523 assert_that(pcoll, equal_to(expected_data)) 524 525 def test_read_all_many_single_files(self): 526 file_name1, expected_data1 = write_data(5) 527 assert len(expected_data1) == 5 528 file_name2, expected_data2 = write_data(10) 529 assert len(expected_data2) == 10 530 file_name3, expected_data3 = write_data(15) 531 assert len(expected_data3) == 15 532 expected_data = [] 533 expected_data.extend(expected_data1) 534 expected_data.extend(expected_data2) 535 expected_data.extend(expected_data3) 536 with TestPipeline() as pipeline: 537 pcoll = pipeline | 'Create' >> Create([ 538 file_name1, file_name2, file_name3 539 ]) | 'ReadAll' >> ReadAllFromText() 540 assert_that(pcoll, equal_to(expected_data)) 541 542 def test_read_all_unavailable_files_ignored(self): 543 file_name1, expected_data1 = write_data(5) 544 assert len(expected_data1) == 5 545 file_name2, expected_data2 = write_data(10) 546 assert len(expected_data2) == 10 547 file_name3, expected_data3 = write_data(15) 548 assert len(expected_data3) == 15 549 file_name4 = "/unavailable_file" 550 expected_data = [] 551 expected_data.extend(expected_data1) 552 expected_data.extend(expected_data2) 553 expected_data.extend(expected_data3) 554 with TestPipeline() as pipeline: 555 pcoll = ( 556 pipeline 557 | 'Create' >> Create([file_name1, file_name2, file_name3, file_name4]) 558 | 'ReadAll' >> ReadAllFromText()) 559 assert_that(pcoll, equal_to(expected_data)) 560 561 class _WriteFilesFn(beam.DoFn): 562 """writes a couple of files with deferral.""" 563 COUNT_STATE = CombiningValueStateSpec('count', combine_fn=sum) 564 565 def __init__(self, temp_path): 566 self.temp_path = temp_path 567 568 def process(self, element, count_state=beam.DoFn.StateParam(COUNT_STATE)): 569 counter = count_state.read() 570 if counter == 0: 571 count_state.add(1) 572 with open(FileSystems.join(self.temp_path, 'file1'), 'w') as f: 573 f.write('second A\nsecond B') 574 with open(FileSystems.join(self.temp_path, 'file2'), 'w') as f: 575 f.write('first') 576 # convert dumb key to basename in output 577 basename = FileSystems.split(element[1][0])[1] 578 content = element[1][1] 579 yield basename, content 580 581 def test_read_all_continuously_new(self): 582 with TempDir() as tempdir, TestPipeline() as pipeline: 583 temp_path = tempdir.get_path() 584 # create a temp file at the beginning 585 with open(FileSystems.join(temp_path, 'file1'), 'w') as f: 586 f.write('first') 587 match_pattern = FileSystems.join(temp_path, '*') 588 interval = 0.5 589 last = 2 590 p_read_once = ( 591 pipeline 592 | 'Continuously read new files' >> ReadAllFromTextContinuously( 593 match_pattern, 594 with_filename=True, 595 start_timestamp=Timestamp.now(), 596 interval=interval, 597 stop_timestamp=Timestamp.now() + last, 598 match_updated_files=False) 599 | 'add dumb key' >> beam.Map(lambda x: (0, x)) 600 | 601 'Write files on-the-fly' >> beam.ParDo(self._WriteFilesFn(temp_path))) 602 assert_that( 603 p_read_once, 604 equal_to([('file1', 'first'), ('file2', 'first')]), 605 label='assert read new files results') 606 607 def test_read_all_continuously_update(self): 608 with TempDir() as tempdir, TestPipeline() as pipeline: 609 temp_path = tempdir.get_path() 610 # create a temp file at the beginning 611 with open(FileSystems.join(temp_path, 'file1'), 'w') as f: 612 f.write('first') 613 match_pattern = FileSystems.join(temp_path, '*') 614 interval = 0.5 615 last = 2 616 p_read_upd = ( 617 pipeline 618 | 'Continuously read updated files' >> ReadAllFromTextContinuously( 619 match_pattern, 620 with_filename=True, 621 start_timestamp=Timestamp.now(), 622 interval=interval, 623 stop_timestamp=Timestamp.now() + last, 624 match_updated_files=True) 625 | 'add dumb key' >> beam.Map(lambda x: (0, x)) 626 | 627 'Write files on-the-fly' >> beam.ParDo(self._WriteFilesFn(temp_path))) 628 assert_that( 629 p_read_upd, 630 equal_to([('file1', 'first'), ('file1', 'second A'), 631 ('file1', 'second B'), ('file2', 'first')]), 632 label='assert read updated files results') 633 634 def test_read_from_text_single_file_with_coder(self): 635 file_name, expected_data = write_data(5) 636 assert len(expected_data) == 5 637 with TestPipeline() as pipeline: 638 pcoll = pipeline | 'Read' >> ReadFromText(file_name, coder=DummyCoder()) 639 assert_that(pcoll, equal_to([record * 2 for record in expected_data])) 640 641 def test_read_from_text_file_pattern(self): 642 pattern, expected_data = write_pattern([5, 3, 12, 8, 8, 4]) 643 assert len(expected_data) == 40 644 with TestPipeline() as pipeline: 645 pcoll = pipeline | 'Read' >> ReadFromText(pattern) 646 assert_that(pcoll, equal_to(expected_data)) 647 648 def test_read_from_text_with_file_name_file_pattern(self): 649 pattern, expected_data = write_pattern( 650 lines_per_file=[5, 5], return_filenames=True) 651 assert len(expected_data) == 10 652 with TestPipeline() as pipeline: 653 pcoll = pipeline | 'Read' >> ReadFromTextWithFilename(pattern) 654 assert_that(pcoll, equal_to(expected_data)) 655 656 def test_read_all_file_pattern(self): 657 pattern, expected_data = write_pattern([5, 3, 12, 8, 8, 4]) 658 assert len(expected_data) == 40 659 with TestPipeline() as pipeline: 660 pcoll = ( 661 pipeline 662 | 'Create' >> Create([pattern]) 663 | 'ReadAll' >> ReadAllFromText()) 664 assert_that(pcoll, equal_to(expected_data)) 665 666 def test_read_all_many_file_patterns(self): 667 pattern1, expected_data1 = write_pattern([5, 3, 12, 8, 8, 4]) 668 assert len(expected_data1) == 40 669 pattern2, expected_data2 = write_pattern([3, 7, 9]) 670 assert len(expected_data2) == 19 671 pattern3, expected_data3 = write_pattern([11, 20, 5, 5]) 672 assert len(expected_data3) == 41 673 expected_data = [] 674 expected_data.extend(expected_data1) 675 expected_data.extend(expected_data2) 676 expected_data.extend(expected_data3) 677 with TestPipeline() as pipeline: 678 pcoll = pipeline | 'Create' >> Create( 679 [pattern1, pattern2, pattern3]) | 'ReadAll' >> ReadAllFromText() 680 assert_that(pcoll, equal_to(expected_data)) 681 682 def test_read_all_with_filename(self): 683 pattern, expected_data = write_pattern([5, 3], return_filenames=True) 684 assert len(expected_data) == 8 685 686 with TestPipeline() as pipeline: 687 pcoll = ( 688 pipeline 689 | 'Create' >> Create([pattern]) 690 | 'ReadAll' >> ReadAllFromText(with_filename=True)) 691 assert_that(pcoll, equal_to(expected_data)) 692 693 def test_read_auto_bzip2(self): 694 _, lines = write_data(15) 695 with TempDir() as tempdir: 696 file_name = tempdir.create_temp_file(suffix='.bz2') 697 with bz2.BZ2File(file_name, 'wb') as f: 698 f.write('\n'.join(lines).encode('utf-8')) 699 700 with TestPipeline() as pipeline: 701 pcoll = pipeline | 'Read' >> ReadFromText(file_name) 702 assert_that(pcoll, equal_to(lines)) 703 704 def test_read_auto_deflate(self): 705 _, lines = write_data(15) 706 with TempDir() as tempdir: 707 file_name = tempdir.create_temp_file(suffix='.deflate') 708 with open(file_name, 'wb') as f: 709 f.write(zlib.compress('\n'.join(lines).encode('utf-8'))) 710 711 with TestPipeline() as pipeline: 712 pcoll = pipeline | 'Read' >> ReadFromText(file_name) 713 assert_that(pcoll, equal_to(lines)) 714 715 def test_read_auto_gzip(self): 716 _, lines = write_data(15) 717 with TempDir() as tempdir: 718 file_name = tempdir.create_temp_file(suffix='.gz') 719 720 with gzip.GzipFile(file_name, 'wb') as f: 721 f.write('\n'.join(lines).encode('utf-8')) 722 723 with TestPipeline() as pipeline: 724 pcoll = pipeline | 'Read' >> ReadFromText(file_name) 725 assert_that(pcoll, equal_to(lines)) 726 727 def test_read_bzip2(self): 728 _, lines = write_data(15) 729 with TempDir() as tempdir: 730 file_name = tempdir.create_temp_file() 731 with bz2.BZ2File(file_name, 'wb') as f: 732 f.write('\n'.join(lines).encode('utf-8')) 733 734 with TestPipeline() as pipeline: 735 pcoll = pipeline | 'Read' >> ReadFromText( 736 file_name, compression_type=CompressionTypes.BZIP2) 737 assert_that(pcoll, equal_to(lines)) 738 739 def test_read_corrupted_bzip2_fails(self): 740 _, lines = write_data(15) 741 with TempDir() as tempdir: 742 file_name = tempdir.create_temp_file() 743 with bz2.BZ2File(file_name, 'wb') as f: 744 f.write('\n'.join(lines).encode('utf-8')) 745 746 with open(file_name, 'wb') as f: 747 f.write(b'corrupt') 748 749 with self.assertRaises(Exception): 750 with TestPipeline() as pipeline: 751 pcoll = pipeline | 'Read' >> ReadFromText( 752 file_name, compression_type=CompressionTypes.BZIP2) 753 assert_that(pcoll, equal_to(lines)) 754 755 def test_read_bzip2_concat(self): 756 with TempDir() as tempdir: 757 bzip2_file_name1 = tempdir.create_temp_file() 758 lines = ['a', 'b', 'c'] 759 with bz2.BZ2File(bzip2_file_name1, 'wb') as dst: 760 data = '\n'.join(lines) + '\n' 761 dst.write(data.encode('utf-8')) 762 763 bzip2_file_name2 = tempdir.create_temp_file() 764 lines = ['p', 'q', 'r'] 765 with bz2.BZ2File(bzip2_file_name2, 'wb') as dst: 766 data = '\n'.join(lines) + '\n' 767 dst.write(data.encode('utf-8')) 768 769 bzip2_file_name3 = tempdir.create_temp_file() 770 lines = ['x', 'y', 'z'] 771 with bz2.BZ2File(bzip2_file_name3, 'wb') as dst: 772 data = '\n'.join(lines) + '\n' 773 dst.write(data.encode('utf-8')) 774 775 final_bzip2_file = tempdir.create_temp_file() 776 with open(bzip2_file_name1, 'rb') as src, open( 777 final_bzip2_file, 'wb') as dst: 778 dst.writelines(src.readlines()) 779 780 with open(bzip2_file_name2, 'rb') as src, open( 781 final_bzip2_file, 'ab') as dst: 782 dst.writelines(src.readlines()) 783 784 with open(bzip2_file_name3, 'rb') as src, open( 785 final_bzip2_file, 'ab') as dst: 786 dst.writelines(src.readlines()) 787 788 with TestPipeline() as pipeline: 789 lines = pipeline | 'ReadFromText' >> beam.io.ReadFromText( 790 final_bzip2_file, 791 compression_type=beam.io.filesystem.CompressionTypes.BZIP2) 792 793 expected = ['a', 'b', 'c', 'p', 'q', 'r', 'x', 'y', 'z'] 794 assert_that(lines, equal_to(expected)) 795 796 def test_read_deflate(self): 797 _, lines = write_data(15) 798 with TempDir() as tempdir: 799 file_name = tempdir.create_temp_file() 800 with open(file_name, 'wb') as f: 801 f.write(zlib.compress('\n'.join(lines).encode('utf-8'))) 802 803 with TestPipeline() as pipeline: 804 pcoll = pipeline | 'Read' >> ReadFromText( 805 file_name, 0, CompressionTypes.DEFLATE, True, coders.StrUtf8Coder()) 806 assert_that(pcoll, equal_to(lines)) 807 808 def test_read_corrupted_deflate_fails(self): 809 _, lines = write_data(15) 810 with TempDir() as tempdir: 811 file_name = tempdir.create_temp_file() 812 with open(file_name, 'wb') as f: 813 f.write(zlib.compress('\n'.join(lines).encode('utf-8'))) 814 815 with open(file_name, 'wb') as f: 816 f.write(b'corrupt') 817 818 with self.assertRaises(Exception): 819 with TestPipeline() as pipeline: 820 pcoll = pipeline | 'Read' >> ReadFromText( 821 file_name, 822 0, 823 CompressionTypes.DEFLATE, 824 True, 825 coders.StrUtf8Coder()) 826 assert_that(pcoll, equal_to(lines)) 827 828 def test_read_deflate_concat(self): 829 with TempDir() as tempdir: 830 deflate_file_name1 = tempdir.create_temp_file() 831 lines = ['a', 'b', 'c'] 832 with open(deflate_file_name1, 'wb') as dst: 833 data = '\n'.join(lines) + '\n' 834 dst.write(zlib.compress(data.encode('utf-8'))) 835 836 deflate_file_name2 = tempdir.create_temp_file() 837 lines = ['p', 'q', 'r'] 838 with open(deflate_file_name2, 'wb') as dst: 839 data = '\n'.join(lines) + '\n' 840 dst.write(zlib.compress(data.encode('utf-8'))) 841 842 deflate_file_name3 = tempdir.create_temp_file() 843 lines = ['x', 'y', 'z'] 844 with open(deflate_file_name3, 'wb') as dst: 845 data = '\n'.join(lines) + '\n' 846 dst.write(zlib.compress(data.encode('utf-8'))) 847 848 final_deflate_file = tempdir.create_temp_file() 849 with open(deflate_file_name1, 'rb') as src, \ 850 open(final_deflate_file, 'wb') as dst: 851 dst.writelines(src.readlines()) 852 853 with open(deflate_file_name2, 'rb') as src, \ 854 open(final_deflate_file, 'ab') as dst: 855 dst.writelines(src.readlines()) 856 857 with open(deflate_file_name3, 'rb') as src, \ 858 open(final_deflate_file, 'ab') as dst: 859 dst.writelines(src.readlines()) 860 861 with TestPipeline() as pipeline: 862 lines = pipeline | 'ReadFromText' >> beam.io.ReadFromText( 863 final_deflate_file, 864 compression_type=beam.io.filesystem.CompressionTypes.DEFLATE) 865 866 expected = ['a', 'b', 'c', 'p', 'q', 'r', 'x', 'y', 'z'] 867 assert_that(lines, equal_to(expected)) 868 869 def test_read_gzip(self): 870 _, lines = write_data(15) 871 with TempDir() as tempdir: 872 file_name = tempdir.create_temp_file() 873 with gzip.GzipFile(file_name, 'wb') as f: 874 f.write('\n'.join(lines).encode('utf-8')) 875 876 with TestPipeline() as pipeline: 877 pcoll = pipeline | 'Read' >> ReadFromText( 878 file_name, 0, CompressionTypes.GZIP, True, coders.StrUtf8Coder()) 879 assert_that(pcoll, equal_to(lines)) 880 881 def test_read_corrupted_gzip_fails(self): 882 _, lines = write_data(15) 883 with TempDir() as tempdir: 884 file_name = tempdir.create_temp_file() 885 with gzip.GzipFile(file_name, 'wb') as f: 886 f.write('\n'.join(lines).encode('utf-8')) 887 888 with open(file_name, 'wb') as f: 889 f.write(b'corrupt') 890 891 with self.assertRaises(Exception): 892 with TestPipeline() as pipeline: 893 pcoll = pipeline | 'Read' >> ReadFromText( 894 file_name, 0, CompressionTypes.GZIP, True, coders.StrUtf8Coder()) 895 assert_that(pcoll, equal_to(lines)) 896 897 def test_read_gzip_concat(self): 898 with TempDir() as tempdir: 899 gzip_file_name1 = tempdir.create_temp_file() 900 lines = ['a', 'b', 'c'] 901 with gzip.open(gzip_file_name1, 'wb') as dst: 902 data = '\n'.join(lines) + '\n' 903 dst.write(data.encode('utf-8')) 904 905 gzip_file_name2 = tempdir.create_temp_file() 906 lines = ['p', 'q', 'r'] 907 with gzip.open(gzip_file_name2, 'wb') as dst: 908 data = '\n'.join(lines) + '\n' 909 dst.write(data.encode('utf-8')) 910 911 gzip_file_name3 = tempdir.create_temp_file() 912 lines = ['x', 'y', 'z'] 913 with gzip.open(gzip_file_name3, 'wb') as dst: 914 data = '\n'.join(lines) + '\n' 915 dst.write(data.encode('utf-8')) 916 917 final_gzip_file = tempdir.create_temp_file() 918 with open(gzip_file_name1, 'rb') as src, \ 919 open(final_gzip_file, 'wb') as dst: 920 dst.writelines(src.readlines()) 921 922 with open(gzip_file_name2, 'rb') as src, \ 923 open(final_gzip_file, 'ab') as dst: 924 dst.writelines(src.readlines()) 925 926 with open(gzip_file_name3, 'rb') as src, \ 927 open(final_gzip_file, 'ab') as dst: 928 dst.writelines(src.readlines()) 929 930 with TestPipeline() as pipeline: 931 lines = pipeline | 'ReadFromText' >> beam.io.ReadFromText( 932 final_gzip_file, 933 compression_type=beam.io.filesystem.CompressionTypes.GZIP) 934 935 expected = ['a', 'b', 'c', 'p', 'q', 'r', 'x', 'y', 'z'] 936 assert_that(lines, equal_to(expected)) 937 938 def test_read_all_gzip(self): 939 _, lines = write_data(100) 940 with TempDir() as tempdir: 941 file_name = tempdir.create_temp_file() 942 with gzip.GzipFile(file_name, 'wb') as f: 943 f.write('\n'.join(lines).encode('utf-8')) 944 with TestPipeline() as pipeline: 945 pcoll = ( 946 pipeline 947 | Create([file_name]) 948 | 'ReadAll' >> 949 ReadAllFromText(compression_type=CompressionTypes.GZIP)) 950 assert_that(pcoll, equal_to(lines)) 951 952 def test_read_gzip_large(self): 953 _, lines = write_data(10000) 954 with TempDir() as tempdir: 955 file_name = tempdir.create_temp_file() 956 957 with gzip.GzipFile(file_name, 'wb') as f: 958 f.write('\n'.join(lines).encode('utf-8')) 959 960 with TestPipeline() as pipeline: 961 pcoll = pipeline | 'Read' >> ReadFromText( 962 file_name, 0, CompressionTypes.GZIP, True, coders.StrUtf8Coder()) 963 assert_that(pcoll, equal_to(lines)) 964 965 def test_read_gzip_large_after_splitting(self): 966 _, lines = write_data(10000) 967 with TempDir() as tempdir: 968 file_name = tempdir.create_temp_file() 969 with gzip.GzipFile(file_name, 'wb') as f: 970 f.write('\n'.join(lines).encode('utf-8')) 971 972 source = TextSource( 973 file_name, 0, CompressionTypes.GZIP, True, coders.StrUtf8Coder()) 974 splits = list(source.split(desired_bundle_size=1000)) 975 976 if len(splits) > 1: 977 raise ValueError( 978 'FileBasedSource generated more than one initial ' 979 'split for a compressed file.') 980 981 reference_source_info = (source, None, None) 982 sources_info = ([ 983 (split.source, split.start_position, split.stop_position) 984 for split in splits 985 ]) 986 source_test_utils.assert_sources_equal_reference_source( 987 reference_source_info, sources_info) 988 989 def test_read_gzip_empty_file(self): 990 with TempDir() as tempdir: 991 file_name = tempdir.create_temp_file() 992 with TestPipeline() as pipeline: 993 pcoll = pipeline | 'Read' >> ReadFromText( 994 file_name, 0, CompressionTypes.GZIP, True, coders.StrUtf8Coder()) 995 assert_that(pcoll, equal_to([])) 996 997 def _remove_lines(self, lines, sublist_lengths, num_to_remove): 998 """Utility function to remove num_to_remove lines from each sublist. 999 1000 Args: 1001 lines: list of items. 1002 sublist_lengths: list of integers representing length of sublist 1003 corresponding to each source file. 1004 num_to_remove: number of lines to remove from each sublist. 1005 Returns: 1006 remaining lines. 1007 """ 1008 curr = 0 1009 result = [] 1010 for offset in sublist_lengths: 1011 end = curr + offset 1012 start = min(curr + num_to_remove, end) 1013 result += lines[start:end] 1014 curr += offset 1015 return result 1016 1017 def _read_skip_header_lines(self, file_or_pattern, skip_header_lines): 1018 """Simple wrapper function for instantiating TextSource.""" 1019 source = TextSource( 1020 file_or_pattern, 1021 0, 1022 CompressionTypes.UNCOMPRESSED, 1023 True, 1024 coders.StrUtf8Coder(), 1025 skip_header_lines=skip_header_lines) 1026 1027 range_tracker = source.get_range_tracker(None, None) 1028 return list(source.read(range_tracker)) 1029 1030 def test_read_skip_header_single(self): 1031 file_name, expected_data = write_data(TextSourceTest.DEFAULT_NUM_RECORDS) 1032 assert len(expected_data) == TextSourceTest.DEFAULT_NUM_RECORDS 1033 skip_header_lines = 1 1034 expected_data = self._remove_lines( 1035 expected_data, [TextSourceTest.DEFAULT_NUM_RECORDS], skip_header_lines) 1036 read_data = self._read_skip_header_lines(file_name, skip_header_lines) 1037 self.assertEqual(len(expected_data), len(read_data)) 1038 self.assertCountEqual(expected_data, read_data) 1039 1040 def test_read_skip_header_pattern(self): 1041 line_counts = [ 1042 TextSourceTest.DEFAULT_NUM_RECORDS * 5, 1043 TextSourceTest.DEFAULT_NUM_RECORDS * 3, 1044 TextSourceTest.DEFAULT_NUM_RECORDS * 12, 1045 TextSourceTest.DEFAULT_NUM_RECORDS * 8, 1046 TextSourceTest.DEFAULT_NUM_RECORDS * 8, 1047 TextSourceTest.DEFAULT_NUM_RECORDS * 4 1048 ] 1049 skip_header_lines = 2 1050 pattern, data = write_pattern(line_counts) 1051 1052 expected_data = self._remove_lines(data, line_counts, skip_header_lines) 1053 read_data = self._read_skip_header_lines(pattern, skip_header_lines) 1054 self.assertEqual(len(expected_data), len(read_data)) 1055 self.assertCountEqual(expected_data, read_data) 1056 1057 def test_read_skip_header_pattern_insufficient_lines(self): 1058 line_counts = [ 1059 5, 1060 3, # Fewer lines in file than we want to skip 1061 12, 1062 8, 1063 8, 1064 4 1065 ] 1066 skip_header_lines = 4 1067 pattern, data = write_pattern(line_counts) 1068 1069 data = self._remove_lines(data, line_counts, skip_header_lines) 1070 read_data = self._read_skip_header_lines(pattern, skip_header_lines) 1071 self.assertEqual(len(data), len(read_data)) 1072 self.assertCountEqual(data, read_data) 1073 1074 def test_read_gzip_with_skip_lines(self): 1075 _, lines = write_data(15) 1076 with TempDir() as tempdir: 1077 file_name = tempdir.create_temp_file() 1078 with gzip.GzipFile(file_name, 'wb') as f: 1079 f.write('\n'.join(lines).encode('utf-8')) 1080 1081 with TestPipeline() as pipeline: 1082 pcoll = pipeline | 'Read' >> ReadFromText( 1083 file_name, 1084 0, 1085 CompressionTypes.GZIP, 1086 True, 1087 coders.StrUtf8Coder(), 1088 skip_header_lines=2) 1089 assert_that(pcoll, equal_to(lines[2:])) 1090 1091 def test_read_after_splitting_skip_header(self): 1092 file_name, expected_data = write_data(100) 1093 assert len(expected_data) == 100 1094 source = TextSource( 1095 file_name, 1096 0, 1097 CompressionTypes.UNCOMPRESSED, 1098 True, 1099 coders.StrUtf8Coder(), 1100 skip_header_lines=2) 1101 splits = list(source.split(desired_bundle_size=33)) 1102 1103 reference_source_info = (source, None, None) 1104 sources_info = ([(split.source, split.start_position, split.stop_position) 1105 for split in splits]) 1106 self.assertGreater(len(sources_info), 1) 1107 reference_lines = source_test_utils.read_from_source(*reference_source_info) 1108 split_lines = [] 1109 for source_info in sources_info: 1110 split_lines.extend(source_test_utils.read_from_source(*source_info)) 1111 1112 self.assertEqual(expected_data[2:], reference_lines) 1113 self.assertEqual(reference_lines, split_lines) 1114 1115 def test_custom_delimiter_read_from_text(self): 1116 file_name, expected_data = write_data( 1117 5, eol=EOL.CUSTOM_DELIMITER, custom_delimiter=b'@#') 1118 assert len(expected_data) == 5 1119 with TestPipeline() as pipeline: 1120 pcoll = pipeline | 'Read' >> ReadFromText(file_name, delimiter=b'@#') 1121 assert_that(pcoll, equal_to(expected_data)) 1122 1123 def test_custom_delimiter_read_all_single_file(self): 1124 file_name, expected_data = write_data( 1125 5, eol=EOL.CUSTOM_DELIMITER, custom_delimiter=b'@#') 1126 assert len(expected_data) == 5 1127 with TestPipeline() as pipeline: 1128 pcoll = pipeline | 'Create' >> Create( 1129 [file_name]) | 'ReadAll' >> ReadAllFromText(delimiter=b'@#') 1130 assert_that(pcoll, equal_to(expected_data)) 1131 1132 def test_invalid_delimiters_are_rejected(self): 1133 file_name, _ = write_data(1) 1134 for delimiter in (b'', '', '\r\n', 'a', 1): 1135 with self.assertRaises( 1136 ValueError, msg='Delimiter must be a non-empty bytes sequence.'): 1137 _ = TextSource( 1138 file_pattern=file_name, 1139 min_bundle_size=0, 1140 buffer_size=6, 1141 compression_type=CompressionTypes.UNCOMPRESSED, 1142 strip_trailing_newlines=True, 1143 coder=coders.StrUtf8Coder(), 1144 delimiter=delimiter, 1145 ) 1146 1147 def test_non_self_overlapping_delimiter_is_accepted(self): 1148 file_name, _ = write_data(1) 1149 for delimiter in (b'\n', b'\r\n', b'*', b'abc', b'cabdab', b'abcabd'): 1150 _ = TextSource( 1151 file_pattern=file_name, 1152 min_bundle_size=0, 1153 buffer_size=6, 1154 compression_type=CompressionTypes.UNCOMPRESSED, 1155 strip_trailing_newlines=True, 1156 coder=coders.StrUtf8Coder(), 1157 delimiter=delimiter, 1158 ) 1159 1160 def test_self_overlapping_delimiter_is_rejected(self): 1161 file_name, _ = write_data(1) 1162 for delimiter in (b'||', b'***', b'aba', b'abcab'): 1163 with self.assertRaises(ValueError, 1164 msg='Delimiter must not self-overlap.'): 1165 _ = TextSource( 1166 file_pattern=file_name, 1167 min_bundle_size=0, 1168 buffer_size=6, 1169 compression_type=CompressionTypes.UNCOMPRESSED, 1170 strip_trailing_newlines=True, 1171 coder=coders.StrUtf8Coder(), 1172 delimiter=delimiter, 1173 ) 1174 1175 def test_read_with_customer_delimiter(self): 1176 delimiters = [ 1177 b'\n', 1178 b'\r\n', 1179 b'*|', 1180 b'*', 1181 b'*=-', 1182 ] 1183 1184 for delimiter in delimiters: 1185 file_name, expected_data = write_data( 1186 10, 1187 eol=EOL.CUSTOM_DELIMITER, 1188 custom_delimiter=delimiter) 1189 1190 assert len(expected_data) == 10 1191 source = TextSource( 1192 file_pattern=file_name, 1193 min_bundle_size=0, 1194 compression_type=CompressionTypes.UNCOMPRESSED, 1195 strip_trailing_newlines=True, 1196 coder=coders.StrUtf8Coder(), 1197 delimiter=delimiter) 1198 range_tracker = source.get_range_tracker(None, None) 1199 read_data = list(source.read(range_tracker)) 1200 1201 self.assertEqual(read_data, expected_data) 1202 1203 def test_read_with_custom_delimiter_around_split_point(self): 1204 for delimiter in (b'\n', b'\r\n', b'@#', b'abc'): 1205 file_name, expected_data = write_data( 1206 20, 1207 eol=EOL.CUSTOM_DELIMITER, 1208 custom_delimiter=delimiter) 1209 assert len(expected_data) == 20 1210 for desired_bundle_size in (4, 5, 6, 7): 1211 source = TextSource( 1212 file_name, 1213 0, 1214 CompressionTypes.UNCOMPRESSED, 1215 True, 1216 coders.StrUtf8Coder(), 1217 delimiter=delimiter) 1218 splits = list(source.split(desired_bundle_size=desired_bundle_size)) 1219 1220 reference_source_info = (source, None, None) 1221 sources_info = ([ 1222 (split.source, split.start_position, split.stop_position) 1223 for split in splits 1224 ]) 1225 source_test_utils.assert_sources_equal_reference_source( 1226 reference_source_info, sources_info) 1227 1228 def test_read_with_customer_delimiter_truncated(self): 1229 """ 1230 Corner case: delimiter truncated at the end of the file 1231 Use delimiter with length = 3, buffer_size = 6 1232 and line_value with length = 4 1233 to split the delimiter 1234 """ 1235 delimiter = b'@$*' 1236 1237 file_name, expected_data = write_data( 1238 10, 1239 eol=EOL.CUSTOM_DELIMITER, 1240 line_value=b'a' * 4, 1241 custom_delimiter=delimiter) 1242 1243 assert len(expected_data) == 10 1244 source = TextSource( 1245 file_pattern=file_name, 1246 min_bundle_size=0, 1247 buffer_size=6, 1248 compression_type=CompressionTypes.UNCOMPRESSED, 1249 strip_trailing_newlines=True, 1250 coder=coders.StrUtf8Coder(), 1251 delimiter=delimiter, 1252 ) 1253 range_tracker = source.get_range_tracker(None, None) 1254 read_data = list(source.read(range_tracker)) 1255 1256 self.assertEqual(read_data, expected_data) 1257 1258 def test_read_with_customer_delimiter_over_buffer_size(self): 1259 """ 1260 Corner case: delimiter is on border of size of buffer 1261 """ 1262 file_name, expected_data = write_data(3, eol=EOL.CRLF, line_value=b'\rline') 1263 assert len(expected_data) == 3 1264 self._run_read_test( 1265 file_name, expected_data, buffer_size=7, delimiter=b'\r\n') 1266 1267 def test_read_with_customer_delimiter_truncated_and_not_equal(self): 1268 """ 1269 Corner case: delimiter truncated at the end of the file 1270 and only part of delimiter equal end of buffer 1271 1272 Use delimiter with length = 3, buffer_size = 6 1273 and line_value with length = 4 1274 to split the delimiter 1275 """ 1276 1277 write_delimiter = b'@$' 1278 read_delimiter = b'@$*' 1279 1280 file_name, expected_data = write_data( 1281 10, 1282 eol=EOL.CUSTOM_DELIMITER, 1283 line_value=b'a' * 4, 1284 custom_delimiter=write_delimiter) 1285 1286 # In this case check, that the line won't be splitted 1287 write_delimiter_encode = write_delimiter.decode('utf-8') 1288 expected_data_str = [ 1289 write_delimiter_encode.join(expected_data) + write_delimiter_encode 1290 ] 1291 1292 source = TextSource( 1293 file_pattern=file_name, 1294 min_bundle_size=0, 1295 buffer_size=6, 1296 compression_type=CompressionTypes.UNCOMPRESSED, 1297 strip_trailing_newlines=True, 1298 coder=coders.StrUtf8Coder(), 1299 delimiter=read_delimiter, 1300 ) 1301 range_tracker = source.get_range_tracker(None, None) 1302 1303 read_data = list(source.read(range_tracker)) 1304 1305 self.assertEqual(read_data, expected_data_str) 1306 1307 def test_read_crlf_split_by_buffer(self): 1308 file_name, expected_data = write_data(3, eol=EOL.CRLF) 1309 assert len(expected_data) == 3 1310 self._run_read_test(file_name, expected_data, buffer_size=6) 1311 1312 def test_read_escaped_lf(self): 1313 file_name, expected_data = write_data( 1314 self.DEFAULT_NUM_RECORDS, eol=EOL.LF, line_value=b'li\\\nne') 1315 assert len(expected_data) == self.DEFAULT_NUM_RECORDS 1316 self._run_read_test(file_name, expected_data, escapechar=b'\\') 1317 1318 def test_read_escaped_crlf(self): 1319 file_name, expected_data = write_data( 1320 TextSource.DEFAULT_READ_BUFFER_SIZE, 1321 eol=EOL.CRLF, 1322 line_value=b'li\\\r\\\nne') 1323 assert len(expected_data) == TextSource.DEFAULT_READ_BUFFER_SIZE 1324 self._run_read_test(file_name, expected_data, escapechar=b'\\') 1325 1326 def test_read_escaped_cr_before_not_escaped_lf(self): 1327 file_name, expected_data_temp = write_data( 1328 self.DEFAULT_NUM_RECORDS, eol=EOL.CRLF, line_value=b'li\\\r\nne') 1329 expected_data = [] 1330 for line in expected_data_temp: 1331 expected_data += line.split("\n") 1332 assert len(expected_data) == self.DEFAULT_NUM_RECORDS * 2 1333 self._run_read_test(file_name, expected_data, escapechar=b'\\') 1334 1335 def test_read_escaped_custom_delimiter_crlf(self): 1336 file_name, expected_data = write_data( 1337 self.DEFAULT_NUM_RECORDS, eol=EOL.CRLF, line_value=b'li\\\r\nne') 1338 assert len(expected_data) == self.DEFAULT_NUM_RECORDS 1339 self._run_read_test( 1340 file_name, expected_data, delimiter=b'\r\n', escapechar=b'\\') 1341 1342 def test_read_escaped_custom_delimiter(self): 1343 file_name, expected_data = write_data( 1344 TextSource.DEFAULT_READ_BUFFER_SIZE, 1345 eol=EOL.CUSTOM_DELIMITER, 1346 custom_delimiter=b'*|', 1347 line_value=b'li\\*|ne') 1348 assert len(expected_data) == TextSource.DEFAULT_READ_BUFFER_SIZE 1349 self._run_read_test( 1350 file_name, expected_data, delimiter=b'*|', escapechar=b'\\') 1351 1352 def test_read_escaped_lf_at_buffer_edge(self): 1353 file_name, expected_data = write_data(3, eol=EOL.LF, line_value=b'line\\\n') 1354 assert len(expected_data) == 3 1355 self._run_read_test( 1356 file_name, expected_data, buffer_size=5, escapechar=b'\\') 1357 1358 def test_read_escaped_crlf_split_by_buffer(self): 1359 file_name, expected_data = write_data( 1360 3, eol=EOL.CRLF, line_value=b'line\\\r\n') 1361 assert len(expected_data) == 3 1362 self._run_read_test( 1363 file_name, 1364 expected_data, 1365 buffer_size=6, 1366 delimiter=b'\r\n', 1367 escapechar=b'\\') 1368 1369 def test_read_escaped_lf_after_splitting(self): 1370 file_name, expected_data = write_data(3, line_value=b'line\\\n') 1371 assert len(expected_data) == 3 1372 source = TextSource( 1373 file_name, 1374 0, 1375 CompressionTypes.UNCOMPRESSED, 1376 True, 1377 coders.StrUtf8Coder(), 1378 escapechar=b'\\') 1379 splits = list(source.split(desired_bundle_size=6)) 1380 1381 reference_source_info = (source, None, None) 1382 sources_info = ([(split.source, split.start_position, split.stop_position) 1383 for split in splits]) 1384 source_test_utils.assert_sources_equal_reference_source( 1385 reference_source_info, sources_info) 1386 1387 def test_read_escaped_lf_after_splitting_many(self): 1388 file_name, expected_data = write_data( 1389 3, line_value=b'\\\\\\\\\\\n') # 5 escapes 1390 assert len(expected_data) == 3 1391 source = TextSource( 1392 file_name, 1393 0, 1394 CompressionTypes.UNCOMPRESSED, 1395 True, 1396 coders.StrUtf8Coder(), 1397 escapechar=b'\\') 1398 splits = list(source.split(desired_bundle_size=6)) 1399 1400 reference_source_info = (source, None, None) 1401 sources_info = ([(split.source, split.start_position, split.stop_position) 1402 for split in splits]) 1403 source_test_utils.assert_sources_equal_reference_source( 1404 reference_source_info, sources_info) 1405 1406 def test_read_escaped_escapechar_after_splitting(self): 1407 file_name, expected_data = write_data(3, line_value=b'line\\\\*|') 1408 assert len(expected_data) == 3 1409 source = TextSource( 1410 file_name, 1411 0, 1412 CompressionTypes.UNCOMPRESSED, 1413 True, 1414 coders.StrUtf8Coder(), 1415 delimiter=b'*|', 1416 escapechar=b'\\') 1417 splits = list(source.split(desired_bundle_size=8)) 1418 1419 reference_source_info = (source, None, None) 1420 sources_info = ([(split.source, split.start_position, split.stop_position) 1421 for split in splits]) 1422 source_test_utils.assert_sources_equal_reference_source( 1423 reference_source_info, sources_info) 1424 1425 def test_read_escaped_escapechar_after_splitting_many(self): 1426 file_name, expected_data = write_data( 1427 3, line_value=b'\\\\\\\\\\\\*|') # 6 escapes 1428 assert len(expected_data) == 3 1429 source = TextSource( 1430 file_name, 1431 0, 1432 CompressionTypes.UNCOMPRESSED, 1433 True, 1434 coders.StrUtf8Coder(), 1435 delimiter=b'*|', 1436 escapechar=b'\\') 1437 splits = list(source.split(desired_bundle_size=8)) 1438 1439 reference_source_info = (source, None, None) 1440 sources_info = ([(split.source, split.start_position, split.stop_position) 1441 for split in splits]) 1442 source_test_utils.assert_sources_equal_reference_source( 1443 reference_source_info, sources_info) 1444 1445 1446 class TextSinkTest(unittest.TestCase): 1447 def setUp(self): 1448 super().setUp() 1449 self.lines = [b'Line %d' % d for d in range(100)] 1450 self.tempdir = tempfile.mkdtemp() 1451 self.path = self._create_temp_file() 1452 1453 def tearDown(self): 1454 if os.path.exists(self.tempdir): 1455 shutil.rmtree(self.tempdir) 1456 1457 def _create_temp_file(self, name='', suffix=''): 1458 if not name: 1459 name = tempfile.template 1460 file_name = tempfile.NamedTemporaryFile( 1461 delete=True, prefix=name, dir=self.tempdir, suffix=suffix).name 1462 return file_name 1463 1464 def _write_lines(self, sink, lines): 1465 f = sink.open(self.path) 1466 for line in lines: 1467 sink.write_record(f, line) 1468 sink.close(f) 1469 1470 def test_write_text_file(self): 1471 sink = TextSink(self.path) 1472 self._write_lines(sink, self.lines) 1473 1474 with open(self.path, 'rb') as f: 1475 self.assertEqual(f.read().splitlines(), self.lines) 1476 1477 def test_write_text_file_empty(self): 1478 sink = TextSink(self.path) 1479 self._write_lines(sink, []) 1480 1481 with open(self.path, 'rb') as f: 1482 self.assertEqual(f.read().splitlines(), []) 1483 1484 def test_write_bzip2_file(self): 1485 sink = TextSink(self.path, compression_type=CompressionTypes.BZIP2) 1486 self._write_lines(sink, self.lines) 1487 1488 with bz2.BZ2File(self.path, 'rb') as f: 1489 self.assertEqual(f.read().splitlines(), self.lines) 1490 1491 def test_write_bzip2_file_auto(self): 1492 self.path = self._create_temp_file(suffix='.bz2') 1493 sink = TextSink(self.path) 1494 self._write_lines(sink, self.lines) 1495 1496 with bz2.BZ2File(self.path, 'rb') as f: 1497 self.assertEqual(f.read().splitlines(), self.lines) 1498 1499 def test_write_gzip_file(self): 1500 sink = TextSink(self.path, compression_type=CompressionTypes.GZIP) 1501 self._write_lines(sink, self.lines) 1502 1503 with gzip.GzipFile(self.path, 'rb') as f: 1504 self.assertEqual(f.read().splitlines(), self.lines) 1505 1506 def test_write_gzip_file_auto(self): 1507 self.path = self._create_temp_file(suffix='.gz') 1508 sink = TextSink(self.path) 1509 self._write_lines(sink, self.lines) 1510 1511 with gzip.GzipFile(self.path, 'rb') as f: 1512 self.assertEqual(f.read().splitlines(), self.lines) 1513 1514 def test_write_gzip_file_empty(self): 1515 sink = TextSink(self.path, compression_type=CompressionTypes.GZIP) 1516 self._write_lines(sink, []) 1517 1518 with gzip.GzipFile(self.path, 'rb') as f: 1519 self.assertEqual(f.read().splitlines(), []) 1520 1521 def test_write_deflate_file(self): 1522 sink = TextSink(self.path, compression_type=CompressionTypes.DEFLATE) 1523 self._write_lines(sink, self.lines) 1524 1525 with open(self.path, 'rb') as f: 1526 self.assertEqual(zlib.decompress(f.read()).splitlines(), self.lines) 1527 1528 def test_write_deflate_file_auto(self): 1529 self.path = self._create_temp_file(suffix='.deflate') 1530 sink = TextSink(self.path) 1531 self._write_lines(sink, self.lines) 1532 1533 with open(self.path, 'rb') as f: 1534 self.assertEqual(zlib.decompress(f.read()).splitlines(), self.lines) 1535 1536 def test_write_deflate_file_empty(self): 1537 sink = TextSink(self.path, compression_type=CompressionTypes.DEFLATE) 1538 self._write_lines(sink, []) 1539 1540 with open(self.path, 'rb') as f: 1541 self.assertEqual(zlib.decompress(f.read()).splitlines(), []) 1542 1543 def test_write_text_file_with_header(self): 1544 header = b'header1\nheader2' 1545 sink = TextSink(self.path, header=header) 1546 self._write_lines(sink, self.lines) 1547 1548 with open(self.path, 'rb') as f: 1549 self.assertEqual(f.read().splitlines(), header.splitlines() + self.lines) 1550 1551 def test_write_text_file_with_footer(self): 1552 footer = b'footer1\nfooter2' 1553 sink = TextSink(self.path, footer=footer) 1554 self._write_lines(sink, self.lines) 1555 1556 with open(self.path, 'rb') as f: 1557 self.assertEqual(f.read().splitlines(), self.lines + footer.splitlines()) 1558 1559 def test_write_text_file_empty_with_header(self): 1560 header = b'header1\nheader2' 1561 sink = TextSink(self.path, header=header) 1562 self._write_lines(sink, []) 1563 1564 with open(self.path, 'rb') as f: 1565 self.assertEqual(f.read().splitlines(), header.splitlines()) 1566 1567 def test_write_pipeline(self): 1568 with TestPipeline() as pipeline: 1569 pcoll = pipeline | beam.core.Create(self.lines) 1570 pcoll | 'Write' >> WriteToText(self.path) # pylint: disable=expression-not-assigned 1571 1572 read_result = [] 1573 for file_name in glob.glob(self.path + '*'): 1574 with open(file_name, 'rb') as f: 1575 read_result.extend(f.read().splitlines()) 1576 1577 self.assertEqual(sorted(read_result), sorted(self.lines)) 1578 1579 def test_write_pipeline_non_globalwindow_input(self): 1580 with TestPipeline() as p: 1581 _ = ( 1582 p 1583 | beam.core.Create(self.lines) 1584 | beam.WindowInto(beam.transforms.window.FixedWindows(1)) 1585 | 'Write' >> WriteToText(self.path)) 1586 1587 read_result = [] 1588 for file_name in glob.glob(self.path + '*'): 1589 with open(file_name, 'rb') as f: 1590 read_result.extend(f.read().splitlines()) 1591 1592 self.assertEqual(sorted(read_result), sorted(self.lines)) 1593 1594 def test_write_pipeline_auto_compression(self): 1595 with TestPipeline() as pipeline: 1596 pcoll = pipeline | beam.core.Create(self.lines) 1597 pcoll | 'Write' >> WriteToText(self.path, file_name_suffix='.gz') # pylint: disable=expression-not-assigned 1598 1599 read_result = [] 1600 for file_name in glob.glob(self.path + '*'): 1601 with gzip.GzipFile(file_name, 'rb') as f: 1602 read_result.extend(f.read().splitlines()) 1603 1604 self.assertEqual(sorted(read_result), sorted(self.lines)) 1605 1606 def test_write_pipeline_auto_compression_unsharded(self): 1607 with TestPipeline() as pipeline: 1608 pcoll = pipeline | 'Create' >> beam.core.Create(self.lines) 1609 pcoll | 'Write' >> WriteToText( # pylint: disable=expression-not-assigned 1610 self.path + '.gz', 1611 shard_name_template='') 1612 1613 read_result = [] 1614 for file_name in glob.glob(self.path + '*'): 1615 with gzip.GzipFile(file_name, 'rb') as f: 1616 read_result.extend(f.read().splitlines()) 1617 1618 self.assertEqual(sorted(read_result), sorted(self.lines)) 1619 1620 def test_write_pipeline_header(self): 1621 with TestPipeline() as pipeline: 1622 pcoll = pipeline | 'Create' >> beam.core.Create(self.lines) 1623 header_text = 'foo' 1624 pcoll | 'Write' >> WriteToText( # pylint: disable=expression-not-assigned 1625 self.path + '.gz', 1626 shard_name_template='', 1627 header=header_text) 1628 1629 read_result = [] 1630 for file_name in glob.glob(self.path + '*'): 1631 with gzip.GzipFile(file_name, 'rb') as f: 1632 read_result.extend(f.read().splitlines()) 1633 # header_text is automatically encoded in WriteToText 1634 self.assertEqual(read_result[0], header_text.encode('utf-8')) 1635 self.assertEqual(sorted(read_result[1:]), sorted(self.lines)) 1636 1637 def test_write_pipeline_footer(self): 1638 with TestPipeline() as pipeline: 1639 footer_text = 'footer' 1640 pcoll = pipeline | beam.core.Create(self.lines) 1641 pcoll | 'Write' >> WriteToText( # pylint: disable=expression-not-assigned 1642 self.path, 1643 footer=footer_text) 1644 1645 read_result = [] 1646 for file_name in glob.glob(self.path + '*'): 1647 with open(file_name, 'rb') as f: 1648 read_result.extend(f.read().splitlines()) 1649 1650 self.assertEqual(sorted(read_result[:-1]), sorted(self.lines)) 1651 self.assertEqual(read_result[-1], footer_text.encode('utf-8')) 1652 1653 def test_write_empty(self): 1654 with TestPipeline() as p: 1655 # pylint: disable=expression-not-assigned 1656 p | beam.core.Create([]) | WriteToText(self.path) 1657 1658 outputs = glob.glob(self.path + '*') 1659 self.assertEqual(len(outputs), 1) 1660 with open(outputs[0], 'rb') as f: 1661 self.assertEqual(list(f.read().splitlines()), []) 1662 1663 def test_write_empty_skipped(self): 1664 with TestPipeline() as p: 1665 # pylint: disable=expression-not-assigned 1666 p | beam.core.Create([]) | WriteToText(self.path, skip_if_empty=True) 1667 1668 outputs = list(glob.glob(self.path + '*')) 1669 self.assertEqual(outputs, []) 1670 1671 def test_write_max_records_per_shard(self): 1672 records_per_shard = 13 1673 lines = [str(i).encode('utf-8') for i in range(100)] 1674 with TestPipeline() as p: 1675 # pylint: disable=expression-not-assigned 1676 p | beam.core.Create(lines) | WriteToText( 1677 self.path, max_records_per_shard=records_per_shard) 1678 1679 read_result = [] 1680 for file_name in glob.glob(self.path + '*'): 1681 with open(file_name, 'rb') as f: 1682 shard_lines = list(f.read().splitlines()) 1683 self.assertLessEqual(len(shard_lines), records_per_shard) 1684 read_result.extend(shard_lines) 1685 self.assertEqual(sorted(read_result), sorted(lines)) 1686 1687 def test_write_max_bytes_per_shard(self): 1688 bytes_per_shard = 300 1689 max_len = 100 1690 lines = [b'x' * i for i in range(max_len)] 1691 header = b'a' * 20 1692 footer = b'b' * 30 1693 with TestPipeline() as p: 1694 # pylint: disable=expression-not-assigned 1695 p | beam.core.Create(lines) | WriteToText( 1696 self.path, 1697 header=header, 1698 footer=footer, 1699 max_bytes_per_shard=bytes_per_shard) 1700 1701 read_result = [] 1702 for file_name in glob.glob(self.path + '*'): 1703 with open(file_name, 'rb') as f: 1704 contents = f.read() 1705 self.assertLessEqual( 1706 len(contents), bytes_per_shard + max_len + len(footer) + 2) 1707 shard_lines = list(contents.splitlines()) 1708 self.assertEqual(shard_lines[0], header) 1709 self.assertEqual(shard_lines[-1], footer) 1710 read_result.extend(shard_lines[1:-1]) 1711 self.assertEqual(sorted(read_result), sorted(lines)) 1712 1713 1714 class CsvTest(unittest.TestCase): 1715 def test_csv_read_write(self): 1716 records = [beam.Row(a='str', b=ix) for ix in range(3)] 1717 with tempfile.TemporaryDirectory() as dest: 1718 with TestPipeline() as p: 1719 # pylint: disable=expression-not-assigned 1720 p | beam.Create(records) | beam.io.WriteToCsv(os.path.join(dest, 'out')) 1721 with TestPipeline() as p: 1722 pcoll = ( 1723 p 1724 | beam.io.ReadFromCsv(os.path.join(dest, 'out*')) 1725 | beam.Map(lambda t: beam.Row(**dict(zip(type(t)._fields, t))))) 1726 1727 assert_that(pcoll, equal_to(records)) 1728 1729 1730 class JsonTest(unittest.TestCase): 1731 def test_json_read_write(self): 1732 records = [beam.Row(a='str', b=ix) for ix in range(3)] 1733 with tempfile.TemporaryDirectory() as dest: 1734 with TestPipeline() as p: 1735 # pylint: disable=expression-not-assigned 1736 p | beam.Create(records) | beam.io.WriteToJson( 1737 os.path.join(dest, 'out')) 1738 with TestPipeline() as p: 1739 pcoll = ( 1740 p 1741 | beam.io.ReadFromJson(os.path.join(dest, 'out*')) 1742 | beam.Map(lambda t: beam.Row(**dict(zip(type(t)._fields, t))))) 1743 1744 assert_that(pcoll, equal_to(records)) 1745 1746 1747 if __name__ == '__main__': 1748 logging.getLogger().setLevel(logging.INFO) 1749 unittest.main()