github.com/lingyao2333/mo-zero@v1.4.1/rest/token/tokenparser.go (about)

     1  package token
     2  
     3  import (
     4  	"net/http"
     5  	"sync"
     6  	"sync/atomic"
     7  	"time"
     8  
     9  	"github.com/golang-jwt/jwt/v4"
    10  	"github.com/golang-jwt/jwt/v4/request"
    11  	"github.com/lingyao2333/mo-zero/core/timex"
    12  )
    13  
    14  const claimHistoryResetDuration = time.Hour * 24
    15  
    16  type (
    17  	// ParseOption defines the method to customize a TokenParser.
    18  	ParseOption func(parser *TokenParser)
    19  
    20  	// A TokenParser is used to parse tokens.
    21  	TokenParser struct {
    22  		resetTime     time.Duration
    23  		resetDuration time.Duration
    24  		history       sync.Map
    25  	}
    26  )
    27  
    28  // NewTokenParser returns a TokenParser.
    29  func NewTokenParser(opts ...ParseOption) *TokenParser {
    30  	parser := &TokenParser{
    31  		resetTime:     timex.Now(),
    32  		resetDuration: claimHistoryResetDuration,
    33  	}
    34  
    35  	for _, opt := range opts {
    36  		opt(parser)
    37  	}
    38  
    39  	return parser
    40  }
    41  
    42  // ParseToken parses token from given r, with passed in secret and prevSecret.
    43  func (tp *TokenParser) ParseToken(r *http.Request, secret, prevSecret string) (*jwt.Token, error) {
    44  	var token *jwt.Token
    45  	var err error
    46  
    47  	if len(prevSecret) > 0 {
    48  		count := tp.loadCount(secret)
    49  		prevCount := tp.loadCount(prevSecret)
    50  
    51  		var first, second string
    52  		if count > prevCount {
    53  			first = secret
    54  			second = prevSecret
    55  		} else {
    56  			first = prevSecret
    57  			second = secret
    58  		}
    59  
    60  		token, err = tp.doParseToken(r, first)
    61  		if err != nil {
    62  			token, err = tp.doParseToken(r, second)
    63  			if err != nil {
    64  				return nil, err
    65  			}
    66  
    67  			tp.incrementCount(second)
    68  		} else {
    69  			tp.incrementCount(first)
    70  		}
    71  	} else {
    72  		token, err = tp.doParseToken(r, secret)
    73  		if err != nil {
    74  			return nil, err
    75  		}
    76  	}
    77  
    78  	return token, nil
    79  }
    80  
    81  func (tp *TokenParser) doParseToken(r *http.Request, secret string) (*jwt.Token, error) {
    82  	return request.ParseFromRequest(r, request.AuthorizationHeaderExtractor,
    83  		func(token *jwt.Token) (interface{}, error) {
    84  			return []byte(secret), nil
    85  		}, request.WithParser(newParser()))
    86  }
    87  
    88  func (tp *TokenParser) incrementCount(secret string) {
    89  	now := timex.Now()
    90  	if tp.resetTime+tp.resetDuration < now {
    91  		tp.history.Range(func(key, value interface{}) bool {
    92  			tp.history.Delete(key)
    93  			return true
    94  		})
    95  	}
    96  
    97  	value, ok := tp.history.Load(secret)
    98  	if ok {
    99  		atomic.AddUint64(value.(*uint64), 1)
   100  	} else {
   101  		var count uint64 = 1
   102  		tp.history.Store(secret, &count)
   103  	}
   104  }
   105  
   106  func (tp *TokenParser) loadCount(secret string) uint64 {
   107  	value, ok := tp.history.Load(secret)
   108  	if ok {
   109  		return *value.(*uint64)
   110  	}
   111  
   112  	return 0
   113  }
   114  
   115  // WithResetDuration returns a func to customize a TokenParser with reset duration.
   116  func WithResetDuration(duration time.Duration) ParseOption {
   117  	return func(parser *TokenParser) {
   118  		parser.resetDuration = duration
   119  	}
   120  }
   121  
   122  func newParser() *jwt.Parser {
   123  	return jwt.NewParser(jwt.WithJSONNumber())
   124  }