github.com/google/fleetspeak@v0.1.15-0.20240426164851-4f31f62c1aea/fleetspeak_python/fleetspeak/server_connector/connector_test.py (about) 1 # Copyright 2017 Google Inc. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # https://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 15 """Tests for grpcservice.client.client.""" 16 17 import datetime 18 import threading 19 import time 20 from unittest import mock 21 22 from absl.testing import absltest 23 from fleetspeak.server_connector import connector 24 from fleetspeak.src.common.proto.fleetspeak import common_pb2 25 from fleetspeak.src.server.proto.fleetspeak_server import admin_pb2 26 from fleetspeak.src.server.proto.fleetspeak_server import admin_pb2_grpc 27 import grpc 28 import grpc_testing 29 30 31 # Keeping additional reference to datetime.datetime, as it gets mocked 32 # in some tests below. 33 orig_datetime_datetime = datetime.datetime 34 35 36 class RetryLoopTest(absltest.TestCase): 37 38 @mock.patch.object(time, "sleep") 39 @mock.patch.object(time, "time", return_value=0) 40 def testNotSleepingOnFirstSuccessfulCall(self, time_mock, sleep_mock): 41 func = mock.Mock(return_value=42) 42 43 result = connector.RetryLoop( 44 func, 45 timeout=datetime.timedelta(seconds=10.5), 46 single_try_timeout=datetime.timedelta(seconds=1), 47 ) 48 49 func.assert_called_once() 50 sleep_mock.assert_not_called() 51 52 self.assertEqual(result, 42) 53 54 @mock.patch.object(time, "sleep") 55 @mock.patch.object(datetime, "datetime", wraps=orig_datetime_datetime) 56 def testSingleTryTimeoutIsUsedForCalls(self, datetime_mock, sleep_mock): 57 cur_time = 0 58 59 def SleepMock(v: float) -> None: 60 nonlocal cur_time 61 cur_time += v 62 63 sleep_mock.side_effect = SleepMock 64 datetime_mock.now = mock.Mock( 65 side_effect=lambda: orig_datetime_datetime.fromtimestamp(cur_time) 66 ) 67 68 def Func(timeout: datetime.timedelta) -> None: 69 nonlocal cur_time 70 cur_time += timeout.total_seconds() 71 raise grpc.RpcError("error") 72 73 func = mock.Mock(wraps=Func) 74 75 with self.assertRaises(grpc.RpcError): 76 connector.RetryLoop( 77 func, 78 timeout=datetime.timedelta(seconds=10.5), 79 single_try_timeout=datetime.timedelta(seconds=1), 80 ) 81 82 # Expected timeline: 83 # 0: func(1) 84 # 1: sleep(1) 85 # 2: func(1) 86 # 3: sleep(2) 87 # 5: func(1): 88 # 6: sleep(4) 89 # 10: func(0.5) 90 # 10.5: -> done 91 self.assertListEqual( 92 [c.args[0].total_seconds() for c in func.call_args_list], [1, 1, 1, 0.5] 93 ) 94 95 @mock.patch.object(time, "sleep") 96 @mock.patch.object(datetime, "datetime", wraps=orig_datetime_datetime) 97 def testDefaultSingleTryTimeoutIsEqualToDefaultTimeout( 98 self, datetime_mock, sleep_mock 99 ): 100 cur_time = 0 101 102 def SleepMock(v: float) -> None: 103 nonlocal cur_time 104 cur_time += v 105 106 sleep_mock.side_effect = SleepMock 107 datetime_mock.now = mock.Mock( 108 side_effect=lambda: orig_datetime_datetime.fromtimestamp(cur_time) 109 ) 110 111 def Func(timeout: datetime.timedelta) -> None: 112 nonlocal cur_time 113 cur_time += timeout.total_seconds() 114 raise grpc.RpcError("error") 115 116 func = mock.Mock(wraps=Func) 117 118 with self.assertRaises(grpc.RpcError): 119 connector.RetryLoop(func, timeout=datetime.timedelta(seconds=10)) 120 121 # Expected timeline: 122 # 0: func(10) 123 # 10: -> done 124 func.assert_called_once_with(datetime.timedelta(seconds=10)) 125 126 127 class ClientTest(absltest.TestCase): 128 129 def _fakeStub(self): 130 return mock.create_autospec( 131 admin_pb2_grpc.AdminStub( 132 grpc_testing.channel( 133 [], 134 grpc_testing.strict_real_time(), 135 ) 136 ) 137 ) 138 139 def testKeepAlive(self): 140 event = threading.Event() 141 142 t = self._fakeStub() 143 t.KeepAlive.side_effect = lambda *args, **kwargs: event.set() 144 145 s = connector.OutgoingConnection(None, "test", t) 146 self.assertTrue(event.wait(10)) 147 148 s.Shutdown() 149 150 def testInsertMessageIsDelegatedToStub(self): 151 t = self._fakeStub() 152 s = connector.OutgoingConnection(None, "test", t) 153 s.InsertMessage(common_pb2.Message()) 154 155 t.InsertMessage.assert_called_once() 156 message = t.InsertMessage.call_args.args[0] 157 self.assertEqual(message.source.service_name, "test") 158 159 def testInsertMessageIsRetried(self): 160 t = self._fakeStub() 161 t.InsertMessage.side_effect = [grpc.RpcError("error"), mock.DEFAULT] 162 163 s = connector.OutgoingConnection(None, "test", t) 164 s.InsertMessage(common_pb2.Message()) 165 166 self.assertEqual(t.InsertMessage.call_count, 2) 167 168 def testDeletePendingMessagesIsDelegatedToStub(self): 169 t = self._fakeStub() 170 s = connector.OutgoingConnection(None, "test", t) 171 s.DeletePendingMessages(admin_pb2.DeletePendingMessagesRequest()) 172 173 t.DeletePendingMessages.assert_called_once() 174 175 def testDeletePendingMessagesIsRetried(self): 176 t = self._fakeStub() 177 t.DeletePendingMessages.side_effect = [grpc.RpcError("error"), mock.DEFAULT] 178 179 s = connector.OutgoingConnection(None, "test", t) 180 s.DeletePendingMessages(admin_pb2.DeletePendingMessagesRequest()) 181 182 self.assertEqual(t.DeletePendingMessages.call_count, 2) 183 184 def testGetPendingMessagesIsDelegatedToStub(self): 185 t = self._fakeStub() 186 s = connector.OutgoingConnection(None, "test", t) 187 s.GetPendingMessages(admin_pb2.GetPendingMessagesRequest()) 188 189 t.GetPendingMessages.assert_called_once() 190 191 def testGetPendingMessagesIsRetried(self): 192 t = self._fakeStub() 193 t.GetPendingMessages.side_effect = [grpc.RpcError("error"), mock.DEFAULT] 194 195 s = connector.OutgoingConnection(None, "test", t) 196 s.GetPendingMessages(admin_pb2.GetPendingMessagesRequest()) 197 198 self.assertEqual(t.GetPendingMessages.call_count, 2) 199 200 def testGetPendingMessageCountIsDelegatedToStub(self): 201 t = self._fakeStub() 202 s = connector.OutgoingConnection(None, "test", t) 203 s.GetPendingMessageCount(admin_pb2.GetPendingMessageCountRequest()) 204 205 t.GetPendingMessageCount.assert_called_once() 206 207 def testGetPendingMessageCountIsRetried(self): 208 t = self._fakeStub() 209 t.GetPendingMessageCount.side_effect = [ 210 grpc.RpcError("error"), 211 mock.DEFAULT, 212 ] 213 214 s = connector.OutgoingConnection(None, "test", t) 215 s.GetPendingMessageCount(admin_pb2.GetPendingMessageCountRequest()) 216 217 self.assertEqual(t.GetPendingMessageCount.call_count, 2) 218 219 220 if __name__ == "__main__": 221 absltest.main()