github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/filebasedsource_test.py (about) 1 # Licensed to the Apache Software Foundation (ASF) under one or more 2 # contributor license agreements. See the NOTICE file distributed with 3 # this work for additional information regarding copyright ownership. 4 # The ASF licenses this file to You under the Apache License, Version 2.0 5 # (the "License"); you may not use this file except in compliance with 6 # the License. You may obtain a copy of the License at 7 # 8 # http://www.apache.org/licenses/LICENSE-2.0 9 # 10 # Unless required by applicable law or agreed to in writing, software 11 # distributed under the License is distributed on an "AS IS" BASIS, 12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 # See the License for the specific language governing permissions and 14 # limitations under the License. 15 # 16 17 # pytype: skip-file 18 19 import bz2 20 import gzip 21 import io 22 import logging 23 import math 24 import os 25 import random 26 import tempfile 27 import unittest 28 29 import hamcrest as hc 30 31 import apache_beam as beam 32 from apache_beam.io import filebasedsource 33 from apache_beam.io import iobase 34 from apache_beam.io import range_trackers 35 # importing following private classes for testing 36 from apache_beam.io.concat_source import ConcatSource 37 from apache_beam.io.filebasedsource import _SingleFileSource as SingleFileSource 38 from apache_beam.io.filebasedsource import FileBasedSource 39 from apache_beam.io.filesystem import CompressionTypes 40 from apache_beam.options.value_provider import RuntimeValueProvider 41 from apache_beam.options.value_provider import StaticValueProvider 42 from apache_beam.testing.test_pipeline import TestPipeline 43 from apache_beam.testing.util import assert_that 44 from apache_beam.testing.util import equal_to 45 from apache_beam.transforms.display import DisplayData 46 from apache_beam.transforms.display_test import DisplayDataItemMatcher 47 48 49 class LineSource(FileBasedSource): 50 def read_records(self, file_name, range_tracker): 51 f = self.open_file(file_name) 52 try: 53 start = range_tracker.start_position() 54 if start > 0: 55 # Any line that starts after 'start' does not belong to the current 56 # bundle. Seeking to (start - 1) and skipping a line moves the current 57 # position to the starting position of the first line that belongs to 58 # the current bundle. 59 start -= 1 60 f.seek(start) 61 line = f.readline() 62 start += len(line) 63 current = start 64 line = f.readline() 65 while range_tracker.try_claim(current): 66 # When the source is unsplittable, try_claim is not enough to determine 67 # whether the file has reached to the end. 68 if not line: 69 return 70 yield line.rstrip(b'\n') 71 current += len(line) 72 line = f.readline() 73 finally: 74 f.close() 75 76 77 class EOL(object): 78 LF = 1 79 CRLF = 2 80 MIXED = 3 81 LF_WITH_NOTHING_AT_LAST_LINE = 4 82 83 84 def write_data( 85 num_lines, 86 no_data=False, 87 directory=None, 88 prefix=tempfile.template, 89 eol=EOL.LF): 90 """Writes test data to a temporary file. 91 92 Args: 93 num_lines (int): The number of lines to write. 94 no_data (bool): If :data:`True`, empty lines will be written, otherwise 95 each line will contain a concatenation of b'line' and the line number. 96 directory (str): The name of the directory to create the temporary file in. 97 prefix (str): The prefix to use for the temporary file. 98 eol (int): The line ending to use when writing. 99 :class:`~apache_beam.io.filebasedsource_test.EOL` exposes attributes that 100 can be used here to define the eol. 101 102 Returns: 103 Tuple[str, List[bytes]]: A tuple of the filename and a list of the written 104 data. 105 """ 106 all_data = [] 107 with tempfile.NamedTemporaryFile(delete=False, dir=directory, 108 prefix=prefix) as f: 109 sep_values = [b'\n', b'\r\n'] 110 for i in range(num_lines): 111 data = b'' if no_data else b'line' + str(i).encode() 112 all_data.append(data) 113 114 if eol == EOL.LF: 115 sep = sep_values[0] 116 elif eol == EOL.CRLF: 117 sep = sep_values[1] 118 elif eol == EOL.MIXED: 119 sep = sep_values[i % len(sep_values)] 120 elif eol == EOL.LF_WITH_NOTHING_AT_LAST_LINE: 121 sep = b'' if i == (num_lines - 1) else sep_values[0] 122 else: 123 raise ValueError('Received unknown value %s for eol.' % eol) 124 125 f.write(data + sep) 126 127 return f.name, all_data 128 129 130 def _write_prepared_data( 131 data, directory=None, prefix=tempfile.template, suffix=''): 132 with tempfile.NamedTemporaryFile(delete=False, 133 dir=directory, 134 prefix=prefix, 135 suffix=suffix) as f: 136 f.write(data) 137 return f.name 138 139 140 def write_prepared_pattern(data, suffixes=None): 141 assert data, 'Data (%s) seems to be empty' % data 142 if suffixes is None: 143 suffixes = [''] * len(data) 144 temp_dir = tempfile.mkdtemp() 145 for i, d in enumerate(data): 146 file_name = _write_prepared_data( 147 d, temp_dir, prefix='mytemp', suffix=suffixes[i]) 148 return file_name[:file_name.rfind(os.path.sep)] + os.path.sep + 'mytemp*' 149 150 151 def write_pattern(lines_per_file, no_data=False): 152 """Writes a pattern of temporary files. 153 154 Args: 155 lines_per_file (List[int]): The number of lines to write per file. 156 no_data (bool): If :data:`True`, empty lines will be written, otherwise 157 each line will contain a concatenation of b'line' and the line number. 158 159 Returns: 160 Tuple[str, List[bytes]]: A tuple of the filename pattern and a list of the 161 written data. 162 """ 163 temp_dir = tempfile.mkdtemp() 164 165 all_data = [] 166 file_name = None 167 start_index = 0 168 for i in range(len(lines_per_file)): 169 file_name, data = write_data(lines_per_file[i], no_data=no_data, 170 directory=temp_dir, prefix='mytemp') 171 all_data.extend(data) 172 start_index += lines_per_file[i] 173 174 assert file_name 175 return ( 176 file_name[:file_name.rfind(os.path.sep)] + os.path.sep + 'mytemp*', 177 all_data) 178 179 180 class TestConcatSource(unittest.TestCase): 181 class DummySource(iobase.BoundedSource): 182 def __init__(self, values): 183 self._values = values 184 185 def split( 186 self, desired_bundle_size, start_position=None, stop_position=None): 187 # simply devides values into two bundles 188 middle = len(self._values) // 2 189 yield iobase.SourceBundle( 190 0.5, TestConcatSource.DummySource(self._values[:middle]), None, None) 191 yield iobase.SourceBundle( 192 0.5, TestConcatSource.DummySource(self._values[middle:]), None, None) 193 194 def get_range_tracker(self, start_position, stop_position): 195 if start_position is None: 196 start_position = 0 197 if stop_position is None: 198 stop_position = len(self._values) 199 200 return range_trackers.OffsetRangeTracker(start_position, stop_position) 201 202 def read(self, range_tracker): 203 for index, value in enumerate(self._values): 204 if not range_tracker.try_claim(index): 205 return 206 207 yield value 208 209 def estimate_size(self): 210 return len(self._values) # Assuming each value to be 1 byte. 211 212 def setUp(self): 213 # Reducing the size of thread pools. Without this test execution may fail in 214 # environments with limited amount of resources. 215 filebasedsource.MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 2 216 217 def test_read(self): 218 sources = [ 219 TestConcatSource.DummySource(range(start, start + 10)) 220 for start in [0, 10, 20] 221 ] 222 concat = ConcatSource(sources) 223 range_tracker = concat.get_range_tracker(None, None) 224 read_data = [value for value in concat.read(range_tracker)] 225 self.assertCountEqual(list(range(30)), read_data) 226 227 def test_split(self): 228 sources = [ 229 TestConcatSource.DummySource(list(range(start, start + 10))) 230 for start in [0, 10, 20] 231 ] 232 concat = ConcatSource(sources) 233 splits = [split for split in concat.split()] 234 self.assertEqual(6, len(splits)) 235 236 # Reading all splits 237 read_data = [] 238 for split in splits: 239 range_tracker_for_split = split.source.get_range_tracker( 240 split.start_position, split.stop_position) 241 read_data.extend( 242 [value for value in split.source.read(range_tracker_for_split)]) 243 self.assertCountEqual(list(range(30)), read_data) 244 245 def test_estimate_size(self): 246 sources = [ 247 TestConcatSource.DummySource(range(start, start + 10)) 248 for start in [0, 10, 20] 249 ] 250 concat = ConcatSource(sources) 251 self.assertEqual(30, concat.estimate_size()) 252 253 254 class TestFileBasedSource(unittest.TestCase): 255 def setUp(self): 256 # Reducing the size of thread pools. Without this test execution may fail in 257 # environments with limited amount of resources. 258 filebasedsource.MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 2 259 260 def test_string_or_value_provider_only(self): 261 str_file_pattern = tempfile.NamedTemporaryFile(delete=False).name 262 self.assertEqual( 263 str_file_pattern, FileBasedSource(str_file_pattern)._pattern.value) 264 265 static_vp_file_pattern = StaticValueProvider( 266 value_type=str, value=str_file_pattern) 267 self.assertEqual( 268 static_vp_file_pattern, 269 FileBasedSource(static_vp_file_pattern)._pattern) 270 271 runtime_vp_file_pattern = RuntimeValueProvider( 272 option_name='arg', value_type=str, default_value=str_file_pattern) 273 self.assertEqual( 274 runtime_vp_file_pattern, 275 FileBasedSource(runtime_vp_file_pattern)._pattern) 276 # Reset runtime options to avoid side-effects in other tests. 277 RuntimeValueProvider.set_runtime_options(None) 278 279 invalid_file_pattern = 123 280 with self.assertRaises(TypeError): 281 FileBasedSource(invalid_file_pattern) 282 283 def test_validation_file_exists(self): 284 file_name, _ = write_data(10) 285 LineSource(file_name) 286 287 def test_validation_directory_non_empty(self): 288 temp_dir = tempfile.mkdtemp() 289 file_name, _ = write_data(10, directory=temp_dir) 290 LineSource(file_name) 291 292 def test_validation_failing(self): 293 no_files_found_error = 'No files found based on the file pattern*' 294 with self.assertRaisesRegex(IOError, no_files_found_error): 295 LineSource('dummy_pattern') 296 with self.assertRaisesRegex(IOError, no_files_found_error): 297 temp_dir = tempfile.mkdtemp() 298 LineSource(os.path.join(temp_dir, '*')) 299 300 def test_validation_file_missing_verification_disabled(self): 301 LineSource('dummy_pattern', validate=False) 302 303 def test_fully_read_single_file(self): 304 file_name, expected_data = write_data(10) 305 assert len(expected_data) == 10 306 fbs = LineSource(file_name) 307 range_tracker = fbs.get_range_tracker(None, None) 308 read_data = [record for record in fbs.read(range_tracker)] 309 self.assertCountEqual(expected_data, read_data) 310 311 def test_single_file_display_data(self): 312 file_name, _ = write_data(10) 313 fbs = LineSource(file_name) 314 dd = DisplayData.create_from(fbs) 315 expected_items = [ 316 DisplayDataItemMatcher('file_pattern', file_name), 317 DisplayDataItemMatcher('compression', 'auto') 318 ] 319 hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) 320 321 def test_fully_read_file_pattern(self): 322 pattern, expected_data = write_pattern([5, 3, 12, 8, 8, 4]) 323 assert len(expected_data) == 40 324 fbs = LineSource(pattern) 325 range_tracker = fbs.get_range_tracker(None, None) 326 read_data = [record for record in fbs.read(range_tracker)] 327 self.assertCountEqual(expected_data, read_data) 328 329 def test_fully_read_file_pattern_with_empty_files(self): 330 pattern, expected_data = write_pattern([5, 0, 12, 0, 8, 0]) 331 assert len(expected_data) == 25 332 fbs = LineSource(pattern) 333 range_tracker = fbs.get_range_tracker(None, None) 334 read_data = [record for record in fbs.read(range_tracker)] 335 self.assertCountEqual(expected_data, read_data) 336 337 def test_estimate_size_of_file(self): 338 file_name, expected_data = write_data(10) 339 assert len(expected_data) == 10 340 fbs = LineSource(file_name) 341 self.assertEqual(10 * 6, fbs.estimate_size()) 342 343 def test_estimate_size_of_pattern(self): 344 pattern, expected_data = write_pattern([5, 3, 10, 8, 8, 4]) 345 assert len(expected_data) == 38 346 fbs = LineSource(pattern) 347 self.assertEqual(38 * 6, fbs.estimate_size()) 348 349 pattern, expected_data = write_pattern([5, 3, 9]) 350 assert len(expected_data) == 17 351 fbs = LineSource(pattern) 352 self.assertEqual(17 * 6, fbs.estimate_size()) 353 354 def test_estimate_size_with_sampling_same_size(self): 355 num_files = 2 * FileBasedSource.MIN_NUMBER_OF_FILES_TO_STAT 356 pattern, _ = write_pattern([10] * num_files) 357 # Each line will be of length 6 since write_pattern() uses 358 # ('line' + line number + '\n') as data. 359 self.assertEqual( 360 6 * 10 * num_files, FileBasedSource(pattern).estimate_size()) 361 362 def test_estimate_size_with_sampling_different_sizes(self): 363 num_files = 2 * FileBasedSource.MIN_NUMBER_OF_FILES_TO_STAT 364 365 # Each line will be of length 8 since write_pattern() uses 366 # ('line' + line number + '\n') as data. 367 base_size = 500 368 variance = 5 369 370 sizes = [] 371 for _ in range(num_files): 372 sizes.append( 373 int(random.uniform(base_size - variance, base_size + variance))) 374 pattern, _ = write_pattern(sizes) 375 tolerance = 0.05 376 self.assertAlmostEqual( 377 base_size * 8 * num_files, 378 FileBasedSource(pattern).estimate_size(), 379 delta=base_size * 8 * num_files * tolerance) 380 381 def test_splits_into_subranges(self): 382 pattern, expected_data = write_pattern([5, 9, 6]) 383 assert len(expected_data) == 20 384 fbs = LineSource(pattern) 385 splits = [split for split in fbs.split(desired_bundle_size=15)] 386 expected_num_splits = ( 387 math.ceil(float(6 * 5) / 15) + math.ceil(float(6 * 9) / 15) + 388 math.ceil(float(6 * 6) / 15)) 389 assert len(splits) == expected_num_splits 390 391 def test_read_splits_single_file(self): 392 file_name, expected_data = write_data(100) 393 assert len(expected_data) == 100 394 fbs = LineSource(file_name) 395 splits = [split for split in fbs.split(desired_bundle_size=33)] 396 397 # Reading all splits 398 read_data = [] 399 for split in splits: 400 source = split.source 401 range_tracker = source.get_range_tracker( 402 split.start_position, split.stop_position) 403 data_from_split = [data for data in source.read(range_tracker)] 404 read_data.extend(data_from_split) 405 406 self.assertCountEqual(expected_data, read_data) 407 408 def test_read_splits_file_pattern(self): 409 pattern, expected_data = write_pattern([34, 66, 40, 24, 24, 12]) 410 assert len(expected_data) == 200 411 fbs = LineSource(pattern) 412 splits = [split for split in fbs.split(desired_bundle_size=50)] 413 414 # Reading all splits 415 read_data = [] 416 for split in splits: 417 source = split.source 418 range_tracker = source.get_range_tracker( 419 split.start_position, split.stop_position) 420 data_from_split = [data for data in source.read(range_tracker)] 421 read_data.extend(data_from_split) 422 423 self.assertCountEqual(expected_data, read_data) 424 425 def _run_source_test(self, pattern, expected_data, splittable=True): 426 with TestPipeline() as pipeline: 427 pcoll = pipeline | 'Read' >> beam.io.Read( 428 LineSource(pattern, splittable=splittable)) 429 assert_that(pcoll, equal_to(expected_data)) 430 431 def test_source_file(self): 432 file_name, expected_data = write_data(100) 433 assert len(expected_data) == 100 434 self._run_source_test(file_name, expected_data) 435 436 def test_source_pattern(self): 437 pattern, expected_data = write_pattern([34, 66, 40, 24, 24, 12]) 438 assert len(expected_data) == 200 439 self._run_source_test(pattern, expected_data) 440 441 def test_unsplittable_does_not_split(self): 442 pattern, expected_data = write_pattern([5, 9, 6]) 443 assert len(expected_data) == 20 444 fbs = LineSource(pattern, splittable=False) 445 splits = [split for split in fbs.split(desired_bundle_size=15)] 446 self.assertEqual(3, len(splits)) 447 448 def test_source_file_unsplittable(self): 449 file_name, expected_data = write_data(100) 450 assert len(expected_data) == 100 451 self._run_source_test(file_name, expected_data, False) 452 453 def test_source_pattern_unsplittable(self): 454 pattern, expected_data = write_pattern([34, 66, 40, 24, 24, 12]) 455 assert len(expected_data) == 200 456 self._run_source_test(pattern, expected_data, False) 457 458 def test_read_file_bzip2(self): 459 _, lines = write_data(10) 460 filename = tempfile.NamedTemporaryFile( 461 delete=False, prefix=tempfile.template).name 462 with bz2.BZ2File(filename, 'wb') as f: 463 f.write(b'\n'.join(lines)) 464 465 with TestPipeline() as pipeline: 466 pcoll = pipeline | 'Read' >> beam.io.Read( 467 LineSource( 468 filename, 469 splittable=False, 470 compression_type=CompressionTypes.BZIP2)) 471 assert_that(pcoll, equal_to(lines)) 472 473 def test_read_file_gzip(self): 474 _, lines = write_data(10) 475 filename = tempfile.NamedTemporaryFile( 476 delete=False, prefix=tempfile.template).name 477 with gzip.GzipFile(filename, 'wb') as f: 478 f.write(b'\n'.join(lines)) 479 480 with TestPipeline() as pipeline: 481 pcoll = pipeline | 'Read' >> beam.io.Read( 482 LineSource( 483 filename, 484 splittable=False, 485 compression_type=CompressionTypes.GZIP)) 486 assert_that(pcoll, equal_to(lines)) 487 488 def test_read_pattern_bzip2(self): 489 _, lines = write_data(200) 490 splits = [0, 34, 100, 140, 164, 188, 200] 491 chunks = [lines[splits[i - 1]:splits[i]] for i in range(1, len(splits))] 492 compressed_chunks = [] 493 for c in chunks: 494 compressobj = bz2.BZ2Compressor() 495 compressed_chunks.append( 496 compressobj.compress(b'\n'.join(c)) + compressobj.flush()) 497 file_pattern = write_prepared_pattern(compressed_chunks) 498 with TestPipeline() as pipeline: 499 pcoll = pipeline | 'Read' >> beam.io.Read( 500 LineSource( 501 file_pattern, 502 splittable=False, 503 compression_type=CompressionTypes.BZIP2)) 504 assert_that(pcoll, equal_to(lines)) 505 506 def test_read_pattern_gzip(self): 507 _, lines = write_data(200) 508 splits = [0, 34, 100, 140, 164, 188, 200] 509 chunks = [lines[splits[i - 1]:splits[i]] for i in range(1, len(splits))] 510 compressed_chunks = [] 511 for c in chunks: 512 out = io.BytesIO() 513 with gzip.GzipFile(fileobj=out, mode="wb") as f: 514 f.write(b'\n'.join(c)) 515 compressed_chunks.append(out.getvalue()) 516 file_pattern = write_prepared_pattern(compressed_chunks) 517 with TestPipeline() as pipeline: 518 pcoll = pipeline | 'Read' >> beam.io.Read( 519 LineSource( 520 file_pattern, 521 splittable=False, 522 compression_type=CompressionTypes.GZIP)) 523 assert_that(pcoll, equal_to(lines)) 524 525 def test_read_auto_single_file_bzip2(self): 526 _, lines = write_data(10) 527 filename = tempfile.NamedTemporaryFile( 528 delete=False, prefix=tempfile.template, suffix='.bz2').name 529 with bz2.BZ2File(filename, 'wb') as f: 530 f.write(b'\n'.join(lines)) 531 532 with TestPipeline() as pipeline: 533 pcoll = pipeline | 'Read' >> beam.io.Read( 534 LineSource(filename, compression_type=CompressionTypes.AUTO)) 535 assert_that(pcoll, equal_to(lines)) 536 537 def test_read_auto_single_file_gzip(self): 538 _, lines = write_data(10) 539 filename = tempfile.NamedTemporaryFile( 540 delete=False, prefix=tempfile.template, suffix='.gz').name 541 with gzip.GzipFile(filename, 'wb') as f: 542 f.write(b'\n'.join(lines)) 543 544 with TestPipeline() as pipeline: 545 pcoll = pipeline | 'Read' >> beam.io.Read( 546 LineSource(filename, compression_type=CompressionTypes.AUTO)) 547 assert_that(pcoll, equal_to(lines)) 548 549 def test_read_auto_pattern(self): 550 _, lines = write_data(200) 551 splits = [0, 34, 100, 140, 164, 188, 200] 552 chunks = [lines[splits[i - 1]:splits[i]] for i in range(1, len(splits))] 553 compressed_chunks = [] 554 for c in chunks: 555 out = io.BytesIO() 556 with gzip.GzipFile(fileobj=out, mode="wb") as f: 557 f.write(b'\n'.join(c)) 558 compressed_chunks.append(out.getvalue()) 559 file_pattern = write_prepared_pattern( 560 compressed_chunks, suffixes=['.gz'] * len(chunks)) 561 with TestPipeline() as pipeline: 562 pcoll = pipeline | 'Read' >> beam.io.Read( 563 LineSource(file_pattern, compression_type=CompressionTypes.AUTO)) 564 assert_that(pcoll, equal_to(lines)) 565 566 def test_read_auto_pattern_compressed_and_uncompressed(self): 567 _, lines = write_data(200) 568 splits = [0, 34, 100, 140, 164, 188, 200] 569 chunks = [lines[splits[i - 1]:splits[i]] for i in range(1, len(splits))] 570 chunks_to_write = [] 571 for i, c in enumerate(chunks): 572 if i % 2 == 0: 573 out = io.BytesIO() 574 with gzip.GzipFile(fileobj=out, mode="wb") as f: 575 f.write(b'\n'.join(c)) 576 chunks_to_write.append(out.getvalue()) 577 else: 578 chunks_to_write.append(b'\n'.join(c)) 579 file_pattern = write_prepared_pattern( 580 chunks_to_write, suffixes=(['.gz', ''] * 3)) 581 with TestPipeline() as pipeline: 582 pcoll = pipeline | 'Read' >> beam.io.Read( 583 LineSource(file_pattern, compression_type=CompressionTypes.AUTO)) 584 assert_that(pcoll, equal_to(lines)) 585 586 def test_splits_get_coder_from_fbs(self): 587 class DummyCoder(object): 588 val = 12345 589 590 class FileBasedSourceWithCoder(LineSource): 591 def default_output_coder(self): 592 return DummyCoder() 593 594 pattern, expected_data = write_pattern([34, 66, 40, 24, 24, 12]) 595 self.assertEqual(200, len(expected_data)) 596 fbs = FileBasedSourceWithCoder(pattern) 597 splits = [split for split in fbs.split(desired_bundle_size=50)] 598 self.assertTrue(len(splits)) 599 for split in splits: 600 self.assertEqual(DummyCoder.val, split.source.default_output_coder().val) 601 602 603 class TestSingleFileSource(unittest.TestCase): 604 def setUp(self): 605 # Reducing the size of thread pools. Without this test execution may fail in 606 # environments with limited amount of resources. 607 filebasedsource.MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 2 608 609 def test_source_creation_fails_for_non_number_offsets(self): 610 start_not_a_number_error = 'start_offset must be a number*' 611 stop_not_a_number_error = 'stop_offset must be a number*' 612 file_name = 'dummy_pattern' 613 fbs = LineSource(file_name, validate=False) 614 615 with self.assertRaisesRegex(TypeError, start_not_a_number_error): 616 SingleFileSource( 617 fbs, file_name='dummy_file', start_offset='aaa', stop_offset='bbb') 618 with self.assertRaisesRegex(TypeError, start_not_a_number_error): 619 SingleFileSource( 620 fbs, file_name='dummy_file', start_offset='aaa', stop_offset=100) 621 with self.assertRaisesRegex(TypeError, stop_not_a_number_error): 622 SingleFileSource( 623 fbs, file_name='dummy_file', start_offset=100, stop_offset='bbb') 624 with self.assertRaisesRegex(TypeError, stop_not_a_number_error): 625 SingleFileSource( 626 fbs, file_name='dummy_file', start_offset=100, stop_offset=None) 627 with self.assertRaisesRegex(TypeError, start_not_a_number_error): 628 SingleFileSource( 629 fbs, file_name='dummy_file', start_offset=None, stop_offset=100) 630 631 def test_source_creation_display_data(self): 632 file_name = 'dummy_pattern' 633 fbs = LineSource(file_name, validate=False) 634 dd = DisplayData.create_from(fbs) 635 expected_items = [ 636 DisplayDataItemMatcher('compression', 'auto'), 637 DisplayDataItemMatcher('file_pattern', file_name) 638 ] 639 hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) 640 641 def test_source_creation_fails_if_start_lg_stop(self): 642 start_larger_than_stop_error = ( 643 'start_offset must be smaller than stop_offset*') 644 fbs = LineSource('dummy_pattern', validate=False) 645 SingleFileSource( 646 fbs, file_name='dummy_file', start_offset=99, stop_offset=100) 647 with self.assertRaisesRegex(ValueError, start_larger_than_stop_error): 648 SingleFileSource( 649 fbs, file_name='dummy_file', start_offset=100, stop_offset=99) 650 with self.assertRaisesRegex(ValueError, start_larger_than_stop_error): 651 SingleFileSource( 652 fbs, file_name='dummy_file', start_offset=100, stop_offset=100) 653 654 def test_estimates_size(self): 655 fbs = LineSource('dummy_pattern', validate=False) 656 657 # Should simply return stop_offset - start_offset 658 source = SingleFileSource( 659 fbs, file_name='dummy_file', start_offset=0, stop_offset=100) 660 self.assertEqual(100, source.estimate_size()) 661 662 source = SingleFileSource( 663 fbs, file_name='dummy_file', start_offset=10, stop_offset=100) 664 self.assertEqual(90, source.estimate_size()) 665 666 def test_read_range_at_beginning(self): 667 fbs = LineSource('dummy_pattern', validate=False) 668 669 file_name, expected_data = write_data(10) 670 assert len(expected_data) == 10 671 672 source = SingleFileSource(fbs, file_name, 0, 10 * 6) 673 range_tracker = source.get_range_tracker(0, 20) 674 read_data = [value for value in source.read(range_tracker)] 675 self.assertCountEqual(expected_data[:4], read_data) 676 677 def test_read_range_at_end(self): 678 fbs = LineSource('dummy_pattern', validate=False) 679 680 file_name, expected_data = write_data(10) 681 assert len(expected_data) == 10 682 683 source = SingleFileSource(fbs, file_name, 0, 10 * 6) 684 range_tracker = source.get_range_tracker(40, 60) 685 read_data = [value for value in source.read(range_tracker)] 686 self.assertCountEqual(expected_data[-3:], read_data) 687 688 def test_read_range_at_middle(self): 689 fbs = LineSource('dummy_pattern', validate=False) 690 691 file_name, expected_data = write_data(10) 692 assert len(expected_data) == 10 693 694 source = SingleFileSource(fbs, file_name, 0, 10 * 6) 695 range_tracker = source.get_range_tracker(20, 40) 696 read_data = [value for value in source.read(range_tracker)] 697 self.assertCountEqual(expected_data[4:7], read_data) 698 699 def test_produces_splits_desiredsize_large_than_size(self): 700 fbs = LineSource('dummy_pattern', validate=False) 701 702 file_name, expected_data = write_data(10) 703 assert len(expected_data) == 10 704 source = SingleFileSource(fbs, file_name, 0, 10 * 6) 705 splits = [split for split in source.split(desired_bundle_size=100)] 706 self.assertEqual(1, len(splits)) 707 self.assertEqual(60, splits[0].weight) 708 self.assertEqual(0, splits[0].start_position) 709 self.assertEqual(60, splits[0].stop_position) 710 711 range_tracker = splits[0].source.get_range_tracker(None, None) 712 read_data = [value for value in splits[0].source.read(range_tracker)] 713 self.assertCountEqual(expected_data, read_data) 714 715 def test_produces_splits_desiredsize_smaller_than_size(self): 716 fbs = LineSource('dummy_pattern', validate=False) 717 718 file_name, expected_data = write_data(10) 719 assert len(expected_data) == 10 720 source = SingleFileSource(fbs, file_name, 0, 10 * 6) 721 splits = [split for split in source.split(desired_bundle_size=25)] 722 self.assertEqual(3, len(splits)) 723 724 read_data = [] 725 for split in splits: 726 source = split.source 727 range_tracker = source.get_range_tracker( 728 split.start_position, split.stop_position) 729 data_from_split = [data for data in source.read(range_tracker)] 730 read_data.extend(data_from_split) 731 self.assertCountEqual(expected_data, read_data) 732 733 def test_produce_split_with_start_and_end_positions(self): 734 fbs = LineSource('dummy_pattern', validate=False) 735 736 file_name, expected_data = write_data(10) 737 assert len(expected_data) == 10 738 source = SingleFileSource(fbs, file_name, 0, 10 * 6) 739 splits = [ 740 split for split in source.split( 741 desired_bundle_size=15, start_offset=10, stop_offset=50) 742 ] 743 self.assertEqual(3, len(splits)) 744 745 read_data = [] 746 for split in splits: 747 source = split.source 748 range_tracker = source.get_range_tracker( 749 split.start_position, split.stop_position) 750 data_from_split = [data for data in source.read(range_tracker)] 751 read_data.extend(data_from_split) 752 self.assertCountEqual(expected_data[2:9], read_data) 753 754 755 if __name__ == '__main__': 756 logging.getLogger().setLevel(logging.INFO) 757 unittest.main()