github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/examples/snippets/snippets_test.py (about)

     1  # coding=utf-8
     2  #
     3  # Licensed to the Apache Software Foundation (ASF) under one or more
     4  # contributor license agreements.  See the NOTICE file distributed with
     5  # this work for additional information regarding copyright ownership.
     6  # The ASF licenses this file to You under the Apache License, Version 2.0
     7  # (the "License"); you may not use this file except in compliance with
     8  # the License.  You may obtain a copy of the License at
     9  #
    10  #    http://www.apache.org/licenses/LICENSE-2.0
    11  #
    12  # Unless required by applicable law or agreed to in writing, software
    13  # distributed under the License is distributed on an "AS IS" BASIS,
    14  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    15  # See the License for the specific language governing permissions and
    16  # limitations under the License.
    17  #
    18  
    19  """Tests for all code snippets used in public docs."""
    20  # pytype: skip-file
    21  
    22  import gc
    23  import glob
    24  import gzip
    25  import logging
    26  import math
    27  import os
    28  import sys
    29  import tempfile
    30  import time
    31  import unittest
    32  import uuid
    33  
    34  import mock
    35  import parameterized
    36  
    37  import apache_beam as beam
    38  from apache_beam import WindowInto
    39  from apache_beam import coders
    40  from apache_beam import pvalue
    41  from apache_beam import typehints
    42  from apache_beam.coders.coders import ToBytesCoder
    43  from apache_beam.examples.snippets import snippets
    44  from apache_beam.metrics import Metrics
    45  from apache_beam.metrics.metric import MetricsFilter
    46  from apache_beam.options.pipeline_options import GoogleCloudOptions
    47  from apache_beam.options.pipeline_options import PipelineOptions
    48  from apache_beam.options.pipeline_options import StandardOptions
    49  from apache_beam.testing.test_pipeline import TestPipeline
    50  from apache_beam.testing.test_stream import TestStream
    51  from apache_beam.testing.util import assert_that
    52  from apache_beam.testing.util import equal_to
    53  from apache_beam.transforms import combiners
    54  from apache_beam.transforms.trigger import AccumulationMode
    55  from apache_beam.transforms.trigger import AfterAny
    56  from apache_beam.transforms.trigger import AfterCount
    57  from apache_beam.transforms.trigger import AfterProcessingTime
    58  from apache_beam.transforms.trigger import AfterWatermark
    59  from apache_beam.transforms.trigger import Repeatedly
    60  from apache_beam.transforms.window import FixedWindows
    61  from apache_beam.transforms.window import TimestampedValue
    62  from apache_beam.utils.windowed_value import WindowedValue
    63  
    64  # Protect against environments where apitools library is not available.
    65  # pylint: disable=wrong-import-order, wrong-import-position
    66  try:
    67    from apitools.base.py import base_api
    68  except ImportError:
    69    base_api = None
    70  # pylint: enable=wrong-import-order, wrong-import-position
    71  
    72  # Protect against environments where datastore library is not available.
    73  # pylint: disable=wrong-import-order, wrong-import-position
    74  try:
    75    from google.cloud.datastore import client as datastore_client
    76  except ImportError:
    77    datastore_client = None
    78  # pylint: enable=wrong-import-order, wrong-import-position
    79  
    80  # Protect against environments where the PubSub library is not available.
    81  # pylint: disable=wrong-import-order, wrong-import-position
    82  try:
    83    from google.cloud import pubsub
    84  except ImportError:
    85    pubsub = None
    86  # pylint: enable=wrong-import-order, wrong-import-position
    87  
    88  
    89  class ParDoTest(unittest.TestCase):
    90    """Tests for model/par-do."""
    91    def test_pardo(self):
    92      # Note: "words" and "ComputeWordLengthFn" are referenced by name in
    93      # the text of the doc.
    94  
    95      words = ['aa', 'bbb', 'c']
    96  
    97      # [START model_pardo_pardo]
    98      class ComputeWordLengthFn(beam.DoFn):
    99        def process(self, element):
   100          return [len(element)]
   101  
   102      # [END model_pardo_pardo]
   103  
   104      # [START model_pardo_apply]
   105      # Apply a ParDo to the PCollection "words" to compute lengths for each word.
   106      word_lengths = words | beam.ParDo(ComputeWordLengthFn())
   107      # [END model_pardo_apply]
   108      self.assertEqual({2, 3, 1}, set(word_lengths))
   109  
   110    def test_pardo_yield(self):
   111      words = ['aa', 'bbb', 'c']
   112  
   113      # [START model_pardo_yield]
   114      class ComputeWordLengthFn(beam.DoFn):
   115        def process(self, element):
   116          yield len(element)
   117  
   118      # [END model_pardo_yield]
   119  
   120      word_lengths = words | beam.ParDo(ComputeWordLengthFn())
   121      self.assertEqual({2, 3, 1}, set(word_lengths))
   122  
   123    def test_pardo_using_map(self):
   124      words = ['aa', 'bbb', 'c']
   125      # [START model_pardo_using_map]
   126      word_lengths = words | beam.Map(len)
   127      # [END model_pardo_using_map]
   128  
   129      self.assertEqual({2, 3, 1}, set(word_lengths))
   130  
   131    def test_pardo_using_flatmap(self):
   132      words = ['aa', 'bbb', 'c']
   133      # [START model_pardo_using_flatmap]
   134      word_lengths = words | beam.FlatMap(lambda word: [len(word)])
   135      # [END model_pardo_using_flatmap]
   136  
   137      self.assertEqual({2, 3, 1}, set(word_lengths))
   138  
   139    def test_pardo_using_flatmap_yield(self):
   140      words = ['aA', 'bbb', 'C']
   141  
   142      # [START model_pardo_using_flatmap_yield]
   143      def capitals(word):
   144        for letter in word:
   145          if 'A' <= letter <= 'Z':
   146            yield letter
   147  
   148      all_capitals = words | beam.FlatMap(capitals)
   149      # [END model_pardo_using_flatmap_yield]
   150  
   151      self.assertEqual({'A', 'C'}, set(all_capitals))
   152  
   153    def test_pardo_with_label(self):
   154      words = ['aa', 'bbc', 'defg']
   155      # [START model_pardo_with_label]
   156      result = words | 'CountUniqueLetters' >> beam.Map(
   157          lambda word: len(set(word)))
   158      # [END model_pardo_with_label]
   159  
   160      self.assertEqual({1, 2, 4}, set(result))
   161  
   162    def test_pardo_side_input(self):
   163      # pylint: disable=line-too-long
   164      with TestPipeline() as p:
   165        words = p | 'start' >> beam.Create(['a', 'bb', 'ccc', 'dddd'])
   166  
   167        # [START model_pardo_side_input]
   168        # Callable takes additional arguments.
   169        def filter_using_length(word, lower_bound, upper_bound=float('inf')):
   170          if lower_bound <= len(word) <= upper_bound:
   171            yield word
   172  
   173        # Construct a deferred side input.
   174        avg_word_len = (
   175            words
   176            | beam.Map(len)
   177            | beam.CombineGlobally(beam.combiners.MeanCombineFn()))
   178  
   179        # Call with explicit side inputs.
   180        small_words = words | 'small' >> beam.FlatMap(filter_using_length, 0, 3)
   181  
   182        # A single deferred side input.
   183        larger_than_average = (
   184            words | 'large' >> beam.FlatMap(
   185                filter_using_length, lower_bound=pvalue.AsSingleton(avg_word_len))
   186        )
   187  
   188        # Mix and match.
   189        small_but_nontrivial = words | beam.FlatMap(
   190            filter_using_length,
   191            lower_bound=2,
   192            upper_bound=pvalue.AsSingleton(avg_word_len))
   193        # [END model_pardo_side_input]
   194  
   195        assert_that(small_words, equal_to(['a', 'bb', 'ccc']))
   196        assert_that(
   197            larger_than_average,
   198            equal_to(['ccc', 'dddd']),
   199            label='larger_than_average')
   200        assert_that(
   201            small_but_nontrivial, equal_to(['bb']), label='small_but_not_trivial')
   202  
   203    def test_pardo_side_input_dofn(self):
   204      words = ['a', 'bb', 'ccc', 'dddd']
   205  
   206      # [START model_pardo_side_input_dofn]
   207      class FilterUsingLength(beam.DoFn):
   208        def process(self, element, lower_bound, upper_bound=float('inf')):
   209          if lower_bound <= len(element) <= upper_bound:
   210            yield element
   211  
   212      small_words = words | beam.ParDo(FilterUsingLength(), 0, 3)
   213      # [END model_pardo_side_input_dofn]
   214      self.assertEqual({'a', 'bb', 'ccc'}, set(small_words))
   215  
   216    def test_pardo_with_tagged_outputs(self):
   217      # [START model_pardo_emitting_values_on_tagged_outputs]
   218      class ProcessWords(beam.DoFn):
   219        def process(self, element, cutoff_length, marker):
   220          if len(element) <= cutoff_length:
   221            # Emit this short word to the main output.
   222            yield element
   223          else:
   224            # Emit this word's long length to the 'above_cutoff_lengths' output.
   225            yield pvalue.TaggedOutput('above_cutoff_lengths', len(element))
   226          if element.startswith(marker):
   227            # Emit this word to a different output with the 'marked strings' tag.
   228            yield pvalue.TaggedOutput('marked strings', element)
   229  
   230      # [END model_pardo_emitting_values_on_tagged_outputs]
   231  
   232      words = ['a', 'an', 'the', 'music', 'xyz']
   233  
   234      # [START model_pardo_with_tagged_outputs]
   235      results = (
   236          words
   237          | beam.ParDo(ProcessWords(), cutoff_length=2, marker='x').with_outputs(
   238              'above_cutoff_lengths',
   239              'marked strings',
   240              main='below_cutoff_strings'))
   241      below = results.below_cutoff_strings
   242      above = results.above_cutoff_lengths
   243      marked = results['marked strings']  # indexing works as well
   244      # [END model_pardo_with_tagged_outputs]
   245  
   246      self.assertEqual({'a', 'an'}, set(below))
   247      self.assertEqual({3, 5}, set(above))
   248      self.assertEqual({'xyz'}, set(marked))
   249  
   250      # [START model_pardo_with_tagged_outputs_iter]
   251      below, above, marked = (words
   252                              | beam.ParDo(
   253                                  ProcessWords(), cutoff_length=2, marker='x')
   254                              .with_outputs('above_cutoff_lengths',
   255                                            'marked strings',
   256                                            main='below_cutoff_strings'))
   257      # [END model_pardo_with_tagged_outputs_iter]
   258  
   259      self.assertEqual({'a', 'an'}, set(below))
   260      self.assertEqual({3, 5}, set(above))
   261      self.assertEqual({'xyz'}, set(marked))
   262  
   263    def test_pardo_with_undeclared_outputs(self):
   264      # Note: the use of undeclared outputs is currently not supported in eager
   265      # execution mode.
   266      with TestPipeline() as p:
   267        numbers = p | beam.Create([1, 2, 3, 4, 5, 10, 20])
   268  
   269        # [START model_pardo_with_undeclared_outputs]
   270        def even_odd(x):
   271          yield pvalue.TaggedOutput('odd' if x % 2 else 'even', x)
   272          if x % 10 == 0:
   273            yield x
   274  
   275        results = numbers | beam.FlatMap(even_odd).with_outputs()
   276  
   277        evens = results.even
   278        odds = results.odd
   279        tens = results[None]  # the undeclared main output
   280        # [END model_pardo_with_undeclared_outputs]
   281  
   282        assert_that(evens, equal_to([2, 4, 10, 20]), label='assert_even')
   283        assert_that(odds, equal_to([1, 3, 5]), label='assert_odds')
   284        assert_that(tens, equal_to([10, 20]), label='assert_tens')
   285  
   286  
   287  class TypeHintsTest(unittest.TestCase):
   288    def test_bad_types(self):
   289      # [START type_hints_missing_define_numbers]
   290      p = TestPipeline()
   291  
   292      numbers = p | beam.Create(['1', '2', '3'])
   293      # [END type_hints_missing_define_numbers]
   294  
   295      # Consider the following code.
   296      # pylint: disable=expression-not-assigned
   297      # pylint: disable=unused-variable
   298      # [START type_hints_missing_apply]
   299      evens = numbers | beam.Filter(lambda x: x % 2 == 0)
   300      # [END type_hints_missing_apply]
   301  
   302      # Now suppose numbers was defined as [snippet above].
   303      # When running this pipeline, you'd get a runtime error,
   304      # possibly on a remote machine, possibly very late.
   305  
   306      with self.assertRaises(TypeError):
   307        p.run()
   308  
   309      # To catch this early, we can assert what types we expect.
   310      with self.assertRaises(typehints.TypeCheckError):
   311        # [START type_hints_takes]
   312        evens = numbers | beam.Filter(lambda x: x % 2 == 0).with_input_types(int)
   313        # [END type_hints_takes]
   314  
   315      # Type hints can be declared on DoFns and callables as well, rather
   316      # than where they're used, to be more self contained.
   317      with self.assertRaises(typehints.TypeCheckError):
   318        # [START type_hints_do_fn]
   319        @beam.typehints.with_input_types(int)
   320        class FilterEvensDoFn(beam.DoFn):
   321          def process(self, element):
   322            if element % 2 == 0:
   323              yield element
   324  
   325        evens = numbers | beam.ParDo(FilterEvensDoFn())
   326        # [END type_hints_do_fn]
   327  
   328      words = p | 'words' >> beam.Create(['a', 'bb', 'c'])
   329      # One can assert outputs and apply them to transforms as well.
   330      # Helps document the contract and checks it at pipeline construction time.
   331      # [START type_hints_transform]
   332      from typing import Tuple, TypeVar
   333  
   334      T = TypeVar('T')
   335  
   336      @beam.typehints.with_input_types(T)
   337      @beam.typehints.with_output_types(Tuple[int, T])
   338      class MyTransform(beam.PTransform):
   339        def expand(self, pcoll):
   340          return pcoll | beam.Map(lambda x: (len(x), x))
   341  
   342      words_with_lens = words | MyTransform()
   343      # [END type_hints_transform]
   344  
   345      # Given an input of str, the inferred output type would be Tuple[int, str].
   346      self.assertEqual(typehints.Tuple[int, str], words_with_lens.element_type)
   347  
   348      # pylint: disable=expression-not-assigned
   349      with self.assertRaises(typehints.TypeCheckError):
   350        words_with_lens | beam.Map(lambda x: x).with_input_types(Tuple[int, int])
   351  
   352    def test_bad_types_annotations(self):
   353      p = TestPipeline(options=PipelineOptions(pipeline_type_check=True))
   354  
   355      numbers = p | beam.Create(['1', '2', '3'])
   356  
   357      # Consider the following code.
   358      # pylint: disable=expression-not-assigned
   359      # pylint: disable=unused-variable
   360      class FilterEvensDoFn(beam.DoFn):
   361        def process(self, element):
   362          if element % 2 == 0:
   363            yield element
   364  
   365      evens = numbers | 'Untyped Filter' >> beam.ParDo(FilterEvensDoFn())
   366  
   367      # Now suppose numbers was defined as [snippet above].
   368      # When running this pipeline, you'd get a runtime error,
   369      # possibly on a remote machine, possibly very late.
   370  
   371      with self.assertRaises(TypeError):
   372        p.run()
   373  
   374      # To catch this early, we can annotate process() with the expected types.
   375      # Beam will then use these as type hints and perform type checking before
   376      # the pipeline starts.
   377      with self.assertRaises(typehints.TypeCheckError):
   378        # [START type_hints_do_fn_annotations]
   379        from typing import Iterable
   380  
   381        class TypedFilterEvensDoFn(beam.DoFn):
   382          def process(self, element: int) -> Iterable[int]:
   383            if element % 2 == 0:
   384              yield element
   385  
   386        evens = numbers | 'filter_evens' >> beam.ParDo(TypedFilterEvensDoFn())
   387        # [END type_hints_do_fn_annotations]
   388  
   389      # Another example, using a list output type. Notice that the output
   390      # annotation has an additional Optional for the else clause.
   391      with self.assertRaises(typehints.TypeCheckError):
   392        # [START type_hints_do_fn_annotations_optional]
   393        from typing import List, Optional
   394  
   395        class FilterEvensDoubleDoFn(beam.DoFn):
   396          def process(self, element: int) -> Optional[List[int]]:
   397            if element % 2 == 0:
   398              return [element, element]
   399            return None
   400  
   401        evens = numbers | 'double_evens' >> beam.ParDo(FilterEvensDoubleDoFn())
   402        # [END type_hints_do_fn_annotations_optional]
   403  
   404      # Example using an annotated function.
   405      with self.assertRaises(typehints.TypeCheckError):
   406        # [START type_hints_map_annotations]
   407        def my_fn(element: int) -> str:
   408          return 'id_' + str(element)
   409  
   410        ids = numbers | 'to_id' >> beam.Map(my_fn)
   411        # [END type_hints_map_annotations]
   412  
   413      # Example using an annotated PTransform.
   414      with self.assertRaises(typehints.TypeCheckError):
   415        # [START type_hints_ptransforms]
   416        from apache_beam.pvalue import PCollection
   417  
   418        class IntToStr(beam.PTransform):
   419          def expand(self, pcoll: PCollection[int]) -> PCollection[str]:
   420            return pcoll | beam.Map(lambda elem: str(elem))
   421  
   422        ids = numbers | 'convert to str' >> IntToStr()
   423        # [END type_hints_ptransforms]
   424  
   425    def test_runtime_checks_off(self):
   426      # We do not run the following pipeline, as it has incorrect type
   427      # information, and may fail with obscure errors, depending on the runner
   428      # implementation.
   429  
   430      # pylint: disable=expression-not-assigned
   431      # [START type_hints_runtime_off]
   432      p = TestPipeline()
   433      p | beam.Create(['a']) | beam.Map(lambda x: 3).with_output_types(str)
   434      # [END type_hints_runtime_off]
   435  
   436    def test_runtime_checks_on(self):
   437      # pylint: disable=expression-not-assigned
   438      with self.assertRaises(typehints.TypeCheckError):
   439        # [START type_hints_runtime_on]
   440        p = TestPipeline(options=PipelineOptions(runtime_type_check=True))
   441        p | beam.Create(['a']) | beam.Map(lambda x: 3).with_output_types(str)
   442        p.run()
   443        # [END type_hints_runtime_on]
   444  
   445    def test_deterministic_key(self):
   446      with TestPipeline() as p:
   447        lines = (
   448            p | beam.Create([
   449                'banana,fruit,3',
   450                'kiwi,fruit,2',
   451                'kiwi,fruit,2',
   452                'zucchini,veg,3'
   453            ]))
   454  
   455        # For pickling.
   456        global Player  # pylint: disable=global-variable-not-assigned
   457  
   458        # [START type_hints_deterministic_key]
   459        from typing import Tuple
   460  
   461        class Player(object):
   462          def __init__(self, team, name):
   463            self.team = team
   464            self.name = name
   465  
   466        class PlayerCoder(beam.coders.Coder):
   467          def encode(self, player):
   468            return ('%s:%s' % (player.team, player.name)).encode('utf-8')
   469  
   470          def decode(self, s):
   471            return Player(*s.decode('utf-8').split(':'))
   472  
   473          def is_deterministic(self):
   474            return True
   475  
   476        beam.coders.registry.register_coder(Player, PlayerCoder)
   477  
   478        def parse_player_and_score(csv):
   479          name, team, score = csv.split(',')
   480          return Player(team, name), int(score)
   481  
   482        totals = (
   483            lines
   484            | beam.Map(parse_player_and_score)
   485            | beam.CombinePerKey(sum).with_input_types(Tuple[Player, int]))
   486        # [END type_hints_deterministic_key]
   487  
   488        assert_that(
   489            totals | beam.Map(lambda k_v: (k_v[0].name, k_v[1])),
   490            equal_to([('banana', 3), ('kiwi', 4), ('zucchini', 3)]))
   491  
   492  
   493  class SnippetsTest(unittest.TestCase):
   494    # Replacing text read/write transforms with dummy transforms for testing.
   495  
   496    class DummyReadTransform(beam.PTransform):
   497      """A transform that will replace iobase.ReadFromText.
   498  
   499      To be used for testing.
   500      """
   501      def __init__(self, file_to_read=None, compression_type=None):
   502        self.file_to_read = file_to_read
   503        self.compression_type = compression_type
   504  
   505      class ReadDoFn(beam.DoFn):
   506        def __init__(self, file_to_read, compression_type):
   507          self.file_to_read = file_to_read
   508          self.compression_type = compression_type
   509          self.coder = coders.StrUtf8Coder()
   510  
   511        def process(self, element):
   512          pass
   513  
   514        def finish_bundle(self):
   515          from apache_beam.transforms import window
   516  
   517          assert self.file_to_read
   518          for file_name in glob.glob(self.file_to_read):
   519            if self.compression_type is None:
   520              with open(file_name, 'rb') as file:
   521                for record in file:
   522                  value = self.coder.decode(record.rstrip(b'\n'))
   523                  yield WindowedValue(value, -1, [window.GlobalWindow()])
   524            else:
   525              with gzip.open(file_name, 'rb') as file:
   526                for record in file:
   527                  value = self.coder.decode(record.rstrip(b'\n'))
   528                  yield WindowedValue(value, -1, [window.GlobalWindow()])
   529  
   530      def expand(self, pcoll):
   531        return pcoll | beam.Create([None]) | 'DummyReadForTesting' >> beam.ParDo(
   532            SnippetsTest.DummyReadTransform.ReadDoFn(
   533                self.file_to_read, self.compression_type))
   534  
   535    class DummyWriteTransform(beam.PTransform):
   536      """A transform that will replace iobase.WriteToText.
   537  
   538      To be used for testing.
   539      """
   540      def __init__(self, file_to_write=None, file_name_suffix=''):
   541        self.file_to_write = file_to_write
   542  
   543      class WriteDoFn(beam.DoFn):
   544        def __init__(self, file_to_write):
   545          self.file_to_write = file_to_write
   546          self.file_obj = None
   547          self.coder = ToBytesCoder()
   548  
   549        def start_bundle(self):
   550          assert self.file_to_write
   551          # Appending a UUID to create a unique file object per invocation.
   552          self.file_obj = open(self.file_to_write + str(uuid.uuid4()), 'wb')
   553  
   554        def process(self, element):
   555          assert self.file_obj
   556          self.file_obj.write(self.coder.encode(element) + b'\n')
   557  
   558        def finish_bundle(self):
   559          assert self.file_obj
   560          self.file_obj.close()
   561  
   562      def expand(self, pcoll):
   563        return pcoll | 'DummyWriteForTesting' >> beam.ParDo(
   564            SnippetsTest.DummyWriteTransform.WriteDoFn(self.file_to_write))
   565  
   566    def setUp(self):
   567      self.old_read_from_text = beam.io.ReadFromText
   568      self.old_write_to_text = beam.io.WriteToText
   569  
   570      # Monkey patching to allow testing pipelines defined in snippets.py using
   571      # real data.
   572      beam.io.ReadFromText = SnippetsTest.DummyReadTransform
   573      beam.io.WriteToText = SnippetsTest.DummyWriteTransform
   574      self.temp_files = []
   575  
   576    def tearDown(self):
   577      beam.io.ReadFromText = self.old_read_from_text
   578      beam.io.WriteToText = self.old_write_to_text
   579      # Cleanup all the temporary files created in the test.
   580      map(os.remove, self.temp_files)
   581      # Ensure that PipelineOptions subclasses have been cleaned up between tests
   582      gc.collect()
   583  
   584    def create_temp_file(self, contents=''):
   585      with tempfile.NamedTemporaryFile(delete=False) as f:
   586        f.write(contents.encode('utf-8'))
   587        self.temp_files.append(f.name)
   588        return f.name
   589  
   590    def get_output(self, path, sorted_output=True, suffix=''):
   591      all_lines = []
   592      for file_name in glob.glob(path + '*'):
   593        with open(file_name) as f:
   594          lines = f.readlines()
   595          all_lines.extend([s.rstrip('\n') for s in lines])
   596  
   597      if sorted_output:
   598        return sorted(s.rstrip('\n') for s in all_lines)
   599      return all_lines
   600  
   601    def test_model_pipelines(self):
   602      temp_path = self.create_temp_file('aa bb cc\n bb cc\n cc')
   603      result_path = temp_path + '.result'
   604      test_argv = [
   605          "unused_argv[0]",
   606          f"--input-file={temp_path}*",
   607          f"--output-path={result_path}",
   608      ]
   609      with mock.patch.object(sys, 'argv', test_argv):
   610        snippets.model_pipelines()
   611      self.assertEqual(
   612          self.get_output(result_path),
   613          [str(s) for s in [(u'aa', 1), (u'bb', 2), (u'cc', 3)]])
   614  
   615    def test_model_pcollection(self):
   616      temp_path = self.create_temp_file()
   617      snippets.model_pcollection(temp_path)
   618      self.assertEqual(
   619          self.get_output(temp_path),
   620          [
   621              'Or to take arms against a sea of troubles, ',
   622              'The slings and arrows of outrageous fortune, ',
   623              'To be, or not to be: that is the question: ',
   624              'Whether \'tis nobler in the mind to suffer ',
   625          ])
   626  
   627    def test_construct_pipeline(self):
   628      temp_path = self.create_temp_file('abc def ghi\n jkl mno pqr\n stu vwx yz')
   629      result_path = self.create_temp_file()
   630      snippets.construct_pipeline({'read': temp_path, 'write': result_path})
   631      self.assertEqual(
   632          self.get_output(result_path),
   633          ['cba', 'fed', 'ihg', 'lkj', 'onm', 'rqp', 'uts', 'xwv', 'zy'])
   634  
   635    def test_model_custom_source(self):
   636      snippets.model_custom_source(100)
   637  
   638    def test_model_custom_sink(self):
   639      tempdir_name = tempfile.mkdtemp()
   640  
   641      class SimpleKV(object):
   642        def __init__(self, tmp_dir):
   643          self._dummy_token = 'dummy_token'
   644          self._tmp_dir = tmp_dir
   645  
   646        def connect(self, url):
   647          return self._dummy_token
   648  
   649        def open_table(self, access_token, table_name):
   650          assert access_token == self._dummy_token
   651          file_name = self._tmp_dir + os.sep + table_name
   652          assert not os.path.exists(file_name)
   653          open(file_name, 'wb').close()
   654          return table_name
   655  
   656        def write_to_table(self, access_token, table_name, key, value):
   657          assert access_token == self._dummy_token
   658          file_name = self._tmp_dir + os.sep + table_name
   659          assert os.path.exists(file_name)
   660          with open(file_name, 'ab') as f:
   661            content = (key + ':' + value + os.linesep).encode('utf-8')
   662            f.write(content)
   663  
   664        def rename_table(self, access_token, old_name, new_name):
   665          assert access_token == self._dummy_token
   666          old_file_name = self._tmp_dir + os.sep + old_name
   667          new_file_name = self._tmp_dir + os.sep + new_name
   668          assert os.path.isfile(old_file_name)
   669          assert not os.path.exists(new_file_name)
   670  
   671          os.rename(old_file_name, new_file_name)
   672  
   673      snippets.model_custom_sink(
   674          SimpleKV(tempdir_name),
   675          [('key' + str(i), 'value' + str(i)) for i in range(100)],
   676          'final_table_no_ptransform',
   677          'final_table_with_ptransform')
   678  
   679      expected_output = [
   680          'key' + str(i) + ':' + 'value' + str(i) for i in range(100)
   681      ]
   682  
   683      glob_pattern = tempdir_name + os.sep + 'final_table_no_ptransform*'
   684      output_files = glob.glob(glob_pattern)
   685      assert output_files
   686  
   687      received_output = []
   688      for file_name in output_files:
   689        with open(file_name) as f:
   690          for line in f:
   691            received_output.append(line.rstrip(os.linesep))
   692  
   693      self.assertCountEqual(expected_output, received_output)
   694  
   695      glob_pattern = tempdir_name + os.sep + 'final_table_with_ptransform*'
   696      output_files = glob.glob(glob_pattern)
   697      assert output_files
   698  
   699      received_output = []
   700      for file_name in output_files:
   701        with open(file_name) as f:
   702          for line in f:
   703            received_output.append(line.rstrip(os.linesep))
   704  
   705      self.assertCountEqual(expected_output, received_output)
   706  
   707    def test_model_textio(self):
   708      temp_path = self.create_temp_file('aa bb cc\n bb cc\n cc')
   709      result_path = temp_path + '.result'
   710      snippets.model_textio({'read': temp_path, 'write': result_path})
   711      self.assertEqual(['aa', 'bb', 'bb', 'cc', 'cc', 'cc'],
   712                       self.get_output(result_path, suffix='.csv'))
   713  
   714    def test_model_textio_compressed(self):
   715      temp_path = self.create_temp_file('aa\nbb\ncc')
   716      gzip_file_name = temp_path + '.gz'
   717      with open(temp_path, 'rb') as src, gzip.open(gzip_file_name, 'wb') as dst:
   718        dst.writelines(src)
   719        # Add the temporary gzip file to be cleaned up as well.
   720        self.temp_files.append(gzip_file_name)
   721      snippets.model_textio_compressed({'read': gzip_file_name},
   722                                       ['aa', 'bb', 'cc'])
   723  
   724    @unittest.skipIf(
   725        datastore_client is None, 'GCP dependencies are not installed')
   726    def test_model_datastoreio(self):
   727      # We cannot test DatastoreIO functionality in unit tests, therefore we limit
   728      # ourselves to making sure the pipeline containing Datastore read and write
   729      # transforms can be built.
   730      # TODO(vikasrk): Expore using Datastore Emulator.
   731      snippets.model_datastoreio()
   732  
   733    @unittest.skipIf(base_api is None, 'GCP dependencies are not installed')
   734    def test_model_bigqueryio(self):
   735      # We cannot test BigQueryIO functionality in unit tests, therefore we limit
   736      # ourselves to making sure the pipeline containing BigQuery sources and
   737      # sinks can be built.
   738      #
   739      # To run locally, set `run_locally` to `True`. You will also have to set
   740      # `project`, `dataset` and `table` to the BigQuery table the test will write
   741      # to.
   742      run_locally = False
   743      if run_locally:
   744        project = 'my-project'
   745        dataset = 'samples'  # this must already exist
   746        table = 'model_bigqueryio'  # this will be created if needed
   747  
   748        options = PipelineOptions().view_as(GoogleCloudOptions)
   749        options.project = project
   750        with beam.Pipeline(options=options) as p:
   751          snippets.model_bigqueryio(p, project, dataset, table)
   752      else:
   753        p = TestPipeline()
   754        p.options.view_as(GoogleCloudOptions).temp_location = 'gs://mylocation'
   755        snippets.model_bigqueryio(p)
   756  
   757    def _run_test_pipeline_for_options(self, fn):
   758      temp_path = self.create_temp_file('aa\nbb\ncc')
   759      result_path = temp_path + '.result'
   760      test_argv = [
   761          "unused_argv[0]",
   762          f"--input={temp_path}*",
   763          f"--output={result_path}",
   764      ]
   765      with mock.patch.object(sys, 'argv', test_argv):
   766        fn()
   767      self.assertEqual(['aa', 'bb', 'cc'], self.get_output(result_path))
   768  
   769    def test_pipeline_options_local(self):
   770      self._run_test_pipeline_for_options(snippets.pipeline_options_local)
   771  
   772    def test_pipeline_options_remote(self):
   773      self._run_test_pipeline_for_options(snippets.pipeline_options_remote)
   774  
   775    def test_pipeline_options_command_line(self):
   776      self._run_test_pipeline_for_options(snippets.pipeline_options_command_line)
   777  
   778    def test_pipeline_logging(self):
   779      result_path = self.create_temp_file()
   780      lines = [
   781          'we found love right where we are',
   782          'we found love right from the start',
   783          'we found love in a hopeless place'
   784      ]
   785      snippets.pipeline_logging(lines, result_path)
   786      self.assertEqual(
   787          sorted(' '.join(lines).split(' ')), self.get_output(result_path))
   788  
   789    @parameterized.parameterized.expand([
   790        [snippets.examples_wordcount_minimal],
   791        [snippets.examples_wordcount_wordcount],
   792        [snippets.pipeline_monitoring],
   793        [snippets.examples_wordcount_templated],
   794    ])
   795    def test_examples_wordcount(self, pipeline):
   796      temp_path = self.create_temp_file('abc def ghi\n abc jkl')
   797      result_path = self.create_temp_file()
   798      test_argv = [
   799          "unused_argv[0]",
   800          f"--input-file={temp_path}*",
   801          f"--output-path={result_path}",
   802      ]
   803      with mock.patch.object(sys, 'argv', test_argv):
   804        pipeline()
   805      self.assertEqual(
   806          self.get_output(result_path), ['abc: 2', 'def: 1', 'ghi: 1', 'jkl: 1'])
   807  
   808    def test_examples_ptransforms_templated(self):
   809      pipelines = [snippets.examples_ptransforms_templated]
   810  
   811      for pipeline in pipelines:
   812        temp_path = self.create_temp_file('1\n 2\n 3')
   813        result_path = self.create_temp_file()
   814        pipeline({'read': temp_path, 'write': result_path})
   815        self.assertEqual(self.get_output(result_path), ['11', '12', '13'])
   816  
   817    def test_examples_wordcount_debugging(self):
   818      temp_path = self.create_temp_file(
   819          'Flourish Flourish Flourish stomach abc def')
   820      result_path = self.create_temp_file()
   821      snippets.examples_wordcount_debugging({
   822          'read': temp_path, 'write': result_path
   823      })
   824      self.assertEqual(
   825          self.get_output(result_path), ['Flourish: 3', 'stomach: 1'])
   826  
   827    @unittest.skipIf(pubsub is None, 'GCP dependencies are not installed')
   828    @mock.patch('apache_beam.io.ReadFromPubSub')
   829    @mock.patch('apache_beam.io.WriteToPubSub')
   830    def test_examples_wordcount_streaming(self, *unused_mocks):
   831      def FakeReadFromPubSub(topic=None, subscription=None, values=None):
   832        expected_topic = topic
   833        expected_subscription = subscription
   834  
   835        def _inner(topic=None, subscription=None):
   836          assert topic == expected_topic
   837          assert subscription == expected_subscription
   838          return TestStream().add_elements(values)
   839  
   840        return _inner
   841  
   842      class AssertTransform(beam.PTransform):
   843        def __init__(self, matcher):
   844          self.matcher = matcher
   845  
   846        def expand(self, pcoll):
   847          assert_that(pcoll, self.matcher)
   848  
   849      def FakeWriteToPubSub(topic=None, values=None):
   850        expected_topic = topic
   851  
   852        def _inner(topic=None, subscription=None):
   853          assert topic == expected_topic
   854          return AssertTransform(equal_to(values))
   855  
   856        return _inner
   857  
   858      # Test basic execution.
   859      input_topic = 'projects/fake-beam-test-project/topic/intopic'
   860      input_values = [
   861          TimestampedValue(b'a a b', 1),
   862          TimestampedValue(u'🤷 ¯\\_(ツ)_/¯ b b '.encode('utf-8'), 12),
   863          TimestampedValue(b'a b c c c', 20)
   864      ]
   865      output_topic = 'projects/fake-beam-test-project/topic/outtopic'
   866      output_values = [b'a: 1', b'a: 2', b'b: 1', b'b: 3', b'c: 3']
   867      beam.io.ReadFromPubSub = (
   868          FakeReadFromPubSub(topic=input_topic, values=input_values))
   869      beam.io.WriteToPubSub = (
   870          FakeWriteToPubSub(topic=output_topic, values=output_values))
   871      test_argv = [
   872          'unused_argv[0]',
   873          '--input_topic',
   874          'projects/fake-beam-test-project/topic/intopic',
   875          '--output_topic',
   876          'projects/fake-beam-test-project/topic/outtopic'
   877      ]
   878      with mock.patch.object(sys, 'argv', test_argv):
   879        snippets.examples_wordcount_streaming()
   880  
   881      # Test with custom subscription.
   882      input_sub = 'projects/fake-beam-test-project/subscriptions/insub'
   883      beam.io.ReadFromPubSub = FakeReadFromPubSub(
   884          subscription=input_sub, values=input_values)
   885      test_argv = [
   886          'unused_argv[0]',
   887          '--input_subscription',
   888          'projects/fake-beam-test-project/subscriptions/insub',
   889          '--output_topic',
   890          'projects/fake-beam-test-project/topic/outtopic'
   891      ]
   892      with mock.patch.object(sys, 'argv', test_argv):
   893        snippets.examples_wordcount_streaming()
   894  
   895    def test_model_composite_transform_example(self):
   896      contents = ['aa bb cc', 'bb cc', 'cc']
   897      result_path = self.create_temp_file()
   898      snippets.model_composite_transform_example(contents, result_path)
   899      self.assertEqual(['aa: 1', 'bb: 2', 'cc: 3'], self.get_output(result_path))
   900  
   901    def test_model_multiple_pcollections_flatten(self):
   902      contents = ['a', 'b', 'c', 'd', 'e', 'f']
   903      result_path = self.create_temp_file()
   904      snippets.model_multiple_pcollections_flatten(contents, result_path)
   905      self.assertEqual(contents, self.get_output(result_path))
   906  
   907    def test_model_multiple_pcollections_partition(self):
   908      contents = [17, 42, 64, 32, 0, 99, 53, 89]
   909      result_path = self.create_temp_file()
   910      snippets.model_multiple_pcollections_partition(contents, result_path)
   911      self.assertEqual(['0', '17', '32', '42', '53', '64', '89', '99'],
   912                       self.get_output(result_path))
   913  
   914    def test_model_group_by_key(self):
   915      contents = ['a bb ccc bb bb a']
   916      result_path = self.create_temp_file()
   917      snippets.model_group_by_key(contents, result_path)
   918      expected = [('a', 2), ('bb', 3), ('ccc', 1)]
   919      self.assertEqual([str(s) for s in expected], self.get_output(result_path))
   920  
   921    def test_model_co_group_by_key_tuple(self):
   922      with TestPipeline() as p:
   923        # [START model_group_by_key_cogroupbykey_tuple_inputs]
   924        emails_list = [
   925            ('amy', 'amy@example.com'),
   926            ('carl', 'carl@example.com'),
   927            ('julia', 'julia@example.com'),
   928            ('carl', 'carl@email.com'),
   929        ]
   930        phones_list = [
   931            ('amy', '111-222-3333'),
   932            ('james', '222-333-4444'),
   933            ('amy', '333-444-5555'),
   934            ('carl', '444-555-6666'),
   935        ]
   936  
   937        emails = p | 'CreateEmails' >> beam.Create(emails_list)
   938        phones = p | 'CreatePhones' >> beam.Create(phones_list)
   939        # [END model_group_by_key_cogroupbykey_tuple_inputs]
   940  
   941        result_path = self.create_temp_file()
   942        snippets.model_co_group_by_key_tuple(emails, phones, result_path)
   943  
   944      # [START model_group_by_key_cogroupbykey_tuple_outputs]
   945      results = [
   946          (
   947              'amy',
   948              {
   949                  'emails': ['amy@example.com'],
   950                  'phones': ['111-222-3333', '333-444-5555']
   951              }),
   952          (
   953              'carl',
   954              {
   955                  'emails': ['carl@email.com', 'carl@example.com'],
   956                  'phones': ['444-555-6666']
   957              }),
   958          ('james', {
   959              'emails': [], 'phones': ['222-333-4444']
   960          }),
   961          ('julia', {
   962              'emails': ['julia@example.com'], 'phones': []
   963          }),
   964      ]
   965      # [END model_group_by_key_cogroupbykey_tuple_outputs]
   966      # [START model_group_by_key_cogroupbykey_tuple_formatted_outputs]
   967      formatted_results = [
   968          "amy; ['amy@example.com']; ['111-222-3333', '333-444-5555']",
   969          "carl; ['carl@email.com', 'carl@example.com']; ['444-555-6666']",
   970          "james; []; ['222-333-4444']",
   971          "julia; ['julia@example.com']; []",
   972      ]
   973      # [END model_group_by_key_cogroupbykey_tuple_formatted_outputs]
   974      expected_results = [
   975          '%s; %s; %s' % (name, info['emails'], info['phones']) for name,
   976          info in results
   977      ]
   978      self.assertEqual(expected_results, formatted_results)
   979      self.assertEqual(formatted_results, self.get_output(result_path))
   980  
   981    def test_model_use_and_query_metrics(self):
   982      """DebuggingWordCount example snippets."""
   983  
   984      import re
   985  
   986      p = TestPipeline()  # Use TestPipeline for testing.
   987      words = p | beam.Create(
   988          ['albert', 'sam', 'mark', 'sarah', 'swati', 'daniel', 'andrea'])
   989  
   990      # pylint: disable=unused-variable
   991      # [START metrics_usage_example]
   992      class FilterTextFn(beam.DoFn):
   993        """A DoFn that filters for a specific key based on a regex."""
   994        def __init__(self, pattern):
   995          self.pattern = pattern
   996          # A custom metric can track values in your pipeline as it runs. Create
   997          # custom metrics to count unmatched words, and know the distribution of
   998          # word lengths in the input PCollection.
   999          self.word_len_dist = Metrics.distribution(
  1000              self.__class__, 'word_len_dist')
  1001          self.unmatched_words = Metrics.counter(
  1002              self.__class__, 'unmatched_words')
  1003  
  1004        def process(self, element):
  1005          word = element
  1006          self.word_len_dist.update(len(word))
  1007          if re.match(self.pattern, word):
  1008            yield element
  1009          else:
  1010            self.unmatched_words.inc()
  1011  
  1012      filtered_words = (words | 'FilterText' >> beam.ParDo(FilterTextFn('s.*')))
  1013      # [END metrics_usage_example]
  1014      # pylint: enable=unused-variable
  1015  
  1016      # [START metrics_check_values_example]
  1017      result = p.run()
  1018      result.wait_until_finish()
  1019  
  1020      custom_distribution = result.metrics().query(
  1021          MetricsFilter().with_name('word_len_dist'))['distributions']
  1022      custom_counter = result.metrics().query(
  1023          MetricsFilter().with_name('unmatched_words'))['counters']
  1024  
  1025      if custom_distribution:
  1026        logging.info(
  1027            'The average word length was %d',
  1028            custom_distribution[0].committed.mean)
  1029      if custom_counter:
  1030        logging.info(
  1031            'There were %d words that did not match the filter.',
  1032            custom_counter[0].committed)
  1033      # [END metrics_check_values_example]
  1034  
  1035      # There should be 4 words that did not match
  1036      self.assertEqual(custom_counter[0].committed, 4)
  1037      # The shortest word is 3 characters, the longest is 6
  1038      self.assertEqual(custom_distribution[0].committed.min, 3)
  1039      self.assertEqual(custom_distribution[0].committed.max, 6)
  1040  
  1041    def test_model_join_using_side_inputs(self):
  1042      name_list = ['a', 'b']
  1043      email_list = [['a', 'a@example.com'], ['b', 'b@example.com']]
  1044      phone_list = [['a', 'x4312'], ['b', 'x8452']]
  1045      result_path = self.create_temp_file()
  1046      snippets.model_join_using_side_inputs(
  1047          name_list, email_list, phone_list, result_path)
  1048      expect = ['a; a@example.com; x4312', 'b; b@example.com; x8452']
  1049      self.assertEqual(expect, self.get_output(result_path))
  1050  
  1051    def test_model_early_late_triggers(self):
  1052      pipeline_options = PipelineOptions()
  1053      pipeline_options.view_as(StandardOptions).streaming = True
  1054  
  1055      with TestPipeline(options=pipeline_options) as p:
  1056        test_stream = (
  1057            TestStream().advance_watermark_to(10).add_elements([
  1058                'a', 'a', 'a', 'b', 'b'
  1059            ]).add_elements([
  1060                TimestampedValue('a', 10)
  1061            ]).advance_watermark_to(20).advance_processing_time(60).add_elements(
  1062                [TimestampedValue('a', 10)]))
  1063        trigger = (
  1064            # [START model_early_late_triggers]
  1065            AfterWatermark(
  1066                early=AfterProcessingTime(delay=1 * 60), late=AfterCount(1))
  1067            # [END model_early_late_triggers]
  1068        )
  1069        counts = (
  1070            p
  1071            | test_stream
  1072            | 'pair_with_one' >> beam.Map(lambda x: (x, 1))
  1073            | WindowInto(
  1074                FixedWindows(15),
  1075                trigger=trigger,
  1076                allowed_lateness=20,
  1077                accumulation_mode=AccumulationMode.DISCARDING)
  1078            | 'group' >> beam.GroupByKey()
  1079            | 'count' >>
  1080            beam.Map(lambda word_ones: (word_ones[0], sum(word_ones[1]))))
  1081        assert_that(counts, equal_to([('a', 4), ('b', 2), ('a', 1)]))
  1082  
  1083    def test_model_setting_trigger(self):
  1084      pipeline_options = PipelineOptions(
  1085          flags=['--streaming', '--allow_unsafe_triggers'])
  1086  
  1087      with TestPipeline(options=pipeline_options) as p:
  1088        test_stream = (
  1089            TestStream().advance_watermark_to(10).add_elements(
  1090                ['a', 'a', 'a', 'b',
  1091                 'b']).advance_watermark_to(70).advance_processing_time(600))
  1092        pcollection = (
  1093            p
  1094            | test_stream
  1095            | 'pair_with_one' >> beam.Map(lambda x: (x, 1)))
  1096  
  1097        counts = (
  1098            pcollection | WindowInto(
  1099                FixedWindows(1 * 60),
  1100                trigger=AfterProcessingTime(10 * 60),
  1101                accumulation_mode=AccumulationMode.DISCARDING)
  1102            | 'group' >> beam.GroupByKey()
  1103            | 'count' >>
  1104            beam.Map(lambda word_ones: (word_ones[0], sum(word_ones[1]))))
  1105        assert_that(counts, equal_to([('a', 3), ('b', 2)]))
  1106  
  1107    def test_model_composite_triggers(self):
  1108      pipeline_options = PipelineOptions()
  1109      pipeline_options.view_as(StandardOptions).streaming = True
  1110  
  1111      with TestPipeline(options=pipeline_options) as p:
  1112        test_stream = (
  1113            TestStream().advance_watermark_to(10).add_elements(
  1114                ['a', 'a', 'a', 'b', 'b']).advance_watermark_to(70).add_elements([
  1115                    TimestampedValue('a', 10),
  1116                    TimestampedValue('a', 10),
  1117                    TimestampedValue('c', 10),
  1118                    TimestampedValue('c', 10)
  1119                ]).advance_processing_time(600))
  1120        pcollection = (
  1121            p
  1122            | test_stream
  1123            | 'pair_with_one' >> beam.Map(lambda x: (x, 1)))
  1124  
  1125        counts = (
  1126            # [START model_composite_triggers]
  1127            pcollection | WindowInto(
  1128                FixedWindows(1 * 60),
  1129                trigger=AfterWatermark(late=AfterProcessingTime(10 * 60)),
  1130                allowed_lateness=10,
  1131                accumulation_mode=AccumulationMode.DISCARDING)
  1132            # [END model_composite_triggers]
  1133            | 'group' >> beam.GroupByKey()
  1134            | 'count' >>
  1135            beam.Map(lambda word_ones: (word_ones[0], sum(word_ones[1]))))
  1136        assert_that(counts, equal_to([('a', 3), ('b', 2), ('a', 2), ('c', 2)]))
  1137  
  1138    def test_model_other_composite_triggers(self):
  1139      pipeline_options = PipelineOptions(
  1140          flags=['--streaming', '--allow_unsafe_triggers'])
  1141  
  1142      with TestPipeline(options=pipeline_options) as p:
  1143        test_stream = (
  1144            TestStream().advance_watermark_to(10).add_elements(
  1145                ['a', 'a']).add_elements(
  1146                    ['a', 'b',
  1147                     'b']).advance_processing_time(60).add_elements(['a'] * 100))
  1148        pcollection = (
  1149            p
  1150            | test_stream
  1151            | 'pair_with_one' >> beam.Map(lambda x: (x, 1)))
  1152  
  1153        counts = (
  1154            # [START model_other_composite_triggers]
  1155            pcollection | WindowInto(
  1156                FixedWindows(1 * 60),
  1157                trigger=Repeatedly(
  1158                    AfterAny(AfterCount(100), AfterProcessingTime(1 * 60))),
  1159                accumulation_mode=AccumulationMode.DISCARDING)
  1160            # [END model_other_composite_triggers]
  1161            | 'group' >> beam.GroupByKey()
  1162            | 'count' >>
  1163            beam.Map(lambda word_ones: (word_ones[0], sum(word_ones[1]))))
  1164        assert_that(counts, equal_to([('a', 3), ('b', 2), ('a', 100)]))
  1165  
  1166  
  1167  class CombineTest(unittest.TestCase):
  1168    """Tests for model/combine."""
  1169    def test_global_sum(self):
  1170      pc = [1, 2, 3]
  1171      # [START global_sum]
  1172      result = pc | beam.CombineGlobally(sum)
  1173      # [END global_sum]
  1174      self.assertEqual([6], result)
  1175  
  1176    def test_combine_values(self):
  1177      occurences = [('cat', 1), ('cat', 5), ('cat', 9), ('dog', 5), ('dog', 2)]
  1178      # [START combine_values]
  1179      first_occurences = occurences | beam.GroupByKey() | beam.CombineValues(min)
  1180      # [END combine_values]
  1181      self.assertEqual({('cat', 1), ('dog', 2)}, set(first_occurences))
  1182  
  1183    def test_combine_per_key(self):
  1184      player_accuracies = [('cat', 1), ('cat', 5), ('cat', 9), ('cat', 1),
  1185                           ('dog', 5), ('dog', 2)]
  1186      # [START combine_per_key]
  1187      avg_accuracy_per_player = (
  1188          player_accuracies
  1189          | beam.CombinePerKey(beam.combiners.MeanCombineFn()))
  1190      # [END combine_per_key]
  1191      self.assertEqual({('cat', 4.0), ('dog', 3.5)}, set(avg_accuracy_per_player))
  1192  
  1193    def test_combine_concat(self):
  1194      pc = ['a', 'b']
  1195  
  1196      # [START combine_concat]
  1197      def concat(values, separator=', '):
  1198        return separator.join(values)
  1199  
  1200      with_commas = pc | beam.CombineGlobally(concat)
  1201      with_dashes = pc | beam.CombineGlobally(concat, separator='-')
  1202      # [END combine_concat]
  1203      self.assertEqual(1, len(with_commas))
  1204      self.assertTrue(with_commas[0] in {'a, b', 'b, a'})
  1205      self.assertEqual(1, len(with_dashes))
  1206      self.assertTrue(with_dashes[0] in {'a-b', 'b-a'})
  1207  
  1208    def test_bounded_sum(self):
  1209      # [START combine_bounded_sum]
  1210      pc = [1, 10, 100, 1000]
  1211  
  1212      def bounded_sum(values, bound=500):
  1213        return min(sum(values), bound)
  1214  
  1215      small_sum = pc | beam.CombineGlobally(bounded_sum)  # [500]
  1216      large_sum = pc | beam.CombineGlobally(bounded_sum, bound=5000)  # [1111]
  1217      # [END combine_bounded_sum]
  1218      self.assertEqual([500], small_sum)
  1219      self.assertEqual([1111], large_sum)
  1220  
  1221    def test_combine_reduce(self):
  1222      factors = [2, 3, 5, 7]
  1223      # [START combine_reduce]
  1224      import functools
  1225      import operator
  1226      product = factors | beam.CombineGlobally(
  1227          functools.partial(functools.reduce, operator.mul), 1)
  1228      # [END combine_reduce]
  1229      self.assertEqual([210], product)
  1230  
  1231    def test_custom_average(self):
  1232      pc = [2, 3, 5, 7]
  1233  
  1234      # [START combine_custom_average_define]
  1235      class AverageFn(beam.CombineFn):
  1236        def create_accumulator(self):
  1237          return (0.0, 0)
  1238  
  1239        def add_input(self, sum_count, input):
  1240          (sum, count) = sum_count
  1241          return sum + input, count + 1
  1242  
  1243        def merge_accumulators(self, accumulators):
  1244          sums, counts = zip(*accumulators)
  1245          return sum(sums), sum(counts)
  1246  
  1247        def extract_output(self, sum_count):
  1248          (sum, count) = sum_count
  1249          return sum / count if count else float('NaN')
  1250  
  1251      # [END combine_custom_average_define]
  1252      # [START combine_custom_average_execute]
  1253      average = pc | beam.CombineGlobally(AverageFn())
  1254      # [END combine_custom_average_execute]
  1255      self.assertEqual([4.25], average)
  1256  
  1257    def test_keys(self):
  1258      occurrences = [('cat', 1), ('cat', 5), ('dog', 5), ('cat', 9), ('dog', 2)]
  1259      unique_keys = occurrences | snippets.Keys()
  1260      self.assertEqual({'cat', 'dog'}, set(unique_keys))
  1261  
  1262    def test_count(self):
  1263      occurrences = ['cat', 'dog', 'cat', 'cat', 'dog']
  1264      perkey_counts = occurrences | snippets.Count()
  1265      self.assertEqual({('cat', 3), ('dog', 2)}, set(perkey_counts))
  1266  
  1267    def test_setting_fixed_windows(self):
  1268      with TestPipeline() as p:
  1269        unkeyed_items = p | beam.Create([22, 33, 55, 100, 115, 120])
  1270        items = (
  1271            unkeyed_items
  1272            | 'key' >>
  1273            beam.Map(lambda x: beam.window.TimestampedValue(('k', x), x)))
  1274        # [START setting_fixed_windows]
  1275        from apache_beam import window
  1276        fixed_windowed_items = (
  1277            items | 'window' >> beam.WindowInto(window.FixedWindows(60)))
  1278        # [END setting_fixed_windows]
  1279        summed = (
  1280            fixed_windowed_items
  1281            | 'group' >> beam.GroupByKey()
  1282            | 'combine' >> beam.CombineValues(sum))
  1283        unkeyed = summed | 'unkey' >> beam.Map(lambda x: x[1])
  1284        assert_that(unkeyed, equal_to([110, 215, 120]))
  1285  
  1286    def test_setting_sliding_windows(self):
  1287      with TestPipeline() as p:
  1288        unkeyed_items = p | beam.Create([2, 16, 23])
  1289        items = (
  1290            unkeyed_items
  1291            | 'key' >>
  1292            beam.Map(lambda x: beam.window.TimestampedValue(('k', x), x)))
  1293        # [START setting_sliding_windows]
  1294        from apache_beam import window
  1295        sliding_windowed_items = (
  1296            items | 'window' >> beam.WindowInto(window.SlidingWindows(30, 5)))
  1297        # [END setting_sliding_windows]
  1298        summed = (
  1299            sliding_windowed_items
  1300            | 'group' >> beam.GroupByKey()
  1301            | 'combine' >> beam.CombineValues(sum))
  1302        unkeyed = summed | 'unkey' >> beam.Map(lambda x: x[1])
  1303        assert_that(unkeyed, equal_to([2, 2, 2, 18, 23, 39, 39, 39, 41, 41]))
  1304  
  1305    def test_setting_session_windows(self):
  1306      with TestPipeline() as p:
  1307        unkeyed_items = p | beam.Create([2, 11, 16, 27])
  1308        items = (
  1309            unkeyed_items
  1310            | 'key' >>
  1311            beam.Map(lambda x: beam.window.TimestampedValue(('k', x), x * 60)))
  1312        # [START setting_session_windows]
  1313        from apache_beam import window
  1314        session_windowed_items = (
  1315            items | 'window' >> beam.WindowInto(window.Sessions(10 * 60)))
  1316        # [END setting_session_windows]
  1317        summed = (
  1318            session_windowed_items
  1319            | 'group' >> beam.GroupByKey()
  1320            | 'combine' >> beam.CombineValues(sum))
  1321        unkeyed = summed | 'unkey' >> beam.Map(lambda x: x[1])
  1322        assert_that(unkeyed, equal_to([29, 27]))
  1323  
  1324    def test_setting_global_window(self):
  1325      with TestPipeline() as p:
  1326        unkeyed_items = p | beam.Create([2, 11, 16, 27])
  1327        items = (
  1328            unkeyed_items
  1329            | 'key' >>
  1330            beam.Map(lambda x: beam.window.TimestampedValue(('k', x), x)))
  1331        # [START setting_global_window]
  1332        from apache_beam import window
  1333        global_windowed_items = (
  1334            items | 'window' >> beam.WindowInto(window.GlobalWindows()))
  1335        # [END setting_global_window]
  1336        summed = (
  1337            global_windowed_items
  1338            | 'group' >> beam.GroupByKey()
  1339            | 'combine' >> beam.CombineValues(sum))
  1340        unkeyed = summed | 'unkey' >> beam.Map(lambda x: x[1])
  1341        assert_that(unkeyed, equal_to([56]))
  1342  
  1343    def test_setting_timestamp(self):
  1344      with TestPipeline() as p:
  1345        unkeyed_items = p | beam.Create([12, 30, 60, 61, 66])
  1346        items = (unkeyed_items | 'key' >> beam.Map(lambda x: ('k', x)))
  1347  
  1348        def extract_timestamp_from_log_entry(entry):
  1349          return entry[1]
  1350  
  1351        # [START setting_timestamp]
  1352        class AddTimestampDoFn(beam.DoFn):
  1353          def process(self, element):
  1354            # Extract the numeric Unix seconds-since-epoch timestamp to be
  1355            # associated with the current log entry.
  1356            unix_timestamp = extract_timestamp_from_log_entry(element)
  1357            # Wrap and emit the current entry and new timestamp in a
  1358            # TimestampedValue.
  1359            yield beam.window.TimestampedValue(element, unix_timestamp)
  1360  
  1361        timestamped_items = items | 'timestamp' >> beam.ParDo(AddTimestampDoFn())
  1362        # [END setting_timestamp]
  1363        fixed_windowed_items = (
  1364            timestamped_items
  1365            | 'window' >> beam.WindowInto(beam.window.FixedWindows(60)))
  1366        summed = (
  1367            fixed_windowed_items
  1368            | 'group' >> beam.GroupByKey()
  1369            | 'combine' >> beam.CombineValues(sum))
  1370        unkeyed = summed | 'unkey' >> beam.Map(lambda x: x[1])
  1371        assert_that(unkeyed, equal_to([42, 187]))
  1372  
  1373  
  1374  class PTransformTest(unittest.TestCase):
  1375    """Tests for PTransform."""
  1376    def test_composite(self):
  1377  
  1378      # [START model_composite_transform]
  1379      class ComputeWordLengths(beam.PTransform):
  1380        def expand(self, pcoll):
  1381          # Transform logic goes here.
  1382          return pcoll | beam.Map(lambda x: len(x))
  1383  
  1384      # [END model_composite_transform]
  1385  
  1386      with TestPipeline() as p:
  1387        lengths = p | beam.Create(["a", "ab", "abc"]) | ComputeWordLengths()
  1388        assert_that(lengths, equal_to([1, 2, 3]))
  1389  
  1390  
  1391  class SlowlyChangingSideInputsTest(unittest.TestCase):
  1392    """Tests for PTransform."""
  1393    def test_side_input_slow_update(self):
  1394      temp_file = tempfile.NamedTemporaryFile(delete=True)
  1395      src_file_pattern = temp_file.name
  1396      temp_file.close()
  1397  
  1398      first_ts = math.floor(time.time()) - 30
  1399      interval = 5
  1400      main_input_windowing_interval = 7
  1401  
  1402      # aligning timestamp to get persistent results
  1403      first_ts = first_ts - (
  1404          first_ts % (interval * main_input_windowing_interval))
  1405      last_ts = first_ts + 45
  1406  
  1407      for i in range(-1, 10, 1):
  1408        count = i + 2
  1409        idstr = str(first_ts + interval * i)
  1410        with open(src_file_pattern + idstr, "w") as f:
  1411          for j in range(count):
  1412            f.write('f' + idstr + 'a' + str(j) + '\n')
  1413  
  1414      sample_main_input_elements = ([first_ts - 2, # no output due to no SI
  1415                                     first_ts + 1,  # First window
  1416                                     first_ts + 8,  # Second window
  1417                                     first_ts + 15,  # Third window
  1418                                     first_ts + 22,  # Fourth window
  1419                                     ])
  1420  
  1421      pipeline, pipeline_result = snippets.side_input_slow_update(
  1422        src_file_pattern, first_ts, last_ts, interval,
  1423        sample_main_input_elements, main_input_windowing_interval)
  1424  
  1425      try:
  1426        with pipeline:
  1427          pipeline_result = (
  1428              pipeline_result
  1429              | 'AddKey' >> beam.Map(lambda v: ('key', v))
  1430              | combiners.Count.PerKey())
  1431  
  1432          assert_that(
  1433              pipeline_result,
  1434              equal_to([('key', 3), ('key', 4), ('key', 6), ('key', 7)]))
  1435      finally:
  1436        for i in range(-1, 10, 1):
  1437          os.unlink(src_file_pattern + str(first_ts + interval * i))
  1438  
  1439  
  1440  if __name__ == '__main__':
  1441    logging.getLogger().setLevel(logging.INFO)
  1442    unittest.main()