github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/parquetio_it_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  # pytype: skip-file
    18  
    19  import logging
    20  import string
    21  import unittest
    22  from collections import Counter
    23  
    24  import pytest
    25  
    26  from apache_beam import Create
    27  from apache_beam import DoFn
    28  from apache_beam import FlatMap
    29  from apache_beam import Flatten
    30  from apache_beam import Map
    31  from apache_beam import ParDo
    32  from apache_beam import Reshuffle
    33  from apache_beam.io.filesystems import FileSystems
    34  from apache_beam.io.parquetio import ReadAllFromParquet
    35  from apache_beam.io.parquetio import WriteToParquet
    36  from apache_beam.testing.test_pipeline import TestPipeline
    37  from apache_beam.testing.util import BeamAssertException
    38  from apache_beam.transforms import CombineGlobally
    39  from apache_beam.transforms.combiners import Count
    40  
    41  try:
    42    import pyarrow as pa
    43  except ImportError:
    44    pa = None
    45  
    46  
    47  @unittest.skipIf(pa is None, "PyArrow is not installed.")
    48  class TestParquetIT(unittest.TestCase):
    49    def setUp(self):
    50      pass
    51  
    52    def tearDown(self):
    53      pass
    54  
    55    @pytest.mark.it_postcommit
    56    def test_parquetio_it(self):
    57      file_prefix = "parquet_it_test"
    58      init_size = 10
    59      data_size = 20000
    60      with TestPipeline(is_integration_test=True) as p:
    61        pcol = self._generate_data(p, file_prefix, init_size, data_size)
    62        self._verify_data(pcol, init_size, data_size)
    63  
    64    @staticmethod
    65    def _sum_verifier(init_size, data_size, x):
    66      expected = sum(range(data_size)) * init_size
    67      if x != expected:
    68        raise BeamAssertException(
    69            "incorrect sum: expected(%d) actual(%d)" % (expected, x))
    70      return []
    71  
    72    @staticmethod
    73    def _count_verifier(init_size, data_size, x):
    74      name, count = x[0].decode('utf-8'), x[1]
    75      counter = Counter(
    76          [string.ascii_uppercase[x % 26] for x in range(0, data_size * 4, 4)])
    77      expected_count = counter[name[0]] * init_size
    78      if count != expected_count:
    79        raise BeamAssertException(
    80            "incorrect count(%s): expected(%d) actual(%d)" %
    81            (name, expected_count, count))
    82      return []
    83  
    84    def _verify_data(self, pcol, init_size, data_size):
    85      read = pcol | 'read' >> ReadAllFromParquet()
    86      v1 = (
    87          read
    88          | 'get_number' >> Map(lambda x: x['number'])
    89          | 'sum_globally' >> CombineGlobally(sum)
    90          | 'validate_number' >>
    91          FlatMap(lambda x: TestParquetIT._sum_verifier(init_size, data_size, x)))
    92      v2 = (
    93          read
    94          | 'make_pair' >> Map(lambda x: (x['name'], x['number']))
    95          | 'count_per_key' >> Count.PerKey()
    96          | 'validate_name' >> FlatMap(
    97              lambda x: TestParquetIT._count_verifier(init_size, data_size, x)))
    98      _ = ((v1, v2, pcol)
    99           | 'flatten' >> Flatten()
   100           | 'reshuffle' >> Reshuffle()
   101           | 'cleanup' >> Map(lambda x: FileSystems.delete([x])))
   102  
   103    def _generate_data(self, p, output_prefix, init_size, data_size):
   104      init_data = [x for x in range(init_size)]
   105  
   106      lines = (
   107          p
   108          | 'create' >> Create(init_data)
   109          | 'produce' >> ParDo(ProducerFn(data_size)))
   110  
   111      schema = pa.schema([('name', pa.binary()), ('number', pa.int64())])
   112  
   113      files = lines | 'write' >> WriteToParquet(
   114          output_prefix, schema, codec='snappy', file_name_suffix='.parquet')
   115  
   116      return files
   117  
   118  
   119  class ProducerFn(DoFn):
   120    def __init__(self, number):
   121      super().__init__()
   122      self._number = number
   123      self._string_index = 0
   124      self._number_index = 0
   125  
   126    def process(self, element):
   127      self._string_index = 0
   128      self._number_index = 0
   129      for _ in range(self._number):
   130        yield {'name': self.get_string(4), 'number': self.get_int()}
   131  
   132    def get_string(self, length):
   133      s = []
   134      for _ in range(length):
   135        s.append(string.ascii_uppercase[self._string_index])
   136        self._string_index = (self._string_index + 1) % 26
   137      return ''.join(s)
   138  
   139    def get_int(self):
   140      i = self._number_index
   141      self._number_index = self._number_index + 1
   142      return i
   143  
   144  
   145  if __name__ == '__main__':
   146    logging.getLogger().setLevel(logging.INFO)
   147    unittest.main()