github.com/cloudwego/kitex@v0.9.0/pkg/remote/trans/nphttp2/server_handler.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package nphttp2
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"errors"
    23  	"fmt"
    24  	"net"
    25  	"runtime/debug"
    26  	"strings"
    27  	"sync"
    28  	"time"
    29  
    30  	"github.com/cloudwego/netpoll"
    31  
    32  	"github.com/cloudwego/kitex/pkg/endpoint"
    33  	"github.com/cloudwego/kitex/pkg/gofunc"
    34  	"github.com/cloudwego/kitex/pkg/kerrors"
    35  	"github.com/cloudwego/kitex/pkg/klog"
    36  	"github.com/cloudwego/kitex/pkg/remote"
    37  	"github.com/cloudwego/kitex/pkg/remote/codec"
    38  	"github.com/cloudwego/kitex/pkg/remote/codec/grpc"
    39  	"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes"
    40  	grpcTransport "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc"
    41  	"github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/status"
    42  	"github.com/cloudwego/kitex/pkg/rpcinfo"
    43  	"github.com/cloudwego/kitex/pkg/serviceinfo"
    44  	"github.com/cloudwego/kitex/pkg/stats"
    45  	"github.com/cloudwego/kitex/pkg/streaming"
    46  	"github.com/cloudwego/kitex/transport"
    47  )
    48  
    49  type svrTransHandlerFactory struct{}
    50  
    51  // NewSvrTransHandlerFactory ...
    52  func NewSvrTransHandlerFactory() remote.ServerTransHandlerFactory {
    53  	return &svrTransHandlerFactory{}
    54  }
    55  
    56  func (f *svrTransHandlerFactory) NewTransHandler(opt *remote.ServerOption) (remote.ServerTransHandler, error) {
    57  	return newSvrTransHandler(opt)
    58  }
    59  
    60  func newSvrTransHandler(opt *remote.ServerOption) (*svrTransHandler, error) {
    61  	return &svrTransHandler{
    62  		opt:          opt,
    63  		svcSearchMap: opt.SvcSearchMap,
    64  		codec:        grpc.NewGRPCCodec(grpc.WithThriftCodec(opt.PayloadCodec)),
    65  	}, nil
    66  }
    67  
    68  var _ remote.ServerTransHandler = &svrTransHandler{}
    69  
    70  type svrTransHandler struct {
    71  	opt          *remote.ServerOption
    72  	svcSearchMap map[string]*serviceinfo.ServiceInfo
    73  	inkHdlFunc   endpoint.Endpoint
    74  	codec        remote.Codec
    75  }
    76  
    77  var prefaceReadAtMost = func() int {
    78  	// min(len(ClientPreface), len(flagBuf))
    79  	// len(flagBuf) = 2 * codec.Size32
    80  	if 2*codec.Size32 < grpcTransport.ClientPrefaceLen {
    81  		return 2 * codec.Size32
    82  	}
    83  	return grpcTransport.ClientPrefaceLen
    84  }()
    85  
    86  func (t *svrTransHandler) ProtocolMatch(ctx context.Context, conn net.Conn) (err error) {
    87  	// Check the validity of client preface.
    88  	npReader := conn.(interface{ Reader() netpoll.Reader }).Reader()
    89  	// read at most avoid block
    90  	preface, err := npReader.Peek(prefaceReadAtMost)
    91  	if err != nil {
    92  		return err
    93  	}
    94  	if bytes.Equal(preface[:prefaceReadAtMost], grpcTransport.ClientPreface[:prefaceReadAtMost]) {
    95  		return nil
    96  	}
    97  	return errors.New("error protocol not match")
    98  }
    99  
   100  func (t *svrTransHandler) Write(ctx context.Context, conn net.Conn, msg remote.Message) (nctx context.Context, err error) {
   101  	buf := newBuffer(conn.(*serverConn))
   102  	defer buf.Release(err)
   103  
   104  	if err = t.codec.Encode(ctx, msg, buf); err != nil {
   105  		return ctx, err
   106  	}
   107  	return ctx, buf.Flush()
   108  }
   109  
   110  func (t *svrTransHandler) Read(ctx context.Context, conn net.Conn, msg remote.Message) (nctx context.Context, err error) {
   111  	buf := newBuffer(conn.(*serverConn))
   112  	defer buf.Release(err)
   113  
   114  	err = t.codec.Decode(ctx, msg, buf)
   115  	return ctx, err
   116  }
   117  
   118  // 只 return write err
   119  func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) error {
   120  	svrTrans := ctx.Value(ctxKeySvrTransport).(*SvrTrans)
   121  	tr := svrTrans.tr
   122  
   123  	tr.HandleStreams(func(s *grpcTransport.Stream) {
   124  		gofunc.GoFunc(ctx, func() {
   125  			ri := svrTrans.pool.Get().(rpcinfo.RPCInfo)
   126  			rCtx := rpcinfo.NewCtxWithRPCInfo(s.Context(), ri)
   127  			defer func() {
   128  				// reset rpcinfo for performance (PR #584)
   129  				if rpcinfo.PoolEnabled() {
   130  					ri = t.opt.InitOrResetRPCInfoFunc(ri, conn.RemoteAddr())
   131  					svrTrans.pool.Put(ri)
   132  				}
   133  			}()
   134  
   135  			// set grpc transport flag before execute metahandler
   136  			rpcinfo.AsMutableRPCConfig(ri.Config()).SetTransportProtocol(transport.GRPC)
   137  			var err error
   138  			for _, shdlr := range t.opt.StreamingMetaHandlers {
   139  				rCtx, err = shdlr.OnReadStream(rCtx)
   140  				if err != nil {
   141  					tr.WriteStatus(s, convertStatus(err))
   142  					return
   143  				}
   144  			}
   145  			rCtx = t.startTracer(rCtx, ri)
   146  			defer func() {
   147  				panicErr := recover()
   148  				if panicErr != nil {
   149  					if conn != nil {
   150  						klog.CtxErrorf(rCtx, "KITEX: gRPC panic happened, close conn, remoteAddress=%s, error=%s\nstack=%s", conn.RemoteAddr(), panicErr, string(debug.Stack()))
   151  					} else {
   152  						klog.CtxErrorf(rCtx, "KITEX: gRPC panic happened, error=%v\nstack=%s", panicErr, string(debug.Stack()))
   153  					}
   154  				}
   155  				t.finishTracer(rCtx, ri, err, panicErr)
   156  			}()
   157  
   158  			ink := ri.Invocation().(rpcinfo.InvocationSetter)
   159  			sm := s.Method()
   160  			if sm != "" && sm[0] == '/' {
   161  				sm = sm[1:]
   162  			}
   163  			pos := strings.LastIndex(sm, "/")
   164  			if pos == -1 {
   165  				errDesc := fmt.Sprintf("malformed method name, method=%q", s.Method())
   166  				tr.WriteStatus(s, status.New(codes.Internal, errDesc))
   167  				return
   168  			}
   169  			methodName := sm[pos+1:]
   170  			ink.SetMethodName(methodName)
   171  
   172  			if mutableTo := rpcinfo.AsMutableEndpointInfo(ri.To()); mutableTo != nil {
   173  				if err = mutableTo.SetMethod(methodName); err != nil {
   174  					errDesc := fmt.Sprintf("setMethod failed in streaming, method=%s, error=%s", methodName, err.Error())
   175  					_ = tr.WriteStatus(s, status.New(codes.Internal, errDesc))
   176  					return
   177  				}
   178  			}
   179  
   180  			var serviceName string
   181  			idx := strings.LastIndex(sm[:pos], ".")
   182  			if idx == -1 {
   183  				ink.SetPackageName("")
   184  				serviceName = sm[0:pos]
   185  			} else {
   186  				ink.SetPackageName(sm[:idx])
   187  				serviceName = sm[idx+1 : pos]
   188  			}
   189  			ink.SetServiceName(serviceName)
   190  
   191  			// set recv grpc compressor at server to decode the pack from client
   192  			remote.SetRecvCompressor(ri, s.RecvCompress())
   193  			// set send grpc compressor at server to encode reply pack
   194  			remote.SetSendCompressor(ri, s.SendCompress())
   195  
   196  			svcInfo := t.svcSearchMap[remote.BuildMultiServiceKey(serviceName, methodName)]
   197  			var methodInfo serviceinfo.MethodInfo
   198  			if svcInfo != nil {
   199  				methodInfo = svcInfo.MethodInfo(methodName)
   200  			}
   201  
   202  			rawStream := NewStream(rCtx, svcInfo, newServerConn(tr, s), t)
   203  			st := newStreamWithMiddleware(rawStream, t.opt.RecvEndpoint, t.opt.SendEndpoint)
   204  
   205  			// bind stream into ctx, in order to let user set header and trailer by provided api in meta_api.go
   206  			rCtx = streaming.NewCtxWithStream(rCtx, st)
   207  
   208  			if methodInfo == nil {
   209  				unknownServiceHandlerFunc := t.opt.GRPCUnknownServiceHandler
   210  				if unknownServiceHandlerFunc != nil {
   211  					rpcinfo.Record(rCtx, ri, stats.ServerHandleStart, nil)
   212  					err = unknownServiceHandlerFunc(rCtx, methodName, st)
   213  					if err != nil {
   214  						err = kerrors.ErrBiz.WithCause(err)
   215  					}
   216  				} else {
   217  					if svcInfo == nil {
   218  						err = remote.NewTransErrorWithMsg(remote.UnknownService, fmt.Sprintf("unknown service %s", serviceName))
   219  					} else {
   220  						err = remote.NewTransErrorWithMsg(remote.UnknownMethod, fmt.Sprintf("unknown method %s", methodName))
   221  					}
   222  				}
   223  			} else {
   224  				if streaming.UnaryCompatibleMiddleware(methodInfo.StreamingMode(), t.opt.CompatibleMiddlewareForUnary) {
   225  					// making streaming unary APIs capable of using the same server middleware as non-streaming APIs
   226  					// note: rawStream skips recv/send middleware for unary API requests to avoid confusion
   227  					err = invokeStreamUnaryHandler(rCtx, rawStream, methodInfo, t.inkHdlFunc, ri)
   228  				} else {
   229  					err = t.inkHdlFunc(rCtx, &streaming.Args{Stream: st}, nil)
   230  				}
   231  			}
   232  
   233  			if err != nil {
   234  				tr.WriteStatus(s, convertStatus(err))
   235  				t.OnError(rCtx, err, conn)
   236  				return
   237  			}
   238  			if bizStatusErr := ri.Invocation().BizStatusErr(); bizStatusErr != nil {
   239  				var st *status.Status
   240  				if sterr, ok := bizStatusErr.(status.Iface); ok {
   241  					st = sterr.GRPCStatus()
   242  				} else {
   243  					st = status.New(codes.Internal, bizStatusErr.BizMessage())
   244  				}
   245  				s.SetBizStatusErr(bizStatusErr)
   246  				tr.WriteStatus(s, st)
   247  				return
   248  			}
   249  			tr.WriteStatus(s, status.New(codes.OK, ""))
   250  		})
   251  	}, func(ctx context.Context, method string) context.Context {
   252  		return ctx
   253  	})
   254  	return nil
   255  }
   256  
   257  // invokeStreamUnaryHandler allows unary APIs over HTTP2 to use the same server middleware as non-streaming APIs.
   258  // For thrift unary APIs over HTTP2, it's enabled by default.
   259  // For grpc(protobuf) unary APIs, it's disabled by default to keep backward compatibility.
   260  func invokeStreamUnaryHandler(ctx context.Context, st streaming.Stream, mi serviceinfo.MethodInfo,
   261  	handler endpoint.Endpoint, ri rpcinfo.RPCInfo,
   262  ) (err error) {
   263  	realArgs, realResp := mi.NewArgs(), mi.NewResult()
   264  	if err = st.RecvMsg(realArgs); err != nil {
   265  		return err
   266  	}
   267  	if err = handler(ctx, realArgs, realResp); err != nil {
   268  		return err
   269  	}
   270  	if ri != nil && ri.Invocation().BizStatusErr() != nil {
   271  		// BizError: do not send the message
   272  		return nil
   273  	}
   274  	return st.SendMsg(realResp)
   275  }
   276  
   277  // msg 是解码后的实例,如 Arg 或 Result, 触发上层处理,用于异步 和 服务端处理
   278  func (t *svrTransHandler) OnMessage(ctx context.Context, args, result remote.Message) (context.Context, error) {
   279  	panic("unimplemented")
   280  }
   281  
   282  type svrTransKey int
   283  
   284  const ctxKeySvrTransport svrTransKey = 1
   285  
   286  type SvrTrans struct {
   287  	tr   grpcTransport.ServerTransport
   288  	pool *sync.Pool // value is rpcInfo
   289  }
   290  
   291  // 新连接建立时触发,主要用于服务端,对应 netpoll onPrepare
   292  func (t *svrTransHandler) OnActive(ctx context.Context, conn net.Conn) (context.Context, error) {
   293  	// set readTimeout to infinity to avoid streaming break
   294  	// use keepalive to check the health of connection
   295  	if npConn, ok := conn.(netpoll.Connection); ok {
   296  		npConn.SetReadTimeout(grpcTransport.Infinity)
   297  	} else {
   298  		conn.SetReadDeadline(time.Now().Add(grpcTransport.Infinity))
   299  	}
   300  
   301  	tr, err := grpcTransport.NewServerTransport(ctx, conn, t.opt.GRPCCfg)
   302  	if err != nil {
   303  		return nil, err
   304  	}
   305  	pool := &sync.Pool{
   306  		New: func() interface{} {
   307  			// init rpcinfo
   308  			ri := t.opt.InitOrResetRPCInfoFunc(nil, conn.RemoteAddr())
   309  			return ri
   310  		},
   311  	}
   312  	ctx = context.WithValue(ctx, ctxKeySvrTransport, &SvrTrans{tr: tr, pool: pool})
   313  	return ctx, nil
   314  }
   315  
   316  // 连接关闭时回调
   317  func (t *svrTransHandler) OnInactive(ctx context.Context, conn net.Conn) {
   318  	tr := ctx.Value(ctxKeySvrTransport).(*SvrTrans).tr
   319  	tr.Close()
   320  }
   321  
   322  // 传输层 error 回调
   323  func (t *svrTransHandler) OnError(ctx context.Context, err error, conn net.Conn) {
   324  	var de *kerrors.DetailedError
   325  	if ok := errors.As(err, &de); ok && de.Stack() != "" {
   326  		klog.Errorf("KITEX: processing gRPC request error, remoteAddr=%s, error=%s\nstack=%s", conn.RemoteAddr(), err.Error(), de.Stack())
   327  	} else {
   328  		klog.Errorf("KITEX: processing gRPC request error, remoteAddr=%s, error=%s", conn.RemoteAddr(), err.Error())
   329  	}
   330  }
   331  
   332  func (t *svrTransHandler) SetInvokeHandleFunc(inkHdlFunc endpoint.Endpoint) {
   333  	t.inkHdlFunc = inkHdlFunc
   334  }
   335  
   336  func (t *svrTransHandler) SetPipeline(p *remote.TransPipeline) {
   337  }
   338  
   339  func (t *svrTransHandler) startTracer(ctx context.Context, ri rpcinfo.RPCInfo) context.Context {
   340  	c := t.opt.TracerCtl.DoStart(ctx, ri)
   341  	return c
   342  }
   343  
   344  func (t *svrTransHandler) finishTracer(ctx context.Context, ri rpcinfo.RPCInfo, err error, panicErr interface{}) {
   345  	rpcStats := rpcinfo.AsMutableRPCStats(ri.Stats())
   346  	if rpcStats == nil {
   347  		return
   348  	}
   349  	if panicErr != nil {
   350  		rpcStats.SetPanicked(panicErr)
   351  	}
   352  	t.opt.TracerCtl.DoFinish(ctx, ri, err)
   353  	rpcStats.Reset()
   354  }