github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/batch_dofn_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """UnitTests for Batched DoFn (process_batch) API."""
    19  
    20  # pytype: skip-file
    21  
    22  import unittest
    23  from typing import Iterator
    24  from typing import List
    25  from typing import Tuple
    26  from typing import no_type_check
    27  
    28  from parameterized import parameterized_class
    29  
    30  import apache_beam as beam
    31  
    32  
    33  class ElementDoFn(beam.DoFn):
    34    def process(self, element: int, *args, **kwargs) -> Iterator[float]:
    35      yield element / 2
    36  
    37  
    38  class BatchDoFn(beam.DoFn):
    39    def process_batch(self, batch: List[int], *args,
    40                      **kwargs) -> Iterator[List[float]]:
    41      yield [element / 2 for element in batch]
    42  
    43  
    44  class NoReturnAnnotation(beam.DoFn):
    45    def process_batch(self, batch: List[int], *args, **kwargs):
    46      yield [element * 2 for element in batch]
    47  
    48  
    49  class OverrideTypeInference(beam.DoFn):
    50    def process_batch(self, batch, *args, **kwargs):
    51      yield [element * 2 for element in batch]
    52  
    53    def get_input_batch_type(self, input_element_type):
    54      return List[input_element_type]
    55  
    56    def get_output_batch_type(self, input_element_type):
    57      return List[input_element_type]
    58  
    59  
    60  class EitherDoFn(beam.DoFn):
    61    def process(self, element: int, *args, **kwargs) -> Iterator[float]:
    62      yield element / 2
    63  
    64    def process_batch(self, batch: List[int], *args,
    65                      **kwargs) -> Iterator[List[float]]:
    66      yield [element / 2 for element in batch]
    67  
    68  
    69  class ElementToBatchDoFn(beam.DoFn):
    70    @beam.DoFn.yields_batches
    71    def process(self, element: int, *args, **kwargs) -> Iterator[List[int]]:
    72      yield [element] * element
    73  
    74    def infer_output_type(self, input_element_type):
    75      return input_element_type
    76  
    77  
    78  class BatchToElementDoFn(beam.DoFn):
    79    @beam.DoFn.yields_elements
    80    def process_batch(self, batch: List[int], *args,
    81                      **kwargs) -> Iterator[Tuple[int, int]]:
    82      yield (sum(batch), len(batch))
    83  
    84  
    85  def get_test_class_name(cls, num, params_dict):
    86    return "%s_%s" % (cls.__name__, params_dict['dofn'].__class__.__name__)
    87  
    88  
    89  @parameterized_class([
    90      {
    91          "dofn": ElementDoFn(),
    92          "input_element_type": int,
    93          "expected_process_defined": True,
    94          "expected_process_batch_defined": False,
    95          "expected_input_batch_type": None,
    96          "expected_output_batch_type": None
    97      },
    98      {
    99          "dofn": BatchDoFn(),
   100          "input_element_type": int,
   101          "expected_process_defined": False,
   102          "expected_process_batch_defined": True,
   103          "expected_input_batch_type": beam.typehints.List[int],
   104          "expected_output_batch_type": beam.typehints.List[float]
   105      },
   106      {
   107          "dofn": NoReturnAnnotation(),
   108          "input_element_type": int,
   109          "expected_process_defined": False,
   110          "expected_process_batch_defined": True,
   111          "expected_input_batch_type": beam.typehints.List[int],
   112          "expected_output_batch_type": beam.typehints.List[int]
   113      },
   114      {
   115          "dofn": OverrideTypeInference(),
   116          "input_element_type": int,
   117          "expected_process_defined": False,
   118          "expected_process_batch_defined": True,
   119          "expected_input_batch_type": beam.typehints.List[int],
   120          "expected_output_batch_type": beam.typehints.List[int]
   121      },
   122      {
   123          "dofn": EitherDoFn(),
   124          "input_element_type": int,
   125          "expected_process_defined": True,
   126          "expected_process_batch_defined": True,
   127          "expected_input_batch_type": beam.typehints.List[int],
   128          "expected_output_batch_type": beam.typehints.List[float]
   129      },
   130      {
   131          "dofn": ElementToBatchDoFn(),
   132          "input_element_type": int,
   133          "expected_process_defined": True,
   134          "expected_process_batch_defined": False,
   135          "expected_input_batch_type": None,
   136          "expected_output_batch_type": beam.typehints.List[int]
   137      },
   138      {
   139          "dofn": BatchToElementDoFn(),
   140          "input_element_type": int,
   141          "expected_process_defined": False,
   142          "expected_process_batch_defined": True,
   143          "expected_input_batch_type": beam.typehints.List[int],
   144          "expected_output_batch_type": None,
   145      },
   146  ],
   147                       class_name_func=get_test_class_name)
   148  class BatchDoFnParameterizedTest(unittest.TestCase):
   149    def test_process_defined(self):
   150      self.assertEqual(self.dofn._process_defined, self.expected_process_defined)
   151  
   152    def test_process_batch_defined(self):
   153      self.assertEqual(
   154          self.dofn._process_batch_defined, self.expected_process_batch_defined)
   155  
   156    def test_get_input_batch_type(self):
   157      self.assertEqual(
   158          self.dofn._get_input_batch_type_normalized(self.input_element_type),
   159          self.expected_input_batch_type)
   160  
   161    def test_get_output_batch_type(self):
   162      self.assertEqual(
   163          self.dofn._get_output_batch_type_normalized(self.input_element_type),
   164          self.expected_output_batch_type)
   165  
   166    def test_can_yield_batches(self):
   167      expected = self.expected_output_batch_type is not None
   168      self.assertEqual(self.dofn._can_yield_batches, expected)
   169  
   170  
   171  class NoInputAnnotation(beam.DoFn):
   172    def process_batch(self, batch, *args, **kwargs):
   173      yield [element * 2 for element in batch]
   174  
   175  
   176  class MismatchedBatchProducingDoFn(beam.DoFn):
   177    """A DoFn that produces batches from both process and process_batch, with
   178    mismatched return types (one yields floats, the other ints). Should yield
   179    a construction time error when applied."""
   180    @beam.DoFn.yields_batches
   181    def process(self, element: int, *args, **kwargs) -> Iterator[List[int]]:
   182      yield [element]
   183  
   184    def process_batch(self, batch: List[int], *args,
   185                      **kwargs) -> Iterator[List[float]]:
   186      yield [element / 2 for element in batch]
   187  
   188  
   189  class MismatchedElementProducingDoFn(beam.DoFn):
   190    """A DoFn that produces elements from both process and process_batch, with
   191    mismatched return types (one yields floats, the other ints). Should yield
   192    a construction time error when applied."""
   193    def process(self, element: int, *args, **kwargs) -> Iterator[float]:
   194      yield element / 2
   195  
   196    @beam.DoFn.yields_elements
   197    def process_batch(self, batch: List[int], *args, **kwargs) -> Iterator[int]:
   198      yield batch[0]
   199  
   200  
   201  class NoElementOutputAnnotation(beam.DoFn):
   202    def process_batch(self, batch: List[int], *args,
   203                      **kwargs) -> Iterator[List[int]]:
   204      yield [element * 2 for element in batch]
   205  
   206  
   207  class BatchDoFnTest(unittest.TestCase):
   208    def test_map_pardo(self):
   209      # verify batch dofn accessors work well with beam.Map generated DoFn
   210      # checking this in parameterized test causes a circular reference issue
   211      dofn = beam.Map(lambda x: x * 2).dofn
   212  
   213      self.assertTrue(dofn._process_defined)
   214      self.assertFalse(dofn._process_batch_defined)
   215      self.assertEqual(dofn._get_input_batch_type_normalized(int), None)
   216      self.assertEqual(dofn._get_output_batch_type_normalized(int), None)
   217  
   218    def test_no_input_annotation_raises(self):
   219      p = beam.Pipeline()
   220      pc = p | beam.Create([1, 2, 3])
   221  
   222      with self.assertRaisesRegex(TypeError, r'NoInputAnnotation.process_batch'):
   223        _ = pc | beam.ParDo(NoInputAnnotation())
   224  
   225    def test_unsupported_dofn_param_raises(self):
   226      class BadParam(beam.DoFn):
   227        @no_type_check
   228        def process_batch(self, batch: List[int], key=beam.DoFn.KeyParam):
   229          yield batch * key
   230  
   231      p = beam.Pipeline()
   232      pc = p | beam.Create([1, 2, 3])
   233  
   234      with self.assertRaisesRegex(NotImplementedError, r'BadParam.*KeyParam'):
   235        _ = pc | beam.ParDo(BadParam())
   236  
   237    def test_mismatched_batch_producer_raises(self):
   238      p = beam.Pipeline()
   239      pc = p | beam.Create([1, 2, 3])
   240  
   241      # Note (?ms) makes this a multiline regex, where . matches newlines.
   242      # See (?aiLmsux) at
   243      # https://docs.python.org/3.4/library/re.html#regular-expression-syntax
   244      with self.assertRaisesRegex(
   245          TypeError,
   246          (r'(?ms)MismatchedBatchProducingDoFn.*'
   247           r'process: List\[<class \'int\'>\].*process_batch: '
   248           r'List\[<class \'float\'>\]')):
   249        _ = pc | beam.ParDo(MismatchedBatchProducingDoFn())
   250  
   251    def test_mismatched_element_producer_raises(self):
   252      p = beam.Pipeline()
   253      pc = p | beam.Create([1, 2, 3])
   254  
   255      # Note (?ms) makes this a multiline regex, where . matches newlines.
   256      # See (?aiLmsux) at
   257      # https://docs.python.org/3.4/library/re.html#regular-expression-syntax
   258      with self.assertRaisesRegex(
   259          TypeError,
   260          r'(?ms)MismatchedElementProducingDoFn.*process:.*process_batch:'):
   261        _ = pc | beam.ParDo(MismatchedElementProducingDoFn())
   262  
   263    def test_cant_infer_batchconverter_input_raises(self):
   264      p = beam.Pipeline()
   265      pc = p | beam.Create(['a', 'b', 'c'])
   266  
   267      with self.assertRaisesRegex(
   268          TypeError,
   269          # Error should mention "input", and the name of the DoFn
   270          r'input.*BatchDoFn.*'):
   271        _ = pc | beam.ParDo(BatchDoFn())
   272  
   273    def test_cant_infer_batchconverter_output_raises(self):
   274      p = beam.Pipeline()
   275      pc = p | beam.Create([1, 2, 3])
   276  
   277      with self.assertRaisesRegex(
   278          TypeError,
   279          # Error should mention "output", the name of the DoFn, and suggest
   280          # overriding DoFn.infer_output_type
   281          r'output.*NoElementOutputAnnotation.*DoFn\.infer_output_type'):
   282        _ = pc | beam.ParDo(NoElementOutputAnnotation())
   283  
   284    def test_element_to_batch_dofn_typehint(self):
   285      # Verify that element to batch DoFn sets the correct typehint on the output
   286      # PCollection.
   287  
   288      p = beam.Pipeline()
   289      pc = (p | beam.Create([1, 2, 3]) | beam.ParDo(ElementToBatchDoFn()))
   290  
   291      self.assertEqual(pc.element_type, int)
   292  
   293    def test_batch_to_element_dofn_typehint(self):
   294      # Verify that batch to element DoFn sets the correct typehint on the output
   295      # PCollection.
   296  
   297      p = beam.Pipeline()
   298      pc = (p | beam.Create([1, 2, 3]) | beam.ParDo(BatchToElementDoFn()))
   299  
   300      self.assertEqual(pc.element_type, beam.typehints.Tuple[int, int])
   301  
   302  
   303  if __name__ == '__main__':
   304    unittest.main()