github.com/openimsdk/tools@v0.0.49/mw/rpc_client_interceptor.go (about) 1 // Copyright © 2023 OpenIM. All rights reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package mw 16 17 import ( 18 "context" 19 "fmt" 20 "strings" 21 22 "github.com/openimsdk/protocol/constant" 23 "github.com/openimsdk/protocol/errinfo" 24 "github.com/openimsdk/tools/errs" 25 "github.com/openimsdk/tools/log" 26 "google.golang.org/grpc" 27 "google.golang.org/grpc/metadata" 28 "google.golang.org/grpc/status" 29 ) 30 31 func GrpcClient() grpc.DialOption { 32 return grpc.WithChainUnaryInterceptor(RpcClientInterceptor) 33 } 34 35 func RpcClientInterceptor(ctx context.Context, method string, req, resp any, cc *grpc.ClientConn, 36 invoker grpc.UnaryInvoker, opts ...grpc.CallOption) (err error) { 37 if ctx == nil { 38 return errs.ErrInternalServer.WrapMsg("call rpc request context is nil") 39 } 40 ctx, err = getRpcContext(ctx, method) 41 if err != nil { 42 return err 43 } 44 log.ZDebug(ctx, fmt.Sprintf("RPC Client Request - %s", extractFunctionName(method)), "funcName", method, "req", req, "conn target", cc.Target()) 45 err = invoker(ctx, method, req, resp, cc, opts...) 46 if err == nil { 47 log.ZInfo(ctx, fmt.Sprintf("RPC Client Response Success - %s", extractFunctionName(method)), "funcName", method, "resp", resp) 48 return nil 49 } 50 log.ZError(ctx, fmt.Sprintf("RPC Client Response Error - %s", extractFunctionName(method)), err, "funcName", method) 51 rpcErr, ok := err.(interface{ GRPCStatus() *status.Status }) 52 if !ok { 53 return errs.ErrInternalServer.WrapMsg(err.Error()) 54 } 55 sta := rpcErr.GRPCStatus() 56 if sta.Code() == 0 { 57 return errs.NewCodeError(errs.ServerInternalError, err.Error()).Wrap() 58 } 59 if details := sta.Details(); len(details) > 0 { 60 errInfo, ok := details[0].(*errinfo.ErrorInfo) 61 if ok { 62 s := strings.Join(errInfo.Warp, "->") + errInfo.Cause 63 return errs.NewCodeError(int(sta.Code()), sta.Message()).WithDetail(s).Wrap() 64 } 65 } 66 return errs.NewCodeError(int(sta.Code()), sta.Message()).Wrap() 67 } 68 69 func getRpcContext(ctx context.Context, method string) (context.Context, error) { 70 md := metadata.Pairs() 71 if keys, _ := ctx.Value(constant.RpcCustomHeader).([]string); len(keys) > 0 { 72 for _, key := range keys { 73 val, ok := ctx.Value(key).([]string) 74 if !ok { 75 return nil, errs.ErrInternalServer.WrapMsg("ctx missing key", "key", key) 76 } 77 if len(val) == 0 { 78 return nil, errs.ErrInternalServer.WrapMsg("ctx key value is empty", "key", key) 79 } 80 md.Set(key, val...) 81 } 82 md.Set(constant.RpcCustomHeader, keys...) 83 } 84 operationID, ok := ctx.Value(constant.OperationID).(string) 85 if !ok { 86 log.ZWarn(ctx, "ctx missing operationID", errs.New("ctx missing operationID"), "funcName", method) 87 return nil, errs.ErrArgs.WrapMsg("ctx missing operationID") 88 } 89 md.Set(constant.OperationID, operationID) 90 // var checkArgs []string 91 // checkArgs = append(checkArgs, constant.OperationID, operationID) 92 opUserID, ok := ctx.Value(constant.OpUserID).(string) 93 if ok { 94 md.Set(constant.OpUserID, opUserID) 95 // checkArgs = append(checkArgs, constant.OpUserID, opUserID) 96 } 97 opUserIDPlatformID, ok := ctx.Value(constant.OpUserPlatform).(string) 98 if ok { 99 md.Set(constant.OpUserPlatform, opUserIDPlatformID) 100 } 101 connID, ok := ctx.Value(constant.ConnID).(string) 102 if ok { 103 md.Set(constant.ConnID, connID) 104 } 105 return metadata.NewOutgoingContext(ctx, md), nil 106 } 107 108 func extractFunctionName(funcName string) string { 109 parts := strings.Split(funcName, "/") 110 if len(parts) > 0 { 111 return parts[len(parts)-1] 112 } 113 return "" 114 }