github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/transforms/ptransform_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Unit tests for the PTransform and descendants."""
    19  
    20  # pytype: skip-file
    21  
    22  import collections
    23  import operator
    24  import os
    25  import pickle
    26  import random
    27  import re
    28  import typing
    29  import unittest
    30  from functools import reduce
    31  from typing import Optional
    32  from unittest.mock import patch
    33  
    34  import hamcrest as hc
    35  import numpy as np
    36  import pytest
    37  from parameterized import parameterized_class
    38  
    39  import apache_beam as beam
    40  import apache_beam.transforms.combiners as combine
    41  from apache_beam import pvalue
    42  from apache_beam import typehints
    43  from apache_beam.io.iobase import Read
    44  from apache_beam.metrics import Metrics
    45  from apache_beam.metrics.metric import MetricsFilter
    46  from apache_beam.options.pipeline_options import PipelineOptions
    47  from apache_beam.options.pipeline_options import TypeOptions
    48  from apache_beam.portability import common_urns
    49  from apache_beam.testing.test_pipeline import TestPipeline
    50  from apache_beam.testing.test_stream import TestStream
    51  from apache_beam.testing.util import SortLists
    52  from apache_beam.testing.util import assert_that
    53  from apache_beam.testing.util import equal_to
    54  from apache_beam.transforms import WindowInto
    55  from apache_beam.transforms import trigger
    56  from apache_beam.transforms import window
    57  from apache_beam.transforms.display import DisplayData
    58  from apache_beam.transforms.display import DisplayDataItem
    59  from apache_beam.transforms.ptransform import PTransform
    60  from apache_beam.transforms.window import TimestampedValue
    61  from apache_beam.typehints import with_input_types
    62  from apache_beam.typehints import with_output_types
    63  from apache_beam.typehints.typehints_test import TypeHintTestCase
    64  from apache_beam.utils.timestamp import Timestamp
    65  from apache_beam.utils.windowed_value import WindowedValue
    66  
    67  # Disable frequent lint warning due to pipe operator for chaining transforms.
    68  # pylint: disable=expression-not-assigned
    69  
    70  
    71  class PTransformTest(unittest.TestCase):
    72    def assertStartswith(self, msg, prefix):
    73      self.assertTrue(
    74          msg.startswith(prefix), '"%s" does not start with "%s"' % (msg, prefix))
    75  
    76    def test_str(self):
    77      self.assertEqual(
    78          '<PTransform(PTransform) label=[PTransform]>', str(PTransform()))
    79  
    80      pa = TestPipeline()
    81      res = pa | 'ALabel' >> beam.Impulse()
    82      self.assertEqual('AppliedPTransform(ALabel, Impulse)', str(res.producer))
    83  
    84      pc = TestPipeline()
    85      res = pc | beam.Impulse()
    86      inputs_tr = res.producer.transform
    87      inputs_tr.inputs = ('ci', )
    88      self.assertEqual(
    89          "<Impulse(PTransform) label=[Impulse] inputs=('ci',)>", str(inputs_tr))
    90  
    91      pd = TestPipeline()
    92      res = pd | beam.Impulse()
    93      side_tr = res.producer.transform
    94      side_tr.side_inputs = (4, )
    95      self.assertEqual(
    96          '<Impulse(PTransform) label=[Impulse] side_inputs=(4,)>', str(side_tr))
    97  
    98      inputs_tr.side_inputs = ('cs', )
    99      self.assertEqual(
   100          """<Impulse(PTransform) label=[Impulse] """
   101          """inputs=('ci',) side_inputs=('cs',)>""",
   102          str(inputs_tr))
   103  
   104    def test_do_with_do_fn(self):
   105      class AddNDoFn(beam.DoFn):
   106        def process(self, element, addon):
   107          return [element + addon]
   108  
   109      with TestPipeline() as pipeline:
   110        pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3])
   111        result = pcoll | 'Do' >> beam.ParDo(AddNDoFn(), 10)
   112        assert_that(result, equal_to([11, 12, 13]))
   113  
   114    def test_do_with_unconstructed_do_fn(self):
   115      class MyDoFn(beam.DoFn):
   116        def process(self):
   117          pass
   118  
   119      with self.assertRaises(ValueError):
   120        with TestPipeline() as pipeline:
   121          pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3])
   122          pcoll | 'Do' >> beam.ParDo(MyDoFn)  # Note the lack of ()'s
   123  
   124    def test_do_with_callable(self):
   125      with TestPipeline() as pipeline:
   126        pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3])
   127        result = pcoll | 'Do' >> beam.FlatMap(lambda x, addon: [x + addon], 10)
   128        assert_that(result, equal_to([11, 12, 13]))
   129  
   130    def test_do_with_side_input_as_arg(self):
   131      with TestPipeline() as pipeline:
   132        side = pipeline | 'Side' >> beam.Create([10])
   133        pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3])
   134        result = pcoll | 'Do' >> beam.FlatMap(
   135            lambda x, addon: [x + addon], pvalue.AsSingleton(side))
   136        assert_that(result, equal_to([11, 12, 13]))
   137  
   138    def test_do_with_side_input_as_keyword_arg(self):
   139      with TestPipeline() as pipeline:
   140        side = pipeline | 'Side' >> beam.Create([10])
   141        pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3])
   142        result = pcoll | 'Do' >> beam.FlatMap(
   143            lambda x, addon: [x + addon], addon=pvalue.AsSingleton(side))
   144        assert_that(result, equal_to([11, 12, 13]))
   145  
   146    def test_do_with_do_fn_returning_string_raises_warning(self):
   147      with self.assertRaises(typehints.TypeCheckError) as cm:
   148        with TestPipeline() as pipeline:
   149          pipeline._options.view_as(TypeOptions).runtime_type_check = True
   150          pcoll = pipeline | 'Start' >> beam.Create(['2', '9', '3'])
   151          pcoll | 'Do' >> beam.FlatMap(lambda x: x + '1')
   152  
   153          # Since the DoFn directly returns a string we should get an
   154          # error warning us when the pipeliene runs.
   155  
   156      expected_error_prefix = (
   157          'Returning a str from a ParDo or FlatMap '
   158          'is discouraged.')
   159      self.assertStartswith(cm.exception.args[0], expected_error_prefix)
   160  
   161    def test_do_with_do_fn_returning_dict_raises_warning(self):
   162      with self.assertRaises(typehints.TypeCheckError) as cm:
   163        with TestPipeline() as pipeline:
   164          pipeline._options.view_as(TypeOptions).runtime_type_check = True
   165          pcoll = pipeline | 'Start' >> beam.Create(['2', '9', '3'])
   166          pcoll | 'Do' >> beam.FlatMap(lambda x: {x: '1'})
   167  
   168          # Since the DoFn directly returns a dict we should get an error warning
   169          # us when the pipeliene runs.
   170  
   171      expected_error_prefix = (
   172          'Returning a dict from a ParDo or FlatMap '
   173          'is discouraged.')
   174      self.assertStartswith(cm.exception.args[0], expected_error_prefix)
   175  
   176    def test_do_with_multiple_outputs_maintains_unique_name(self):
   177      with TestPipeline() as pipeline:
   178        pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3])
   179        r1 = pcoll | 'A' >> beam.FlatMap(lambda x: [x + 1]).with_outputs(main='m')
   180        r2 = pcoll | 'B' >> beam.FlatMap(lambda x: [x + 2]).with_outputs(main='m')
   181        assert_that(r1.m, equal_to([2, 3, 4]), label='r1')
   182        assert_that(r2.m, equal_to([3, 4, 5]), label='r2')
   183  
   184    @pytest.mark.it_validatesrunner
   185    def test_impulse(self):
   186      with TestPipeline() as pipeline:
   187        result = pipeline | beam.Impulse() | beam.Map(lambda _: 0)
   188        assert_that(result, equal_to([0]))
   189  
   190    # TODO(BEAM-3544): Disable this test in streaming temporarily.
   191    # Remove sickbay-streaming tag after it's resolved.
   192    @pytest.mark.no_sickbay_streaming
   193    @pytest.mark.it_validatesrunner
   194    def test_read_metrics(self):
   195      from apache_beam.io.utils import CountingSource
   196  
   197      class CounterDoFn(beam.DoFn):
   198        def __init__(self):
   199          # This counter is unused.
   200          self.received_records = Metrics.counter(
   201              self.__class__, 'receivedRecords')
   202  
   203        def process(self, element):
   204          self.received_records.inc()
   205  
   206      pipeline = TestPipeline()
   207      (pipeline | Read(CountingSource(100)) | beam.ParDo(CounterDoFn()))
   208      res = pipeline.run()
   209      res.wait_until_finish()
   210      # This counter is defined in utils.CountingSource.
   211      metric_results = res.metrics().query(
   212          MetricsFilter().with_name('recordsRead'))
   213      outputs_counter = metric_results['counters'][0]
   214      self.assertStartswith(outputs_counter.key.step, 'Read')
   215      self.assertEqual(outputs_counter.key.metric.name, 'recordsRead')
   216      self.assertEqual(outputs_counter.committed, 100)
   217      self.assertEqual(outputs_counter.attempted, 100)
   218  
   219    @pytest.mark.it_validatesrunner
   220    def test_par_do_with_multiple_outputs_and_using_yield(self):
   221      class SomeDoFn(beam.DoFn):
   222        """A custom DoFn using yield."""
   223        def process(self, element):
   224          yield element
   225          if element % 2 == 0:
   226            yield pvalue.TaggedOutput('even', element)
   227          else:
   228            yield pvalue.TaggedOutput('odd', element)
   229  
   230      with TestPipeline() as pipeline:
   231        nums = pipeline | 'Some Numbers' >> beam.Create([1, 2, 3, 4])
   232        results = nums | 'ClassifyNumbers' >> beam.ParDo(SomeDoFn()).with_outputs(
   233            'odd', 'even', main='main')
   234        assert_that(results.main, equal_to([1, 2, 3, 4]))
   235        assert_that(results.odd, equal_to([1, 3]), label='assert:odd')
   236        assert_that(results.even, equal_to([2, 4]), label='assert:even')
   237  
   238    @pytest.mark.it_validatesrunner
   239    def test_par_do_with_multiple_outputs_and_using_return(self):
   240      def some_fn(v):
   241        if v % 2 == 0:
   242          return [v, pvalue.TaggedOutput('even', v)]
   243        return [v, pvalue.TaggedOutput('odd', v)]
   244  
   245      with TestPipeline() as pipeline:
   246        nums = pipeline | 'Some Numbers' >> beam.Create([1, 2, 3, 4])
   247        results = nums | 'ClassifyNumbers' >> beam.FlatMap(some_fn).with_outputs(
   248            'odd', 'even', main='main')
   249        assert_that(results.main, equal_to([1, 2, 3, 4]))
   250        assert_that(results.odd, equal_to([1, 3]), label='assert:odd')
   251        assert_that(results.even, equal_to([2, 4]), label='assert:even')
   252  
   253    @pytest.mark.it_validatesrunner
   254    def test_undeclared_outputs(self):
   255      with TestPipeline() as pipeline:
   256        nums = pipeline | 'Some Numbers' >> beam.Create([1, 2, 3, 4])
   257        results = nums | 'ClassifyNumbers' >> beam.FlatMap(
   258            lambda x: [
   259                x,
   260                pvalue.TaggedOutput('even' if x % 2 == 0 else 'odd', x),
   261                pvalue.TaggedOutput('extra', x)
   262            ]).with_outputs()
   263        assert_that(results[None], equal_to([1, 2, 3, 4]))
   264        assert_that(results.odd, equal_to([1, 3]), label='assert:odd')
   265        assert_that(results.even, equal_to([2, 4]), label='assert:even')
   266  
   267    @pytest.mark.it_validatesrunner
   268    def test_multiple_empty_outputs(self):
   269      with TestPipeline() as pipeline:
   270        nums = pipeline | 'Some Numbers' >> beam.Create([1, 3, 5])
   271        results = nums | 'ClassifyNumbers' >> beam.FlatMap(
   272            lambda x:
   273            [x, pvalue.TaggedOutput('even'
   274                                    if x % 2 == 0 else 'odd', x)]).with_outputs()
   275        assert_that(results[None], equal_to([1, 3, 5]))
   276        assert_that(results.odd, equal_to([1, 3, 5]), label='assert:odd')
   277        assert_that(results.even, equal_to([]), label='assert:even')
   278  
   279    def test_do_requires_do_fn_returning_iterable(self):
   280      # This function is incorrect because it returns an object that isn't an
   281      # iterable.
   282      def incorrect_par_do_fn(x):
   283        return x + 5
   284  
   285      with self.assertRaises(typehints.TypeCheckError) as cm:
   286        with TestPipeline() as pipeline:
   287          pipeline._options.view_as(TypeOptions).runtime_type_check = True
   288          pcoll = pipeline | 'Start' >> beam.Create([2, 9, 3])
   289          pcoll | 'Do' >> beam.FlatMap(incorrect_par_do_fn)
   290          # It's a requirement that all user-defined functions to a ParDo return
   291          # an iterable.
   292  
   293      expected_error_prefix = 'FlatMap and ParDo must return an iterable.'
   294      self.assertStartswith(cm.exception.args[0], expected_error_prefix)
   295  
   296    def test_do_fn_with_finish(self):
   297      class MyDoFn(beam.DoFn):
   298        def process(self, element):
   299          pass
   300  
   301        def finish_bundle(self):
   302          yield WindowedValue('finish', -1, [window.GlobalWindow()])
   303  
   304      with TestPipeline() as pipeline:
   305        pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3])
   306        result = pcoll | 'Do' >> beam.ParDo(MyDoFn())
   307  
   308        # May have many bundles, but each has a start and finish.
   309        def matcher():
   310          def match(actual):
   311            equal_to(['finish'])(list(set(actual)))
   312            equal_to([1])([actual.count('finish')])
   313  
   314          return match
   315  
   316        assert_that(result, matcher())
   317  
   318    def test_do_fn_with_windowing_in_finish_bundle(self):
   319      windowfn = window.FixedWindows(2)
   320  
   321      class MyDoFn(beam.DoFn):
   322        def process(self, element):
   323          yield TimestampedValue('process' + str(element), 5)
   324  
   325        def finish_bundle(self):
   326          yield WindowedValue('finish', 1, [windowfn])
   327  
   328      with TestPipeline() as pipeline:
   329        result = (
   330            pipeline
   331            | 'Start' >> beam.Create([1])
   332            | beam.ParDo(MyDoFn())
   333            | WindowInto(windowfn)
   334            | 'create tuple' >> beam.Map(
   335                lambda v,
   336                t=beam.DoFn.TimestampParam,
   337                w=beam.DoFn.WindowParam: (v, t, w.start, w.end)))
   338        expected_process = [
   339            ('process1', Timestamp(5), Timestamp(4), Timestamp(6))
   340        ]
   341        expected_finish = [('finish', Timestamp(1), Timestamp(0), Timestamp(2))]
   342  
   343        assert_that(result, equal_to(expected_process + expected_finish))
   344  
   345    def test_do_fn_with_start(self):
   346      class MyDoFn(beam.DoFn):
   347        def __init__(self):
   348          self.state = 'init'
   349  
   350        def start_bundle(self):
   351          self.state = 'started'
   352  
   353        def process(self, element):
   354          if self.state == 'started':
   355            yield 'started'
   356          self.state = 'process'
   357  
   358      with TestPipeline() as pipeline:
   359        pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3])
   360        result = pcoll | 'Do' >> beam.ParDo(MyDoFn())
   361  
   362        # May have many bundles, but each has a start and finish.
   363        def matcher():
   364          def match(actual):
   365            equal_to(['started'])(list(set(actual)))
   366            equal_to([1])([actual.count('started')])
   367  
   368          return match
   369  
   370        assert_that(result, matcher())
   371  
   372    def test_do_fn_with_start_error(self):
   373      class MyDoFn(beam.DoFn):
   374        def start_bundle(self):
   375          return [1]
   376  
   377        def process(self, element):
   378          pass
   379  
   380      with self.assertRaises(RuntimeError):
   381        with TestPipeline() as p:
   382          p | 'Start' >> beam.Create([1, 2, 3]) | 'Do' >> beam.ParDo(MyDoFn())
   383  
   384    def test_map_builtin(self):
   385      with TestPipeline() as pipeline:
   386        pcoll = pipeline | 'Start' >> beam.Create([[1, 2], [1], [1, 2, 3]])
   387        result = pcoll | beam.Map(len)
   388        assert_that(result, equal_to([1, 2, 3]))
   389  
   390    def test_flatmap_builtin(self):
   391      with TestPipeline() as pipeline:
   392        pcoll = pipeline | 'Start' >> beam.Create([
   393            [np.array([1, 2, 3])] * 3, [np.array([5, 4, 3]), np.array([5, 6, 7])]
   394        ])
   395        result = pcoll | beam.FlatMap(sum)
   396        assert_that(result, equal_to([3, 6, 9, 10, 10, 10]))
   397  
   398    def test_filter_builtin(self):
   399      with TestPipeline() as pipeline:
   400        pcoll = pipeline | 'Start' >> beam.Create([[], [2], [], [4]])
   401        result = pcoll | 'Filter' >> beam.Filter(len)
   402        assert_that(result, equal_to([[2], [4]]))
   403  
   404    def test_filter(self):
   405      with TestPipeline() as pipeline:
   406        pcoll = pipeline | 'Start' >> beam.Create([1, 2, 3, 4])
   407        result = pcoll | 'Filter' >> beam.Filter(lambda x: x % 2 == 0)
   408        assert_that(result, equal_to([2, 4]))
   409  
   410    class _MeanCombineFn(beam.CombineFn):
   411      def create_accumulator(self):
   412        return (0, 0)
   413  
   414      def add_input(self, sum_count, element):
   415        (sum_, count) = sum_count
   416        return sum_ + element, count + 1
   417  
   418      def merge_accumulators(self, accumulators):
   419        sums, counts = zip(*accumulators)
   420        return sum(sums), sum(counts)
   421  
   422      def extract_output(self, sum_count):
   423        (sum_, count) = sum_count
   424        if not count:
   425          return float('nan')
   426        return sum_ / float(count)
   427  
   428    def test_combine_with_combine_fn(self):
   429      vals = [1, 2, 3, 4, 5, 6, 7]
   430      with TestPipeline() as pipeline:
   431        pcoll = pipeline | 'Start' >> beam.Create(vals)
   432        result = pcoll | 'Mean' >> beam.CombineGlobally(self._MeanCombineFn())
   433        assert_that(result, equal_to([sum(vals) // len(vals)]))
   434  
   435    def test_combine_with_callable(self):
   436      vals = [1, 2, 3, 4, 5, 6, 7]
   437      with TestPipeline() as pipeline:
   438        pcoll = pipeline | 'Start' >> beam.Create(vals)
   439        result = pcoll | beam.CombineGlobally(sum)
   440        assert_that(result, equal_to([sum(vals)]))
   441  
   442    def test_combine_with_side_input_as_arg(self):
   443      values = [1, 2, 3, 4, 5, 6, 7]
   444      with TestPipeline() as pipeline:
   445        pcoll = pipeline | 'Start' >> beam.Create(values)
   446        divisor = pipeline | 'Divisor' >> beam.Create([2])
   447        result = pcoll | 'Max' >> beam.CombineGlobally(
   448            # Multiples of divisor only.
   449            lambda vals,
   450            d: max(v for v in vals if v % d == 0),
   451            pvalue.AsSingleton(divisor)).without_defaults()
   452        filt_vals = [v for v in values if v % 2 == 0]
   453        assert_that(result, equal_to([max(filt_vals)]))
   454  
   455    def test_combine_per_key_with_combine_fn(self):
   456      vals_1 = [1, 2, 3, 4, 5, 6, 7]
   457      vals_2 = [2, 4, 6, 8, 10, 12, 14]
   458      with TestPipeline() as pipeline:
   459        pcoll = pipeline | 'Start' >> beam.Create(
   460            ([('a', x) for x in vals_1] + [('b', x) for x in vals_2]))
   461        result = pcoll | 'Mean' >> beam.CombinePerKey(self._MeanCombineFn())
   462        assert_that(
   463            result,
   464            equal_to([('a', sum(vals_1) // len(vals_1)),
   465                      ('b', sum(vals_2) // len(vals_2))]))
   466  
   467    def test_combine_per_key_with_callable(self):
   468      vals_1 = [1, 2, 3, 4, 5, 6, 7]
   469      vals_2 = [2, 4, 6, 8, 10, 12, 14]
   470      with TestPipeline() as pipeline:
   471        pcoll = pipeline | 'Start' >> beam.Create(
   472            ([('a', x) for x in vals_1] + [('b', x) for x in vals_2]))
   473        result = pcoll | beam.CombinePerKey(sum)
   474        assert_that(result, equal_to([('a', sum(vals_1)), ('b', sum(vals_2))]))
   475  
   476    def test_combine_per_key_with_side_input_as_arg(self):
   477      vals_1 = [1, 2, 3, 4, 5, 6, 7]
   478      vals_2 = [2, 4, 6, 8, 10, 12, 14]
   479      with TestPipeline() as pipeline:
   480        pcoll = pipeline | 'Start' >> beam.Create(
   481            ([('a', x) for x in vals_1] + [('b', x) for x in vals_2]))
   482        divisor = pipeline | 'Divisor' >> beam.Create([2])
   483        result = pcoll | beam.CombinePerKey(
   484            lambda vals,
   485            d: max(v for v in vals if v % d == 0),
   486            pvalue.AsSingleton(divisor))  # Multiples of divisor only.
   487        m_1 = max(v for v in vals_1 if v % 2 == 0)
   488        m_2 = max(v for v in vals_2 if v % 2 == 0)
   489        assert_that(result, equal_to([('a', m_1), ('b', m_2)]))
   490  
   491    def test_group_by_key(self):
   492      with TestPipeline() as pipeline:
   493        pcoll = pipeline | 'start' >> beam.Create([(1, 1), (2, 1), (3, 1), (1, 2),
   494                                                   (2, 2), (1, 3)])
   495        result = pcoll | 'Group' >> beam.GroupByKey() | SortLists
   496        assert_that(result, equal_to([(1, [1, 2, 3]), (2, [1, 2]), (3, [1])]))
   497  
   498    def test_group_by_key_unbounded_global_default_trigger(self):
   499      test_options = PipelineOptions()
   500      test_options.view_as(TypeOptions).allow_unsafe_triggers = False
   501      with self.assertRaisesRegex(
   502          ValueError,
   503          'GroupByKey cannot be applied to an unbounded PCollection with ' +
   504          'global windowing and a default trigger'):
   505        with TestPipeline(options=test_options) as pipeline:
   506          pipeline | TestStream() | beam.GroupByKey()
   507  
   508    def test_group_by_key_unsafe_trigger(self):
   509      test_options = PipelineOptions()
   510      test_options.view_as(TypeOptions).allow_unsafe_triggers = False
   511      with self.assertRaisesRegex(ValueError, 'Unsafe trigger'):
   512        with TestPipeline(options=test_options) as pipeline:
   513          _ = (
   514              pipeline
   515              | beam.Create([(None, None)])
   516              | WindowInto(
   517                  window.GlobalWindows(),
   518                  trigger=trigger.AfterCount(5),
   519                  accumulation_mode=trigger.AccumulationMode.ACCUMULATING)
   520              | beam.GroupByKey())
   521  
   522    def test_group_by_key_allow_unsafe_triggers(self):
   523      test_options = PipelineOptions(flags=['--allow_unsafe_triggers'])
   524      with TestPipeline(options=test_options) as pipeline:
   525        pcoll = (
   526            pipeline
   527            | beam.Create([(1, 1), (1, 2), (1, 3), (1, 4)])
   528            | WindowInto(
   529                window.GlobalWindows(),
   530                trigger=trigger.AfterCount(4),
   531                accumulation_mode=trigger.AccumulationMode.ACCUMULATING)
   532            | beam.GroupByKey())
   533        assert_that(pcoll, equal_to([(1, [1, 2, 3, 4])]))
   534  
   535    def test_group_by_key_reiteration(self):
   536      class MyDoFn(beam.DoFn):
   537        def process(self, gbk_result):
   538          key, value_list = gbk_result
   539          sum_val = 0
   540          # Iterate the GBK result for multiple times.
   541          for _ in range(0, 17):
   542            sum_val += sum(value_list)
   543          return [(key, sum_val)]
   544  
   545      with TestPipeline() as pipeline:
   546        pcoll = pipeline | 'start' >> beam.Create([(1, 1), (1, 2), (1, 3),
   547                                                   (1, 4)])
   548        result = (
   549            pcoll | 'Group' >> beam.GroupByKey()
   550            | 'Reiteration-Sum' >> beam.ParDo(MyDoFn()))
   551        assert_that(result, equal_to([(1, 170)]))
   552  
   553    def test_group_by_key_deterministic_coder(self):
   554      # pylint: disable=global-variable-not-assigned
   555      global MyObject  # for pickling of the class instance
   556  
   557      class MyObject:
   558        def __init__(self, value):
   559          self.value = value
   560  
   561        def __eq__(self, other):
   562          return self.value == other.value
   563  
   564        def __hash__(self):
   565          return hash(self.value)
   566  
   567      class MyObjectCoder(beam.coders.Coder):
   568        def encode(self, o):
   569          return pickle.dumps((o.value, random.random()))
   570  
   571        def decode(self, encoded):
   572          return MyObject(pickle.loads(encoded)[0])
   573  
   574        def as_deterministic_coder(self, *args):
   575          return MydeterministicObjectCoder()
   576  
   577        def to_type_hint(self):
   578          return MyObject
   579  
   580      class MydeterministicObjectCoder(beam.coders.Coder):
   581        def encode(self, o):
   582          return pickle.dumps(o.value)
   583  
   584        def decode(self, encoded):
   585          return MyObject(pickle.loads(encoded))
   586  
   587        def is_deterministic(self):
   588          return True
   589  
   590      beam.coders.registry.register_coder(MyObject, MyObjectCoder)
   591  
   592      with TestPipeline() as pipeline:
   593        pcoll = pipeline | beam.Create([(MyObject(k % 2), k) for k in range(10)])
   594        grouped = pcoll | beam.GroupByKey() | beam.MapTuple(
   595            lambda k, vs: (k.value, sorted(vs)))
   596        combined = pcoll | beam.CombinePerKey(sum) | beam.MapTuple(
   597            lambda k, v: (k.value, v))
   598        assert_that(
   599            grouped,
   600            equal_to([(0, [0, 2, 4, 6, 8]), (1, [1, 3, 5, 7, 9])]),
   601            'CheckGrouped')
   602        assert_that(combined, equal_to([(0, 20), (1, 25)]), 'CheckCombined')
   603  
   604    def test_group_by_key_non_deterministic_coder(self):
   605      with self.assertRaisesRegex(Exception, r'deterministic'):
   606        with TestPipeline() as pipeline:
   607          _ = (
   608              pipeline
   609              | beam.Create([(PickledObject(10), None)])
   610              | beam.GroupByKey()
   611              | beam.MapTuple(lambda k, v: list(v)))
   612  
   613    def test_group_by_key_allow_non_deterministic_coder(self):
   614      with TestPipeline() as pipeline:
   615        # The GroupByKey below would fail without this option.
   616        pipeline._options.view_as(
   617            TypeOptions).allow_non_deterministic_key_coders = True
   618        grouped = (
   619            pipeline
   620            | beam.Create([(PickledObject(10), None)])
   621            | beam.GroupByKey()
   622            | beam.MapTuple(lambda k, v: list(v)))
   623        assert_that(grouped, equal_to([[None]]))
   624  
   625    def test_group_by_key_fake_deterministic_coder(self):
   626      fresh_registry = beam.coders.typecoders.CoderRegistry()
   627      with patch.object(
   628          beam.coders, 'registry', fresh_registry), patch.object(
   629          beam.coders.typecoders, 'registry', fresh_registry):
   630        with TestPipeline() as pipeline:
   631          # The GroupByKey below would fail without this registration.
   632          beam.coders.registry.register_fallback_coder(
   633              beam.coders.coders.FakeDeterministicFastPrimitivesCoder())
   634          grouped = (
   635              pipeline
   636              | beam.Create([(PickledObject(10), None)])
   637              | beam.GroupByKey()
   638              | beam.MapTuple(lambda k, v: list(v)))
   639          assert_that(grouped, equal_to([[None]]))
   640  
   641    def test_partition_with_partition_fn(self):
   642      class SomePartitionFn(beam.PartitionFn):
   643        def partition_for(self, element, num_partitions, offset):
   644          return (element % 3) + offset
   645  
   646      with TestPipeline() as pipeline:
   647        pcoll = pipeline | 'Start' >> beam.Create([0, 1, 2, 3, 4, 5, 6, 7, 8])
   648        # Attempt nominal partition operation.
   649        partitions = pcoll | 'Part 1' >> beam.Partition(SomePartitionFn(), 4, 1)
   650        assert_that(partitions[0], equal_to([]))
   651        assert_that(partitions[1], equal_to([0, 3, 6]), label='p1')
   652        assert_that(partitions[2], equal_to([1, 4, 7]), label='p2')
   653        assert_that(partitions[3], equal_to([2, 5, 8]), label='p3')
   654  
   655      # Check that a bad partition label will yield an error. For the
   656      # DirectRunner, this error manifests as an exception.
   657      with self.assertRaises(ValueError):
   658        with TestPipeline() as pipeline:
   659          pcoll = pipeline | 'Start' >> beam.Create([0, 1, 2, 3, 4, 5, 6, 7, 8])
   660          partitions = pcoll | beam.Partition(SomePartitionFn(), 4, 10000)
   661  
   662    def test_partition_with_callable(self):
   663      with TestPipeline() as pipeline:
   664        pcoll = pipeline | 'Start' >> beam.Create([0, 1, 2, 3, 4, 5, 6, 7, 8])
   665        partitions = (
   666            pcoll |
   667            'part' >> beam.Partition(lambda e, n, offset: (e % 3) + offset, 4, 1))
   668        assert_that(partitions[0], equal_to([]))
   669        assert_that(partitions[1], equal_to([0, 3, 6]), label='p1')
   670        assert_that(partitions[2], equal_to([1, 4, 7]), label='p2')
   671        assert_that(partitions[3], equal_to([2, 5, 8]), label='p3')
   672  
   673    def test_partition_with_callable_and_side_input(self):
   674      with TestPipeline() as pipeline:
   675        pcoll = pipeline | 'Start' >> beam.Create([0, 1, 2, 3, 4, 5, 6, 7, 8])
   676        side_input = pipeline | 'Side Input' >> beam.Create([100, 1000])
   677        partitions = (
   678            pcoll | 'part' >> beam.Partition(
   679                lambda e,
   680                n,
   681                offset,
   682                si_list: ((e + len(si_list)) % 3) + offset,
   683                4,
   684                1,
   685                pvalue.AsList(side_input)))
   686        assert_that(partitions[0], equal_to([]))
   687        assert_that(partitions[1], equal_to([1, 4, 7]), label='p1')
   688        assert_that(partitions[2], equal_to([2, 5, 8]), label='p2')
   689        assert_that(partitions[3], equal_to([0, 3, 6]), label='p3')
   690  
   691    def test_partition_followed_by_flatten_and_groupbykey(self):
   692      """Regression test for an issue with how partitions are handled."""
   693      with TestPipeline() as pipeline:
   694        contents = [('aa', 1), ('bb', 2), ('aa', 2)]
   695        created = pipeline | 'A' >> beam.Create(contents)
   696        partitioned = created | 'B' >> beam.Partition(lambda x, n: len(x) % n, 3)
   697        flattened = partitioned | 'C' >> beam.Flatten()
   698        grouped = flattened | 'D' >> beam.GroupByKey() | SortLists
   699        assert_that(grouped, equal_to([('aa', [1, 2]), ('bb', [2])]))
   700  
   701    @pytest.mark.it_validatesrunner
   702    def test_flatten_pcollections(self):
   703      with TestPipeline() as pipeline:
   704        pcoll_1 = pipeline | 'Start 1' >> beam.Create([0, 1, 2, 3])
   705        pcoll_2 = pipeline | 'Start 2' >> beam.Create([4, 5, 6, 7])
   706        result = (pcoll_1, pcoll_2) | 'Flatten' >> beam.Flatten()
   707        assert_that(result, equal_to([0, 1, 2, 3, 4, 5, 6, 7]))
   708  
   709    def test_flatten_no_pcollections(self):
   710      with TestPipeline() as pipeline:
   711        with self.assertRaises(ValueError):
   712          () | 'PipelineArgMissing' >> beam.Flatten()
   713        result = () | 'Empty' >> beam.Flatten(pipeline=pipeline)
   714        assert_that(result, equal_to([]))
   715  
   716    @pytest.mark.it_validatesrunner
   717    def test_flatten_one_single_pcollection(self):
   718      with TestPipeline() as pipeline:
   719        input = [0, 1, 2, 3]
   720        pcoll = pipeline | 'Input' >> beam.Create(input)
   721        result = (pcoll, ) | 'Single Flatten' >> beam.Flatten()
   722        assert_that(result, equal_to(input))
   723  
   724    # TODO(https://github.com/apache/beam/issues/20067): Does not work in
   725    # streaming mode on Dataflow.
   726    @pytest.mark.no_sickbay_streaming
   727    @pytest.mark.it_validatesrunner
   728    def test_flatten_same_pcollections(self):
   729      with TestPipeline() as pipeline:
   730        pc = pipeline | beam.Create(['a', 'b'])
   731        assert_that((pc, pc, pc) | beam.Flatten(), equal_to(['a', 'b'] * 3))
   732  
   733    def test_flatten_pcollections_in_iterable(self):
   734      with TestPipeline() as pipeline:
   735        pcoll_1 = pipeline | 'Start 1' >> beam.Create([0, 1, 2, 3])
   736        pcoll_2 = pipeline | 'Start 2' >> beam.Create([4, 5, 6, 7])
   737        result = [pcoll for pcoll in (pcoll_1, pcoll_2)] | beam.Flatten()
   738        assert_that(result, equal_to([0, 1, 2, 3, 4, 5, 6, 7]))
   739  
   740    @pytest.mark.it_validatesrunner
   741    def test_flatten_a_flattened_pcollection(self):
   742      with TestPipeline() as pipeline:
   743        pcoll_1 = pipeline | 'Start 1' >> beam.Create([0, 1, 2, 3])
   744        pcoll_2 = pipeline | 'Start 2' >> beam.Create([4, 5, 6, 7])
   745        pcoll_3 = pipeline | 'Start 3' >> beam.Create([8, 9])
   746        pcoll_12 = (pcoll_1, pcoll_2) | 'Flatten' >> beam.Flatten()
   747        pcoll_123 = (pcoll_12, pcoll_3) | 'Flatten again' >> beam.Flatten()
   748        assert_that(pcoll_123, equal_to([x for x in range(10)]))
   749  
   750    def test_flatten_input_type_must_be_iterable(self):
   751      # Inputs to flatten *must* be an iterable.
   752      with self.assertRaises(ValueError):
   753        4 | beam.Flatten()
   754  
   755    def test_flatten_input_type_must_be_iterable_of_pcolls(self):
   756      # Inputs to flatten *must* be an iterable of PCollections.
   757      with self.assertRaises(TypeError):
   758        {'l': 'test'} | beam.Flatten()
   759      with self.assertRaises(TypeError):
   760        set([1, 2, 3]) | beam.Flatten()
   761  
   762    @pytest.mark.it_validatesrunner
   763    def test_flatten_multiple_pcollections_having_multiple_consumers(self):
   764      with TestPipeline() as pipeline:
   765        input = pipeline | 'Start' >> beam.Create(['AA', 'BBB', 'CC'])
   766  
   767        def split_even_odd(element):
   768          tag = 'even_length' if len(element) % 2 == 0 else 'odd_length'
   769          return pvalue.TaggedOutput(tag, element)
   770  
   771        even_length, odd_length = (input | beam.Map(split_even_odd)
   772                                   .with_outputs('even_length', 'odd_length'))
   773        merged = (even_length, odd_length) | 'Flatten' >> beam.Flatten()
   774  
   775        assert_that(merged, equal_to(['AA', 'BBB', 'CC']))
   776        assert_that(even_length, equal_to(['AA', 'CC']), label='assert:even')
   777        assert_that(odd_length, equal_to(['BBB']), label='assert:odd')
   778  
   779    def test_group_by_key_input_must_be_kv_pairs(self):
   780      with self.assertRaises(typehints.TypeCheckError) as e:
   781        with TestPipeline() as pipeline:
   782          pcolls = pipeline | 'A' >> beam.Create([1, 2, 3, 4, 5])
   783          pcolls | 'D' >> beam.GroupByKey()
   784  
   785      self.assertStartswith(
   786          e.exception.args[0],
   787          'Input type hint violation at D: expected '
   788          'Tuple[TypeVariable[K], TypeVariable[V]]')
   789  
   790    def test_group_by_key_only_input_must_be_kv_pairs(self):
   791      with self.assertRaises(typehints.TypeCheckError) as cm:
   792        with TestPipeline() as pipeline:
   793          pcolls = pipeline | 'A' >> beam.Create(['a', 'b', 'f'])
   794          pcolls | 'D' >> beam.GroupByKey()
   795  
   796      expected_error_prefix = (
   797          'Input type hint violation at D: expected '
   798          'Tuple[TypeVariable[K], TypeVariable[V]]')
   799      self.assertStartswith(cm.exception.args[0], expected_error_prefix)
   800  
   801    def test_keys_and_values(self):
   802      with TestPipeline() as pipeline:
   803        pcoll = pipeline | 'Start' >> beam.Create([(3, 1), (2, 1), (1, 1), (3, 2),
   804                                                   (2, 2), (3, 3)])
   805        keys = pcoll.apply(beam.Keys('keys'))
   806        vals = pcoll.apply(beam.Values('vals'))
   807        assert_that(keys, equal_to([1, 2, 2, 3, 3, 3]), label='assert:keys')
   808        assert_that(vals, equal_to([1, 1, 1, 2, 2, 3]), label='assert:vals')
   809  
   810    def test_kv_swap(self):
   811      with TestPipeline() as pipeline:
   812        pcoll = pipeline | 'Start' >> beam.Create([(6, 3), (1, 2), (7, 1), (5, 2),
   813                                                   (3, 2)])
   814        result = pcoll.apply(beam.KvSwap(), label='swap')
   815        assert_that(result, equal_to([(1, 7), (2, 1), (2, 3), (2, 5), (3, 6)]))
   816  
   817    def test_distinct(self):
   818      with TestPipeline() as pipeline:
   819        pcoll = pipeline | 'Start' >> beam.Create(
   820            [6, 3, 1, 1, 9, 'pleat', 'pleat', 'kazoo', 'navel'])
   821        result = pcoll.apply(beam.Distinct())
   822        assert_that(result, equal_to([1, 3, 6, 9, 'pleat', 'kazoo', 'navel']))
   823  
   824    def test_chained_ptransforms(self):
   825      with TestPipeline() as pipeline:
   826        t = (
   827            beam.Map(lambda x: (x, 1))
   828            | beam.GroupByKey()
   829            | beam.Map(lambda x_ones: (x_ones[0], sum(x_ones[1]))))
   830        result = pipeline | 'Start' >> beam.Create(['a', 'a', 'b']) | t
   831        assert_that(result, equal_to([('a', 2), ('b', 1)]))
   832  
   833    def test_apply_to_list(self):
   834      self.assertCountEqual([1, 2, 3],
   835                            [0, 1, 2] | 'AddOne' >> beam.Map(lambda x: x + 1))
   836      self.assertCountEqual([1],
   837                            [0, 1, 2] | 'Odd' >> beam.Filter(lambda x: x % 2))
   838      self.assertCountEqual([1, 2, 100, 3], ([1, 2, 3], [100]) | beam.Flatten())
   839      join_input = ([('k', 'a')], [('k', 'b'), ('k', 'c')])
   840      self.assertCountEqual([('k', (['a'], ['b', 'c']))],
   841                            join_input | beam.CoGroupByKey() | SortLists)
   842  
   843    def test_multi_input_ptransform(self):
   844      class DisjointUnion(PTransform):
   845        def expand(self, pcollections):
   846          return (
   847              pcollections
   848              | beam.Flatten()
   849              | beam.Map(lambda x: (x, None))
   850              | beam.GroupByKey()
   851              | beam.Map(lambda kv: kv[0]))
   852  
   853      self.assertEqual([1, 2, 3], sorted(([1, 2], [2, 3]) | DisjointUnion()))
   854  
   855    def test_apply_to_crazy_pvaluish(self):
   856      class NestedFlatten(PTransform):
   857        """A PTransform taking and returning nested PValueish.
   858  
   859        Takes as input a list of dicts, and returns a dict with the corresponding
   860        values flattened.
   861        """
   862        def _extract_input_pvalues(self, pvalueish):
   863          pvalueish = list(pvalueish)
   864          return pvalueish, sum([list(p.values()) for p in pvalueish], [])
   865  
   866        def expand(self, pcoll_dicts):
   867          keys = reduce(operator.or_, [set(p.keys()) for p in pcoll_dicts])
   868          res = {}
   869          for k in keys:
   870            res[k] = [p[k] for p in pcoll_dicts if k in p] | k >> beam.Flatten()
   871          return res
   872  
   873      res = [{
   874          'a': [1, 2, 3]
   875      }, {
   876          'a': [4, 5, 6], 'b': ['x', 'y', 'z']
   877      }, {
   878          'a': [7, 8], 'b': ['x', 'y'], 'c': []
   879      }] | NestedFlatten()
   880      self.assertEqual(3, len(res))
   881      self.assertEqual([1, 2, 3, 4, 5, 6, 7, 8], sorted(res['a']))
   882      self.assertEqual(['x', 'x', 'y', 'y', 'z'], sorted(res['b']))
   883      self.assertEqual([], sorted(res['c']))
   884  
   885    def test_named_tuple(self):
   886      MinMax = collections.namedtuple('MinMax', ['min', 'max'])
   887  
   888      class MinMaxTransform(PTransform):
   889        def expand(self, pcoll):
   890          return MinMax(
   891              min=pcoll | beam.CombineGlobally(min).without_defaults(),
   892              max=pcoll | beam.CombineGlobally(max).without_defaults())
   893  
   894      res = [1, 2, 4, 8] | MinMaxTransform()
   895      self.assertIsInstance(res, MinMax)
   896      self.assertEqual(res, MinMax(min=[1], max=[8]))
   897  
   898      flat = res | beam.Flatten()
   899      self.assertEqual(sorted(flat), [1, 8])
   900  
   901    def test_tuple_twice(self):
   902      class Duplicate(PTransform):
   903        def expand(self, pcoll):
   904          return pcoll, pcoll
   905  
   906      res1, res2 = [1, 2, 4, 8] | Duplicate()
   907      self.assertEqual(sorted(res1), [1, 2, 4, 8])
   908      self.assertEqual(sorted(res2), [1, 2, 4, 8])
   909  
   910    def test_resource_hint_application_is_additive(self):
   911      t = beam.Map(lambda x: x + 1).with_resource_hints(
   912          accelerator='gpu').with_resource_hints(min_ram=1).with_resource_hints(
   913              accelerator='tpu')
   914      self.assertEqual(
   915          t.get_resource_hints(),
   916          {
   917              common_urns.resource_hints.ACCELERATOR.urn: b'tpu',
   918              common_urns.resource_hints.MIN_RAM_BYTES.urn: b'1'
   919          })
   920  
   921  
   922  class TestGroupBy(unittest.TestCase):
   923    def test_lambdas(self):
   924      def normalize(key, values):
   925        return tuple(key) if isinstance(key, tuple) else key, sorted(values)
   926  
   927      with TestPipeline() as p:
   928        pcoll = p | beam.Create(range(6))
   929        assert_that(
   930            pcoll | beam.GroupBy() | beam.MapTuple(normalize),
   931            equal_to([((), [0, 1, 2, 3, 4, 5])]),
   932            'GroupAll')
   933        assert_that(
   934            pcoll | beam.GroupBy(lambda x: x % 2)
   935            | 'n2' >> beam.MapTuple(normalize),
   936            equal_to([(0, [0, 2, 4]), (1, [1, 3, 5])]),
   937            'GroupOne')
   938        assert_that(
   939            pcoll | 'G2' >> beam.GroupBy(lambda x: x % 2).force_tuple_keys()
   940            | 'n3' >> beam.MapTuple(normalize),
   941            equal_to([((0, ), [0, 2, 4]), ((1, ), [1, 3, 5])]),
   942            'GroupOneTuple')
   943        assert_that(
   944            pcoll | beam.GroupBy(a=lambda x: x % 2, b=lambda x: x < 4)
   945            | 'n4' >> beam.MapTuple(normalize),
   946            equal_to([((0, True), [0, 2]), ((1, True), [1, 3]), ((0, False), [4]),
   947                      ((1, False), [5])]),
   948            'GroupTwo')
   949  
   950    def test_fields(self):
   951      def normalize(key, values):
   952        if isinstance(key, tuple):
   953          key = beam.Row(
   954              **{name: value
   955                 for name, value in zip(type(key)._fields, key)})
   956        return key, sorted(v.value for v in values)
   957  
   958      with TestPipeline() as p:
   959        pcoll = p | beam.Create(range(-2, 3)) | beam.Map(int) | beam.Map(
   960            lambda x: beam.Row(
   961                value=x, square=x * x, sign=x // abs(x) if x else 0))
   962        assert_that(
   963            pcoll | beam.GroupBy('square') | beam.MapTuple(normalize),
   964            equal_to([
   965                (0, [0]),
   966                (1, [-1, 1]),
   967                (4, [-2, 2]),
   968            ]),
   969            'GroupSquare')
   970        assert_that(
   971            pcoll | 'G2' >> beam.GroupBy('square').force_tuple_keys()
   972            | 'n2' >> beam.MapTuple(normalize),
   973            equal_to([
   974                (beam.Row(square=0), [0]),
   975                (beam.Row(square=1), [-1, 1]),
   976                (beam.Row(square=4), [-2, 2]),
   977            ]),
   978            'GroupSquareTupleKey')
   979        assert_that(
   980            pcoll | beam.GroupBy('square', 'sign')
   981            | 'n3' >> beam.MapTuple(normalize),
   982            equal_to([
   983                (beam.Row(square=0, sign=0), [0]),
   984                (beam.Row(square=1, sign=1), [1]),
   985                (beam.Row(square=4, sign=1), [2]),
   986                (beam.Row(square=1, sign=-1), [-1]),
   987                (beam.Row(square=4, sign=-1), [-2]),
   988            ]),
   989            'GroupSquareSign')
   990        assert_that(
   991            pcoll | beam.GroupBy('square', big=lambda x: x.value > 1)
   992            | 'n4' >> beam.MapTuple(normalize),
   993            equal_to([
   994                (beam.Row(square=0, big=False), [0]),
   995                (beam.Row(square=1, big=False), [-1, 1]),
   996                (beam.Row(square=4, big=False), [-2]),
   997                (beam.Row(square=4, big=True), [2]),
   998            ]),
   999            'GroupSquareNonzero')
  1000  
  1001    def test_aggregate(self):
  1002      def named_tuple_to_row(t):
  1003        return beam.Row(
  1004            **{name: value
  1005               for name, value in zip(type(t)._fields, t)})
  1006  
  1007      with TestPipeline() as p:
  1008        pcoll = p | beam.Create(range(-2, 3)) | beam.Map(
  1009            lambda x: beam.Row(
  1010                value=x, square=x * x, sign=x // abs(x) if x else 0))
  1011  
  1012        assert_that(
  1013            pcoll
  1014            | beam.GroupBy('square', big=lambda x: x.value > 1)
  1015              .aggregate_field('value', sum, 'sum')
  1016              .aggregate_field(lambda x: x.sign == 1, all, 'positive')
  1017            | beam.Map(named_tuple_to_row),
  1018            equal_to([
  1019                beam.Row(square=0, big=False, sum=0, positive=False),   # [0],
  1020                beam.Row(square=1, big=False, sum=0, positive=False),   # [-1, 1]
  1021                beam.Row(square=4, big=False, sum=-2, positive=False),  # [-2]
  1022                beam.Row(square=4, big=True, sum=2, positive=True),     # [2]
  1023            ]))
  1024  
  1025    def test_pickled_field(self):
  1026      with TestPipeline() as p:
  1027        assert_that(
  1028            p
  1029            | beam.Create(['a', 'a', 'b'])
  1030            | beam.Map(
  1031                lambda s: beam.Row(
  1032                    key1=PickledObject(s), key2=s.upper(), value=0))
  1033            | beam.GroupBy('key1', 'key2')
  1034            | beam.MapTuple(lambda k, vs: (k.key1.value, k.key2, len(list(vs)))),
  1035            equal_to([('a', 'A', 2), ('b', 'B', 1)]))
  1036  
  1037  
  1038  class SelectTest(unittest.TestCase):
  1039    def test_simple(self):
  1040      with TestPipeline() as p:
  1041        rows = (
  1042            p | beam.Create([1, 2, 10])
  1043            | beam.Select(a=lambda x: x * x, b=lambda x: -x))
  1044  
  1045        assert_that(
  1046            rows,
  1047            equal_to([
  1048                beam.Row(a=1, b=-1),
  1049                beam.Row(a=4, b=-2),
  1050                beam.Row(a=100, b=-10),
  1051            ]),
  1052            label='CheckFromLambdas')
  1053  
  1054        from_attr = rows | beam.Select('b', z='a')
  1055        assert_that(
  1056            from_attr,
  1057            equal_to([
  1058                beam.Row(b=-1, z=1),
  1059                beam.Row(b=-2, z=4),
  1060                beam.Row(
  1061                    b=-10,
  1062                    z=100,
  1063                ),
  1064            ]),
  1065            label='CheckFromAttrs')
  1066  
  1067  
  1068  @beam.ptransform_fn
  1069  def SamplePTransform(pcoll):
  1070    """Sample transform using the @ptransform_fn decorator."""
  1071    map_transform = 'ToPairs' >> beam.Map(lambda v: (v, None))
  1072    combine_transform = 'Group' >> beam.CombinePerKey(lambda vs: None)
  1073    keys_transform = 'Distinct' >> beam.Keys()
  1074    return pcoll | map_transform | combine_transform | keys_transform
  1075  
  1076  
  1077  class PTransformLabelsTest(unittest.TestCase):
  1078    class CustomTransform(beam.PTransform):
  1079  
  1080      pardo = None  # type: Optional[beam.PTransform]
  1081  
  1082      def expand(self, pcoll):
  1083        self.pardo = '*Do*' >> beam.FlatMap(lambda x: [x + 1])
  1084        return pcoll | self.pardo
  1085  
  1086    def test_chained_ptransforms(self):
  1087      """Tests that chaining gets proper nesting."""
  1088      with TestPipeline() as pipeline:
  1089        map1 = 'Map1' >> beam.Map(lambda x: (x, 1))
  1090        gbk = 'Gbk' >> beam.GroupByKey()
  1091        map2 = 'Map2' >> beam.Map(lambda x_ones2: (x_ones2[0], sum(x_ones2[1])))
  1092        t = (map1 | gbk | map2)
  1093        result = pipeline | 'Start' >> beam.Create(['a', 'a', 'b']) | t
  1094        self.assertTrue('Map1|Gbk|Map2/Map1' in pipeline.applied_labels)
  1095        self.assertTrue('Map1|Gbk|Map2/Gbk' in pipeline.applied_labels)
  1096        self.assertTrue('Map1|Gbk|Map2/Map2' in pipeline.applied_labels)
  1097        assert_that(result, equal_to([('a', 2), ('b', 1)]))
  1098  
  1099    def test_apply_custom_transform_without_label(self):
  1100      with TestPipeline() as pipeline:
  1101        pcoll = pipeline | 'PColl' >> beam.Create([1, 2, 3])
  1102        custom = PTransformLabelsTest.CustomTransform()
  1103        result = pipeline.apply(custom, pcoll)
  1104        self.assertTrue('CustomTransform' in pipeline.applied_labels)
  1105        self.assertTrue('CustomTransform/*Do*' in pipeline.applied_labels)
  1106        assert_that(result, equal_to([2, 3, 4]))
  1107  
  1108    def test_apply_custom_transform_with_label(self):
  1109      with TestPipeline() as pipeline:
  1110        pcoll = pipeline | 'PColl' >> beam.Create([1, 2, 3])
  1111        custom = PTransformLabelsTest.CustomTransform('*Custom*')
  1112        result = pipeline.apply(custom, pcoll)
  1113        self.assertTrue('*Custom*' in pipeline.applied_labels)
  1114        self.assertTrue('*Custom*/*Do*' in pipeline.applied_labels)
  1115        assert_that(result, equal_to([2, 3, 4]))
  1116  
  1117    def test_combine_without_label(self):
  1118      vals = [1, 2, 3, 4, 5, 6, 7]
  1119      with TestPipeline() as pipeline:
  1120        pcoll = pipeline | 'Start' >> beam.Create(vals)
  1121        combine = beam.CombineGlobally(sum)
  1122        result = pcoll | combine
  1123        self.assertTrue('CombineGlobally(sum)' in pipeline.applied_labels)
  1124        assert_that(result, equal_to([sum(vals)]))
  1125  
  1126    def test_apply_ptransform_using_decorator(self):
  1127      pipeline = TestPipeline()
  1128      pcoll = pipeline | 'PColl' >> beam.Create([1, 2, 3])
  1129      _ = pcoll | '*Sample*' >> SamplePTransform()
  1130      self.assertTrue('*Sample*' in pipeline.applied_labels)
  1131      self.assertTrue('*Sample*/ToPairs' in pipeline.applied_labels)
  1132      self.assertTrue('*Sample*/Group' in pipeline.applied_labels)
  1133      self.assertTrue('*Sample*/Distinct' in pipeline.applied_labels)
  1134  
  1135    def test_combine_with_label(self):
  1136      vals = [1, 2, 3, 4, 5, 6, 7]
  1137      with TestPipeline() as pipeline:
  1138        pcoll = pipeline | 'Start' >> beam.Create(vals)
  1139        combine = '*Sum*' >> beam.CombineGlobally(sum)
  1140        result = pcoll | combine
  1141        self.assertTrue('*Sum*' in pipeline.applied_labels)
  1142        assert_that(result, equal_to([sum(vals)]))
  1143  
  1144    def check_label(self, ptransform, expected_label):
  1145      pipeline = TestPipeline()
  1146      pipeline | 'Start' >> beam.Create([('a', 1)]) | ptransform
  1147      actual_label = sorted(
  1148          label for label in pipeline.applied_labels
  1149          if not label.startswith('Start'))[0]
  1150      self.assertEqual(expected_label, re.sub(r'\d{3,}', '#', actual_label))
  1151  
  1152    def test_default_labels(self):
  1153      def my_function(*args):
  1154        pass
  1155  
  1156      self.check_label(beam.Map(len), 'Map(len)')
  1157      self.check_label(beam.Map(my_function), 'Map(my_function)')
  1158      self.check_label(
  1159          beam.Map(lambda x: x), 'Map(<lambda at ptransform_test.py:#>)')
  1160      self.check_label(beam.FlatMap(list), 'FlatMap(list)')
  1161      self.check_label(beam.FlatMap(my_function), 'FlatMap(my_function)')
  1162      self.check_label(beam.Filter(sum), 'Filter(sum)')
  1163      self.check_label(beam.CombineGlobally(sum), 'CombineGlobally(sum)')
  1164      self.check_label(beam.CombinePerKey(sum), 'CombinePerKey(sum)')
  1165  
  1166      class MyDoFn(beam.DoFn):
  1167        def process(self, unused_element):
  1168          pass
  1169  
  1170      self.check_label(beam.ParDo(MyDoFn()), 'ParDo(MyDoFn)')
  1171  
  1172    def test_label_propogation(self):
  1173      self.check_label('TestMap' >> beam.Map(len), 'TestMap')
  1174      self.check_label('TestLambda' >> beam.Map(lambda x: x), 'TestLambda')
  1175      self.check_label('TestFlatMap' >> beam.FlatMap(list), 'TestFlatMap')
  1176      self.check_label('TestFilter' >> beam.Filter(sum), 'TestFilter')
  1177      self.check_label('TestCG' >> beam.CombineGlobally(sum), 'TestCG')
  1178      self.check_label('TestCPK' >> beam.CombinePerKey(sum), 'TestCPK')
  1179  
  1180      class MyDoFn(beam.DoFn):
  1181        def process(self, unused_element):
  1182          pass
  1183  
  1184      self.check_label('TestParDo' >> beam.ParDo(MyDoFn()), 'TestParDo')
  1185  
  1186  
  1187  class PTransformTestDisplayData(unittest.TestCase):
  1188    def test_map_named_function(self):
  1189      tr = beam.Map(len)
  1190      dd = DisplayData.create_from(tr)
  1191      nspace = 'apache_beam.transforms.core.CallableWrapperDoFn'
  1192      expected_item = DisplayDataItem(
  1193          'len', key='fn', label='Transform Function', namespace=nspace)
  1194      hc.assert_that(dd.items, hc.has_item(expected_item))
  1195  
  1196    def test_map_anonymous_function(self):
  1197      tr = beam.Map(lambda x: x)
  1198      dd = DisplayData.create_from(tr)
  1199      nspace = 'apache_beam.transforms.core.CallableWrapperDoFn'
  1200      expected_item = DisplayDataItem(
  1201          '<lambda>', key='fn', label='Transform Function', namespace=nspace)
  1202      hc.assert_that(dd.items, hc.has_item(expected_item))
  1203  
  1204    def test_flatmap_named_function(self):
  1205      tr = beam.FlatMap(list)
  1206      dd = DisplayData.create_from(tr)
  1207      nspace = 'apache_beam.transforms.core.CallableWrapperDoFn'
  1208      expected_item = DisplayDataItem(
  1209          'list', key='fn', label='Transform Function', namespace=nspace)
  1210      hc.assert_that(dd.items, hc.has_item(expected_item))
  1211  
  1212    def test_flatmap_anonymous_function(self):
  1213      tr = beam.FlatMap(lambda x: [x])
  1214      dd = DisplayData.create_from(tr)
  1215      nspace = 'apache_beam.transforms.core.CallableWrapperDoFn'
  1216      expected_item = DisplayDataItem(
  1217          '<lambda>', key='fn', label='Transform Function', namespace=nspace)
  1218      hc.assert_that(dd.items, hc.has_item(expected_item))
  1219  
  1220    def test_filter_named_function(self):
  1221      tr = beam.Filter(sum)
  1222      dd = DisplayData.create_from(tr)
  1223      nspace = 'apache_beam.transforms.core.CallableWrapperDoFn'
  1224      expected_item = DisplayDataItem(
  1225          'sum', key='fn', label='Transform Function', namespace=nspace)
  1226      hc.assert_that(dd.items, hc.has_item(expected_item))
  1227  
  1228    def test_filter_anonymous_function(self):
  1229      tr = beam.Filter(lambda x: x // 30)
  1230      dd = DisplayData.create_from(tr)
  1231      nspace = 'apache_beam.transforms.core.CallableWrapperDoFn'
  1232      expected_item = DisplayDataItem(
  1233          '<lambda>', key='fn', label='Transform Function', namespace=nspace)
  1234      hc.assert_that(dd.items, hc.has_item(expected_item))
  1235  
  1236  
  1237  class PTransformTypeCheckTestCase(TypeHintTestCase):
  1238    def assertStartswith(self, msg, prefix):
  1239      self.assertTrue(
  1240          msg.startswith(prefix), '"%s" does not start with "%s"' % (msg, prefix))
  1241  
  1242    def setUp(self):
  1243      self.p = TestPipeline()
  1244  
  1245    def test_do_fn_pipeline_pipeline_type_check_satisfied(self):
  1246      @with_input_types(int, int)
  1247      @with_output_types(int)
  1248      class AddWithFive(beam.DoFn):
  1249        def process(self, element, five):
  1250          return [element + five]
  1251  
  1252      d = (
  1253          self.p
  1254          | 'T' >> beam.Create([1, 2, 3]).with_output_types(int)
  1255          | 'Add' >> beam.ParDo(AddWithFive(), 5))
  1256  
  1257      assert_that(d, equal_to([6, 7, 8]))
  1258      self.p.run()
  1259  
  1260    def test_do_fn_pipeline_pipeline_type_check_violated(self):
  1261      @with_input_types(str, str)
  1262      @with_output_types(str)
  1263      class ToUpperCaseWithPrefix(beam.DoFn):
  1264        def process(self, element, prefix):
  1265          return [prefix + element.upper()]
  1266  
  1267      with self.assertRaises(typehints.TypeCheckError) as e:
  1268        (
  1269            self.p
  1270            | 'T' >> beam.Create([1, 2, 3]).with_output_types(int)
  1271            | 'Upper' >> beam.ParDo(ToUpperCaseWithPrefix(), 'hello'))
  1272  
  1273      self.assertStartswith(
  1274          e.exception.args[0],
  1275          "Type hint violation for 'Upper': "
  1276          "requires {} but got {} for element".format(str, int))
  1277  
  1278    def test_do_fn_pipeline_runtime_type_check_satisfied(self):
  1279      self.p._options.view_as(TypeOptions).runtime_type_check = True
  1280  
  1281      @with_input_types(int, int)
  1282      @with_output_types(int)
  1283      class AddWithNum(beam.DoFn):
  1284        def process(self, element, num):
  1285          return [element + num]
  1286  
  1287      d = (
  1288          self.p
  1289          | 'T' >> beam.Create([1, 2, 3]).with_output_types(int)
  1290          | 'Add' >> beam.ParDo(AddWithNum(), 5))
  1291  
  1292      assert_that(d, equal_to([6, 7, 8]))
  1293      self.p.run()
  1294  
  1295    def test_do_fn_pipeline_runtime_type_check_violated(self):
  1296      self.p._options.view_as(TypeOptions).runtime_type_check = True
  1297  
  1298      @with_input_types(int, int)
  1299      @with_output_types(int)
  1300      class AddWithNum(beam.DoFn):
  1301        def process(self, element, num):
  1302          return [element + num]
  1303  
  1304      with self.assertRaises(typehints.TypeCheckError) as e:
  1305        (
  1306            self.p
  1307            | 'T' >> beam.Create(['1', '2', '3']).with_output_types(str)
  1308            | 'Add' >> beam.ParDo(AddWithNum(), 5))
  1309        self.p.run()
  1310  
  1311      self.assertStartswith(
  1312          e.exception.args[0],
  1313          "Type hint violation for 'Add': "
  1314          "requires {} but got {} for element".format(int, str))
  1315  
  1316    def test_pardo_does_not_type_check_using_type_hint_decorators(self):
  1317      @with_input_types(a=int)
  1318      @with_output_types(typing.List[str])
  1319      def int_to_str(a):
  1320        return [str(a)]
  1321  
  1322      # The function above is expecting an int for its only parameter. However, it
  1323      # will receive a str instead, which should result in a raised exception.
  1324      with self.assertRaises(typehints.TypeCheckError) as e:
  1325        (
  1326            self.p
  1327            | 'S' >> beam.Create(['b', 'a', 'r']).with_output_types(str)
  1328            | 'ToStr' >> beam.FlatMap(int_to_str))
  1329  
  1330      self.assertStartswith(
  1331          e.exception.args[0],
  1332          "Type hint violation for 'ToStr': "
  1333          "requires {} but got {} for a".format(int, str))
  1334  
  1335    def test_pardo_properly_type_checks_using_type_hint_decorators(self):
  1336      @with_input_types(a=str)
  1337      @with_output_types(typing.List[str])
  1338      def to_all_upper_case(a):
  1339        return [a.upper()]
  1340  
  1341      # If this type-checks than no error should be raised.
  1342      d = (
  1343          self.p
  1344          | 'T' >> beam.Create(['t', 'e', 's', 't']).with_output_types(str)
  1345          | 'Case' >> beam.FlatMap(to_all_upper_case))
  1346      assert_that(d, equal_to(['T', 'E', 'S', 'T']))
  1347      self.p.run()
  1348  
  1349      # Output type should have been recognized as 'str' rather than List[str] to
  1350      # do the flatten part of FlatMap.
  1351      self.assertEqual(str, d.element_type)
  1352  
  1353    def test_pardo_does_not_type_check_using_type_hint_methods(self):
  1354      # The first ParDo outputs pcoll's of type int, however the second ParDo is
  1355      # expecting pcoll's of type str instead.
  1356      with self.assertRaises(typehints.TypeCheckError) as e:
  1357        (
  1358            self.p
  1359            | 'S' >> beam.Create(['t', 'e', 's', 't']).with_output_types(str)
  1360            | (
  1361                'Score' >> beam.FlatMap(lambda x: [1] if x == 't' else [2]).
  1362                with_input_types(str).with_output_types(int))
  1363            | (
  1364                'Upper' >> beam.FlatMap(lambda x: [x.upper()]).with_input_types(
  1365                    str).with_output_types(str)))
  1366  
  1367      self.assertStartswith(
  1368          e.exception.args[0],
  1369          "Type hint violation for 'Upper': "
  1370          "requires {} but got {} for x".format(str, int))
  1371  
  1372    def test_pardo_properly_type_checks_using_type_hint_methods(self):
  1373      # Pipeline should be created successfully without an error
  1374      d = (
  1375          self.p
  1376          | 'S' >> beam.Create(['t', 'e', 's', 't']).with_output_types(str)
  1377          | 'Dup' >> beam.FlatMap(lambda x: [x + x]).with_input_types(
  1378              str).with_output_types(str)
  1379          | 'Upper' >> beam.FlatMap(lambda x: [x.upper()]).with_input_types(
  1380              str).with_output_types(str))
  1381  
  1382      assert_that(d, equal_to(['TT', 'EE', 'SS', 'TT']))
  1383      self.p.run()
  1384  
  1385    def test_map_does_not_type_check_using_type_hints_methods(self):
  1386      # The transform before 'Map' has indicated that it outputs PCollections with
  1387      # int's, while Map is expecting one of str.
  1388      with self.assertRaises(typehints.TypeCheckError) as e:
  1389        (
  1390            self.p
  1391            | 'S' >> beam.Create([1, 2, 3, 4]).with_output_types(int)
  1392            | 'Upper' >> beam.Map(lambda x: x.upper()).with_input_types(
  1393                str).with_output_types(str))
  1394  
  1395      self.assertStartswith(
  1396          e.exception.args[0],
  1397          "Type hint violation for 'Upper': "
  1398          "requires {} but got {} for x".format(str, int))
  1399  
  1400    def test_map_properly_type_checks_using_type_hints_methods(self):
  1401      # No error should be raised if this type-checks properly.
  1402      d = (
  1403          self.p
  1404          | 'S' >> beam.Create([1, 2, 3, 4]).with_output_types(int)
  1405          | 'ToStr' >>
  1406          beam.Map(lambda x: str(x)).with_input_types(int).with_output_types(str))
  1407      assert_that(d, equal_to(['1', '2', '3', '4']))
  1408      self.p.run()
  1409  
  1410    def test_map_does_not_type_check_using_type_hints_decorator(self):
  1411      @with_input_types(s=str)
  1412      @with_output_types(str)
  1413      def upper(s):
  1414        return s.upper()
  1415  
  1416      # Hinted function above expects a str at pipeline construction.
  1417      # However, 'Map' should detect that Create has hinted an int instead.
  1418      with self.assertRaises(typehints.TypeCheckError) as e:
  1419        (
  1420            self.p
  1421            | 'S' >> beam.Create([1, 2, 3, 4]).with_output_types(int)
  1422            | 'Upper' >> beam.Map(upper))
  1423  
  1424      self.assertStartswith(
  1425          e.exception.args[0],
  1426          "Type hint violation for 'Upper': "
  1427          "requires {} but got {} for s".format(str, int))
  1428  
  1429    def test_map_properly_type_checks_using_type_hints_decorator(self):
  1430      @with_input_types(a=bool)
  1431      @with_output_types(int)
  1432      def bool_to_int(a):
  1433        return int(a)
  1434  
  1435      # If this type-checks than no error should be raised.
  1436      d = (
  1437          self.p
  1438          | 'Bools' >> beam.Create([True, False, True]).with_output_types(bool)
  1439          | 'ToInts' >> beam.Map(bool_to_int))
  1440      assert_that(d, equal_to([1, 0, 1]))
  1441      self.p.run()
  1442  
  1443    def test_filter_does_not_type_check_using_type_hints_method(self):
  1444      # Filter is expecting an int but instead looks to the 'left' and sees a str
  1445      # incoming.
  1446      with self.assertRaises(typehints.TypeCheckError) as e:
  1447        (
  1448            self.p
  1449            | 'Strs' >> beam.Create(['1', '2', '3', '4', '5'
  1450                                     ]).with_output_types(str)
  1451            | 'Lower' >> beam.Map(lambda x: x.lower()).with_input_types(
  1452                str).with_output_types(str)
  1453            | 'Below 3' >> beam.Filter(lambda x: x < 3).with_input_types(int))
  1454  
  1455      self.assertStartswith(
  1456          e.exception.args[0],
  1457          "Type hint violation for 'Below 3': "
  1458          "requires {} but got {} for x".format(int, str))
  1459  
  1460    def test_filter_type_checks_using_type_hints_method(self):
  1461      # No error should be raised if this type-checks properly.
  1462      d = (
  1463          self.p
  1464          | beam.Create(['1', '2', '3', '4', '5']).with_output_types(str)
  1465          | 'ToInt' >>
  1466          beam.Map(lambda x: int(x)).with_input_types(str).with_output_types(int)
  1467          | 'Below 3' >> beam.Filter(lambda x: x < 3).with_input_types(int))
  1468      assert_that(d, equal_to([1, 2]))
  1469      self.p.run()
  1470  
  1471    def test_filter_does_not_type_check_using_type_hints_decorator(self):
  1472      @with_input_types(a=float)
  1473      def more_than_half(a):
  1474        return a > 0.50
  1475  
  1476      # Func above was hinted to only take a float, yet an int will be passed.
  1477      with self.assertRaises(typehints.TypeCheckError) as e:
  1478        (
  1479            self.p
  1480            | 'Ints' >> beam.Create([1, 2, 3, 4]).with_output_types(int)
  1481            | 'Half' >> beam.Filter(more_than_half))
  1482  
  1483      self.assertStartswith(
  1484          e.exception.args[0],
  1485          "Type hint violation for 'Half': "
  1486          "requires {} but got {} for a".format(float, int))
  1487  
  1488    def test_filter_type_checks_using_type_hints_decorator(self):
  1489      @with_input_types(b=int)
  1490      def half(b):
  1491        return bool(random.choice([0, 1]))
  1492  
  1493      # Filter should deduce that it returns the same type that it takes.
  1494      (
  1495          self.p
  1496          | 'Str' >> beam.Create(range(5)).with_output_types(int)
  1497          | 'Half' >> beam.Filter(half)
  1498          | 'ToBool' >> beam.Map(lambda x: bool(x)).with_input_types(
  1499              int).with_output_types(bool))
  1500  
  1501    def test_pardo_like_inheriting_output_types_from_annotation(self):
  1502      def fn1(x: str) -> int:
  1503        return 1
  1504  
  1505      def fn1_flat(x: str) -> typing.List[int]:
  1506        return [1]
  1507  
  1508      def fn2(x: int, y: str) -> str:
  1509        return y
  1510  
  1511      def fn2_flat(x: int, y: str) -> typing.List[str]:
  1512        return [y]
  1513  
  1514      # We only need the args section of the hints.
  1515      def output_hints(transform):
  1516        return transform.default_type_hints().output_types[0][0]
  1517  
  1518      self.assertEqual(int, output_hints(beam.Map(fn1)))
  1519      self.assertEqual(int, output_hints(beam.FlatMap(fn1_flat)))
  1520  
  1521      self.assertEqual(str, output_hints(beam.MapTuple(fn2)))
  1522      self.assertEqual(str, output_hints(beam.FlatMapTuple(fn2_flat)))
  1523  
  1524      def add(a: typing.Iterable[int]) -> int:
  1525        return sum(a)
  1526  
  1527      self.assertCompatible(
  1528          typing.Tuple[typing.TypeVar('K'), int],
  1529          output_hints(beam.CombinePerKey(add)))
  1530  
  1531    def test_group_by_key_only_output_type_deduction(self):
  1532      d = (
  1533          self.p
  1534          | 'Str' >> beam.Create(['t', 'e', 's', 't']).with_output_types(str)
  1535          | (
  1536              'Pair' >> beam.Map(lambda x: (x, ord(x))).with_output_types(
  1537                  typing.Tuple[str, str]))
  1538          | beam.GroupByKey())
  1539  
  1540      # Output type should correctly be deduced.
  1541      # GBK-only should deduce that Tuple[A, B] is turned into
  1542      # Tuple[A, Iterable[B]].
  1543      self.assertCompatible(
  1544          typing.Tuple[str, typing.Iterable[str]], d.element_type)
  1545  
  1546    def test_group_by_key_output_type_deduction(self):
  1547      d = (
  1548          self.p
  1549          | 'Str' >> beam.Create(range(20)).with_output_types(int)
  1550          | (
  1551              'PairNegative' >> beam.Map(lambda x: (x % 5, -x)).with_output_types(
  1552                  typing.Tuple[int, int]))
  1553          | beam.GroupByKey())
  1554  
  1555      # Output type should correctly be deduced.
  1556      # GBK should deduce that Tuple[A, B] is turned into Tuple[A, Iterable[B]].
  1557      self.assertCompatible(
  1558          typing.Tuple[int, typing.Iterable[int]], d.element_type)
  1559  
  1560    def test_group_by_key_only_does_not_type_check(self):
  1561      # GBK will be passed raw int's here instead of some form of Tuple[A, B].
  1562      with self.assertRaises(typehints.TypeCheckError) as e:
  1563        (
  1564            self.p
  1565            | beam.Create([1, 2, 3]).with_output_types(int)
  1566            | 'F' >> beam.GroupByKey())
  1567  
  1568      self.assertStartswith(
  1569          e.exception.args[0],
  1570          "Input type hint violation at F: "
  1571          "expected Tuple[TypeVariable[K], TypeVariable[V]], "
  1572          "got {}".format(int))
  1573  
  1574    def test_group_by_does_not_type_check(self):
  1575      # Create is returning a List[int, str], rather than a Tuple[int, str]
  1576      # that is aliased to Tuple[int, str].
  1577      with self.assertRaises(typehints.TypeCheckError) as e:
  1578        (
  1579            self.p
  1580            | (beam.Create([[1], [2]]).with_output_types(typing.Iterable[int]))
  1581            | 'T' >> beam.GroupByKey())
  1582  
  1583      self.assertStartswith(
  1584          e.exception.args[0],
  1585          "Input type hint violation at T: "
  1586          "expected Tuple[TypeVariable[K], TypeVariable[V]], "
  1587          "got Iterable[<class 'int'>]")
  1588  
  1589    def test_pipeline_checking_pardo_insufficient_type_information(self):
  1590      self.p._options.view_as(TypeOptions).type_check_strictness = 'ALL_REQUIRED'
  1591  
  1592      # Type checking is enabled, but 'Create' doesn't pass on any relevant type
  1593      # information to the ParDo.
  1594      with self.assertRaises(typehints.TypeCheckError) as e:
  1595        (
  1596            self.p
  1597            | 'Nums' >> beam.Create(range(5))
  1598            | 'ModDup' >> beam.FlatMap(lambda x: (x % 2, x)))
  1599  
  1600      self.assertEqual(
  1601          'Pipeline type checking is enabled, however no output '
  1602          'type-hint was found for the PTransform Create(Nums)',
  1603          e.exception.args[0])
  1604  
  1605    def test_pipeline_checking_gbk_insufficient_type_information(self):
  1606      self.p._options.view_as(TypeOptions).type_check_strictness = 'ALL_REQUIRED'
  1607      # Type checking is enabled, but 'Map' doesn't pass on any relevant type
  1608      # information to GBK-only.
  1609      with self.assertRaises(typehints.TypeCheckError) as e:
  1610        (
  1611            self.p
  1612            | 'Nums' >> beam.Create(range(5)).with_output_types(int)
  1613            | 'ModDup' >> beam.Map(lambda x: (x % 2, x))
  1614            | beam.GroupByKey())
  1615  
  1616      self.assertEqual(
  1617          'Pipeline type checking is enabled, however no output '
  1618          'type-hint was found for the PTransform '
  1619          'ParDo(ModDup)',
  1620          e.exception.args[0])
  1621  
  1622    def test_disable_pipeline_type_check(self):
  1623      self.p._options.view_as(TypeOptions).pipeline_type_check = False
  1624  
  1625      # The pipeline below should raise a TypeError, however pipeline type
  1626      # checking was disabled above.
  1627      (
  1628          self.p
  1629          | 'T' >> beam.Create([1, 2, 3]).with_output_types(int)
  1630          | 'Lower' >> beam.Map(lambda x: x.lower()).with_input_types(
  1631              str).with_output_types(str))
  1632  
  1633    def test_run_time_type_checking_enabled_type_violation(self):
  1634      self.p._options.view_as(TypeOptions).pipeline_type_check = False
  1635      self.p._options.view_as(TypeOptions).runtime_type_check = True
  1636  
  1637      @with_output_types(str)
  1638      @with_input_types(x=int)
  1639      def int_to_string(x):
  1640        return str(x)
  1641  
  1642      # Function above has been type-hinted to only accept an int. But in the
  1643      # pipeline execution it'll be passed a string due to the output of Create.
  1644      (
  1645          self.p
  1646          | 'T' >> beam.Create(['some_string'])
  1647          | 'ToStr' >> beam.Map(int_to_string))
  1648      with self.assertRaises(typehints.TypeCheckError) as e:
  1649        self.p.run()
  1650  
  1651      self.assertStartswith(
  1652          e.exception.args[0],
  1653          "Runtime type violation detected within ParDo(ToStr): "
  1654          "Type-hint for argument: 'x' violated. "
  1655          "Expected an instance of {}, "
  1656          "instead found some_string, an instance of {}.".format(int, str))
  1657  
  1658    def test_run_time_type_checking_enabled_types_satisfied(self):
  1659      self.p._options.view_as(TypeOptions).pipeline_type_check = False
  1660      self.p._options.view_as(TypeOptions).runtime_type_check = True
  1661  
  1662      @with_output_types(typing.Tuple[int, str])
  1663      @with_input_types(x=str)
  1664      def group_with_upper_ord(x):
  1665        return (ord(x.upper()) % 5, x)
  1666  
  1667      # Pipeline checking is off, but the above function should satisfy types at
  1668      # run-time.
  1669      result = (
  1670          self.p
  1671          | 'T' >> beam.Create(['t', 'e', 's', 't', 'i', 'n', 'g'
  1672                                ]).with_output_types(str)
  1673          | 'GenKeys' >> beam.Map(group_with_upper_ord)
  1674          | 'O' >> beam.GroupByKey()
  1675          | SortLists)
  1676  
  1677      assert_that(
  1678          result,
  1679          equal_to([(1, ['g']), (3, ['i', 'n', 's']), (4, ['e', 't', 't'])]))
  1680      self.p.run()
  1681  
  1682    def test_pipeline_checking_satisfied_but_run_time_types_violate(self):
  1683      self.p._options.view_as(TypeOptions).pipeline_type_check = False
  1684      self.p._options.view_as(TypeOptions).runtime_type_check = True
  1685  
  1686      @with_output_types(typing.Tuple[bool, int])
  1687      @with_input_types(a=int)
  1688      def is_even_as_key(a):
  1689        # Simulate a programming error, should be: return (a % 2 == 0, a)
  1690        # However this returns Tuple[int, int]
  1691        return (a % 2, a)
  1692  
  1693      (
  1694          self.p
  1695          | 'Nums' >> beam.Create(range(5)).with_output_types(int)
  1696          | 'IsEven' >> beam.Map(is_even_as_key)
  1697          | 'Parity' >> beam.GroupByKey())
  1698  
  1699      # Although all the types appear to be correct when checked at pipeline
  1700      # construction. Runtime type-checking should detect the 'is_even_as_key' is
  1701      # returning Tuple[int, int], instead of Tuple[bool, int].
  1702      with self.assertRaises(typehints.TypeCheckError) as e:
  1703        self.p.run()
  1704  
  1705      self.assertStartswith(
  1706          e.exception.args[0],
  1707          "Runtime type violation detected within ParDo(IsEven): "
  1708          "Tuple[<class 'bool'>, <class 'int'>] hint type-constraint violated. "
  1709          "The type of element #0 in the passed tuple is incorrect. "
  1710          "Expected an instance of type <class 'bool'>, "
  1711          "instead received an instance of type int.")
  1712  
  1713    def test_pipeline_checking_satisfied_run_time_checking_satisfied(self):
  1714      self.p._options.view_as(TypeOptions).pipeline_type_check = False
  1715  
  1716      @with_output_types(typing.Tuple[bool, int])
  1717      @with_input_types(a=int)
  1718      def is_even_as_key(a):
  1719        # The programming error in the above test-case has now been fixed.
  1720        # Everything should properly type-check.
  1721        return (a % 2 == 0, a)
  1722  
  1723      result = (
  1724          self.p
  1725          | 'Nums' >> beam.Create(range(5)).with_output_types(int)
  1726          | 'IsEven' >> beam.Map(is_even_as_key)
  1727          | 'Parity' >> beam.GroupByKey()
  1728          | SortLists)
  1729  
  1730      assert_that(result, equal_to([(False, [1, 3]), (True, [0, 2, 4])]))
  1731      self.p.run()
  1732  
  1733    def test_pipeline_runtime_checking_violation_simple_type_input(self):
  1734      self.p._options.view_as(TypeOptions).runtime_type_check = True
  1735      self.p._options.view_as(TypeOptions).pipeline_type_check = False
  1736  
  1737      # The type-hinted applied via the 'with_input_types()' method indicates the
  1738      # ParDo should receive an instance of type 'str', however an 'int' will be
  1739      # passed instead.
  1740      with self.assertRaises(typehints.TypeCheckError) as e:
  1741        (
  1742            self.p
  1743            | beam.Create([1, 1, 1])
  1744            | (
  1745                'ToInt' >> beam.FlatMap(lambda x: [int(x)]).with_input_types(
  1746                    str).with_output_types(int)))
  1747        self.p.run()
  1748  
  1749      self.assertStartswith(
  1750          e.exception.args[0],
  1751          "Runtime type violation detected within ParDo(ToInt): "
  1752          "Type-hint for argument: 'x' violated. "
  1753          "Expected an instance of {}, "
  1754          "instead found 1, an instance of {}.".format(str, int))
  1755  
  1756    def test_pipeline_runtime_checking_violation_composite_type_input(self):
  1757      self.p._options.view_as(TypeOptions).runtime_type_check = True
  1758      self.p._options.view_as(TypeOptions).pipeline_type_check = False
  1759  
  1760      with self.assertRaises(typehints.TypeCheckError) as e:
  1761        (
  1762            self.p
  1763            | beam.Create([(1, 3.0), (2, 4.9), (3, 9.5)])
  1764            | (
  1765                'Add' >>
  1766                beam.FlatMap(lambda x_y: [x_y[0] + x_y[1]]).with_input_types(
  1767                    typing.Tuple[int, int]).with_output_types(int)))
  1768        self.p.run()
  1769  
  1770      self.assertStartswith(
  1771          e.exception.args[0],
  1772          "Runtime type violation detected within ParDo(Add): "
  1773          "Type-hint for argument: 'x_y' violated: "
  1774          "Tuple[<class 'int'>, <class 'int'>] hint type-constraint violated. "
  1775          "The type of element #1 in the passed tuple is incorrect. "
  1776          "Expected an instance of type <class 'int'>, instead received an "
  1777          "instance of type float.")
  1778  
  1779    def test_pipeline_runtime_checking_violation_simple_type_output(self):
  1780      self.p._options.view_as(TypeOptions).runtime_type_check = True
  1781      self.p._options.view_as(TypeOptions).pipeline_type_check = False
  1782  
  1783      # The type-hinted applied via the 'returns()' method indicates the ParDo
  1784      # should output an instance of type 'int', however a 'float' will be
  1785      # generated instead.
  1786      print(
  1787          "HINTS",
  1788          (
  1789              'ToInt' >> beam.FlatMap(lambda x: [float(x)]).with_input_types(
  1790                  int).with_output_types(int)).get_type_hints())
  1791      with self.assertRaises(typehints.TypeCheckError) as e:
  1792        (
  1793            self.p
  1794            | beam.Create([1, 1, 1])
  1795            | (
  1796                'ToInt' >> beam.FlatMap(lambda x: [float(x)]).with_input_types(
  1797                    int).with_output_types(int)))
  1798        self.p.run()
  1799  
  1800      if self.p._options.view_as(TypeOptions).runtime_type_check:
  1801        self.assertStartswith(
  1802            e.exception.args[0],
  1803            "Runtime type violation detected within "
  1804            "ParDo(ToInt): "
  1805            "According to type-hint expected output should be "
  1806            "of type {}. Instead, received '1.0', "
  1807            "an instance of type {}.".format(int, float))
  1808  
  1809      if self.p._options.view_as(TypeOptions).performance_runtime_type_check:
  1810        self.assertStartswith(
  1811            e.exception.args[0],
  1812            "Runtime type violation detected within ToInt: "
  1813            "Type-hint for argument: 'x' violated. "
  1814            "Expected an instance of {}, "
  1815            "instead found 1.0, an instance of {}".format(int, float))
  1816  
  1817    def test_pipeline_runtime_checking_violation_composite_type_output(self):
  1818      self.p._options.view_as(TypeOptions).runtime_type_check = True
  1819      self.p._options.view_as(TypeOptions).pipeline_type_check = False
  1820  
  1821      # The type-hinted applied via the 'returns()' method indicates the ParDo
  1822      # should return an instance of type: Tuple[float, int]. However, an instance
  1823      # of 'int' will be generated instead.
  1824      with self.assertRaises(typehints.TypeCheckError) as e:
  1825        (
  1826            self.p
  1827            | beam.Create([(1, 3.0), (2, 4.9), (3, 9.5)])
  1828            | (
  1829                'Swap' >>
  1830                beam.FlatMap(lambda x_y1: [x_y1[0] + x_y1[1]]).with_input_types(
  1831                    typing.Tuple[int, float]).with_output_types(
  1832                        typing.Tuple[float, int])))
  1833        self.p.run()
  1834  
  1835      if self.p._options.view_as(TypeOptions).runtime_type_check:
  1836        self.assertStartswith(
  1837            e.exception.args[0],
  1838            "Runtime type violation detected within "
  1839            "ParDo(Swap): Tuple type constraint violated. "
  1840            "Valid object instance must be of type 'tuple'. Instead, "
  1841            "an instance of 'float' was received.")
  1842  
  1843      if self.p._options.view_as(TypeOptions).performance_runtime_type_check:
  1844        self.assertStartswith(
  1845            e.exception.args[0],
  1846            "Runtime type violation detected within "
  1847            "Swap: Type-hint for argument: 'x_y1' violated: "
  1848            "Tuple type constraint violated. "
  1849            "Valid object instance must be of type 'tuple'. ")
  1850  
  1851    def test_pipeline_runtime_checking_violation_with_side_inputs_decorator(self):
  1852      self.p._options.view_as(TypeOptions).pipeline_type_check = False
  1853      self.p._options.view_as(TypeOptions).runtime_type_check = True
  1854  
  1855      @with_output_types(int)
  1856      @with_input_types(a=int, b=int)
  1857      def add(a, b):
  1858        return a + b
  1859  
  1860      with self.assertRaises(typehints.TypeCheckError) as e:
  1861        (self.p | beam.Create([1, 2, 3, 4]) | 'Add 1' >> beam.Map(add, 1.0))
  1862        self.p.run()
  1863  
  1864      self.assertStartswith(
  1865          e.exception.args[0],
  1866          "Runtime type violation detected within ParDo(Add 1): "
  1867          "Type-hint for argument: 'b' violated. "
  1868          "Expected an instance of {}, "
  1869          "instead found 1.0, an instance of {}.".format(int, float))
  1870  
  1871    def test_pipeline_runtime_checking_violation_with_side_inputs_via_method(self):  # pylint: disable=line-too-long
  1872      self.p._options.view_as(TypeOptions).runtime_type_check = True
  1873      self.p._options.view_as(TypeOptions).pipeline_type_check = False
  1874  
  1875      with self.assertRaises(typehints.TypeCheckError) as e:
  1876        (
  1877            self.p
  1878            | beam.Create([1, 2, 3, 4])
  1879            | (
  1880                'Add 1' >> beam.Map(lambda x, one: x + one, 1.0).with_input_types(
  1881                    int, int).with_output_types(float)))
  1882        self.p.run()
  1883  
  1884      self.assertStartswith(
  1885          e.exception.args[0],
  1886          "Runtime type violation detected within ParDo(Add 1): "
  1887          "Type-hint for argument: 'one' violated. "
  1888          "Expected an instance of {}, "
  1889          "instead found 1.0, an instance of {}.".format(int, float))
  1890  
  1891    def test_combine_properly_pipeline_type_checks_using_decorator(self):
  1892      @with_output_types(int)
  1893      @with_input_types(ints=typing.Iterable[int])
  1894      def sum_ints(ints):
  1895        return sum(ints)
  1896  
  1897      d = (
  1898          self.p
  1899          | 'T' >> beam.Create([1, 2, 3]).with_output_types(int)
  1900          | 'Sum' >> beam.CombineGlobally(sum_ints))
  1901  
  1902      self.assertEqual(int, d.element_type)
  1903      assert_that(d, equal_to([6]))
  1904      self.p.run()
  1905  
  1906    def test_combine_properly_pipeline_type_checks_without_decorator(self):
  1907      def sum_ints(ints):
  1908        return sum(ints)
  1909  
  1910      d = (
  1911          self.p
  1912          | beam.Create([1, 2, 3])
  1913          | beam.Map(lambda x: ('key', x))
  1914          | beam.CombinePerKey(sum_ints))
  1915  
  1916      self.assertEqual(typehints.Tuple[str, typehints.Any], d.element_type)
  1917      self.p.run()
  1918  
  1919    def test_combine_func_type_hint_does_not_take_iterable_using_decorator(self):
  1920      @with_output_types(int)
  1921      @with_input_types(a=int)
  1922      def bad_combine(a):
  1923        5 + a
  1924  
  1925      with self.assertRaises(typehints.TypeCheckError) as e:
  1926        (
  1927            self.p
  1928            | 'M' >> beam.Create([1, 2, 3]).with_output_types(int)
  1929            | 'Add' >> beam.CombineGlobally(bad_combine))
  1930  
  1931      self.assertEqual(
  1932          "All functions for a Combine PTransform must accept a "
  1933          "single argument compatible with: Iterable[Any]. "
  1934          "Instead a function with input type: {} was received.".format(int),
  1935          e.exception.args[0])
  1936  
  1937    def test_combine_pipeline_type_propagation_using_decorators(self):
  1938      @with_output_types(int)
  1939      @with_input_types(ints=typing.Iterable[int])
  1940      def sum_ints(ints):
  1941        return sum(ints)
  1942  
  1943      @with_output_types(typing.List[int])
  1944      @with_input_types(n=int)
  1945      def range_from_zero(n):
  1946        return list(range(n + 1))
  1947  
  1948      d = (
  1949          self.p
  1950          | 'T' >> beam.Create([1, 2, 3]).with_output_types(int)
  1951          | 'Sum' >> beam.CombineGlobally(sum_ints)
  1952          | 'Range' >> beam.ParDo(range_from_zero))
  1953  
  1954      self.assertEqual(int, d.element_type)
  1955      assert_that(d, equal_to([0, 1, 2, 3, 4, 5, 6]))
  1956      self.p.run()
  1957  
  1958    def test_combine_runtime_type_check_satisfied_using_decorators(self):
  1959      self.p._options.view_as(TypeOptions).pipeline_type_check = False
  1960  
  1961      @with_output_types(int)
  1962      @with_input_types(ints=typing.Iterable[int])
  1963      def iter_mul(ints):
  1964        return reduce(operator.mul, ints, 1)
  1965  
  1966      d = (
  1967          self.p
  1968          | 'K' >> beam.Create([5, 5, 5, 5]).with_output_types(int)
  1969          | 'Mul' >> beam.CombineGlobally(iter_mul))
  1970  
  1971      assert_that(d, equal_to([625]))
  1972      self.p.run()
  1973  
  1974    def test_combine_runtime_type_check_violation_using_decorators(self):
  1975      self.p._options.view_as(TypeOptions).pipeline_type_check = False
  1976      self.p._options.view_as(TypeOptions).runtime_type_check = True
  1977  
  1978      # Combine fn is returning the incorrect type
  1979      @with_output_types(int)
  1980      @with_input_types(ints=typing.Iterable[int])
  1981      def iter_mul(ints):
  1982        return str(reduce(operator.mul, ints, 1))
  1983  
  1984      with self.assertRaises(typehints.TypeCheckError) as e:
  1985        (
  1986            self.p
  1987            | 'K' >> beam.Create([5, 5, 5, 5]).with_output_types(int)
  1988            | 'Mul' >> beam.CombineGlobally(iter_mul))
  1989        self.p.run()
  1990  
  1991      self.assertStartswith(
  1992          e.exception.args[0],
  1993          "Runtime type violation detected within "
  1994          "Mul/CombinePerKey: "
  1995          "Type-hint for return type violated. "
  1996          "Expected an instance of {}, instead found".format(int))
  1997  
  1998    def test_combine_pipeline_type_check_using_methods(self):
  1999      d = (
  2000          self.p
  2001          | beam.Create(['t', 'e', 's', 't']).with_output_types(str)
  2002          | (
  2003              'concat' >> beam.CombineGlobally(lambda s: ''.join(s)).
  2004              with_input_types(str).with_output_types(str)))
  2005  
  2006      def matcher(expected):
  2007        def match(actual):
  2008          equal_to(expected)(list(actual[0]))
  2009  
  2010        return match
  2011  
  2012      assert_that(d, matcher('estt'))
  2013      self.p.run()
  2014  
  2015    def test_combine_runtime_type_check_using_methods(self):
  2016      self.p._options.view_as(TypeOptions).pipeline_type_check = False
  2017      self.p._options.view_as(TypeOptions).runtime_type_check = True
  2018  
  2019      d = (
  2020          self.p
  2021          | beam.Create(range(5)).with_output_types(int)
  2022          | (
  2023              'Sum' >> beam.CombineGlobally(lambda s: sum(s)).with_input_types(
  2024                  int).with_output_types(int)))
  2025  
  2026      assert_that(d, equal_to([10]))
  2027      self.p.run()
  2028  
  2029    def test_combine_pipeline_type_check_violation_using_methods(self):
  2030      with self.assertRaises(typehints.TypeCheckError) as e:
  2031        (
  2032            self.p
  2033            | beam.Create(range(3)).with_output_types(int)
  2034            | (
  2035                'SortJoin' >> beam.CombineGlobally(lambda s: ''.join(sorted(s))).
  2036                with_input_types(str).with_output_types(str)))
  2037  
  2038      self.assertStartswith(
  2039          e.exception.args[0],
  2040          "Input type hint violation at SortJoin: "
  2041          "expected {}, got {}".format(str, int))
  2042  
  2043    def test_combine_runtime_type_check_violation_using_methods(self):
  2044      self.p._options.view_as(TypeOptions).pipeline_type_check = False
  2045      self.p._options.view_as(TypeOptions).runtime_type_check = True
  2046  
  2047      with self.assertRaises(typehints.TypeCheckError) as e:
  2048        (
  2049            self.p
  2050            | beam.Create([0]).with_output_types(int)
  2051            | (
  2052                'SortJoin' >> beam.CombineGlobally(lambda s: ''.join(sorted(s))).
  2053                with_input_types(str).with_output_types(str)))
  2054        self.p.run()
  2055  
  2056      self.assertStartswith(
  2057          e.exception.args[0],
  2058          "Runtime type violation detected within "
  2059          "ParDo(SortJoin/KeyWithVoid): "
  2060          "Type-hint for argument: 'v' violated. "
  2061          "Expected an instance of {}, "
  2062          "instead found 0, an instance of {}.".format(str, int))
  2063  
  2064    def test_combine_insufficient_type_hint_information(self):
  2065      self.p._options.view_as(TypeOptions).type_check_strictness = 'ALL_REQUIRED'
  2066  
  2067      with self.assertRaises(typehints.TypeCheckError) as e:
  2068        (
  2069            self.p
  2070            | 'E' >> beam.Create(range(3)).with_output_types(int)
  2071            | 'SortJoin' >> beam.CombineGlobally(lambda s: ''.join(sorted(s)))
  2072            | 'F' >> beam.Map(lambda x: x + 1))
  2073  
  2074      self.assertStartswith(
  2075          e.exception.args[0],
  2076          'Pipeline type checking is enabled, '
  2077          'however no output type-hint was found for the PTransform '
  2078          'ParDo('
  2079          'SortJoin/CombinePerKey/')
  2080  
  2081    def test_mean_globally_pipeline_checking_satisfied(self):
  2082      d = (
  2083          self.p
  2084          | 'C' >> beam.Create(range(5)).with_output_types(int)
  2085          | 'Mean' >> combine.Mean.Globally())
  2086  
  2087      self.assertEqual(float, d.element_type)
  2088      assert_that(d, equal_to([2.0]))
  2089      self.p.run()
  2090  
  2091    def test_mean_globally_pipeline_checking_violated(self):
  2092      with self.assertRaises(typehints.TypeCheckError) as e:
  2093        (
  2094            self.p
  2095            | 'C' >> beam.Create(['test']).with_output_types(str)
  2096            | 'Mean' >> combine.Mean.Globally())
  2097  
  2098      expected_msg = \
  2099        "Type hint violation for 'CombinePerKey': " \
  2100        "requires Tuple[TypeVariable[K], Union[<class 'float'>, <class 'int'>, " \
  2101        "<class 'numpy.float64'>, <class 'numpy.int64'>]] " \
  2102        "but got Tuple[None, <class 'str'>] for element"
  2103  
  2104      self.assertStartswith(e.exception.args[0], expected_msg)
  2105  
  2106    def test_mean_globally_runtime_checking_satisfied(self):
  2107      self.p._options.view_as(TypeOptions).runtime_type_check = True
  2108  
  2109      d = (
  2110          self.p
  2111          | 'C' >> beam.Create(range(5)).with_output_types(int)
  2112          | 'Mean' >> combine.Mean.Globally())
  2113  
  2114      self.assertEqual(float, d.element_type)
  2115      assert_that(d, equal_to([2.0]))
  2116      self.p.run()
  2117  
  2118    def test_mean_globally_runtime_checking_violated(self):
  2119      self.p._options.view_as(TypeOptions).pipeline_type_check = False
  2120      self.p._options.view_as(TypeOptions).runtime_type_check = True
  2121  
  2122      with self.assertRaises(typehints.TypeCheckError) as e:
  2123        (
  2124            self.p
  2125            | 'C' >> beam.Create(['t', 'e', 's', 't']).with_output_types(str)
  2126            | 'Mean' >> combine.Mean.Globally())
  2127        self.p.run()
  2128        self.assertEqual(
  2129            "Runtime type violation detected for transform input "
  2130            "when executing ParDoFlatMap(Combine): Tuple[Any, "
  2131            "Iterable[Union[int, float]]] hint type-constraint "
  2132            "violated. The type of element #1 in the passed tuple "
  2133            "is incorrect. Iterable[Union[int, float]] hint "
  2134            "type-constraint violated. The type of element #0 in "
  2135            "the passed Iterable is incorrect: Union[int, float] "
  2136            "type-constraint violated. Expected an instance of one "
  2137            "of: ('int', 'float'), received str instead.",
  2138            e.exception.args[0])
  2139  
  2140    def test_mean_per_key_pipeline_checking_satisfied(self):
  2141      d = (
  2142          self.p
  2143          | beam.Create(range(5)).with_output_types(int)
  2144          | (
  2145              'EvenGroup' >> beam.Map(lambda x: (not x % 2, x)).with_output_types(
  2146                  typing.Tuple[bool, int]))
  2147          | 'EvenMean' >> combine.Mean.PerKey())
  2148  
  2149      self.assertCompatible(typing.Tuple[bool, float], d.element_type)
  2150      assert_that(d, equal_to([(False, 2.0), (True, 2.0)]))
  2151      self.p.run()
  2152  
  2153    def test_mean_per_key_pipeline_checking_violated(self):
  2154      with self.assertRaises(typehints.TypeCheckError) as e:
  2155        (
  2156            self.p
  2157            | beam.Create(map(str, range(5))).with_output_types(str)
  2158            | (
  2159                'UpperPair' >> beam.Map(lambda x:
  2160                                        (x.upper(), x)).with_output_types(
  2161                                            typing.Tuple[str, str]))
  2162            | 'EvenMean' >> combine.Mean.PerKey())
  2163        self.p.run()
  2164  
  2165      expected_msg = \
  2166        "Type hint violation for 'CombinePerKey(MeanCombineFn)': " \
  2167        "requires Tuple[TypeVariable[K], Union[<class 'float'>, <class 'int'>, " \
  2168        "<class 'numpy.float64'>, <class 'numpy.int64'>]] " \
  2169        "but got Tuple[<class 'str'>, <class 'str'>] for element"
  2170  
  2171      self.assertStartswith(e.exception.args[0], expected_msg)
  2172  
  2173    def test_mean_per_key_runtime_checking_satisfied(self):
  2174      self.p._options.view_as(TypeOptions).runtime_type_check = True
  2175  
  2176      d = (
  2177          self.p
  2178          | beam.Create(range(5)).with_output_types(int)
  2179          | (
  2180              'OddGroup' >> beam.Map(lambda x:
  2181                                     (bool(x % 2), x)).with_output_types(
  2182                                         typing.Tuple[bool, int]))
  2183          | 'OddMean' >> combine.Mean.PerKey())
  2184  
  2185      self.assertCompatible(typing.Tuple[bool, float], d.element_type)
  2186      assert_that(d, equal_to([(False, 2.0), (True, 2.0)]))
  2187      self.p.run()
  2188  
  2189    def test_mean_per_key_runtime_checking_violated(self):
  2190      self.p._options.view_as(TypeOptions).pipeline_type_check = False
  2191      self.p._options.view_as(TypeOptions).runtime_type_check = True
  2192  
  2193      with self.assertRaises(typehints.TypeCheckError) as e:
  2194        (
  2195            self.p
  2196            | beam.Create(range(5)).with_output_types(int)
  2197            | (
  2198                'OddGroup' >> beam.Map(lambda x:
  2199                                       (x, str(bool(x % 2)))).with_output_types(
  2200                                           typing.Tuple[int, str]))
  2201            | 'OddMean' >> combine.Mean.PerKey())
  2202        self.p.run()
  2203  
  2204      expected_msg = \
  2205        "Runtime type violation detected within " \
  2206        "OddMean/CombinePerKey(MeanCombineFn): " \
  2207        "Type-hint for argument: 'element' violated: " \
  2208        "Union[<class 'float'>, <class 'int'>, <class 'numpy.float64'>, <class " \
  2209        "'numpy.int64'>] type-constraint violated. " \
  2210        "Expected an instance of one of: (\"<class 'float'>\", \"<class " \
  2211        "'int'>\", \"<class 'numpy.float64'>\", \"<class 'numpy.int64'>\"), " \
  2212        "received str instead"
  2213  
  2214      self.assertStartswith(e.exception.args[0], expected_msg)
  2215  
  2216    def test_count_globally_pipeline_type_checking_satisfied(self):
  2217      d = (
  2218          self.p
  2219          | 'P' >> beam.Create(range(5)).with_output_types(int)
  2220          | 'CountInt' >> combine.Count.Globally())
  2221  
  2222      self.assertEqual(int, d.element_type)
  2223      assert_that(d, equal_to([5]))
  2224      self.p.run()
  2225  
  2226    def test_count_globally_runtime_type_checking_satisfied(self):
  2227      self.p._options.view_as(TypeOptions).runtime_type_check = True
  2228  
  2229      d = (
  2230          self.p
  2231          | 'P' >> beam.Create(range(5)).with_output_types(int)
  2232          | 'CountInt' >> combine.Count.Globally())
  2233  
  2234      self.assertEqual(int, d.element_type)
  2235      assert_that(d, equal_to([5]))
  2236      self.p.run()
  2237  
  2238    def test_count_perkey_pipeline_type_checking_satisfied(self):
  2239      d = (
  2240          self.p
  2241          | beam.Create(range(5)).with_output_types(int)
  2242          | 'EvenGroup' >> beam.Map(lambda x: (not x % 2, x)).with_output_types(
  2243              typing.Tuple[bool, int])
  2244          | 'CountInt' >> combine.Count.PerKey())
  2245  
  2246      self.assertCompatible(typing.Tuple[bool, int], d.element_type)
  2247      assert_that(d, equal_to([(False, 2), (True, 3)]))
  2248      self.p.run()
  2249  
  2250    def test_count_perkey_pipeline_type_checking_violated(self):
  2251      with self.assertRaises(typehints.TypeCheckError) as e:
  2252        (
  2253            self.p
  2254            | beam.Create(range(5)).with_output_types(int)
  2255            | 'CountInt' >> combine.Count.PerKey())
  2256  
  2257      self.assertStartswith(
  2258          e.exception.args[0], 'Input type hint violation at CountInt')
  2259  
  2260    def test_count_perkey_runtime_type_checking_satisfied(self):
  2261      self.p._options.view_as(TypeOptions).runtime_type_check = True
  2262  
  2263      d = (
  2264          self.p
  2265          | beam.Create(['t', 'e', 's', 't']).with_output_types(str)
  2266          | 'DupKey' >> beam.Map(lambda x: (x, x)).with_output_types(
  2267              typing.Tuple[str, str])
  2268          | 'CountDups' >> combine.Count.PerKey())
  2269  
  2270      self.assertCompatible(typing.Tuple[str, int], d.element_type)
  2271      assert_that(d, equal_to([('e', 1), ('s', 1), ('t', 2)]))
  2272      self.p.run()
  2273  
  2274    def test_count_perelement_pipeline_type_checking_satisfied(self):
  2275      d = (
  2276          self.p
  2277          | beam.Create([1, 1, 2, 3]).with_output_types(int)
  2278          | 'CountElems' >> combine.Count.PerElement())
  2279  
  2280      self.assertCompatible(typing.Tuple[int, int], d.element_type)
  2281      assert_that(d, equal_to([(1, 2), (2, 1), (3, 1)]))
  2282      self.p.run()
  2283  
  2284    def test_count_perelement_pipeline_type_checking_violated(self):
  2285      self.p._options.view_as(TypeOptions).type_check_strictness = 'ALL_REQUIRED'
  2286  
  2287      with self.assertRaises(typehints.TypeCheckError) as e:
  2288        (
  2289            self.p
  2290            | 'f' >> beam.Create([1, 1, 2, 3])
  2291            | 'CountElems' >> combine.Count.PerElement())
  2292  
  2293      self.assertEqual(
  2294          'Pipeline type checking is enabled, however no output '
  2295          'type-hint was found for the PTransform '
  2296          'Create(f)',
  2297          e.exception.args[0])
  2298  
  2299    def test_count_perelement_runtime_type_checking_satisfied(self):
  2300      self.p._options.view_as(TypeOptions).runtime_type_check = True
  2301  
  2302      d = (
  2303          self.p
  2304          | beam.Create([True, True, False, True, True]).with_output_types(bool)
  2305          | 'CountElems' >> combine.Count.PerElement())
  2306  
  2307      self.assertCompatible(typing.Tuple[bool, int], d.element_type)
  2308      assert_that(d, equal_to([(False, 1), (True, 4)]))
  2309      self.p.run()
  2310  
  2311    def test_top_of_pipeline_checking_satisfied(self):
  2312      d = (
  2313          self.p
  2314          | beam.Create(range(5, 11)).with_output_types(int)
  2315          | 'Top 3' >> combine.Top.Of(3))
  2316  
  2317      self.assertCompatible(typing.Iterable[int], d.element_type)
  2318      assert_that(d, equal_to([[10, 9, 8]]))
  2319      self.p.run()
  2320  
  2321    def test_top_of_runtime_checking_satisfied(self):
  2322      self.p._options.view_as(TypeOptions).runtime_type_check = True
  2323  
  2324      d = (
  2325          self.p
  2326          | beam.Create(list('testing')).with_output_types(str)
  2327          | 'AciiTop' >> combine.Top.Of(3))
  2328  
  2329      self.assertCompatible(typing.Iterable[str], d.element_type)
  2330      assert_that(d, equal_to([['t', 't', 's']]))
  2331      self.p.run()
  2332  
  2333    def test_per_key_pipeline_checking_violated(self):
  2334      with self.assertRaises(typehints.TypeCheckError) as e:
  2335        (
  2336            self.p
  2337            | beam.Create(range(100)).with_output_types(int)
  2338            | 'Num + 1' >> beam.Map(lambda x: x + 1).with_output_types(int)
  2339            | 'TopMod' >> combine.Top.PerKey(1))
  2340  
  2341      self.assertStartswith(
  2342          e.exception.args[0],
  2343          "Input type hint violation at TopMod: expected Tuple[TypeVariable[K], "
  2344          "TypeVariable[V]], got {}".format(int))
  2345  
  2346    def test_per_key_pipeline_checking_satisfied(self):
  2347      d = (
  2348          self.p
  2349          | beam.Create(range(100)).with_output_types(int)
  2350          | (
  2351              'GroupMod 3' >> beam.Map(lambda x: (x % 3, x)).with_output_types(
  2352                  typing.Tuple[int, int]))
  2353          | 'TopMod' >> combine.Top.PerKey(1))
  2354  
  2355      self.assertCompatible(
  2356          typing.Tuple[int, typing.Iterable[int]], d.element_type)
  2357      assert_that(d, equal_to([(0, [99]), (1, [97]), (2, [98])]))
  2358      self.p.run()
  2359  
  2360    def test_per_key_runtime_checking_satisfied(self):
  2361      self.p._options.view_as(TypeOptions).runtime_type_check = True
  2362  
  2363      d = (
  2364          self.p
  2365          | beam.Create(range(21))
  2366          | (
  2367              'GroupMod 3' >> beam.Map(lambda x: (x % 3, x)).with_output_types(
  2368                  typing.Tuple[int, int]))
  2369          | 'TopMod' >> combine.Top.PerKey(1))
  2370  
  2371      self.assertCompatible(
  2372          typing.Tuple[int, typing.Iterable[int]], d.element_type)
  2373      assert_that(d, equal_to([(0, [18]), (1, [19]), (2, [20])]))
  2374      self.p.run()
  2375  
  2376    def test_sample_globally_pipeline_satisfied(self):
  2377      d = (
  2378          self.p
  2379          | beam.Create([2, 2, 3, 3]).with_output_types(int)
  2380          | 'Sample' >> combine.Sample.FixedSizeGlobally(3))
  2381  
  2382      self.assertCompatible(typing.Iterable[int], d.element_type)
  2383  
  2384      def matcher(expected_len):
  2385        def match(actual):
  2386          equal_to([expected_len])([len(actual[0])])
  2387  
  2388        return match
  2389  
  2390      assert_that(d, matcher(3))
  2391      self.p.run()
  2392  
  2393    def test_sample_globally_runtime_satisfied(self):
  2394      self.p._options.view_as(TypeOptions).runtime_type_check = True
  2395  
  2396      d = (
  2397          self.p
  2398          | beam.Create([2, 2, 3, 3]).with_output_types(int)
  2399          | 'Sample' >> combine.Sample.FixedSizeGlobally(2))
  2400  
  2401      self.assertCompatible(typing.Iterable[int], d.element_type)
  2402  
  2403      def matcher(expected_len):
  2404        def match(actual):
  2405          equal_to([expected_len])([len(actual[0])])
  2406  
  2407        return match
  2408  
  2409      assert_that(d, matcher(2))
  2410      self.p.run()
  2411  
  2412    def test_sample_per_key_pipeline_satisfied(self):
  2413      d = (
  2414          self.p
  2415          | (
  2416              beam.Create([(1, 2), (1, 2), (2, 3),
  2417                           (2, 3)]).with_output_types(typing.Tuple[int, int]))
  2418          | 'Sample' >> combine.Sample.FixedSizePerKey(2))
  2419  
  2420      self.assertCompatible(
  2421          typing.Tuple[int, typing.Iterable[int]], d.element_type)
  2422  
  2423      def matcher(expected_len):
  2424        def match(actual):
  2425          for _, sample in actual:
  2426            equal_to([expected_len])([len(sample)])
  2427  
  2428        return match
  2429  
  2430      assert_that(d, matcher(2))
  2431      self.p.run()
  2432  
  2433    def test_sample_per_key_runtime_satisfied(self):
  2434      self.p._options.view_as(TypeOptions).runtime_type_check = True
  2435  
  2436      d = (
  2437          self.p
  2438          | (
  2439              beam.Create([(1, 2), (1, 2), (2, 3),
  2440                           (2, 3)]).with_output_types(typing.Tuple[int, int]))
  2441          | 'Sample' >> combine.Sample.FixedSizePerKey(1))
  2442  
  2443      self.assertCompatible(
  2444          typing.Tuple[int, typing.Iterable[int]], d.element_type)
  2445  
  2446      def matcher(expected_len):
  2447        def match(actual):
  2448          for _, sample in actual:
  2449            equal_to([expected_len])([len(sample)])
  2450  
  2451        return match
  2452  
  2453      assert_that(d, matcher(1))
  2454      self.p.run()
  2455  
  2456    def test_to_list_pipeline_check_satisfied(self):
  2457      d = (
  2458          self.p
  2459          | beam.Create((1, 2, 3, 4)).with_output_types(int)
  2460          | combine.ToList())
  2461  
  2462      self.assertCompatible(typing.List[int], d.element_type)
  2463  
  2464      def matcher(expected):
  2465        def match(actual):
  2466          equal_to(expected)(actual[0])
  2467  
  2468        return match
  2469  
  2470      assert_that(d, matcher([1, 2, 3, 4]))
  2471      self.p.run()
  2472  
  2473    def test_to_list_runtime_check_satisfied(self):
  2474      self.p._options.view_as(TypeOptions).runtime_type_check = True
  2475  
  2476      d = (
  2477          self.p
  2478          | beam.Create(list('test')).with_output_types(str)
  2479          | combine.ToList())
  2480  
  2481      self.assertCompatible(typing.List[str], d.element_type)
  2482  
  2483      def matcher(expected):
  2484        def match(actual):
  2485          equal_to(expected)(actual[0])
  2486  
  2487        return match
  2488  
  2489      assert_that(d, matcher(['e', 's', 't', 't']))
  2490      self.p.run()
  2491  
  2492    def test_to_dict_pipeline_check_violated(self):
  2493      with self.assertRaises(typehints.TypeCheckError) as e:
  2494        (
  2495            self.p
  2496            | beam.Create([1, 2, 3, 4]).with_output_types(int)
  2497            | combine.ToDict())
  2498  
  2499      self.assertStartswith(
  2500          e.exception.args[0],
  2501          "Input type hint violation at ToDict: expected Tuple[TypeVariable[K], "
  2502          "TypeVariable[V]], got {}".format(int))
  2503  
  2504    def test_to_dict_pipeline_check_satisfied(self):
  2505      d = (
  2506          self.p
  2507          | beam.Create([(1, 2),
  2508                         (3, 4)]).with_output_types(typing.Tuple[int, int])
  2509          | combine.ToDict())
  2510  
  2511      self.assertCompatible(typing.Dict[int, int], d.element_type)
  2512      assert_that(d, equal_to([{1: 2, 3: 4}]))
  2513      self.p.run()
  2514  
  2515    def test_to_dict_runtime_check_satisfied(self):
  2516      self.p._options.view_as(TypeOptions).runtime_type_check = True
  2517  
  2518      d = (
  2519          self.p
  2520          | (
  2521              beam.Create([('1', 2),
  2522                           ('3', 4)]).with_output_types(typing.Tuple[str, int]))
  2523          | combine.ToDict())
  2524  
  2525      self.assertCompatible(typing.Dict[str, int], d.element_type)
  2526      assert_that(d, equal_to([{'1': 2, '3': 4}]))
  2527      self.p.run()
  2528  
  2529    def test_runtime_type_check_python_type_error(self):
  2530      self.p._options.view_as(TypeOptions).runtime_type_check = True
  2531  
  2532      with self.assertRaises(TypeError) as e:
  2533        (
  2534            self.p
  2535            | beam.Create([1, 2, 3]).with_output_types(int)
  2536            | 'Len' >> beam.Map(lambda x: len(x)).with_output_types(int))
  2537        self.p.run()
  2538  
  2539      # Our special type-checking related TypeError shouldn't have been raised.
  2540      # Instead the above pipeline should have triggered a regular Python runtime
  2541      # TypeError.
  2542      self.assertEqual(
  2543          "object of type 'int' has no len() [while running 'Len']",
  2544          e.exception.args[0])
  2545      self.assertFalse(isinstance(e, typehints.TypeCheckError))
  2546  
  2547    def test_pardo_type_inference(self):
  2548      self.assertEqual(int, beam.Filter(lambda x: False).infer_output_type(int))
  2549      self.assertEqual(
  2550          typehints.Tuple[str, int],
  2551          beam.Map(lambda x: (x, 1)).infer_output_type(str))
  2552  
  2553    def test_gbk_type_inference(self):
  2554      self.assertEqual(
  2555          typehints.Tuple[str, typehints.Iterable[int]],
  2556          beam.GroupByKey().infer_output_type(typehints.KV[str, int]))
  2557  
  2558    def test_pipeline_inference(self):
  2559      created = self.p | beam.Create(['a', 'b', 'c'])
  2560      mapped = created | 'pair with 1' >> beam.Map(lambda x: (x, 1))
  2561      grouped = mapped | beam.GroupByKey()
  2562      self.assertEqual(str, created.element_type)
  2563      self.assertEqual(typehints.KV[str, int], mapped.element_type)
  2564      self.assertEqual(
  2565          typehints.KV[str, typehints.Iterable[int]], grouped.element_type)
  2566  
  2567    def test_inferred_bad_kv_type(self):
  2568      with self.assertRaises(typehints.TypeCheckError) as e:
  2569        _ = (
  2570            self.p
  2571            | beam.Create(['a', 'b', 'c'])
  2572            | 'Ungroupable' >> beam.Map(lambda x: (x, 0, 1.0))
  2573            | beam.GroupByKey())
  2574  
  2575      self.assertStartswith(
  2576          e.exception.args[0],
  2577          "Input type hint violation at GroupByKey: "
  2578          "expected Tuple[TypeVariable[K], TypeVariable[V]], "
  2579          "got Tuple[<class 'str'>, <class 'int'>, <class 'float'>]")
  2580  
  2581    def test_type_inference_command_line_flag_toggle(self):
  2582      self.p._options.view_as(TypeOptions).pipeline_type_check = False
  2583      x = self.p | 'C1' >> beam.Create([1, 2, 3, 4])
  2584      self.assertIsNone(x.element_type)
  2585  
  2586      self.p._options.view_as(TypeOptions).pipeline_type_check = True
  2587      x = self.p | 'C2' >> beam.Create([1, 2, 3, 4])
  2588      self.assertEqual(int, x.element_type)
  2589  
  2590    def test_eager_execution(self):
  2591      doubled = [1, 2, 3, 4] | beam.Map(lambda x: 2 * x)
  2592      self.assertEqual([2, 4, 6, 8], doubled)
  2593  
  2594    def test_eager_execution_tagged_outputs(self):
  2595      result = [1, 2, 3, 4] | beam.Map(
  2596          lambda x: pvalue.TaggedOutput('bar', 2 * x)).with_outputs('bar')
  2597      self.assertEqual([2, 4, 6, 8], result.bar)
  2598      with self.assertRaises(KeyError,
  2599                             msg='Tag \'foo\' is not a defined output tag'):
  2600        result.foo
  2601  
  2602  
  2603  @parameterized_class([{'use_subprocess': False}, {'use_subprocess': True}])
  2604  class DeadLettersTest(unittest.TestCase):
  2605    @classmethod
  2606    def die(cls, x):
  2607      if cls.use_subprocess:
  2608        os._exit(x)
  2609      else:
  2610        raise ValueError(x)
  2611  
  2612    @classmethod
  2613    def die_if_negative(cls, x):
  2614      if x < 0:
  2615        cls.die(x)
  2616      else:
  2617        return x
  2618  
  2619    @classmethod
  2620    def exception_if_negative(cls, x):
  2621      if x < 0:
  2622        raise ValueError(x)
  2623      else:
  2624        return x
  2625  
  2626    @classmethod
  2627    def die_if_less(cls, x, bound=0):
  2628      if x < bound:
  2629        cls.die(x)
  2630      else:
  2631        return x, bound
  2632  
  2633    def test_error_messages(self):
  2634      with TestPipeline() as p:
  2635        good, bad = (
  2636            p
  2637            | beam.Create([-1, 10, -100, 2, 0])
  2638            | beam.Map(self.exception_if_negative).with_exception_handling())
  2639        assert_that(good, equal_to([0, 2, 10]), label='CheckGood')
  2640        assert_that(
  2641            bad |
  2642            beam.MapTuple(lambda e, exc_info: (e, exc_info[1].replace(',', ''))),
  2643            equal_to([(-1, 'ValueError(-1)'), (-100, 'ValueError(-100)')]),
  2644            label='CheckBad')
  2645  
  2646    def test_filters_exceptions(self):
  2647      with TestPipeline() as p:
  2648        good, _ = (
  2649            p
  2650            | beam.Create([-1, 10, -100, 2, 0])
  2651            | beam.Map(self.exception_if_negative).with_exception_handling(
  2652                use_subprocess=self.use_subprocess,
  2653                exc_class=(ValueError, TypeError)))
  2654        assert_that(good, equal_to([0, 2, 10]), label='CheckGood')
  2655  
  2656      with self.assertRaises(Exception):
  2657        with TestPipeline() as p:
  2658          good, _ = (
  2659              p
  2660              | beam.Create([-1, 10, -100, 2, 0])
  2661              | beam.Map(self.die_if_negative).with_exception_handling(
  2662                  use_subprocess=self.use_subprocess,
  2663                  exc_class=TypeError))
  2664  
  2665    def test_tuples(self):
  2666  
  2667      with TestPipeline() as p:
  2668        good, _ = (
  2669            p
  2670            | beam.Create([(1, 2), (3, 2), (1, -10)])
  2671            | beam.MapTuple(self.die_if_less).with_exception_handling(
  2672                use_subprocess=self.use_subprocess))
  2673        assert_that(good, equal_to([(3, 2), (1, -10)]), label='CheckGood')
  2674  
  2675    def test_side_inputs(self):
  2676  
  2677      with TestPipeline() as p:
  2678        input = p | beam.Create([-1, 10, 100])
  2679  
  2680        assert_that((
  2681            input
  2682            | 'Default' >> beam.Map(self.die_if_less).with_exception_handling(
  2683                use_subprocess=self.use_subprocess)).good,
  2684                    equal_to([(10, 0), (100, 0)]),
  2685                    label='CheckDefault')
  2686        assert_that((
  2687            input
  2688            | 'Pos' >> beam.Map(self.die_if_less, 20).with_exception_handling(
  2689                use_subprocess=self.use_subprocess)).good,
  2690                    equal_to([(100, 20)]),
  2691                    label='PosSideInput')
  2692        assert_that((
  2693            input
  2694            |
  2695            'Key' >> beam.Map(self.die_if_less, bound=30).with_exception_handling(
  2696                use_subprocess=self.use_subprocess)).good,
  2697                    equal_to([(100, 30)]),
  2698                    label='KeySideInput')
  2699  
  2700    def test_multiple_outputs(self):
  2701      die = type(self).die
  2702  
  2703      def die_on_negative_even_odd(x):
  2704        if x < 0:
  2705          die(x)
  2706        elif x % 2 == 0:
  2707          return pvalue.TaggedOutput('even', x)
  2708        elif x % 2 == 1:
  2709          return pvalue.TaggedOutput('odd', x)
  2710  
  2711      with TestPipeline() as p:
  2712        results = (
  2713            p
  2714            | beam.Create([1, -1, 2, -2, 3])
  2715            | beam.Map(die_on_negative_even_odd).with_exception_handling(
  2716                use_subprocess=self.use_subprocess))
  2717        assert_that(results.even, equal_to([2]), label='CheckEven')
  2718        assert_that(results.odd, equal_to([1, 3]), label='CheckOdd')
  2719  
  2720    def test_params(self):
  2721      die = type(self).die
  2722  
  2723      def die_if_negative_with_timestamp(x, ts=beam.DoFn.TimestampParam):
  2724        if x < 0:
  2725          die(x)
  2726        else:
  2727          return x, ts
  2728  
  2729      with TestPipeline() as p:
  2730        good, _ = (
  2731            p
  2732            | beam.Create([-1, 0, 1])
  2733            | beam.Map(lambda x: TimestampedValue(x, x))
  2734            | beam.Map(die_if_negative_with_timestamp).with_exception_handling(
  2735                use_subprocess=self.use_subprocess))
  2736        assert_that(good, equal_to([(0, Timestamp(0)), (1, Timestamp(1))]))
  2737  
  2738    def test_timeout(self):
  2739      import time
  2740      timeout = 1 if self.use_subprocess else .1
  2741  
  2742      with TestPipeline() as p:
  2743        good, bad = (
  2744            p
  2745            | beam.Create('records starting with lowercase S are slow'.split())
  2746            | beam.Map(
  2747                lambda x: time.sleep(2.5 * timeout) if x.startswith('s') else x)
  2748            .with_exception_handling(
  2749                use_subprocess=self.use_subprocess, timeout=timeout))
  2750        assert_that(
  2751            good,
  2752            equal_to(['records', 'with', 'lowercase', 'S', 'are']),
  2753            label='CheckGood')
  2754        assert_that(
  2755            bad |
  2756            beam.MapTuple(lambda e, exc_info: (e, exc_info[1].replace(',', ''))),
  2757            equal_to([('starting', 'TimeoutError()'),
  2758                      ('slow', 'TimeoutError()')]),
  2759            label='CheckBad')
  2760  
  2761    def test_lifecycle(self):
  2762      die = type(self).die
  2763  
  2764      class MyDoFn(beam.DoFn):
  2765        state = None
  2766  
  2767        def setup(self):
  2768          assert self.state is None
  2769          self.state = 'setup'
  2770  
  2771        def start_bundle(self):
  2772          assert self.state in ('setup', 'finish_bundle'), self.state
  2773          self.state = 'start_bundle'
  2774  
  2775        def finish_bundle(self):
  2776          assert self.state in ('start_bundle', ), self.state
  2777          self.state = 'finish_bundle'
  2778  
  2779        def teardown(self):
  2780          assert self.state in ('setup', 'finish_bundle'), self.state
  2781          self.state = 'teardown'
  2782  
  2783        def process(self, x):
  2784          if x < 0:
  2785            die(x)
  2786          else:
  2787            yield self.state
  2788  
  2789      with TestPipeline() as p:
  2790        good, _ = (
  2791            p
  2792            | beam.Create([-1, 0, 1, -10, 10])
  2793            | beam.ParDo(MyDoFn()).with_exception_handling(
  2794                use_subprocess=self.use_subprocess))
  2795        assert_that(good, equal_to(['start_bundle'] * 3))
  2796  
  2797    def test_partial(self):
  2798      if self.use_subprocess:
  2799        self.skipTest('Subprocess and partial mutally exclusive.')
  2800  
  2801      def die_if_negative_iter(elements):
  2802        for element in elements:
  2803          if element < 0:
  2804            raise ValueError(element)
  2805          yield element
  2806  
  2807      with TestPipeline() as p:
  2808        input = p | beam.Create([(-1, 1, 11), (2, -2, 22), (3, 33, -3), (4, 44)])
  2809  
  2810        assert_that((
  2811            input
  2812            | 'Partial' >> beam.FlatMap(
  2813                die_if_negative_iter).with_exception_handling(partial=True)).good,
  2814                    equal_to([2, 3, 33, 4, 44]),
  2815                    'CheckPartial')
  2816  
  2817        assert_that((
  2818            input
  2819            | 'Complete' >> beam.FlatMap(die_if_negative_iter).
  2820            with_exception_handling(partial=False)).good,
  2821                    equal_to([4, 44]),
  2822                    'CheckComplete')
  2823  
  2824    def test_threshold(self):
  2825      # The threshold is high enough.
  2826      with TestPipeline() as p:
  2827        _ = (
  2828            p
  2829            | beam.Create([-1, -2, 0, 1, 2, 3, 4, 5])
  2830            | beam.Map(self.die_if_negative).with_exception_handling(
  2831                threshold=0.5, use_subprocess=self.use_subprocess))
  2832  
  2833      # The threshold is too low enough.
  2834      with self.assertRaisesRegex(Exception, "2 / 8 = 0.25 > 0.1"):
  2835        with TestPipeline() as p:
  2836          _ = (
  2837              p
  2838              | beam.Create([-1, -2, 0, 1, 2, 3, 4, 5])
  2839              | beam.Map(self.die_if_negative).with_exception_handling(
  2840                  threshold=0.1, use_subprocess=self.use_subprocess))
  2841  
  2842      # The threshold is too low per window.
  2843      with self.assertRaisesRegex(Exception, "2 / 2 = 1.0 > 0.5"):
  2844        with TestPipeline() as p:
  2845          _ = (
  2846              p
  2847              | beam.Create([-1, -2, 0, 1, 2, 3, 4, 5])
  2848              | beam.Map(lambda x: TimestampedValue(x, x))
  2849              | beam.Map(self.die_if_negative).with_exception_handling(
  2850                  threshold=0.5,
  2851                  threshold_windowing=window.FixedWindows(10),
  2852                  use_subprocess=self.use_subprocess))
  2853  
  2854  
  2855  class TestPTransformFn(TypeHintTestCase):
  2856    def test_type_checking_fail(self):
  2857      @beam.ptransform_fn
  2858      def MyTransform(pcoll):
  2859        return pcoll | beam.ParDo(lambda x: [x]).with_output_types(str)
  2860  
  2861      p = TestPipeline()
  2862      with self.assertRaisesRegex(beam.typehints.TypeCheckError,
  2863                                  r'expected.*int.*got.*str'):
  2864        _ = (p | beam.Create([1, 2]) | MyTransform().with_output_types(int))
  2865  
  2866    def test_type_checking_success(self):
  2867      @beam.ptransform_fn
  2868      def MyTransform(pcoll):
  2869        return pcoll | beam.ParDo(lambda x: [x]).with_output_types(int)
  2870  
  2871      with TestPipeline() as p:
  2872        _ = (p | beam.Create([1, 2]) | MyTransform().with_output_types(int))
  2873  
  2874    def test_type_hints_arg(self):
  2875      # Tests passing type hints via the magic 'type_hints' argument name.
  2876      @beam.ptransform_fn
  2877      def MyTransform(pcoll, type_hints, test_arg):
  2878        self.assertEqual(test_arg, 'test')
  2879        return (
  2880            pcoll
  2881            | beam.ParDo(lambda x: [x]).with_output_types(
  2882                type_hints.output_types[0][0]))
  2883  
  2884      with TestPipeline() as p:
  2885        _ = (p | beam.Create([1, 2]) | MyTransform('test').with_output_types(int))
  2886  
  2887  
  2888  class PickledObject(object):
  2889    def __init__(self, value):
  2890      self.value = value
  2891  
  2892  
  2893  if __name__ == '__main__':
  2894    unittest.main()