github.com/apache/beam/sdks/v2@v2.48.2/python/apache_beam/runners/sdf_utils_test.py (about)

     1  #
     2  # Licensed to the Apache Software Foundation (ASF) under one or more
     3  # contributor license agreements.  See the NOTICE file distributed with
     4  # this work for additional information regarding copyright ownership.
     5  # The ASF licenses this file to You under the Apache License, Version 2.0
     6  # (the "License"); you may not use this file except in compliance with
     7  # the License.  You may obtain a copy of the License at
     8  #
     9  #    http://www.apache.org/licenses/LICENSE-2.0
    10  #
    11  # Unless required by applicable law or agreed to in writing, software
    12  # distributed under the License is distributed on an "AS IS" BASIS,
    13  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14  # See the License for the specific language governing permissions and
    15  # limitations under the License.
    16  #
    17  
    18  """Unit tests for classes in sdf_utils.py."""
    19  
    20  # pytype: skip-file
    21  
    22  import time
    23  import unittest
    24  
    25  from apache_beam.io.concat_source_test import RangeSource
    26  from apache_beam.io.restriction_trackers import OffsetRange
    27  from apache_beam.io.restriction_trackers import OffsetRestrictionTracker
    28  from apache_beam.io.watermark_estimators import ManualWatermarkEstimator
    29  from apache_beam.runners.sdf_utils import RestrictionTrackerView
    30  from apache_beam.runners.sdf_utils import ThreadsafeRestrictionTracker
    31  from apache_beam.runners.sdf_utils import ThreadsafeWatermarkEstimator
    32  from apache_beam.utils import timestamp
    33  
    34  
    35  class ThreadsafeRestrictionTrackerTest(unittest.TestCase):
    36    def test_initialization(self):
    37      with self.assertRaises(ValueError):
    38        ThreadsafeRestrictionTracker(RangeSource(0, 1))
    39  
    40    def test_defer_remainder_with_wrong_time_type(self):
    41      threadsafe_tracker = ThreadsafeRestrictionTracker(
    42          OffsetRestrictionTracker(OffsetRange(0, 10)))
    43      with self.assertRaises(ValueError):
    44        threadsafe_tracker.defer_remainder(10)
    45  
    46    def test_self_checkpoint_immediately(self):
    47      restriction_tracker = OffsetRestrictionTracker(OffsetRange(0, 10))
    48      threadsafe_tracker = ThreadsafeRestrictionTracker(restriction_tracker)
    49      threadsafe_tracker.defer_remainder()
    50      deferred_residual, deferred_time = threadsafe_tracker.deferred_status()
    51      expected_residual = OffsetRange(0, 10)
    52      self.assertEqual(deferred_residual, expected_residual)
    53      self.assertTrue(isinstance(deferred_time, timestamp.Duration))
    54      self.assertEqual(deferred_time, 0)
    55  
    56    def test_self_checkpoint_with_relative_time(self):
    57      threadsafe_tracker = ThreadsafeRestrictionTracker(
    58          OffsetRestrictionTracker(OffsetRange(0, 10)))
    59      threadsafe_tracker.defer_remainder(timestamp.Duration(100))
    60      time.sleep(2)
    61      _, deferred_time = threadsafe_tracker.deferred_status()
    62      self.assertTrue(isinstance(deferred_time, timestamp.Duration))
    63      # The expectation = 100 - 2 - some_delta
    64      self.assertTrue(deferred_time <= 98)
    65  
    66    def test_self_checkpoint_with_absolute_time(self):
    67      threadsafe_tracker = ThreadsafeRestrictionTracker(
    68          OffsetRestrictionTracker(OffsetRange(0, 10)))
    69      now = timestamp.Timestamp.now()
    70      schedule_time = now + timestamp.Duration(100)
    71      self.assertTrue(isinstance(schedule_time, timestamp.Timestamp))
    72      threadsafe_tracker.defer_remainder(schedule_time)
    73      time.sleep(2)
    74      _, deferred_time = threadsafe_tracker.deferred_status()
    75      self.assertTrue(isinstance(deferred_time, timestamp.Duration))
    76      # The expectation =
    77      # schedule_time - the time when deferred_status is called - some_delta
    78      self.assertTrue(deferred_time <= 98)
    79  
    80  
    81  class RestrictionTrackerViewTest(unittest.TestCase):
    82    def test_initialization(self):
    83      with self.assertRaises(ValueError):
    84        RestrictionTrackerView(OffsetRestrictionTracker(OffsetRange(0, 10)))
    85  
    86    def test_api_expose(self):
    87      threadsafe_tracker = ThreadsafeRestrictionTracker(
    88          OffsetRestrictionTracker(OffsetRange(0, 10)))
    89      tracker_view = RestrictionTrackerView(threadsafe_tracker)
    90      current_restriction = tracker_view.current_restriction()
    91      self.assertEqual(current_restriction, OffsetRange(0, 10))
    92      self.assertTrue(tracker_view.try_claim(0))
    93      tracker_view.defer_remainder()
    94      deferred_remainder, deferred_watermark = (
    95          threadsafe_tracker.deferred_status())
    96      self.assertEqual(deferred_remainder, OffsetRange(1, 10))
    97      self.assertEqual(deferred_watermark, timestamp.Duration())
    98  
    99    def test_non_expose_apis(self):
   100      threadsafe_tracker = ThreadsafeRestrictionTracker(
   101          OffsetRestrictionTracker(OffsetRange(0, 10)))
   102      tracker_view = RestrictionTrackerView(threadsafe_tracker)
   103      with self.assertRaises(AttributeError):
   104        tracker_view.check_done()
   105      with self.assertRaises(AttributeError):
   106        tracker_view.current_progress()
   107      with self.assertRaises(AttributeError):
   108        tracker_view.try_split()
   109      with self.assertRaises(AttributeError):
   110        tracker_view.deferred_status()
   111  
   112  
   113  class ThreadsafeWatermarkEstimatorTest(unittest.TestCase):
   114    def test_initialization(self):
   115      with self.assertRaises(ValueError):
   116        ThreadsafeWatermarkEstimator(None)
   117  
   118    def test_get_estimator_state(self):
   119      estimator = ThreadsafeWatermarkEstimator(ManualWatermarkEstimator(None))
   120      self.assertIsNone(estimator.get_estimator_state())
   121      estimator.set_watermark(timestamp.Timestamp(10))
   122      self.assertEqual(estimator.get_estimator_state(), timestamp.Timestamp(10))
   123  
   124    def test_track_timestamp(self):
   125      estimator = ThreadsafeWatermarkEstimator(ManualWatermarkEstimator(None))
   126      estimator.observe_timestamp(timestamp.Timestamp(10))
   127      self.assertIsNone(estimator.current_watermark())
   128      estimator.set_watermark(timestamp.Timestamp(20))
   129      self.assertEqual(estimator.current_watermark(), timestamp.Timestamp(20))
   130  
   131    def test_non_exsited_attr(self):
   132      estimator = ThreadsafeWatermarkEstimator(ManualWatermarkEstimator(None))
   133      with self.assertRaises(AttributeError):
   134        estimator.non_existed_call()
   135  
   136  
   137  if __name__ == '__main__':
   138    unittest.main()