github.com/clerkinc/clerk-sdk-go@v1.49.1/clerk/middleware_v2.go (about)

     1  package clerk
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"net"
     7  	"net/http"
     8  	"net/url"
     9  	"regexp"
    10  	"strconv"
    11  	"strings"
    12  
    13  	"github.com/go-jose/go-jose/v3/jwt"
    14  )
    15  
    16  var urlSchemeRe = regexp.MustCompile(`(^\w+:|^)\/\/`)
    17  
    18  // RequireSessionV2 will hijack the request and return an HTTP status 403
    19  // if the session is not authenticated.
    20  func RequireSessionV2(client Client, verifyTokenOptions ...VerifyTokenOption) func(handler http.Handler) http.Handler {
    21  	return func(next http.Handler) http.Handler {
    22  		f := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    23  			claims, ok := r.Context().Value(ActiveSessionClaims).(*SessionClaims)
    24  			if !ok || claims == nil {
    25  				w.WriteHeader(http.StatusForbidden)
    26  				return
    27  			}
    28  
    29  			next.ServeHTTP(w, r)
    30  		})
    31  
    32  		return WithSessionV2(client, verifyTokenOptions...)(f)
    33  	}
    34  }
    35  
    36  // SessionFromContext returns the session's (if any) claims, as parsed from the
    37  // token.
    38  func SessionFromContext(ctx context.Context) (*SessionClaims, bool) {
    39  	c, ok := ctx.Value(ActiveSessionClaims).(*SessionClaims)
    40  	return c, ok
    41  }
    42  
    43  // WithSessionV2 is the de-facto authentication middleware and should be
    44  // preferred to WithSession. If the session is authenticated, it adds the corresponding
    45  // session claims found in the JWT to request's context.
    46  func WithSessionV2(client Client, verifyTokenOptions ...VerifyTokenOption) func(handler http.Handler) http.Handler {
    47  	return func(next http.Handler) http.Handler {
    48  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    49  			// ****************************************************
    50  			//                                                    *
    51  			//                HEADER AUTHENTICATION               *
    52  			//                                                    *
    53  			// ****************************************************
    54  			_, authorizationHeaderExists := r.Header["Authorization"]
    55  
    56  			if authorizationHeaderExists {
    57  				headerToken := strings.TrimSpace(r.Header.Get("Authorization"))
    58  				headerToken = strings.TrimPrefix(headerToken, "Bearer ")
    59  
    60  				_, err := client.DecodeToken(headerToken)
    61  				if err != nil {
    62  					// signed out
    63  					next.ServeHTTP(w, r)
    64  					return
    65  				}
    66  
    67  				claims, err := client.VerifyToken(headerToken, verifyTokenOptions...)
    68  				if err == nil { // signed in
    69  					ctx := context.WithValue(r.Context(), ActiveSessionClaims, claims)
    70  					next.ServeHTTP(w, r.WithContext(ctx))
    71  					return
    72  				}
    73  
    74  				// Clerk.js should refresh the token and retry
    75  				w.WriteHeader(http.StatusUnauthorized)
    76  				return
    77  			}
    78  
    79  			// In development or staging environments only, based on the request User Agent, detect non-browser
    80  			// requests (e.g. scripts). If there is no Authorization header, consider the user as signed out
    81  			// and prevent interstitial rendering
    82  			if isDevelopmentOrStaging(client) && !strings.HasPrefix(r.UserAgent(), "Mozilla/") {
    83  				// signed out
    84  				next.ServeHTTP(w, r)
    85  				return
    86  			}
    87  
    88  			// in cross-origin requests the use of Authorization
    89  			// header is mandatory
    90  			if isCrossOrigin(r) {
    91  				// signed out
    92  				next.ServeHTTP(w, r)
    93  				return
    94  			}
    95  
    96  			// ****************************************************
    97  			//                                                    *
    98  			//                COOKIE AUTHENTICATION               *
    99  			//                                                    *
   100  			// ****************************************************
   101  			cookieToken, _ := r.Cookie("__session")
   102  			clientUat, _ := r.Cookie("__client_uat")
   103  
   104  			if isDevelopmentOrStaging(client) && (r.Referer() == "" || isCrossOrigin(r)) {
   105  				renderInterstitial(client, w)
   106  				return
   107  			}
   108  
   109  			if isProduction(client) && clientUat == nil {
   110  				next.ServeHTTP(w, r)
   111  				return
   112  			}
   113  
   114  			if clientUat != nil && clientUat.Value == "0" {
   115  				next.ServeHTTP(w, r)
   116  				return
   117  			}
   118  
   119  			if clientUat == nil {
   120  				renderInterstitial(client, w)
   121  				return
   122  			}
   123  
   124  			var clientUatTs int64
   125  			ts, err := strconv.ParseInt(clientUat.Value, 10, 64)
   126  			if err == nil {
   127  				clientUatTs = ts
   128  			}
   129  
   130  			if cookieToken == nil {
   131  				renderInterstitial(client, w)
   132  				return
   133  			}
   134  
   135  			claims, err := client.VerifyToken(cookieToken.Value, verifyTokenOptions...)
   136  
   137  			if err == nil {
   138  				if claims.IssuedAt != nil && clientUatTs <= int64(*claims.IssuedAt) {
   139  					ctx := context.WithValue(r.Context(), ActiveSessionClaims, claims)
   140  					next.ServeHTTP(w, r.WithContext(ctx))
   141  					return
   142  				}
   143  
   144  				renderInterstitial(client, w)
   145  				return
   146  			}
   147  
   148  			if errors.Is(err, jwt.ErrExpired) || errors.Is(err, jwt.ErrIssuedInTheFuture) {
   149  				renderInterstitial(client, w)
   150  				return
   151  			}
   152  
   153  			// signed out
   154  			next.ServeHTTP(w, r)
   155  			return
   156  		})
   157  	}
   158  }
   159  
   160  func isCrossOrigin(r *http.Request) bool {
   161  	// origin contains scheme+host and optionally port (ommitted if 80 or 443)
   162  	// ref. https://www.rfc-editor.org/rfc/rfc6454#section-6.1
   163  	origin := strings.TrimSpace(r.Header.Get("Origin"))
   164  	origin = urlSchemeRe.ReplaceAllString(origin, "") // strip scheme
   165  	if origin == "" {
   166  		return false
   167  	}
   168  
   169  	// parse request's host and port, taking into account reverse proxies
   170  	u := &url.URL{Host: r.Host}
   171  	host := strings.TrimSpace(r.Header.Get("X-Forwarded-Host"))
   172  	if host == "" {
   173  		host = u.Hostname()
   174  	}
   175  	port := strings.TrimSpace(r.Header.Get("X-Forwarded-Port"))
   176  	if port == "" {
   177  		port = u.Port()
   178  	}
   179  
   180  	if port != "" && port != "80" && port != "443" {
   181  		host = net.JoinHostPort(host, port)
   182  	}
   183  
   184  	return origin != host
   185  }
   186  
   187  func isDevelopmentOrStaging(c Client) bool {
   188  	return strings.HasPrefix(c.APIKey(), "test_") || strings.HasPrefix(c.APIKey(), "sk_test_")
   189  }
   190  
   191  func isProduction(c Client) bool {
   192  	return !isDevelopmentOrStaging(c)
   193  }
   194  
   195  func renderInterstitial(c Client, w http.ResponseWriter) {
   196  	w.Header().Set("content-type", "text/html")
   197  	w.WriteHeader(401)
   198  	resp, _ := c.Interstitial()
   199  	w.Write(resp)
   200  }