github.com/matrixorigin/matrixone@v1.2.0/pkg/common/morpc/handler.go (about) 1 // Copyright 2022 Matrix Origin 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package morpc 16 17 import ( 18 "context" 19 "sync" 20 21 "github.com/matrixorigin/matrixone/pkg/common/moerr" 22 "github.com/matrixorigin/matrixone/pkg/util/trace" 23 "go.uber.org/zap" 24 ) 25 26 type pool[REQ, RESP MethodBasedMessage] struct { 27 request sync.Pool 28 response sync.Pool 29 } 30 31 // NewMessagePool create message pool 32 func NewMessagePool[REQ, RESP MethodBasedMessage]( 33 requestFactory func() REQ, 34 responseFactory func() RESP) MessagePool[REQ, RESP] { 35 return &pool[REQ, RESP]{ 36 request: sync.Pool{New: func() any { return requestFactory() }}, 37 response: sync.Pool{New: func() any { return responseFactory() }}, 38 } 39 } 40 41 func (p *pool[REQ, RESP]) AcquireRequest() REQ { 42 return p.request.Get().(REQ) 43 } 44 func (p *pool[REQ, RESP]) ReleaseRequest(request REQ) { 45 request.Reset() 46 p.request.Put(request) 47 } 48 func (p *pool[REQ, RESP]) AcquireResponse() RESP { 49 return p.response.Get().(RESP) 50 } 51 func (p *pool[REQ, RESP]) ReleaseResponse(resp RESP) { 52 resp.Reset() 53 p.response.Put(resp) 54 } 55 56 type handleFuncCtx[REQ, RESP MethodBasedMessage] struct { 57 handleFunc HandleFunc[REQ, RESP] 58 async bool 59 } 60 61 func (c *handleFuncCtx[REQ, RESP]) call( 62 ctx context.Context, 63 req REQ, 64 resp RESP) { 65 if err := c.handleFunc(ctx, req, resp); err != nil { 66 resp.WrapError(err) 67 } 68 if getLogger().Enabled(zap.DebugLevel) { 69 getLogger().Debug("handle request completed", 70 zap.String("response", resp.DebugString())) 71 } 72 } 73 74 type handler[REQ, RESP MethodBasedMessage] struct { 75 cfg *Config 76 rpc RPCServer 77 pool MessagePool[REQ, RESP] 78 handlers map[uint32]handleFuncCtx[REQ, RESP] 79 80 // respReleaseFunc is the function to release response. 81 respReleaseFunc func(Message) 82 83 options struct { 84 filter func(REQ) bool 85 } 86 } 87 88 // WithHandleMessageFilter set filter func. Requests can be modified or filtered out by the filter 89 // before they are processed by the handler. 90 func WithHandleMessageFilter[REQ, RESP MethodBasedMessage](filter func(REQ) bool) HandlerOption[REQ, RESP] { 91 return func(s *handler[REQ, RESP]) { 92 s.options.filter = filter 93 } 94 } 95 96 // WithHandlerRespReleaseFunc sets the respReleaseFunc of the handler. 97 func WithHandlerRespReleaseFunc[REQ, RESP MethodBasedMessage](f func(message Message)) HandlerOption[REQ, RESP] { 98 return func(s *handler[REQ, RESP]) { 99 s.respReleaseFunc = f 100 } 101 } 102 103 // NewMessageHandler create a message handler. 104 func NewMessageHandler[REQ, RESP MethodBasedMessage]( 105 name string, 106 address string, 107 cfg Config, 108 pool MessagePool[REQ, RESP], 109 opts ...HandlerOption[REQ, RESP]) (MessageHandler[REQ, RESP], error) { 110 s := &handler[REQ, RESP]{ 111 cfg: &cfg, 112 pool: pool, 113 handlers: make(map[uint32]handleFuncCtx[REQ, RESP]), 114 } 115 s.cfg.Adjust() 116 for _, opt := range opts { 117 opt(s) 118 } 119 120 if s.respReleaseFunc == nil { 121 s.respReleaseFunc = func(m Message) { 122 pool.ReleaseResponse(m.(RESP)) 123 } 124 } 125 126 rpc, err := s.cfg.NewServer( 127 name, 128 address, 129 getLogger().RawLogger(), 130 func() Message { return pool.AcquireRequest() }, 131 s.respReleaseFunc, 132 WithServerDisableAutoCancelContext()) 133 if err != nil { 134 return nil, err 135 } 136 rpc.RegisterRequestHandler(s.onMessage) 137 s.rpc = rpc 138 return s, nil 139 } 140 141 func (s *handler[REQ, RESP]) Start() error { 142 return s.rpc.Start() 143 } 144 145 func (s *handler[REQ, RESP]) Close() error { 146 return s.rpc.Close() 147 } 148 149 func (s *handler[REQ, RESP]) RegisterHandleFunc( 150 method uint32, 151 h HandleFunc[REQ, RESP], 152 async bool) MessageHandler[REQ, RESP] { 153 s.handlers[method] = handleFuncCtx[REQ, RESP]{handleFunc: h, async: async} 154 return s 155 } 156 157 func (s *handler[REQ, RESP]) Handle( 158 ctx context.Context, 159 req REQ) RESP { 160 resp := s.pool.AcquireResponse() 161 if handlerCtx, ok := s.getHandler(ctx, req, resp); ok { 162 handlerCtx.call(ctx, req, resp) 163 } 164 return resp 165 } 166 167 func (s *handler[REQ, RESP]) onMessage( 168 ctx context.Context, 169 request RPCMessage, 170 sequence uint64, 171 cs ClientSession) error { 172 ctx, span := trace.Debug(ctx, "lockservice.server.handle") 173 defer span.End() 174 req, ok := request.Message.(REQ) 175 if !ok { 176 getLogger().Fatal("received invalid message", 177 zap.Any("message", request)) 178 } 179 180 resp := s.pool.AcquireResponse() 181 handlerCtx, ok := s.getHandler(ctx, req, resp) 182 if !ok { 183 s.pool.ReleaseRequest(req) 184 return cs.Write(ctx, resp) 185 } 186 187 fn := func(request RPCMessage) error { 188 defer request.Cancel() 189 req, ok := request.Message.(REQ) 190 if !ok { 191 getLogger().Fatal("received invalid message", 192 zap.Any("message", request)) 193 } 194 195 defer s.pool.ReleaseRequest(req) 196 handlerCtx.call(ctx, req, resp) 197 return cs.Write(ctx, resp) 198 } 199 200 if handlerCtx.async { 201 // TODO: make a goroutine pool 202 go fn(request) 203 return nil 204 } 205 return fn(request) 206 } 207 208 func (s *handler[REQ, RESP]) getHandler( 209 ctx context.Context, 210 req REQ, 211 resp RESP) (handleFuncCtx[REQ, RESP], bool) { 212 if getLogger().Enabled(zap.DebugLevel) { 213 getLogger().Debug("received a request", 214 zap.String("request", req.DebugString())) 215 } 216 217 select { 218 case <-ctx.Done(): 219 if getLogger().Enabled(zap.DebugLevel) { 220 getLogger().Debug("skip request by timeout", 221 zap.String("request", req.DebugString())) 222 } 223 resp.WrapError(ctx.Err()) 224 return handleFuncCtx[REQ, RESP]{}, false 225 default: 226 } 227 228 if s.options.filter != nil && 229 !s.options.filter(req) { 230 if getLogger().Enabled(zap.DebugLevel) { 231 getLogger().Debug("skip request by filter", 232 zap.String("request", req.DebugString())) 233 } 234 resp.WrapError(moerr.NewInvalidInputNoCtx("skip request by filter")) 235 return handleFuncCtx[REQ, RESP]{}, false 236 } 237 238 resp.SetID(req.GetID()) 239 resp.SetMethod(req.Method()) 240 handlerCtx, ok := s.handlers[req.Method()] 241 if !ok { 242 resp.WrapError(moerr.NewNotSupportedNoCtx("%d not support in current service", 243 req.Method())) 244 } 245 return handlerCtx, ok 246 }