github.com/cloudwego/kitex@v0.9.0/pkg/remote/trans/nphttp2/server_handler.go (about) 1 /* 2 * Copyright 2021 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package nphttp2 18 19 import ( 20 "bytes" 21 "context" 22 "errors" 23 "fmt" 24 "net" 25 "runtime/debug" 26 "strings" 27 "sync" 28 "time" 29 30 "github.com/cloudwego/netpoll" 31 32 "github.com/cloudwego/kitex/pkg/endpoint" 33 "github.com/cloudwego/kitex/pkg/gofunc" 34 "github.com/cloudwego/kitex/pkg/kerrors" 35 "github.com/cloudwego/kitex/pkg/klog" 36 "github.com/cloudwego/kitex/pkg/remote" 37 "github.com/cloudwego/kitex/pkg/remote/codec" 38 "github.com/cloudwego/kitex/pkg/remote/codec/grpc" 39 "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/codes" 40 grpcTransport "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/grpc" 41 "github.com/cloudwego/kitex/pkg/remote/trans/nphttp2/status" 42 "github.com/cloudwego/kitex/pkg/rpcinfo" 43 "github.com/cloudwego/kitex/pkg/serviceinfo" 44 "github.com/cloudwego/kitex/pkg/stats" 45 "github.com/cloudwego/kitex/pkg/streaming" 46 "github.com/cloudwego/kitex/transport" 47 ) 48 49 type svrTransHandlerFactory struct{} 50 51 // NewSvrTransHandlerFactory ... 52 func NewSvrTransHandlerFactory() remote.ServerTransHandlerFactory { 53 return &svrTransHandlerFactory{} 54 } 55 56 func (f *svrTransHandlerFactory) NewTransHandler(opt *remote.ServerOption) (remote.ServerTransHandler, error) { 57 return newSvrTransHandler(opt) 58 } 59 60 func newSvrTransHandler(opt *remote.ServerOption) (*svrTransHandler, error) { 61 return &svrTransHandler{ 62 opt: opt, 63 svcSearchMap: opt.SvcSearchMap, 64 codec: grpc.NewGRPCCodec(grpc.WithThriftCodec(opt.PayloadCodec)), 65 }, nil 66 } 67 68 var _ remote.ServerTransHandler = &svrTransHandler{} 69 70 type svrTransHandler struct { 71 opt *remote.ServerOption 72 svcSearchMap map[string]*serviceinfo.ServiceInfo 73 inkHdlFunc endpoint.Endpoint 74 codec remote.Codec 75 } 76 77 var prefaceReadAtMost = func() int { 78 // min(len(ClientPreface), len(flagBuf)) 79 // len(flagBuf) = 2 * codec.Size32 80 if 2*codec.Size32 < grpcTransport.ClientPrefaceLen { 81 return 2 * codec.Size32 82 } 83 return grpcTransport.ClientPrefaceLen 84 }() 85 86 func (t *svrTransHandler) ProtocolMatch(ctx context.Context, conn net.Conn) (err error) { 87 // Check the validity of client preface. 88 npReader := conn.(interface{ Reader() netpoll.Reader }).Reader() 89 // read at most avoid block 90 preface, err := npReader.Peek(prefaceReadAtMost) 91 if err != nil { 92 return err 93 } 94 if bytes.Equal(preface[:prefaceReadAtMost], grpcTransport.ClientPreface[:prefaceReadAtMost]) { 95 return nil 96 } 97 return errors.New("error protocol not match") 98 } 99 100 func (t *svrTransHandler) Write(ctx context.Context, conn net.Conn, msg remote.Message) (nctx context.Context, err error) { 101 buf := newBuffer(conn.(*serverConn)) 102 defer buf.Release(err) 103 104 if err = t.codec.Encode(ctx, msg, buf); err != nil { 105 return ctx, err 106 } 107 return ctx, buf.Flush() 108 } 109 110 func (t *svrTransHandler) Read(ctx context.Context, conn net.Conn, msg remote.Message) (nctx context.Context, err error) { 111 buf := newBuffer(conn.(*serverConn)) 112 defer buf.Release(err) 113 114 err = t.codec.Decode(ctx, msg, buf) 115 return ctx, err 116 } 117 118 // 只 return write err 119 func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) error { 120 svrTrans := ctx.Value(ctxKeySvrTransport).(*SvrTrans) 121 tr := svrTrans.tr 122 123 tr.HandleStreams(func(s *grpcTransport.Stream) { 124 gofunc.GoFunc(ctx, func() { 125 ri := svrTrans.pool.Get().(rpcinfo.RPCInfo) 126 rCtx := rpcinfo.NewCtxWithRPCInfo(s.Context(), ri) 127 defer func() { 128 // reset rpcinfo for performance (PR #584) 129 if rpcinfo.PoolEnabled() { 130 ri = t.opt.InitOrResetRPCInfoFunc(ri, conn.RemoteAddr()) 131 svrTrans.pool.Put(ri) 132 } 133 }() 134 135 // set grpc transport flag before execute metahandler 136 rpcinfo.AsMutableRPCConfig(ri.Config()).SetTransportProtocol(transport.GRPC) 137 var err error 138 for _, shdlr := range t.opt.StreamingMetaHandlers { 139 rCtx, err = shdlr.OnReadStream(rCtx) 140 if err != nil { 141 tr.WriteStatus(s, convertStatus(err)) 142 return 143 } 144 } 145 rCtx = t.startTracer(rCtx, ri) 146 defer func() { 147 panicErr := recover() 148 if panicErr != nil { 149 if conn != nil { 150 klog.CtxErrorf(rCtx, "KITEX: gRPC panic happened, close conn, remoteAddress=%s, error=%s\nstack=%s", conn.RemoteAddr(), panicErr, string(debug.Stack())) 151 } else { 152 klog.CtxErrorf(rCtx, "KITEX: gRPC panic happened, error=%v\nstack=%s", panicErr, string(debug.Stack())) 153 } 154 } 155 t.finishTracer(rCtx, ri, err, panicErr) 156 }() 157 158 ink := ri.Invocation().(rpcinfo.InvocationSetter) 159 sm := s.Method() 160 if sm != "" && sm[0] == '/' { 161 sm = sm[1:] 162 } 163 pos := strings.LastIndex(sm, "/") 164 if pos == -1 { 165 errDesc := fmt.Sprintf("malformed method name, method=%q", s.Method()) 166 tr.WriteStatus(s, status.New(codes.Internal, errDesc)) 167 return 168 } 169 methodName := sm[pos+1:] 170 ink.SetMethodName(methodName) 171 172 if mutableTo := rpcinfo.AsMutableEndpointInfo(ri.To()); mutableTo != nil { 173 if err = mutableTo.SetMethod(methodName); err != nil { 174 errDesc := fmt.Sprintf("setMethod failed in streaming, method=%s, error=%s", methodName, err.Error()) 175 _ = tr.WriteStatus(s, status.New(codes.Internal, errDesc)) 176 return 177 } 178 } 179 180 var serviceName string 181 idx := strings.LastIndex(sm[:pos], ".") 182 if idx == -1 { 183 ink.SetPackageName("") 184 serviceName = sm[0:pos] 185 } else { 186 ink.SetPackageName(sm[:idx]) 187 serviceName = sm[idx+1 : pos] 188 } 189 ink.SetServiceName(serviceName) 190 191 // set recv grpc compressor at server to decode the pack from client 192 remote.SetRecvCompressor(ri, s.RecvCompress()) 193 // set send grpc compressor at server to encode reply pack 194 remote.SetSendCompressor(ri, s.SendCompress()) 195 196 svcInfo := t.svcSearchMap[remote.BuildMultiServiceKey(serviceName, methodName)] 197 var methodInfo serviceinfo.MethodInfo 198 if svcInfo != nil { 199 methodInfo = svcInfo.MethodInfo(methodName) 200 } 201 202 rawStream := NewStream(rCtx, svcInfo, newServerConn(tr, s), t) 203 st := newStreamWithMiddleware(rawStream, t.opt.RecvEndpoint, t.opt.SendEndpoint) 204 205 // bind stream into ctx, in order to let user set header and trailer by provided api in meta_api.go 206 rCtx = streaming.NewCtxWithStream(rCtx, st) 207 208 if methodInfo == nil { 209 unknownServiceHandlerFunc := t.opt.GRPCUnknownServiceHandler 210 if unknownServiceHandlerFunc != nil { 211 rpcinfo.Record(rCtx, ri, stats.ServerHandleStart, nil) 212 err = unknownServiceHandlerFunc(rCtx, methodName, st) 213 if err != nil { 214 err = kerrors.ErrBiz.WithCause(err) 215 } 216 } else { 217 if svcInfo == nil { 218 err = remote.NewTransErrorWithMsg(remote.UnknownService, fmt.Sprintf("unknown service %s", serviceName)) 219 } else { 220 err = remote.NewTransErrorWithMsg(remote.UnknownMethod, fmt.Sprintf("unknown method %s", methodName)) 221 } 222 } 223 } else { 224 if streaming.UnaryCompatibleMiddleware(methodInfo.StreamingMode(), t.opt.CompatibleMiddlewareForUnary) { 225 // making streaming unary APIs capable of using the same server middleware as non-streaming APIs 226 // note: rawStream skips recv/send middleware for unary API requests to avoid confusion 227 err = invokeStreamUnaryHandler(rCtx, rawStream, methodInfo, t.inkHdlFunc, ri) 228 } else { 229 err = t.inkHdlFunc(rCtx, &streaming.Args{Stream: st}, nil) 230 } 231 } 232 233 if err != nil { 234 tr.WriteStatus(s, convertStatus(err)) 235 t.OnError(rCtx, err, conn) 236 return 237 } 238 if bizStatusErr := ri.Invocation().BizStatusErr(); bizStatusErr != nil { 239 var st *status.Status 240 if sterr, ok := bizStatusErr.(status.Iface); ok { 241 st = sterr.GRPCStatus() 242 } else { 243 st = status.New(codes.Internal, bizStatusErr.BizMessage()) 244 } 245 s.SetBizStatusErr(bizStatusErr) 246 tr.WriteStatus(s, st) 247 return 248 } 249 tr.WriteStatus(s, status.New(codes.OK, "")) 250 }) 251 }, func(ctx context.Context, method string) context.Context { 252 return ctx 253 }) 254 return nil 255 } 256 257 // invokeStreamUnaryHandler allows unary APIs over HTTP2 to use the same server middleware as non-streaming APIs. 258 // For thrift unary APIs over HTTP2, it's enabled by default. 259 // For grpc(protobuf) unary APIs, it's disabled by default to keep backward compatibility. 260 func invokeStreamUnaryHandler(ctx context.Context, st streaming.Stream, mi serviceinfo.MethodInfo, 261 handler endpoint.Endpoint, ri rpcinfo.RPCInfo, 262 ) (err error) { 263 realArgs, realResp := mi.NewArgs(), mi.NewResult() 264 if err = st.RecvMsg(realArgs); err != nil { 265 return err 266 } 267 if err = handler(ctx, realArgs, realResp); err != nil { 268 return err 269 } 270 if ri != nil && ri.Invocation().BizStatusErr() != nil { 271 // BizError: do not send the message 272 return nil 273 } 274 return st.SendMsg(realResp) 275 } 276 277 // msg 是解码后的实例,如 Arg 或 Result, 触发上层处理,用于异步 和 服务端处理 278 func (t *svrTransHandler) OnMessage(ctx context.Context, args, result remote.Message) (context.Context, error) { 279 panic("unimplemented") 280 } 281 282 type svrTransKey int 283 284 const ctxKeySvrTransport svrTransKey = 1 285 286 type SvrTrans struct { 287 tr grpcTransport.ServerTransport 288 pool *sync.Pool // value is rpcInfo 289 } 290 291 // 新连接建立时触发,主要用于服务端,对应 netpoll onPrepare 292 func (t *svrTransHandler) OnActive(ctx context.Context, conn net.Conn) (context.Context, error) { 293 // set readTimeout to infinity to avoid streaming break 294 // use keepalive to check the health of connection 295 if npConn, ok := conn.(netpoll.Connection); ok { 296 npConn.SetReadTimeout(grpcTransport.Infinity) 297 } else { 298 conn.SetReadDeadline(time.Now().Add(grpcTransport.Infinity)) 299 } 300 301 tr, err := grpcTransport.NewServerTransport(ctx, conn, t.opt.GRPCCfg) 302 if err != nil { 303 return nil, err 304 } 305 pool := &sync.Pool{ 306 New: func() interface{} { 307 // init rpcinfo 308 ri := t.opt.InitOrResetRPCInfoFunc(nil, conn.RemoteAddr()) 309 return ri 310 }, 311 } 312 ctx = context.WithValue(ctx, ctxKeySvrTransport, &SvrTrans{tr: tr, pool: pool}) 313 return ctx, nil 314 } 315 316 // 连接关闭时回调 317 func (t *svrTransHandler) OnInactive(ctx context.Context, conn net.Conn) { 318 tr := ctx.Value(ctxKeySvrTransport).(*SvrTrans).tr 319 tr.Close() 320 } 321 322 // 传输层 error 回调 323 func (t *svrTransHandler) OnError(ctx context.Context, err error, conn net.Conn) { 324 var de *kerrors.DetailedError 325 if ok := errors.As(err, &de); ok && de.Stack() != "" { 326 klog.Errorf("KITEX: processing gRPC request error, remoteAddr=%s, error=%s\nstack=%s", conn.RemoteAddr(), err.Error(), de.Stack()) 327 } else { 328 klog.Errorf("KITEX: processing gRPC request error, remoteAddr=%s, error=%s", conn.RemoteAddr(), err.Error()) 329 } 330 } 331 332 func (t *svrTransHandler) SetInvokeHandleFunc(inkHdlFunc endpoint.Endpoint) { 333 t.inkHdlFunc = inkHdlFunc 334 } 335 336 func (t *svrTransHandler) SetPipeline(p *remote.TransPipeline) { 337 } 338 339 func (t *svrTransHandler) startTracer(ctx context.Context, ri rpcinfo.RPCInfo) context.Context { 340 c := t.opt.TracerCtl.DoStart(ctx, ri) 341 return c 342 } 343 344 func (t *svrTransHandler) finishTracer(ctx context.Context, ri rpcinfo.RPCInfo, err error, panicErr interface{}) { 345 rpcStats := rpcinfo.AsMutableRPCStats(ri.Stats()) 346 if rpcStats == nil { 347 return 348 } 349 if panicErr != nil { 350 rpcStats.SetPanicked(panicErr) 351 } 352 t.opt.TracerCtl.DoFinish(ctx, ri, err) 353 rpcStats.Reset() 354 }