github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/write_ptransform_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Unit tests for the write transform.""" 19 20 # pytype: skip-file 21 22 import logging 23 import unittest 24 25 import apache_beam as beam 26 from apache_beam.io import iobase 27 from apache_beam.testing.test_pipeline import TestPipeline 28 from apache_beam.testing.util import assert_that 29 from apache_beam.testing.util import is_empty 30 from apache_beam.transforms.ptransform import PTransform 31 32 33 class _TestSink(iobase.Sink): 34 TEST_INIT_RESULT = 'test_init_result' 35 36 def __init__(self, return_init_result=True, return_write_results=True): 37 self.return_init_result = return_init_result 38 self.return_write_results = return_write_results 39 40 def initialize_write(self): 41 if self.return_init_result: 42 return _TestSink.TEST_INIT_RESULT 43 44 def pre_finalize(self, init_result, writer_results): 45 pass 46 47 def finalize_write( 48 self, init_result, writer_results, unused_pre_finalize_result): 49 self.init_result_at_finalize = init_result 50 self.write_results_at_finalize = writer_results 51 52 def open_writer(self, init_result, uid): 53 writer = _TestWriter(init_result, uid, self.return_write_results) 54 return writer 55 56 57 class _TestWriter(iobase.Writer): 58 STATE_UNSTARTED, STATE_WRITTEN, STATE_CLOSED = 0, 1, 2 59 TEST_WRITE_RESULT = 'test_write_result' 60 61 def __init__(self, init_result, uid, return_write_results=True): 62 self.state = _TestWriter.STATE_UNSTARTED 63 self.init_result = init_result 64 self.uid = uid 65 self.write_output = [] 66 self.return_write_results = return_write_results 67 68 def close(self): 69 assert self.state in ( 70 _TestWriter.STATE_WRITTEN, _TestWriter.STATE_UNSTARTED) 71 self.state = _TestWriter.STATE_CLOSED 72 if self.return_write_results: 73 return _TestWriter.TEST_WRITE_RESULT 74 75 def write(self, value): 76 if self.write_output: 77 assert self.state == _TestWriter.STATE_WRITTEN 78 else: 79 assert self.state == _TestWriter.STATE_UNSTARTED 80 81 self.state = _TestWriter.STATE_WRITTEN 82 self.write_output.append(value) 83 84 85 class WriteToTestSink(PTransform): 86 def __init__(self, return_init_result=True, return_write_results=True): 87 self.return_init_result = return_init_result 88 self.return_write_results = return_write_results 89 self.last_sink = None 90 self.label = 'write_to_test_sink' 91 92 def expand(self, pcoll): 93 self.last_sink = _TestSink( 94 return_init_result=self.return_init_result, 95 return_write_results=self.return_write_results) 96 return pcoll | beam.io.Write(self.last_sink) 97 98 99 class WriteTest(unittest.TestCase): 100 DATA = ['some data', 'more data', 'another data', 'yet another data'] 101 102 def _run_write_test( 103 self, data, return_init_result=True, return_write_results=True): 104 write_to_test_sink = WriteToTestSink( 105 return_init_result, return_write_results) 106 with TestPipeline() as p: 107 result = p | beam.Create(data) | write_to_test_sink | beam.Map(list) 108 109 assert_that(result, is_empty()) 110 111 sink = write_to_test_sink.last_sink 112 self.assertIsNotNone(sink) 113 114 def test_write(self): 115 self._run_write_test(WriteTest.DATA) 116 117 def test_write_with_empty_pcollection(self): 118 data = [] 119 self._run_write_test(data) 120 121 def test_write_no_init_result(self): 122 self._run_write_test(WriteTest.DATA, return_init_result=False) 123 124 def test_write_no_write_results(self): 125 self._run_write_test(WriteTest.DATA, return_write_results=False) 126 127 128 if __name__ == '__main__': 129 logging.getLogger().setLevel(logging.INFO) 130 unittest.main()