github.com/cloudwego/kitex@v0.9.0/pkg/remote/trans/netpollmux/server_handler.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package netpollmux
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"fmt"
    23  	"net"
    24  	"runtime/debug"
    25  	"sync"
    26  	"time"
    27  
    28  	"github.com/cloudwego/netpoll"
    29  
    30  	"github.com/cloudwego/kitex/pkg/endpoint"
    31  	"github.com/cloudwego/kitex/pkg/gofunc"
    32  	"github.com/cloudwego/kitex/pkg/kerrors"
    33  	"github.com/cloudwego/kitex/pkg/klog"
    34  	"github.com/cloudwego/kitex/pkg/remote"
    35  	"github.com/cloudwego/kitex/pkg/remote/trans"
    36  	np "github.com/cloudwego/kitex/pkg/remote/trans/netpoll"
    37  	"github.com/cloudwego/kitex/pkg/remote/transmeta"
    38  	"github.com/cloudwego/kitex/pkg/rpcinfo"
    39  	"github.com/cloudwego/kitex/pkg/serviceinfo"
    40  	"github.com/cloudwego/kitex/pkg/stats"
    41  	"github.com/cloudwego/kitex/transport"
    42  )
    43  
    44  const defaultExitWaitGracefulShutdownTime = 1 * time.Second
    45  
    46  type svrTransHandlerFactory struct{}
    47  
    48  // NewSvrTransHandlerFactory creates a default netpollmux remote.ServerTransHandlerFactory.
    49  func NewSvrTransHandlerFactory() remote.ServerTransHandlerFactory {
    50  	return &svrTransHandlerFactory{}
    51  }
    52  
    53  // MuxEnabled returns true to mark svrTransHandlerFactory as a mux server factory.
    54  func (f *svrTransHandlerFactory) MuxEnabled() bool {
    55  	return true
    56  }
    57  
    58  // NewTransHandler implements the remote.ServerTransHandlerFactory interface.
    59  // TODO: use object pool?
    60  func (f *svrTransHandlerFactory) NewTransHandler(opt *remote.ServerOption) (remote.ServerTransHandler, error) {
    61  	return newSvrTransHandler(opt)
    62  }
    63  
    64  func newSvrTransHandler(opt *remote.ServerOption) (*svrTransHandler, error) {
    65  	svrHdlr := &svrTransHandler{
    66  		opt:           opt,
    67  		codec:         opt.Codec,
    68  		svcSearchMap:  opt.SvcSearchMap,
    69  		targetSvcInfo: opt.TargetSvcInfo,
    70  		ext:           np.NewNetpollConnExtension(),
    71  	}
    72  	if svrHdlr.opt.TracerCtl == nil {
    73  		// init TraceCtl when it is nil, or it will lead some unit tests panic
    74  		svrHdlr.opt.TracerCtl = &rpcinfo.TraceController{}
    75  	}
    76  	svrHdlr.funcPool.New = func() interface{} {
    77  		fs := make([]func(), 0, 64) // 64 is defined casually, no special meaning
    78  		return &fs
    79  	}
    80  	return svrHdlr, nil
    81  }
    82  
    83  var _ remote.ServerTransHandler = &svrTransHandler{}
    84  
    85  type svrTransHandler struct {
    86  	opt           *remote.ServerOption
    87  	svcSearchMap  map[string]*serviceinfo.ServiceInfo
    88  	targetSvcInfo *serviceinfo.ServiceInfo
    89  	inkHdlFunc    endpoint.Endpoint
    90  	codec         remote.Codec
    91  	transPipe     *remote.TransPipeline
    92  	ext           trans.Extension
    93  	funcPool      sync.Pool
    94  	conns         sync.Map
    95  	tasks         sync.WaitGroup
    96  }
    97  
    98  // Write implements the remote.ServerTransHandler interface.
    99  func (t *svrTransHandler) Write(ctx context.Context, conn net.Conn, sendMsg remote.Message) (nctx context.Context, err error) {
   100  	ri := rpcinfo.GetRPCInfo(ctx)
   101  	rpcinfo.Record(ctx, ri, stats.WriteStart, nil)
   102  	defer func() {
   103  		rpcinfo.Record(ctx, ri, stats.WriteFinish, nil)
   104  	}()
   105  
   106  	svcInfo := sendMsg.ServiceInfo()
   107  	if svcInfo != nil {
   108  		if methodInfo, _ := trans.GetMethodInfo(ri, svcInfo); methodInfo != nil {
   109  			if methodInfo.OneWay() {
   110  				return ctx, nil
   111  			}
   112  		}
   113  	}
   114  
   115  	wbuf := netpoll.NewLinkBuffer()
   116  	bufWriter := np.NewWriterByteBuffer(wbuf)
   117  	err = t.codec.Encode(ctx, sendMsg, bufWriter)
   118  	bufWriter.Release(err)
   119  	if err != nil {
   120  		return ctx, err
   121  	}
   122  	conn.(*muxSvrConn).Put(func() (buf netpoll.Writer, isNil bool) {
   123  		return wbuf, false
   124  	})
   125  	return ctx, nil
   126  }
   127  
   128  // Read implements the remote.ServerTransHandler interface.
   129  func (t *svrTransHandler) Read(ctx context.Context, conn net.Conn, msg remote.Message) (nctx context.Context, err error) {
   130  	return ctx, nil
   131  }
   132  
   133  func (t *svrTransHandler) readWithByteBuffer(ctx context.Context, bufReader remote.ByteBuffer, msg remote.Message) (err error) {
   134  	defer func() {
   135  		if bufReader != nil {
   136  			if err != nil {
   137  				bufReader.Skip(bufReader.ReadableLen())
   138  			}
   139  			bufReader.Release(err)
   140  		}
   141  		rpcinfo.Record(ctx, msg.RPCInfo(), stats.ReadFinish, err)
   142  	}()
   143  	rpcinfo.Record(ctx, msg.RPCInfo(), stats.ReadStart, nil)
   144  
   145  	err = t.codec.Decode(ctx, msg, bufReader)
   146  	if err != nil {
   147  		msg.Tags()[remote.ReadFailed] = true
   148  		return err
   149  	}
   150  	return nil
   151  }
   152  
   153  // OnRead implements the remote.ServerTransHandler interface.
   154  // Returns write err only.
   155  func (t *svrTransHandler) OnRead(muxSvrConnCtx context.Context, conn net.Conn) error {
   156  	defer t.tryRecover(muxSvrConnCtx, conn)
   157  	connection := conn.(netpoll.Connection)
   158  	r := connection.Reader()
   159  
   160  	fs := *t.funcPool.Get().(*[]func())
   161  	for total := r.Len(); total > 0; total = r.Len() {
   162  		// protocol header check
   163  		length, _, err := parseHeader(r)
   164  		if err != nil {
   165  			err = fmt.Errorf("%w: addr(%s)", err, connection.RemoteAddr())
   166  			klog.Errorf("KITEX: error=%s", err.Error())
   167  			connection.Close()
   168  			return err
   169  		}
   170  		if total < length && len(fs) > 0 {
   171  			go t.batchGoTasks(fs)
   172  			fs = *t.funcPool.Get().(*[]func())
   173  		}
   174  		reader, err := r.Slice(length)
   175  		if err != nil {
   176  			err = fmt.Errorf("%w: addr(%s)", err, connection.RemoteAddr())
   177  			klog.Errorf("KITEX: error=%s", err.Error())
   178  			connection.Close()
   179  			return nil
   180  		}
   181  		fs = append(fs, func() {
   182  			t.task(muxSvrConnCtx, conn, reader)
   183  		})
   184  	}
   185  	go t.batchGoTasks(fs)
   186  	return nil
   187  }
   188  
   189  // batchGoTasks centrally creates goroutines to execute tasks.
   190  func (t *svrTransHandler) batchGoTasks(fs []func()) {
   191  	for n := range fs {
   192  		gofunc.GoFunc(context.Background(), fs[n])
   193  	}
   194  	fs = fs[:0]
   195  	t.funcPool.Put(&fs)
   196  }
   197  
   198  // task contains a complete process about decoding request -> handling -> writing response
   199  func (t *svrTransHandler) task(muxSvrConnCtx context.Context, conn net.Conn, reader netpoll.Reader) {
   200  	t.tasks.Add(1)
   201  	defer t.tasks.Done()
   202  
   203  	// rpcInfoCtx is a pooled ctx with inited RPCInfo which can be reused.
   204  	// it's recycled in defer.
   205  	muxSvrConn, _ := muxSvrConnCtx.Value(ctxKeyMuxSvrConn{}).(*muxSvrConn)
   206  	rpcInfo := muxSvrConn.pool.Get().(rpcinfo.RPCInfo)
   207  	rpcInfoCtx := rpcinfo.NewCtxWithRPCInfo(muxSvrConnCtx, rpcInfo)
   208  
   209  	// This is the request-level, one-shot ctx.
   210  	// It adds the tracer principally, thus do not recycle.
   211  	ctx := t.startTracer(rpcInfoCtx, rpcInfo)
   212  	var err error
   213  	var recvMsg remote.Message
   214  	var sendMsg remote.Message
   215  	var closeConn bool
   216  	defer func() {
   217  		panicErr := recover()
   218  		if panicErr != nil {
   219  			if conn != nil {
   220  				ri := rpcinfo.GetRPCInfo(ctx)
   221  				rService, rAddr := getRemoteInfo(ri, conn)
   222  				klog.Errorf("KITEX: panic happened, close conn, remoteAddress=%s remoteService=%s error=%s\nstack=%s", rAddr, rService, panicErr, string(debug.Stack()))
   223  				closeConn = true
   224  			} else {
   225  				klog.Errorf("KITEX: panic happened, error=%s\nstack=%s", panicErr, string(debug.Stack()))
   226  			}
   227  		}
   228  		if closeConn && conn != nil {
   229  			conn.Close()
   230  		}
   231  		t.finishTracer(ctx, rpcInfo, err, panicErr)
   232  		remote.RecycleMessage(recvMsg)
   233  		remote.RecycleMessage(sendMsg)
   234  		// reset rpcinfo for reuse
   235  		if rpcinfo.PoolEnabled() {
   236  			rpcInfo = t.opt.InitOrResetRPCInfoFunc(rpcInfo, conn.RemoteAddr())
   237  			muxSvrConn.pool.Put(rpcInfo)
   238  		}
   239  	}()
   240  
   241  	// read
   242  	recvMsg = remote.NewMessageWithNewer(t.targetSvcInfo, t.svcSearchMap, rpcInfo, remote.Call, remote.Server, t.opt.RefuseTrafficWithoutServiceName)
   243  	bufReader := np.NewReaderByteBuffer(reader)
   244  	err = t.readWithByteBuffer(ctx, bufReader, recvMsg)
   245  	if err != nil {
   246  		// No need to close the connection when read failed in mux case, because it had finished reads.
   247  		// But still need to close conn if write failed
   248  		closeConn = t.writeErrorReplyIfNeeded(ctx, recvMsg, muxSvrConn, rpcInfo, err, true)
   249  		// for proxy case, need read actual remoteAddr, error print must exec after writeErrorReplyIfNeeded
   250  		t.OnError(ctx, err, muxSvrConn)
   251  		return
   252  	}
   253  
   254  	svcInfo := recvMsg.ServiceInfo()
   255  	t.targetSvcInfo = svcInfo
   256  	if recvMsg.MessageType() == remote.Heartbeat {
   257  		sendMsg = remote.NewMessage(nil, svcInfo, rpcInfo, remote.Heartbeat, remote.Server)
   258  	} else {
   259  		var methodInfo serviceinfo.MethodInfo
   260  		if methodInfo, err = trans.GetMethodInfo(rpcInfo, svcInfo); err != nil {
   261  			closeConn = t.writeErrorReplyIfNeeded(ctx, recvMsg, muxSvrConn, rpcInfo, err, true)
   262  			t.OnError(ctx, err, muxSvrConn)
   263  			return
   264  		}
   265  		if methodInfo.OneWay() {
   266  			sendMsg = remote.NewMessage(nil, svcInfo, rpcInfo, remote.Reply, remote.Server)
   267  		} else {
   268  			sendMsg = remote.NewMessage(methodInfo.NewResult(), svcInfo, rpcInfo, remote.Reply, remote.Server)
   269  		}
   270  
   271  		ctx, err = t.transPipe.OnMessage(ctx, recvMsg, sendMsg)
   272  		if err != nil {
   273  			// error cannot be wrapped to print here, so it must exec before NewTransError
   274  			t.OnError(ctx, err, muxSvrConn)
   275  			err = remote.NewTransError(remote.InternalError, err)
   276  			closeConn = t.writeErrorReplyIfNeeded(ctx, recvMsg, muxSvrConn, rpcInfo, err, false)
   277  			return
   278  		}
   279  	}
   280  
   281  	remote.FillSendMsgFromRecvMsg(recvMsg, sendMsg)
   282  	if ctx, err = t.transPipe.Write(ctx, muxSvrConn, sendMsg); err != nil {
   283  		t.OnError(ctx, err, muxSvrConn)
   284  		closeConn = true
   285  		return
   286  	}
   287  }
   288  
   289  // OnMessage implements the remote.ServerTransHandler interface.
   290  // msg is the decoded instance, such as Arg or Result.
   291  // OnMessage notifies the higher level to process. It's used in async and server-side logic.
   292  func (t *svrTransHandler) OnMessage(ctx context.Context, args, result remote.Message) (context.Context, error) {
   293  	err := t.inkHdlFunc(ctx, args.Data(), result.Data())
   294  	return ctx, err
   295  }
   296  
   297  type ctxKeyMuxSvrConn struct{}
   298  
   299  // OnActive implements the remote.ServerTransHandler interface.
   300  // sync.Pool for RPCInfo is setup here.
   301  func (t *svrTransHandler) OnActive(ctx context.Context, conn net.Conn) (context.Context, error) {
   302  	connection := conn.(netpoll.Connection)
   303  
   304  	// 1. set readwrite timeout
   305  	connection.SetReadTimeout(t.opt.ReadWriteTimeout)
   306  
   307  	// 2. set mux server conn
   308  	pool := &sync.Pool{
   309  		New: func() interface{} {
   310  			// init rpcinfo
   311  			ri := t.opt.InitOrResetRPCInfoFunc(nil, connection.RemoteAddr())
   312  			return ri
   313  		},
   314  	}
   315  	muxConn := newMuxSvrConn(connection, pool)
   316  	t.conns.Store(conn, muxConn)
   317  	return context.WithValue(context.Background(), ctxKeyMuxSvrConn{}, muxConn), nil
   318  }
   319  
   320  // OnInactive implements the remote.ServerTransHandler interface.
   321  func (t *svrTransHandler) OnInactive(ctx context.Context, conn net.Conn) {
   322  	t.conns.Delete(conn)
   323  }
   324  
   325  func (t *svrTransHandler) GracefulShutdown(ctx context.Context) error {
   326  	// Send a control frame with sequence ID 0 to notify the remote
   327  	// end to close the connection or prevent further operation on it.
   328  	iv := rpcinfo.NewInvocation("none", "none")
   329  	iv.SetSeqID(0)
   330  	ri := rpcinfo.NewRPCInfo(nil, nil, iv, nil, nil)
   331  	data := NewControlFrame()
   332  	svcInfo := t.getSvcInfo()
   333  	msg := remote.NewMessage(data, svcInfo, ri, remote.Reply, remote.Server)
   334  	msg.SetProtocolInfo(remote.NewProtocolInfo(transport.TTHeader, serviceinfo.Thrift))
   335  	msg.TransInfo().TransStrInfo()[transmeta.HeaderConnectionReadyToReset] = "1"
   336  
   337  	// wait until all notifications are sent and clients stop using those connections
   338  	done := make(chan struct{})
   339  	gofunc.GoFunc(context.Background(), func() {
   340  		// 1. write control frames to all connections
   341  		t.conns.Range(func(k, v interface{}) bool {
   342  			sconn := v.(*muxSvrConn)
   343  			if !sconn.IsActive() {
   344  				return true
   345  			}
   346  			wbuf := netpoll.NewLinkBuffer()
   347  			bufWriter := np.NewWriterByteBuffer(wbuf)
   348  			err := t.codec.Encode(ctx, msg, bufWriter)
   349  			bufWriter.Release(err)
   350  			if err == nil {
   351  				sconn.Put(func() (buf netpoll.Writer, isNil bool) {
   352  					return wbuf, false
   353  				})
   354  			} else {
   355  				klog.Warn("KITEX: signal connection closing error:",
   356  					err.Error(), sconn.LocalAddr().String(), "=>", sconn.RemoteAddr().String())
   357  			}
   358  			return true
   359  		})
   360  		// 2. waiting for all tasks finished
   361  		t.tasks.Wait()
   362  		// 3. waiting for all connections have been shutdown gracefully
   363  		t.conns.Range(func(k, v interface{}) bool {
   364  			sconn := v.(*muxSvrConn)
   365  			if sconn.IsActive() {
   366  				sconn.GracefulShutdown()
   367  			}
   368  			return true
   369  		})
   370  		// 4. waiting all crrst packets received by client
   371  		time.Sleep(defaultExitWaitGracefulShutdownTime)
   372  		close(done)
   373  	})
   374  	select {
   375  	case <-ctx.Done():
   376  		return ctx.Err()
   377  	case <-done:
   378  		return nil
   379  	}
   380  }
   381  
   382  // OnError implements the remote.ServerTransHandler interface.
   383  func (t *svrTransHandler) OnError(ctx context.Context, err error, conn net.Conn) {
   384  	ri := rpcinfo.GetRPCInfo(ctx)
   385  	rService, rAddr := getRemoteInfo(ri, conn)
   386  	if t.ext.IsRemoteClosedErr(err) {
   387  		// it should not regard error which cause by remote connection closed as server error
   388  		if ri == nil {
   389  			return
   390  		}
   391  		remote := rpcinfo.AsMutableEndpointInfo(ri.From())
   392  		remote.SetTag(rpcinfo.RemoteClosedTag, "1")
   393  	} else {
   394  		var de *kerrors.DetailedError
   395  		if ok := errors.As(err, &de); ok && de.Stack() != "" {
   396  			klog.CtxErrorf(ctx, "KITEX: processing request error, remoteService=%s, remoteAddr=%v, error=%s\nstack=%s", rService, rAddr, err.Error(), de.Stack())
   397  		} else {
   398  			klog.CtxErrorf(ctx, "KITEX: processing request error, remoteService=%s, remoteAddr=%v, error=%s", rService, rAddr, err.Error())
   399  		}
   400  	}
   401  }
   402  
   403  // SetInvokeHandleFunc implements the remote.InvokeHandleFuncSetter interface.
   404  func (t *svrTransHandler) SetInvokeHandleFunc(inkHdlFunc endpoint.Endpoint) {
   405  	t.inkHdlFunc = inkHdlFunc
   406  }
   407  
   408  // SetPipeline implements the remote.ServerTransHandler interface.
   409  func (t *svrTransHandler) SetPipeline(p *remote.TransPipeline) {
   410  	t.transPipe = p
   411  }
   412  
   413  func (t *svrTransHandler) writeErrorReplyIfNeeded(
   414  	ctx context.Context, recvMsg remote.Message, conn net.Conn, ri rpcinfo.RPCInfo, err error, doOnMessage bool,
   415  ) (shouldCloseConn bool) {
   416  	svcInfo := recvMsg.ServiceInfo()
   417  	if svcInfo != nil {
   418  		if methodInfo, _ := trans.GetMethodInfo(ri, svcInfo); methodInfo != nil {
   419  			if methodInfo.OneWay() {
   420  				return
   421  			}
   422  		}
   423  	}
   424  	transErr, isTransErr := err.(*remote.TransError)
   425  	if !isTransErr {
   426  		return
   427  	}
   428  	errMsg := remote.NewMessage(transErr, svcInfo, ri, remote.Exception, remote.Server)
   429  	remote.FillSendMsgFromRecvMsg(recvMsg, errMsg)
   430  	if doOnMessage {
   431  		// if error happen before normal OnMessage, exec it to transfer header trans info into rpcinfo
   432  		t.transPipe.OnMessage(ctx, recvMsg, errMsg)
   433  	}
   434  	ctx, err = t.transPipe.Write(ctx, conn, errMsg)
   435  	if err != nil {
   436  		klog.CtxErrorf(ctx, "KITEX: write error reply failed, remote=%s, error=%s", conn.RemoteAddr(), err.Error())
   437  		return true
   438  	}
   439  	return
   440  }
   441  
   442  func (t *svrTransHandler) tryRecover(ctx context.Context, conn net.Conn) {
   443  	if err := recover(); err != nil {
   444  		// rpcStat := internal.AsMutableRPCStats(t.rpcinfo.Stats())
   445  		// rpcStat.SetPanicked(err)
   446  		// t.opt.TracerCtl.DoFinish(ctx, klog)
   447  		// 这里不需要 Reset rpcStats 因为连接会关闭,会直接把 RPCInfo 进行 Recycle
   448  
   449  		if conn != nil {
   450  			conn.Close()
   451  			klog.CtxErrorf(ctx, "KITEX: panic happened, close conn[%s], %s\n%s", conn.RemoteAddr(), err, string(debug.Stack()))
   452  		} else {
   453  			klog.CtxErrorf(ctx, "KITEX: panic happened, %s\n%s", err, string(debug.Stack()))
   454  		}
   455  	}
   456  }
   457  
   458  func (t *svrTransHandler) startTracer(ctx context.Context, ri rpcinfo.RPCInfo) context.Context {
   459  	c := t.opt.TracerCtl.DoStart(ctx, ri)
   460  	return c
   461  }
   462  
   463  func (t *svrTransHandler) finishTracer(ctx context.Context, ri rpcinfo.RPCInfo, err error, panicErr interface{}) {
   464  	rpcStats := rpcinfo.AsMutableRPCStats(ri.Stats())
   465  	if rpcStats == nil {
   466  		return
   467  	}
   468  	if panicErr != nil {
   469  		rpcStats.SetPanicked(panicErr)
   470  	}
   471  	if errors.Is(err, netpoll.ErrConnClosed) {
   472  		// it should not regard error which cause by remote connection closed as server error
   473  		err = nil
   474  	}
   475  	t.opt.TracerCtl.DoFinish(ctx, ri, err)
   476  	// for server side, rpcinfo is reused on connection, clear the rpc stats info but keep the level config
   477  	sl := ri.Stats().Level()
   478  	rpcStats.Reset()
   479  	rpcStats.SetLevel(sl)
   480  }
   481  
   482  // getSvcInfo is used to get one ServiceInfo
   483  func (t *svrTransHandler) getSvcInfo() *serviceinfo.ServiceInfo {
   484  	if t.targetSvcInfo != nil {
   485  		return t.targetSvcInfo
   486  	}
   487  	for _, svcInfo := range t.svcSearchMap {
   488  		return svcInfo
   489  	}
   490  	return nil
   491  }
   492  
   493  func getRemoteInfo(ri rpcinfo.RPCInfo, conn net.Conn) (string, net.Addr) {
   494  	rAddr := conn.RemoteAddr()
   495  	if ri == nil {
   496  		return "", rAddr
   497  	}
   498  	if rAddr.Network() == "unix" {
   499  		if ri.From().Address() != nil {
   500  			rAddr = ri.From().Address()
   501  		}
   502  	}
   503  	return ri.From().ServiceName(), rAddr
   504  }