github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/augmented_pipeline_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 """Tests for augmented_pipeline module.""" 19 20 # pytest: skip-file 21 22 import unittest 23 24 import apache_beam as beam 25 from apache_beam.runners.interactive import augmented_pipeline as ap 26 from apache_beam.runners.interactive import interactive_beam as ib 27 from apache_beam.runners.interactive import interactive_environment as ie 28 29 30 class CacheableTest(unittest.TestCase): 31 def setUp(self): 32 ie.new_env() 33 34 def test_find_all_cacheables(self): 35 p = beam.Pipeline() 36 cacheable_pcoll_1 = p | beam.Create([1, 2, 3]) 37 cacheable_pcoll_2 = cacheable_pcoll_1 | beam.Map(lambda x: x * x) 38 ib.watch(locals()) 39 40 aug_p = ap.AugmentedPipeline(p) 41 cacheables = aug_p.cacheables() 42 self.assertIn(cacheable_pcoll_1, cacheables) 43 self.assertIn(cacheable_pcoll_2, cacheables) 44 45 def test_ignore_cacheables(self): 46 p = beam.Pipeline() 47 cacheable_pcoll_1 = p | 'cacheable_pcoll_1' >> beam.Create([1, 2, 3]) 48 cacheable_pcoll_2 = p | 'cacheable_pcoll_2' >> beam.Create([4, 5, 6]) 49 ib.watch(locals()) 50 51 aug_p = ap.AugmentedPipeline(p, (cacheable_pcoll_1, )) 52 cacheables = aug_p.cacheables() 53 self.assertIn(cacheable_pcoll_1, cacheables) 54 self.assertNotIn(cacheable_pcoll_2, cacheables) 55 56 def test_ignore_pcoll_from_other_pipeline(self): 57 p = beam.Pipeline() 58 p2 = beam.Pipeline() 59 cacheable_from_p2 = p2 | beam.Create([1, 2, 3]) 60 ib.watch(locals()) 61 62 aug_p = ap.AugmentedPipeline(p) 63 cacheables = aug_p.cacheables() 64 self.assertNotIn(cacheable_from_p2, cacheables) 65 66 67 class AugmentTest(unittest.TestCase): 68 def setUp(self): 69 ie.new_env() 70 71 def test_error_when_pcolls_from_mixed_pipelines(self): 72 p = beam.Pipeline() 73 cacheable_from_p = p | beam.Create([1, 2, 3]) 74 p2 = beam.Pipeline() 75 cacheable_from_p2 = p2 | beam.Create([1, 2, 3]) 76 ib.watch(locals()) 77 78 self.assertRaises( 79 AssertionError, 80 lambda: ap.AugmentedPipeline(p, (cacheable_from_p, cacheable_from_p2))) 81 82 83 if __name__ == '__main__': 84 unittest.main()