github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/dataflow/native_io/iobase_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Tests corresponding to Dataflow's iobase module.""" 19 20 # pytype: skip-file 21 22 import unittest 23 24 from apache_beam import Create 25 from apache_beam import error 26 from apache_beam import pvalue 27 from apache_beam.runners.dataflow.native_io.iobase import ConcatPosition 28 from apache_beam.runners.dataflow.native_io.iobase import DynamicSplitRequest 29 from apache_beam.runners.dataflow.native_io.iobase import DynamicSplitResultWithPosition 30 from apache_beam.runners.dataflow.native_io.iobase import NativeSink 31 from apache_beam.runners.dataflow.native_io.iobase import NativeSinkWriter 32 from apache_beam.runners.dataflow.native_io.iobase import NativeSource 33 from apache_beam.runners.dataflow.native_io.iobase import ReaderPosition 34 from apache_beam.runners.dataflow.native_io.iobase import ReaderProgress 35 from apache_beam.runners.dataflow.native_io.iobase import _dict_printable_fields 36 from apache_beam.runners.dataflow.native_io.iobase import _NativeWrite 37 from apache_beam.testing.test_pipeline import TestPipeline 38 39 40 class TestHelperFunctions(unittest.TestCase): 41 def test_dict_printable_fields(self): 42 dict_object = { 43 'key_alpha': '1', 44 'key_beta': None, 45 'key_charlie': [], 46 'key_delta': 2.0, 47 'key_echo': 'skip_me', 48 'key_fox': 0 49 } 50 skip_fields = [ 51 'key_echo', 52 ] 53 self.assertEqual( 54 sorted(_dict_printable_fields(dict_object, skip_fields)), 55 ["key_alpha='1'", 'key_delta=2.0', 'key_fox=0']) 56 57 58 class TestNativeSource(unittest.TestCase): 59 def test_reader_method(self): 60 native_source = NativeSource() 61 self.assertRaises(NotImplementedError, native_source.reader) 62 63 def test_repr_method(self): 64 class FakeSource(NativeSource): 65 """A fake source modeled after BigQuerySource, which inherits from 66 NativeSource.""" 67 def __init__( 68 self, 69 table=None, 70 dataset=None, 71 project=None, 72 query=None, 73 validate=False, 74 coder=None, 75 use_std_sql=False, 76 flatten_results=True): 77 self.validate = validate 78 79 fake_source = FakeSource() 80 self.assertEqual(fake_source.__repr__(), '<FakeSource validate=False>') 81 82 83 class TestReaderProgress(unittest.TestCase): 84 def test_out_of_bounds_percent_complete(self): 85 with self.assertRaises(ValueError): 86 ReaderProgress(percent_complete=-0.1) 87 with self.assertRaises(ValueError): 88 ReaderProgress(percent_complete=1.1) 89 90 def test_position_property(self): 91 reader_progress = ReaderProgress(position=ReaderPosition()) 92 self.assertEqual(type(reader_progress.position), ReaderPosition) 93 94 def test_percent_complete_property(self): 95 reader_progress = ReaderProgress(percent_complete=0.5) 96 self.assertEqual(reader_progress.percent_complete, 0.5) 97 98 99 class TestReaderPosition(unittest.TestCase): 100 def test_invalid_concat_position_type(self): 101 with self.assertRaises(AssertionError): 102 ReaderPosition(concat_position=1) 103 104 def test_valid_concat_position_type(self): 105 ReaderPosition(concat_position=ConcatPosition(None, None)) 106 107 108 class TestConcatPosition(unittest.TestCase): 109 def test_invalid_position_type(self): 110 with self.assertRaises(AssertionError): 111 ConcatPosition(None, position=1) 112 113 def test_valid_position_type(self): 114 ConcatPosition(None, position=ReaderPosition()) 115 116 117 class TestDynamicSplitRequest(unittest.TestCase): 118 def test_invalid_progress_type(self): 119 with self.assertRaises(AssertionError): 120 DynamicSplitRequest(progress=1) 121 122 def test_valid_progress_type(self): 123 DynamicSplitRequest(progress=ReaderProgress()) 124 125 126 class TestDynamicSplitResultWithPosition(unittest.TestCase): 127 def test_invalid_stop_position_type(self): 128 with self.assertRaises(AssertionError): 129 DynamicSplitResultWithPosition(stop_position=1) 130 131 def test_valid_stop_position_type(self): 132 DynamicSplitResultWithPosition(stop_position=ReaderPosition()) 133 134 135 class TestNativeSink(unittest.TestCase): 136 def test_writer_method(self): 137 native_sink = NativeSink() 138 self.assertRaises(NotImplementedError, native_sink.writer) 139 140 def test_repr_method(self): 141 class FakeSink(NativeSink): 142 """A fake sink modeled after BigQuerySink, which inherits from 143 NativeSink.""" 144 def __init__( 145 self, 146 validate=False, 147 dataset=None, 148 project=None, 149 schema=None, 150 create_disposition='create', 151 write_disposition=None, 152 coder=None): 153 self.validate = validate 154 155 fake_sink = FakeSink() 156 self.assertEqual(fake_sink.__repr__(), "<FakeSink ['validate=False']>") 157 158 def test_on_direct_runner(self): 159 class FakeSink(NativeSink): 160 """A fake sink outputing a number of elements.""" 161 def __init__(self): 162 self.written_values = [] 163 self.writer_instance = FakeSinkWriter(self.written_values) 164 165 def writer(self): 166 return self.writer_instance 167 168 class FakeSinkWriter(NativeSinkWriter): 169 """A fake sink writer for testing.""" 170 def __init__(self, written_values): 171 self.written_values = written_values 172 173 def __enter__(self): 174 return self 175 176 def __exit__(self, *unused_args): 177 pass 178 179 def Write(self, value): 180 self.written_values.append(value) 181 182 with TestPipeline() as p: 183 sink = FakeSink() 184 p | Create(['a', 'b', 'c']) | _NativeWrite(sink) # pylint: disable=expression-not-assigned 185 186 self.assertEqual(['a', 'b', 'c'], sorted(sink.written_values)) 187 188 189 class Test_NativeWrite(unittest.TestCase): 190 def setUp(self): 191 self.native_sink = NativeSink() 192 self.native_write = _NativeWrite(self.native_sink) 193 194 def test_expand_method_pcollection_errors(self): 195 with self.assertRaises(error.TransformError): 196 self.native_write.expand(None) 197 with self.assertRaises(error.TransformError): 198 pcoll = pvalue.PCollection(pipeline=None) 199 self.native_write.expand(pcoll) 200 201 202 if __name__ == '__main__': 203 unittest.main()