github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/dataframe/io_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 17 import glob 18 import importlib 19 import math 20 import os 21 import platform 22 import shutil 23 import tempfile 24 import typing 25 import unittest 26 from datetime import datetime 27 from io import BytesIO 28 from io import StringIO 29 30 import mock 31 import pandas as pd 32 import pandas.testing 33 import pyarrow 34 import pytest 35 from pandas.testing import assert_frame_equal 36 from parameterized import parameterized 37 38 import apache_beam as beam 39 import apache_beam.io.gcp.bigquery 40 from apache_beam.dataframe import convert 41 from apache_beam.dataframe import io 42 from apache_beam.io import fileio 43 from apache_beam.io import restriction_trackers 44 from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper 45 from apache_beam.io.gcp.internal.clients import bigquery 46 from apache_beam.testing.util import assert_that 47 from apache_beam.testing.util import equal_to 48 49 try: 50 from apitools.base.py.exceptions import HttpError 51 except ImportError: 52 HttpError = None 53 54 # Get major, minor version 55 PD_VERSION = tuple(map(int, pd.__version__.split('.')[0:2])) 56 PYARROW_VERSION = tuple(map(int, pyarrow.__version__.split('.')[0:2])) 57 58 59 class SimpleRow(typing.NamedTuple): 60 value: int 61 62 63 class MyRow(typing.NamedTuple): 64 timestamp: int 65 value: int 66 67 68 @unittest.skipIf( 69 platform.system() == 'Windows', 70 'https://github.com/apache/beam/issues/20642') 71 class IOTest(unittest.TestCase): 72 def setUp(self): 73 self._temp_roots = [] 74 75 def tearDown(self): 76 for root in self._temp_roots: 77 shutil.rmtree(root) 78 79 def temp_dir(self, files=None): 80 dir = tempfile.mkdtemp(prefix='beam-test') 81 self._temp_roots.append(dir) 82 if files: 83 for name, contents in files.items(): 84 with open(os.path.join(dir, name), 'w') as fout: 85 fout.write(contents) 86 return dir + os.path.sep 87 88 def read_all_lines(self, pattern, delete=False): 89 for path in glob.glob(pattern): 90 with open(path) as fin: 91 # TODO(Py3): yield from 92 for line in fin: 93 yield line.rstrip('\n') 94 if delete: 95 os.remove(path) 96 97 def test_read_fwf(self): 98 input = self.temp_dir( 99 {'all.fwf': ''' 100 A B 101 11a 0 102 37a 1 103 389a 2 104 '''.strip()}) 105 with beam.Pipeline() as p: 106 df = p | io.read_fwf(input + 'all.fwf') 107 rows = convert.to_pcollection(df) | beam.Map(tuple) 108 assert_that(rows, equal_to([('11a', 0), ('37a', 1), ('389a', 2)])) 109 110 def test_read_write_csv(self): 111 input = self.temp_dir({'1.csv': 'a,b\n1,2\n', '2.csv': 'a,b\n3,4\n'}) 112 output = self.temp_dir() 113 with beam.Pipeline() as p: 114 df = p | io.read_csv(input + '*.csv') 115 df['c'] = df.a + df.b 116 df.to_csv(output + 'out.csv', index=False) 117 self.assertCountEqual(['a,b,c', '1,2,3', '3,4,7'], 118 set(self.read_all_lines(output + 'out.csv*'))) 119 120 def test_sharding_parameters(self): 121 data = pd.DataFrame({'label': ['11a', '37a', '389a'], 'rank': [0, 1, 2]}) 122 output = self.temp_dir() 123 with beam.Pipeline() as p: 124 df = convert.to_dataframe(p | beam.Create([data]), proxy=data[:0]) 125 df.to_csv( 126 output, 127 num_shards=1, 128 file_naming=fileio.single_file_naming('out.csv')) 129 self.assertEqual(glob.glob(output + '*'), [output + 'out.csv']) 130 131 @pytest.mark.uses_pyarrow 132 @unittest.skipIf( 133 PD_VERSION >= (1, 4) and PYARROW_VERSION < (1, 0), 134 "pandas 1.4 requires at least pyarrow 1.0.1") 135 def test_read_write_parquet(self): 136 self._run_read_write_test( 137 'parquet', {}, {}, dict(check_index=False), ['pyarrow']) 138 139 @parameterized.expand([ 140 ('csv', dict(index_col=0)), 141 ('csv', dict(index_col=0, splittable=True)), 142 ('json', dict(orient='index'), dict(orient='index')), 143 ('json', dict(orient='columns'), dict(orient='columns')), 144 ('json', dict(orient='split'), dict(orient='split')), 145 ( 146 'json', 147 dict(orient='values'), 148 dict(orient='values'), 149 dict(check_index=False, check_names=False)), 150 ( 151 'json', 152 dict(orient='records'), 153 dict(orient='records'), 154 dict(check_index=False)), 155 ( 156 'json', 157 dict(orient='records', lines=True), 158 dict(orient='records', lines=True), 159 dict(check_index=False)), 160 ('html', dict(index_col=0), {}, {}, ['lxml']), 161 ('excel', dict(index_col=0), {}, {}, ['openpyxl', 'xlrd']), 162 ]) 163 # pylint: disable=dangerous-default-value 164 def test_read_write( 165 self, 166 format, 167 read_kwargs={}, 168 write_kwargs={}, 169 check_options={}, 170 requires=()): 171 self._run_read_write_test( 172 format, read_kwargs, write_kwargs, check_options, requires) 173 174 # pylint: disable=dangerous-default-value 175 def _run_read_write_test( 176 self, 177 format, 178 read_kwargs={}, 179 write_kwargs={}, 180 check_options={}, 181 requires=()): 182 183 for module in requires: 184 try: 185 importlib.import_module(module) 186 except ImportError: 187 raise unittest.SkipTest('Missing dependency: %s' % module) 188 small = pd.DataFrame({'label': ['11a', '37a', '389a'], 'rank': [0, 1, 2]}) 189 big = pd.DataFrame({'number': list(range(1000))}) 190 big['float'] = big.number.map(math.sqrt) 191 big['text'] = big.number.map(lambda n: 'f' + 'o' * n) 192 193 def frame_equal_to(expected_, check_index=True, check_names=True): 194 def check(actual): 195 expected = expected_ 196 try: 197 actual = pd.concat(actual) 198 if not check_index: 199 expected = expected.sort_values(list( 200 expected.columns)).reset_index(drop=True) 201 actual = actual.sort_values(list( 202 actual.columns)).reset_index(drop=True) 203 if not check_names: 204 actual = actual.rename( 205 columns=dict(zip(actual.columns, expected.columns))) 206 return assert_frame_equal(expected, actual, check_like=True) 207 except: 208 print("EXPECTED") 209 print(expected) 210 print("ACTUAL") 211 print(actual) 212 raise 213 214 return check 215 216 for df in (small, big): 217 with tempfile.TemporaryDirectory() as dir: 218 dest = os.path.join(dir, 'out') 219 try: 220 with beam.Pipeline() as p: 221 deferred_df = convert.to_dataframe( 222 p | beam.Create([df[::3], df[1::3], df[2::3]]), proxy=df[:0]) 223 # This does the write. 224 getattr(deferred_df, 'to_%s' % format)(dest, **write_kwargs) 225 with beam.Pipeline() as p: 226 # Now do the read. 227 # TODO(robertwb): Allow reading from pcoll of paths to do it all in 228 # one pipeline. 229 230 result = convert.to_pcollection( 231 p | getattr(io, 'read_%s' % format)(dest + '*', **read_kwargs), 232 yield_elements='pandas') 233 assert_that(result, frame_equal_to(df, **check_options)) 234 except: 235 os.system('head -n 100 ' + dest + '*') 236 raise 237 238 def _run_truncating_file_handle_test( 239 self, s, splits, delim=' ', chunk_size=10): 240 split_results = [] 241 next_range = restriction_trackers.OffsetRange(0, len(s)) 242 for split in list(splits) + [None]: 243 tracker = restriction_trackers.OffsetRestrictionTracker(next_range) 244 handle = io._TruncatingFileHandle( 245 StringIO(s), tracker, splitter=io._DelimSplitter(delim, chunk_size)) 246 data = '' 247 chunk = handle.read(1) 248 if split is not None: 249 _, next_range = tracker.try_split(split) 250 while chunk: 251 data += chunk 252 chunk = handle.read(7) 253 split_results.append(data) 254 return split_results 255 256 def test_truncating_filehandle(self): 257 self.assertEqual( 258 self._run_truncating_file_handle_test('a b c d e', [0.5]), 259 ['a b c ', 'd e']) 260 self.assertEqual( 261 self._run_truncating_file_handle_test('aaaaaaaaaaaaaaXaaa b', [0.5]), 262 ['aaaaaaaaaaaaaaXaaa ', 'b']) 263 self.assertEqual( 264 self._run_truncating_file_handle_test( 265 'aa bbbbbbbbbbbbbbbbbbbbbbbbbb ccc ', [0.01, 0.5]), 266 ['aa ', 'bbbbbbbbbbbbbbbbbbbbbbbbbb ', 'ccc ']) 267 268 numbers = 'x'.join(str(k) for k in range(1000)) 269 splits = self._run_truncating_file_handle_test( 270 numbers, [0.1] * 20, delim='x') 271 self.assertEqual(numbers, ''.join(splits)) 272 self.assertTrue(s.endswith('x') for s in splits[:-1]) 273 self.assertLess(max(len(s) for s in splits), len(numbers) * 0.9 + 10) 274 self.assertGreater( 275 min(len(s) for s in splits), len(numbers) * 0.9**20 * 0.1) 276 277 def _run_truncating_file_handle_iter_test(self, s, delim=' ', chunk_size=10): 278 tracker = restriction_trackers.OffsetRestrictionTracker( 279 restriction_trackers.OffsetRange(0, len(s))) 280 handle = io._TruncatingFileHandle( 281 StringIO(s), tracker, splitter=io._DelimSplitter(delim, chunk_size)) 282 self.assertEqual(s, ''.join(list(handle))) 283 284 def test_truncating_filehandle_iter(self): 285 self._run_truncating_file_handle_iter_test('a b c') 286 self._run_truncating_file_handle_iter_test('aaaaaaaaaaaaaaaaaaaa b ccc') 287 self._run_truncating_file_handle_iter_test('aaa b cccccccccccccccccccc') 288 self._run_truncating_file_handle_iter_test('aaa b ccccccccccccccccc ') 289 290 @parameterized.expand([ 291 ('defaults', {}), 292 ('header', dict(header=1)), 293 ('multi_header', dict(header=[0, 1])), 294 ('multi_header', dict(header=[0, 1, 4])), 295 ('names', dict(names=('m', 'n', 'o'))), 296 ('names_and_header', dict(names=('m', 'n', 'o'), header=0)), 297 ('skip_blank_lines', dict(header=4, skip_blank_lines=True)), 298 ('skip_blank_lines', dict(header=4, skip_blank_lines=False)), 299 ('comment', dict(comment='X', header=4)), 300 ('comment', dict(comment='X', header=[0, 3])), 301 ('skiprows', dict(skiprows=0, header=[0, 1])), 302 ('skiprows', dict(skiprows=[1], header=[0, 3], skip_blank_lines=False)), 303 ('skiprows', dict(skiprows=[0, 1], header=[0, 1], comment='X')), 304 ]) 305 def test_csv_splitter(self, name, kwargs): 306 def assert_frame_equal(expected, actual): 307 try: 308 pandas.testing.assert_frame_equal(expected, actual) 309 except AssertionError: 310 print("Expected:\n", expected) 311 print("Actual:\n", actual) 312 raise 313 314 def read_truncated_csv(start, stop): 315 return pd.read_csv( 316 io._TruncatingFileHandle( 317 BytesIO(contents.encode('ascii')), 318 restriction_trackers.OffsetRestrictionTracker( 319 restriction_trackers.OffsetRange(start, stop)), 320 splitter=io._TextFileSplitter((), kwargs, read_chunk_size=7)), 321 index_col=0, 322 **kwargs) 323 324 contents = ''' 325 a0, a1, a2 326 b0, b1, b2 327 328 X , c1, c2 329 e0, e1, e2 330 f0, f1, f2 331 w, daaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaata, w 332 x, daaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaata, x 333 y, daaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaata, y 334 z, daaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaata, z 335 '''.strip() 336 expected = pd.read_csv(StringIO(contents), index_col=0, **kwargs) 337 338 one_shard = read_truncated_csv(0, len(contents)) 339 assert_frame_equal(expected, one_shard) 340 341 equal_shards = pd.concat([ 342 read_truncated_csv(0, len(contents) // 2), 343 read_truncated_csv(len(contents) // 2, len(contents)), 344 ]) 345 assert_frame_equal(expected, equal_shards) 346 347 three_shards = pd.concat([ 348 read_truncated_csv(0, len(contents) // 3), 349 read_truncated_csv(len(contents) // 3, len(contents) * 2 // 3), 350 read_truncated_csv(len(contents) * 2 // 3, len(contents)), 351 ]) 352 assert_frame_equal(expected, three_shards) 353 354 # https://github.com/pandas-dev/pandas/issues/38292 355 if not isinstance(kwargs.get('header'), list): 356 split_in_header = pd.concat([ 357 read_truncated_csv(0, 1), 358 read_truncated_csv(1, len(contents)), 359 ]) 360 assert_frame_equal(expected, split_in_header) 361 362 if not kwargs: 363 # Make sure we're correct as we cross the header boundary. 364 # We don't need to do this for every permutation. 365 header_end = contents.index('a2') + 3 366 for split in range(header_end - 2, header_end + 2): 367 split_at_header = pd.concat([ 368 read_truncated_csv(0, split), 369 read_truncated_csv(split, len(contents)), 370 ]) 371 assert_frame_equal(expected, split_at_header) 372 373 def test_file_not_found(self): 374 with self.assertRaisesRegex(FileNotFoundError, r'/tmp/fake_dir/\*\*'): 375 with beam.Pipeline() as p: 376 _ = p | io.read_csv('/tmp/fake_dir/**') 377 378 def test_windowed_write(self): 379 output = self.temp_dir() 380 with beam.Pipeline() as p: 381 pc = ( 382 p | beam.Create([MyRow(timestamp=i, value=i % 3) for i in range(20)]) 383 | beam.Map(lambda v: beam.window.TimestampedValue(v, v.timestamp)). 384 with_output_types(MyRow) 385 | beam.WindowInto( 386 beam.window.FixedWindows(10)).with_output_types(MyRow)) 387 388 deferred_df = convert.to_dataframe(pc) 389 deferred_df.to_csv(output + 'out.csv', index=False) 390 391 first_window_files = ( 392 f'{output}out.csv-' 393 f'{datetime.utcfromtimestamp(0).isoformat()}*') 394 self.assertCountEqual( 395 ['timestamp,value'] + [f'{i},{i % 3}' for i in range(10)], 396 set(self.read_all_lines(first_window_files, delete=True))) 397 398 second_window_files = ( 399 f'{output}out.csv-' 400 f'{datetime.utcfromtimestamp(10).isoformat()}*') 401 self.assertCountEqual( 402 ['timestamp,value'] + [f'{i},{i%3}' for i in range(10, 20)], 403 set(self.read_all_lines(second_window_files, delete=True))) 404 405 # Check that we've read (and removed) every output file 406 self.assertEqual(len(glob.glob(f'{output}out.csv*')), 0) 407 408 def test_double_write(self): 409 output = self.temp_dir() 410 with beam.Pipeline() as p: 411 pc1 = p | 'create pc1' >> beam.Create( 412 [SimpleRow(value=i) for i in [1, 2]]) 413 pc2 = p | 'create pc2' >> beam.Create( 414 [SimpleRow(value=i) for i in [3, 4]]) 415 416 deferred_df1 = convert.to_dataframe(pc1) 417 deferred_df2 = convert.to_dataframe(pc2) 418 419 deferred_df1.to_csv( 420 f'{output}out1.csv', 421 transform_label="Writing to csv PC1", 422 index=False) 423 deferred_df2.to_csv( 424 f'{output}out2.csv', 425 transform_label="Writing to csv PC2", 426 index=False) 427 428 self.assertCountEqual(['value', '1', '2'], 429 set(self.read_all_lines(output + 'out1.csv*'))) 430 self.assertCountEqual(['value', '3', '4'], 431 set(self.read_all_lines(output + 'out2.csv*'))) 432 433 434 @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') 435 class ReadGbqTransformTests(unittest.TestCase): 436 @mock.patch.object(BigQueryWrapper, 'get_table') 437 def test_bad_schema_public_api_direct_read(self, get_table): 438 try: 439 bigquery.TableFieldSchema 440 except AttributeError: 441 raise ValueError('Please install GCP Dependencies.') 442 fields = [ 443 bigquery.TableFieldSchema(name='stn', type='DOUBLE', mode="NULLABLE"), 444 bigquery.TableFieldSchema(name='temp', type='FLOAT64', mode="REPEATED"), 445 bigquery.TableFieldSchema(name='count', type='INTEGER', mode=None) 446 ] 447 schema = bigquery.TableSchema(fields=fields) 448 table = apache_beam.io.gcp.internal.clients.bigquery. \ 449 bigquery_v2_messages.Table( 450 schema=schema) 451 get_table.return_value = table 452 453 with self.assertRaisesRegex(ValueError, 454 "Encountered an unsupported type: 'DOUBLE'"): 455 p = apache_beam.Pipeline() 456 _ = p | apache_beam.dataframe.io.read_gbq( 457 table="dataset.sample_table", use_bqstorage_api=True) 458 459 def test_unsupported_callable(self): 460 def filterTable(table): 461 if table is not None: 462 return table 463 464 res = filterTable 465 with self.assertRaisesRegex(TypeError, 466 'ReadFromBigQuery: table must be of type string' 467 '; got a callable instead'): 468 p = beam.Pipeline() 469 _ = p | beam.dataframe.io.read_gbq(table=res) 470 471 def test_ReadGbq_unsupported_param(self): 472 with self.assertRaisesRegex(ValueError, 473 r"""Encountered unsupported parameter\(s\) """ 474 r"""in read_gbq: dict_keys\(\['reauth']\)"""): 475 p = beam.Pipeline() 476 _ = p | beam.dataframe.io.read_gbq( 477 table="table", use_bqstorage_api=False, reauth="true_config") 478 479 480 if __name__ == '__main__': 481 unittest.main()