github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/pkg/p2p/server.go (about)

     1  // Copyright 2021 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package p2p
    15  
    16  import (
    17  	"context"
    18  	"reflect"
    19  	"sync"
    20  	"sync/atomic"
    21  	"time"
    22  
    23  	"github.com/pingcap/errors"
    24  	"github.com/pingcap/failpoint"
    25  	"github.com/pingcap/log"
    26  	cerror "github.com/pingcap/tiflow/pkg/errors"
    27  	"github.com/pingcap/tiflow/pkg/workerpool"
    28  	"github.com/pingcap/tiflow/proto/p2p"
    29  	"github.com/prometheus/client_golang/prometheus"
    30  	"go.uber.org/zap"
    31  	"golang.org/x/sync/errgroup"
    32  	"golang.org/x/time/rate"
    33  	"google.golang.org/grpc/codes"
    34  	gRPCPeer "google.golang.org/grpc/peer"
    35  	"google.golang.org/grpc/status"
    36  )
    37  
    38  const (
    39  	messageServerReportsIndividualMessageSize = true
    40  )
    41  
    42  // MessageServerConfig stores configurations for the MessageServer
    43  type MessageServerConfig struct {
    44  	// The maximum number of entries to be cached for topics with no handler registered
    45  	MaxPendingMessageCountPerTopic int
    46  	// The maximum number of unhandled internal tasks for the main thread.
    47  	MaxPendingTaskCount int
    48  	// The size of the channel for pending messages before sending them to gRPC.
    49  	SendChannelSize int
    50  	// The interval between ACKs.
    51  	AckInterval time.Duration
    52  	// The size of the goroutine pool for running the handlers.
    53  	WorkerPoolSize int
    54  	// The maximum send rate per stream (per peer).
    55  	SendRateLimitPerStream float64
    56  	// The maximum number of peers acceptable by this server
    57  	MaxPeerCount int
    58  	// Semver of the server. Empty string means no version check.
    59  	ServerVersion string
    60  	// MaxRecvMsgSize is the maximum message size in bytes TiCDC can receive.
    61  	MaxRecvMsgSize int
    62  
    63  	// After a duration of this time if the server doesn't see any activity it
    64  	// pings the client to see if the transport is still alive.
    65  	KeepAliveTime time.Duration
    66  
    67  	// After having pinged for keepalive check, the server waits for a duration
    68  	// of Timeout and if no activity is seen even after that the connection is
    69  	// closed.
    70  	KeepAliveTimeout time.Duration
    71  
    72  	// The maximum time duration to wait before forcefully removing a handler.
    73  	//
    74  	// waitUnregisterHandleTimeout specifies how long to wait for
    75  	// the topic handler to consume all pending messages before
    76  	// forcefully unregister the handler.
    77  	// For a correct implementation of the handler, the time it needs
    78  	// to consume these messages is minimal, as the handler is not
    79  	// expected to block on channels, etc.
    80  	WaitUnregisterHandleTimeoutThreshold time.Duration
    81  }
    82  
    83  // cdcPeer is used to store information on one connected client.
    84  type cdcPeer struct {
    85  	PeerID string // unique ID of the client
    86  
    87  	// Epoch is increased when the client retries.
    88  	// It is used to avoid two streams from the same client racing with each other.
    89  	// This can happen because the MessageServer might not immediately know that
    90  	// a stream has become stale.
    91  	Epoch int64
    92  
    93  	// necessary information on the stream.
    94  	sender *streamHandle
    95  
    96  	// valid says whether the peer is valid.
    97  	// Note that it does not need to be thread-safe
    98  	// because it should only be accessed in MessageServer.run().
    99  	valid bool
   100  
   101  	metricsAckCount prometheus.Counter
   102  }
   103  
   104  func newCDCPeer(senderID NodeID, epoch int64, sender *streamHandle) *cdcPeer {
   105  	return &cdcPeer{
   106  		PeerID: senderID,
   107  		Epoch:  epoch,
   108  		sender: sender,
   109  		valid:  true,
   110  		metricsAckCount: serverAckCount.With(prometheus.Labels{
   111  			"to": senderID,
   112  		}),
   113  	}
   114  }
   115  
   116  func (p *cdcPeer) abort(ctx context.Context, err error) {
   117  	if !p.valid {
   118  		log.Panic("p2p: aborting invalid peer", zap.String("peer", p.PeerID))
   119  	}
   120  
   121  	defer func() {
   122  		p.valid = false
   123  	}()
   124  	if sendErr := p.sender.Send(ctx, errorToRPCResponse(err)); sendErr != nil {
   125  		log.Warn("could not send error to peer", zap.Error(err),
   126  			zap.NamedError("sendErr", sendErr))
   127  		return
   128  	}
   129  	log.Debug("sent error to peer", zap.Error(err))
   130  }
   131  
   132  // MessageServer is an implementation of the gRPC server for the peer-to-peer system
   133  type MessageServer struct {
   134  	serverID NodeID
   135  
   136  	// Each topic has at most one registered event handle,
   137  	// registered with a WorkerPool.
   138  	handlers map[Topic]workerpool.EventHandle
   139  
   140  	peerLock sync.RWMutex
   141  	peers    map[string]*cdcPeer // all currently connected clients
   142  
   143  	// pendingMessages store messages for topics with NO registered handle.
   144  	// This can happen when the server is slow.
   145  	// The upper limit of pending messages is restricted by
   146  	// MaxPendingMessageCountPerTopic in MessageServerConfig.
   147  	pendingMessages map[topicSenderPair][]pendingMessageEntry
   148  
   149  	acks *ackManager
   150  
   151  	// taskQueue is used to store internal tasks MessageServer
   152  	// needs to execute serially.
   153  	taskQueue chan interface{}
   154  
   155  	// The WorkerPool instance used to execute message handlers.
   156  	pool workerpool.WorkerPool
   157  
   158  	isRunning int32 // atomic
   159  	closeCh   chan struct{}
   160  
   161  	config *MessageServerConfig // read only
   162  }
   163  
   164  type taskOnMessageBatch struct {
   165  	// for grpc msgs
   166  	streamMeta     *p2p.StreamMeta
   167  	messageEntries []*p2p.MessageEntry
   168  
   169  	// for internal msgs
   170  	rawMessageEntries []RawMessageEntry
   171  }
   172  
   173  type taskOnRegisterPeer struct {
   174  	sender     *streamHandle
   175  	clientAddr string // for logging
   176  }
   177  
   178  type taskOnDeregisterPeer struct {
   179  	peerID string
   180  }
   181  
   182  type taskOnRegisterHandler struct {
   183  	topic   string
   184  	handler workerpool.EventHandle
   185  	done    chan struct{}
   186  }
   187  
   188  type taskOnDeregisterHandler struct {
   189  	topic string
   190  	done  chan struct{}
   191  }
   192  
   193  // taskDebugDelay is used in unit tests to artificially block the main
   194  // goroutine of the server. It is not used in other places.
   195  type taskDebugDelay struct {
   196  	doneCh chan struct{}
   197  }
   198  
   199  // NewMessageServer creates a new MessageServer
   200  func NewMessageServer(serverID NodeID, config *MessageServerConfig) *MessageServer {
   201  	return &MessageServer{
   202  		serverID:        serverID,
   203  		handlers:        make(map[string]workerpool.EventHandle),
   204  		peers:           make(map[string]*cdcPeer),
   205  		pendingMessages: make(map[topicSenderPair][]pendingMessageEntry),
   206  		acks:            newAckManager(),
   207  		taskQueue:       make(chan interface{}, config.MaxPendingTaskCount),
   208  		pool:            workerpool.NewDefaultWorkerPool(config.WorkerPoolSize),
   209  		closeCh:         make(chan struct{}),
   210  		config:          config,
   211  	}
   212  }
   213  
   214  // Run starts the MessageServer's worker goroutines.
   215  // It must be running to provide the gRPC service.
   216  func (m *MessageServer) Run(ctx context.Context, localCh <-chan RawMessageEntry) error {
   217  	atomic.StoreInt32(&m.isRunning, 1)
   218  	defer func() {
   219  		atomic.StoreInt32(&m.isRunning, 0)
   220  		close(m.closeCh)
   221  	}()
   222  
   223  	errg, ctx := errgroup.WithContext(ctx)
   224  	errg.Go(func() error {
   225  		return errors.Trace(m.run(ctx))
   226  	})
   227  
   228  	errg.Go(func() error {
   229  		return errors.Trace(m.pool.Run(ctx))
   230  	})
   231  
   232  	if localCh != nil {
   233  		errg.Go(func() error {
   234  			return errors.Trace(m.receiveLocalMessage(ctx, localCh))
   235  		})
   236  	}
   237  
   238  	return errg.Wait()
   239  }
   240  
   241  func (m *MessageServer) run(ctx context.Context) error {
   242  	ticker := time.NewTicker(m.config.AckInterval)
   243  	defer ticker.Stop()
   244  
   245  	for {
   246  		failpoint.Inject("ServerInjectTaskDelay", func() {
   247  			log.Info("channel size", zap.Int("len", len(m.taskQueue)))
   248  		})
   249  		select {
   250  		case <-ctx.Done():
   251  			return errors.Trace(ctx.Err())
   252  		case <-ticker.C:
   253  			m.tick(ctx)
   254  		case task := <-m.taskQueue:
   255  			switch task := task.(type) {
   256  			case taskOnMessageBatch:
   257  				for _, entry := range task.rawMessageEntries {
   258  					m.handleRawMessage(ctx, entry)
   259  				}
   260  				for _, entry := range task.messageEntries {
   261  					m.handleMessage(ctx, task.streamMeta, entry)
   262  				}
   263  			case taskOnRegisterHandler:
   264  				// FIXME better error handling here.
   265  				// Notes: registering a handler is not expected to fail unless a context is cancelled.
   266  				// The current error handling here will cause the server to exit, which is not ideal,
   267  				// but will not cause service to be interrupted because the `ctx` involved here will not
   268  				// be cancelled unless the server is exiting.
   269  				m.registerHandler(ctx, task.topic, task.handler, task.done)
   270  				log.Debug("handler registered", zap.String("topic", task.topic))
   271  			case taskOnDeregisterHandler:
   272  				if handler, ok := m.handlers[task.topic]; ok {
   273  					delete(m.handlers, task.topic)
   274  					go func() {
   275  						err := handler.GracefulUnregister(ctx, m.config.WaitUnregisterHandleTimeoutThreshold)
   276  						if err != nil {
   277  							// This can only happen if `ctx` is cancelled or the workerpool
   278  							// fails to unregister the handle in time, which can be caused
   279  							// by inappropriate blocking inside the handler.
   280  							// We use `DPanic` here so that any unexpected blocking can be
   281  							// caught in tests, but in the same time we can provide better
   282  							// resilience in production (`DPanic` does not panic in production).
   283  							//
   284  							// Note: Even if `GracefulUnregister` does fail, the handle is still
   285  							// unregistered, only forcefully.
   286  							log.Warn("failed to gracefully unregister handle",
   287  								zap.Error(err))
   288  						}
   289  						log.Debug("handler deregistered", zap.String("topic", task.topic))
   290  						if task.done != nil {
   291  							close(task.done)
   292  						}
   293  					}()
   294  				} else {
   295  					// This is to make deregistering a handler idempotent.
   296  					// Idempotency here will simplify error handling for the callers of this package.
   297  					log.Warn("handler not found", zap.String("topic", task.topic))
   298  					if task.done != nil {
   299  						close(task.done)
   300  					}
   301  				}
   302  			case taskOnRegisterPeer:
   303  				log.Debug("taskOnRegisterPeer",
   304  					zap.String("sender", task.sender.GetStreamMeta().SenderId),
   305  					zap.Int64("epoch", task.sender.GetStreamMeta().Epoch))
   306  				if err := m.registerPeer(ctx, task.sender, task.clientAddr); err != nil {
   307  					if cerror.ErrPeerMessageStaleConnection.Equal(err) || cerror.ErrPeerMessageDuplicateConnection.Equal(err) {
   308  						// These two errors should not affect other peers
   309  						if err1 := task.sender.Send(ctx, errorToRPCResponse(err)); err1 != nil {
   310  							return errors.Trace(err)
   311  						}
   312  						continue // to handling the next task
   313  					}
   314  					return errors.Trace(err)
   315  				}
   316  			case taskOnDeregisterPeer:
   317  				log.Info("taskOnDeregisterPeer", zap.String("peerID", task.peerID))
   318  				m.deregisterPeerByID(ctx, task.peerID)
   319  			case taskDebugDelay:
   320  				log.Info("taskDebugDelay started")
   321  				select {
   322  				case <-ctx.Done():
   323  					log.Info("taskDebugDelay canceled")
   324  					return errors.Trace(ctx.Err())
   325  				case <-task.doneCh:
   326  				}
   327  				log.Info("taskDebugDelay ended")
   328  			}
   329  		}
   330  	}
   331  }
   332  
   333  func (m *MessageServer) tick(ctx context.Context) {
   334  	var peersToDeregister []*cdcPeer
   335  	defer func() {
   336  		for _, peer := range peersToDeregister {
   337  			// err is nil because the peers are gone already, so sending errors will not succeed.
   338  			m.deregisterPeer(ctx, peer, nil)
   339  		}
   340  	}()
   341  
   342  	m.peerLock.RLock()
   343  	defer m.peerLock.RUnlock()
   344  
   345  	for _, peer := range m.peers {
   346  		var acks []*p2p.Ack
   347  		m.acks.Range(peer.PeerID, func(topic Topic, seq Seq) bool {
   348  			acks = append(acks, &p2p.Ack{
   349  				Topic:   topic,
   350  				LastSeq: seq,
   351  			})
   352  			return true
   353  		})
   354  
   355  		if len(acks) == 0 {
   356  			continue
   357  		}
   358  
   359  		peer.metricsAckCount.Inc()
   360  		err := peer.sender.Send(ctx, p2p.SendMessageResponse{
   361  			Ack:        acks,
   362  			ExitReason: p2p.ExitReason_OK, // ExitReason_Ok means not exiting
   363  		})
   364  		if err != nil {
   365  			log.Warn("sending response to peer failed", zap.Error(err))
   366  			if cerror.ErrPeerMessageInternalSenderClosed.Equal(err) {
   367  				peersToDeregister = append(peersToDeregister, peer)
   368  			}
   369  		}
   370  	}
   371  }
   372  
   373  func (m *MessageServer) deregisterPeer(ctx context.Context, peer *cdcPeer, err error) {
   374  	log.Info("Deregistering peer",
   375  		zap.String("sender", peer.PeerID),
   376  		zap.Int64("epoch", peer.Epoch),
   377  		zap.Error(err))
   378  
   379  	m.peerLock.Lock()
   380  	// TODO add a tombstone state to facilitate GC'ing the acks records associated with the peer.
   381  	delete(m.peers, peer.PeerID)
   382  	m.peerLock.Unlock()
   383  	if err != nil {
   384  		peer.abort(ctx, err)
   385  	}
   386  }
   387  
   388  func (m *MessageServer) deregisterPeerByID(ctx context.Context, peerID string) {
   389  	m.peerLock.Lock()
   390  	peer, ok := m.peers[peerID]
   391  	m.peerLock.Unlock()
   392  	if !ok {
   393  		log.Warn("peer not found", zap.String("peerID", peerID))
   394  		return
   395  	}
   396  	m.deregisterPeer(ctx, peer, nil)
   397  }
   398  
   399  // ScheduleDeregisterPeerTask schedules a task to deregister a peer.
   400  func (m *MessageServer) ScheduleDeregisterPeerTask(ctx context.Context, peerID string) error {
   401  	return m.scheduleTask(ctx, taskOnDeregisterPeer{peerID: peerID})
   402  }
   403  
   404  // We use an empty interface to hold the information on the type of the object
   405  // that we want to deserialize a message to.
   406  // We pass an object of the desired type, and use `reflect.TypeOf` to extract the type,
   407  // and then when we need it, we can use `reflect.New` to allocate a new object of this
   408  // type.
   409  type typeInformation = interface{}
   410  
   411  // SyncAddHandler registers a handler for messages in a given topic and waits for the operation
   412  // to complete.
   413  func (m *MessageServer) SyncAddHandler(
   414  	ctx context.Context,
   415  	topic string,
   416  	tpi typeInformation,
   417  	fn func(string, interface{}) error,
   418  ) (<-chan error, error) {
   419  	doneCh, errCh, err := m.AddHandler(ctx, topic, tpi, fn)
   420  	if err != nil {
   421  		return nil, errors.Trace(err)
   422  	}
   423  	select {
   424  	case <-ctx.Done():
   425  		return nil, errors.Trace(ctx.Err())
   426  	case <-doneCh:
   427  	case <-m.closeCh:
   428  		return nil, cerror.ErrPeerMessageServerClosed.GenWithStackByArgs()
   429  	}
   430  	return errCh, nil
   431  }
   432  
   433  // AddHandler registers a handler for messages in a given topic.
   434  func (m *MessageServer) AddHandler(
   435  	ctx context.Context,
   436  	topic string,
   437  	tpi typeInformation,
   438  	fn func(string, interface{}) error,
   439  ) (chan struct{}, <-chan error, error) {
   440  	tp := reflect.TypeOf(tpi)
   441  
   442  	metricsServerRepeatedMessageCount := serverRepeatedMessageCount.MustCurryWith(prometheus.Labels{
   443  		"topic": topic,
   444  	})
   445  
   446  	poolHandle := m.pool.RegisterEvent(func(ctx context.Context, argsI interface{}) error {
   447  		args, ok := argsI.(poolEventArgs)
   448  		if !ok {
   449  			// Handle message from local.
   450  			if err := fn(m.serverID, argsI); err != nil {
   451  				return errors.Trace(err)
   452  			}
   453  			return nil
   454  		}
   455  		sm := args.streamMeta
   456  		entry := args.entry
   457  		e := reflect.New(tp.Elem()).Interface()
   458  
   459  		lastAck := m.acks.Get(sm.SenderId, entry.GetTopic())
   460  		if lastAck >= entry.Sequence {
   461  			metricsServerRepeatedMessageCount.With(prometheus.Labels{
   462  				"from": sm.SenderAdvertisedAddr,
   463  			}).Inc()
   464  
   465  			log.Debug("skipping peer message",
   466  				zap.String("senderID", sm.SenderId),
   467  				zap.String("topic", topic),
   468  				zap.Int64("skippedSeq", entry.Sequence),
   469  				zap.Int64("lastAck", lastAck))
   470  			return nil
   471  		}
   472  		if lastAck != initAck && entry.Sequence > lastAck+1 {
   473  			// We detected a message loss at seq = (lastAck+1).
   474  			// Note that entry.Sequence == lastAck+1 is actual a requirement
   475  			// on the continuity of sequence numbers, which can be guaranteed
   476  			// by the client locally.
   477  
   478  			// A data loss can only happen if the receiver's handler had failed to
   479  			// unregister before the receiver restarted. This is expected to be
   480  			// rare and indicates problems with the receiver's handler.
   481  			// It is expected to happen only with extreme system latency or buggy code.
   482  			//
   483  			// Reports an error so that the receiver can gracefully exit.
   484  			return cerror.ErrPeerMessageDataLost.GenWithStackByArgs(entry.Topic, lastAck+1)
   485  		}
   486  
   487  		if err := unmarshalMessage(entry.Content, e); err != nil {
   488  			return cerror.WrapError(cerror.ErrPeerMessageDecodeError, err)
   489  		}
   490  
   491  		if err := fn(sm.SenderId, e); err != nil {
   492  			return errors.Trace(err)
   493  		}
   494  
   495  		m.acks.Set(sm.SenderId, entry.GetTopic(), entry.GetSequence())
   496  
   497  		return nil
   498  	}).OnExit(func(err error) {
   499  		log.Warn("topic handler returned error", zap.Error(err))
   500  		_ = m.scheduleTask(ctx, taskOnDeregisterHandler{
   501  			topic: topic,
   502  		})
   503  	})
   504  
   505  	doneCh := make(chan struct{})
   506  
   507  	if err := m.scheduleTask(ctx, taskOnRegisterHandler{
   508  		topic:   topic,
   509  		handler: poolHandle,
   510  		done:    doneCh,
   511  	}); err != nil {
   512  		return nil, nil, errors.Trace(err)
   513  	}
   514  
   515  	return doneCh, poolHandle.ErrCh(), nil
   516  }
   517  
   518  // SyncRemoveHandler removes the registered handler for the given topic and wait
   519  // for the operation to complete.
   520  func (m *MessageServer) SyncRemoveHandler(ctx context.Context, topic string) error {
   521  	doneCh, err := m.RemoveHandler(ctx, topic)
   522  	if err != nil {
   523  		return errors.Trace(err)
   524  	}
   525  
   526  	select {
   527  	case <-ctx.Done():
   528  		return errors.Trace(ctx.Err())
   529  	case <-doneCh:
   530  	case <-m.closeCh:
   531  		log.Debug("message server is closed while a handler is being removed",
   532  			zap.String("topic", topic))
   533  		return nil
   534  	}
   535  
   536  	return nil
   537  }
   538  
   539  // RemoveHandler removes the registered handler for the given topic.
   540  func (m *MessageServer) RemoveHandler(ctx context.Context, topic string) (chan struct{}, error) {
   541  	doneCh := make(chan struct{})
   542  	if err := m.scheduleTask(ctx, taskOnDeregisterHandler{
   543  		topic: topic,
   544  		done:  doneCh,
   545  	}); err != nil {
   546  		return nil, errors.Trace(err)
   547  	}
   548  
   549  	return doneCh, nil
   550  }
   551  
   552  func (m *MessageServer) registerHandler(ctx context.Context, topic string, handler workerpool.EventHandle, doneCh chan struct{}) {
   553  	defer close(doneCh)
   554  
   555  	if _, ok := m.handlers[topic]; ok {
   556  		// allow replacing the handler here would result in behaviors difficult to define.
   557  		// Continuing the program when there is a risk of duplicate handlers will likely
   558  		// result in undefined behaviors, so we panic here.
   559  		log.Panic("duplicate handlers",
   560  			zap.String("topic", topic))
   561  	}
   562  
   563  	m.handlers[topic] = handler
   564  	m.handlePendingMessages(ctx, topic)
   565  }
   566  
   567  // handlePendingMessages must be called with `handlerLock` taken exclusively.
   568  func (m *MessageServer) handlePendingMessages(ctx context.Context, topic string) {
   569  	for key, entries := range m.pendingMessages {
   570  		if key.Topic != topic {
   571  			continue
   572  		}
   573  
   574  		for _, entry := range entries {
   575  			if entry.StreamMeta != nil {
   576  				m.handleMessage(ctx, entry.StreamMeta, entry.Entry)
   577  			} else {
   578  				m.handleRawMessage(ctx, entry.RawEntry)
   579  			}
   580  		}
   581  
   582  		delete(m.pendingMessages, key)
   583  	}
   584  }
   585  
   586  func (m *MessageServer) registerPeer(
   587  	ctx context.Context,
   588  	sender *streamHandle,
   589  	clientIP string,
   590  ) error {
   591  	streamMeta := sender.GetStreamMeta()
   592  
   593  	log.Info("peer connection received",
   594  		zap.String("senderID", streamMeta.SenderId),
   595  		zap.String("senderAdvertiseAddr", streamMeta.SenderAdvertisedAddr),
   596  		zap.String("addr", clientIP),
   597  		zap.Int64("epoch", streamMeta.Epoch))
   598  
   599  	m.peerLock.Lock()
   600  	peer, ok := m.peers[streamMeta.SenderId]
   601  	if !ok {
   602  		peerCount := len(m.peers)
   603  		if peerCount > m.config.MaxPeerCount {
   604  			m.peerLock.Unlock()
   605  			return cerror.ErrPeerMessageToManyPeers.GenWithStackByArgs(peerCount)
   606  		}
   607  		// no existing peer
   608  		m.peers[streamMeta.SenderId] = newCDCPeer(streamMeta.SenderId, streamMeta.Epoch, sender)
   609  		m.peerLock.Unlock()
   610  	} else {
   611  		m.peerLock.Unlock()
   612  		// there is an existing peer
   613  		if peer.Epoch > streamMeta.Epoch {
   614  			log.Warn("incoming connection is stale",
   615  				zap.String("senderID", streamMeta.SenderId),
   616  				zap.String("addr", clientIP),
   617  				zap.Int64("epoch", streamMeta.Epoch))
   618  
   619  			// the current stream is stale
   620  			return cerror.ErrPeerMessageStaleConnection.GenWithStackByArgs(streamMeta.Epoch /* old */, peer.Epoch /* new */)
   621  		} else if peer.Epoch < streamMeta.Epoch {
   622  			err := cerror.ErrPeerMessageStaleConnection.GenWithStackByArgs(peer.Epoch /* old */, streamMeta.Epoch /* new */)
   623  			m.deregisterPeer(ctx, peer, err)
   624  			m.peerLock.Lock()
   625  			m.peers[streamMeta.SenderId] = newCDCPeer(streamMeta.SenderId, streamMeta.Epoch, sender)
   626  			m.peerLock.Unlock()
   627  		} else {
   628  			log.Warn("incoming connection is duplicate",
   629  				zap.String("senderID", streamMeta.SenderId),
   630  				zap.String("addr", clientIP),
   631  				zap.Int64("epoch", streamMeta.Epoch))
   632  
   633  			return cerror.ErrPeerMessageDuplicateConnection.GenWithStackByArgs(streamMeta.Epoch)
   634  		}
   635  	}
   636  
   637  	return nil
   638  }
   639  
   640  func (m *MessageServer) scheduleTask(ctx context.Context, task interface{}) error {
   641  	select {
   642  	case <-ctx.Done():
   643  		return errors.Trace(ctx.Err())
   644  	case m.taskQueue <- task:
   645  	default:
   646  		return cerror.ErrPeerMessageTaskQueueCongested.GenWithStackByArgs()
   647  	}
   648  	return nil
   649  }
   650  
   651  func (m *MessageServer) scheduleTaskBlocking(ctx context.Context, task interface{}) error {
   652  	select {
   653  	case <-ctx.Done():
   654  		return errors.Trace(ctx.Err())
   655  	case m.taskQueue <- task:
   656  	}
   657  	return nil
   658  }
   659  
   660  func (m *MessageServer) receiveLocalMessage(ctx context.Context, localCh <-chan RawMessageEntry) error {
   661  	batchRawMessages := []RawMessageEntry{}
   662  	sendTaskBlocking := func() {
   663  		if len(batchRawMessages) == 0 {
   664  			return
   665  		}
   666  		_ = m.scheduleTaskBlocking(ctx, taskOnMessageBatch{
   667  			rawMessageEntries: batchRawMessages,
   668  		})
   669  		batchRawMessages = []RawMessageEntry{}
   670  	}
   671  
   672  	ticker := time.NewTicker(10 * time.Millisecond)
   673  	for {
   674  		select {
   675  		case <-ctx.Done():
   676  			return errors.Trace(ctx.Err())
   677  		case entry, ok := <-localCh:
   678  			if !ok {
   679  				errMsg := "local server stream closed since the channel is closed"
   680  				return cerror.ErrPeerMessageServerClosed.GenWithStackByArgs(errMsg)
   681  			}
   682  			batchRawMessages = append(batchRawMessages, entry)
   683  
   684  			if len(batchRawMessages) >= 1024 {
   685  				sendTaskBlocking()
   686  			}
   687  		case <-ticker.C:
   688  			sendTaskBlocking()
   689  		}
   690  	}
   691  }
   692  
   693  // SendMessage implements the gRPC call SendMessage.
   694  func (m *MessageServer) SendMessage(stream p2p.CDCPeerToPeer_SendMessageServer) error {
   695  	ctx := stream.Context()
   696  	packet, err := stream.Recv()
   697  	if err != nil {
   698  		return errors.Trace(err)
   699  	}
   700  
   701  	if err := m.verifyStreamMeta(packet.Meta); err != nil {
   702  		msg := errorToRPCResponse(err)
   703  		_ = stream.Send(&msg)
   704  		return errors.Trace(err)
   705  	}
   706  
   707  	metricsServerStreamCount := serverStreamCount.With(prometheus.Labels{
   708  		"from": packet.Meta.SenderAdvertisedAddr,
   709  	})
   710  	metricsServerStreamCount.Add(1)
   711  	defer metricsServerStreamCount.Sub(1)
   712  
   713  	sendCh := make(chan p2p.SendMessageResponse, m.config.SendChannelSize)
   714  	streamHandle := newStreamHandle(packet.Meta, sendCh)
   715  	ctx, cancel := context.WithCancel(ctx)
   716  	defer cancel()
   717  	errg, egCtx := errgroup.WithContext(ctx)
   718  
   719  	// receive messages from the sender
   720  	errg.Go(func() error {
   721  		defer streamHandle.Close()
   722  		clientSocketAddr := unknownPeerLabel
   723  		if p, ok := gRPCPeer.FromContext(egCtx); ok {
   724  			clientSocketAddr = p.Addr.String()
   725  		}
   726  		if err := m.receive(egCtx, clientSocketAddr, stream, streamHandle); err != nil {
   727  			log.Warn("peer-to-peer message handler error", zap.Error(err))
   728  			select {
   729  			case <-egCtx.Done():
   730  				log.Warn("error receiving from peer", zap.Error(egCtx.Err()))
   731  				return errors.Trace(egCtx.Err())
   732  			case sendCh <- errorToRPCResponse(err):
   733  			default:
   734  				log.Warn("sendCh congested, could not send error", zap.Error(err))
   735  				return errors.Trace(err)
   736  			}
   737  		}
   738  		return nil
   739  	})
   740  
   741  	// send acks to the sender
   742  	errg.Go(func() error {
   743  		rl := rate.NewLimiter(rate.Limit(m.config.SendRateLimitPerStream), 1)
   744  		for {
   745  			select {
   746  			case <-ctx.Done():
   747  				return errors.Trace(ctx.Err())
   748  			case resp, ok := <-sendCh:
   749  				if !ok {
   750  					log.Info("peer stream handle is closed",
   751  						zap.String("peerAddr", streamHandle.GetStreamMeta().SenderAdvertisedAddr),
   752  						zap.String("peerID", streamHandle.GetStreamMeta().SenderId))
   753  					// cancel the stream when sendCh is closed
   754  					cancel()
   755  					return nil
   756  				}
   757  				if err := rl.Wait(ctx); err != nil {
   758  					return errors.Trace(err)
   759  				}
   760  				if err := stream.Send(&resp); err != nil {
   761  					return errors.Trace(err)
   762  				}
   763  			}
   764  		}
   765  	})
   766  
   767  	// We need to select on `m.closeCh` and `ctx.Done()` to make sure that
   768  	// the request handler returns when we need it to.
   769  	// We cannot allow `Send` and `Recv` to block the handler when it needs to exit,
   770  	// such as when the MessageServer is exiting due to an error.
   771  	select {
   772  	case <-ctx.Done():
   773  		return status.New(codes.Canceled, "context canceled").Err()
   774  	case <-m.closeCh:
   775  		return status.New(codes.Aborted, "message server is closing").Err()
   776  	}
   777  
   778  	// NB: `errg` will NOT be waited on because due to the limitation of grpc-go, we cannot cancel Send & Recv
   779  	// with contexts, and the only reliable way to cancel these operations is to return the gRPC call handler,
   780  	// namely this function.
   781  }
   782  
   783  func (m *MessageServer) receive(
   784  	ctx context.Context,
   785  	clientSocketAddr string,
   786  	stream p2p.CDCPeerToPeer_SendMessageServer,
   787  	streamHandle *streamHandle,
   788  ) error {
   789  	// We use scheduleTaskBlocking because blocking here is acceptable.
   790  	// Blocking here will cause grpc-go to back propagate the pressure
   791  	// to the client, which is what we want.
   792  	if err := m.scheduleTaskBlocking(ctx, taskOnRegisterPeer{
   793  		sender:     streamHandle,
   794  		clientAddr: clientSocketAddr,
   795  	}); err != nil {
   796  		return errors.Trace(err)
   797  	}
   798  
   799  	metricsServerMessageCount := serverMessageCount.With(prometheus.Labels{
   800  		"from": streamHandle.GetStreamMeta().SenderAdvertisedAddr,
   801  	})
   802  	metricsServerMessageBatchHistogram := serverMessageBatchHistogram.With(prometheus.Labels{
   803  		"from": streamHandle.GetStreamMeta().SenderAdvertisedAddr,
   804  	})
   805  	metricsServerMessageBatchBytesHistogram := serverMessageBatchBytesHistogram.With(prometheus.Labels{
   806  		"from": streamHandle.GetStreamMeta().SenderAdvertisedAddr,
   807  	})
   808  	metricsServerMessageBytesHistogram := serverMessageBytesHistogram.With(prometheus.Labels{
   809  		"from": streamHandle.GetStreamMeta().SenderAdvertisedAddr,
   810  	})
   811  
   812  	for {
   813  		failpoint.Inject("ServerInjectServerRestart", func() {
   814  			_ = stream.Send(&p2p.SendMessageResponse{
   815  				ExitReason: p2p.ExitReason_CONGESTED,
   816  			})
   817  			failpoint.Return(errors.Trace(errors.New("injected error")))
   818  		})
   819  
   820  		packet, err := stream.Recv()
   821  		if err != nil {
   822  			return errors.Trace(err)
   823  		}
   824  
   825  		batchSize := len(packet.GetEntries())
   826  		log.Debug("received packet", zap.String("streamHandle", streamHandle.GetStreamMeta().SenderId),
   827  			zap.Int("numEntries", batchSize))
   828  
   829  		batchBytes := packet.Size()
   830  		metricsServerMessageBatchBytesHistogram.Observe(float64(batchBytes))
   831  		metricsServerMessageBatchHistogram.Observe(float64(batchSize))
   832  		metricsServerMessageCount.Add(float64(batchSize))
   833  
   834  		entries := packet.GetEntries()
   835  		if batchSize > 0 {
   836  			if messageServerReportsIndividualMessageSize /* true for now */ {
   837  				// Note that this can be costly if the number of messages is huge.
   838  				// However, the current usage of this package in TiCDC should not
   839  				// cause any problem, as the messages are for metadata only.
   840  				for _, entry := range entries {
   841  					messageWireSize := entry.Size()
   842  					metricsServerMessageBytesHistogram.Observe(float64(messageWireSize))
   843  				}
   844  			}
   845  
   846  			// See the comment above on why use scheduleTaskBlocking.
   847  			if err := m.scheduleTaskBlocking(ctx, taskOnMessageBatch{
   848  				streamMeta:     streamHandle.GetStreamMeta(),
   849  				messageEntries: packet.GetEntries(),
   850  			}); err != nil {
   851  				return errors.Trace(err)
   852  			}
   853  		}
   854  	}
   855  }
   856  
   857  func (m *MessageServer) handleRawMessage(ctx context.Context, entry RawMessageEntry) {
   858  	handler, ok := m.handlers[entry.topic]
   859  	if !ok {
   860  		// handler not found
   861  		pendingMessageKey := topicSenderPair{
   862  			Topic:    entry.topic,
   863  			SenderID: m.serverID,
   864  		}
   865  		pendingEntries := m.pendingMessages[pendingMessageKey]
   866  		m.pendingMessages[pendingMessageKey] = append(pendingEntries, pendingMessageEntry{
   867  			RawEntry: entry,
   868  		})
   869  		if len(m.pendingMessages[pendingMessageKey]) >= m.config.MaxPendingMessageCountPerTopic {
   870  			delete(m.pendingMessages, pendingMessageKey)
   871  			log.Warn("Topic congested because no handler has been registered", zap.Any("topic", pendingMessageKey))
   872  		}
   873  		return
   874  	}
   875  	// handler is found
   876  	if err := handler.AddEvent(ctx, entry.value); err != nil {
   877  		// just ignore the message if handler returns an error
   878  		errMsg := "Failed to process message due to a handler error"
   879  		log.Debug(errMsg, zap.Error(err), zap.String("topic", entry.topic))
   880  	}
   881  }
   882  
   883  func (m *MessageServer) handleMessage(ctx context.Context, streamMeta *p2p.StreamMeta, entry *p2p.MessageEntry) {
   884  	m.peerLock.RLock()
   885  	peer, ok := m.peers[streamMeta.SenderId]
   886  	m.peerLock.RUnlock()
   887  	if !ok || peer.Epoch != streamMeta.GetEpoch() {
   888  		log.Debug("p2p: message without corresponding peer",
   889  			zap.String("topic", entry.GetTopic()),
   890  			zap.Int64("seq", entry.GetSequence()))
   891  		return
   892  	}
   893  
   894  	// Drop messages from invalid peers
   895  	if !peer.valid {
   896  		return
   897  	}
   898  
   899  	topic := entry.GetTopic()
   900  	pendingMessageKey := topicSenderPair{
   901  		Topic:    topic,
   902  		SenderID: streamMeta.SenderId,
   903  	}
   904  	handler, ok := m.handlers[topic]
   905  	if !ok {
   906  		// handler not found
   907  		pendingEntries := m.pendingMessages[pendingMessageKey]
   908  		if len(pendingEntries) > m.config.MaxPendingMessageCountPerTopic {
   909  			log.Warn("Topic congested because no handler has been registered", zap.String("topic", topic))
   910  			delete(m.pendingMessages, pendingMessageKey)
   911  			m.deregisterPeer(ctx, peer, cerror.ErrPeerMessageTopicCongested.FastGenByArgs())
   912  			return
   913  		}
   914  		m.pendingMessages[pendingMessageKey] = append(pendingEntries, pendingMessageEntry{
   915  			StreamMeta: streamMeta,
   916  			Entry:      entry,
   917  		})
   918  
   919  		return
   920  	}
   921  
   922  	// handler is found
   923  	if err := handler.AddEvent(ctx, poolEventArgs{
   924  		streamMeta: streamMeta,
   925  		entry:      entry,
   926  	}); err != nil {
   927  		log.Warn("Failed to process message due to a handler error",
   928  			zap.Error(err), zap.String("topic", topic))
   929  		m.deregisterPeer(ctx, peer, err)
   930  	}
   931  }
   932  
   933  func (m *MessageServer) verifyStreamMeta(streamMeta *p2p.StreamMeta) error {
   934  	if streamMeta == nil {
   935  		return cerror.ErrPeerMessageIllegalMeta.GenWithStackByArgs()
   936  	}
   937  
   938  	if streamMeta.ReceiverId != m.serverID {
   939  		return cerror.ErrPeerMessageReceiverMismatch.GenWithStackByArgs(
   940  			m.serverID,            // expected
   941  			streamMeta.ReceiverId, // actual
   942  		)
   943  	}
   944  
   945  	return nil
   946  }
   947  
   948  type topicSenderPair struct {
   949  	Topic    string
   950  	SenderID string
   951  }
   952  
   953  type pendingMessageEntry struct {
   954  	// for grpc msgs
   955  	StreamMeta *p2p.StreamMeta
   956  	Entry      *p2p.MessageEntry
   957  
   958  	// for local msgs
   959  	RawEntry RawMessageEntry
   960  }
   961  
   962  func errorToRPCResponse(err error) p2p.SendMessageResponse {
   963  	if cerror.ErrPeerMessageTopicCongested.Equal(err) ||
   964  		cerror.ErrPeerMessageTaskQueueCongested.Equal(err) {
   965  
   966  		return p2p.SendMessageResponse{
   967  			ExitReason:   p2p.ExitReason_CONGESTED,
   968  			ErrorMessage: err.Error(),
   969  		}
   970  	} else if cerror.ErrPeerMessageStaleConnection.Equal(err) {
   971  		return p2p.SendMessageResponse{
   972  			ExitReason:   p2p.ExitReason_STALE_CONNECTION,
   973  			ErrorMessage: err.Error(),
   974  		}
   975  	} else if cerror.ErrPeerMessageReceiverMismatch.Equal(err) {
   976  		return p2p.SendMessageResponse{
   977  			ExitReason:   p2p.ExitReason_CAPTURE_ID_MISMATCH,
   978  			ErrorMessage: err.Error(),
   979  		}
   980  	} else {
   981  		return p2p.SendMessageResponse{
   982  			ExitReason:   p2p.ExitReason_UNKNOWN,
   983  			ErrorMessage: err.Error(),
   984  		}
   985  	}
   986  }
   987  
   988  type poolEventArgs struct {
   989  	streamMeta *p2p.StreamMeta
   990  	entry      *p2p.MessageEntry
   991  }