github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/caching/expression_cache_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 import unittest 19 20 import apache_beam as beam 21 from apache_beam.dataframe import expressions 22 from apache_beam.runners.interactive.caching.expression_cache import ExpressionCache 23 24 25 class ExpressionCacheTest(unittest.TestCase): 26 def setUp(self): 27 self._pcollection_cache = {} 28 self._computed_cache = set() 29 self._pipeline = beam.Pipeline() 30 self.cache = ExpressionCache(self._pcollection_cache, self._computed_cache) 31 32 def create_trace(self, expr): 33 trace = [expr] 34 for input in expr.args(): 35 trace += self.create_trace(input) 36 return trace 37 38 def mock_cache(self, expr): 39 pcoll = beam.PCollection(self._pipeline) 40 self._pcollection_cache[expr._id] = pcoll 41 self._computed_cache.add(pcoll) 42 43 def assertTraceTypes(self, expr, expected): 44 actual_types = [type(e).__name__ for e in self.create_trace(expr)] 45 expected_types = [e.__name__ for e in expected] 46 self.assertListEqual(actual_types, expected_types) 47 48 def test_only_replaces_cached(self): 49 in_expr = expressions.ConstantExpression(0) 50 comp_expr = expressions.ComputedExpression('test', lambda x: x, [in_expr]) 51 52 # Expect that no replacement of expressions is performed. 53 expected_trace = [ 54 expressions.ComputedExpression, expressions.ConstantExpression 55 ] 56 self.assertTraceTypes(comp_expr, expected_trace) 57 58 self.cache.replace_with_cached(comp_expr) 59 60 self.assertTraceTypes(comp_expr, expected_trace) 61 62 # Now "cache" the expression and assert that the cached expression was 63 # replaced with a placeholder. 64 self.mock_cache(in_expr) 65 66 replaced = self.cache.replace_with_cached(comp_expr) 67 68 expected_trace = [ 69 expressions.ComputedExpression, expressions.PlaceholderExpression 70 ] 71 self.assertTraceTypes(comp_expr, expected_trace) 72 self.assertIn(in_expr._id, replaced) 73 74 def test_only_replaces_inputs(self): 75 arg_0_expr = expressions.ConstantExpression(0) 76 ident_val = expressions.ComputedExpression( 77 'ident', lambda x: x, [arg_0_expr]) 78 79 arg_1_expr = expressions.ConstantExpression(1) 80 comp_expr = expressions.ComputedExpression( 81 'add', lambda x, y: x + y, [ident_val, arg_1_expr]) 82 83 self.mock_cache(ident_val) 84 85 replaced = self.cache.replace_with_cached(comp_expr) 86 87 # Assert that ident_val was replaced and that its arguments were removed 88 # from the expression tree. 89 expected_trace = [ 90 expressions.ComputedExpression, 91 expressions.PlaceholderExpression, 92 expressions.ConstantExpression 93 ] 94 self.assertTraceTypes(comp_expr, expected_trace) 95 self.assertIn(ident_val._id, replaced) 96 self.assertNotIn(arg_0_expr, self.create_trace(comp_expr)) 97 98 def test_only_caches_same_input(self): 99 arg_0_expr = expressions.ConstantExpression(0) 100 ident_val = expressions.ComputedExpression( 101 'ident', lambda x: x, [arg_0_expr]) 102 comp_expr = expressions.ComputedExpression( 103 'add', lambda x, y: x + y, [ident_val, arg_0_expr]) 104 105 self.mock_cache(arg_0_expr) 106 107 replaced = self.cache.replace_with_cached(comp_expr) 108 109 # Assert that arg_0_expr, being an input to two computations, was replaced 110 # with the same placeholder expression. 111 expected_trace = [ 112 expressions.ComputedExpression, 113 expressions.ComputedExpression, 114 expressions.PlaceholderExpression, 115 expressions.PlaceholderExpression 116 ] 117 actual_trace = self.create_trace(comp_expr) 118 unique_placeholders = set( 119 t for t in actual_trace 120 if isinstance(t, expressions.PlaceholderExpression)) 121 self.assertTraceTypes(comp_expr, expected_trace) 122 self.assertTrue( 123 all(e == replaced[arg_0_expr._id] for e in unique_placeholders)) 124 self.assertIn(arg_0_expr._id, replaced) 125 126 127 if __name__ == '__main__': 128 unittest.main()