github.com/pf-qiu/concourse/v6@v6.7.3-0.20201207032516-1f455d73275f/atc/api/auth/web_auth_handler.go (about)

     1  package auth
     2  
     3  import (
     4  	"context"
     5  	"net/http"
     6  
     7  	"github.com/pf-qiu/concourse/v6/skymarshal/token"
     8  )
     9  
    10  //go:generate counterfeiter net/http.Handler
    11  
    12  func NewResponseWrapper(w http.ResponseWriter, m token.Middleware) *responseWrapper {
    13  	return &responseWrapper{w, m}
    14  }
    15  
    16  type responseWrapper struct {
    17  	http.ResponseWriter
    18  	token.Middleware
    19  }
    20  
    21  func (r *responseWrapper) WriteHeader(statusCode int) {
    22  
    23  	// we need to unset cookies before writing the header
    24  	if statusCode == http.StatusUnauthorized {
    25  		r.Middleware.UnsetAuthToken(r.ResponseWriter)
    26  		r.Middleware.UnsetCSRFToken(r.ResponseWriter)
    27  	}
    28  
    29  	r.ResponseWriter.WriteHeader(statusCode)
    30  }
    31  
    32  func (r *responseWrapper) Flush() {
    33  	r.ResponseWriter.(http.Flusher).Flush()
    34  }
    35  
    36  type WebAuthHandler struct {
    37  	Handler    http.Handler
    38  	Middleware token.Middleware
    39  }
    40  
    41  func (handler WebAuthHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    42  
    43  	tokenString := handler.Middleware.GetAuthToken(r)
    44  	if tokenString != "" {
    45  		ctx := context.WithValue(r.Context(), CSRFRequiredKey, handler.isCSRFRequired(r))
    46  		r = r.WithContext(ctx)
    47  
    48  		if r.Header.Get("Authorization") == "" {
    49  			r.Header.Set("Authorization", tokenString)
    50  		}
    51  
    52  		wrapper := NewResponseWrapper(w, handler.Middleware)
    53  		handler.Handler.ServeHTTP(wrapper, r)
    54  	} else {
    55  		handler.Handler.ServeHTTP(w, r)
    56  	}
    57  }
    58  
    59  // We don't validate CSRF token for GET requests
    60  // since they are not changing the state
    61  func (handler WebAuthHandler) isCSRFRequired(r *http.Request) bool {
    62  	return (r.Method != http.MethodGet && r.Method != http.MethodHead && r.Method != http.MethodOptions)
    63  }
    64  
    65  func IsCSRFRequired(r *http.Request) bool {
    66  	required, ok := r.Context().Value(CSRFRequiredKey).(bool)
    67  	return ok && required
    68  }