github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/create_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Unit tests for the Create and _CreateSource classes."""
    19  # pytype: skip-file
    20  
    21  import logging
    22  import unittest
    23  
    24  from apache_beam import Create
    25  from apache_beam import coders
    26  from apache_beam.coders import FastPrimitivesCoder
    27  from apache_beam.internal import pickler
    28  from apache_beam.io import source_test_utils
    29  from apache_beam.testing.test_pipeline import TestPipeline
    30  from apache_beam.testing.util import assert_that
    31  from apache_beam.testing.util import equal_to
    32  
    33  
    34  class CreateTest(unittest.TestCase):
    35    def setUp(self):
    36      self.coder = FastPrimitivesCoder()
    37  
    38    def test_create_transform(self):
    39      with TestPipeline() as p:
    40        assert_that(p | 'Empty' >> Create([]), equal_to([]), label='empty')
    41        assert_that(p | 'One' >> Create([None]), equal_to([None]), label='one')
    42        assert_that(p | Create(list(range(10))), equal_to(list(range(10))))
    43  
    44    def test_create_source_read(self):
    45      self.check_read([], self.coder)
    46      self.check_read([1], self.coder)
    47      # multiple values.
    48      self.check_read(list(range(10)), self.coder)
    49  
    50    def check_read(self, values, coder):
    51      source = Create._create_source_from_iterable(values, coder)
    52      read_values = source_test_utils.read_from_source(source)
    53      self.assertEqual(sorted(values), sorted(read_values))
    54  
    55    def test_create_source_read_with_initial_splits(self):
    56      self.check_read_with_initial_splits([], self.coder, num_splits=2)
    57      self.check_read_with_initial_splits([1], self.coder, num_splits=2)
    58      values = list(range(8))
    59      # multiple values with a single split.
    60      self.check_read_with_initial_splits(values, self.coder, num_splits=1)
    61      # multiple values with a single split with a large desired bundle size
    62      self.check_read_with_initial_splits(values, self.coder, num_splits=0.5)
    63      # multiple values with many splits.
    64      self.check_read_with_initial_splits(values, self.coder, num_splits=3)
    65      # multiple values with uneven sized splits.
    66      self.check_read_with_initial_splits(values, self.coder, num_splits=4)
    67      # multiple values with num splits equal to num values.
    68      self.check_read_with_initial_splits(
    69          values, self.coder, num_splits=len(values))
    70      # multiple values with num splits greater than to num values.
    71      self.check_read_with_initial_splits(values, self.coder, num_splits=30)
    72  
    73    def check_read_with_initial_splits(self, values, coder, num_splits):
    74      """A test that splits the given source into `num_splits` and verifies that
    75      the data read from original source is equal to the union of the data read
    76      from the split sources.
    77      """
    78      source = Create._create_source_from_iterable(values, coder)
    79      desired_bundle_size = source._total_size // num_splits
    80      splits = source.split(desired_bundle_size)
    81      splits_info = [(split.source, split.start_position, split.stop_position)
    82                     for split in splits]
    83      source_test_utils.assert_sources_equal_reference_source(
    84          (source, None, None), splits_info)
    85  
    86    def test_create_source_read_reentrant(self):
    87      source = Create._create_source_from_iterable(range(9), self.coder)
    88      source_test_utils.assert_reentrant_reads_succeed((source, None, None))
    89  
    90    def test_create_source_read_reentrant_with_initial_splits(self):
    91      source = Create._create_source_from_iterable(range(24), self.coder)
    92      for split in source.split(desired_bundle_size=5):
    93        source_test_utils.assert_reentrant_reads_succeed(
    94            (split.source, split.start_position, split.stop_position))
    95  
    96    def test_create_source_dynamic_splitting(self):
    97      # 2 values
    98      source = Create._create_source_from_iterable(range(2), self.coder)
    99      source_test_utils.assert_split_at_fraction_exhaustive(source)
   100      # Multiple values.
   101      source = Create._create_source_from_iterable(range(11), self.coder)
   102      source_test_utils.assert_split_at_fraction_exhaustive(
   103          source, perform_multi_threaded_test=True)
   104  
   105    def test_create_source_progress(self):
   106      num_values = 10
   107      source = Create._create_source_from_iterable(range(num_values), self.coder)
   108      splits = [split for split in source.split(desired_bundle_size=100)]
   109      assert len(splits) == 1
   110      fraction_consumed_report = []
   111      split_points_report = []
   112      range_tracker = splits[0].source.get_range_tracker(
   113          splits[0].start_position, splits[0].stop_position)
   114      for _ in splits[0].source.read(range_tracker):
   115        fraction_consumed_report.append(range_tracker.fraction_consumed())
   116        split_points_report.append(range_tracker.split_points())
   117  
   118      self.assertEqual([float(i) / num_values for i in range(num_values)],
   119                       fraction_consumed_report)
   120  
   121      expected_split_points_report = [((i - 1), num_values - (i - 1))
   122                                      for i in range(1, num_values + 1)]
   123  
   124      self.assertEqual(expected_split_points_report, split_points_report)
   125  
   126    def test_create_uses_coder_for_pickling(self):
   127      coders.registry.register_coder(_Unpicklable, _UnpicklableCoder)
   128      create = Create([_Unpicklable(1), _Unpicklable(2), _Unpicklable(3)])
   129      unpickled_create = pickler.loads(pickler.dumps(create))
   130      self.assertEqual(
   131          sorted(create.values, key=lambda v: v.value),
   132          sorted(unpickled_create.values, key=lambda v: v.value))
   133  
   134      with self.assertRaises(NotImplementedError):
   135        # As there is no special coder for Union types, this will fall back to
   136        # FastPrimitivesCoder, which in turn falls back to pickling.
   137        create_mixed_types = Create([_Unpicklable(1), 2])
   138        pickler.dumps(create_mixed_types)
   139  
   140  
   141  class _Unpicklable(object):
   142    def __init__(self, value):
   143      self.value = value
   144  
   145    def __eq__(self, other):
   146      return self.value == other.value
   147  
   148    def __getstate__(self):
   149      raise NotImplementedError()
   150  
   151    def __setstate__(self, state):
   152      raise NotImplementedError()
   153  
   154  
   155  class _UnpicklableCoder(coders.Coder):
   156    def encode(self, value):
   157      return str(value.value).encode()
   158  
   159    def decode(self, encoded):
   160      return _Unpicklable(int(encoded.decode()))
   161  
   162    def to_type_hint(self):
   163      return _Unpicklable
   164  
   165  
   166  if __name__ == '__main__':
   167    logging.getLogger().setLevel(logging.INFO)
   168    unittest.main()