github.com/blend/go-sdk@v1.20220411.3/sanitize/request.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package sanitize
     9  
    10  import (
    11  	"net/http"
    12  	"strings"
    13  )
    14  
    15  // Request sanitizes a given request.
    16  func Request(r *http.Request, opts ...RequestOption) *http.Request {
    17  	return NewRequestSanitizer(opts...).Sanitize(r)
    18  }
    19  
    20  // NewRequestSanitizer creates a new request sanitizer.
    21  func NewRequestSanitizer(opts ...RequestOption) RequestSanitizer {
    22  	r := RequestSanitizer{
    23  		DisallowedHeaders:     DefaultSanitizationDisallowedHeaders,
    24  		DisallowedQueryParams: DefaultSanitizationDisallowedQueryParams,
    25  		KeyValuesSanitizer:    KeyValuesSanitizerFunc(DefaultKeyValuesSanitizerFunc),
    26  		PathSanitizer:         PathSanitizerFunc(DefaultPathSanitizerFunc),
    27  	}
    28  	for _, opt := range opts {
    29  		opt(&r)
    30  	}
    31  	return r
    32  }
    33  
    34  // RequestOption is a function that mutates sanitization options.
    35  type RequestOption func(*RequestSanitizer)
    36  
    37  // OptRequestAddDisallowedHeaders adds disallowed request headers, augmenting defaults.
    38  func OptRequestAddDisallowedHeaders(headers ...string) RequestOption {
    39  	return func(ro *RequestSanitizer) {
    40  		ro.DisallowedHeaders = append(ro.DisallowedHeaders, headers...)
    41  	}
    42  }
    43  
    44  // OptRequestSetDisallowedHeaders sets the disallowed request headers, overwriting defaults.
    45  func OptRequestSetDisallowedHeaders(headers ...string) RequestOption {
    46  	return func(ro *RequestSanitizer) {
    47  		ro.DisallowedHeaders = headers
    48  	}
    49  }
    50  
    51  // OptRequestAddDisallowedQueryParams adds disallowed request query params, augmenting defaults.
    52  func OptRequestAddDisallowedQueryParams(queryParams ...string) RequestOption {
    53  	return func(rs *RequestSanitizer) {
    54  		rs.DisallowedQueryParams = append(rs.DisallowedQueryParams, queryParams...)
    55  	}
    56  }
    57  
    58  // OptRequestSetDisallowedQueryParams sets the disallowed request query params, overwriting defaults.
    59  func OptRequestSetDisallowedQueryParams(queryParams ...string) RequestOption {
    60  	return func(rs *RequestSanitizer) {
    61  		rs.DisallowedQueryParams = queryParams
    62  	}
    63  }
    64  
    65  // OptRequestKeyValuesSanitizer sets the request key values sanitizer.
    66  func OptRequestKeyValuesSanitizer(valueSanitizer KeyValuesSanitizer) RequestOption {
    67  	return func(rs *RequestSanitizer) {
    68  		rs.KeyValuesSanitizer = valueSanitizer
    69  	}
    70  }
    71  
    72  // OptRequestPathSanitizer sets the request path sanitizer.
    73  func OptRequestPathSanitizer(pathSanitizer PathSanitizer) RequestOption {
    74  	return func(rs *RequestSanitizer) {
    75  		rs.PathSanitizer = pathSanitizer
    76  	}
    77  }
    78  
    79  // RequestSanitizer are options for sanitization of http requests.
    80  type RequestSanitizer struct {
    81  	DisallowedHeaders     []string
    82  	DisallowedQueryParams []string
    83  	KeyValuesSanitizer    KeyValuesSanitizer
    84  	PathSanitizer         PathSanitizer
    85  }
    86  
    87  // Sanitize applies sanitization options to a given request.
    88  func (rs RequestSanitizer) Sanitize(r *http.Request) *http.Request {
    89  	if r == nil {
    90  		return nil
    91  	}
    92  
    93  	copy := r.Clone(r.Context())
    94  	for header, values := range copy.Header {
    95  		if rs.IsHeaderDisallowed(header) {
    96  			copy.Header[header] = rs.KeyValuesSanitizer.SanitizeKeyValues(header, values...)
    97  		}
    98  	}
    99  	if copy.URL != nil {
   100  		queryParams := copy.URL.Query()
   101  		for queryParam, values := range queryParams {
   102  			if rs.IsQueryParamDisallowed(queryParam) {
   103  				queryParams[queryParam] = rs.KeyValuesSanitizer.SanitizeKeyValues(queryParam, values...)
   104  			}
   105  		}
   106  		copy.URL.RawQuery = queryParams.Encode()
   107  
   108  		// also sanitize the path
   109  		copy.URL.Path = rs.PathSanitizer.SanitizePath(copy.URL.Path)
   110  	}
   111  	return copy
   112  }
   113  
   114  // IsHeaderDisallowed returns if a header is in the disallowed list.
   115  func (rs RequestSanitizer) IsHeaderDisallowed(header string) bool {
   116  	for _, disallowedHeader := range rs.DisallowedHeaders {
   117  		if strings.EqualFold(disallowedHeader, header) {
   118  			return true
   119  		}
   120  	}
   121  	return false
   122  }
   123  
   124  // IsQueryParamDisallowed returns if a query param is in the disallowed list.
   125  func (rs RequestSanitizer) IsQueryParamDisallowed(queryParam string) bool {
   126  	for _, disallowedQueryParam := range rs.DisallowedQueryParams {
   127  		if strings.EqualFold(disallowedQueryParam, queryParam) {
   128  			return true
   129  		}
   130  	}
   131  	return false
   132  }