github.com/cloudwego/kitex@v0.9.0/pkg/remote/trans/netpollmux/server_handler.go (about) 1 /* 2 * Copyright 2021 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package netpollmux 18 19 import ( 20 "context" 21 "errors" 22 "fmt" 23 "net" 24 "runtime/debug" 25 "sync" 26 "time" 27 28 "github.com/cloudwego/netpoll" 29 30 "github.com/cloudwego/kitex/pkg/endpoint" 31 "github.com/cloudwego/kitex/pkg/gofunc" 32 "github.com/cloudwego/kitex/pkg/kerrors" 33 "github.com/cloudwego/kitex/pkg/klog" 34 "github.com/cloudwego/kitex/pkg/remote" 35 "github.com/cloudwego/kitex/pkg/remote/trans" 36 np "github.com/cloudwego/kitex/pkg/remote/trans/netpoll" 37 "github.com/cloudwego/kitex/pkg/remote/transmeta" 38 "github.com/cloudwego/kitex/pkg/rpcinfo" 39 "github.com/cloudwego/kitex/pkg/serviceinfo" 40 "github.com/cloudwego/kitex/pkg/stats" 41 "github.com/cloudwego/kitex/transport" 42 ) 43 44 const defaultExitWaitGracefulShutdownTime = 1 * time.Second 45 46 type svrTransHandlerFactory struct{} 47 48 // NewSvrTransHandlerFactory creates a default netpollmux remote.ServerTransHandlerFactory. 49 func NewSvrTransHandlerFactory() remote.ServerTransHandlerFactory { 50 return &svrTransHandlerFactory{} 51 } 52 53 // MuxEnabled returns true to mark svrTransHandlerFactory as a mux server factory. 54 func (f *svrTransHandlerFactory) MuxEnabled() bool { 55 return true 56 } 57 58 // NewTransHandler implements the remote.ServerTransHandlerFactory interface. 59 // TODO: use object pool? 60 func (f *svrTransHandlerFactory) NewTransHandler(opt *remote.ServerOption) (remote.ServerTransHandler, error) { 61 return newSvrTransHandler(opt) 62 } 63 64 func newSvrTransHandler(opt *remote.ServerOption) (*svrTransHandler, error) { 65 svrHdlr := &svrTransHandler{ 66 opt: opt, 67 codec: opt.Codec, 68 svcSearchMap: opt.SvcSearchMap, 69 targetSvcInfo: opt.TargetSvcInfo, 70 ext: np.NewNetpollConnExtension(), 71 } 72 if svrHdlr.opt.TracerCtl == nil { 73 // init TraceCtl when it is nil, or it will lead some unit tests panic 74 svrHdlr.opt.TracerCtl = &rpcinfo.TraceController{} 75 } 76 svrHdlr.funcPool.New = func() interface{} { 77 fs := make([]func(), 0, 64) // 64 is defined casually, no special meaning 78 return &fs 79 } 80 return svrHdlr, nil 81 } 82 83 var _ remote.ServerTransHandler = &svrTransHandler{} 84 85 type svrTransHandler struct { 86 opt *remote.ServerOption 87 svcSearchMap map[string]*serviceinfo.ServiceInfo 88 targetSvcInfo *serviceinfo.ServiceInfo 89 inkHdlFunc endpoint.Endpoint 90 codec remote.Codec 91 transPipe *remote.TransPipeline 92 ext trans.Extension 93 funcPool sync.Pool 94 conns sync.Map 95 tasks sync.WaitGroup 96 } 97 98 // Write implements the remote.ServerTransHandler interface. 99 func (t *svrTransHandler) Write(ctx context.Context, conn net.Conn, sendMsg remote.Message) (nctx context.Context, err error) { 100 ri := rpcinfo.GetRPCInfo(ctx) 101 rpcinfo.Record(ctx, ri, stats.WriteStart, nil) 102 defer func() { 103 rpcinfo.Record(ctx, ri, stats.WriteFinish, nil) 104 }() 105 106 svcInfo := sendMsg.ServiceInfo() 107 if svcInfo != nil { 108 if methodInfo, _ := trans.GetMethodInfo(ri, svcInfo); methodInfo != nil { 109 if methodInfo.OneWay() { 110 return ctx, nil 111 } 112 } 113 } 114 115 wbuf := netpoll.NewLinkBuffer() 116 bufWriter := np.NewWriterByteBuffer(wbuf) 117 err = t.codec.Encode(ctx, sendMsg, bufWriter) 118 bufWriter.Release(err) 119 if err != nil { 120 return ctx, err 121 } 122 conn.(*muxSvrConn).Put(func() (buf netpoll.Writer, isNil bool) { 123 return wbuf, false 124 }) 125 return ctx, nil 126 } 127 128 // Read implements the remote.ServerTransHandler interface. 129 func (t *svrTransHandler) Read(ctx context.Context, conn net.Conn, msg remote.Message) (nctx context.Context, err error) { 130 return ctx, nil 131 } 132 133 func (t *svrTransHandler) readWithByteBuffer(ctx context.Context, bufReader remote.ByteBuffer, msg remote.Message) (err error) { 134 defer func() { 135 if bufReader != nil { 136 if err != nil { 137 bufReader.Skip(bufReader.ReadableLen()) 138 } 139 bufReader.Release(err) 140 } 141 rpcinfo.Record(ctx, msg.RPCInfo(), stats.ReadFinish, err) 142 }() 143 rpcinfo.Record(ctx, msg.RPCInfo(), stats.ReadStart, nil) 144 145 err = t.codec.Decode(ctx, msg, bufReader) 146 if err != nil { 147 msg.Tags()[remote.ReadFailed] = true 148 return err 149 } 150 return nil 151 } 152 153 // OnRead implements the remote.ServerTransHandler interface. 154 // Returns write err only. 155 func (t *svrTransHandler) OnRead(muxSvrConnCtx context.Context, conn net.Conn) error { 156 defer t.tryRecover(muxSvrConnCtx, conn) 157 connection := conn.(netpoll.Connection) 158 r := connection.Reader() 159 160 fs := *t.funcPool.Get().(*[]func()) 161 for total := r.Len(); total > 0; total = r.Len() { 162 // protocol header check 163 length, _, err := parseHeader(r) 164 if err != nil { 165 err = fmt.Errorf("%w: addr(%s)", err, connection.RemoteAddr()) 166 klog.Errorf("KITEX: error=%s", err.Error()) 167 connection.Close() 168 return err 169 } 170 if total < length && len(fs) > 0 { 171 go t.batchGoTasks(fs) 172 fs = *t.funcPool.Get().(*[]func()) 173 } 174 reader, err := r.Slice(length) 175 if err != nil { 176 err = fmt.Errorf("%w: addr(%s)", err, connection.RemoteAddr()) 177 klog.Errorf("KITEX: error=%s", err.Error()) 178 connection.Close() 179 return nil 180 } 181 fs = append(fs, func() { 182 t.task(muxSvrConnCtx, conn, reader) 183 }) 184 } 185 go t.batchGoTasks(fs) 186 return nil 187 } 188 189 // batchGoTasks centrally creates goroutines to execute tasks. 190 func (t *svrTransHandler) batchGoTasks(fs []func()) { 191 for n := range fs { 192 gofunc.GoFunc(context.Background(), fs[n]) 193 } 194 fs = fs[:0] 195 t.funcPool.Put(&fs) 196 } 197 198 // task contains a complete process about decoding request -> handling -> writing response 199 func (t *svrTransHandler) task(muxSvrConnCtx context.Context, conn net.Conn, reader netpoll.Reader) { 200 t.tasks.Add(1) 201 defer t.tasks.Done() 202 203 // rpcInfoCtx is a pooled ctx with inited RPCInfo which can be reused. 204 // it's recycled in defer. 205 muxSvrConn, _ := muxSvrConnCtx.Value(ctxKeyMuxSvrConn{}).(*muxSvrConn) 206 rpcInfo := muxSvrConn.pool.Get().(rpcinfo.RPCInfo) 207 rpcInfoCtx := rpcinfo.NewCtxWithRPCInfo(muxSvrConnCtx, rpcInfo) 208 209 // This is the request-level, one-shot ctx. 210 // It adds the tracer principally, thus do not recycle. 211 ctx := t.startTracer(rpcInfoCtx, rpcInfo) 212 var err error 213 var recvMsg remote.Message 214 var sendMsg remote.Message 215 var closeConn bool 216 defer func() { 217 panicErr := recover() 218 if panicErr != nil { 219 if conn != nil { 220 ri := rpcinfo.GetRPCInfo(ctx) 221 rService, rAddr := getRemoteInfo(ri, conn) 222 klog.Errorf("KITEX: panic happened, close conn, remoteAddress=%s remoteService=%s error=%s\nstack=%s", rAddr, rService, panicErr, string(debug.Stack())) 223 closeConn = true 224 } else { 225 klog.Errorf("KITEX: panic happened, error=%s\nstack=%s", panicErr, string(debug.Stack())) 226 } 227 } 228 if closeConn && conn != nil { 229 conn.Close() 230 } 231 t.finishTracer(ctx, rpcInfo, err, panicErr) 232 remote.RecycleMessage(recvMsg) 233 remote.RecycleMessage(sendMsg) 234 // reset rpcinfo for reuse 235 if rpcinfo.PoolEnabled() { 236 rpcInfo = t.opt.InitOrResetRPCInfoFunc(rpcInfo, conn.RemoteAddr()) 237 muxSvrConn.pool.Put(rpcInfo) 238 } 239 }() 240 241 // read 242 recvMsg = remote.NewMessageWithNewer(t.targetSvcInfo, t.svcSearchMap, rpcInfo, remote.Call, remote.Server, t.opt.RefuseTrafficWithoutServiceName) 243 bufReader := np.NewReaderByteBuffer(reader) 244 err = t.readWithByteBuffer(ctx, bufReader, recvMsg) 245 if err != nil { 246 // No need to close the connection when read failed in mux case, because it had finished reads. 247 // But still need to close conn if write failed 248 closeConn = t.writeErrorReplyIfNeeded(ctx, recvMsg, muxSvrConn, rpcInfo, err, true) 249 // for proxy case, need read actual remoteAddr, error print must exec after writeErrorReplyIfNeeded 250 t.OnError(ctx, err, muxSvrConn) 251 return 252 } 253 254 svcInfo := recvMsg.ServiceInfo() 255 t.targetSvcInfo = svcInfo 256 if recvMsg.MessageType() == remote.Heartbeat { 257 sendMsg = remote.NewMessage(nil, svcInfo, rpcInfo, remote.Heartbeat, remote.Server) 258 } else { 259 var methodInfo serviceinfo.MethodInfo 260 if methodInfo, err = trans.GetMethodInfo(rpcInfo, svcInfo); err != nil { 261 closeConn = t.writeErrorReplyIfNeeded(ctx, recvMsg, muxSvrConn, rpcInfo, err, true) 262 t.OnError(ctx, err, muxSvrConn) 263 return 264 } 265 if methodInfo.OneWay() { 266 sendMsg = remote.NewMessage(nil, svcInfo, rpcInfo, remote.Reply, remote.Server) 267 } else { 268 sendMsg = remote.NewMessage(methodInfo.NewResult(), svcInfo, rpcInfo, remote.Reply, remote.Server) 269 } 270 271 ctx, err = t.transPipe.OnMessage(ctx, recvMsg, sendMsg) 272 if err != nil { 273 // error cannot be wrapped to print here, so it must exec before NewTransError 274 t.OnError(ctx, err, muxSvrConn) 275 err = remote.NewTransError(remote.InternalError, err) 276 closeConn = t.writeErrorReplyIfNeeded(ctx, recvMsg, muxSvrConn, rpcInfo, err, false) 277 return 278 } 279 } 280 281 remote.FillSendMsgFromRecvMsg(recvMsg, sendMsg) 282 if ctx, err = t.transPipe.Write(ctx, muxSvrConn, sendMsg); err != nil { 283 t.OnError(ctx, err, muxSvrConn) 284 closeConn = true 285 return 286 } 287 } 288 289 // OnMessage implements the remote.ServerTransHandler interface. 290 // msg is the decoded instance, such as Arg or Result. 291 // OnMessage notifies the higher level to process. It's used in async and server-side logic. 292 func (t *svrTransHandler) OnMessage(ctx context.Context, args, result remote.Message) (context.Context, error) { 293 err := t.inkHdlFunc(ctx, args.Data(), result.Data()) 294 return ctx, err 295 } 296 297 type ctxKeyMuxSvrConn struct{} 298 299 // OnActive implements the remote.ServerTransHandler interface. 300 // sync.Pool for RPCInfo is setup here. 301 func (t *svrTransHandler) OnActive(ctx context.Context, conn net.Conn) (context.Context, error) { 302 connection := conn.(netpoll.Connection) 303 304 // 1. set readwrite timeout 305 connection.SetReadTimeout(t.opt.ReadWriteTimeout) 306 307 // 2. set mux server conn 308 pool := &sync.Pool{ 309 New: func() interface{} { 310 // init rpcinfo 311 ri := t.opt.InitOrResetRPCInfoFunc(nil, connection.RemoteAddr()) 312 return ri 313 }, 314 } 315 muxConn := newMuxSvrConn(connection, pool) 316 t.conns.Store(conn, muxConn) 317 return context.WithValue(context.Background(), ctxKeyMuxSvrConn{}, muxConn), nil 318 } 319 320 // OnInactive implements the remote.ServerTransHandler interface. 321 func (t *svrTransHandler) OnInactive(ctx context.Context, conn net.Conn) { 322 t.conns.Delete(conn) 323 } 324 325 func (t *svrTransHandler) GracefulShutdown(ctx context.Context) error { 326 // Send a control frame with sequence ID 0 to notify the remote 327 // end to close the connection or prevent further operation on it. 328 iv := rpcinfo.NewInvocation("none", "none") 329 iv.SetSeqID(0) 330 ri := rpcinfo.NewRPCInfo(nil, nil, iv, nil, nil) 331 data := NewControlFrame() 332 svcInfo := t.getSvcInfo() 333 msg := remote.NewMessage(data, svcInfo, ri, remote.Reply, remote.Server) 334 msg.SetProtocolInfo(remote.NewProtocolInfo(transport.TTHeader, serviceinfo.Thrift)) 335 msg.TransInfo().TransStrInfo()[transmeta.HeaderConnectionReadyToReset] = "1" 336 337 // wait until all notifications are sent and clients stop using those connections 338 done := make(chan struct{}) 339 gofunc.GoFunc(context.Background(), func() { 340 // 1. write control frames to all connections 341 t.conns.Range(func(k, v interface{}) bool { 342 sconn := v.(*muxSvrConn) 343 if !sconn.IsActive() { 344 return true 345 } 346 wbuf := netpoll.NewLinkBuffer() 347 bufWriter := np.NewWriterByteBuffer(wbuf) 348 err := t.codec.Encode(ctx, msg, bufWriter) 349 bufWriter.Release(err) 350 if err == nil { 351 sconn.Put(func() (buf netpoll.Writer, isNil bool) { 352 return wbuf, false 353 }) 354 } else { 355 klog.Warn("KITEX: signal connection closing error:", 356 err.Error(), sconn.LocalAddr().String(), "=>", sconn.RemoteAddr().String()) 357 } 358 return true 359 }) 360 // 2. waiting for all tasks finished 361 t.tasks.Wait() 362 // 3. waiting for all connections have been shutdown gracefully 363 t.conns.Range(func(k, v interface{}) bool { 364 sconn := v.(*muxSvrConn) 365 if sconn.IsActive() { 366 sconn.GracefulShutdown() 367 } 368 return true 369 }) 370 // 4. waiting all crrst packets received by client 371 time.Sleep(defaultExitWaitGracefulShutdownTime) 372 close(done) 373 }) 374 select { 375 case <-ctx.Done(): 376 return ctx.Err() 377 case <-done: 378 return nil 379 } 380 } 381 382 // OnError implements the remote.ServerTransHandler interface. 383 func (t *svrTransHandler) OnError(ctx context.Context, err error, conn net.Conn) { 384 ri := rpcinfo.GetRPCInfo(ctx) 385 rService, rAddr := getRemoteInfo(ri, conn) 386 if t.ext.IsRemoteClosedErr(err) { 387 // it should not regard error which cause by remote connection closed as server error 388 if ri == nil { 389 return 390 } 391 remote := rpcinfo.AsMutableEndpointInfo(ri.From()) 392 remote.SetTag(rpcinfo.RemoteClosedTag, "1") 393 } else { 394 var de *kerrors.DetailedError 395 if ok := errors.As(err, &de); ok && de.Stack() != "" { 396 klog.CtxErrorf(ctx, "KITEX: processing request error, remoteService=%s, remoteAddr=%v, error=%s\nstack=%s", rService, rAddr, err.Error(), de.Stack()) 397 } else { 398 klog.CtxErrorf(ctx, "KITEX: processing request error, remoteService=%s, remoteAddr=%v, error=%s", rService, rAddr, err.Error()) 399 } 400 } 401 } 402 403 // SetInvokeHandleFunc implements the remote.InvokeHandleFuncSetter interface. 404 func (t *svrTransHandler) SetInvokeHandleFunc(inkHdlFunc endpoint.Endpoint) { 405 t.inkHdlFunc = inkHdlFunc 406 } 407 408 // SetPipeline implements the remote.ServerTransHandler interface. 409 func (t *svrTransHandler) SetPipeline(p *remote.TransPipeline) { 410 t.transPipe = p 411 } 412 413 func (t *svrTransHandler) writeErrorReplyIfNeeded( 414 ctx context.Context, recvMsg remote.Message, conn net.Conn, ri rpcinfo.RPCInfo, err error, doOnMessage bool, 415 ) (shouldCloseConn bool) { 416 svcInfo := recvMsg.ServiceInfo() 417 if svcInfo != nil { 418 if methodInfo, _ := trans.GetMethodInfo(ri, svcInfo); methodInfo != nil { 419 if methodInfo.OneWay() { 420 return 421 } 422 } 423 } 424 transErr, isTransErr := err.(*remote.TransError) 425 if !isTransErr { 426 return 427 } 428 errMsg := remote.NewMessage(transErr, svcInfo, ri, remote.Exception, remote.Server) 429 remote.FillSendMsgFromRecvMsg(recvMsg, errMsg) 430 if doOnMessage { 431 // if error happen before normal OnMessage, exec it to transfer header trans info into rpcinfo 432 t.transPipe.OnMessage(ctx, recvMsg, errMsg) 433 } 434 ctx, err = t.transPipe.Write(ctx, conn, errMsg) 435 if err != nil { 436 klog.CtxErrorf(ctx, "KITEX: write error reply failed, remote=%s, error=%s", conn.RemoteAddr(), err.Error()) 437 return true 438 } 439 return 440 } 441 442 func (t *svrTransHandler) tryRecover(ctx context.Context, conn net.Conn) { 443 if err := recover(); err != nil { 444 // rpcStat := internal.AsMutableRPCStats(t.rpcinfo.Stats()) 445 // rpcStat.SetPanicked(err) 446 // t.opt.TracerCtl.DoFinish(ctx, klog) 447 // 这里不需要 Reset rpcStats 因为连接会关闭,会直接把 RPCInfo 进行 Recycle 448 449 if conn != nil { 450 conn.Close() 451 klog.CtxErrorf(ctx, "KITEX: panic happened, close conn[%s], %s\n%s", conn.RemoteAddr(), err, string(debug.Stack())) 452 } else { 453 klog.CtxErrorf(ctx, "KITEX: panic happened, %s\n%s", err, string(debug.Stack())) 454 } 455 } 456 } 457 458 func (t *svrTransHandler) startTracer(ctx context.Context, ri rpcinfo.RPCInfo) context.Context { 459 c := t.opt.TracerCtl.DoStart(ctx, ri) 460 return c 461 } 462 463 func (t *svrTransHandler) finishTracer(ctx context.Context, ri rpcinfo.RPCInfo, err error, panicErr interface{}) { 464 rpcStats := rpcinfo.AsMutableRPCStats(ri.Stats()) 465 if rpcStats == nil { 466 return 467 } 468 if panicErr != nil { 469 rpcStats.SetPanicked(panicErr) 470 } 471 if errors.Is(err, netpoll.ErrConnClosed) { 472 // it should not regard error which cause by remote connection closed as server error 473 err = nil 474 } 475 t.opt.TracerCtl.DoFinish(ctx, ri, err) 476 // for server side, rpcinfo is reused on connection, clear the rpc stats info but keep the level config 477 sl := ri.Stats().Level() 478 rpcStats.Reset() 479 rpcStats.SetLevel(sl) 480 } 481 482 // getSvcInfo is used to get one ServiceInfo 483 func (t *svrTransHandler) getSvcInfo() *serviceinfo.ServiceInfo { 484 if t.targetSvcInfo != nil { 485 return t.targetSvcInfo 486 } 487 for _, svcInfo := range t.svcSearchMap { 488 return svcInfo 489 } 490 return nil 491 } 492 493 func getRemoteInfo(ri rpcinfo.RPCInfo, conn net.Conn) (string, net.Addr) { 494 rAddr := conn.RemoteAddr() 495 if ri == nil { 496 return "", rAddr 497 } 498 if rAddr.Network() == "unix" { 499 if ri.From().Address() != nil { 500 rAddr = ri.From().Address() 501 } 502 } 503 return ri.From().ServiceName(), rAddr 504 }