github.com/openimsdk/tools@v0.0.49/mw/rpc_server_interceptor.go (about) 1 // Copyright © 2023 OpenIM. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package mw 16 17 import ( 18 "context" 19 "fmt" 20 "math" 21 "runtime" 22 "strings" 23 24 "github.com/openimsdk/tools/checker" 25 "github.com/pkg/errors" 26 27 "github.com/openimsdk/protocol/constant" 28 "github.com/openimsdk/protocol/errinfo" 29 "github.com/openimsdk/tools/errs" 30 "github.com/openimsdk/tools/log" 31 "github.com/openimsdk/tools/mw/specialerror" 32 "google.golang.org/grpc" 33 "google.golang.org/grpc/codes" 34 "google.golang.org/grpc/metadata" 35 "google.golang.org/grpc/status" 36 ) 37 38 func RpcServerInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { 39 funcName := info.FullMethod 40 md, err := validateMetadata(ctx) 41 if err != nil { 42 return nil, err 43 } 44 ctx, err = enrichContextWithMetadata(ctx, md) 45 if err != nil { 46 return nil, err 47 } 48 log.ZInfo(ctx, fmt.Sprintf("RPC Server Request - %s", extractFunctionName(funcName)), "funcName", funcName, "req", req) 49 if err := checker.Validate(req); err != nil { 50 return nil, err 51 } 52 53 resp, err := handler(ctx, req) 54 if err != nil { 55 return nil, handleError(ctx, funcName, req, err) 56 } 57 log.ZInfo(ctx, fmt.Sprintf("RPC Server Response Success - %s", extractFunctionName(funcName)), "funcName", funcName, "resp", resp) 58 return resp, nil 59 } 60 61 func validateMetadata(ctx context.Context) (metadata.MD, error) { 62 md, ok := metadata.FromIncomingContext(ctx) 63 if !ok { 64 return nil, status.New(codes.InvalidArgument, "missing metadata").Err() 65 } 66 if len(md.Get(constant.OperationID)) != 1 { 67 return nil, status.New(codes.InvalidArgument, "operationID error").Err() 68 } 69 return md, nil 70 } 71 72 func enrichContextWithMetadata(ctx context.Context, md metadata.MD) (context.Context, error) { 73 if keys := md.Get(constant.RpcCustomHeader); len(keys) > 0 { 74 ctx = context.WithValue(ctx, constant.RpcCustomHeader, keys) 75 for _, key := range keys { 76 values := md.Get(key) 77 if len(values) == 0 { 78 return nil, status.New(codes.InvalidArgument, fmt.Sprintf("missing metadata key %s", key)).Err() 79 } 80 ctx = context.WithValue(ctx, key, values) 81 } 82 } 83 ctx = context.WithValue(ctx, constant.OperationID, md.Get(constant.OperationID)[0]) 84 if opts := md.Get(constant.OpUserID); len(opts) == 1 { 85 ctx = context.WithValue(ctx, constant.OpUserID, opts[0]) 86 } 87 if opts := md.Get(constant.OpUserPlatform); len(opts) == 1 { 88 ctx = context.WithValue(ctx, constant.OpUserPlatform, opts[0]) 89 } 90 if opts := md.Get(constant.ConnID); len(opts) == 1 { 91 ctx = context.WithValue(ctx, constant.ConnID, opts[0]) 92 } 93 return ctx, nil 94 } 95 96 func handleError(ctx context.Context, funcName string, req any, err error) error { 97 log.ZWarn(ctx, "rpc server resp WithDetails error", formatError(err), "funcName", funcName) 98 unwrap := errs.Unwrap(err) 99 codeErr := specialerror.ErrCode(unwrap) 100 if codeErr == nil { 101 log.ZError(ctx, "rpc InternalServer error", formatError(err), "funcName", funcName, "req", req) 102 codeErr = errs.ErrInternalServer 103 } 104 code := codeErr.Code() 105 if code <= 0 || int64(code) > int64(math.MaxUint32) { 106 log.ZError(ctx, "rpc UnknownError", formatError(err), "funcName", funcName, "rpc UnknownCode:", int64(code)) 107 code = errs.ServerInternalError 108 } 109 grpcStatus := status.New(codes.Code(code), err.Error()) 110 errInfo := &errinfo.ErrorInfo{Cause: err.Error()} 111 details, err := grpcStatus.WithDetails(errInfo) 112 if err != nil { 113 log.ZWarn(ctx, "rpc server resp WithDetails error", formatError(err), "funcName", funcName) 114 return errs.WrapMsg(err, "rpc server resp WithDetails error", "err", err) 115 } 116 log.ZWarn(ctx, fmt.Sprintf("RPC Server Response Error - %s", extractFunctionName(funcName)), formatError(details.Err()), "funcName", funcName, "req", req, "err", err) 117 return details.Err() 118 } 119 120 func GrpcServer() grpc.ServerOption { 121 return grpc.ChainUnaryInterceptor(RpcServerInterceptor) 122 } 123 func formatError(err error) error { 124 type stackTracer interface { 125 StackTrace() errors.StackTrace 126 } 127 if e, ok := err.(stackTracer); ok { 128 st := e.StackTrace() 129 var sb strings.Builder 130 sb.WriteString("Error: ") 131 sb.WriteString(err.Error()) 132 sb.WriteString(" | Error trace: ") 133 134 var callPath []string 135 for _, f := range st { 136 pc := uintptr(f) - 1 137 fn := runtime.FuncForPC(pc) 138 if fn == nil { 139 continue 140 } 141 if strings.Contains(fn.Name(), "runtime.") { 142 continue 143 } 144 file, line := fn.FileLine(pc) 145 funcName := simplifyFuncName(fn.Name()) 146 callPath = append(callPath, fmt.Sprintf("%s (%s:%d)", funcName, file, line)) 147 } 148 for i := len(callPath) - 1; i >= 0; i-- { 149 if i != len(callPath)-1 { 150 sb.WriteString(" -> ") 151 } 152 sb.WriteString(callPath[i]) 153 } 154 return errors.New(sb.String()) 155 } 156 return err 157 } 158 func simplifyFuncName(fullFuncName string) string { 159 parts := strings.Split(fullFuncName, "/") 160 lastPart := parts[len(parts)-1] 161 parts = strings.Split(lastPart, ".") 162 if len(parts) > 1 { 163 return parts[len(parts)-1] 164 } 165 return lastPart 166 }