github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/avroio_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  # pytype: skip-file
    18  
    19  import json
    20  import logging
    21  import math
    22  import os
    23  import tempfile
    24  import unittest
    25  from typing import List
    26  
    27  import hamcrest as hc
    28  
    29  from fastavro.schema import parse_schema
    30  from fastavro import writer
    31  
    32  import apache_beam as beam
    33  from apache_beam import Create
    34  from apache_beam.io import avroio
    35  from apache_beam.io import filebasedsource
    36  from apache_beam.io import iobase
    37  from apache_beam.io import source_test_utils
    38  from apache_beam.io.avroio import _create_avro_sink  # For testing
    39  from apache_beam.io.avroio import _create_avro_source  # For testing
    40  from apache_beam.io.filesystems import FileSystems
    41  from apache_beam.testing.test_pipeline import TestPipeline
    42  from apache_beam.testing.util import assert_that
    43  from apache_beam.testing.util import equal_to
    44  from apache_beam.transforms.display import DisplayData
    45  from apache_beam.transforms.display_test import DisplayDataItemMatcher
    46  from apache_beam.transforms.userstate import CombiningValueStateSpec
    47  from apache_beam.utils.timestamp import Timestamp
    48  
    49  # Import snappy optionally; some tests will be skipped when import fails.
    50  try:
    51    import snappy  # pylint: disable=import-error
    52  except ImportError:
    53    snappy = None  # pylint: disable=invalid-name
    54    logging.warning('python-snappy is not installed; some tests will be skipped.')
    55  
    56  RECORDS = [{
    57      'name': 'Thomas', 'favorite_number': 1, 'favorite_color': 'blue'
    58  }, {
    59      'name': 'Henry', 'favorite_number': 3, 'favorite_color': 'green'
    60  }, {
    61      'name': 'Toby', 'favorite_number': 7, 'favorite_color': 'brown'
    62  }, {
    63      'name': 'Gordon', 'favorite_number': 4, 'favorite_color': 'blue'
    64  }, {
    65      'name': 'Emily', 'favorite_number': -1, 'favorite_color': 'Red'
    66  }, {
    67      'name': 'Percy', 'favorite_number': 6, 'favorite_color': 'Green'
    68  }]
    69  
    70  
    71  class AvroBase(object):
    72  
    73    _temp_files = []  # type: List[str]
    74  
    75    def __init__(self, methodName='runTest'):
    76      super().__init__(methodName)
    77      self.RECORDS = RECORDS
    78      self.SCHEMA_STRING = '''
    79            {"namespace": "example.avro",
    80             "type": "record",
    81             "name": "User",
    82             "fields": [
    83                 {"name": "name", "type": "string"},
    84                 {"name": "favorite_number",  "type": ["int", "null"]},
    85                 {"name": "favorite_color", "type": ["string", "null"]}
    86             ]
    87            }
    88            '''
    89  
    90    def setUp(self):
    91      # Reducing the size of thread pools. Without this test execution may fail in
    92      # environments with limited amount of resources.
    93      filebasedsource.MAX_NUM_THREADS_FOR_SIZE_ESTIMATION = 2
    94  
    95    def tearDown(self):
    96      for path in self._temp_files:
    97        if os.path.exists(path):
    98          os.remove(path)
    99      self._temp_files = []
   100  
   101    def _write_data(
   102        self,
   103        directory=None,
   104        prefix=None,
   105        codec=None,
   106        count=None,
   107        sync_interval=None):
   108      raise NotImplementedError
   109  
   110    def _write_pattern(self, num_files, return_filenames=False):
   111      assert num_files > 0
   112      temp_dir = tempfile.mkdtemp()
   113  
   114      file_name = None
   115      file_list = []
   116      for _ in range(num_files):
   117        file_name = self._write_data(directory=temp_dir, prefix='mytemp')
   118        file_list.append(file_name)
   119  
   120      assert file_name
   121      file_name_prefix = file_name[:file_name.rfind(os.path.sep)]
   122      if return_filenames:
   123        return (file_name_prefix + os.path.sep + 'mytemp*', file_list)
   124      return file_name_prefix + os.path.sep + 'mytemp*'
   125  
   126    def _run_avro_test(
   127        self, pattern, desired_bundle_size, perform_splitting, expected_result):
   128      source = _create_avro_source(pattern)
   129  
   130      if perform_splitting:
   131        assert desired_bundle_size
   132        splits = [
   133            split
   134            for split in source.split(desired_bundle_size=desired_bundle_size)
   135        ]
   136        if len(splits) < 2:
   137          raise ValueError(
   138              'Test is trivial. Please adjust it so that at least '
   139              'two splits get generated')
   140  
   141        sources_info = [(split.source, split.start_position, split.stop_position)
   142                        for split in splits]
   143        source_test_utils.assert_sources_equal_reference_source(
   144            (source, None, None), sources_info)
   145      else:
   146        read_records = source_test_utils.read_from_source(source, None, None)
   147        self.assertCountEqual(expected_result, read_records)
   148  
   149    def test_read_without_splitting(self):
   150      file_name = self._write_data()
   151      expected_result = self.RECORDS
   152      self._run_avro_test(file_name, None, False, expected_result)
   153  
   154    def test_read_with_splitting(self):
   155      file_name = self._write_data()
   156      expected_result = self.RECORDS
   157      self._run_avro_test(file_name, 100, True, expected_result)
   158  
   159    def test_source_display_data(self):
   160      file_name = 'some_avro_source'
   161      source = \
   162          _create_avro_source(
   163              file_name,
   164              validate=False,
   165          )
   166      dd = DisplayData.create_from(source)
   167  
   168      # No extra avro parameters for AvroSource.
   169      expected_items = [
   170          DisplayDataItemMatcher('compression', 'auto'),
   171          DisplayDataItemMatcher('file_pattern', file_name)
   172      ]
   173      hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
   174  
   175    def test_read_display_data(self):
   176      file_name = 'some_avro_source'
   177      read = \
   178          avroio.ReadFromAvro(
   179              file_name,
   180              validate=False)
   181      dd = DisplayData.create_from(read)
   182  
   183      # No extra avro parameters for AvroSource.
   184      expected_items = [
   185          DisplayDataItemMatcher('compression', 'auto'),
   186          DisplayDataItemMatcher('file_pattern', file_name)
   187      ]
   188      hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
   189  
   190    def test_sink_display_data(self):
   191      file_name = 'some_avro_sink'
   192      sink = _create_avro_sink(
   193          file_name, self.SCHEMA, 'null', '.end', 0, None, 'application/x-avro')
   194      dd = DisplayData.create_from(sink)
   195  
   196      expected_items = [
   197          DisplayDataItemMatcher('schema', str(self.SCHEMA)),
   198          DisplayDataItemMatcher(
   199              'file_pattern',
   200              'some_avro_sink-%(shard_num)05d-of-%(num_shards)05d.end'),
   201          DisplayDataItemMatcher('codec', 'null'),
   202          DisplayDataItemMatcher('compression', 'uncompressed')
   203      ]
   204  
   205      hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
   206  
   207    def test_write_display_data(self):
   208      file_name = 'some_avro_sink'
   209      write = avroio.WriteToAvro(file_name, self.SCHEMA)
   210      dd = DisplayData.create_from(write)
   211      expected_items = [
   212          DisplayDataItemMatcher('schema', str(self.SCHEMA)),
   213          DisplayDataItemMatcher(
   214              'file_pattern',
   215              'some_avro_sink-%(shard_num)05d-of-%(num_shards)05d'),
   216          DisplayDataItemMatcher('codec', 'deflate'),
   217          DisplayDataItemMatcher('compression', 'uncompressed')
   218      ]
   219      hc.assert_that(dd.items, hc.contains_inanyorder(*expected_items))
   220  
   221    def test_read_reentrant_without_splitting(self):
   222      file_name = self._write_data()
   223      source = _create_avro_source(file_name)
   224      source_test_utils.assert_reentrant_reads_succeed((source, None, None))
   225  
   226    def test_read_reantrant_with_splitting(self):
   227      file_name = self._write_data()
   228      source = _create_avro_source(file_name)
   229      splits = [split for split in source.split(desired_bundle_size=100000)]
   230      assert len(splits) == 1
   231      source_test_utils.assert_reentrant_reads_succeed(
   232          (splits[0].source, splits[0].start_position, splits[0].stop_position))
   233  
   234    def test_read_without_splitting_multiple_blocks(self):
   235      file_name = self._write_data(count=12000)
   236      expected_result = self.RECORDS * 2000
   237      self._run_avro_test(file_name, None, False, expected_result)
   238  
   239    def test_read_with_splitting_multiple_blocks(self):
   240      file_name = self._write_data(count=12000)
   241      expected_result = self.RECORDS * 2000
   242      self._run_avro_test(file_name, 10000, True, expected_result)
   243  
   244    def test_split_points(self):
   245      num_records = 12000
   246      sync_interval = 16000
   247      file_name = self._write_data(count=num_records, sync_interval=sync_interval)
   248  
   249      source = _create_avro_source(file_name)
   250  
   251      splits = [split for split in source.split(desired_bundle_size=float('inf'))]
   252      assert len(splits) == 1
   253      range_tracker = splits[0].source.get_range_tracker(
   254          splits[0].start_position, splits[0].stop_position)
   255  
   256      split_points_report = []
   257  
   258      for _ in splits[0].source.read(range_tracker):
   259        split_points_report.append(range_tracker.split_points())
   260      # There will be a total of num_blocks in the generated test file,
   261      # proportional to number of records in the file divided by syncronization
   262      # interval used by avro during write. Each block has more than 10 records.
   263      num_blocks = int(math.ceil(14.5 * num_records / sync_interval))
   264      assert num_blocks > 1
   265      # When reading records of the first block, range_tracker.split_points()
   266      # should return (0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)
   267      self.assertEqual(
   268          split_points_report[:10],
   269          [(0, iobase.RangeTracker.SPLIT_POINTS_UNKNOWN)] * 10)
   270  
   271      # When reading records of last block, range_tracker.split_points() should
   272      # return (num_blocks - 1, 1)
   273      self.assertEqual(split_points_report[-10:], [(num_blocks - 1, 1)] * 10)
   274  
   275    def test_read_without_splitting_compressed_deflate(self):
   276      file_name = self._write_data(codec='deflate')
   277      expected_result = self.RECORDS
   278      self._run_avro_test(file_name, None, False, expected_result)
   279  
   280    def test_read_with_splitting_compressed_deflate(self):
   281      file_name = self._write_data(codec='deflate')
   282      expected_result = self.RECORDS
   283      self._run_avro_test(file_name, 100, True, expected_result)
   284  
   285    @unittest.skipIf(snappy is None, 'python-snappy not installed.')
   286    def test_read_without_splitting_compressed_snappy(self):
   287      file_name = self._write_data(codec='snappy')
   288      expected_result = self.RECORDS
   289      self._run_avro_test(file_name, None, False, expected_result)
   290  
   291    @unittest.skipIf(snappy is None, 'python-snappy not installed.')
   292    def test_read_with_splitting_compressed_snappy(self):
   293      file_name = self._write_data(codec='snappy')
   294      expected_result = self.RECORDS
   295      self._run_avro_test(file_name, 100, True, expected_result)
   296  
   297    def test_read_without_splitting_pattern(self):
   298      pattern = self._write_pattern(3)
   299      expected_result = self.RECORDS * 3
   300      self._run_avro_test(pattern, None, False, expected_result)
   301  
   302    def test_read_with_splitting_pattern(self):
   303      pattern = self._write_pattern(3)
   304      expected_result = self.RECORDS * 3
   305      self._run_avro_test(pattern, 100, True, expected_result)
   306  
   307    def test_dynamic_work_rebalancing_exhaustive(self):
   308      def compare_split_points(file_name):
   309        source = _create_avro_source(file_name)
   310        splits = [
   311            split for split in source.split(desired_bundle_size=float('inf'))
   312        ]
   313        assert len(splits) == 1
   314        source_test_utils.assert_split_at_fraction_exhaustive(splits[0].source)
   315  
   316      # Adjusting block size so that we can perform a exhaustive dynamic
   317      # work rebalancing test that completes within an acceptable amount of time.
   318      file_name = self._write_data(count=5, sync_interval=2)
   319  
   320      compare_split_points(file_name)
   321  
   322    def test_corrupted_file(self):
   323      file_name = self._write_data()
   324      with open(file_name, 'rb') as f:
   325        data = f.read()
   326  
   327      # Corrupt the last character of the file which is also the last character of
   328      # the last sync_marker.
   329      # https://avro.apache.org/docs/current/spec.html#Object+Container+Files
   330      corrupted_data = bytearray(data)
   331      corrupted_data[-1] = (corrupted_data[-1] + 1) % 256
   332      with tempfile.NamedTemporaryFile(delete=False,
   333                                       prefix=tempfile.template) as f:
   334        f.write(corrupted_data)
   335        corrupted_file_name = f.name
   336  
   337      source = _create_avro_source(corrupted_file_name)
   338      with self.assertRaisesRegex(ValueError, r'expected sync marker'):
   339        source_test_utils.read_from_source(source, None, None)
   340  
   341    def test_read_from_avro(self):
   342      path = self._write_data()
   343      with TestPipeline() as p:
   344        assert_that(p | avroio.ReadFromAvro(path), equal_to(self.RECORDS))
   345  
   346    def test_read_all_from_avro_single_file(self):
   347      path = self._write_data()
   348      with TestPipeline() as p:
   349        assert_that(
   350            p \
   351            | Create([path]) \
   352            | avroio.ReadAllFromAvro(),
   353            equal_to(self.RECORDS))
   354  
   355    def test_read_all_from_avro_many_single_files(self):
   356      path1 = self._write_data()
   357      path2 = self._write_data()
   358      path3 = self._write_data()
   359      with TestPipeline() as p:
   360        assert_that(
   361            p \
   362            | Create([path1, path2, path3]) \
   363            | avroio.ReadAllFromAvro(),
   364            equal_to(self.RECORDS * 3))
   365  
   366    def test_read_all_from_avro_file_pattern(self):
   367      file_pattern = self._write_pattern(5)
   368      with TestPipeline() as p:
   369        assert_that(
   370            p \
   371            | Create([file_pattern]) \
   372            | avroio.ReadAllFromAvro(),
   373            equal_to(self.RECORDS * 5))
   374  
   375    def test_read_all_from_avro_many_file_patterns(self):
   376      file_pattern1 = self._write_pattern(5)
   377      file_pattern2 = self._write_pattern(2)
   378      file_pattern3 = self._write_pattern(3)
   379      with TestPipeline() as p:
   380        assert_that(
   381            p \
   382            | Create([file_pattern1, file_pattern2, file_pattern3]) \
   383            | avroio.ReadAllFromAvro(),
   384            equal_to(self.RECORDS * 10))
   385  
   386    def test_read_all_from_avro_with_filename(self):
   387      file_pattern, file_paths = self._write_pattern(3, return_filenames=True)
   388      result = [(path, record) for path in file_paths for record in self.RECORDS]
   389      with TestPipeline() as p:
   390        assert_that(
   391            p \
   392            | Create([file_pattern]) \
   393            | avroio.ReadAllFromAvro(with_filename=True),
   394            equal_to(result))
   395  
   396    class _WriteFilesFn(beam.DoFn):
   397      """writes a couple of files with deferral."""
   398  
   399      COUNT_STATE = CombiningValueStateSpec('count', combine_fn=sum)
   400  
   401      def __init__(self, SCHEMA, RECORDS, tempdir):
   402        self._thread = None
   403        self.SCHEMA = SCHEMA
   404        self.RECORDS = RECORDS
   405        self.tempdir = tempdir
   406  
   407      def get_expect(self, match_updated_files):
   408        results_file1 = [('file1', x) for x in self.gen_records(1)]
   409        results_file2 = [('file2', x) for x in self.gen_records(3)]
   410        if match_updated_files:
   411          results_file1 += [('file1', x) for x in self.gen_records(2)]
   412        return results_file1 + results_file2
   413  
   414      def gen_records(self, count):
   415        return self.RECORDS * (count // len(self.RECORDS)) + self.RECORDS[:(
   416            count % len(self.RECORDS))]
   417  
   418      def process(self, element, count_state=beam.DoFn.StateParam(COUNT_STATE)):
   419        counter = count_state.read()
   420        if counter == 0:
   421          count_state.add(1)
   422          with open(FileSystems.join(self.tempdir, 'file1'), 'wb') as f:
   423            writer(f, self.SCHEMA, self.gen_records(2))
   424          with open(FileSystems.join(self.tempdir, 'file2'), 'wb') as f:
   425            writer(f, self.SCHEMA, self.gen_records(3))
   426        # convert dumb key to basename in output
   427        basename = FileSystems.split(element[1][0])[1]
   428        content = element[1][1]
   429        yield basename, content
   430  
   431    def test_read_all_continuously_new(self):
   432      with TestPipeline() as pipeline:
   433        tempdir = tempfile.mkdtemp()
   434        writer_fn = self._WriteFilesFn(self.SCHEMA, self.RECORDS, tempdir)
   435        with open(FileSystems.join(tempdir, 'file1'), 'wb') as f:
   436          writer(f, writer_fn.SCHEMA, writer_fn.gen_records(1))
   437        match_pattern = FileSystems.join(tempdir, '*')
   438        interval = 0.5
   439        last = 2
   440  
   441        p_read_once = (
   442            pipeline
   443            | 'Continuously read new files' >> avroio.ReadAllFromAvroContinuously(
   444                match_pattern,
   445                with_filename=True,
   446                start_timestamp=Timestamp.now(),
   447                interval=interval,
   448                stop_timestamp=Timestamp.now() + last,
   449                match_updated_files=False)
   450            | 'add dumb key' >> beam.Map(lambda x: (0, x))
   451            | 'Write files on-the-fly' >> beam.ParDo(writer_fn))
   452        assert_that(
   453            p_read_once,
   454            equal_to(writer_fn.get_expect(match_updated_files=False)),
   455            label='assert read new files results')
   456  
   457    def test_read_all_continuously_update(self):
   458      with TestPipeline() as pipeline:
   459        tempdir = tempfile.mkdtemp()
   460        writer_fn = self._WriteFilesFn(self.SCHEMA, self.RECORDS, tempdir)
   461        with open(FileSystems.join(tempdir, 'file1'), 'wb') as f:
   462          writer(f, writer_fn.SCHEMA, writer_fn.gen_records(1))
   463        match_pattern = FileSystems.join(tempdir, '*')
   464        interval = 0.5
   465        last = 2
   466  
   467        p_read_upd = (
   468            pipeline
   469            | 'Continuously read updated files' >>
   470            avroio.ReadAllFromAvroContinuously(
   471                match_pattern,
   472                with_filename=True,
   473                start_timestamp=Timestamp.now(),
   474                interval=interval,
   475                stop_timestamp=Timestamp.now() + last,
   476                match_updated_files=True)
   477            | 'add dumb key' >> beam.Map(lambda x: (0, x))
   478            | 'Write files on-the-fly' >> beam.ParDo(writer_fn))
   479        assert_that(
   480            p_read_upd,
   481            equal_to(writer_fn.get_expect(match_updated_files=True)),
   482            label='assert read updated files results')
   483  
   484    def test_sink_transform(self):
   485      with tempfile.NamedTemporaryFile() as dst:
   486        path = dst.name
   487        with TestPipeline() as p:
   488          # pylint: disable=expression-not-assigned
   489          p \
   490          | beam.Create(self.RECORDS) \
   491          | avroio.WriteToAvro(path, self.SCHEMA,)
   492        with TestPipeline() as p:
   493          # json used for stable sortability
   494          readback = \
   495              p \
   496              | avroio.ReadFromAvro(path + '*', ) \
   497              | beam.Map(json.dumps)
   498          assert_that(readback, equal_to([json.dumps(r) for r in self.RECORDS]))
   499  
   500    @unittest.skipIf(snappy is None, 'python-snappy not installed.')
   501    def test_sink_transform_snappy(self):
   502      with tempfile.NamedTemporaryFile() as dst:
   503        path = dst.name
   504        with TestPipeline() as p:
   505          # pylint: disable=expression-not-assigned
   506          p \
   507          | beam.Create(self.RECORDS) \
   508          | avroio.WriteToAvro(
   509              path,
   510              self.SCHEMA,
   511              codec='snappy')
   512        with TestPipeline() as p:
   513          # json used for stable sortability
   514          readback = \
   515              p \
   516              | avroio.ReadFromAvro(path + '*') \
   517              | beam.Map(json.dumps)
   518          assert_that(readback, equal_to([json.dumps(r) for r in self.RECORDS]))
   519  
   520    def test_writer_open_and_close(self):
   521      # Create and then close a temp file so we can manually open it later
   522      dst = tempfile.NamedTemporaryFile(delete=False)
   523      dst.close()
   524  
   525      schema = parse_schema(json.loads(self.SCHEMA_STRING))
   526      sink = _create_avro_sink(
   527          'some_avro_sink', schema, 'null', '.end', 0, None, 'application/x-avro')
   528  
   529      w = sink.open(dst.name)
   530  
   531      sink.close(w)
   532  
   533      os.unlink(dst.name)
   534  
   535  
   536  class TestFastAvro(AvroBase, unittest.TestCase):
   537    def __init__(self, methodName='runTest'):
   538      super().__init__(methodName)
   539      self.SCHEMA = parse_schema(json.loads(self.SCHEMA_STRING))
   540  
   541    def _write_data(
   542        self,
   543        directory=None,
   544        prefix=tempfile.template,
   545        codec='null',
   546        count=len(RECORDS),
   547        **kwargs):
   548      all_records = self.RECORDS * \
   549        (count // len(self.RECORDS)) + self.RECORDS[:(count % len(self.RECORDS))]
   550      with tempfile.NamedTemporaryFile(delete=False,
   551                                       dir=directory,
   552                                       prefix=prefix,
   553                                       mode='w+b') as f:
   554        writer(f, self.SCHEMA, all_records, codec=codec, **kwargs)
   555        self._temp_files.append(f.name)
   556      return f.name
   557  
   558  
   559  if __name__ == '__main__':
   560    logging.getLogger().setLevel(logging.INFO)
   561    unittest.main()