github.com/openimsdk/tools@v0.0.49/mw/rpc_client_interceptor.go (about)

     1  // Copyright © 2023 OpenIM. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package mw
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"strings"
    21  
    22  	"github.com/openimsdk/protocol/constant"
    23  	"github.com/openimsdk/protocol/errinfo"
    24  	"github.com/openimsdk/tools/errs"
    25  	"github.com/openimsdk/tools/log"
    26  	"google.golang.org/grpc"
    27  	"google.golang.org/grpc/metadata"
    28  	"google.golang.org/grpc/status"
    29  )
    30  
    31  func GrpcClient() grpc.DialOption {
    32  	return grpc.WithChainUnaryInterceptor(RpcClientInterceptor)
    33  }
    34  
    35  func RpcClientInterceptor(ctx context.Context, method string, req, resp any, cc *grpc.ClientConn,
    36  	invoker grpc.UnaryInvoker, opts ...grpc.CallOption) (err error) {
    37  	if ctx == nil {
    38  		return errs.ErrInternalServer.WrapMsg("call rpc request context is nil")
    39  	}
    40  	ctx, err = getRpcContext(ctx, method)
    41  	if err != nil {
    42  		return err
    43  	}
    44  	log.ZDebug(ctx, fmt.Sprintf("RPC Client Request - %s", extractFunctionName(method)), "funcName", method, "req", req, "conn target", cc.Target())
    45  	err = invoker(ctx, method, req, resp, cc, opts...)
    46  	if err == nil {
    47  		log.ZInfo(ctx, fmt.Sprintf("RPC Client Response Success - %s", extractFunctionName(method)), "funcName", method, "resp", resp)
    48  		return nil
    49  	}
    50  	log.ZError(ctx, fmt.Sprintf("RPC Client Response Error - %s", extractFunctionName(method)), err, "funcName", method)
    51  	rpcErr, ok := err.(interface{ GRPCStatus() *status.Status })
    52  	if !ok {
    53  		return errs.ErrInternalServer.WrapMsg(err.Error())
    54  	}
    55  	sta := rpcErr.GRPCStatus()
    56  	if sta.Code() == 0 {
    57  		return errs.NewCodeError(errs.ServerInternalError, err.Error()).Wrap()
    58  	}
    59  	if details := sta.Details(); len(details) > 0 {
    60  		errInfo, ok := details[0].(*errinfo.ErrorInfo)
    61  		if ok {
    62  			s := strings.Join(errInfo.Warp, "->") + errInfo.Cause
    63  			return errs.NewCodeError(int(sta.Code()), sta.Message()).WithDetail(s).Wrap()
    64  		}
    65  	}
    66  	return errs.NewCodeError(int(sta.Code()), sta.Message()).Wrap()
    67  }
    68  
    69  func getRpcContext(ctx context.Context, method string) (context.Context, error) {
    70  	md := metadata.Pairs()
    71  	if keys, _ := ctx.Value(constant.RpcCustomHeader).([]string); len(keys) > 0 {
    72  		for _, key := range keys {
    73  			val, ok := ctx.Value(key).([]string)
    74  			if !ok {
    75  				return nil, errs.ErrInternalServer.WrapMsg("ctx missing key", "key", key)
    76  			}
    77  			if len(val) == 0 {
    78  				return nil, errs.ErrInternalServer.WrapMsg("ctx key value is empty", "key", key)
    79  			}
    80  			md.Set(key, val...)
    81  		}
    82  		md.Set(constant.RpcCustomHeader, keys...)
    83  	}
    84  	operationID, ok := ctx.Value(constant.OperationID).(string)
    85  	if !ok {
    86  		log.ZWarn(ctx, "ctx missing operationID", errs.New("ctx missing operationID"), "funcName", method)
    87  		return nil, errs.ErrArgs.WrapMsg("ctx missing operationID")
    88  	}
    89  	md.Set(constant.OperationID, operationID)
    90  	// var checkArgs []string
    91  	// checkArgs = append(checkArgs, constant.OperationID, operationID)
    92  	opUserID, ok := ctx.Value(constant.OpUserID).(string)
    93  	if ok {
    94  		md.Set(constant.OpUserID, opUserID)
    95  		// checkArgs = append(checkArgs, constant.OpUserID, opUserID)
    96  	}
    97  	opUserIDPlatformID, ok := ctx.Value(constant.OpUserPlatform).(string)
    98  	if ok {
    99  		md.Set(constant.OpUserPlatform, opUserIDPlatformID)
   100  	}
   101  	connID, ok := ctx.Value(constant.ConnID).(string)
   102  	if ok {
   103  		md.Set(constant.ConnID, connID)
   104  	}
   105  	return metadata.NewOutgoingContext(ctx, md), nil
   106  }
   107  
   108  func extractFunctionName(funcName string) string {
   109  	parts := strings.Split(funcName, "/")
   110  	if len(parts) > 0 {
   111  		return parts[len(parts)-1]
   112  	}
   113  	return ""
   114  }