github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/caching/write_cache_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Tests for write_cache.""" 19 # pytype: skip-file 20 21 import unittest 22 from unittest.mock import patch 23 24 import apache_beam as beam 25 from apache_beam.runners.interactive import augmented_pipeline as ap 26 from apache_beam.runners.interactive import interactive_beam as ib 27 from apache_beam.runners.interactive import interactive_environment as ie 28 from apache_beam.runners.interactive.caching import write_cache 29 from apache_beam.runners.interactive.testing.pipeline_assertion import assert_pipeline_proto_equal 30 from apache_beam.runners.interactive.testing.test_cache_manager import InMemoryCache 31 32 33 class WriteCacheTest(unittest.TestCase): 34 def setUp(self): 35 ie.new_env() 36 37 @patch( 38 'apache_beam.runners.interactive.interactive_environment' 39 '.InteractiveEnvironment.get_cache_manager') 40 def test_write_cache(self, mocked_get_cache_manager): 41 p = beam.Pipeline() 42 pcoll = p | beam.Create([1, 2, 3]) 43 ib.watch(locals()) 44 45 cache_manager = InMemoryCache() 46 mocked_get_cache_manager.return_value = cache_manager 47 aug_p = ap.AugmentedPipeline(p) 48 key = repr(aug_p._cacheables[pcoll].to_key()) 49 pipeline_proto = p.to_runner_api() 50 51 # Write cache on the pipeline proto. 52 write_cache.WriteCache( 53 pipeline_proto, 54 aug_p._context, 55 aug_p._cache_manager, 56 aug_p._cacheables[pcoll]).write_cache() 57 actual_pipeline = pipeline_proto 58 59 # Write cache directly on the piepline instance. 60 transform = write_cache._WriteCacheTransform(aug_p._cache_manager, key) 61 _ = pcoll | 'sink_cache_' + key >> transform 62 expected_pipeline = p.to_runner_api() 63 64 assert_pipeline_proto_equal(self, expected_pipeline, actual_pipeline) 65 66 # Check if the actual_pipeline uses pcoll as an input of a write transform. 67 pcoll_id = aug_p._context.pcollections.get_id(pcoll) 68 write_transform_id = None 69 for transform_id, transform in \ 70 actual_pipeline.components.transforms.items(): 71 if pcoll_id in transform.inputs.values(): 72 write_transform_id = transform_id 73 break 74 self.assertIsNotNone(write_transform_id) 75 self.assertIn( 76 'sink', 77 actual_pipeline.components.transforms[write_transform_id].unique_name) 78 79 80 if __name__ == '__main__': 81 unittest.main()