github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/user_pipeline_tracker_test.py (about) 1 # 2 # Licensed to the Apache Software Foundation (ASF) under one or more 3 # contributor license agreements. See the NOTICE file distributed with 4 # this work for additional information regarding copyright ownership. 5 # The ASF licenses this file to You under the Apache License, Version 2.0 6 # (the "License"); you may not use this file except in compliance with 7 # the License. You may obtain a copy of the License at 8 # 9 # http://www.apache.org/licenses/LICENSE-2.0 10 # 11 # Unless required by applicable law or agreed to in writing, software 12 # distributed under the License is distributed on an "AS IS" BASIS, 13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 # See the License for the specific language governing permissions and 15 # limitations under the License. 16 # 17 18 import unittest 19 20 import apache_beam as beam 21 from apache_beam.runners.interactive.user_pipeline_tracker import UserPipelineTracker 22 23 24 class UserPipelineTrackerTest(unittest.TestCase): 25 def test_getting_unknown_pid_returns_none(self): 26 ut = UserPipelineTracker() 27 28 p = beam.Pipeline() 29 30 self.assertIsNone(ut.get_pipeline(str(id(p)))) 31 32 def test_getting_unknown_pipeline_returns_none(self): 33 ut = UserPipelineTracker() 34 35 p = beam.Pipeline() 36 37 self.assertIsNone(ut.get_user_pipeline(p)) 38 39 def test_no_parent_returns_none(self): 40 ut = UserPipelineTracker() 41 42 user = beam.Pipeline() 43 derived = beam.Pipeline() 44 orphan = beam.Pipeline() 45 46 ut.add_derived_pipeline(user, derived) 47 48 self.assertIsNone(ut.get_user_pipeline(orphan)) 49 50 def test_get_user_pipeline_is_same(self): 51 ut = UserPipelineTracker() 52 53 p = beam.Pipeline() 54 ut.add_user_pipeline(p) 55 56 self.assertIs(ut.get_user_pipeline(p), p) 57 58 def test_can_add_derived(self): 59 ut = UserPipelineTracker() 60 61 user = beam.Pipeline() 62 derived = beam.Pipeline() 63 64 ut.add_derived_pipeline(user, derived) 65 66 self.assertIs(ut.get_user_pipeline(derived), user) 67 68 def test_can_add_multiple_derived(self): 69 """Tests that there can be many user pipelines with many derived 70 pipelines. 71 """ 72 ut = UserPipelineTracker() 73 74 # Add the first set of user and derived pipelines. 75 user1 = beam.Pipeline() 76 derived11 = beam.Pipeline() 77 derived12 = beam.Pipeline() 78 79 ut.add_derived_pipeline(user1, derived11) 80 ut.add_derived_pipeline(user1, derived12) 81 82 # Add the second set of user and derived pipelines. 83 user2 = beam.Pipeline() 84 derived21 = beam.Pipeline() 85 derived22 = beam.Pipeline() 86 87 ut.add_derived_pipeline(user2, derived21) 88 ut.add_derived_pipeline(user2, derived22) 89 90 # Assert that the user pipelines are correct. 91 self.assertIs(ut.get_user_pipeline(derived11), user1) 92 self.assertIs(ut.get_user_pipeline(derived12), user1) 93 self.assertIs(ut.get_user_pipeline(derived21), user2) 94 self.assertIs(ut.get_user_pipeline(derived22), user2) 95 96 def test_cannot_have_multiple_parents(self): 97 ut = UserPipelineTracker() 98 99 user1 = beam.Pipeline() 100 user2 = beam.Pipeline() 101 derived = beam.Pipeline() 102 103 ut.add_derived_pipeline(user1, derived) 104 105 with self.assertRaises(AssertionError): 106 ut.add_derived_pipeline(user2, derived) 107 108 self.assertIs(ut.get_user_pipeline(derived), user1) 109 110 def test_adding_derived_with_derived_gets_user_pipeline(self): 111 """Tests that one can correctly add a derived pipeline from a derived 112 pipeline and still get the correct user pipeline. 113 """ 114 ut = UserPipelineTracker() 115 116 user = beam.Pipeline() 117 derived1 = beam.Pipeline() 118 derived2 = beam.Pipeline() 119 120 # Add the first derived pipeline to the user pipelne. 121 ut.add_derived_pipeline(user, derived1) 122 123 # Add the second derived pipeline to the first derived pipeline. This should 124 # get the user pipeline of the first and add the second to it. 125 ut.add_derived_pipeline(derived1, derived2) 126 127 # Asserts that both derived pipelines are under the same user pipeline. 128 self.assertIs(ut.get_user_pipeline(derived1), user) 129 self.assertIs(ut.get_user_pipeline(derived2), user) 130 131 def test_can_get_pipeline_from_id(self): 132 """Tests the pid -> pipeline memoization.""" 133 ut = UserPipelineTracker() 134 135 user = beam.Pipeline() 136 derived = beam.Pipeline() 137 138 ut.add_user_pipeline(user) 139 ut.add_derived_pipeline(user, derived) 140 141 self.assertIs(ut.get_pipeline(str(id(user))), user) 142 self.assertIs(ut.get_pipeline(str(id(derived))), derived) 143 144 def test_clear(self): 145 ut = UserPipelineTracker() 146 147 user = beam.Pipeline() 148 derived = beam.Pipeline() 149 150 ut.add_derived_pipeline(user, derived) 151 152 self.assertIs(ut.get_user_pipeline(derived), user) 153 154 ut.clear() 155 156 self.assertIsNone(ut.get_user_pipeline(user)) 157 self.assertIsNone(ut.get_user_pipeline(derived)) 158 159 def test_can_iterate(self): 160 ut = UserPipelineTracker() 161 162 user1 = beam.Pipeline() 163 derived11 = beam.Pipeline() 164 derived12 = beam.Pipeline() 165 166 ut.add_derived_pipeline(user1, derived11) 167 ut.add_derived_pipeline(user1, derived12) 168 169 user2 = beam.Pipeline() 170 derived21 = beam.Pipeline() 171 derived22 = beam.Pipeline() 172 173 ut.add_derived_pipeline(user2, derived21) 174 ut.add_derived_pipeline(user2, derived22) 175 176 user_pipelines = set(p for p in ut) 177 self.assertSetEqual(set([user1, user2]), user_pipelines) 178 179 def test_can_evict_user_pipeline(self): 180 ut = UserPipelineTracker() 181 182 user1 = beam.Pipeline() 183 derived11 = beam.Pipeline() 184 derived12 = beam.Pipeline() 185 186 ut.add_derived_pipeline(user1, derived11) 187 ut.add_derived_pipeline(user1, derived12) 188 189 user2 = beam.Pipeline() 190 derived21 = beam.Pipeline() 191 derived22 = beam.Pipeline() 192 193 ut.add_derived_pipeline(user2, derived21) 194 ut.add_derived_pipeline(user2, derived22) 195 196 ut.evict(user1) 197 198 self.assertIsNone(ut.get_user_pipeline(user1)) 199 self.assertIsNone(ut.get_user_pipeline(derived11)) 200 self.assertIsNone(ut.get_user_pipeline(derived12)) 201 202 self.assertIs(user2, ut.get_user_pipeline(derived21)) 203 self.assertIs(user2, ut.get_user_pipeline(derived22)) 204 205 206 if __name__ == '__main__': 207 unittest.main()