github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/avroio_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 # pytype: skip-file 18 19 import json 20 import logging 21 import math 22 import os 23 import tempfile 24 import unittest 25 from typing import List 26 27 import hamcrest as hc 28 29 from fastavro.schema import parse_schema 30 from fastavro import writer 31 32 import apache_beam as beam 33 from apache_beam import Create 34 from apache_beam.io import avroio 35 from apache_beam.io import filebasedsource 36 from apache_beam.io import iobase 37 from apache_beam.io import source_test_utils 38 from apache_beam.io.avroio import _create_avro_sink # For testing 39 from apache_beam.io.avroio import _create_avro_source # For testing 40 from apache_beam.io.filesystems import FileSystems 41 from apache_beam.testing.test_pipeline import TestPipeline 42 from apache_beam.testing.util import assert_that 43 from apache_beam.testing.util import equal_to 44 from apache_beam.transforms.display import DisplayData 45 from apache_beam.transforms.display_test import DisplayDataItemMatcher 46 from apache_beam.transforms.userstate import CombiningValueStateSpec 47 from apache_beam.utils.timestamp import Timestamp 48 49 # Import snappy optionally; some tests will be skipped when import fails. 50 try: 51 import snappy # pylint: disable=import-error 52 except ImportError: 53 snappy = None # pylint: disable=invalid-name 54 logging.warning('python-snappy is not installed; some tests will be skipped.') 55 56 RECORDS = [{ 57 'name': 'Thomas', 'favorite_number': 1, 'favorite_color': 'blue' 58 }, { 59 'name': 'Henry', 'favorite_number': 3, 'favorite_color': 'green' 60 }, { 61 'name': 'Toby', 'favorite_number': 7, 'favorite_color': 'brown' 62 }, { 63 'name': 'Gordon', 'favorite_number': 4, 'favorite_color': 'blue' 64 }, { 65 'name': 'Emily', 'favorite_number': -1, 'favorite_color': 'Red' 66 }, { 67 'name': 'Percy', 'favorite_number': 6, 'favorite_color': 'Green' 68 }] 69 70 71 class AvroBase(object): 72 73 _temp_files = [] # type: List[str] 74 75 def __init__(self, methodName='runTest'): 76 super().__init__(methodName) 77 self.RECORDS = RECORDS 78 self.SCHEMA_STRING = ''' 79 {"namespace": "example.avro", 80 "type": "record", 81 "name": "User", 82 "fields": [ 83 {"name": "name", "type": "string"}, 84 {"name": "favorite_number", "type": ["int", "null"]}, 85 {"name": "favorite_color", "type": ["string", "null"]} 86 ] 87 } 88 ''' 89 90 def setUp(self): 91 # Reducing the size of thread pools. Without this test execution may fail in 92 # environments with limited amount of resources. 93 filebasedsource.MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 2 94 95 def tearDown(self): 96 for path in self._temp_files: 97 if os.path.exists(path): 98 os.remove(path) 99 self._temp_files = [] 100 101 def _write_data( 102 self, 103 directory=None, 104 prefix=None, 105 codec=None, 106 count=None, 107 sync_interval=None): 108 raise NotImplementedError 109 110 def _write_pattern(self, num_files, return_filenames=False): 111 assert num_files > 0 112 temp_dir = tempfile.mkdtemp() 113 114 file_name = None 115 file_list = [] 116 for _ in range(num_files): 117 file_name = self._write_data(directory=temp_dir, prefix='mytemp') 118 file_list.append(file_name) 119 120 assert file_name 121 file_name_prefix = file_name[:file_name.rfind(os.path.sep)] 122 if return_filenames: 123 return (file_name_prefix + os.path.sep + 'mytemp*', file_list) 124 return file_name_prefix + os.path.sep + 'mytemp*' 125 126 def _run_avro_test( 127 self, pattern, desired_bundle_size, perform_splitting, expected_result): 128 source = _create_avro_source(pattern) 129 130 if perform_splitting: 131 assert desired_bundle_size 132 splits = [ 133 split 134 for split in source.split(desired_bundle_size=desired_bundle_size) 135 ] 136 if len(splits) < 2: 137 raise ValueError( 138 'Test is trivial. Please adjust it so that at least ' 139 'two splits get generated') 140 141 sources_info = [(split.source, split.start_position, split.stop_position) 142 for split in splits] 143 source_test_utils.assert_sources_equal_reference_source( 144 (source, None, None), sources_info) 145 else: 146 read_records = source_test_utils.read_from_source(source, None, None) 147 self.assertCountEqual(expected_result, read_records) 148 149 def test_read_without_splitting(self): 150 file_name = self._write_data() 151 expected_result = self.RECORDS 152 self._run_avro_test(file_name, None, False, expected_result) 153 154 def test_read_with_splitting(self): 155 file_name = self._write_data() 156 expected_result = self.RECORDS 157 self._run_avro_test(file_name, 100, True, expected_result) 158 159 def test_source_display_data(self): 160 file_name = 'some_avro_source' 161 source = \ 162 _create_avro_source( 163 file_name, 164 validate=False, 165 ) 166 dd = DisplayData.create_from(source) 167 168 # No extra avro parameters for AvroSource. 169 expected_items = [ 170 DisplayDataItemMatcher('compression', 'auto'), 171 DisplayDataItemMatcher('file_pattern', file_name) 172 ] 173 hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) 174 175 def test_read_display_data(self): 176 file_name = 'some_avro_source' 177 read = \ 178 avroio.ReadFromAvro( 179 file_name, 180 validate=False) 181 dd = DisplayData.create_from(read) 182 183 # No extra avro parameters for AvroSource. 184 expected_items = [ 185 DisplayDataItemMatcher('compression', 'auto'), 186 DisplayDataItemMatcher('file_pattern', file_name) 187 ] 188 hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) 189 190 def test_sink_display_data(self): 191 file_name = 'some_avro_sink' 192 sink = _create_avro_sink( 193 file_name, self.SCHEMA, 'null', '.end', 0, None, 'application/x-avro') 194 dd = DisplayData.create_from(sink) 195 196 expected_items = [ 197 DisplayDataItemMatcher('schema', str(self.SCHEMA)), 198 DisplayDataItemMatcher( 199 'file_pattern', 200 'some_avro_sink-%(shard_num)05d-of-%(num_shards)05d.end'), 201 DisplayDataItemMatcher('codec', 'null'), 202 DisplayDataItemMatcher('compression', 'uncompressed') 203 ] 204 205 hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) 206 207 def test_write_display_data(self): 208 file_name = 'some_avro_sink' 209 write = avroio.WriteToAvro(file_name, self.SCHEMA) 210 dd = DisplayData.create_from(write) 211 expected_items = [ 212 DisplayDataItemMatcher('schema', str(self.SCHEMA)), 213 DisplayDataItemMatcher( 214 'file_pattern', 215 'some_avro_sink-%(shard_num)05d-of-%(num_shards)05d'), 216 DisplayDataItemMatcher('codec', 'deflate'), 217 DisplayDataItemMatcher('compression', 'uncompressed') 218 ] 219 hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items)) 220 221 def test_read_reentrant_without_splitting(self): 222 file_name = self._write_data() 223 source = _create_avro_source(file_name) 224 source_test_utils.assert_reentrant_reads_succeed((source, None, None)) 225 226 def test_read_reantrant_with_splitting(self): 227 file_name = self._write_data() 228 source = _create_avro_source(file_name) 229 splits = [split for split in source.split(desired_bundle_size=100000)] 230 assert len(splits) == 1 231 source_test_utils.assert_reentrant_reads_succeed( 232 (splits[0].source, splits[0].start_position, splits[0].stop_position)) 233 234 def test_read_without_splitting_multiple_blocks(self): 235 file_name = self._write_data(count=12000) 236 expected_result = self.RECORDS * 2000 237 self._run_avro_test(file_name, None, False, expected_result) 238 239 def test_read_with_splitting_multiple_blocks(self): 240 file_name = self._write_data(count=12000) 241 expected_result = self.RECORDS * 2000 242 self._run_avro_test(file_name, 10000, True, expected_result) 243 244 def test_split_points(self): 245 num_records = 12000 246 sync_interval = 16000 247 file_name = self._write_data(count=num_records, sync_interval=sync_interval) 248 249 source = _create_avro_source(file_name) 250 251 splits = [split for split in source.split(desired_bundle_size=float('inf'))] 252 assert len(splits) == 1 253 range_tracker = splits[0].source.get_range_tracker( 254 splits[0].start_position, splits[0].stop_position) 255 256 split_points_report = [] 257 258 for _ in splits[0].source.read(range_tracker): 259 split_points_report.append(range_tracker.split_points()) 260 # There will be a total of num_blocks in the generated test file, 261 # proportional to number of records in the file divided by syncronization 262 # interval used by avro during write. Each block has more than 10 records. 263 num_blocks = int(math.ceil(14.5 * num_records / sync_interval)) 264 assert num_blocks > 1 265 # When reading records of the first block, range_tracker.split_points() 266 # should return (0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN) 267 self.assertEqual( 268 split_points_report[:10], 269 [(0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)] * 10) 270 271 # When reading records of last block, range_tracker.split_points() should 272 # return (num_blocks - 1, 1) 273 self.assertEqual(split_points_report[-10:], [(num_blocks - 1, 1)] * 10) 274 275 def test_read_without_splitting_compressed_deflate(self): 276 file_name = self._write_data(codec='deflate') 277 expected_result = self.RECORDS 278 self._run_avro_test(file_name, None, False, expected_result) 279 280 def test_read_with_splitting_compressed_deflate(self): 281 file_name = self._write_data(codec='deflate') 282 expected_result = self.RECORDS 283 self._run_avro_test(file_name, 100, True, expected_result) 284 285 @unittest.skipIf(snappy is None, 'python-snappy not installed.') 286 def test_read_without_splitting_compressed_snappy(self): 287 file_name = self._write_data(codec='snappy') 288 expected_result = self.RECORDS 289 self._run_avro_test(file_name, None, False, expected_result) 290 291 @unittest.skipIf(snappy is None, 'python-snappy not installed.') 292 def test_read_with_splitting_compressed_snappy(self): 293 file_name = self._write_data(codec='snappy') 294 expected_result = self.RECORDS 295 self._run_avro_test(file_name, 100, True, expected_result) 296 297 def test_read_without_splitting_pattern(self): 298 pattern = self._write_pattern(3) 299 expected_result = self.RECORDS * 3 300 self._run_avro_test(pattern, None, False, expected_result) 301 302 def test_read_with_splitting_pattern(self): 303 pattern = self._write_pattern(3) 304 expected_result = self.RECORDS * 3 305 self._run_avro_test(pattern, 100, True, expected_result) 306 307 def test_dynamic_work_rebalancing_exhaustive(self): 308 def compare_split_points(file_name): 309 source = _create_avro_source(file_name) 310 splits = [ 311 split for split in source.split(desired_bundle_size=float('inf')) 312 ] 313 assert len(splits) == 1 314 source_test_utils.assert_split_at_fraction_exhaustive(splits[0].source) 315 316 # Adjusting block size so that we can perform a exhaustive dynamic 317 # work rebalancing test that completes within an acceptable amount of time. 318 file_name = self._write_data(count=5, sync_interval=2) 319 320 compare_split_points(file_name) 321 322 def test_corrupted_file(self): 323 file_name = self._write_data() 324 with open(file_name, 'rb') as f: 325 data = f.read() 326 327 # Corrupt the last character of the file which is also the last character of 328 # the last sync_marker. 329 # https://avro.apache.org/docs/current/spec.html#Object+Container+Files 330 corrupted_data = bytearray(data) 331 corrupted_data[-1] = (corrupted_data[-1] + 1) % 256 332 with tempfile.NamedTemporaryFile(delete=False, 333 prefix=tempfile.template) as f: 334 f.write(corrupted_data) 335 corrupted_file_name = f.name 336 337 source = _create_avro_source(corrupted_file_name) 338 with self.assertRaisesRegex(ValueError, r'expected sync marker'): 339 source_test_utils.read_from_source(source, None, None) 340 341 def test_read_from_avro(self): 342 path = self._write_data() 343 with TestPipeline() as p: 344 assert_that(p | avroio.ReadFromAvro(path), equal_to(self.RECORDS)) 345 346 def test_read_all_from_avro_single_file(self): 347 path = self._write_data() 348 with TestPipeline() as p: 349 assert_that( 350 p \ 351 | Create([path]) \ 352 | avroio.ReadAllFromAvro(), 353 equal_to(self.RECORDS)) 354 355 def test_read_all_from_avro_many_single_files(self): 356 path1 = self._write_data() 357 path2 = self._write_data() 358 path3 = self._write_data() 359 with TestPipeline() as p: 360 assert_that( 361 p \ 362 | Create([path1, path2, path3]) \ 363 | avroio.ReadAllFromAvro(), 364 equal_to(self.RECORDS * 3)) 365 366 def test_read_all_from_avro_file_pattern(self): 367 file_pattern = self._write_pattern(5) 368 with TestPipeline() as p: 369 assert_that( 370 p \ 371 | Create([file_pattern]) \ 372 | avroio.ReadAllFromAvro(), 373 equal_to(self.RECORDS * 5)) 374 375 def test_read_all_from_avro_many_file_patterns(self): 376 file_pattern1 = self._write_pattern(5) 377 file_pattern2 = self._write_pattern(2) 378 file_pattern3 = self._write_pattern(3) 379 with TestPipeline() as p: 380 assert_that( 381 p \ 382 | Create([file_pattern1, file_pattern2, file_pattern3]) \ 383 | avroio.ReadAllFromAvro(), 384 equal_to(self.RECORDS * 10)) 385 386 def test_read_all_from_avro_with_filename(self): 387 file_pattern, file_paths = self._write_pattern(3, return_filenames=True) 388 result = [(path, record) for path in file_paths for record in self.RECORDS] 389 with TestPipeline() as p: 390 assert_that( 391 p \ 392 | Create([file_pattern]) \ 393 | avroio.ReadAllFromAvro(with_filename=True), 394 equal_to(result)) 395 396 class _WriteFilesFn(beam.DoFn): 397 """writes a couple of files with deferral.""" 398 399 COUNT_STATE = CombiningValueStateSpec('count', combine_fn=sum) 400 401 def __init__(self, SCHEMA, RECORDS, tempdir): 402 self._thread = None 403 self.SCHEMA = SCHEMA 404 self.RECORDS = RECORDS 405 self.tempdir = tempdir 406 407 def get_expect(self, match_updated_files): 408 results_file1 = [('file1', x) for x in self.gen_records(1)] 409 results_file2 = [('file2', x) for x in self.gen_records(3)] 410 if match_updated_files: 411 results_file1 += [('file1', x) for x in self.gen_records(2)] 412 return results_file1 + results_file2 413 414 def gen_records(self, count): 415 return self.RECORDS * (count // len(self.RECORDS)) + self.RECORDS[:( 416 count % len(self.RECORDS))] 417 418 def process(self, element, count_state=beam.DoFn.StateParam(COUNT_STATE)): 419 counter = count_state.read() 420 if counter == 0: 421 count_state.add(1) 422 with open(FileSystems.join(self.tempdir, 'file1'), 'wb') as f: 423 writer(f, self.SCHEMA, self.gen_records(2)) 424 with open(FileSystems.join(self.tempdir, 'file2'), 'wb') as f: 425 writer(f, self.SCHEMA, self.gen_records(3)) 426 # convert dumb key to basename in output 427 basename = FileSystems.split(element[1][0])[1] 428 content = element[1][1] 429 yield basename, content 430 431 def test_read_all_continuously_new(self): 432 with TestPipeline() as pipeline: 433 tempdir = tempfile.mkdtemp() 434 writer_fn = self._WriteFilesFn(self.SCHEMA, self.RECORDS, tempdir) 435 with open(FileSystems.join(tempdir, 'file1'), 'wb') as f: 436 writer(f, writer_fn.SCHEMA, writer_fn.gen_records(1)) 437 match_pattern = FileSystems.join(tempdir, '*') 438 interval = 0.5 439 last = 2 440 441 p_read_once = ( 442 pipeline 443 | 'Continuously read new files' >> avroio.ReadAllFromAvroContinuously( 444 match_pattern, 445 with_filename=True, 446 start_timestamp=Timestamp.now(), 447 interval=interval, 448 stop_timestamp=Timestamp.now() + last, 449 match_updated_files=False) 450 | 'add dumb key' >> beam.Map(lambda x: (0, x)) 451 | 'Write files on-the-fly' >> beam.ParDo(writer_fn)) 452 assert_that( 453 p_read_once, 454 equal_to(writer_fn.get_expect(match_updated_files=False)), 455 label='assert read new files results') 456 457 def test_read_all_continuously_update(self): 458 with TestPipeline() as pipeline: 459 tempdir = tempfile.mkdtemp() 460 writer_fn = self._WriteFilesFn(self.SCHEMA, self.RECORDS, tempdir) 461 with open(FileSystems.join(tempdir, 'file1'), 'wb') as f: 462 writer(f, writer_fn.SCHEMA, writer_fn.gen_records(1)) 463 match_pattern = FileSystems.join(tempdir, '*') 464 interval = 0.5 465 last = 2 466 467 p_read_upd = ( 468 pipeline 469 | 'Continuously read updated files' >> 470 avroio.ReadAllFromAvroContinuously( 471 match_pattern, 472 with_filename=True, 473 start_timestamp=Timestamp.now(), 474 interval=interval, 475 stop_timestamp=Timestamp.now() + last, 476 match_updated_files=True) 477 | 'add dumb key' >> beam.Map(lambda x: (0, x)) 478 | 'Write files on-the-fly' >> beam.ParDo(writer_fn)) 479 assert_that( 480 p_read_upd, 481 equal_to(writer_fn.get_expect(match_updated_files=True)), 482 label='assert read updated files results') 483 484 def test_sink_transform(self): 485 with tempfile.NamedTemporaryFile() as dst: 486 path = dst.name 487 with TestPipeline() as p: 488 # pylint: disable=expression-not-assigned 489 p \ 490 | beam.Create(self.RECORDS) \ 491 | avroio.WriteToAvro(path, self.SCHEMA,) 492 with TestPipeline() as p: 493 # json used for stable sortability 494 readback = \ 495 p \ 496 | avroio.ReadFromAvro(path + '*', ) \ 497 | beam.Map(json.dumps) 498 assert_that(readback, equal_to([json.dumps(r) for r in self.RECORDS])) 499 500 @unittest.skipIf(snappy is None, 'python-snappy not installed.') 501 def test_sink_transform_snappy(self): 502 with tempfile.NamedTemporaryFile() as dst: 503 path = dst.name 504 with TestPipeline() as p: 505 # pylint: disable=expression-not-assigned 506 p \ 507 | beam.Create(self.RECORDS) \ 508 | avroio.WriteToAvro( 509 path, 510 self.SCHEMA, 511 codec='snappy') 512 with TestPipeline() as p: 513 # json used for stable sortability 514 readback = \ 515 p \ 516 | avroio.ReadFromAvro(path + '*') \ 517 | beam.Map(json.dumps) 518 assert_that(readback, equal_to([json.dumps(r) for r in self.RECORDS])) 519 520 def test_writer_open_and_close(self): 521 # Create and then close a temp file so we can manually open it later 522 dst = tempfile.NamedTemporaryFile(delete=False) 523 dst.close() 524 525 schema = parse_schema(json.loads(self.SCHEMA_STRING)) 526 sink = _create_avro_sink( 527 'some_avro_sink', schema, 'null', '.end', 0, None, 'application/x-avro') 528 529 w = sink.open(dst.name) 530 531 sink.close(w) 532 533 os.unlink(dst.name) 534 535 536 class TestFastAvro(AvroBase, unittest.TestCase): 537 def __init__(self, methodName='runTest'): 538 super().__init__(methodName) 539 self.SCHEMA = parse_schema(json.loads(self.SCHEMA_STRING)) 540 541 def _write_data( 542 self, 543 directory=None, 544 prefix=tempfile.template, 545 codec='null', 546 count=len(RECORDS), 547 **kwargs): 548 all_records = self.RECORDS * \ 549 (count // len(self.RECORDS)) + self.RECORDS[:(count % len(self.RECORDS))] 550 with tempfile.NamedTemporaryFile(delete=False, 551 dir=directory, 552 prefix=prefix, 553 mode='w+b') as f: 554 writer(f, self.SCHEMA, all_records, codec=codec, **kwargs) 555 self._temp_files.append(f.name) 556 return f.name 557 558 559 if __name__ == '__main__': 560 logging.getLogger().setLevel(logging.INFO) 561 unittest.main()