github.com/matrixorigin/matrixone@v0.7.0/pkg/common/morpc/server.go (about)

     1  // Copyright 2021 - 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package morpc
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"sync"
    21  	"time"
    22  
    23  	"github.com/fagongzi/goetty/v2"
    24  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    25  	"github.com/matrixorigin/matrixone/pkg/common/stopper"
    26  	"github.com/matrixorigin/matrixone/pkg/logutil"
    27  	"go.uber.org/zap"
    28  )
    29  
    30  // WithServerLogger set rpc server logger
    31  func WithServerLogger(logger *zap.Logger) ServerOption {
    32  	return func(rs *server) {
    33  		rs.logger = logger
    34  	}
    35  }
    36  
    37  // WithServerSessionBufferSize set the buffer size of the write response chan.
    38  // Default is 16.
    39  func WithServerSessionBufferSize(size int) ServerOption {
    40  	return func(s *server) {
    41  		s.options.bufferSize = size
    42  	}
    43  }
    44  
    45  // WithServerWriteFilter set write filter func. Input ready to send Messages, output
    46  // is really need to be send Messages.
    47  func WithServerWriteFilter(filter func(Message) bool) ServerOption {
    48  	return func(s *server) {
    49  		s.options.filter = filter
    50  	}
    51  }
    52  
    53  // WithServerGoettyOptions set write filter func. Input ready to send Messages, output
    54  // is really need to be send Messages.
    55  func WithServerGoettyOptions(options ...goetty.Option) ServerOption {
    56  	return func(s *server) {
    57  		s.options.goettyOptions = options
    58  	}
    59  }
    60  
    61  // WithServerBatchSendSize set the maximum number of messages to be sent together
    62  // at each batch. Default is 8.
    63  func WithServerBatchSendSize(size int) ServerOption {
    64  	return func(s *server) {
    65  		s.options.batchSendSize = size
    66  	}
    67  }
    68  
    69  // WithServerDisableAutoCancelContext disable automatic cancel messaging for the context.
    70  // The server will receive RPC messages from the client, each message comes with a Context,
    71  // and morpc will call the handler to process it, and when the handler returns, the Context
    72  // will be auto cancel the context. But in some scenarios, the handler is asynchronous,
    73  // so morpc can't directly cancel the context after the handler returns, otherwise many strange
    74  // problems will occur.
    75  func WithServerDisableAutoCancelContext() ServerOption {
    76  	return func(s *server) {
    77  		s.options.disableAutoCancelContext = true
    78  	}
    79  }
    80  
    81  type server struct {
    82  	name        string
    83  	address     string
    84  	logger      *zap.Logger
    85  	codec       Codec
    86  	application goetty.NetApplication
    87  	stopper     *stopper.Stopper
    88  	handler     func(ctx context.Context, request Message, sequence uint64, cs ClientSession) error
    89  	sessions    *sync.Map // session-id => *clientSession
    90  	options     struct {
    91  		goettyOptions            []goetty.Option
    92  		bufferSize               int
    93  		batchSendSize            int
    94  		filter                   func(Message) bool
    95  		disableAutoCancelContext bool
    96  	}
    97  	pool struct {
    98  		futures *sync.Pool
    99  	}
   100  }
   101  
   102  // NewRPCServer create rpc server with options. After the rpc server starts, one link corresponds to two
   103  // goroutines, one read and one write. All messages to be written are first written to a buffer chan and
   104  // sent to the client by the write goroutine.
   105  func NewRPCServer(name, address string, codec Codec, options ...ServerOption) (RPCServer, error) {
   106  	s := &server{
   107  		name:     name,
   108  		address:  address,
   109  		codec:    codec,
   110  		stopper:  stopper.NewStopper(fmt.Sprintf("rpc-server-%s", name)),
   111  		sessions: &sync.Map{},
   112  	}
   113  	for _, opt := range options {
   114  		opt(s)
   115  	}
   116  	s.adjust()
   117  
   118  	s.options.goettyOptions = append(s.options.goettyOptions,
   119  		goetty.WithSessionCodec(codec),
   120  		goetty.WithSessionLogger(s.logger))
   121  
   122  	app, err := goetty.NewApplication(
   123  		s.address,
   124  		s.onMessage,
   125  		goetty.WithAppLogger(s.logger),
   126  		goetty.WithAppSessionOptions(s.options.goettyOptions...),
   127  	)
   128  	if err != nil {
   129  		s.logger.Error("create rpc server failed",
   130  			zap.Error(err))
   131  		return nil, err
   132  	}
   133  	s.application = app
   134  	s.pool.futures = &sync.Pool{
   135  		New: func() interface{} {
   136  			return newFuture(s.releaseFuture)
   137  		},
   138  	}
   139  	return s, nil
   140  }
   141  
   142  func (s *server) Start() error {
   143  	err := s.application.Start()
   144  	if err != nil {
   145  		s.logger.Fatal("start rpcserver failed",
   146  			zap.Error(err))
   147  		return err
   148  	}
   149  	return nil
   150  }
   151  
   152  func (s *server) Close() error {
   153  	s.stopper.Stop()
   154  	err := s.application.Stop()
   155  	if err != nil {
   156  		s.logger.Error("stop rpcserver failed",
   157  			zap.Error(err))
   158  	}
   159  
   160  	return err
   161  }
   162  
   163  func (s *server) RegisterRequestHandler(handler func(ctx context.Context, request Message, sequence uint64, cs ClientSession) error) {
   164  	s.handler = handler
   165  }
   166  
   167  func (s *server) adjust() {
   168  	s.logger = logutil.Adjust(s.logger).With(zap.String("name", s.name))
   169  	if s.options.batchSendSize == 0 {
   170  		s.options.batchSendSize = 8
   171  	}
   172  	if s.options.bufferSize == 0 {
   173  		s.options.bufferSize = 16
   174  	}
   175  	if s.options.filter == nil {
   176  		s.options.filter = func(messages Message) bool {
   177  			return true
   178  		}
   179  	}
   180  }
   181  
   182  func (s *server) onMessage(rs goetty.IOSession, value any, sequence uint64) error {
   183  	cs, err := s.getSession(rs)
   184  	if err != nil {
   185  		return err
   186  	}
   187  	request := value.(RPCMessage)
   188  	if ce := s.logger.Check(zap.DebugLevel, "received request"); ce != nil {
   189  		ce.Write(zap.Uint64("sequence", sequence),
   190  			zap.String("client", rs.RemoteAddress()),
   191  			zap.Uint64("request-id", request.Message.GetID()),
   192  			zap.String("request", request.Message.DebugString()))
   193  	}
   194  
   195  	// Can't be sure that the Context is properly consumed if disableAutoCancelContext is set to
   196  	// true. So we use the pessimistic wait for the context to time out automatically be canceled
   197  	// behavior here, which may cause some resources to be released more slowly.
   198  	// FIXME: Use the CancelFunc pass to let the handler decide to cancel itself
   199  	if !s.options.disableAutoCancelContext && request.cancel != nil {
   200  		defer request.cancel()
   201  	}
   202  	// get requestID here to avoid data race, because the request maybe released in handler
   203  	requestID := request.Message.GetID()
   204  
   205  	if request.stream &&
   206  		!cs.validateStreamRequest(requestID, request.streamSequence) {
   207  		s.logger.Error("failed to handle stream request",
   208  			zap.Uint32("last-sequence", cs.receivedStreamSequences[requestID]),
   209  			zap.Uint32("current-sequence", request.streamSequence),
   210  			zap.String("client", rs.RemoteAddress()))
   211  		cs.cancelWrite()
   212  		return moerr.NewStreamClosedNoCtx()
   213  	}
   214  
   215  	// handle internal message
   216  	if request.internal {
   217  		if m, ok := request.Message.(*flagOnlyMessage); ok {
   218  			switch m.flag {
   219  			case flagPing:
   220  				return cs.Write(request.Ctx, &flagOnlyMessage{flag: flagPong, id: m.id})
   221  			default:
   222  				panic(fmt.Sprintf("invalid internal message, flag %d", m.flag))
   223  			}
   224  		}
   225  	}
   226  
   227  	if err := s.handler(request.Ctx, request.Message, sequence, cs); err != nil {
   228  		s.logger.Error("handle request failed",
   229  			zap.Uint64("sequence", sequence),
   230  			zap.String("client", rs.RemoteAddress()),
   231  			zap.Error(err))
   232  		cs.cancelWrite()
   233  		return err
   234  	}
   235  
   236  	if ce := s.logger.Check(zap.DebugLevel, "handle request completed"); ce != nil {
   237  		ce.Write(zap.Uint64("sequence", sequence),
   238  			zap.String("client", rs.RemoteAddress()),
   239  			zap.Uint64("request-id", requestID))
   240  	}
   241  	return nil
   242  }
   243  
   244  func (s *server) startWriteLoop(cs *clientSession) error {
   245  	return s.stopper.RunTask(func(ctx context.Context) {
   246  		defer s.closeClientSession(cs)
   247  
   248  		responses := make([]*Future, 0, s.options.batchSendSize)
   249  		fetch := func() {
   250  			for i := 0; i < len(responses); i++ {
   251  				responses[i] = nil
   252  			}
   253  			responses = responses[:0]
   254  
   255  			for i := 0; i < s.options.batchSendSize; i++ {
   256  				if len(responses) == 0 {
   257  					select {
   258  					case <-ctx.Done():
   259  						responses = nil
   260  						return
   261  					case <-cs.ctx.Done():
   262  						responses = nil
   263  						return
   264  					case resp := <-cs.c:
   265  						responses = append(responses, resp)
   266  					}
   267  				} else {
   268  					select {
   269  					case <-ctx.Done():
   270  						responses = nil
   271  						return
   272  					case <-cs.ctx.Done():
   273  						responses = nil
   274  						return
   275  					case resp := <-cs.c:
   276  						responses = append(responses, resp)
   277  					default:
   278  						return
   279  					}
   280  				}
   281  			}
   282  		}
   283  
   284  		for {
   285  			select {
   286  			case <-ctx.Done():
   287  				return
   288  			case <-cs.ctx.Done():
   289  				return
   290  			default:
   291  				fetch()
   292  
   293  				if len(responses) > 0 {
   294  					var fields []zap.Field
   295  					ce := s.logger.Check(zap.DebugLevel, "write responses")
   296  					if ce != nil {
   297  						fields = append(fields, zap.String("client", cs.conn.RemoteAddress()))
   298  					}
   299  
   300  					written := 0
   301  					timeout := time.Duration(0)
   302  					for _, f := range responses {
   303  						if !s.options.filter(f.send.Message) {
   304  							f.messageSended(messageSkipped)
   305  							continue
   306  						}
   307  
   308  						if f.send.Timeout() {
   309  							f.messageSended(f.send.Ctx.Err())
   310  							continue
   311  						}
   312  
   313  						v, err := f.send.GetTimeoutFromContext()
   314  						if err != nil {
   315  							f.messageSended(err)
   316  							continue
   317  						}
   318  
   319  						timeout += v
   320  						// Record the information of some responses in advance, because after flush,
   321  						// these responses will be released, thus avoiding causing data race.
   322  						if ce != nil {
   323  							fields = append(fields, zap.Uint64("request-id",
   324  								f.send.Message.GetID()))
   325  							fields = append(fields, zap.String("response",
   326  								f.send.Message.DebugString()))
   327  						}
   328  						if err := cs.conn.Write(f.send, goetty.WriteOptions{}); err != nil {
   329  							s.logger.Error("write response failed",
   330  								zap.Uint64("request-id", f.send.Message.GetID()),
   331  								zap.Error(err))
   332  							f.messageSended(err)
   333  							return
   334  						}
   335  						written++
   336  					}
   337  
   338  					if written > 0 {
   339  						err := cs.conn.Flush(timeout)
   340  						if err != nil {
   341  							if ce != nil {
   342  								fields = append(fields, zap.Error(err))
   343  							}
   344  							for _, f := range responses {
   345  								if !s.options.filter(f.send.Message) {
   346  									id := f.getSendMessageID()
   347  									s.logger.Error("write response failed",
   348  										zap.Uint64("request-id", id),
   349  										zap.Error(err))
   350  									f.messageSended(err)
   351  								}
   352  							}
   353  						}
   354  						if ce != nil {
   355  							ce.Write(fields...)
   356  						}
   357  					}
   358  
   359  					for _, f := range responses {
   360  						f.messageSended(nil)
   361  					}
   362  				}
   363  			}
   364  		}
   365  	})
   366  }
   367  
   368  func (s *server) closeClientSession(cs *clientSession) {
   369  	s.sessions.Delete(cs.conn.ID())
   370  	if err := cs.Close(); err != nil {
   371  		s.logger.Error("close client session failed",
   372  			zap.Error(err))
   373  	}
   374  }
   375  
   376  func (s *server) getSession(rs goetty.IOSession) (*clientSession, error) {
   377  	if v, ok := s.sessions.Load(rs.ID()); ok {
   378  		return v.(*clientSession), nil
   379  	}
   380  
   381  	cs := newClientSession(rs, s.codec, s.newFuture)
   382  	v, loaded := s.sessions.LoadOrStore(rs.ID(), cs)
   383  	if loaded {
   384  		close(cs.c)
   385  		return v.(*clientSession), nil
   386  	}
   387  
   388  	rs.Ref()
   389  	if err := s.startWriteLoop(cs); err != nil {
   390  		s.closeClientSession(cs)
   391  		return nil, err
   392  	}
   393  	return cs, nil
   394  }
   395  
   396  func (s *server) releaseFuture(f *Future) {
   397  	f.reset()
   398  	s.pool.futures.Put(f)
   399  }
   400  
   401  func (s *server) newFuture() *Future {
   402  	return s.pool.futures.Get().(*Future)
   403  }
   404  
   405  type clientSession struct {
   406  	codec         Codec
   407  	conn          goetty.IOSession
   408  	c             chan *Future
   409  	newFutureFunc func() *Future
   410  	// streaming id -> last received sequence, no concurrent, access in io goroutine
   411  	receivedStreamSequences map[uint64]uint32
   412  	// streaming id -> last sent sequence, multi-stream access in multi-goroutines if
   413  	// the tcp connection is shared. But no concurrent in one stream.
   414  	sentStreamSequences sync.Map
   415  	cancel              context.CancelFunc
   416  	ctx                 context.Context
   417  	mu                  struct {
   418  		sync.RWMutex
   419  		closed bool
   420  	}
   421  }
   422  
   423  func newClientSession(
   424  	conn goetty.IOSession,
   425  	codec Codec,
   426  	newFutureFunc func() *Future) *clientSession {
   427  	ctx, cancel := context.WithCancel(context.Background())
   428  	return &clientSession{
   429  		codec:                   codec,
   430  		c:                       make(chan *Future, 32),
   431  		receivedStreamSequences: make(map[uint64]uint32),
   432  		conn:                    conn,
   433  		ctx:                     ctx,
   434  		cancel:                  cancel,
   435  		newFutureFunc:           newFutureFunc,
   436  	}
   437  }
   438  
   439  func (cs *clientSession) Close() error {
   440  	cs.mu.Lock()
   441  	defer cs.mu.Unlock()
   442  	if cs.mu.closed {
   443  		return nil
   444  	}
   445  
   446  	close(cs.c)
   447  	cs.mu.closed = true
   448  	return cs.conn.Close()
   449  }
   450  
   451  func (cs *clientSession) Write(ctx context.Context, response Message) error {
   452  	if err := cs.codec.Valid(response); err != nil {
   453  		return err
   454  	}
   455  
   456  	cs.mu.RLock()
   457  	defer cs.mu.RUnlock()
   458  
   459  	if cs.mu.closed {
   460  		return moerr.NewClientClosedNoCtx()
   461  	}
   462  
   463  	msg := RPCMessage{Ctx: ctx, Message: response}
   464  	id := response.GetID()
   465  	if v, ok := cs.sentStreamSequences.Load(id); ok {
   466  		seq := v.(uint32) + 1
   467  		cs.sentStreamSequences.Store(id, seq)
   468  		msg.stream = true
   469  		msg.streamSequence = seq
   470  	}
   471  
   472  	f := cs.newFutureFunc()
   473  	f.ref()
   474  	f.init(msg)
   475  	defer f.Close()
   476  
   477  	cs.c <- f
   478  	// stream only wait send completed
   479  	return f.waitSendCompleted()
   480  }
   481  
   482  func (cs *clientSession) cancelWrite() {
   483  	cs.cancel()
   484  }
   485  
   486  func (cs *clientSession) validateStreamRequest(id uint64, sequence uint32) bool {
   487  	expectSequence := cs.receivedStreamSequences[id] + 1
   488  	if sequence != expectSequence {
   489  		return false
   490  	}
   491  	cs.receivedStreamSequences[id] = sequence
   492  	if sequence == 1 {
   493  		cs.sentStreamSequences.Store(id, uint32(0))
   494  	}
   495  	return true
   496  }