github.com/matrixorigin/matrixone@v1.2.0/pkg/txn/rpc/server.go (about)

     1  // Copyright 2021 - 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package rpc
    16  
    17  import (
    18  	"context"
    19  	"encoding/hex"
    20  	"sync"
    21  	"time"
    22  
    23  	"github.com/fagongzi/goetty/v2"
    24  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    25  	"github.com/matrixorigin/matrixone/pkg/common/morpc"
    26  	"github.com/matrixorigin/matrixone/pkg/common/mpool"
    27  	"github.com/matrixorigin/matrixone/pkg/common/runtime"
    28  	"github.com/matrixorigin/matrixone/pkg/common/stopper"
    29  	"github.com/matrixorigin/matrixone/pkg/defines"
    30  	"github.com/matrixorigin/matrixone/pkg/pb/txn"
    31  	v2 "github.com/matrixorigin/matrixone/pkg/util/metric/v2"
    32  	"github.com/matrixorigin/matrixone/pkg/util/trace"
    33  	"go.uber.org/zap"
    34  )
    35  
    36  var methodVersions = map[txn.TxnMethod]int64{
    37  	txn.TxnMethod_Read:   defines.MORPCVersion1,
    38  	txn.TxnMethod_Write:  defines.MORPCVersion1,
    39  	txn.TxnMethod_Commit: defines.MORPCVersion1,
    40  
    41  	txn.TxnMethod_Prepare:         defines.MORPCVersion1,
    42  	txn.TxnMethod_CommitTNShard:   defines.MORPCVersion1,
    43  	txn.TxnMethod_RollbackTNShard: defines.MORPCVersion1,
    44  	txn.TxnMethod_GetStatus:       defines.MORPCVersion1,
    45  
    46  	txn.TxnMethod_DEBUG: defines.MORPCVersion1,
    47  }
    48  
    49  // WithServerMaxMessageSize set max rpc message size
    50  func WithServerMaxMessageSize(maxMessageSize int) ServerOption {
    51  	return func(s *server) {
    52  		s.options.maxMessageSize = maxMessageSize
    53  	}
    54  }
    55  
    56  // WithServerEnableCompress enable compress
    57  func WithServerEnableCompress(enable bool) ServerOption {
    58  	return func(s *server) {
    59  		s.options.enableCompress = enable
    60  	}
    61  }
    62  
    63  // set filter func. Requests can be modified or filtered out by the filter
    64  // before they are processed by the handler.
    65  func WithServerMessageFilter(filter func(*txn.TxnRequest) bool) ServerOption {
    66  	return func(s *server) {
    67  		s.options.filter = filter
    68  	}
    69  }
    70  
    71  // WithServerQueueBufferSize set queue buffer size
    72  func WithServerQueueBufferSize(value int) ServerOption {
    73  	return func(s *server) {
    74  		s.options.maxChannelBufferSize = value
    75  	}
    76  }
    77  
    78  // WithServerQueueWorkers set worker number
    79  func WithServerQueueWorkers(value int) ServerOption {
    80  	return func(s *server) {
    81  		s.options.workers = value
    82  	}
    83  }
    84  
    85  type server struct {
    86  	rt       runtime.Runtime
    87  	rpc      morpc.RPCServer
    88  	handlers map[txn.TxnMethod]TxnRequestHandleFunc
    89  
    90  	pool struct {
    91  		requests  sync.Pool
    92  		responses sync.Pool
    93  	}
    94  
    95  	options struct {
    96  		filter               func(*txn.TxnRequest) bool
    97  		maxMessageSize       int
    98  		enableCompress       bool
    99  		maxChannelBufferSize int
   100  		workers              int
   101  	}
   102  
   103  	// in order not to block tcp, the data read from tcp will be put into this ringbuffer. This ringbuffer will
   104  	// be consumed by many goroutines concurrently, and the number of goroutines will be set to the number of
   105  	// cpu's number.
   106  	queue   chan executor
   107  	stopper *stopper.Stopper
   108  }
   109  
   110  // NewTxnServer create a txn server. One DNStore corresponds to one TxnServer
   111  func NewTxnServer(
   112  	address string,
   113  	rt runtime.Runtime,
   114  	opts ...ServerOption) (TxnServer, error) {
   115  	s := &server{
   116  		rt:       rt,
   117  		handlers: make(map[txn.TxnMethod]TxnRequestHandleFunc),
   118  		stopper: stopper.NewStopper("txn rpc server",
   119  			stopper.WithLogger(rt.Logger().RawLogger())),
   120  	}
   121  	s.pool.requests = sync.Pool{
   122  		New: func() any {
   123  			return &txn.TxnRequest{}
   124  		},
   125  	}
   126  	s.pool.responses = sync.Pool{
   127  		New: func() any {
   128  			return &txn.TxnResponse{}
   129  		},
   130  	}
   131  	for _, opt := range opts {
   132  		opt(s)
   133  	}
   134  
   135  	var codecOpts []morpc.CodecOption
   136  	codecOpts = append(codecOpts,
   137  		morpc.WithCodecIntegrationHLC(rt.Clock()),
   138  		morpc.WithCodecEnableChecksum(),
   139  		morpc.WithCodecPayloadCopyBufferSize(16*1024),
   140  		morpc.WithCodecMaxBodySize(s.options.maxMessageSize))
   141  	if s.options.enableCompress {
   142  		mp, err := mpool.NewMPool("txn-server", 0, mpool.NoFixed)
   143  		if err != nil {
   144  			return nil, err
   145  		}
   146  		codecOpts = append(codecOpts, morpc.WithCodecEnableCompress(mp))
   147  	}
   148  	rpc, err := morpc.NewRPCServer("txn-server", address,
   149  		morpc.NewMessageCodec(s.acquireRequest, codecOpts...),
   150  		morpc.WithServerLogger(s.rt.Logger().RawLogger()),
   151  		morpc.WithServerDisableAutoCancelContext(),
   152  		morpc.WithServerGoettyOptions(goetty.WithSessionReleaseMsgFunc(func(v interface{}) {
   153  			m := v.(morpc.RPCMessage)
   154  			if !m.InternalMessage() {
   155  				s.releaseResponse(m.Message.(*txn.TxnResponse))
   156  			}
   157  		})))
   158  	if err != nil {
   159  		return nil, err
   160  	}
   161  
   162  	rpc.RegisterRequestHandler(s.onMessage)
   163  	s.rpc = rpc
   164  	return s, nil
   165  }
   166  
   167  func (s *server) Start() error {
   168  	s.queue = make(chan executor, s.options.maxChannelBufferSize)
   169  	s.startProcessors()
   170  	return s.rpc.Start()
   171  }
   172  
   173  func (s *server) Close() error {
   174  	s.stopper.Stop()
   175  	return s.rpc.Close()
   176  }
   177  
   178  func (s *server) RegisterMethodHandler(m txn.TxnMethod, h TxnRequestHandleFunc) {
   179  	s.handlers[m] = h
   180  }
   181  
   182  func (s *server) startProcessors() {
   183  	for i := 0; i < s.options.workers; i++ {
   184  		if err := s.stopper.RunTask(s.handleTxnRequest); err != nil {
   185  			panic(err)
   186  		}
   187  	}
   188  }
   189  
   190  // onMessage a client connection has a separate read goroutine. The onMessage invoked in this read goroutine.
   191  func (s *server) onMessage(
   192  	ctx context.Context,
   193  	msg morpc.RPCMessage,
   194  	sequence uint64,
   195  	cs morpc.ClientSession) error {
   196  	ctx, span := trace.Debug(ctx, "server.onMessage")
   197  	defer span.End()
   198  
   199  	m, ok := msg.Message.(*txn.TxnRequest)
   200  	if !ok {
   201  		s.rt.Logger().Fatal("received invalid message", zap.Any("message", msg))
   202  	}
   203  	if s.options.filter != nil && !s.options.filter(m) {
   204  		s.releaseRequest(m)
   205  		msg.Cancel()
   206  		return nil
   207  	}
   208  	if err := checkMethodVersion(ctx, m); err != nil {
   209  		s.releaseRequest(m)
   210  		msg.Cancel()
   211  		return err
   212  	}
   213  	handler, ok := s.handlers[m.Method]
   214  	if !ok {
   215  		return moerr.NewNotSupported(ctx, "unknown txn request method: %s", m.Method.String())
   216  	}
   217  
   218  	select {
   219  	case <-ctx.Done():
   220  		s.releaseRequest(m)
   221  		msg.Cancel()
   222  		return nil
   223  	default:
   224  	}
   225  
   226  	t := time.Now()
   227  	s.queue <- executor{
   228  		t:       t,
   229  		ctx:     ctx,
   230  		cancel:  msg.Cancel,
   231  		req:     m,
   232  		cs:      cs,
   233  		handler: handler,
   234  		s:       s,
   235  	}
   236  	n := len(s.queue)
   237  	v2.TxnCommitQueueSizeGauge.Set(float64(n))
   238  	if n > s.options.maxChannelBufferSize/2 {
   239  		s.rt.Logger().Warn("txn request handle channel is busy",
   240  			zap.Int("size", n),
   241  			zap.Int("max", s.options.maxChannelBufferSize))
   242  	}
   243  	return nil
   244  }
   245  
   246  func (s *server) acquireResponse() *txn.TxnResponse {
   247  	return s.pool.responses.Get().(*txn.TxnResponse)
   248  }
   249  
   250  func (s *server) releaseResponse(resp *txn.TxnResponse) {
   251  	resp.Reset()
   252  	s.pool.responses.Put(resp)
   253  }
   254  
   255  func (s *server) acquireRequest() morpc.Message {
   256  	return s.pool.requests.Get().(*txn.TxnRequest)
   257  }
   258  
   259  func (s *server) releaseRequest(req *txn.TxnRequest) {
   260  	req.Reset()
   261  	s.pool.requests.Put(req)
   262  }
   263  
   264  func (s *server) handleTxnRequest(ctx context.Context) {
   265  	for {
   266  		select {
   267  		case <-ctx.Done():
   268  			return
   269  		case req := <-s.queue:
   270  			if txnID, err := req.exec(); err != nil {
   271  				if s.rt.Logger().Enabled(zap.DebugLevel) {
   272  					s.rt.Logger().Error("handle txn request failed",
   273  						zap.String("txn-id", hex.EncodeToString(txnID)),
   274  						zap.Error(err))
   275  				}
   276  			}
   277  		}
   278  	}
   279  }
   280  
   281  type executor struct {
   282  	t       time.Time
   283  	ctx     context.Context
   284  	cancel  context.CancelFunc
   285  	req     *txn.TxnRequest
   286  	cs      morpc.ClientSession
   287  	handler TxnRequestHandleFunc
   288  	s       *server
   289  }
   290  
   291  func (r executor) exec() ([]byte, error) {
   292  	defer r.cancel()
   293  	defer r.s.releaseRequest(r.req)
   294  	resp := r.s.acquireResponse()
   295  	if err := r.handler(r.ctx, r.req, resp); err != nil {
   296  		r.s.releaseResponse(resp)
   297  		return r.req.Txn.ID, err
   298  	}
   299  	resp.RequestID = r.req.RequestID
   300  	txnID := r.req.Txn.ID
   301  	err := r.cs.Write(r.ctx, resp)
   302  	return txnID, err
   303  }
   304  
   305  func checkMethodVersion(ctx context.Context, req *txn.TxnRequest) error {
   306  	return runtime.CheckMethodVersion(ctx, methodVersions, req)
   307  }