github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/engine/pkg/dm/message_agent_test.go (about)

     1  // Copyright 2022 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package dm
    15  
    16  import (
    17  	"context"
    18  	"encoding/json"
    19  	"fmt"
    20  	"sync"
    21  	"testing"
    22  	"time"
    23  
    24  	"github.com/pingcap/log"
    25  	"github.com/pingcap/tiflow/engine/framework"
    26  	"github.com/pingcap/tiflow/engine/pkg/p2p"
    27  	"github.com/pingcap/tiflow/pkg/errors"
    28  	"github.com/stretchr/testify/mock"
    29  	"github.com/stretchr/testify/require"
    30  )
    31  
    32  func TestAllocID(t *testing.T) {
    33  	t.Parallel()
    34  
    35  	messageMatcher := newMessageMatcher()
    36  	require.Equal(t, messageID(1), messageMatcher.allocID())
    37  	require.Equal(t, messageID(2), messageMatcher.allocID())
    38  	require.Equal(t, messageID(3), messageMatcher.allocID())
    39  
    40  	var wg sync.WaitGroup
    41  	for i := 0; i < 10; i++ {
    42  		wg.Add(1)
    43  		go func() {
    44  			defer wg.Done()
    45  			for i := 0; i < 100; i++ {
    46  				messageMatcher.allocID()
    47  			}
    48  		}()
    49  	}
    50  	wg.Wait()
    51  	require.Equal(t, messageID(1004), messageMatcher.allocID())
    52  }
    53  
    54  func TestMessageMatcher(t *testing.T) {
    55  	t.Parallel()
    56  
    57  	messageMatcher := newMessageMatcher()
    58  	mockClient := &MockClient{}
    59  	ctx, cancel := context.WithCancel(context.Background())
    60  	defer cancel()
    61  
    62  	clientCtx := context.Background()
    63  
    64  	messageErr := errors.New("message error")
    65  	// synchronous send
    66  	mockClient.On("SendMessage").Return(messageErr).Once()
    67  	resp, err := messageMatcher.sendRequest(ctx, clientCtx, "topic", "command", "request", mockClient)
    68  	require.EqualError(t, err, messageErr.Error())
    69  	require.Nil(t, resp)
    70  	// deadline exceeded
    71  	mockClient.On("SendMessage").Return(nil).Once()
    72  	ctx2, cancel2 := context.WithTimeout(ctx, 100*time.Millisecond)
    73  	defer cancel2()
    74  	resp, err = messageMatcher.sendRequest(ctx2, clientCtx, "topic", "command", "request", mockClient)
    75  	require.EqualError(t, err, context.DeadlineExceeded.Error())
    76  	require.Nil(t, resp)
    77  	// late response
    78  	require.EqualError(t, messageMatcher.onResponse(2, "response"), "request 2 not found")
    79  
    80  	resp2 := "response"
    81  	go func() {
    82  		mockClient.On("SendMessage").Return(nil).Once()
    83  		ctx3, cancel3 := context.WithTimeout(ctx, 5*time.Second)
    84  		defer cancel3()
    85  		resp3, err := messageMatcher.sendRequest(ctx3, clientCtx, "request-topic", "command", "request", mockClient)
    86  		require.NoError(t, err)
    87  		require.Equal(t, "response", resp3)
    88  	}()
    89  
    90  	// send response
    91  	time.Sleep(time.Second)
    92  	mockClient.On("SendMessage").Return(nil).Once()
    93  	require.NoError(t, messageMatcher.sendResponse(ctx, "response-topic", 3, "command", resp2, mockClient))
    94  	require.NoError(t, messageMatcher.onResponse(3, resp2))
    95  
    96  	// duplicate response
    97  	require.Eventually(t, func() bool {
    98  		err := messageMatcher.onResponse(3, resp2)
    99  		return err != nil && err.Error() == fmt.Sprintf("request %d not found", 3)
   100  	}, 5*time.Second, 100*time.Millisecond)
   101  }
   102  
   103  func TestUpdateClient(t *testing.T) {
   104  	messageAgent := NewMessageAgentImpl("", nil, p2p.NewMockMessageHandlerManager(), log.L()).(*MessageAgentImpl)
   105  	workerHandle1 := &framework.MockHandle{WorkerID: "worker1"}
   106  	workerHandle2 := &framework.MockHandle{WorkerID: "worker2"}
   107  
   108  	// add client
   109  	messageAgent.UpdateClient("task1", workerHandle1.Unwrap())
   110  	require.Len(t, messageAgent.clients.clients, 1)
   111  	client, err := messageAgent.getClient("task1")
   112  	require.NoError(t, err)
   113  	require.Equal(t, client, workerHandle1.Unwrap())
   114  	client, err = messageAgent.getClient("task2")
   115  	require.EqualError(t, err, "client task2 not found")
   116  	require.Equal(t, client, nil)
   117  	messageAgent.UpdateClient("task2", workerHandle2.Unwrap())
   118  	require.Len(t, messageAgent.clients.clients, 2)
   119  	client, err = messageAgent.getClient("task1")
   120  	require.NoError(t, err)
   121  	require.Equal(t, client, workerHandle1.Unwrap())
   122  	client, err = messageAgent.getClient("task2")
   123  	require.NoError(t, err)
   124  	require.Equal(t, client, workerHandle2.Unwrap())
   125  
   126  	// remove client
   127  	messageAgent.UpdateClient("task3", nil)
   128  	require.Len(t, messageAgent.clients.clients, 2)
   129  	client, err = messageAgent.getClient("task1")
   130  	require.NoError(t, err)
   131  	require.Equal(t, client, workerHandle1.Unwrap())
   132  	client, err = messageAgent.getClient("task2")
   133  	require.NoError(t, err)
   134  	require.Equal(t, client, workerHandle2.Unwrap())
   135  	messageAgent.RemoveClient("task2")
   136  	require.Len(t, messageAgent.clients.clients, 1)
   137  	client, err = messageAgent.getClient("task1")
   138  	require.NoError(t, err)
   139  	require.Equal(t, client, workerHandle1.Unwrap())
   140  }
   141  
   142  func TestMessageAgent(t *testing.T) {
   143  	messageAgent := NewMessageAgentImpl("id", nil, p2p.NewMockMessageHandlerManager(), log.L()).(*MessageAgentImpl)
   144  	clientID := "client-id"
   145  	mockClient := &MockClient{}
   146  	messageAgent.UpdateClient(clientID, mockClient)
   147  
   148  	require.Error(t, messageAgent.SendMessage(context.Background(), "wrong-id", "command", "msg"), "client wrong-id not found")
   149  	require.Error(t, messageAgent.sendResponse(context.Background(), "wrong-id", 1, "command", "resp"), "client wrong-id not found")
   150  	ret, err := messageAgent.SendRequest(context.Background(), "wrong-id", "command", "request")
   151  	require.EqualError(t, err, "client wrong-id not found")
   152  	require.Nil(t, ret)
   153  
   154  	mockClient.On("SendMessage").Return(nil).Once()
   155  	require.NoError(t, messageAgent.SendMessage(context.Background(), clientID, "command", "msg"))
   156  
   157  	resp := "response"
   158  	go func() {
   159  		mockClient.On("SendMessage").Return(nil).Once()
   160  		ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
   161  		defer cancel()
   162  		resp2, err := messageAgent.SendRequest(ctx, clientID, "command", "request")
   163  		require.NoError(t, err)
   164  		require.Equal(t, resp, resp2)
   165  	}()
   166  
   167  	time.Sleep(time.Second)
   168  	// send response
   169  	mockClient.On("SendMessage").Return(nil).Once()
   170  	require.NoError(t, messageAgent.sendResponse(context.Background(), clientID, 2, "command", "response"))
   171  	messageAgent.onMessage(generateTopic("Client", "Receiver"), message{ID: 2, Type: responseTp, Payload: resp})
   172  }
   173  
   174  func TestMessageHandler(t *testing.T) {
   175  	var (
   176  		clientID           = "client-id"
   177  		receiveID          = "receiver-id"
   178  		topic              = generateTopic(clientID, receiveID)
   179  		msg                = message{ID: 0, Type: messageTp, Command: messageAPI, Payload: &MessageAPIMessage{Msg: "msg"}}
   180  		req                = message{ID: 1, Type: requestTp, Command: requestAPI, Payload: &RequestAPIRequest{Req: "req"}}
   181  		resp               = message{ID: 1, Type: responseTp, Command: requestAPI, Payload: &RequestAPIResponse{Resp: "resp"}}
   182  		wrongMsg           = message{ID: 0, Type: messageTp, Command: wrongAPI, Payload: &MessageAPIMessage{Msg: "msg"}}
   183  		wrongReq           = message{ID: 0, Type: requestTp, Command: wrongAPI, Payload: &MessageAPIMessage{Msg: "msg"}}
   184  		wrongResp          = message{ID: 0, Type: responseTp, Command: wrongAPI, Payload: &MessageAPIMessage{Msg: "msg"}}
   185  		serializeMsg       = &message{}
   186  		serializeReq       = &message{}
   187  		serializeResp      = &message{}
   188  		serializeWrongMsg  = &message{}
   189  		serializeWrongReq  = &message{}
   190  		serializeWrongResp = &message{}
   191  	)
   192  	// mock serializeMessage
   193  	serialize(t, msg, serializeMsg)
   194  	serialize(t, req, serializeReq)
   195  	serialize(t, resp, serializeResp)
   196  	serialize(t, wrongMsg, serializeWrongMsg)
   197  	serialize(t, wrongReq, serializeWrongReq)
   198  	serialize(t, wrongResp, serializeWrongResp)
   199  
   200  	// mock no handler
   201  	messageAgent := NewMessageAgentImpl("id", &MockNothing{}, p2p.NewMockMessageHandlerManager(), log.L()).(*MessageAgentImpl)
   202  	require.EqualError(t, messageAgent.onMessage(topic, serializeMsg), "message handler for command MessageAPI not found")
   203  	require.EqualError(t, messageAgent.onMessage(topic, serializeReq), "request handler for command RequestAPI not found")
   204  	require.EqualError(t, messageAgent.onMessage(topic, serializeResp), "response handler for command RequestAPI not found")
   205  
   206  	// mock has handler
   207  	mockHandler := &MockHanlder{}
   208  	messageAgent = NewMessageAgentImpl("id", mockHandler, p2p.NewMockMessageHandlerManager(), log.L()).(*MessageAgentImpl)
   209  	mockClient := &MockClient{}
   210  	messageAgent.UpdateClient(clientID, mockClient)
   211  	mockClient.On("SendMessage").Return(nil).Once()
   212  	// wrong handler type
   213  	require.Error(t, messageAgent.onMessage(topic, serializeWrongMsg), "wrong message handler type for command WrongAPI")
   214  	require.Error(t, messageAgent.onMessage(topic, serializeWrongReq), "wrong request handler type for command WrongAPI")
   215  	require.Error(t, messageAgent.onMessage(topic, serializeWrongResp), "wrong response handler type for command WrongAPI")
   216  
   217  	// handle message
   218  	mockHandler.On(messageAPI).Return(nil).Once()
   219  	require.NoError(t, messageAgent.onMessage(topic, serializeMsg))
   220  	mockHandler.On(messageAPI).Return(errors.New("error")).Once()
   221  	require.EqualError(t, messageAgent.onMessage(topic, serializeMsg), "error")
   222  	// handle request
   223  	mockHandler.On(requestAPI).Return(&RequestAPIResponse{}, nil).Once()
   224  	require.NoError(t, messageAgent.onMessage(topic, serializeReq))
   225  	mockHandler.On(requestAPI).Return(&RequestAPIResponse{}, errors.New("error")).Once()
   226  	require.EqualError(t, messageAgent.onMessage(topic, serializeReq), "error")
   227  	// handle response
   228  	require.EqualError(t, messageAgent.onMessage(topic, serializeResp), "request 1 not found")
   229  }
   230  
   231  func TestMessageHandlerLifeCycle(t *testing.T) {
   232  	messageAgent := NewMessageAgentImpl("id", nil, p2p.NewMockMessageHandlerManager(), log.L())
   233  	messageAgent.Tick(context.Background())
   234  	messageAgent.UpdateClient("client-id", &framework.MockWorkerHandler{})
   235  	messageAgent.UpdateClient("client-id", nil)
   236  	messageAgent.Close(context.Background())
   237  }
   238  
   239  func serialize(t *testing.T, m message, mPtr *message) {
   240  	bytes, err := json.Marshal(m)
   241  	require.NoError(t, err)
   242  	require.NoError(t, json.Unmarshal(bytes, mPtr))
   243  }
   244  
   245  const (
   246  	wrongAPI   p2p.Topic = "WrongAPI"
   247  	messageAPI p2p.Topic = "MessageAPI"
   248  	requestAPI p2p.Topic = "RequestAPI"
   249  )
   250  
   251  type (
   252  	WrongAPIMessage struct {
   253  		Msg string
   254  	}
   255  	MessageAPIMessage struct {
   256  		Msg string
   257  	}
   258  	RequestAPIRequest struct {
   259  		Req string
   260  	}
   261  	RequestAPIResponse struct {
   262  		Resp string
   263  	}
   264  )
   265  
   266  type MockNothing struct{}
   267  
   268  type MockHanlder struct {
   269  	sync.Mutex
   270  	mock.Mock
   271  }
   272  
   273  func (m *MockHanlder) WrongAPI() {}
   274  
   275  func (m *MockHanlder) MessageAPI(ctx context.Context, msg *MessageAPIMessage) error {
   276  	m.Lock()
   277  	defer m.Unlock()
   278  	args := m.Called()
   279  	return args.Error(0)
   280  }
   281  
   282  func (m *MockHanlder) RequestAPI(ctx context.Context, req *RequestAPIRequest) (*RequestAPIResponse, error) {
   283  	m.Lock()
   284  	defer m.Unlock()
   285  	args := m.Called()
   286  	return args.Get(0).(*RequestAPIResponse), args.Error(1)
   287  }
   288  
   289  type MockClient struct {
   290  	sync.Mutex
   291  	mock.Mock
   292  }
   293  
   294  func (s *MockClient) SendMessage(ctx context.Context, topic p2p.Topic, message interface{}, nonblocking bool) error {
   295  	s.Lock()
   296  	defer s.Unlock()
   297  	args := s.Called()
   298  	return args.Error(0)
   299  }