github.com/godaddy-x/freego@v1.0.156/rpcx/filter.go (about)

     1  package rpcx
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"github.com/godaddy-x/freego/cache/limiter"
     8  	"github.com/godaddy-x/freego/ex"
     9  	"github.com/godaddy-x/freego/utils"
    10  	"github.com/godaddy-x/freego/utils/jwt"
    11  	"github.com/godaddy-x/freego/zlog"
    12  	"google.golang.org/grpc"
    13  	"google.golang.org/grpc/metadata"
    14  	"google.golang.org/grpc/status"
    15  )
    16  
    17  const (
    18  	token          = "token"
    19  	limiterKey     = "grpc:limiter:"
    20  	timeDifference = 2400
    21  )
    22  
    23  var (
    24  	unauthorizedUrl = []string{"/pub_worker.PubWorker/Authorize", "/pub_worker.PubWorker/PublicKey"}
    25  )
    26  
    27  var defaultLimiter = rate.NewRateLimiter(rate.Option{
    28  	Limit:       10,
    29  	Bucket:      100,
    30  	Distributed: true,
    31  })
    32  
    33  func (self *GRPCManager) getRateOption(method string) (rate.Option, error) {
    34  	if rateLimiterCall == nil {
    35  		return rate.Option{}, errors.New("rateLimiterCall function is nil")
    36  	}
    37  	return rateLimiterCall(method)
    38  }
    39  
    40  func (self *GRPCManager) rateLimit(method string) error {
    41  	option, err := self.getRateOption(method)
    42  	if err != nil {
    43  		return err
    44  	}
    45  	var limiter rate.RateLimiter
    46  	if option.Limit == 0 || option.Bucket == 0 {
    47  		limiter = defaultLimiter
    48  	} else {
    49  		limiter = rate.NewRateLimiter(option)
    50  	}
    51  	if b := limiter.Allow(limiterKey + method); !b {
    52  		return errors.New(fmt.Sprintf("the method [%s] request is full", method))
    53  	}
    54  	return nil
    55  }
    56  
    57  func (self *GRPCManager) checkToken(ctx context.Context, method string) error {
    58  	if !self.authenticate {
    59  		return nil
    60  	}
    61  	if utils.CheckStr(method, unauthorizedUrl...) {
    62  		return nil
    63  	}
    64  	md, ok := metadata.FromIncomingContext(ctx)
    65  	if !ok {
    66  		return errors.New("rpc context key/value is nil")
    67  	}
    68  	token, b := md[token]
    69  	if !b || len(token) == 0 {
    70  		return errors.New("rpc context token is nil")
    71  	}
    72  	if len(jwtConfig.TokenKey) == 0 {
    73  		return errors.New("rpc context jwt is nil")
    74  	}
    75  	subject := &jwt.Subject{}
    76  	if err := subject.Verify(token[0], jwtConfig.TokenKey, false); err != nil {
    77  		return err
    78  	}
    79  	return nil
    80  }
    81  
    82  func (self *GRPCManager) createToken(ctx context.Context, method string) (context.Context, error) {
    83  	if len(accessToken) == 0 {
    84  		return ctx, nil
    85  	}
    86  	if utils.CheckStr(method, unauthorizedUrl...) {
    87  		return ctx, nil
    88  	}
    89  	md := metadata.New(map[string]string{token: accessToken})
    90  	ctx = metadata.NewOutgoingContext(ctx, md)
    91  	return ctx, nil
    92  }
    93  
    94  func (self *GRPCManager) ServerInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
    95  	//if err := self.rateLimit(info.FullMethod); err != nil {
    96  	//	return nil, err
    97  	//}
    98  	if err := self.checkToken(ctx, info.FullMethod); err != nil {
    99  		return nil, status.Error(ex.BIZ, err.Error())
   100  	}
   101  	res, err := handler(ctx, req)
   102  	if err != nil {
   103  		return nil, status.Error(ex.GRPC, err.Error())
   104  	}
   105  	return res, nil
   106  }
   107  
   108  func (self *GRPCManager) ClientInterceptor(ctx context.Context, method string, req, reply interface{}, conn *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) (err error) {
   109  	//if err := self.rateLimit(method); err != nil {
   110  	//	return err
   111  	//}
   112  	ctx, err = self.createToken(ctx, method)
   113  	if err != nil {
   114  		return err
   115  	}
   116  	start := utils.UnixMilli()
   117  	if err := invoker(ctx, method, req, reply, conn, opts...); err != nil {
   118  		//rpcErr := status.Convert(err)
   119  		//zlog.Error("grpc call failed", start, zlog.String("service", method), zlog.AddError(rpcErr.Err()))
   120  		return utils.Error(status.Convert(err).Message())
   121  	}
   122  	cost := utils.UnixMilli() - start
   123  	if self.consul != nil && self.consul.Config.SlowQuery > 0 && cost > self.consul.Config.SlowQuery {
   124  		l := self.consul.GetSlowLog()
   125  		if l != nil {
   126  			l.Warn("grpc call slow query", zlog.Int64("cost", cost), zlog.Any("service", method))
   127  		}
   128  	}
   129  	return nil
   130  }