github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/interactive/user_pipeline_tracker_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  import unittest
    19  
    20  import apache_beam as beam
    21  from apache_beam.runners.interactive.user_pipeline_tracker import UserPipelineTracker
    22  
    23  
    24  class UserPipelineTrackerTest(unittest.TestCase):
    25    def test_getting_unknown_pid_returns_none(self):
    26      ut = UserPipelineTracker()
    27  
    28      p = beam.Pipeline()
    29  
    30      self.assertIsNone(ut.get_pipeline(str(id(p))))
    31  
    32    def test_getting_unknown_pipeline_returns_none(self):
    33      ut = UserPipelineTracker()
    34  
    35      p = beam.Pipeline()
    36  
    37      self.assertIsNone(ut.get_user_pipeline(p))
    38  
    39    def test_no_parent_returns_none(self):
    40      ut = UserPipelineTracker()
    41  
    42      user = beam.Pipeline()
    43      derived = beam.Pipeline()
    44      orphan = beam.Pipeline()
    45  
    46      ut.add_derived_pipeline(user, derived)
    47  
    48      self.assertIsNone(ut.get_user_pipeline(orphan))
    49  
    50    def test_get_user_pipeline_is_same(self):
    51      ut = UserPipelineTracker()
    52  
    53      p = beam.Pipeline()
    54      ut.add_user_pipeline(p)
    55  
    56      self.assertIs(ut.get_user_pipeline(p), p)
    57  
    58    def test_can_add_derived(self):
    59      ut = UserPipelineTracker()
    60  
    61      user = beam.Pipeline()
    62      derived = beam.Pipeline()
    63  
    64      ut.add_derived_pipeline(user, derived)
    65  
    66      self.assertIs(ut.get_user_pipeline(derived), user)
    67  
    68    def test_can_add_multiple_derived(self):
    69      """Tests that there can be many user pipelines with many derived
    70      pipelines.
    71      """
    72      ut = UserPipelineTracker()
    73  
    74      # Add the first set of user and derived pipelines.
    75      user1 = beam.Pipeline()
    76      derived11 = beam.Pipeline()
    77      derived12 = beam.Pipeline()
    78  
    79      ut.add_derived_pipeline(user1, derived11)
    80      ut.add_derived_pipeline(user1, derived12)
    81  
    82      # Add the second set of user and derived pipelines.
    83      user2 = beam.Pipeline()
    84      derived21 = beam.Pipeline()
    85      derived22 = beam.Pipeline()
    86  
    87      ut.add_derived_pipeline(user2, derived21)
    88      ut.add_derived_pipeline(user2, derived22)
    89  
    90      # Assert that the user pipelines are correct.
    91      self.assertIs(ut.get_user_pipeline(derived11), user1)
    92      self.assertIs(ut.get_user_pipeline(derived12), user1)
    93      self.assertIs(ut.get_user_pipeline(derived21), user2)
    94      self.assertIs(ut.get_user_pipeline(derived22), user2)
    95  
    96    def test_cannot_have_multiple_parents(self):
    97      ut = UserPipelineTracker()
    98  
    99      user1 = beam.Pipeline()
   100      user2 = beam.Pipeline()
   101      derived = beam.Pipeline()
   102  
   103      ut.add_derived_pipeline(user1, derived)
   104  
   105      with self.assertRaises(AssertionError):
   106        ut.add_derived_pipeline(user2, derived)
   107  
   108      self.assertIs(ut.get_user_pipeline(derived), user1)
   109  
   110    def test_adding_derived_with_derived_gets_user_pipeline(self):
   111      """Tests that one can correctly add a derived pipeline from a derived
   112      pipeline and still get the correct user pipeline.
   113      """
   114      ut = UserPipelineTracker()
   115  
   116      user = beam.Pipeline()
   117      derived1 = beam.Pipeline()
   118      derived2 = beam.Pipeline()
   119  
   120      # Add the first derived pipeline to the user pipelne.
   121      ut.add_derived_pipeline(user, derived1)
   122  
   123      # Add the second derived pipeline to the first derived pipeline. This should
   124      # get the user pipeline of the first and add the second to it.
   125      ut.add_derived_pipeline(derived1, derived2)
   126  
   127      # Asserts that both derived pipelines are under the same user pipeline.
   128      self.assertIs(ut.get_user_pipeline(derived1), user)
   129      self.assertIs(ut.get_user_pipeline(derived2), user)
   130  
   131    def test_can_get_pipeline_from_id(self):
   132      """Tests the pid -> pipeline memoization."""
   133      ut = UserPipelineTracker()
   134  
   135      user = beam.Pipeline()
   136      derived = beam.Pipeline()
   137  
   138      ut.add_user_pipeline(user)
   139      ut.add_derived_pipeline(user, derived)
   140  
   141      self.assertIs(ut.get_pipeline(str(id(user))), user)
   142      self.assertIs(ut.get_pipeline(str(id(derived))), derived)
   143  
   144    def test_clear(self):
   145      ut = UserPipelineTracker()
   146  
   147      user = beam.Pipeline()
   148      derived = beam.Pipeline()
   149  
   150      ut.add_derived_pipeline(user, derived)
   151  
   152      self.assertIs(ut.get_user_pipeline(derived), user)
   153  
   154      ut.clear()
   155  
   156      self.assertIsNone(ut.get_user_pipeline(user))
   157      self.assertIsNone(ut.get_user_pipeline(derived))
   158  
   159    def test_can_iterate(self):
   160      ut = UserPipelineTracker()
   161  
   162      user1 = beam.Pipeline()
   163      derived11 = beam.Pipeline()
   164      derived12 = beam.Pipeline()
   165  
   166      ut.add_derived_pipeline(user1, derived11)
   167      ut.add_derived_pipeline(user1, derived12)
   168  
   169      user2 = beam.Pipeline()
   170      derived21 = beam.Pipeline()
   171      derived22 = beam.Pipeline()
   172  
   173      ut.add_derived_pipeline(user2, derived21)
   174      ut.add_derived_pipeline(user2, derived22)
   175  
   176      user_pipelines = set(p for p in ut)
   177      self.assertSetEqual(set([user1, user2]), user_pipelines)
   178  
   179    def test_can_evict_user_pipeline(self):
   180      ut = UserPipelineTracker()
   181  
   182      user1 = beam.Pipeline()
   183      derived11 = beam.Pipeline()
   184      derived12 = beam.Pipeline()
   185  
   186      ut.add_derived_pipeline(user1, derived11)
   187      ut.add_derived_pipeline(user1, derived12)
   188  
   189      user2 = beam.Pipeline()
   190      derived21 = beam.Pipeline()
   191      derived22 = beam.Pipeline()
   192  
   193      ut.add_derived_pipeline(user2, derived21)
   194      ut.add_derived_pipeline(user2, derived22)
   195  
   196      ut.evict(user1)
   197  
   198      self.assertIsNone(ut.get_user_pipeline(user1))
   199      self.assertIsNone(ut.get_user_pipeline(derived11))
   200      self.assertIsNone(ut.get_user_pipeline(derived12))
   201  
   202      self.assertIs(user2, ut.get_user_pipeline(derived21))
   203      self.assertIs(user2, ut.get_user_pipeline(derived22))
   204  
   205  
   206  if __name__ == '__main__':
   207    unittest.main()