github.com/dhax/go-base@v0.0.0-20231004214136-8be7e5c1972b/auth/jwt/tokenauth.go (about)

     1  package jwt
     2  
     3  import (
     4  	"crypto/rand"
     5  	"encoding/json"
     6  	"net/http"
     7  	"time"
     8  
     9  	"github.com/go-chi/jwtauth/v5"
    10  	"github.com/spf13/viper"
    11  )
    12  
    13  // TokenAuth implements JWT authentication flow.
    14  type TokenAuth struct {
    15  	JwtAuth          *jwtauth.JWTAuth
    16  	JwtExpiry        time.Duration
    17  	JwtRefreshExpiry time.Duration
    18  }
    19  
    20  // NewTokenAuth configures and returns a JWT authentication instance.
    21  func NewTokenAuth() (*TokenAuth, error) {
    22  	secret := viper.GetString("auth_jwt_secret")
    23  	if secret == "random" {
    24  		secret = randStringBytes(32)
    25  	}
    26  
    27  	a := &TokenAuth{
    28  		JwtAuth:          jwtauth.New("HS256", []byte(secret), nil),
    29  		JwtExpiry:        viper.GetDuration("auth_jwt_expiry"),
    30  		JwtRefreshExpiry: viper.GetDuration("auth_jwt_refresh_expiry"),
    31  	}
    32  
    33  	return a, nil
    34  }
    35  
    36  // Verifier http middleware will verify a jwt string from a http request.
    37  func (a *TokenAuth) Verifier() func(http.Handler) http.Handler {
    38  	return jwtauth.Verifier(a.JwtAuth)
    39  }
    40  
    41  // GenTokenPair returns both an access token and a refresh token.
    42  func (a *TokenAuth) GenTokenPair(accessClaims AppClaims, refreshClaims RefreshClaims) (string, string, error) {
    43  	access, err := a.CreateJWT(accessClaims)
    44  	if err != nil {
    45  		return "", "", err
    46  	}
    47  	refresh, err := a.CreateRefreshJWT(refreshClaims)
    48  	if err != nil {
    49  		return "", "", err
    50  	}
    51  	return access, refresh, nil
    52  }
    53  
    54  // CreateJWT returns an access token for provided account claims.
    55  func (a *TokenAuth) CreateJWT(c AppClaims) (string, error) {
    56  	c.IssuedAt = time.Now().Unix()
    57  	c.ExpiresAt = time.Now().Add(a.JwtExpiry).Unix()
    58  
    59  	claims, err := ParseStructToMap(c)
    60  	if err != nil {
    61  		return "", err
    62  	}
    63  
    64  	_, tokenString, err := a.JwtAuth.Encode(claims)
    65  	return tokenString, err
    66  }
    67  
    68  func ParseStructToMap(c interface{}) (map[string]interface{}, error) {
    69  	var claims map[string]interface{}
    70  	inrec, _ := json.Marshal(c)
    71  	err := json.Unmarshal(inrec, &claims)
    72  	if err != nil {
    73  		return nil, err
    74  	}
    75  
    76  	return claims, err
    77  }
    78  
    79  // CreateRefreshJWT returns a refresh token for provided token Claims.
    80  func (a *TokenAuth) CreateRefreshJWT(c RefreshClaims) (string, error) {
    81  	c.IssuedAt = time.Now().Unix()
    82  	c.ExpiresAt = time.Now().Add(a.JwtRefreshExpiry).Unix()
    83  
    84  	claims, err := ParseStructToMap(c)
    85  	if err != nil {
    86  		return "", err
    87  	}
    88  
    89  	_, tokenString, err := a.JwtAuth.Encode(claims)
    90  	return tokenString, err
    91  }
    92  
    93  const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
    94  
    95  func randStringBytes(n int) string {
    96  	buf := make([]byte, n)
    97  	if _, err := rand.Read(buf); err != nil {
    98  		panic(err)
    99  	}
   100  
   101  	for k, v := range buf {
   102  		buf[k] = letterBytes[v%byte(len(letterBytes))]
   103  	}
   104  	return string(buf)
   105  }