github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/dataflow/native_io/iobase_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Tests corresponding to Dataflow's iobase module."""
    19  
    20  # pytype: skip-file
    21  
    22  import unittest
    23  
    24  from apache_beam import Create
    25  from apache_beam import error
    26  from apache_beam import pvalue
    27  from apache_beam.runners.dataflow.native_io.iobase import ConcatPosition
    28  from apache_beam.runners.dataflow.native_io.iobase import DynamicSplitRequest
    29  from apache_beam.runners.dataflow.native_io.iobase import DynamicSplitResultWithPosition
    30  from apache_beam.runners.dataflow.native_io.iobase import NativeSink
    31  from apache_beam.runners.dataflow.native_io.iobase import NativeSinkWriter
    32  from apache_beam.runners.dataflow.native_io.iobase import NativeSource
    33  from apache_beam.runners.dataflow.native_io.iobase import ReaderPosition
    34  from apache_beam.runners.dataflow.native_io.iobase import ReaderProgress
    35  from apache_beam.runners.dataflow.native_io.iobase import _dict_printable_fields
    36  from apache_beam.runners.dataflow.native_io.iobase import _NativeWrite
    37  from apache_beam.testing.test_pipeline import TestPipeline
    38  
    39  
    40  class TestHelperFunctions(unittest.TestCase):
    41    def test_dict_printable_fields(self):
    42      dict_object = {
    43          'key_alpha': '1',
    44          'key_beta': None,
    45          'key_charlie': [],
    46          'key_delta': 2.0,
    47          'key_echo': 'skip_me',
    48          'key_fox': 0
    49      }
    50      skip_fields = [
    51          'key_echo',
    52      ]
    53      self.assertEqual(
    54          sorted(_dict_printable_fields(dict_object, skip_fields)),
    55          ["key_alpha='1'", 'key_delta=2.0', 'key_fox=0'])
    56  
    57  
    58  class TestNativeSource(unittest.TestCase):
    59    def test_reader_method(self):
    60      native_source = NativeSource()
    61      self.assertRaises(NotImplementedError, native_source.reader)
    62  
    63    def test_repr_method(self):
    64      class FakeSource(NativeSource):
    65        """A fake source modeled after BigQuerySource, which inherits from
    66        NativeSource."""
    67        def __init__(
    68            self,
    69            table=None,
    70            dataset=None,
    71            project=None,
    72            query=None,
    73            validate=False,
    74            coder=None,
    75            use_std_sql=False,
    76            flatten_results=True):
    77          self.validate = validate
    78  
    79      fake_source = FakeSource()
    80      self.assertEqual(fake_source.__repr__(), '<FakeSource validate=False>')
    81  
    82  
    83  class TestReaderProgress(unittest.TestCase):
    84    def test_out_of_bounds_percent_complete(self):
    85      with self.assertRaises(ValueError):
    86        ReaderProgress(percent_complete=-0.1)
    87      with self.assertRaises(ValueError):
    88        ReaderProgress(percent_complete=1.1)
    89  
    90    def test_position_property(self):
    91      reader_progress = ReaderProgress(position=ReaderPosition())
    92      self.assertEqual(type(reader_progress.position), ReaderPosition)
    93  
    94    def test_percent_complete_property(self):
    95      reader_progress = ReaderProgress(percent_complete=0.5)
    96      self.assertEqual(reader_progress.percent_complete, 0.5)
    97  
    98  
    99  class TestReaderPosition(unittest.TestCase):
   100    def test_invalid_concat_position_type(self):
   101      with self.assertRaises(AssertionError):
   102        ReaderPosition(concat_position=1)
   103  
   104    def test_valid_concat_position_type(self):
   105      ReaderPosition(concat_position=ConcatPosition(None, None))
   106  
   107  
   108  class TestConcatPosition(unittest.TestCase):
   109    def test_invalid_position_type(self):
   110      with self.assertRaises(AssertionError):
   111        ConcatPosition(None, position=1)
   112  
   113    def test_valid_position_type(self):
   114      ConcatPosition(None, position=ReaderPosition())
   115  
   116  
   117  class TestDynamicSplitRequest(unittest.TestCase):
   118    def test_invalid_progress_type(self):
   119      with self.assertRaises(AssertionError):
   120        DynamicSplitRequest(progress=1)
   121  
   122    def test_valid_progress_type(self):
   123      DynamicSplitRequest(progress=ReaderProgress())
   124  
   125  
   126  class TestDynamicSplitResultWithPosition(unittest.TestCase):
   127    def test_invalid_stop_position_type(self):
   128      with self.assertRaises(AssertionError):
   129        DynamicSplitResultWithPosition(stop_position=1)
   130  
   131    def test_valid_stop_position_type(self):
   132      DynamicSplitResultWithPosition(stop_position=ReaderPosition())
   133  
   134  
   135  class TestNativeSink(unittest.TestCase):
   136    def test_writer_method(self):
   137      native_sink = NativeSink()
   138      self.assertRaises(NotImplementedError, native_sink.writer)
   139  
   140    def test_repr_method(self):
   141      class FakeSink(NativeSink):
   142        """A fake sink modeled after BigQuerySink, which inherits from
   143        NativeSink."""
   144        def __init__(
   145            self,
   146            validate=False,
   147            dataset=None,
   148            project=None,
   149            schema=None,
   150            create_disposition='create',
   151            write_disposition=None,
   152            coder=None):
   153          self.validate = validate
   154  
   155      fake_sink = FakeSink()
   156      self.assertEqual(fake_sink.__repr__(), "<FakeSink ['validate=False']>")
   157  
   158    def test_on_direct_runner(self):
   159      class FakeSink(NativeSink):
   160        """A fake sink outputing a number of elements."""
   161        def __init__(self):
   162          self.written_values = []
   163          self.writer_instance = FakeSinkWriter(self.written_values)
   164  
   165        def writer(self):
   166          return self.writer_instance
   167  
   168      class FakeSinkWriter(NativeSinkWriter):
   169        """A fake sink writer for testing."""
   170        def __init__(self, written_values):
   171          self.written_values = written_values
   172  
   173        def __enter__(self):
   174          return self
   175  
   176        def __exit__(self, *unused_args):
   177          pass
   178  
   179        def Write(self, value):
   180          self.written_values.append(value)
   181  
   182      with TestPipeline() as p:
   183        sink = FakeSink()
   184        p | Create(['a', 'b', 'c']) | _NativeWrite(sink)  # pylint: disable=expression-not-assigned
   185  
   186      self.assertEqual(['a', 'b', 'c'], sorted(sink.written_values))
   187  
   188  
   189  class Test_NativeWrite(unittest.TestCase):
   190    def setUp(self):
   191      self.native_sink = NativeSink()
   192      self.native_write = _NativeWrite(self.native_sink)
   193  
   194    def test_expand_method_pcollection_errors(self):
   195      with self.assertRaises(error.TransformError):
   196        self.native_write.expand(None)
   197      with self.assertRaises(error.TransformError):
   198        pcoll = pvalue.PCollection(pipeline=None)
   199        self.native_write.expand(pcoll)
   200  
   201  
   202  if __name__ == '__main__':
   203    unittest.main()