github.com/dhax/go-base@v0.0.0-20231004214136-8be7e5c1972b/auth/jwt/authenticator.go (about)

     1  package jwt
     2  
     3  import (
     4  	"context"
     5  	"net/http"
     6  
     7  	"github.com/lestrrat-go/jwx/v2/jwt"
     8  
     9  	"github.com/go-chi/jwtauth/v5"
    10  	"github.com/go-chi/render"
    11  
    12  	"github.com/dhax/go-base/logging"
    13  )
    14  
    15  type ctxKey int
    16  
    17  const (
    18  	ctxClaims ctxKey = iota
    19  	ctxRefreshToken
    20  )
    21  
    22  // ClaimsFromCtx retrieves the parsed AppClaims from request context.
    23  func ClaimsFromCtx(ctx context.Context) AppClaims {
    24  	return ctx.Value(ctxClaims).(AppClaims)
    25  }
    26  
    27  // RefreshTokenFromCtx retrieves the parsed refresh token from context.
    28  func RefreshTokenFromCtx(ctx context.Context) string {
    29  	return ctx.Value(ctxRefreshToken).(string)
    30  }
    31  
    32  // Authenticator is a default authentication middleware to enforce access from the
    33  // Verifier middleware request context values. The Authenticator sends a 401 Unauthorized
    34  // response for any unverified tokens and passes the good ones through.
    35  func Authenticator(next http.Handler) http.Handler {
    36  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    37  		token, claims, err := jwtauth.FromContext(r.Context())
    38  
    39  		if err != nil {
    40  			logging.GetLogEntry(r).Warn(err)
    41  			render.Render(w, r, ErrUnauthorized(ErrTokenUnauthorized))
    42  			return
    43  		}
    44  
    45  		if err := jwt.Validate(token); err != nil {
    46  			render.Render(w, r, ErrUnauthorized(ErrTokenExpired))
    47  			return
    48  		}
    49  
    50  		// Token is authenticated, parse claims
    51  		var c AppClaims
    52  		err = c.ParseClaims(claims)
    53  		if err != nil {
    54  			logging.GetLogEntry(r).Error(err)
    55  			render.Render(w, r, ErrUnauthorized(ErrInvalidAccessToken))
    56  			return
    57  		}
    58  
    59  		// Set AppClaims on context
    60  		ctx := context.WithValue(r.Context(), ctxClaims, c)
    61  		next.ServeHTTP(w, r.WithContext(ctx))
    62  	})
    63  }
    64  
    65  // AuthenticateRefreshJWT checks validity of refresh tokens and is only used for access token refresh and logout requests. It responds with 401 Unauthorized for invalid or expired refresh tokens.
    66  func AuthenticateRefreshJWT(next http.Handler) http.Handler {
    67  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    68  		token, claims, err := jwtauth.FromContext(r.Context())
    69  		if err != nil {
    70  			logging.GetLogEntry(r).Warn(err)
    71  			render.Render(w, r, ErrUnauthorized(ErrTokenUnauthorized))
    72  			return
    73  		}
    74  
    75  		if err := jwt.Validate(token); err != nil {
    76  			render.Render(w, r, ErrUnauthorized(ErrTokenExpired))
    77  			return
    78  		}
    79  
    80  		// Token is authenticated, parse refresh token string
    81  		var c RefreshClaims
    82  		err = c.ParseClaims(claims)
    83  		if err != nil {
    84  			logging.GetLogEntry(r).Error(err)
    85  			render.Render(w, r, ErrUnauthorized(ErrInvalidRefreshToken))
    86  			return
    87  		}
    88  		// Set refresh token string on context
    89  		ctx := context.WithValue(r.Context(), ctxRefreshToken, c.Token)
    90  		next.ServeHTTP(w, r.WithContext(ctx))
    91  	})
    92  }