github.com/dhax/go-base@v0.0.0-20231004214136-8be7e5c1972b/auth/jwt/authenticator.go (about) 1 package jwt 2 3 import ( 4 "context" 5 "net/http" 6 7 "github.com/lestrrat-go/jwx/v2/jwt" 8 9 "github.com/go-chi/jwtauth/v5" 10 "github.com/go-chi/render" 11 12 "github.com/dhax/go-base/logging" 13 ) 14 15 type ctxKey int 16 17 const ( 18 ctxClaims ctxKey = iota 19 ctxRefreshToken 20 ) 21 22 // ClaimsFromCtx retrieves the parsed AppClaims from request context. 23 func ClaimsFromCtx(ctx context.Context) AppClaims { 24 return ctx.Value(ctxClaims).(AppClaims) 25 } 26 27 // RefreshTokenFromCtx retrieves the parsed refresh token from context. 28 func RefreshTokenFromCtx(ctx context.Context) string { 29 return ctx.Value(ctxRefreshToken).(string) 30 } 31 32 // Authenticator is a default authentication middleware to enforce access from the 33 // Verifier middleware request context values. The Authenticator sends a 401 Unauthorized 34 // response for any unverified tokens and passes the good ones through. 35 func Authenticator(next http.Handler) http.Handler { 36 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 37 token, claims, err := jwtauth.FromContext(r.Context()) 38 39 if err != nil { 40 logging.GetLogEntry(r).Warn(err) 41 render.Render(w, r, ErrUnauthorized(ErrTokenUnauthorized)) 42 return 43 } 44 45 if err := jwt.Validate(token); err != nil { 46 render.Render(w, r, ErrUnauthorized(ErrTokenExpired)) 47 return 48 } 49 50 // Token is authenticated, parse claims 51 var c AppClaims 52 err = c.ParseClaims(claims) 53 if err != nil { 54 logging.GetLogEntry(r).Error(err) 55 render.Render(w, r, ErrUnauthorized(ErrInvalidAccessToken)) 56 return 57 } 58 59 // Set AppClaims on context 60 ctx := context.WithValue(r.Context(), ctxClaims, c) 61 next.ServeHTTP(w, r.WithContext(ctx)) 62 }) 63 } 64 65 // AuthenticateRefreshJWT checks validity of refresh tokens and is only used for access token refresh and logout requests. It responds with 401 Unauthorized for invalid or expired refresh tokens. 66 func AuthenticateRefreshJWT(next http.Handler) http.Handler { 67 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 68 token, claims, err := jwtauth.FromContext(r.Context()) 69 if err != nil { 70 logging.GetLogEntry(r).Warn(err) 71 render.Render(w, r, ErrUnauthorized(ErrTokenUnauthorized)) 72 return 73 } 74 75 if err := jwt.Validate(token); err != nil { 76 render.Render(w, r, ErrUnauthorized(ErrTokenExpired)) 77 return 78 } 79 80 // Token is authenticated, parse refresh token string 81 var c RefreshClaims 82 err = c.ParseClaims(claims) 83 if err != nil { 84 logging.GetLogEntry(r).Error(err) 85 render.Render(w, r, ErrUnauthorized(ErrInvalidRefreshToken)) 86 return 87 } 88 // Set refresh token string on context 89 ctx := context.WithValue(r.Context(), ctxRefreshToken, c.Token) 90 next.ServeHTTP(w, r.WithContext(ctx)) 91 }) 92 }