github.com/xmidt-org/webpa-common@v1.11.9/secure/handler/authorizationHandler.go (about)

     1  package handler
     2  
     3  import (
     4  	"errors"
     5  	"net/http"
     6  
     7  	"github.com/SermoDigital/jose/jws"
     8  	"github.com/go-kit/kit/log"
     9  	"github.com/go-kit/kit/log/level"
    10  	"github.com/xmidt-org/webpa-common/logging"
    11  	"github.com/xmidt-org/webpa-common/secure"
    12  	"github.com/xmidt-org/webpa-common/xhttp"
    13  )
    14  
    15  const (
    16  	// The Content-Type value for JSON
    17  	JsonContentType string = "application/json; charset=UTF-8"
    18  
    19  	// The Content-Type header
    20  	ContentTypeHeader string = "Content-Type"
    21  
    22  	// The X-Content-Type-Options header
    23  	ContentTypeOptionsHeader string = "X-Content-Type-Options"
    24  
    25  	// NoSniff is the value used for content options for errors written by this package
    26  	NoSniff string = "nosniff"
    27  )
    28  
    29  // AuthorizationHandler provides decoration for http.Handler instances and will
    30  // ensure that requests pass the validator.  Note that secure.Validators is a Validator
    31  // implementation that allows chaining validators together via logical OR.
    32  type AuthorizationHandler struct {
    33  	HeaderName          string
    34  	ForbiddenStatusCode int
    35  	Validator           secure.Validator
    36  	Logger              log.Logger
    37  	measures            *secure.JWTValidationMeasures
    38  }
    39  
    40  // headerName returns the authorization header to use, either a.HeaderName
    41  // or secure.AuthorizationHeader if no header is supplied
    42  func (a AuthorizationHandler) headerName() string {
    43  	if len(a.HeaderName) > 0 {
    44  		return a.HeaderName
    45  	}
    46  
    47  	return secure.AuthorizationHeader
    48  }
    49  
    50  // forbiddenStatusCode returns a.ForbiddenStatusCode if supplied, otherwise
    51  // http.StatusForbidden is returned
    52  func (a AuthorizationHandler) forbiddenStatusCode() int {
    53  	if a.ForbiddenStatusCode > 0 {
    54  		return a.ForbiddenStatusCode
    55  	}
    56  
    57  	return http.StatusForbidden
    58  }
    59  
    60  func (a AuthorizationHandler) logger() log.Logger {
    61  	if a.Logger != nil {
    62  		return a.Logger
    63  	}
    64  
    65  	return logging.DefaultLogger()
    66  }
    67  
    68  // Decorate provides an Alice-compatible constructor that validates requests
    69  // using the configuration specified.
    70  func (a AuthorizationHandler) Decorate(delegate http.Handler) http.Handler {
    71  	// if there is no validator, there's no point in decorating anything
    72  	if a.Validator == nil {
    73  		return delegate
    74  	}
    75  
    76  	var (
    77  		headerName          = a.headerName()
    78  		forbiddenStatusCode = a.forbiddenStatusCode()
    79  		logger              = a.logger()
    80  		errorLog            = logging.Error(logger)
    81  	)
    82  
    83  	return http.HandlerFunc(func(response http.ResponseWriter, request *http.Request) {
    84  		headerValue := request.Header.Get(headerName)
    85  		if len(headerValue) == 0 {
    86  			errorLog.Log(logging.MessageKey(), "missing header", "name", headerName)
    87  			xhttp.WriteErrorf(response, forbiddenStatusCode, "missing header: %s", headerName)
    88  
    89  			if a.measures != nil {
    90  				a.measures.ValidationReason.With("reason", "missing_header").Add(1)
    91  			}
    92  			return
    93  		}
    94  
    95  		token, err := secure.ParseAuthorization(headerValue)
    96  		if err != nil {
    97  			errorLog.Log(logging.MessageKey(), "invalid authorization header", "name", headerName, logging.ErrorKey(), err)
    98  			xhttp.WriteErrorf(response, forbiddenStatusCode, "Invalid authorization header [%s]: %s", headerName, err.Error())
    99  
   100  			if a.measures != nil {
   101  				a.measures.ValidationReason.With("reason", "invalid_header").Add(1)
   102  			}
   103  			return
   104  		}
   105  
   106  		contextValues := &ContextValues{
   107  			Method: request.Method,
   108  			Path:   request.URL.Path,
   109  			Trust:  secure.Untrusted, // trust isn't set on the token until validation (ugh)
   110  		}
   111  
   112  		sharedContext := NewContextWithValue(request.Context(), contextValues)
   113  
   114  		valid, err := a.Validator.Validate(sharedContext, token)
   115  		if err == nil && valid {
   116  			if err := populateContextValues(token, contextValues); err != nil {
   117  				logger.Log(level.Key(), level.ErrorValue(), logging.MessageKey(), "unable to populate context", logging.ErrorKey(), err)
   118  			}
   119  
   120  			// this is absolutely horrible, but it's the only way we can do it for now.
   121  			// TODO: address this in a redesign
   122  			contextValues.Trust = token.Trust()
   123  			delegate.ServeHTTP(response, request.WithContext(sharedContext))
   124  			return
   125  		}
   126  
   127  		errorLog.Log(
   128  			logging.MessageKey(), "request denied",
   129  			"validator-response", valid,
   130  			"validator-error", err,
   131  			"sat-client-id", contextValues.SatClientID,
   132  			"method", request.Method,
   133  			"url", request.URL,
   134  			"user-agent", request.Header.Get("User-Agent"),
   135  			"content-length", request.ContentLength,
   136  			"remoteAddress", request.RemoteAddr,
   137  		)
   138  
   139  		xhttp.WriteError(response, forbiddenStatusCode, "request denied")
   140  	})
   141  }
   142  
   143  //DefineMeasures facilitates clients to define authHandler metrics tools
   144  func (a *AuthorizationHandler) DefineMeasures(m *secure.JWTValidationMeasures) {
   145  	a.measures = m
   146  }
   147  
   148  func populateContextValues(token *secure.Token, values *ContextValues) error {
   149  	values.SatClientID = "N/A"
   150  
   151  	if token.Type() != secure.Bearer {
   152  		return nil
   153  	}
   154  
   155  	jwsToken, err := secure.DefaultJWSParser.ParseJWS(token)
   156  	if err != nil {
   157  		return err
   158  	}
   159  
   160  	claims, ok := jwsToken.Payload().(jws.Claims)
   161  	if !ok {
   162  		return errors.New("no claims")
   163  	}
   164  
   165  	if sub, ok := claims.Get("sub").(string); ok {
   166  		values.SatClientID = sub
   167  	}
   168  
   169  	if allowedResources, ok := claims.Get("allowedResources").(map[string]interface{}); ok {
   170  		if allowedPartners, ok := allowedResources["allowedPartners"].([]interface{}); ok {
   171  			values.PartnerIDs = make([]string, 0, len(allowedPartners))
   172  			for i := 0; i < len(allowedPartners); i++ {
   173  				if value, ok := allowedPartners[i].(string); ok {
   174  					values.PartnerIDs = append(values.PartnerIDs, value)
   175  				}
   176  			}
   177  		}
   178  	}
   179  
   180  	return nil
   181  }