github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/io/parquetio_it_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 # pytype: skip-file 18 19 import logging 20 import string 21 import unittest 22 from collections import Counter 23 24 import pytest 25 26 from apache_beam import Create 27 from apache_beam import DoFn 28 from apache_beam import FlatMap 29 from apache_beam import Flatten 30 from apache_beam import Map 31 from apache_beam import ParDo 32 from apache_beam import Reshuffle 33 from apache_beam.io.filesystems import FileSystems 34 from apache_beam.io.parquetio import ReadAllFromParquet 35 from apache_beam.io.parquetio import WriteToParquet 36 from apache_beam.testing.test_pipeline import TestPipeline 37 from apache_beam.testing.util import BeamAssertException 38 from apache_beam.transforms import CombineGlobally 39 from apache_beam.transforms.combiners import Count 40 41 try: 42 import pyarrow as pa 43 except ImportError: 44 pa = None 45 46 47 @unittest.skipIf(pa is None, "PyArrow is not installed.") 48 class TestParquetIT(unittest.TestCase): 49 def setUp(self): 50 pass 51 52 def tearDown(self): 53 pass 54 55 @pytest.mark.it_postcommit 56 def test_parquetio_it(self): 57 file_prefix = "parquet_it_test" 58 init_size = 10 59 data_size = 20000 60 with TestPipeline(is_integration_test=True) as p: 61 pcol = self._generate_data(p, file_prefix, init_size, data_size) 62 self._verify_data(pcol, init_size, data_size) 63 64 @staticmethod 65 def _sum_verifier(init_size, data_size, x): 66 expected = sum(range(data_size)) * init_size 67 if x != expected: 68 raise BeamAssertException( 69 "incorrect sum: expected(%d) actual(%d)" % (expected, x)) 70 return [] 71 72 @staticmethod 73 def _count_verifier(init_size, data_size, x): 74 name, count = x[0].decode('utf-8'), x[1] 75 counter = Counter( 76 [string.ascii_uppercase[x % 26] for x in range(0, data_size * 4, 4)]) 77 expected_count = counter[name[0]] * init_size 78 if count != expected_count: 79 raise BeamAssertException( 80 "incorrect count(%s): expected(%d) actual(%d)" % 81 (name, expected_count, count)) 82 return [] 83 84 def _verify_data(self, pcol, init_size, data_size): 85 read = pcol | 'read' >> ReadAllFromParquet() 86 v1 = ( 87 read 88 | 'get_number' >> Map(lambda x: x['number']) 89 | 'sum_globally' >> CombineGlobally(sum) 90 | 'validate_number' >> 91 FlatMap(lambda x: TestParquetIT._sum_verifier(init_size, data_size, x))) 92 v2 = ( 93 read 94 | 'make_pair' >> Map(lambda x: (x['name'], x['number'])) 95 | 'count_per_key' >> Count.PerKey() 96 | 'validate_name' >> FlatMap( 97 lambda x: TestParquetIT._count_verifier(init_size, data_size, x))) 98 _ = ((v1, v2, pcol) 99 | 'flatten' >> Flatten() 100 | 'reshuffle' >> Reshuffle() 101 | 'cleanup' >> Map(lambda x: FileSystems.delete([x]))) 102 103 def _generate_data(self, p, output_prefix, init_size, data_size): 104 init_data = [x for x in range(init_size)] 105 106 lines = ( 107 p 108 | 'create' >> Create(init_data) 109 | 'produce' >> ParDo(ProducerFn(data_size))) 110 111 schema = pa.schema([('name', pa.binary()), ('number', pa.int64())]) 112 113 files = lines | 'write' >> WriteToParquet( 114 output_prefix, schema, codec='snappy', file_name_suffix='.parquet') 115 116 return files 117 118 119 class ProducerFn(DoFn): 120 def __init__(self, number): 121 super().__init__() 122 self._number = number 123 self._string_index = 0 124 self._number_index = 0 125 126 def process(self, element): 127 self._string_index = 0 128 self._number_index = 0 129 for _ in range(self._number): 130 yield {'name': self.get_string(4), 'number': self.get_int()} 131 132 def get_string(self, length): 133 s = [] 134 for _ in range(length): 135 s.append(string.ascii_uppercase[self._string_index]) 136 self._string_index = (self._string_index + 1) % 26 137 return ''.join(s) 138 139 def get_int(self): 140 i = self._number_index 141 self._number_index = self._number_index + 1 142 return i 143 144 145 if __name__ == '__main__': 146 logging.getLogger().setLevel(logging.INFO) 147 unittest.main()