github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/engine/pkg/dm/message_agent_test.go (about) 1 // Copyright 2022 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package dm 15 16 import ( 17 "context" 18 "encoding/json" 19 "fmt" 20 "sync" 21 "testing" 22 "time" 23 24 "github.com/pingcap/log" 25 "github.com/pingcap/tiflow/engine/framework" 26 "github.com/pingcap/tiflow/engine/pkg/p2p" 27 "github.com/pingcap/tiflow/pkg/errors" 28 "github.com/stretchr/testify/mock" 29 "github.com/stretchr/testify/require" 30 ) 31 32 func TestAllocID(t *testing.T) { 33 t.Parallel() 34 35 messageMatcher := newMessageMatcher() 36 require.Equal(t, messageID(1), messageMatcher.allocID()) 37 require.Equal(t, messageID(2), messageMatcher.allocID()) 38 require.Equal(t, messageID(3), messageMatcher.allocID()) 39 40 var wg sync.WaitGroup 41 for i := 0; i < 10; i++ { 42 wg.Add(1) 43 go func() { 44 defer wg.Done() 45 for i := 0; i < 100; i++ { 46 messageMatcher.allocID() 47 } 48 }() 49 } 50 wg.Wait() 51 require.Equal(t, messageID(1004), messageMatcher.allocID()) 52 } 53 54 func TestMessageMatcher(t *testing.T) { 55 t.Parallel() 56 57 messageMatcher := newMessageMatcher() 58 mockClient := &MockClient{} 59 ctx, cancel := context.WithCancel(context.Background()) 60 defer cancel() 61 62 clientCtx := context.Background() 63 64 messageErr := errors.New("message error") 65 // synchronous send 66 mockClient.On("SendMessage").Return(messageErr).Once() 67 resp, err := messageMatcher.sendRequest(ctx, clientCtx, "topic", "command", "request", mockClient) 68 require.EqualError(t, err, messageErr.Error()) 69 require.Nil(t, resp) 70 // deadline exceeded 71 mockClient.On("SendMessage").Return(nil).Once() 72 ctx2, cancel2 := context.WithTimeout(ctx, 100*time.Millisecond) 73 defer cancel2() 74 resp, err = messageMatcher.sendRequest(ctx2, clientCtx, "topic", "command", "request", mockClient) 75 require.EqualError(t, err, context.DeadlineExceeded.Error()) 76 require.Nil(t, resp) 77 // late response 78 require.EqualError(t, messageMatcher.onResponse(2, "response"), "request 2 not found") 79 80 resp2 := "response" 81 go func() { 82 mockClient.On("SendMessage").Return(nil).Once() 83 ctx3, cancel3 := context.WithTimeout(ctx, 5*time.Second) 84 defer cancel3() 85 resp3, err := messageMatcher.sendRequest(ctx3, clientCtx, "request-topic", "command", "request", mockClient) 86 require.NoError(t, err) 87 require.Equal(t, "response", resp3) 88 }() 89 90 // send response 91 time.Sleep(time.Second) 92 mockClient.On("SendMessage").Return(nil).Once() 93 require.NoError(t, messageMatcher.sendResponse(ctx, "response-topic", 3, "command", resp2, mockClient)) 94 require.NoError(t, messageMatcher.onResponse(3, resp2)) 95 96 // duplicate response 97 require.Eventually(t, func() bool { 98 err := messageMatcher.onResponse(3, resp2) 99 return err != nil && err.Error() == fmt.Sprintf("request %d not found", 3) 100 }, 5*time.Second, 100*time.Millisecond) 101 } 102 103 func TestUpdateClient(t *testing.T) { 104 messageAgent := NewMessageAgentImpl("", nil, p2p.NewMockMessageHandlerManager(), log.L()).(*MessageAgentImpl) 105 workerHandle1 := &framework.MockHandle{WorkerID: "worker1"} 106 workerHandle2 := &framework.MockHandle{WorkerID: "worker2"} 107 108 // add client 109 messageAgent.UpdateClient("task1", workerHandle1.Unwrap()) 110 require.Len(t, messageAgent.clients.clients, 1) 111 client, err := messageAgent.getClient("task1") 112 require.NoError(t, err) 113 require.Equal(t, client, workerHandle1.Unwrap()) 114 client, err = messageAgent.getClient("task2") 115 require.EqualError(t, err, "client task2 not found") 116 require.Equal(t, client, nil) 117 messageAgent.UpdateClient("task2", workerHandle2.Unwrap()) 118 require.Len(t, messageAgent.clients.clients, 2) 119 client, err = messageAgent.getClient("task1") 120 require.NoError(t, err) 121 require.Equal(t, client, workerHandle1.Unwrap()) 122 client, err = messageAgent.getClient("task2") 123 require.NoError(t, err) 124 require.Equal(t, client, workerHandle2.Unwrap()) 125 126 // remove client 127 messageAgent.UpdateClient("task3", nil) 128 require.Len(t, messageAgent.clients.clients, 2) 129 client, err = messageAgent.getClient("task1") 130 require.NoError(t, err) 131 require.Equal(t, client, workerHandle1.Unwrap()) 132 client, err = messageAgent.getClient("task2") 133 require.NoError(t, err) 134 require.Equal(t, client, workerHandle2.Unwrap()) 135 messageAgent.RemoveClient("task2") 136 require.Len(t, messageAgent.clients.clients, 1) 137 client, err = messageAgent.getClient("task1") 138 require.NoError(t, err) 139 require.Equal(t, client, workerHandle1.Unwrap()) 140 } 141 142 func TestMessageAgent(t *testing.T) { 143 messageAgent := NewMessageAgentImpl("id", nil, p2p.NewMockMessageHandlerManager(), log.L()).(*MessageAgentImpl) 144 clientID := "client-id" 145 mockClient := &MockClient{} 146 messageAgent.UpdateClient(clientID, mockClient) 147 148 require.Error(t, messageAgent.SendMessage(context.Background(), "wrong-id", "command", "msg"), "client wrong-id not found") 149 require.Error(t, messageAgent.sendResponse(context.Background(), "wrong-id", 1, "command", "resp"), "client wrong-id not found") 150 ret, err := messageAgent.SendRequest(context.Background(), "wrong-id", "command", "request") 151 require.EqualError(t, err, "client wrong-id not found") 152 require.Nil(t, ret) 153 154 mockClient.On("SendMessage").Return(nil).Once() 155 require.NoError(t, messageAgent.SendMessage(context.Background(), clientID, "command", "msg")) 156 157 resp := "response" 158 go func() { 159 mockClient.On("SendMessage").Return(nil).Once() 160 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) 161 defer cancel() 162 resp2, err := messageAgent.SendRequest(ctx, clientID, "command", "request") 163 require.NoError(t, err) 164 require.Equal(t, resp, resp2) 165 }() 166 167 time.Sleep(time.Second) 168 // send response 169 mockClient.On("SendMessage").Return(nil).Once() 170 require.NoError(t, messageAgent.sendResponse(context.Background(), clientID, 2, "command", "response")) 171 messageAgent.onMessage(generateTopic("Client", "Receiver"), message{ID: 2, Type: responseTp, Payload: resp}) 172 } 173 174 func TestMessageHandler(t *testing.T) { 175 var ( 176 clientID = "client-id" 177 receiveID = "receiver-id" 178 topic = generateTopic(clientID, receiveID) 179 msg = message{ID: 0, Type: messageTp, Command: messageAPI, Payload: &MessageAPIMessage{Msg: "msg"}} 180 req = message{ID: 1, Type: requestTp, Command: requestAPI, Payload: &RequestAPIRequest{Req: "req"}} 181 resp = message{ID: 1, Type: responseTp, Command: requestAPI, Payload: &RequestAPIResponse{Resp: "resp"}} 182 wrongMsg = message{ID: 0, Type: messageTp, Command: wrongAPI, Payload: &MessageAPIMessage{Msg: "msg"}} 183 wrongReq = message{ID: 0, Type: requestTp, Command: wrongAPI, Payload: &MessageAPIMessage{Msg: "msg"}} 184 wrongResp = message{ID: 0, Type: responseTp, Command: wrongAPI, Payload: &MessageAPIMessage{Msg: "msg"}} 185 serializeMsg = &message{} 186 serializeReq = &message{} 187 serializeResp = &message{} 188 serializeWrongMsg = &message{} 189 serializeWrongReq = &message{} 190 serializeWrongResp = &message{} 191 ) 192 // mock serializeMessage 193 serialize(t, msg, serializeMsg) 194 serialize(t, req, serializeReq) 195 serialize(t, resp, serializeResp) 196 serialize(t, wrongMsg, serializeWrongMsg) 197 serialize(t, wrongReq, serializeWrongReq) 198 serialize(t, wrongResp, serializeWrongResp) 199 200 // mock no handler 201 messageAgent := NewMessageAgentImpl("id", &MockNothing{}, p2p.NewMockMessageHandlerManager(), log.L()).(*MessageAgentImpl) 202 require.EqualError(t, messageAgent.onMessage(topic, serializeMsg), "message handler for command MessageAPI not found") 203 require.EqualError(t, messageAgent.onMessage(topic, serializeReq), "request handler for command RequestAPI not found") 204 require.EqualError(t, messageAgent.onMessage(topic, serializeResp), "response handler for command RequestAPI not found") 205 206 // mock has handler 207 mockHandler := &MockHanlder{} 208 messageAgent = NewMessageAgentImpl("id", mockHandler, p2p.NewMockMessageHandlerManager(), log.L()).(*MessageAgentImpl) 209 mockClient := &MockClient{} 210 messageAgent.UpdateClient(clientID, mockClient) 211 mockClient.On("SendMessage").Return(nil).Once() 212 // wrong handler type 213 require.Error(t, messageAgent.onMessage(topic, serializeWrongMsg), "wrong message handler type for command WrongAPI") 214 require.Error(t, messageAgent.onMessage(topic, serializeWrongReq), "wrong request handler type for command WrongAPI") 215 require.Error(t, messageAgent.onMessage(topic, serializeWrongResp), "wrong response handler type for command WrongAPI") 216 217 // handle message 218 mockHandler.On(messageAPI).Return(nil).Once() 219 require.NoError(t, messageAgent.onMessage(topic, serializeMsg)) 220 mockHandler.On(messageAPI).Return(errors.New("error")).Once() 221 require.EqualError(t, messageAgent.onMessage(topic, serializeMsg), "error") 222 // handle request 223 mockHandler.On(requestAPI).Return(&RequestAPIResponse{}, nil).Once() 224 require.NoError(t, messageAgent.onMessage(topic, serializeReq)) 225 mockHandler.On(requestAPI).Return(&RequestAPIResponse{}, errors.New("error")).Once() 226 require.EqualError(t, messageAgent.onMessage(topic, serializeReq), "error") 227 // handle response 228 require.EqualError(t, messageAgent.onMessage(topic, serializeResp), "request 1 not found") 229 } 230 231 func TestMessageHandlerLifeCycle(t *testing.T) { 232 messageAgent := NewMessageAgentImpl("id", nil, p2p.NewMockMessageHandlerManager(), log.L()) 233 messageAgent.Tick(context.Background()) 234 messageAgent.UpdateClient("client-id", &framework.MockWorkerHandler{}) 235 messageAgent.UpdateClient("client-id", nil) 236 messageAgent.Close(context.Background()) 237 } 238 239 func serialize(t *testing.T, m message, mPtr *message) { 240 bytes, err := json.Marshal(m) 241 require.NoError(t, err) 242 require.NoError(t, json.Unmarshal(bytes, mPtr)) 243 } 244 245 const ( 246 wrongAPI p2p.Topic = "WrongAPI" 247 messageAPI p2p.Topic = "MessageAPI" 248 requestAPI p2p.Topic = "RequestAPI" 249 ) 250 251 type ( 252 WrongAPIMessage struct { 253 Msg string 254 } 255 MessageAPIMessage struct { 256 Msg string 257 } 258 RequestAPIRequest struct { 259 Req string 260 } 261 RequestAPIResponse struct { 262 Resp string 263 } 264 ) 265 266 type MockNothing struct{} 267 268 type MockHanlder struct { 269 sync.Mutex 270 mock.Mock 271 } 272 273 func (m *MockHanlder) WrongAPI() {} 274 275 func (m *MockHanlder) MessageAPI(ctx context.Context, msg *MessageAPIMessage) error { 276 m.Lock() 277 defer m.Unlock() 278 args := m.Called() 279 return args.Error(0) 280 } 281 282 func (m *MockHanlder) RequestAPI(ctx context.Context, req *RequestAPIRequest) (*RequestAPIResponse, error) { 283 m.Lock() 284 defer m.Unlock() 285 args := m.Called() 286 return args.Get(0).(*RequestAPIResponse), args.Error(1) 287 } 288 289 type MockClient struct { 290 sync.Mutex 291 mock.Mock 292 } 293 294 func (s *MockClient) SendMessage(ctx context.Context, topic p2p.Topic, message interface{}, nonblocking bool) error { 295 s.Lock() 296 defer s.Unlock() 297 args := s.Called() 298 return args.Error(0) 299 }