github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/engine/pkg/dm/message_agent.go (about)

     1  // Copyright 2022 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package dm
    15  
    16  import (
    17  	"context"
    18  	"encoding/hex"
    19  	"encoding/json"
    20  	"path"
    21  	"reflect"
    22  	"strings"
    23  	"sync"
    24  	"time"
    25  
    26  	"github.com/pingcap/tiflow/engine/framework"
    27  	"github.com/pingcap/tiflow/engine/pkg/p2p"
    28  	"github.com/pingcap/tiflow/pkg/errors"
    29  	"github.com/pingcap/tiflow/pkg/logutil"
    30  	"github.com/pingcap/tiflow/pkg/workerpool"
    31  	"go.uber.org/atomic"
    32  	"go.uber.org/zap"
    33  )
    34  
    35  var (
    36  	defaultMessageTimeOut  = time.Second * 10
    37  	defaultRequestTimeOut  = time.Second * 30
    38  	defaultResponseTimeOut = time.Second * 10
    39  	defaultHandlerTimeOut  = time.Second * 30
    40  
    41  	// NewMessageAgent creates a new MessageAgent instance.
    42  	NewMessageAgent = NewMessageAgentImpl
    43  )
    44  
    45  // generateTopic generate dm message topic with hex encoding.
    46  func generateTopic(senderID string, receiverID string) string {
    47  	hexKeys := []string{"DM"}
    48  	hexKeys = append(hexKeys, hex.EncodeToString([]byte(senderID)))
    49  	hexKeys = append(hexKeys, hex.EncodeToString([]byte(receiverID)))
    50  	ret := path.Join(hexKeys...)
    51  	return ret
    52  }
    53  
    54  // extractTopic extract dm message topic with hex decoding.
    55  // TODO: handle error.
    56  func extractTopic(topic string) (string, string) {
    57  	v := strings.Split(strings.TrimPrefix(topic, "DM"), "/")
    58  	// nolint:errcheck
    59  	senderID, _ := hex.DecodeString(v[1])
    60  	// nolint:errcheck
    61  	receiverID, _ := hex.DecodeString(v[2])
    62  	return string(senderID), string(receiverID)
    63  }
    64  
    65  type messageID uint64
    66  
    67  type messageType int
    68  
    69  const (
    70  	messageTp messageType = iota + 1
    71  	requestTp
    72  	responseTp
    73  )
    74  
    75  // message use for asynchronous message and synchronous request/response.
    76  type message struct {
    77  	ID      messageID
    78  	Type    messageType
    79  	Command string
    80  	Payload interface{}
    81  }
    82  
    83  // client defines an interface that supports send message
    84  type client interface {
    85  	SendMessage(ctx context.Context, topic p2p.Topic, message interface{}, nonblocking bool) error
    86  }
    87  
    88  // messageMatcher implement a simple synchronous request/response message matcher since the lib currently only support asynchronous message.
    89  type messageMatcher struct {
    90  	// messageID -> response channel
    91  	// TODO: limit the MaxPendingMessageCount if needed.
    92  	pendings sync.Map
    93  	id       atomic.Uint64
    94  }
    95  
    96  // newMessageMatcher creates a new messageMatcher instance
    97  func newMessageMatcher() *messageMatcher {
    98  	return &messageMatcher{}
    99  }
   100  
   101  func (m *messageMatcher) allocID() messageID {
   102  	return messageID(m.id.Add(1))
   103  }
   104  
   105  // sendRequest sends a request message and wait for response.
   106  func (m *messageMatcher) sendRequest(
   107  	ctx context.Context,
   108  	clientCtx context.Context,
   109  	topic p2p.Topic,
   110  	command string,
   111  	req interface{},
   112  	client client,
   113  ) (interface{}, error) {
   114  	msg := message{ID: m.allocID(), Type: requestTp, Command: command, Payload: req}
   115  	respCh := make(chan interface{}, 1)
   116  	m.pendings.Store(msg.ID, respCh)
   117  	defer m.pendings.Delete(msg.ID)
   118  
   119  	if err := client.SendMessage(ctx, topic, msg, false /* nonblock */); err != nil {
   120  		return nil, err
   121  	}
   122  
   123  	select {
   124  	case <-ctx.Done():
   125  		return nil, ctx.Err()
   126  	case <-clientCtx.Done():
   127  		// todo: it's quite normal to see worker finished during the request, should we return nil, nil here?
   128  		return nil, errors.New("client is finished before receiving response")
   129  	case resp := <-respCh:
   130  		return resp, nil
   131  	}
   132  }
   133  
   134  // sendResponse sends a response with message ID.
   135  func (m *messageMatcher) sendResponse(ctx context.Context, topic p2p.Topic, id messageID, command string, resp interface{}, client client) error {
   136  	msg := message{ID: id, Type: responseTp, Command: command, Payload: resp}
   137  	return client.SendMessage(ctx, topic, msg, false /* nonblock */)
   138  }
   139  
   140  // onResponse receives and pairs a response message.
   141  func (m *messageMatcher) onResponse(id messageID, resp interface{}) error {
   142  	respCh, ok := m.pendings.Load(id)
   143  	if !ok {
   144  		return errors.Errorf("request %d not found", id)
   145  	}
   146  
   147  	select {
   148  	case respCh.(chan interface{}) <- resp:
   149  		return nil
   150  	default:
   151  	}
   152  	return errors.Errorf("duplicated response of request %d, and the last response is not consumed", id)
   153  }
   154  
   155  // MessageAgent defines interface for message communication.
   156  type MessageAgent interface {
   157  	Tick(ctx context.Context) error
   158  	Close(ctx context.Context) error
   159  	// UpdateClient updates the client status.
   160  	// When client online, caller should use this method to with not nil client.
   161  	// When client offline temporary, caller should use this method with nil client.
   162  	UpdateClient(clientID string, client client) error
   163  	// RemoveClient is used when client is offline permanently, or the new client
   164  	// with this clientID should be treated as a different client.
   165  	RemoveClient(clientID string) error
   166  	SendMessage(ctx context.Context, clientID string, command string, msg interface{}) error
   167  	SendRequest(ctx context.Context, clientID string, command string, req interface{}) (interface{}, error)
   168  }
   169  
   170  type clientGroup struct {
   171  	mu sync.RWMutex
   172  	// key is client-id
   173  	clients map[string]client
   174  	ctxs    map[string]context.Context
   175  	cancels map[string]context.CancelFunc
   176  }
   177  
   178  // MessageAgentImpl implements the message processing mechanism.
   179  type MessageAgentImpl struct {
   180  	ctx                   context.Context
   181  	cancel                context.CancelFunc
   182  	logger                *zap.Logger
   183  	messageMatcher        *messageMatcher
   184  	messageHandlerManager p2p.MessageHandlerManager
   185  	pool                  workerpool.AsyncPool
   186  	messageRouter         *framework.MessageRouter
   187  	wg                    sync.WaitGroup
   188  	clients               clientGroup
   189  	// when receive message/request/response,
   190  	// the corresponding processing method of commandHandler will be called according to the command name.
   191  	commandHandler interface{}
   192  	id             string
   193  }
   194  
   195  // NewMessageAgentImpl creates a new MessageAgent instance.
   196  // message agent will call the method of commandHandler by command name automatically.
   197  // The type of method of commandHandler should follow one of below:
   198  // MessageFuncType: func(ctx context.Context, msg *interface{}) error {}
   199  // RequestFuncType(1): func(ctx context.Context, req *interface{}) (resp *interface{}, err error) {}
   200  // RequestFuncType(2): func(ctx context.Context, req *interface{}) (resp *interface{}) {}
   201  func NewMessageAgentImpl(id string, commandHandler interface{}, messageHandlerManager p2p.MessageHandlerManager, pLogger *zap.Logger) MessageAgent {
   202  	agent := &MessageAgentImpl{
   203  		messageMatcher: newMessageMatcher(),
   204  		clients: clientGroup{
   205  			clients: map[string]client{},
   206  			ctxs:    map[string]context.Context{},
   207  			cancels: map[string]context.CancelFunc{},
   208  		},
   209  		commandHandler:        commandHandler,
   210  		messageHandlerManager: messageHandlerManager,
   211  		pool:                  workerpool.NewDefaultAsyncPool(10),
   212  		id:                    id,
   213  		logger:                pLogger.With(zap.String("component", "message-agent")),
   214  	}
   215  	agent.messageRouter = framework.NewMessageRouter(agent.id, agent.pool, 100,
   216  		func(topic p2p.Topic, msg p2p.MessageValue) error {
   217  			err := agent.onMessage(topic, msg)
   218  			if err != nil {
   219  				// Todo: handle error
   220  				agent.logger.Error("failed to handle message", logutil.ShortError(err))
   221  			}
   222  			return err
   223  		},
   224  	)
   225  	agent.ctx, agent.cancel = context.WithCancel(context.Background())
   226  	agent.wg.Add(1)
   227  	go func() {
   228  		defer agent.wg.Done()
   229  		err := agent.pool.Run(agent.ctx)
   230  		agent.logger.Info("workerpool exited", zap.Error(err))
   231  	}()
   232  	return agent
   233  }
   234  
   235  // Tick implements MessageAgent.Tick
   236  func (agent *MessageAgentImpl) Tick(ctx context.Context) error {
   237  	return agent.messageRouter.Tick(ctx)
   238  }
   239  
   240  // Close closes message agent.
   241  func (agent *MessageAgentImpl) Close(ctx context.Context) error {
   242  	if agent.cancel != nil {
   243  		agent.cancel()
   244  	}
   245  	agent.wg.Wait()
   246  
   247  	return nil
   248  }
   249  
   250  // UpdateClient implements MessageAgent.UpdateClient.
   251  func (agent *MessageAgentImpl) UpdateClient(clientID string, client client) error {
   252  	agent.clients.mu.Lock()
   253  	defer agent.clients.mu.Unlock()
   254  
   255  	_, ok := agent.clients.clients[clientID]
   256  	if client == nil && ok {
   257  		// delete client
   258  		if err := agent.unregisterTopic(agent.ctx, clientID); err != nil {
   259  			return err
   260  		}
   261  		delete(agent.clients.clients, clientID)
   262  	} else if client != nil && !ok {
   263  		// add client
   264  		if err := agent.registerTopic(agent.ctx, clientID); err != nil {
   265  			return err
   266  		}
   267  		agent.clients.clients[clientID] = client
   268  
   269  		// don't overwrite existing context, we allow multiple worker share same topic
   270  		if _, ok := agent.clients.ctxs[clientID]; !ok {
   271  			ctx, cancel := context.WithCancel(agent.ctx)
   272  			agent.clients.ctxs[clientID] = ctx
   273  			agent.clients.cancels[clientID] = cancel
   274  		}
   275  
   276  	}
   277  	return nil
   278  }
   279  
   280  // RemoveClient implements MessageAgent.RemoveClient.
   281  func (agent *MessageAgentImpl) RemoveClient(clientID string) error {
   282  	agent.clients.mu.Lock()
   283  	defer agent.clients.mu.Unlock()
   284  
   285  	if err := agent.unregisterTopic(agent.ctx, clientID); err != nil {
   286  		return err
   287  	}
   288  
   289  	delete(agent.clients.clients, clientID)
   290  	delete(agent.clients.ctxs, clientID)
   291  	cancel, ok := agent.clients.cancels[clientID]
   292  	if ok {
   293  		cancel()
   294  		delete(agent.clients.cancels, clientID)
   295  	}
   296  	return nil
   297  }
   298  
   299  func (agent *MessageAgentImpl) getClient(clientID string) (client, error) {
   300  	agent.clients.mu.RLock()
   301  	defer agent.clients.mu.RUnlock()
   302  
   303  	client, ok := agent.clients.clients[clientID]
   304  	if !ok {
   305  		return nil, errors.Errorf("client %s not found", clientID)
   306  	}
   307  	return client, nil
   308  }
   309  
   310  // SendMessage send message asynchronously.
   311  func (agent *MessageAgentImpl) SendMessage(ctx context.Context, clientID string, command string, msg interface{}) error {
   312  	client, err := agent.getClient(clientID)
   313  	if err != nil {
   314  		return err
   315  	}
   316  	ctx2, cancel := context.WithTimeout(ctx, defaultMessageTimeOut)
   317  	defer cancel()
   318  	agent.logger.Debug("send message", zap.String("client-id", clientID), zap.String("command", command), zap.Any("msg", msg))
   319  	return client.SendMessage(ctx2, generateTopic(agent.id, clientID), message{ID: 0, Type: messageTp, Command: command, Payload: msg}, false /* nonblock */)
   320  }
   321  
   322  // SendRequest send request synchronously.
   323  // caller should add its own retry mechanism if needed.
   324  // caller should persist the request itself if needed.
   325  func (agent *MessageAgentImpl) SendRequest(ctx context.Context, clientID string, command string, req interface{}) (interface{}, error) {
   326  	agent.clients.mu.RLock()
   327  	client, ok := agent.clients.clients[clientID]
   328  	clientCtx, ok2 := agent.clients.ctxs[clientID]
   329  	agent.clients.mu.RUnlock()
   330  
   331  	if !ok {
   332  		return nil, errors.Errorf("client %s not found", clientID)
   333  	}
   334  	if !ok2 {
   335  		return nil, errors.Errorf("client %s context not found, this should not happen", clientID)
   336  	}
   337  	ctx2, cancel := context.WithTimeout(ctx, defaultRequestTimeOut)
   338  	defer cancel()
   339  	agent.logger.Debug("send request", zap.String("client-id", clientID), zap.String("command", command), zap.Any("req", req))
   340  	return agent.messageMatcher.sendRequest(ctx2, clientCtx, generateTopic(agent.id, clientID), command, req, client)
   341  }
   342  
   343  // sendResponse send response asynchronously.
   344  func (agent *MessageAgentImpl) sendResponse(ctx context.Context, clientID string, msgID messageID, command string, resp interface{}) error {
   345  	client, err := agent.getClient(clientID)
   346  	if err != nil {
   347  		return err
   348  	}
   349  	ctx2, cancel := context.WithTimeout(ctx, defaultResponseTimeOut)
   350  	defer cancel()
   351  	agent.logger.Debug("send response", zap.String("client-id", clientID), zap.String("command", command), zap.Any("resp", resp))
   352  	return agent.messageMatcher.sendResponse(ctx2, generateTopic(agent.id, clientID), msgID, command, resp, client)
   353  }
   354  
   355  // onMessage receive message/request/response.
   356  // Forward the response to the corresponding request.
   357  // According to the command, the corresponding message processing function of commandHandler will be called.
   358  // According to the command, the corresponding request processing function of commandHandler will be called, and send the response to caller.
   359  func (agent *MessageAgentImpl) onMessage(topic string, msg interface{}) error {
   360  	agent.logger.Debug("on message", zap.String("topic", topic), zap.Any("msg", msg))
   361  	m, ok := msg.(*message)
   362  	if !ok {
   363  		return errors.Errorf("unknown message type of topic %s", topic)
   364  	}
   365  
   366  	switch m.Type {
   367  	case responseTp:
   368  		return agent.handleResponse(m.ID, m.Command, m.Payload)
   369  	case requestTp:
   370  		clientID, _ := extractTopic(topic)
   371  		return agent.handleRequest(clientID, m.ID, m.Command, m.Payload)
   372  	default:
   373  		return agent.handleMessage(m.Command, m.Payload)
   374  	}
   375  }
   376  
   377  // handleResponse receive response.
   378  func (agent *MessageAgentImpl) handleResponse(id messageID, command string, resp interface{}) error {
   379  	handler := reflect.ValueOf(agent.commandHandler).MethodByName(command)
   380  	if !handler.IsValid() {
   381  		return errors.Errorf("response handler for command %s not found", command)
   382  	}
   383  	handlerType := handler.Type()
   384  	if handlerType.NumOut() != 1 && handlerType.NumOut() != 2 {
   385  		return errors.Errorf("wrong response handler type for command %s", command)
   386  	}
   387  	ret := reflect.New(handlerType.Out(0).Elem())
   388  	if bytes, err := json.Marshal(resp); err != nil {
   389  		return err
   390  	} else if err := json.Unmarshal(bytes, ret.Interface()); err != nil {
   391  		return err
   392  	}
   393  	return agent.messageMatcher.onResponse(id, ret.Interface())
   394  }
   395  
   396  // handleRequest receive request, call request handler and send response.
   397  func (agent *MessageAgentImpl) handleRequest(clientID string, msgID messageID, command string, req interface{}) error {
   398  	handler := reflect.ValueOf(agent.commandHandler).MethodByName(command)
   399  	if !handler.IsValid() {
   400  		return errors.Errorf("request handler for command %s not found", command)
   401  	}
   402  	handlerType := handler.Type()
   403  	if handlerType.NumIn() != 2 || (handlerType.NumOut() != 1 && handlerType.NumOut() != 2) {
   404  		return errors.Errorf("wrong request handler type for command %s", command)
   405  	}
   406  	arg := reflect.New(handlerType.In(1).Elem())
   407  	if bytes, err := json.Marshal(req); err != nil {
   408  		return err
   409  	} else if err := json.Unmarshal(bytes, arg.Interface()); err != nil {
   410  		return err
   411  	}
   412  
   413  	// call request handler
   414  	ctx, cancel := context.WithTimeout(agent.ctx, defaultHandlerTimeOut)
   415  	defer cancel()
   416  	params := []reflect.Value{reflect.ValueOf(ctx), arg}
   417  	rets := handler.Call(params)
   418  	if len(rets) == 2 && rets[1].Interface() != nil {
   419  		return rets[1].Interface().(error)
   420  	}
   421  	// send response
   422  	ctx2, cancel2 := context.WithTimeout(agent.ctx, defaultResponseTimeOut)
   423  	defer cancel2()
   424  	return agent.sendResponse(ctx2, clientID, msgID, command, rets[0].Interface())
   425  }
   426  
   427  // handle message receive message and call message handler.
   428  func (agent *MessageAgentImpl) handleMessage(command string, msg interface{}) error {
   429  	handler := reflect.ValueOf(agent.commandHandler).MethodByName(command)
   430  	if !handler.IsValid() {
   431  		return errors.Errorf("message handler for command %s not found", command)
   432  	}
   433  	handlerType := handler.Type()
   434  	if handlerType.NumIn() != 2 || handlerType.NumOut() != 1 {
   435  		return errors.Errorf("wrong message handler type for command %s", command)
   436  	}
   437  	arg := reflect.New(handlerType.In(1).Elem())
   438  	if bytes, err := json.Marshal(msg); err != nil {
   439  		return err
   440  	} else if err := json.Unmarshal(bytes, arg.Interface()); err != nil {
   441  		return err
   442  	}
   443  
   444  	// call message handler
   445  	ctx, cancel := context.WithTimeout(agent.ctx, defaultHandlerTimeOut)
   446  	defer cancel()
   447  	params := []reflect.Value{reflect.ValueOf(ctx), arg}
   448  	err := handler.Call(params)[0].Interface()
   449  	if err == nil {
   450  		return nil
   451  	}
   452  	return err.(error)
   453  }
   454  
   455  // registerTopic register p2p topic.
   456  func (agent *MessageAgentImpl) registerTopic(ctx context.Context, clientID string) error {
   457  	topic := generateTopic(clientID, agent.id)
   458  	agent.logger.Debug("register topic", zap.String("topic", topic))
   459  	_, err := agent.messageHandlerManager.RegisterHandler(
   460  		ctx,
   461  		topic,
   462  		&message{},
   463  		func(client p2p.NodeID, msg p2p.MessageValue) error {
   464  			agent.messageRouter.AppendMessage(topic, msg)
   465  			return nil
   466  		},
   467  	)
   468  	return err
   469  }
   470  
   471  // unregisterTopic unregister p2p topic.
   472  func (agent *MessageAgentImpl) unregisterTopic(ctx context.Context, clientID string) error {
   473  	agent.logger.Debug("unregister topic", zap.String("topic", generateTopic(clientID, agent.id)))
   474  	_, err := agent.messageHandlerManager.UnregisterHandler(ctx, generateTopic(clientID, agent.id))
   475  	return err
   476  }