github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/dataframe/io_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  
    17  import glob
    18  import importlib
    19  import math
    20  import os
    21  import platform
    22  import shutil
    23  import tempfile
    24  import typing
    25  import unittest
    26  from datetime import datetime
    27  from io import BytesIO
    28  from io import StringIO
    29  
    30  import mock
    31  import pandas as pd
    32  import pandas.testing
    33  import pyarrow
    34  import pytest
    35  from pandas.testing import assert_frame_equal
    36  from parameterized import parameterized
    37  
    38  import apache_beam as beam
    39  import apache_beam.io.gcp.bigquery
    40  from apache_beam.dataframe import convert
    41  from apache_beam.dataframe import io
    42  from apache_beam.io import fileio
    43  from apache_beam.io import restriction_trackers
    44  from apache_beam.io.gcp.bigquery_tools import BigQueryWrapper
    45  from apache_beam.io.gcp.internal.clients import bigquery
    46  from apache_beam.testing.util import assert_that
    47  from apache_beam.testing.util import equal_to
    48  
    49  try:
    50    from apitools.base.py.exceptions import HttpError
    51  except ImportError:
    52    HttpError = None
    53  
    54  # Get major, minor version
    55  PD_VERSION = tuple(map(int, pd.__version__.split('.')[0:2]))
    56  PYARROW_VERSION = tuple(map(int, pyarrow.__version__.split('.')[0:2]))
    57  
    58  
    59  class SimpleRow(typing.NamedTuple):
    60    value: int
    61  
    62  
    63  class MyRow(typing.NamedTuple):
    64    timestamp: int
    65    value: int
    66  
    67  
    68  @unittest.skipIf(
    69      platform.system() == 'Windows',
    70      'https://github.com/apache/beam/issues/20642')
    71  class IOTest(unittest.TestCase):
    72    def setUp(self):
    73      self._temp_roots = []
    74  
    75    def tearDown(self):
    76      for root in self._temp_roots:
    77        shutil.rmtree(root)
    78  
    79    def temp_dir(self, files=None):
    80      dir = tempfile.mkdtemp(prefix='beam-test')
    81      self._temp_roots.append(dir)
    82      if files:
    83        for name, contents in files.items():
    84          with open(os.path.join(dir, name), 'w') as fout:
    85            fout.write(contents)
    86      return dir + os.path.sep
    87  
    88    def read_all_lines(self, pattern, delete=False):
    89      for path in glob.glob(pattern):
    90        with open(path) as fin:
    91          # TODO(Py3): yield from
    92          for line in fin:
    93            yield line.rstrip('\n')
    94        if delete:
    95          os.remove(path)
    96  
    97    def test_read_fwf(self):
    98      input = self.temp_dir(
    99          {'all.fwf': '''
   100  A     B
   101  11a   0
   102  37a   1
   103  389a  2
   104      '''.strip()})
   105      with beam.Pipeline() as p:
   106        df = p | io.read_fwf(input + 'all.fwf')
   107        rows = convert.to_pcollection(df) | beam.Map(tuple)
   108        assert_that(rows, equal_to([('11a', 0), ('37a', 1), ('389a', 2)]))
   109  
   110    def test_read_write_csv(self):
   111      input = self.temp_dir({'1.csv': 'a,b\n1,2\n', '2.csv': 'a,b\n3,4\n'})
   112      output = self.temp_dir()
   113      with beam.Pipeline() as p:
   114        df = p | io.read_csv(input + '*.csv')
   115        df['c'] = df.a + df.b
   116        df.to_csv(output + 'out.csv', index=False)
   117      self.assertCountEqual(['a,b,c', '1,2,3', '3,4,7'],
   118                            set(self.read_all_lines(output + 'out.csv*')))
   119  
   120    def test_sharding_parameters(self):
   121      data = pd.DataFrame({'label': ['11a', '37a', '389a'], 'rank': [0, 1, 2]})
   122      output = self.temp_dir()
   123      with beam.Pipeline() as p:
   124        df = convert.to_dataframe(p | beam.Create([data]), proxy=data[:0])
   125        df.to_csv(
   126            output,
   127            num_shards=1,
   128            file_naming=fileio.single_file_naming('out.csv'))
   129      self.assertEqual(glob.glob(output + '*'), [output + 'out.csv'])
   130  
   131    @pytest.mark.uses_pyarrow
   132    @unittest.skipIf(
   133        PD_VERSION >= (1, 4) and PYARROW_VERSION < (1, 0),
   134        "pandas 1.4 requires at least pyarrow 1.0.1")
   135    def test_read_write_parquet(self):
   136      self._run_read_write_test(
   137          'parquet', {}, {}, dict(check_index=False), ['pyarrow'])
   138  
   139    @parameterized.expand([
   140        ('csv', dict(index_col=0)),
   141        ('csv', dict(index_col=0, splittable=True)),
   142        ('json', dict(orient='index'), dict(orient='index')),
   143        ('json', dict(orient='columns'), dict(orient='columns')),
   144        ('json', dict(orient='split'), dict(orient='split')),
   145        (
   146            'json',
   147            dict(orient='values'),
   148            dict(orient='values'),
   149            dict(check_index=False, check_names=False)),
   150        (
   151            'json',
   152            dict(orient='records'),
   153            dict(orient='records'),
   154            dict(check_index=False)),
   155        (
   156            'json',
   157            dict(orient='records', lines=True),
   158            dict(orient='records', lines=True),
   159            dict(check_index=False)),
   160        ('html', dict(index_col=0), {}, {}, ['lxml']),
   161        ('excel', dict(index_col=0), {}, {}, ['openpyxl', 'xlrd']),
   162    ])
   163    # pylint: disable=dangerous-default-value
   164    def test_read_write(
   165        self,
   166        format,
   167        read_kwargs={},
   168        write_kwargs={},
   169        check_options={},
   170        requires=()):
   171      self._run_read_write_test(
   172          format, read_kwargs, write_kwargs, check_options, requires)
   173  
   174    # pylint: disable=dangerous-default-value
   175    def _run_read_write_test(
   176        self,
   177        format,
   178        read_kwargs={},
   179        write_kwargs={},
   180        check_options={},
   181        requires=()):
   182  
   183      for module in requires:
   184        try:
   185          importlib.import_module(module)
   186        except ImportError:
   187          raise unittest.SkipTest('Missing dependency: %s' % module)
   188      small = pd.DataFrame({'label': ['11a', '37a', '389a'], 'rank': [0, 1, 2]})
   189      big = pd.DataFrame({'number': list(range(1000))})
   190      big['float'] = big.number.map(math.sqrt)
   191      big['text'] = big.number.map(lambda n: 'f' + 'o' * n)
   192  
   193      def frame_equal_to(expected_, check_index=True, check_names=True):
   194        def check(actual):
   195          expected = expected_
   196          try:
   197            actual = pd.concat(actual)
   198            if not check_index:
   199              expected = expected.sort_values(list(
   200                  expected.columns)).reset_index(drop=True)
   201              actual = actual.sort_values(list(
   202                  actual.columns)).reset_index(drop=True)
   203            if not check_names:
   204              actual = actual.rename(
   205                  columns=dict(zip(actual.columns, expected.columns)))
   206            return assert_frame_equal(expected, actual, check_like=True)
   207          except:
   208            print("EXPECTED")
   209            print(expected)
   210            print("ACTUAL")
   211            print(actual)
   212            raise
   213  
   214        return check
   215  
   216      for df in (small, big):
   217        with tempfile.TemporaryDirectory() as dir:
   218          dest = os.path.join(dir, 'out')
   219          try:
   220            with beam.Pipeline() as p:
   221              deferred_df = convert.to_dataframe(
   222                  p | beam.Create([df[::3], df[1::3], df[2::3]]), proxy=df[:0])
   223              # This does the write.
   224              getattr(deferred_df, 'to_%s' % format)(dest, **write_kwargs)
   225            with beam.Pipeline() as p:
   226              # Now do the read.
   227              # TODO(robertwb): Allow reading from pcoll of paths to do it all in
   228              # one pipeline.
   229  
   230              result = convert.to_pcollection(
   231                  p | getattr(io, 'read_%s' % format)(dest + '*', **read_kwargs),
   232                  yield_elements='pandas')
   233              assert_that(result, frame_equal_to(df, **check_options))
   234          except:
   235            os.system('head -n 100 ' + dest + '*')
   236            raise
   237  
   238    def _run_truncating_file_handle_test(
   239        self, s, splits, delim=' ', chunk_size=10):
   240      split_results = []
   241      next_range = restriction_trackers.OffsetRange(0, len(s))
   242      for split in list(splits) + [None]:
   243        tracker = restriction_trackers.OffsetRestrictionTracker(next_range)
   244        handle = io._TruncatingFileHandle(
   245            StringIO(s), tracker, splitter=io._DelimSplitter(delim, chunk_size))
   246        data = ''
   247        chunk = handle.read(1)
   248        if split is not None:
   249          _, next_range = tracker.try_split(split)
   250        while chunk:
   251          data += chunk
   252          chunk = handle.read(7)
   253        split_results.append(data)
   254      return split_results
   255  
   256    def test_truncating_filehandle(self):
   257      self.assertEqual(
   258          self._run_truncating_file_handle_test('a b c d e', [0.5]),
   259          ['a b c ', 'd e'])
   260      self.assertEqual(
   261          self._run_truncating_file_handle_test('aaaaaaaaaaaaaaXaaa b', [0.5]),
   262          ['aaaaaaaaaaaaaaXaaa ', 'b'])
   263      self.assertEqual(
   264          self._run_truncating_file_handle_test(
   265              'aa bbbbbbbbbbbbbbbbbbbbbbbbbb ccc ', [0.01, 0.5]),
   266          ['aa ', 'bbbbbbbbbbbbbbbbbbbbbbbbbb ', 'ccc '])
   267  
   268      numbers = 'x'.join(str(k) for k in range(1000))
   269      splits = self._run_truncating_file_handle_test(
   270          numbers, [0.1] * 20, delim='x')
   271      self.assertEqual(numbers, ''.join(splits))
   272      self.assertTrue(s.endswith('x') for s in splits[:-1])
   273      self.assertLess(max(len(s) for s in splits), len(numbers) * 0.9 + 10)
   274      self.assertGreater(
   275          min(len(s) for s in splits), len(numbers) * 0.9**20 * 0.1)
   276  
   277    def _run_truncating_file_handle_iter_test(self, s, delim=' ', chunk_size=10):
   278      tracker = restriction_trackers.OffsetRestrictionTracker(
   279          restriction_trackers.OffsetRange(0, len(s)))
   280      handle = io._TruncatingFileHandle(
   281          StringIO(s), tracker, splitter=io._DelimSplitter(delim, chunk_size))
   282      self.assertEqual(s, ''.join(list(handle)))
   283  
   284    def test_truncating_filehandle_iter(self):
   285      self._run_truncating_file_handle_iter_test('a b c')
   286      self._run_truncating_file_handle_iter_test('aaaaaaaaaaaaaaaaaaaa b ccc')
   287      self._run_truncating_file_handle_iter_test('aaa b cccccccccccccccccccc')
   288      self._run_truncating_file_handle_iter_test('aaa b ccccccccccccccccc ')
   289  
   290    @parameterized.expand([
   291        ('defaults', {}),
   292        ('header', dict(header=1)),
   293        ('multi_header', dict(header=[0, 1])),
   294        ('multi_header', dict(header=[0, 1, 4])),
   295        ('names', dict(names=('m', 'n', 'o'))),
   296        ('names_and_header', dict(names=('m', 'n', 'o'), header=0)),
   297        ('skip_blank_lines', dict(header=4, skip_blank_lines=True)),
   298        ('skip_blank_lines', dict(header=4, skip_blank_lines=False)),
   299        ('comment', dict(comment='X', header=4)),
   300        ('comment', dict(comment='X', header=[0, 3])),
   301        ('skiprows', dict(skiprows=0, header=[0, 1])),
   302        ('skiprows', dict(skiprows=[1], header=[0, 3], skip_blank_lines=False)),
   303        ('skiprows', dict(skiprows=[0, 1], header=[0, 1], comment='X')),
   304    ])
   305    def test_csv_splitter(self, name, kwargs):
   306      def assert_frame_equal(expected, actual):
   307        try:
   308          pandas.testing.assert_frame_equal(expected, actual)
   309        except AssertionError:
   310          print("Expected:\n", expected)
   311          print("Actual:\n", actual)
   312          raise
   313  
   314      def read_truncated_csv(start, stop):
   315        return pd.read_csv(
   316            io._TruncatingFileHandle(
   317                BytesIO(contents.encode('ascii')),
   318                restriction_trackers.OffsetRestrictionTracker(
   319                    restriction_trackers.OffsetRange(start, stop)),
   320                splitter=io._TextFileSplitter((), kwargs, read_chunk_size=7)),
   321            index_col=0,
   322            **kwargs)
   323  
   324      contents = '''
   325      a0, a1, a2
   326      b0, b1, b2
   327  
   328  X     , c1, c2
   329      e0, e1, e2
   330      f0, f1, f2
   331      w, daaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaata, w
   332      x, daaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaata, x
   333      y, daaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaata, y
   334      z, daaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaata, z
   335      '''.strip()
   336      expected = pd.read_csv(StringIO(contents), index_col=0, **kwargs)
   337  
   338      one_shard = read_truncated_csv(0, len(contents))
   339      assert_frame_equal(expected, one_shard)
   340  
   341      equal_shards = pd.concat([
   342          read_truncated_csv(0, len(contents) // 2),
   343          read_truncated_csv(len(contents) // 2, len(contents)),
   344      ])
   345      assert_frame_equal(expected, equal_shards)
   346  
   347      three_shards = pd.concat([
   348          read_truncated_csv(0, len(contents) // 3),
   349          read_truncated_csv(len(contents) // 3, len(contents) * 2 // 3),
   350          read_truncated_csv(len(contents) * 2 // 3, len(contents)),
   351      ])
   352      assert_frame_equal(expected, three_shards)
   353  
   354      # https://github.com/pandas-dev/pandas/issues/38292
   355      if not isinstance(kwargs.get('header'), list):
   356        split_in_header = pd.concat([
   357            read_truncated_csv(0, 1),
   358            read_truncated_csv(1, len(contents)),
   359        ])
   360        assert_frame_equal(expected, split_in_header)
   361  
   362      if not kwargs:
   363        # Make sure we're correct as we cross the header boundary.
   364        # We don't need to do this for every permutation.
   365        header_end = contents.index('a2') + 3
   366        for split in range(header_end - 2, header_end + 2):
   367          split_at_header = pd.concat([
   368              read_truncated_csv(0, split),
   369              read_truncated_csv(split, len(contents)),
   370          ])
   371          assert_frame_equal(expected, split_at_header)
   372  
   373    def test_file_not_found(self):
   374      with self.assertRaisesRegex(FileNotFoundError, r'/tmp/fake_dir/\*\*'):
   375        with beam.Pipeline() as p:
   376          _ = p | io.read_csv('/tmp/fake_dir/**')
   377  
   378    def test_windowed_write(self):
   379      output = self.temp_dir()
   380      with beam.Pipeline() as p:
   381        pc = (
   382            p | beam.Create([MyRow(timestamp=i, value=i % 3) for i in range(20)])
   383            | beam.Map(lambda v: beam.window.TimestampedValue(v, v.timestamp)).
   384            with_output_types(MyRow)
   385            | beam.WindowInto(
   386                beam.window.FixedWindows(10)).with_output_types(MyRow))
   387  
   388        deferred_df = convert.to_dataframe(pc)
   389        deferred_df.to_csv(output + 'out.csv', index=False)
   390  
   391      first_window_files = (
   392          f'{output}out.csv-'
   393          f'{datetime.utcfromtimestamp(0).isoformat()}*')
   394      self.assertCountEqual(
   395          ['timestamp,value'] + [f'{i},{i % 3}' for i in range(10)],
   396          set(self.read_all_lines(first_window_files, delete=True)))
   397  
   398      second_window_files = (
   399          f'{output}out.csv-'
   400          f'{datetime.utcfromtimestamp(10).isoformat()}*')
   401      self.assertCountEqual(
   402          ['timestamp,value'] + [f'{i},{i%3}' for i in range(10, 20)],
   403          set(self.read_all_lines(second_window_files, delete=True)))
   404  
   405      # Check that we've read (and removed) every output file
   406      self.assertEqual(len(glob.glob(f'{output}out.csv*')), 0)
   407  
   408    def test_double_write(self):
   409      output = self.temp_dir()
   410      with beam.Pipeline() as p:
   411        pc1 = p | 'create pc1' >> beam.Create(
   412            [SimpleRow(value=i) for i in [1, 2]])
   413        pc2 = p | 'create pc2' >> beam.Create(
   414            [SimpleRow(value=i) for i in [3, 4]])
   415  
   416        deferred_df1 = convert.to_dataframe(pc1)
   417        deferred_df2 = convert.to_dataframe(pc2)
   418  
   419        deferred_df1.to_csv(
   420            f'{output}out1.csv',
   421            transform_label="Writing to csv PC1",
   422            index=False)
   423        deferred_df2.to_csv(
   424            f'{output}out2.csv',
   425            transform_label="Writing to csv PC2",
   426            index=False)
   427  
   428      self.assertCountEqual(['value', '1', '2'],
   429                            set(self.read_all_lines(output + 'out1.csv*')))
   430      self.assertCountEqual(['value', '3', '4'],
   431                            set(self.read_all_lines(output + 'out2.csv*')))
   432  
   433  
   434  @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed')
   435  class ReadGbqTransformTests(unittest.TestCase):
   436    @mock.patch.object(BigQueryWrapper, 'get_table')
   437    def test_bad_schema_public_api_direct_read(self, get_table):
   438      try:
   439        bigquery.TableFieldSchema
   440      except AttributeError:
   441        raise ValueError('Please install GCP Dependencies.')
   442      fields = [
   443          bigquery.TableFieldSchema(name='stn', type='DOUBLE', mode="NULLABLE"),
   444          bigquery.TableFieldSchema(name='temp', type='FLOAT64', mode="REPEATED"),
   445          bigquery.TableFieldSchema(name='count', type='INTEGER', mode=None)
   446      ]
   447      schema = bigquery.TableSchema(fields=fields)
   448      table = apache_beam.io.gcp.internal.clients.bigquery. \
   449          bigquery_v2_messages.Table(
   450          schema=schema)
   451      get_table.return_value = table
   452  
   453      with self.assertRaisesRegex(ValueError,
   454                                  "Encountered an unsupported type: 'DOUBLE'"):
   455        p = apache_beam.Pipeline()
   456        _ = p | apache_beam.dataframe.io.read_gbq(
   457            table="dataset.sample_table", use_bqstorage_api=True)
   458  
   459    def test_unsupported_callable(self):
   460      def filterTable(table):
   461        if table is not None:
   462          return table
   463  
   464      res = filterTable
   465      with self.assertRaisesRegex(TypeError,
   466                                  'ReadFromBigQuery: table must be of type string'
   467                                  '; got a callable instead'):
   468        p = beam.Pipeline()
   469        _ = p | beam.dataframe.io.read_gbq(table=res)
   470  
   471    def test_ReadGbq_unsupported_param(self):
   472      with self.assertRaisesRegex(ValueError,
   473                                  r"""Encountered unsupported parameter\(s\) """
   474                                  r"""in read_gbq: dict_keys\(\['reauth']\)"""):
   475        p = beam.Pipeline()
   476        _ = p | beam.dataframe.io.read_gbq(
   477            table="table", use_bqstorage_api=False, reauth="true_config")
   478  
   479  
   480  if __name__ == '__main__':
   481    unittest.main()