github.com/matrixorigin/matrixone@v1.2.0/pkg/common/morpc/handler.go (about)

     1  // Copyright 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package morpc
    16  
    17  import (
    18  	"context"
    19  	"sync"
    20  
    21  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    22  	"github.com/matrixorigin/matrixone/pkg/util/trace"
    23  	"go.uber.org/zap"
    24  )
    25  
    26  type pool[REQ, RESP MethodBasedMessage] struct {
    27  	request  sync.Pool
    28  	response sync.Pool
    29  }
    30  
    31  // NewMessagePool create message pool
    32  func NewMessagePool[REQ, RESP MethodBasedMessage](
    33  	requestFactory func() REQ,
    34  	responseFactory func() RESP) MessagePool[REQ, RESP] {
    35  	return &pool[REQ, RESP]{
    36  		request:  sync.Pool{New: func() any { return requestFactory() }},
    37  		response: sync.Pool{New: func() any { return responseFactory() }},
    38  	}
    39  }
    40  
    41  func (p *pool[REQ, RESP]) AcquireRequest() REQ {
    42  	return p.request.Get().(REQ)
    43  }
    44  func (p *pool[REQ, RESP]) ReleaseRequest(request REQ) {
    45  	request.Reset()
    46  	p.request.Put(request)
    47  }
    48  func (p *pool[REQ, RESP]) AcquireResponse() RESP {
    49  	return p.response.Get().(RESP)
    50  }
    51  func (p *pool[REQ, RESP]) ReleaseResponse(resp RESP) {
    52  	resp.Reset()
    53  	p.response.Put(resp)
    54  }
    55  
    56  type handleFuncCtx[REQ, RESP MethodBasedMessage] struct {
    57  	handleFunc HandleFunc[REQ, RESP]
    58  	async      bool
    59  }
    60  
    61  func (c *handleFuncCtx[REQ, RESP]) call(
    62  	ctx context.Context,
    63  	req REQ,
    64  	resp RESP) {
    65  	if err := c.handleFunc(ctx, req, resp); err != nil {
    66  		resp.WrapError(err)
    67  	}
    68  	if getLogger().Enabled(zap.DebugLevel) {
    69  		getLogger().Debug("handle request completed",
    70  			zap.String("response", resp.DebugString()))
    71  	}
    72  }
    73  
    74  type handler[REQ, RESP MethodBasedMessage] struct {
    75  	cfg      *Config
    76  	rpc      RPCServer
    77  	pool     MessagePool[REQ, RESP]
    78  	handlers map[uint32]handleFuncCtx[REQ, RESP]
    79  
    80  	// respReleaseFunc is the function to release response.
    81  	respReleaseFunc func(Message)
    82  
    83  	options struct {
    84  		filter func(REQ) bool
    85  	}
    86  }
    87  
    88  // WithHandleMessageFilter set filter func. Requests can be modified or filtered out by the filter
    89  // before they are processed by the handler.
    90  func WithHandleMessageFilter[REQ, RESP MethodBasedMessage](filter func(REQ) bool) HandlerOption[REQ, RESP] {
    91  	return func(s *handler[REQ, RESP]) {
    92  		s.options.filter = filter
    93  	}
    94  }
    95  
    96  // WithHandlerRespReleaseFunc sets the respReleaseFunc of the handler.
    97  func WithHandlerRespReleaseFunc[REQ, RESP MethodBasedMessage](f func(message Message)) HandlerOption[REQ, RESP] {
    98  	return func(s *handler[REQ, RESP]) {
    99  		s.respReleaseFunc = f
   100  	}
   101  }
   102  
   103  // NewMessageHandler create a message handler.
   104  func NewMessageHandler[REQ, RESP MethodBasedMessage](
   105  	name string,
   106  	address string,
   107  	cfg Config,
   108  	pool MessagePool[REQ, RESP],
   109  	opts ...HandlerOption[REQ, RESP]) (MessageHandler[REQ, RESP], error) {
   110  	s := &handler[REQ, RESP]{
   111  		cfg:      &cfg,
   112  		pool:     pool,
   113  		handlers: make(map[uint32]handleFuncCtx[REQ, RESP]),
   114  	}
   115  	s.cfg.Adjust()
   116  	for _, opt := range opts {
   117  		opt(s)
   118  	}
   119  
   120  	if s.respReleaseFunc == nil {
   121  		s.respReleaseFunc = func(m Message) {
   122  			pool.ReleaseResponse(m.(RESP))
   123  		}
   124  	}
   125  
   126  	rpc, err := s.cfg.NewServer(
   127  		name,
   128  		address,
   129  		getLogger().RawLogger(),
   130  		func() Message { return pool.AcquireRequest() },
   131  		s.respReleaseFunc,
   132  		WithServerDisableAutoCancelContext())
   133  	if err != nil {
   134  		return nil, err
   135  	}
   136  	rpc.RegisterRequestHandler(s.onMessage)
   137  	s.rpc = rpc
   138  	return s, nil
   139  }
   140  
   141  func (s *handler[REQ, RESP]) Start() error {
   142  	return s.rpc.Start()
   143  }
   144  
   145  func (s *handler[REQ, RESP]) Close() error {
   146  	return s.rpc.Close()
   147  }
   148  
   149  func (s *handler[REQ, RESP]) RegisterHandleFunc(
   150  	method uint32,
   151  	h HandleFunc[REQ, RESP],
   152  	async bool) MessageHandler[REQ, RESP] {
   153  	s.handlers[method] = handleFuncCtx[REQ, RESP]{handleFunc: h, async: async}
   154  	return s
   155  }
   156  
   157  func (s *handler[REQ, RESP]) Handle(
   158  	ctx context.Context,
   159  	req REQ) RESP {
   160  	resp := s.pool.AcquireResponse()
   161  	if handlerCtx, ok := s.getHandler(ctx, req, resp); ok {
   162  		handlerCtx.call(ctx, req, resp)
   163  	}
   164  	return resp
   165  }
   166  
   167  func (s *handler[REQ, RESP]) onMessage(
   168  	ctx context.Context,
   169  	request RPCMessage,
   170  	sequence uint64,
   171  	cs ClientSession) error {
   172  	ctx, span := trace.Debug(ctx, "lockservice.server.handle")
   173  	defer span.End()
   174  	req, ok := request.Message.(REQ)
   175  	if !ok {
   176  		getLogger().Fatal("received invalid message",
   177  			zap.Any("message", request))
   178  	}
   179  
   180  	resp := s.pool.AcquireResponse()
   181  	handlerCtx, ok := s.getHandler(ctx, req, resp)
   182  	if !ok {
   183  		s.pool.ReleaseRequest(req)
   184  		return cs.Write(ctx, resp)
   185  	}
   186  
   187  	fn := func(request RPCMessage) error {
   188  		defer request.Cancel()
   189  		req, ok := request.Message.(REQ)
   190  		if !ok {
   191  			getLogger().Fatal("received invalid message",
   192  				zap.Any("message", request))
   193  		}
   194  
   195  		defer s.pool.ReleaseRequest(req)
   196  		handlerCtx.call(ctx, req, resp)
   197  		return cs.Write(ctx, resp)
   198  	}
   199  
   200  	if handlerCtx.async {
   201  		// TODO: make a goroutine pool
   202  		go fn(request)
   203  		return nil
   204  	}
   205  	return fn(request)
   206  }
   207  
   208  func (s *handler[REQ, RESP]) getHandler(
   209  	ctx context.Context,
   210  	req REQ,
   211  	resp RESP) (handleFuncCtx[REQ, RESP], bool) {
   212  	if getLogger().Enabled(zap.DebugLevel) {
   213  		getLogger().Debug("received a request",
   214  			zap.String("request", req.DebugString()))
   215  	}
   216  
   217  	select {
   218  	case <-ctx.Done():
   219  		if getLogger().Enabled(zap.DebugLevel) {
   220  			getLogger().Debug("skip request by timeout",
   221  				zap.String("request", req.DebugString()))
   222  		}
   223  		resp.WrapError(ctx.Err())
   224  		return handleFuncCtx[REQ, RESP]{}, false
   225  	default:
   226  	}
   227  
   228  	if s.options.filter != nil &&
   229  		!s.options.filter(req) {
   230  		if getLogger().Enabled(zap.DebugLevel) {
   231  			getLogger().Debug("skip request by filter",
   232  				zap.String("request", req.DebugString()))
   233  		}
   234  		resp.WrapError(moerr.NewInvalidInputNoCtx("skip request by filter"))
   235  		return handleFuncCtx[REQ, RESP]{}, false
   236  	}
   237  
   238  	resp.SetID(req.GetID())
   239  	resp.SetMethod(req.Method())
   240  	handlerCtx, ok := s.handlers[req.Method()]
   241  	if !ok {
   242  		resp.WrapError(moerr.NewNotSupportedNoCtx("%d not support in current service",
   243  			req.Method()))
   244  	}
   245  	return handlerCtx, ok
   246  }