github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/caching/expression_cache_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  import unittest
    19  
    20  import apache_beam as beam
    21  from apache_beam.dataframe import expressions
    22  from apache_beam.runners.interactive.caching.expression_cache import ExpressionCache
    23  
    24  
    25  class ExpressionCacheTest(unittest.TestCase):
    26    def setUp(self):
    27      self._pcollection_cache = {}
    28      self._computed_cache = set()
    29      self._pipeline = beam.Pipeline()
    30      self.cache = ExpressionCache(self._pcollection_cache, self._computed_cache)
    31  
    32    def create_trace(self, expr):
    33      trace = [expr]
    34      for input in expr.args():
    35        trace += self.create_trace(input)
    36      return trace
    37  
    38    def mock_cache(self, expr):
    39      pcoll = beam.PCollection(self._pipeline)
    40      self._pcollection_cache[expr._id] = pcoll
    41      self._computed_cache.add(pcoll)
    42  
    43    def assertTraceTypes(self, expr, expected):
    44      actual_types = [type(e).__name__ for e in self.create_trace(expr)]
    45      expected_types = [e.__name__ for e in expected]
    46      self.assertListEqual(actual_types, expected_types)
    47  
    48    def test_only_replaces_cached(self):
    49      in_expr = expressions.ConstantExpression(0)
    50      comp_expr = expressions.ComputedExpression('test', lambda x: x, [in_expr])
    51  
    52      # Expect that no replacement of expressions is performed.
    53      expected_trace = [
    54          expressions.ComputedExpression, expressions.ConstantExpression
    55      ]
    56      self.assertTraceTypes(comp_expr, expected_trace)
    57  
    58      self.cache.replace_with_cached(comp_expr)
    59  
    60      self.assertTraceTypes(comp_expr, expected_trace)
    61  
    62      # Now "cache" the expression and assert that the cached expression was
    63      # replaced with a placeholder.
    64      self.mock_cache(in_expr)
    65  
    66      replaced = self.cache.replace_with_cached(comp_expr)
    67  
    68      expected_trace = [
    69          expressions.ComputedExpression, expressions.PlaceholderExpression
    70      ]
    71      self.assertTraceTypes(comp_expr, expected_trace)
    72      self.assertIn(in_expr._id, replaced)
    73  
    74    def test_only_replaces_inputs(self):
    75      arg_0_expr = expressions.ConstantExpression(0)
    76      ident_val = expressions.ComputedExpression(
    77          'ident', lambda x: x, [arg_0_expr])
    78  
    79      arg_1_expr = expressions.ConstantExpression(1)
    80      comp_expr = expressions.ComputedExpression(
    81          'add', lambda x, y: x + y, [ident_val, arg_1_expr])
    82  
    83      self.mock_cache(ident_val)
    84  
    85      replaced = self.cache.replace_with_cached(comp_expr)
    86  
    87      # Assert that ident_val was replaced and that its arguments were removed
    88      # from the expression tree.
    89      expected_trace = [
    90          expressions.ComputedExpression,
    91          expressions.PlaceholderExpression,
    92          expressions.ConstantExpression
    93      ]
    94      self.assertTraceTypes(comp_expr, expected_trace)
    95      self.assertIn(ident_val._id, replaced)
    96      self.assertNotIn(arg_0_expr, self.create_trace(comp_expr))
    97  
    98    def test_only_caches_same_input(self):
    99      arg_0_expr = expressions.ConstantExpression(0)
   100      ident_val = expressions.ComputedExpression(
   101          'ident', lambda x: x, [arg_0_expr])
   102      comp_expr = expressions.ComputedExpression(
   103          'add', lambda x, y: x + y, [ident_val, arg_0_expr])
   104  
   105      self.mock_cache(arg_0_expr)
   106  
   107      replaced = self.cache.replace_with_cached(comp_expr)
   108  
   109      # Assert that arg_0_expr, being an input to two computations, was replaced
   110      # with the same placeholder expression.
   111      expected_trace = [
   112          expressions.ComputedExpression,
   113          expressions.ComputedExpression,
   114          expressions.PlaceholderExpression,
   115          expressions.PlaceholderExpression
   116      ]
   117      actual_trace = self.create_trace(comp_expr)
   118      unique_placeholders = set(
   119          t for t in actual_trace
   120          if isinstance(t, expressions.PlaceholderExpression))
   121      self.assertTraceTypes(comp_expr, expected_trace)
   122      self.assertTrue(
   123          all(e == replaced[arg_0_expr._id] for e in unique_placeholders))
   124      self.assertIn(arg_0_expr._id, replaced)
   125  
   126  
   127  if __name__ == '__main__':
   128    unittest.main()