github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/parquetio_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 # pytype: skip-file 18 19 import json 20 import logging 21 import os 22 import shutil 23 import tempfile 24 import unittest 25 from tempfile import TemporaryDirectory 26 27 import hamcrest as hc 28 import pandas 29 import pytest 30 from parameterized import param 31 from parameterized import parameterized 32 33 from apache_beam import Create 34 from apache_beam import Map 35 from apache_beam.io import filebasedsource 36 from apache_beam.io import source_test_utils 37 from apache_beam.io.iobase import RangeTracker 38 from apache_beam.io.parquetio import ReadAllFromParquet 39 from apache_beam.io.parquetio import ReadAllFromParquetBatched 40 from apache_beam.io.parquetio import ReadFromParquet 41 from apache_beam.io.parquetio import ReadFromParquetBatched 42 from apache_beam.io.parquetio import WriteToParquet 43 from apache_beam.io.parquetio import WriteToParquetBatched 44 from apache_beam.io.parquetio import _create_parquet_sink 45 from apache_beam.io.parquetio import _create_parquet_source 46 from apache_beam.testing.test_pipeline import TestPipeline 47 from apache_beam.testing.util import assert_that 48 from apache_beam.testing.util import equal_to 49 from apache_beam.transforms.display import DisplayData 50 from apache_beam.transforms.display_test import DisplayDataItemMatcher 51 52 try: 53 import pyarrow as pa 54 import pyarrow.lib as pl 55 import pyarrow.parquet as pq 56 except ImportError: 57 pa = None 58 pl = None 59 pq = None 60 61 ARROW_MAJOR_VERSION, _, _ = map(int, pa.__version__.split('.')) 62 63 64 @unittest.skipIf(pa is None, "PyArrow is not installed.") 65 @pytest.mark.uses_pyarrow 66 class TestParquet(unittest.TestCase): 67 def setUp(self): 68 # Reducing the size of thread pools. Without this test execution may fail in 69 # environments with limited amount of resources. 70 filebasedsource.MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 2 71 self.temp_dir = tempfile.mkdtemp() 72 73 self.RECORDS = [{ 74 'name': 'Thomas', 'favorite_number': 1, 'favorite_color': 'blue' 75 }, 76 { 77 'name': 'Henry', 78 'favorite_number': 3, 79 'favorite_color': 'green' 80 }, 81 { 82 'name': 'Toby', 83 'favorite_number': 7, 84 'favorite_color': 'brown' 85 }, 86 { 87 'name': 'Gordon', 88 'favorite_number': 4, 89 'favorite_color': 'blue' 90 }, 91 { 92 'name': 'Emily', 93 'favorite_number': -1, 94 'favorite_color': 'Red' 95 }, 96 { 97 'name': 'Percy', 98 'favorite_number': 6, 99 'favorite_color': 'Green' 100 }, 101 { 102 'name': 'Peter', 103 'favorite_number': 3, 104 'favorite_color': None 105 }] 106 107 self.SCHEMA = pa.schema([('name', pa.string(), False), 108 ('favorite_number', pa.int64(), False), 109 ('favorite_color', pa.string())]) 110 111 self.SCHEMA96 = pa.schema([('name', pa.string(), False), 112 ('favorite_number', pa.timestamp('ns'), False), 113 ('favorite_color', pa.string())]) 114 115 self.RECORDS_NESTED = [{ 116 'items': [ 117 { 118 'name': 'Thomas', 119 'favorite_number': 1, 120 'favorite_color': 'blue' 121 }, 122 { 123 'name': 'Henry', 124 'favorite_number': 3, 125 'favorite_color': 'green' 126 }, 127 ] 128 }, 129 { 130 'items': [ 131 { 132 'name': 'Toby', 133 'favorite_number': 7, 134 'favorite_color': 'brown' 135 }, 136 ] 137 }] 138 139 self.SCHEMA_NESTED = pa.schema([( 140 'items', 141 pa.list_( 142 pa.struct([('name', pa.string(), False), 143 ('favorite_number', pa.int64(), False), 144 ('favorite_color', pa.string())])))]) 145 146 def tearDown(self): 147 shutil.rmtree(self.temp_dir) 148 149 def _record_to_columns(self, records, schema): 150 col_list = [] 151 for n in schema.names: 152 column = [] 153 for r in records: 154 column.append(r[n]) 155 156 col_list.append(column) 157 return col_list 158 159 def _records_as_arrow(self, schema=None, count=None): 160 if schema is None: 161 schema = self.SCHEMA 162 163 if count is None: 164 count = len(self.RECORDS) 165 166 len_records = len(self.RECORDS) 167 data = [] 168 for i in range(count): 169 data.append(self.RECORDS[i % len_records]) 170 col_data = self._record_to_columns(data, schema) 171 col_array = [pa.array(c, schema.types[cn]) for cn, c in enumerate(col_data)] 172 return pa.Table.from_arrays(col_array, schema=schema) 173 174 def _write_data( 175 self, 176 directory=None, 177 schema=None, 178 prefix=tempfile.template, 179 row_group_size=1000, 180 codec='none', 181 count=None): 182 if directory is None: 183 directory = self.temp_dir 184 185 with tempfile.NamedTemporaryFile(delete=False, dir=directory, 186 prefix=prefix) as f: 187 table = self._records_as_arrow(schema, count) 188 pq.write_table( 189 table, 190 f, 191 row_group_size=row_group_size, 192 compression=codec, 193 use_deprecated_int96_timestamps=True) 194 195 return f.name 196 197 def _write_pattern(self, num_files, with_filename=False): 198 assert num_files > 0 199 temp_dir = tempfile.mkdtemp(dir=self.temp_dir) 200 201 file_list = [] 202 for _ in range(num_files): 203 file_list.append(self._write_data(directory=temp_dir, prefix='mytemp')) 204 205 if with_filename: 206 return (temp_dir + os.path.sep + 'mytemp*', file_list) 207 return temp_dir + os.path.sep + 'mytemp*' 208 209 def _run_parquet_test( 210 self, 211 pattern, 212 columns, 213 desired_bundle_size, 214 perform_splitting, 215 expected_result): 216 source = _create_parquet_source(pattern, columns=columns) 217 if perform_splitting: 218 assert desired_bundle_size 219 sources_info = [ 220 (split.source, split.start_position, split.stop_position) 221 for split in source.split(desired_bundle_size=desired_bundle_size) 222 ] 223 if len(sources_info) < 2: 224 raise ValueError( 225 'Test is trivial. Please adjust it so that at least ' 226 'two splits get generated') 227 228 source_test_utils.assert_sources_equal_reference_source( 229 (source, None, None), sources_info) 230 else: 231 read_records = source_test_utils.read_from_source(source, None, None) 232 self.assertCountEqual(expected_result, read_records) 233 234 def test_read_without_splitting(self): 235 file_name = self._write_data() 236 expected_result = [self._records_as_arrow()] 237 self._run_parquet_test(file_name, None, None, False, expected_result) 238 239 def test_read_with_splitting(self): 240 file_name = self._write_data() 241 expected_result = [self._records_as_arrow()] 242 self._run_parquet_test(file_name, None, 100, True, expected_result) 243 244 def test_source_display_data(self): 245 file_name = 'some_parquet_source' 246 source = \ 247 _create_parquet_source( 248 file_name, 249 validate=False 250 ) 251 dd = DisplayData.create_from(source) 252 253 expected_items = [ 254 DisplayDataItemMatcher('compression', 'auto'), 255 DisplayDataItemMatcher('file_pattern', file_name) 256 ] 257 hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) 258 259 def test_read_display_data(self): 260 file_name = 'some_parquet_source' 261 read = \ 262 ReadFromParquet( 263 file_name, 264 validate=False) 265 read_batched = \ 266 ReadFromParquetBatched( 267 file_name, 268 validate=False) 269 270 expected_items = [ 271 DisplayDataItemMatcher('compression', 'auto'), 272 DisplayDataItemMatcher('file_pattern', file_name) 273 ] 274 275 hc.assert_that( 276 DisplayData.create_from(read).items, 277 hc.contains_inanyorder(*expected_items)) 278 hc.assert_that( 279 DisplayData.create_from(read_batched).items, 280 hc.contains_inanyorder(*expected_items)) 281 282 def test_sink_display_data(self): 283 file_name = 'some_parquet_sink' 284 sink = _create_parquet_sink( 285 file_name, 286 self.SCHEMA, 287 'none', 288 False, 289 False, 290 '.end', 291 0, 292 None, 293 'application/x-parquet') 294 dd = DisplayData.create_from(sink) 295 expected_items = [ 296 DisplayDataItemMatcher('schema', str(self.SCHEMA)), 297 DisplayDataItemMatcher( 298 'file_pattern', 299 'some_parquet_sink-%(shard_num)05d-of-%(num_shards)05d.end'), 300 DisplayDataItemMatcher('codec', 'none'), 301 DisplayDataItemMatcher('compression', 'uncompressed') 302 ] 303 hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) 304 305 def test_write_display_data(self): 306 file_name = 'some_parquet_sink' 307 write = WriteToParquet(file_name, self.SCHEMA) 308 dd = DisplayData.create_from(write) 309 310 expected_items = [ 311 DisplayDataItemMatcher('codec', 'none'), 312 DisplayDataItemMatcher('schema', str(self.SCHEMA)), 313 DisplayDataItemMatcher('row_group_buffer_size', str(64 * 1024 * 1024)), 314 DisplayDataItemMatcher( 315 'file_pattern', 316 'some_parquet_sink-%(shard_num)05d-of-%(num_shards)05d'), 317 DisplayDataItemMatcher('compression', 'uncompressed') 318 ] 319 hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) 320 321 def test_write_batched_display_data(self): 322 file_name = 'some_parquet_sink' 323 write = WriteToParquetBatched(file_name, self.SCHEMA) 324 dd = DisplayData.create_from(write) 325 326 expected_items = [ 327 DisplayDataItemMatcher('codec', 'none'), 328 DisplayDataItemMatcher('schema', str(self.SCHEMA)), 329 DisplayDataItemMatcher( 330 'file_pattern', 331 'some_parquet_sink-%(shard_num)05d-of-%(num_shards)05d'), 332 DisplayDataItemMatcher('compression', 'uncompressed') 333 ] 334 hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) 335 336 def test_sink_transform_int96(self): 337 with tempfile.NamedTemporaryFile() as dst: 338 path = dst.name 339 # pylint: disable=c-extension-no-member 340 with self.assertRaises(pl.ArrowInvalid): 341 # Should throw an error "ArrowInvalid: Casting from timestamp[ns] to 342 # timestamp[us] would lose data" 343 with TestPipeline() as p: 344 _ = p \ 345 | Create(self.RECORDS) \ 346 | WriteToParquet( 347 path, self.SCHEMA96, num_shards=1, shard_name_template='') 348 349 def test_sink_transform(self): 350 with TemporaryDirectory() as tmp_dirname: 351 path = os.path.join(tmp_dirname + "tmp_filename") 352 with TestPipeline() as p: 353 _ = p \ 354 | Create(self.RECORDS) \ 355 | WriteToParquet( 356 path, self.SCHEMA, num_shards=1, shard_name_template='') 357 with TestPipeline() as p: 358 # json used for stable sortability 359 readback = \ 360 p \ 361 | ReadFromParquet(path) \ 362 | Map(json.dumps) 363 assert_that(readback, equal_to([json.dumps(r) for r in self.RECORDS])) 364 365 def test_sink_transform_batched(self): 366 with TemporaryDirectory() as tmp_dirname: 367 path = os.path.join(tmp_dirname + "tmp_filename") 368 with TestPipeline() as p: 369 _ = p \ 370 | Create([self._records_as_arrow()]) \ 371 | WriteToParquetBatched( 372 path, self.SCHEMA, num_shards=1, shard_name_template='') 373 with TestPipeline() as p: 374 # json used for stable sortability 375 readback = \ 376 p \ 377 | ReadFromParquet(path) \ 378 | Map(json.dumps) 379 assert_that(readback, equal_to([json.dumps(r) for r in self.RECORDS])) 380 381 def test_sink_transform_compliant_nested_type(self): 382 if ARROW_MAJOR_VERSION < 4: 383 return unittest.skip( 384 'Writing with compliant nested type is only ' 385 'supported in pyarrow 4.x and above') 386 with TemporaryDirectory() as tmp_dirname: 387 path = os.path.join(tmp_dirname + 'tmp_filename') 388 with TestPipeline() as p: 389 _ = p \ 390 | Create(self.RECORDS_NESTED) \ 391 | WriteToParquet( 392 path, self.SCHEMA_NESTED, num_shards=1, 393 shard_name_template='', use_compliant_nested_type=True) 394 with TestPipeline() as p: 395 # json used for stable sortability 396 readback = \ 397 p \ 398 | ReadFromParquet(path) \ 399 | Map(json.dumps) 400 assert_that( 401 readback, equal_to([json.dumps(r) for r in self.RECORDS_NESTED])) 402 403 def test_batched_read(self): 404 with TemporaryDirectory() as tmp_dirname: 405 path = os.path.join(tmp_dirname + "tmp_filename") 406 with TestPipeline() as p: 407 _ = p \ 408 | Create(self.RECORDS, reshuffle=False) \ 409 | WriteToParquet( 410 path, self.SCHEMA, num_shards=1, shard_name_template='') 411 with TestPipeline() as p: 412 # json used for stable sortability 413 readback = \ 414 p \ 415 | ReadFromParquetBatched(path) 416 assert_that(readback, equal_to([self._records_as_arrow()])) 417 418 @parameterized.expand([ 419 param(compression_type='snappy'), 420 param(compression_type='gzip'), 421 param(compression_type='brotli'), 422 param(compression_type='lz4'), 423 param(compression_type='zstd') 424 ]) 425 def test_sink_transform_compressed(self, compression_type): 426 if compression_type == 'lz4' and ARROW_MAJOR_VERSION == 1: 427 return unittest.skip( 428 "Writing with LZ4 compression is not supported in " 429 "pyarrow 1.x") 430 with TemporaryDirectory() as tmp_dirname: 431 path = os.path.join(tmp_dirname + "tmp_filename") 432 with TestPipeline() as p: 433 _ = p \ 434 | Create(self.RECORDS) \ 435 | WriteToParquet( 436 path, self.SCHEMA, codec=compression_type, 437 num_shards=1, shard_name_template='') 438 with TestPipeline() as p: 439 # json used for stable sortability 440 readback = \ 441 p \ 442 | ReadFromParquet(path + '*') \ 443 | Map(json.dumps) 444 assert_that(readback, equal_to([json.dumps(r) for r in self.RECORDS])) 445 446 def test_read_reentrant(self): 447 file_name = self._write_data(count=6, row_group_size=3) 448 source = _create_parquet_source(file_name) 449 source_test_utils.assert_reentrant_reads_succeed((source, None, None)) 450 451 def test_read_without_splitting_multiple_row_group(self): 452 file_name = self._write_data(count=12000, row_group_size=1000) 453 # We expect 12000 elements, split into batches of 1000 elements. Create 454 # a list of pa.Table instances to model this expecation 455 expected_result = [ 456 pa.Table.from_batches([batch]) for batch in self._records_as_arrow( 457 count=12000).to_batches(max_chunksize=1000) 458 ] 459 self._run_parquet_test(file_name, None, None, False, expected_result) 460 461 def test_read_with_splitting_multiple_row_group(self): 462 file_name = self._write_data(count=12000, row_group_size=1000) 463 # We expect 12000 elements, split into batches of 1000 elements. Create 464 # a list of pa.Table instances to model this expecation 465 expected_result = [ 466 pa.Table.from_batches([batch]) for batch in self._records_as_arrow( 467 count=12000).to_batches(max_chunksize=1000) 468 ] 469 self._run_parquet_test(file_name, None, 10000, True, expected_result) 470 471 def test_dynamic_work_rebalancing(self): 472 file_name = self._write_data(count=120, row_group_size=20) 473 source = _create_parquet_source(file_name) 474 475 splits = [split for split in source.split(desired_bundle_size=float('inf'))] 476 assert len(splits) == 1 477 478 source_test_utils.assert_split_at_fraction_exhaustive( 479 splits[0].source, splits[0].start_position, splits[0].stop_position) 480 481 def test_min_bundle_size(self): 482 file_name = self._write_data(count=120, row_group_size=20) 483 484 source = _create_parquet_source( 485 file_name, min_bundle_size=100 * 1024 * 1024) 486 splits = [split for split in source.split(desired_bundle_size=1)] 487 self.assertEqual(len(splits), 1) 488 489 source = _create_parquet_source(file_name, min_bundle_size=0) 490 splits = [split for split in source.split(desired_bundle_size=1)] 491 self.assertNotEqual(len(splits), 1) 492 493 def _convert_to_timestamped_record(self, record): 494 timestamped_record = record.copy() 495 timestamped_record['favorite_number'] =\ 496 pandas.Timestamp(timestamped_record['favorite_number']) 497 return timestamped_record 498 499 def test_int96_type_conversion(self): 500 file_name = self._write_data( 501 count=120, row_group_size=20, schema=self.SCHEMA96) 502 orig = self._records_as_arrow(count=120, schema=self.SCHEMA96) 503 expected_result = [ 504 pa.Table.from_batches([batch], schema=self.SCHEMA96) 505 for batch in orig.to_batches(max_chunksize=20) 506 ] 507 self._run_parquet_test(file_name, None, None, False, expected_result) 508 509 def test_split_points(self): 510 file_name = self._write_data(count=12000, row_group_size=3000) 511 source = _create_parquet_source(file_name) 512 513 splits = [split for split in source.split(desired_bundle_size=float('inf'))] 514 assert len(splits) == 1 515 516 range_tracker = splits[0].source.get_range_tracker( 517 splits[0].start_position, splits[0].stop_position) 518 519 split_points_report = [] 520 521 for _ in splits[0].source.read(range_tracker): 522 split_points_report.append(range_tracker.split_points()) 523 524 # There are a total of four row groups. Each row group has 3000 records. 525 526 # When reading records of the first group, range_tracker.split_points() 527 # should return (0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN) 528 self.assertEqual( 529 split_points_report, 530 [ 531 (0, RangeTracker.SPLIT_POINTS_UNKNOWN), 532 (1, RangeTracker.SPLIT_POINTS_UNKNOWN), 533 (2, RangeTracker.SPLIT_POINTS_UNKNOWN), 534 (3, 1), 535 ]) 536 537 def test_selective_columns(self): 538 file_name = self._write_data() 539 orig = self._records_as_arrow() 540 name_column = self.SCHEMA.field('name') 541 expected_result = [ 542 pa.Table.from_arrays( 543 [orig.column('name')], 544 schema=pa.schema([('name', name_column.type, name_column.nullable) 545 ])) 546 ] 547 self._run_parquet_test(file_name, ['name'], None, False, expected_result) 548 549 def test_sink_transform_multiple_row_group(self): 550 with TemporaryDirectory() as tmp_dirname: 551 path = os.path.join(tmp_dirname + "tmp_filename") 552 with TestPipeline() as p: 553 # writing 623200 bytes of data 554 _ = p \ 555 | Create(self.RECORDS * 4000) \ 556 | WriteToParquet( 557 path, self.SCHEMA, num_shards=1, codec='none', 558 shard_name_template='', row_group_buffer_size=250000) 559 self.assertEqual(pq.read_metadata(path).num_row_groups, 3) 560 561 def test_read_all_from_parquet_single_file(self): 562 path = self._write_data() 563 with TestPipeline() as p: 564 assert_that( 565 p \ 566 | Create([path]) \ 567 | ReadAllFromParquet(), 568 equal_to(self.RECORDS)) 569 570 with TestPipeline() as p: 571 assert_that( 572 p \ 573 | Create([path]) \ 574 | ReadAllFromParquetBatched(), 575 equal_to([self._records_as_arrow()])) 576 577 def test_read_all_from_parquet_many_single_files(self): 578 path1 = self._write_data() 579 path2 = self._write_data() 580 path3 = self._write_data() 581 with TestPipeline() as p: 582 assert_that( 583 p \ 584 | Create([path1, path2, path3]) \ 585 | ReadAllFromParquet(), 586 equal_to(self.RECORDS * 3)) 587 with TestPipeline() as p: 588 assert_that( 589 p \ 590 | Create([path1, path2, path3]) \ 591 | ReadAllFromParquetBatched(), 592 equal_to([self._records_as_arrow()] * 3)) 593 594 def test_read_all_from_parquet_file_pattern(self): 595 file_pattern = self._write_pattern(5) 596 with TestPipeline() as p: 597 assert_that( 598 p \ 599 | Create([file_pattern]) \ 600 | ReadAllFromParquet(), 601 equal_to(self.RECORDS * 5)) 602 with TestPipeline() as p: 603 assert_that( 604 p \ 605 | Create([file_pattern]) \ 606 | ReadAllFromParquetBatched(), 607 equal_to([self._records_as_arrow()] * 5)) 608 609 def test_read_all_from_parquet_many_file_patterns(self): 610 file_pattern1 = self._write_pattern(5) 611 file_pattern2 = self._write_pattern(2) 612 file_pattern3 = self._write_pattern(3) 613 with TestPipeline() as p: 614 assert_that( 615 p \ 616 | Create([file_pattern1, file_pattern2, file_pattern3]) \ 617 | ReadAllFromParquet(), 618 equal_to(self.RECORDS * 10)) 619 with TestPipeline() as p: 620 assert_that( 621 p \ 622 | Create([file_pattern1, file_pattern2, file_pattern3]) \ 623 | ReadAllFromParquetBatched(), 624 equal_to([self._records_as_arrow()] * 10)) 625 626 def test_read_all_from_parquet_with_filename(self): 627 file_pattern, file_paths = self._write_pattern(3, with_filename=True) 628 result = [(path, record) for path in file_paths for record in self.RECORDS] 629 with TestPipeline() as p: 630 assert_that( 631 p \ 632 | Create([file_pattern]) \ 633 | ReadAllFromParquet(with_filename=True), 634 equal_to(result)) 635 636 637 if __name__ == '__main__': 638 logging.getLogger().setLevel(logging.INFO) 639 unittest.main()