github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/caching/read_cache_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Tests for read_cache.""" 19 # pytype: skip-file 20 21 import unittest 22 from unittest.mock import patch 23 24 import apache_beam as beam 25 from apache_beam.runners.interactive import augmented_pipeline as ap 26 from apache_beam.runners.interactive import interactive_beam as ib 27 from apache_beam.runners.interactive import interactive_environment as ie 28 from apache_beam.runners.interactive.caching import read_cache 29 from apache_beam.runners.interactive.testing.pipeline_assertion import assert_pipeline_proto_equal 30 from apache_beam.runners.interactive.testing.test_cache_manager import InMemoryCache 31 32 33 class ReadCacheTest(unittest.TestCase): 34 def setUp(self): 35 ie.new_env() 36 37 @patch( 38 'apache_beam.runners.interactive.interactive_environment' 39 '.InteractiveEnvironment.get_cache_manager') 40 def test_read_cache(self, mocked_get_cache_manager): 41 p = beam.Pipeline() 42 pcoll = p | beam.Create([1, 2, 3]) 43 consumer_transform = beam.Map(lambda x: x * x) 44 _ = pcoll | consumer_transform 45 ib.watch(locals()) 46 47 # Create the cache in memory. 48 cache_manager = InMemoryCache() 49 mocked_get_cache_manager.return_value = cache_manager 50 aug_p = ap.AugmentedPipeline(p) 51 key = repr(aug_p._cacheables[pcoll].to_key()) 52 cache_manager.write('test', 'full', key) 53 54 # Capture the applied transform of the consumer_transform. 55 pcoll_id = aug_p._context.pcollections.get_id(pcoll) 56 consumer_transform_id = None 57 pipeline_proto = p.to_runner_api() 58 for (transform_id, 59 transform) in pipeline_proto.components.transforms.items(): 60 if pcoll_id in transform.inputs.values(): 61 consumer_transform_id = transform_id 62 break 63 self.assertIsNotNone(consumer_transform_id) 64 65 # Read cache on the pipeline proto. 66 _, cache_id = read_cache.ReadCache( 67 pipeline_proto, aug_p._context, aug_p._cache_manager, 68 aug_p._cacheables[pcoll]).read_cache() 69 actual_pipeline = pipeline_proto 70 71 # Read cache directly on the pipeline instance. 72 transform = read_cache._ReadCacheTransform(aug_p._cache_manager, key) 73 p | 'source_cache_' + key >> transform 74 expected_pipeline = p.to_runner_api() 75 76 # This rougly checks the equivalence between two protos, not detailed 77 # wiring in sub transforms under top level transforms. 78 assert_pipeline_proto_equal(self, expected_pipeline, actual_pipeline) 79 80 # Check if the actual_pipeline uses cache as input of the 81 # consumer_transform instead of the original pcoll from source. 82 inputs = actual_pipeline.components.transforms[consumer_transform_id].inputs 83 self.assertIn(cache_id, inputs.values()) 84 self.assertNotIn(pcoll_id, inputs.values()) 85 86 87 if __name__ == '__main__': 88 unittest.main()