github.com/dhax/go-base@v0.0.0-20231004214136-8be7e5c1972b/auth/jwt/tokenauth.go (about) 1 package jwt 2 3 import ( 4 "crypto/rand" 5 "encoding/json" 6 "net/http" 7 "time" 8 9 "github.com/go-chi/jwtauth/v5" 10 "github.com/spf13/viper" 11 ) 12 13 // TokenAuth implements JWT authentication flow. 14 type TokenAuth struct { 15 JwtAuth *jwtauth.JWTAuth 16 JwtExpiry time.Duration 17 JwtRefreshExpiry time.Duration 18 } 19 20 // NewTokenAuth configures and returns a JWT authentication instance. 21 func NewTokenAuth() (*TokenAuth, error) { 22 secret := viper.GetString("auth_jwt_secret") 23 if secret == "random" { 24 secret = randStringBytes(32) 25 } 26 27 a := &TokenAuth{ 28 JwtAuth: jwtauth.New("HS256", []byte(secret), nil), 29 JwtExpiry: viper.GetDuration("auth_jwt_expiry"), 30 JwtRefreshExpiry: viper.GetDuration("auth_jwt_refresh_expiry"), 31 } 32 33 return a, nil 34 } 35 36 // Verifier http middleware will verify a jwt string from a http request. 37 func (a *TokenAuth) Verifier() func(http.Handler) http.Handler { 38 return jwtauth.Verifier(a.JwtAuth) 39 } 40 41 // GenTokenPair returns both an access token and a refresh token. 42 func (a *TokenAuth) GenTokenPair(accessClaims AppClaims, refreshClaims RefreshClaims) (string, string, error) { 43 access, err := a.CreateJWT(accessClaims) 44 if err != nil { 45 return "", "", err 46 } 47 refresh, err := a.CreateRefreshJWT(refreshClaims) 48 if err != nil { 49 return "", "", err 50 } 51 return access, refresh, nil 52 } 53 54 // CreateJWT returns an access token for provided account claims. 55 func (a *TokenAuth) CreateJWT(c AppClaims) (string, error) { 56 c.IssuedAt = time.Now().Unix() 57 c.ExpiresAt = time.Now().Add(a.JwtExpiry).Unix() 58 59 claims, err := ParseStructToMap(c) 60 if err != nil { 61 return "", err 62 } 63 64 _, tokenString, err := a.JwtAuth.Encode(claims) 65 return tokenString, err 66 } 67 68 func ParseStructToMap(c interface{}) (map[string]interface{}, error) { 69 var claims map[string]interface{} 70 inrec, _ := json.Marshal(c) 71 err := json.Unmarshal(inrec, &claims) 72 if err != nil { 73 return nil, err 74 } 75 76 return claims, err 77 } 78 79 // CreateRefreshJWT returns a refresh token for provided token Claims. 80 func (a *TokenAuth) CreateRefreshJWT(c RefreshClaims) (string, error) { 81 c.IssuedAt = time.Now().Unix() 82 c.ExpiresAt = time.Now().Add(a.JwtRefreshExpiry).Unix() 83 84 claims, err := ParseStructToMap(c) 85 if err != nil { 86 return "", err 87 } 88 89 _, tokenString, err := a.JwtAuth.Encode(claims) 90 return tokenString, err 91 } 92 93 const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" 94 95 func randStringBytes(n int) string { 96 buf := make([]byte, n) 97 if _, err := rand.Read(buf); err != nil { 98 panic(err) 99 } 100 101 for k, v := range buf { 102 buf[k] = letterBytes[v%byte(len(letterBytes))] 103 } 104 return string(buf) 105 }