github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/write_ptransform_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Unit tests for the write transform."""
    19  
    20  # pytype: skip-file
    21  
    22  import logging
    23  import unittest
    24  
    25  import apache_beam as beam
    26  from apache_beam.io import iobase
    27  from apache_beam.testing.test_pipeline import TestPipeline
    28  from apache_beam.testing.util import assert_that
    29  from apache_beam.testing.util import is_empty
    30  from apache_beam.transforms.ptransform import PTransform
    31  
    32  
    33  class _TestSink(iobase.Sink):
    34    TEST_INIT_RESULT = 'test_init_result'
    35  
    36    def __init__(self, return_init_result=True, return_write_results=True):
    37      self.return_init_result = return_init_result
    38      self.return_write_results = return_write_results
    39  
    40    def initialize_write(self):
    41      if self.return_init_result:
    42        return _TestSink.TEST_INIT_RESULT
    43  
    44    def pre_finalize(self, init_result, writer_results):
    45      pass
    46  
    47    def finalize_write(
    48        self, init_result, writer_results, unused_pre_finalize_result):
    49      self.init_result_at_finalize = init_result
    50      self.write_results_at_finalize = writer_results
    51  
    52    def open_writer(self, init_result, uid):
    53      writer = _TestWriter(init_result, uid, self.return_write_results)
    54      return writer
    55  
    56  
    57  class _TestWriter(iobase.Writer):
    58    STATE_UNSTARTED, STATE_WRITTEN, STATE_CLOSED = 0, 1, 2
    59    TEST_WRITE_RESULT = 'test_write_result'
    60  
    61    def __init__(self, init_result, uid, return_write_results=True):
    62      self.state = _TestWriter.STATE_UNSTARTED
    63      self.init_result = init_result
    64      self.uid = uid
    65      self.write_output = []
    66      self.return_write_results = return_write_results
    67  
    68    def close(self):
    69      assert self.state in (
    70          _TestWriter.STATE_WRITTEN, _TestWriter.STATE_UNSTARTED)
    71      self.state = _TestWriter.STATE_CLOSED
    72      if self.return_write_results:
    73        return _TestWriter.TEST_WRITE_RESULT
    74  
    75    def write(self, value):
    76      if self.write_output:
    77        assert self.state == _TestWriter.STATE_WRITTEN
    78      else:
    79        assert self.state == _TestWriter.STATE_UNSTARTED
    80  
    81      self.state = _TestWriter.STATE_WRITTEN
    82      self.write_output.append(value)
    83  
    84  
    85  class WriteToTestSink(PTransform):
    86    def __init__(self, return_init_result=True, return_write_results=True):
    87      self.return_init_result = return_init_result
    88      self.return_write_results = return_write_results
    89      self.last_sink = None
    90      self.label = 'write_to_test_sink'
    91  
    92    def expand(self, pcoll):
    93      self.last_sink = _TestSink(
    94          return_init_result=self.return_init_result,
    95          return_write_results=self.return_write_results)
    96      return pcoll | beam.io.Write(self.last_sink)
    97  
    98  
    99  class WriteTest(unittest.TestCase):
   100    DATA = ['some data', 'more data', 'another data', 'yet another data']
   101  
   102    def _run_write_test(
   103        self, data, return_init_result=True, return_write_results=True):
   104      write_to_test_sink = WriteToTestSink(
   105          return_init_result, return_write_results)
   106      with TestPipeline() as p:
   107        result = p | beam.Create(data) | write_to_test_sink | beam.Map(list)
   108  
   109        assert_that(result, is_empty())
   110  
   111      sink = write_to_test_sink.last_sink
   112      self.assertIsNotNone(sink)
   113  
   114    def test_write(self):
   115      self._run_write_test(WriteTest.DATA)
   116  
   117    def test_write_with_empty_pcollection(self):
   118      data = []
   119      self._run_write_test(data)
   120  
   121    def test_write_no_init_result(self):
   122      self._run_write_test(WriteTest.DATA, return_init_result=False)
   123  
   124    def test_write_no_write_results(self):
   125      self._run_write_test(WriteTest.DATA, return_write_results=False)
   126  
   127  
   128  if __name__ == '__main__':
   129    logging.getLogger().setLevel(logging.INFO)
   130    unittest.main()