github.com/cloudwego/kitex@v0.9.0/pkg/remote/trans/default_server_handler.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package trans
    18  
    19  import (
    20  	"context"
    21  	"errors"
    22  	"fmt"
    23  	"net"
    24  	"runtime/debug"
    25  
    26  	"github.com/cloudwego/kitex/pkg/endpoint"
    27  	"github.com/cloudwego/kitex/pkg/kerrors"
    28  	"github.com/cloudwego/kitex/pkg/klog"
    29  	"github.com/cloudwego/kitex/pkg/remote"
    30  	"github.com/cloudwego/kitex/pkg/rpcinfo"
    31  	"github.com/cloudwego/kitex/pkg/serviceinfo"
    32  	"github.com/cloudwego/kitex/pkg/stats"
    33  )
    34  
    35  // NewDefaultSvrTransHandler to provide default impl of svrTransHandler, it can be reused in netpoll, shm-ipc, framework-sdk extensions
    36  func NewDefaultSvrTransHandler(opt *remote.ServerOption, ext Extension) (remote.ServerTransHandler, error) {
    37  	svrHdlr := &svrTransHandler{
    38  		opt:           opt,
    39  		codec:         opt.Codec,
    40  		svcSearchMap:  opt.SvcSearchMap,
    41  		targetSvcInfo: opt.TargetSvcInfo,
    42  		ext:           ext,
    43  	}
    44  	if svrHdlr.opt.TracerCtl == nil {
    45  		// init TraceCtl when it is nil, or it will lead some unit tests panic
    46  		svrHdlr.opt.TracerCtl = &rpcinfo.TraceController{}
    47  	}
    48  	return svrHdlr, nil
    49  }
    50  
    51  type svrTransHandler struct {
    52  	opt           *remote.ServerOption
    53  	svcSearchMap  map[string]*serviceinfo.ServiceInfo
    54  	targetSvcInfo *serviceinfo.ServiceInfo
    55  	inkHdlFunc    endpoint.Endpoint
    56  	codec         remote.Codec
    57  	transPipe     *remote.TransPipeline
    58  	ext           Extension
    59  }
    60  
    61  // Write implements the remote.ServerTransHandler interface.
    62  func (t *svrTransHandler) Write(ctx context.Context, conn net.Conn, sendMsg remote.Message) (nctx context.Context, err error) {
    63  	var bufWriter remote.ByteBuffer
    64  	ri := sendMsg.RPCInfo()
    65  	rpcinfo.Record(ctx, ri, stats.WriteStart, nil)
    66  	defer func() {
    67  		t.ext.ReleaseBuffer(bufWriter, err)
    68  		rpcinfo.Record(ctx, ri, stats.WriteFinish, err)
    69  	}()
    70  
    71  	svcInfo := sendMsg.ServiceInfo()
    72  	if svcInfo != nil {
    73  		if methodInfo, _ := GetMethodInfo(ri, svcInfo); methodInfo != nil {
    74  			if methodInfo.OneWay() {
    75  				return ctx, nil
    76  			}
    77  		}
    78  	}
    79  
    80  	bufWriter = t.ext.NewWriteByteBuffer(ctx, conn, sendMsg)
    81  	err = t.codec.Encode(ctx, sendMsg, bufWriter)
    82  	if err != nil {
    83  		return ctx, err
    84  	}
    85  	return ctx, bufWriter.Flush()
    86  }
    87  
    88  // Read implements the remote.ServerTransHandler interface.
    89  func (t *svrTransHandler) Read(ctx context.Context, conn net.Conn, recvMsg remote.Message) (nctx context.Context, err error) {
    90  	var bufReader remote.ByteBuffer
    91  	defer func() {
    92  		t.ext.ReleaseBuffer(bufReader, err)
    93  		rpcinfo.Record(ctx, recvMsg.RPCInfo(), stats.ReadFinish, err)
    94  	}()
    95  	rpcinfo.Record(ctx, recvMsg.RPCInfo(), stats.ReadStart, nil)
    96  
    97  	bufReader = t.ext.NewReadByteBuffer(ctx, conn, recvMsg)
    98  	if codec, ok := t.codec.(remote.MetaDecoder); ok {
    99  		if err = codec.DecodeMeta(ctx, recvMsg, bufReader); err == nil {
   100  			if t.opt.Profiler != nil && t.opt.ProfilerTransInfoTagging != nil && recvMsg.TransInfo() != nil {
   101  				var tags []string
   102  				ctx, tags = t.opt.ProfilerTransInfoTagging(ctx, recvMsg)
   103  				ctx = t.opt.Profiler.Tag(ctx, tags...)
   104  			}
   105  			err = codec.DecodePayload(ctx, recvMsg, bufReader)
   106  		}
   107  	} else {
   108  		err = t.codec.Decode(ctx, recvMsg, bufReader)
   109  	}
   110  	if err != nil {
   111  		recvMsg.Tags()[remote.ReadFailed] = true
   112  		return ctx, err
   113  	}
   114  	return ctx, nil
   115  }
   116  
   117  func (t *svrTransHandler) newCtxWithRPCInfo(ctx context.Context, conn net.Conn) (context.Context, rpcinfo.RPCInfo) {
   118  	if rpcinfo.PoolEnabled() { // reuse per-connection rpcinfo
   119  		return ctx, rpcinfo.GetRPCInfo(ctx)
   120  		// delayed reinitialize for faster response
   121  	}
   122  	// new rpcinfo if reuse is disabled
   123  	ri := t.opt.InitOrResetRPCInfoFunc(nil, conn.RemoteAddr())
   124  	return rpcinfo.NewCtxWithRPCInfo(ctx, ri), ri
   125  }
   126  
   127  // OnRead implements the remote.ServerTransHandler interface.
   128  // The connection should be closed after returning error.
   129  func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) (err error) {
   130  	ctx, ri := t.newCtxWithRPCInfo(ctx, conn)
   131  	t.ext.SetReadTimeout(ctx, conn, ri.Config(), remote.Server)
   132  	var recvMsg remote.Message
   133  	var sendMsg remote.Message
   134  	closeConnOutsideIfErr := true
   135  	defer func() {
   136  		panicErr := recover()
   137  		var wrapErr error
   138  		if panicErr != nil {
   139  			stack := string(debug.Stack())
   140  			if conn != nil {
   141  				ri := rpcinfo.GetRPCInfo(ctx)
   142  				rService, rAddr := getRemoteInfo(ri, conn)
   143  				klog.CtxErrorf(ctx, "KITEX: panic happened, remoteAddress=%s, remoteService=%s, error=%v\nstack=%s", rAddr, rService, panicErr, stack)
   144  			} else {
   145  				klog.CtxErrorf(ctx, "KITEX: panic happened, error=%v\nstack=%s", panicErr, stack)
   146  			}
   147  			if err != nil {
   148  				wrapErr = kerrors.ErrPanic.WithCauseAndStack(fmt.Errorf("[happened in OnRead] %s, last error=%s", panicErr, err.Error()), stack)
   149  			} else {
   150  				wrapErr = kerrors.ErrPanic.WithCauseAndStack(fmt.Errorf("[happened in OnRead] %s", panicErr), stack)
   151  			}
   152  		}
   153  		t.finishTracer(ctx, ri, err, panicErr)
   154  		t.finishProfiler(ctx)
   155  		remote.RecycleMessage(recvMsg)
   156  		remote.RecycleMessage(sendMsg)
   157  		// reset rpcinfo for reuse
   158  		if rpcinfo.PoolEnabled() {
   159  			t.opt.InitOrResetRPCInfoFunc(ri, conn.RemoteAddr())
   160  		}
   161  		if wrapErr != nil {
   162  			err = wrapErr
   163  		}
   164  		if err != nil && !closeConnOutsideIfErr {
   165  			err = nil
   166  		}
   167  	}()
   168  	ctx = t.startTracer(ctx, ri)
   169  	ctx = t.startProfiler(ctx)
   170  	recvMsg = remote.NewMessageWithNewer(t.targetSvcInfo, t.svcSearchMap, ri, remote.Call, remote.Server, t.opt.RefuseTrafficWithoutServiceName)
   171  	recvMsg.SetPayloadCodec(t.opt.PayloadCodec)
   172  	ctx, err = t.transPipe.Read(ctx, conn, recvMsg)
   173  	if err != nil {
   174  		t.writeErrorReplyIfNeeded(ctx, recvMsg, conn, err, ri, true)
   175  		t.OnError(ctx, err, conn)
   176  		return err
   177  	}
   178  
   179  	svcInfo := recvMsg.ServiceInfo()
   180  	// heartbeat processing
   181  	// recvMsg.MessageType would be set to remote.Heartbeat in previous Read procedure
   182  	// if specified codec support heartbeat
   183  	if recvMsg.MessageType() == remote.Heartbeat {
   184  		sendMsg = remote.NewMessage(nil, svcInfo, ri, remote.Heartbeat, remote.Server)
   185  	} else {
   186  		// reply processing
   187  		var methodInfo serviceinfo.MethodInfo
   188  		if methodInfo, err = GetMethodInfo(ri, svcInfo); err != nil {
   189  			// it won't be err, because the method has been checked in decode, err check here just do defensive inspection
   190  			t.writeErrorReplyIfNeeded(ctx, recvMsg, conn, err, ri, true)
   191  			// for proxy case, need read actual remoteAddr, error print must exec after writeErrorReplyIfNeeded
   192  			t.OnError(ctx, err, conn)
   193  			return err
   194  		}
   195  		if methodInfo.OneWay() {
   196  			sendMsg = remote.NewMessage(nil, svcInfo, ri, remote.Reply, remote.Server)
   197  		} else {
   198  			sendMsg = remote.NewMessage(methodInfo.NewResult(), svcInfo, ri, remote.Reply, remote.Server)
   199  		}
   200  
   201  		ctx, err = t.transPipe.OnMessage(ctx, recvMsg, sendMsg)
   202  		if err != nil {
   203  			// error cannot be wrapped to print here, so it must exec before NewTransError
   204  			t.OnError(ctx, err, conn)
   205  			err = remote.NewTransError(remote.InternalError, err)
   206  			if closeConn := t.writeErrorReplyIfNeeded(ctx, recvMsg, conn, err, ri, false); closeConn {
   207  				return err
   208  			}
   209  			// connection don't need to be closed when the error is return by the server handler
   210  			closeConnOutsideIfErr = false
   211  			return
   212  		}
   213  	}
   214  
   215  	remote.FillSendMsgFromRecvMsg(recvMsg, sendMsg)
   216  	if ctx, err = t.transPipe.Write(ctx, conn, sendMsg); err != nil {
   217  		t.OnError(ctx, err, conn)
   218  		return err
   219  	}
   220  	return
   221  }
   222  
   223  // OnMessage implements the remote.ServerTransHandler interface.
   224  // msg is the decoded instance, such as Arg and Result.
   225  func (t *svrTransHandler) OnMessage(ctx context.Context, args, result remote.Message) (context.Context, error) {
   226  	err := t.inkHdlFunc(ctx, args.Data(), result.Data())
   227  	return ctx, err
   228  }
   229  
   230  // OnActive implements the remote.ServerTransHandler interface.
   231  func (t *svrTransHandler) OnActive(ctx context.Context, conn net.Conn) (context.Context, error) {
   232  	// init rpcinfo
   233  	ri := t.opt.InitOrResetRPCInfoFunc(nil, conn.RemoteAddr())
   234  	return rpcinfo.NewCtxWithRPCInfo(ctx, ri), nil
   235  }
   236  
   237  // OnInactive implements the remote.ServerTransHandler interface.
   238  func (t *svrTransHandler) OnInactive(ctx context.Context, conn net.Conn) {
   239  	// recycle rpcinfo
   240  	rpcinfo.PutRPCInfo(rpcinfo.GetRPCInfo(ctx))
   241  }
   242  
   243  // OnError implements the remote.ServerTransHandler interface.
   244  func (t *svrTransHandler) OnError(ctx context.Context, err error, conn net.Conn) {
   245  	ri := rpcinfo.GetRPCInfo(ctx)
   246  	rService, rAddr := getRemoteInfo(ri, conn)
   247  	if t.ext.IsRemoteClosedErr(err) {
   248  		// it should not regard error which cause by remote connection closed as server error
   249  		if ri == nil {
   250  			return
   251  		}
   252  		remote := rpcinfo.AsMutableEndpointInfo(ri.From())
   253  		remote.SetTag(rpcinfo.RemoteClosedTag, "1")
   254  	} else {
   255  		var de *kerrors.DetailedError
   256  		if ok := errors.As(err, &de); ok && de.Stack() != "" {
   257  			klog.CtxErrorf(ctx, "KITEX: processing request error, remoteService=%s, remoteAddr=%v, error=%s\nstack=%s", rService, rAddr, err.Error(), de.Stack())
   258  		} else {
   259  			klog.CtxErrorf(ctx, "KITEX: processing request error, remoteService=%s, remoteAddr=%v, error=%s", rService, rAddr, err.Error())
   260  		}
   261  	}
   262  }
   263  
   264  // SetInvokeHandleFunc implements the remote.InvokeHandleFuncSetter interface.
   265  func (t *svrTransHandler) SetInvokeHandleFunc(inkHdlFunc endpoint.Endpoint) {
   266  	t.inkHdlFunc = inkHdlFunc
   267  }
   268  
   269  // SetPipeline implements the remote.ServerTransHandler interface.
   270  func (t *svrTransHandler) SetPipeline(p *remote.TransPipeline) {
   271  	t.transPipe = p
   272  }
   273  
   274  func (t *svrTransHandler) writeErrorReplyIfNeeded(
   275  	ctx context.Context, recvMsg remote.Message, conn net.Conn, err error, ri rpcinfo.RPCInfo, doOnMessage bool,
   276  ) (shouldCloseConn bool) {
   277  	if cn, ok := conn.(remote.IsActive); ok && !cn.IsActive() {
   278  		// conn is closed, no need reply
   279  		return
   280  	}
   281  	svcInfo := recvMsg.ServiceInfo()
   282  	if svcInfo != nil {
   283  		if methodInfo, _ := GetMethodInfo(ri, svcInfo); methodInfo != nil {
   284  			if methodInfo.OneWay() {
   285  				return
   286  			}
   287  		}
   288  	}
   289  
   290  	transErr, isTransErr := err.(*remote.TransError)
   291  	if !isTransErr {
   292  		return
   293  	}
   294  	errMsg := remote.NewMessage(transErr, svcInfo, ri, remote.Exception, remote.Server)
   295  	remote.FillSendMsgFromRecvMsg(recvMsg, errMsg)
   296  	if doOnMessage {
   297  		// if error happen before normal OnMessage, exec it to transfer header trans info into rpcinfo
   298  		t.transPipe.OnMessage(ctx, recvMsg, errMsg)
   299  	}
   300  	ctx, err = t.transPipe.Write(ctx, conn, errMsg)
   301  	if err != nil {
   302  		klog.CtxErrorf(ctx, "KITEX: write error reply failed, remote=%s, error=%s", conn.RemoteAddr(), err.Error())
   303  		return true
   304  	}
   305  	return
   306  }
   307  
   308  func (t *svrTransHandler) startTracer(ctx context.Context, ri rpcinfo.RPCInfo) context.Context {
   309  	c := t.opt.TracerCtl.DoStart(ctx, ri)
   310  	return c
   311  }
   312  
   313  func (t *svrTransHandler) finishTracer(ctx context.Context, ri rpcinfo.RPCInfo, err error, panicErr interface{}) {
   314  	rpcStats := rpcinfo.AsMutableRPCStats(ri.Stats())
   315  	if rpcStats == nil {
   316  		return
   317  	}
   318  	if panicErr != nil {
   319  		rpcStats.SetPanicked(panicErr)
   320  	}
   321  	if err != nil && t.ext.IsRemoteClosedErr(err) {
   322  		// it should not regard the error which caused by remote connection closed as server error
   323  		err = nil
   324  	}
   325  	t.opt.TracerCtl.DoFinish(ctx, ri, err)
   326  	// for server side, rpcinfo is reused on connection, clear the rpc stats info but keep the level config
   327  	sl := ri.Stats().Level()
   328  	rpcStats.Reset()
   329  	rpcStats.SetLevel(sl)
   330  }
   331  
   332  func (t *svrTransHandler) startProfiler(ctx context.Context) context.Context {
   333  	if t.opt.Profiler == nil {
   334  		return ctx
   335  	}
   336  	return t.opt.Profiler.Prepare(ctx)
   337  }
   338  
   339  func (t *svrTransHandler) finishProfiler(ctx context.Context) {
   340  	if t.opt.Profiler == nil {
   341  		return
   342  	}
   343  	t.opt.Profiler.Untag(ctx)
   344  }
   345  
   346  func getRemoteInfo(ri rpcinfo.RPCInfo, conn net.Conn) (string, net.Addr) {
   347  	rAddr := conn.RemoteAddr()
   348  	if ri == nil {
   349  		return "", rAddr
   350  	}
   351  	if rAddr != nil && rAddr.Network() == "unix" {
   352  		if ri.From().Address() != nil {
   353  			rAddr = ri.From().Address()
   354  		}
   355  	}
   356  	return ri.From().ServiceName(), rAddr
   357  }