github.com/google/fleetspeak@v0.1.15-0.20240426164851-4f31f62c1aea/fleetspeak_python/fleetspeak/server_connector/connector_test.py (about)

     1  # Copyright 2017 Google Inc.
     2  #
     3  # Licensed under the Apache License, Version 2.0 (the "License");
     4  # you may not use this file except in compliance with the License.
     5  # You may obtain a copy of the License at
     6  #
     7  #     https://www.apache.org/licenses/LICENSE-2.0
     8  #
     9  # Unless required by applicable law or agreed to in writing, software
    10  # distributed under the License is distributed on an "AS IS" BASIS,
    11  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  # See the License for the specific language governing permissions and
    13  # limitations under the License.
    14  
    15  """Tests for grpcservice.client.client."""
    16  
    17  import datetime
    18  import threading
    19  import time
    20  from unittest import mock
    21  
    22  from absl.testing import absltest
    23  from fleetspeak.server_connector import connector
    24  from fleetspeak.src.common.proto.fleetspeak import common_pb2
    25  from fleetspeak.src.server.proto.fleetspeak_server import admin_pb2
    26  from fleetspeak.src.server.proto.fleetspeak_server import admin_pb2_grpc
    27  import grpc
    28  import grpc_testing
    29  
    30  
    31  # Keeping additional reference to datetime.datetime, as it gets mocked
    32  # in some tests below.
    33  orig_datetime_datetime = datetime.datetime
    34  
    35  
    36  class RetryLoopTest(absltest.TestCase):
    37  
    38    @mock.patch.object(time, "sleep")
    39    @mock.patch.object(time, "time", return_value=0)
    40    def testNotSleepingOnFirstSuccessfulCall(self, time_mock, sleep_mock):
    41      func = mock.Mock(return_value=42)
    42  
    43      result = connector.RetryLoop(
    44          func,
    45          timeout=datetime.timedelta(seconds=10.5),
    46          single_try_timeout=datetime.timedelta(seconds=1),
    47      )
    48  
    49      func.assert_called_once()
    50      sleep_mock.assert_not_called()
    51  
    52      self.assertEqual(result, 42)
    53  
    54    @mock.patch.object(time, "sleep")
    55    @mock.patch.object(datetime, "datetime", wraps=orig_datetime_datetime)
    56    def testSingleTryTimeoutIsUsedForCalls(self, datetime_mock, sleep_mock):
    57      cur_time = 0
    58  
    59      def SleepMock(v: float) -> None:
    60        nonlocal cur_time
    61        cur_time += v
    62  
    63      sleep_mock.side_effect = SleepMock
    64      datetime_mock.now = mock.Mock(
    65          side_effect=lambda: orig_datetime_datetime.fromtimestamp(cur_time)
    66      )
    67  
    68      def Func(timeout: datetime.timedelta) -> None:
    69        nonlocal cur_time
    70        cur_time += timeout.total_seconds()
    71        raise grpc.RpcError("error")
    72  
    73      func = mock.Mock(wraps=Func)
    74  
    75      with self.assertRaises(grpc.RpcError):
    76        connector.RetryLoop(
    77            func,
    78            timeout=datetime.timedelta(seconds=10.5),
    79            single_try_timeout=datetime.timedelta(seconds=1),
    80        )
    81  
    82      # Expected timeline:
    83      # 0:    func(1)
    84      # 1:    sleep(1)
    85      # 2:    func(1)
    86      # 3:    sleep(2)
    87      # 5:    func(1):
    88      # 6:    sleep(4)
    89      # 10:   func(0.5)
    90      # 10.5: -> done
    91      self.assertListEqual(
    92          [c.args[0].total_seconds() for c in func.call_args_list], [1, 1, 1, 0.5]
    93      )
    94  
    95    @mock.patch.object(time, "sleep")
    96    @mock.patch.object(datetime, "datetime", wraps=orig_datetime_datetime)
    97    def testDefaultSingleTryTimeoutIsEqualToDefaultTimeout(
    98        self, datetime_mock, sleep_mock
    99    ):
   100      cur_time = 0
   101  
   102      def SleepMock(v: float) -> None:
   103        nonlocal cur_time
   104        cur_time += v
   105  
   106      sleep_mock.side_effect = SleepMock
   107      datetime_mock.now = mock.Mock(
   108          side_effect=lambda: orig_datetime_datetime.fromtimestamp(cur_time)
   109      )
   110  
   111      def Func(timeout: datetime.timedelta) -> None:
   112        nonlocal cur_time
   113        cur_time += timeout.total_seconds()
   114        raise grpc.RpcError("error")
   115  
   116      func = mock.Mock(wraps=Func)
   117  
   118      with self.assertRaises(grpc.RpcError):
   119        connector.RetryLoop(func, timeout=datetime.timedelta(seconds=10))
   120  
   121      # Expected timeline:
   122      # 0:  func(10)
   123      # 10: -> done
   124      func.assert_called_once_with(datetime.timedelta(seconds=10))
   125  
   126  
   127  class ClientTest(absltest.TestCase):
   128  
   129    def _fakeStub(self):
   130      return mock.create_autospec(
   131          admin_pb2_grpc.AdminStub(
   132              grpc_testing.channel(
   133                  [],
   134                  grpc_testing.strict_real_time(),
   135              )
   136          )
   137      )
   138  
   139    def testKeepAlive(self):
   140      event = threading.Event()
   141  
   142      t = self._fakeStub()
   143      t.KeepAlive.side_effect = lambda *args, **kwargs: event.set()
   144  
   145      s = connector.OutgoingConnection(None, "test", t)
   146      self.assertTrue(event.wait(10))
   147  
   148      s.Shutdown()
   149  
   150    def testInsertMessageIsDelegatedToStub(self):
   151      t = self._fakeStub()
   152      s = connector.OutgoingConnection(None, "test", t)
   153      s.InsertMessage(common_pb2.Message())
   154  
   155      t.InsertMessage.assert_called_once()
   156      message = t.InsertMessage.call_args.args[0]
   157      self.assertEqual(message.source.service_name, "test")
   158  
   159    def testInsertMessageIsRetried(self):
   160      t = self._fakeStub()
   161      t.InsertMessage.side_effect = [grpc.RpcError("error"), mock.DEFAULT]
   162  
   163      s = connector.OutgoingConnection(None, "test", t)
   164      s.InsertMessage(common_pb2.Message())
   165  
   166      self.assertEqual(t.InsertMessage.call_count, 2)
   167  
   168    def testDeletePendingMessagesIsDelegatedToStub(self):
   169      t = self._fakeStub()
   170      s = connector.OutgoingConnection(None, "test", t)
   171      s.DeletePendingMessages(admin_pb2.DeletePendingMessagesRequest())
   172  
   173      t.DeletePendingMessages.assert_called_once()
   174  
   175    def testDeletePendingMessagesIsRetried(self):
   176      t = self._fakeStub()
   177      t.DeletePendingMessages.side_effect = [grpc.RpcError("error"), mock.DEFAULT]
   178  
   179      s = connector.OutgoingConnection(None, "test", t)
   180      s.DeletePendingMessages(admin_pb2.DeletePendingMessagesRequest())
   181  
   182      self.assertEqual(t.DeletePendingMessages.call_count, 2)
   183  
   184    def testGetPendingMessagesIsDelegatedToStub(self):
   185      t = self._fakeStub()
   186      s = connector.OutgoingConnection(None, "test", t)
   187      s.GetPendingMessages(admin_pb2.GetPendingMessagesRequest())
   188  
   189      t.GetPendingMessages.assert_called_once()
   190  
   191    def testGetPendingMessagesIsRetried(self):
   192      t = self._fakeStub()
   193      t.GetPendingMessages.side_effect = [grpc.RpcError("error"), mock.DEFAULT]
   194  
   195      s = connector.OutgoingConnection(None, "test", t)
   196      s.GetPendingMessages(admin_pb2.GetPendingMessagesRequest())
   197  
   198      self.assertEqual(t.GetPendingMessages.call_count, 2)
   199  
   200    def testGetPendingMessageCountIsDelegatedToStub(self):
   201      t = self._fakeStub()
   202      s = connector.OutgoingConnection(None, "test", t)
   203      s.GetPendingMessageCount(admin_pb2.GetPendingMessageCountRequest())
   204  
   205      t.GetPendingMessageCount.assert_called_once()
   206  
   207    def testGetPendingMessageCountIsRetried(self):
   208      t = self._fakeStub()
   209      t.GetPendingMessageCount.side_effect = [
   210          grpc.RpcError("error"),
   211          mock.DEFAULT,
   212      ]
   213  
   214      s = connector.OutgoingConnection(None, "test", t)
   215      s.GetPendingMessageCount(admin_pb2.GetPendingMessageCountRequest())
   216  
   217      self.assertEqual(t.GetPendingMessageCount.call_count, 2)
   218  
   219  
   220  if __name__ == "__main__":
   221    absltest.main()