github.com/google/go-safeweb@v0.0.0-20231219055052-64d8cfc90fbb/safehttp/incoming_request.go (about)

     1  // Copyright 2020 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package safehttp
    16  
    17  import (
    18  	"context"
    19  	"crypto/tls"
    20  	"fmt"
    21  	"io"
    22  	"net/http"
    23  	"net/url"
    24  	"strings"
    25  	"sync"
    26  )
    27  
    28  // IncomingRequest represents an HTTP request received by the server.
    29  type IncomingRequest struct {
    30  	// Header is the collection of HTTP headers.
    31  	//
    32  	// The Host header is removed from this struct and can be retrieved using Host()
    33  	Header Header
    34  	// TLS is set just like this TLS field of the net/http.Request. For more information
    35  	// see https://pkg.go.dev/net/http?tab=doc#Request.
    36  	TLS *tls.ConnectionState
    37  	req *http.Request
    38  
    39  	// The fields below are kept as pointers to allow cloning through
    40  	// IncomingRequest.WithContext. Otherwise, we'd need to copy locks.
    41  	postParseOnce      *sync.Once
    42  	multipartParseOnce *sync.Once
    43  }
    44  
    45  // NewIncomingRequest creates an IncomingRequest
    46  // from the underlying http.Request.
    47  func NewIncomingRequest(req *http.Request) *IncomingRequest {
    48  	if req == nil {
    49  		return nil
    50  	}
    51  	req = req.WithContext(context.WithValue(req.Context(),
    52  		flightValuesCtxKey{}, flightValues{m: make(map[interface{}]interface{})}))
    53  	return &IncomingRequest{
    54  		req:                req,
    55  		Header:             NewHeader(req.Header),
    56  		TLS:                req.TLS,
    57  		postParseOnce:      &sync.Once{},
    58  		multipartParseOnce: &sync.Once{},
    59  	}
    60  }
    61  
    62  // Body returns the request body reader. It is always non-nil but will return
    63  // EOF immediately when no body is present.
    64  func (r *IncomingRequest) Body() io.ReadCloser {
    65  	return r.req.Body
    66  }
    67  
    68  // Host returns the host the request is targeted to. This value comes from the
    69  // Host header.
    70  func (r *IncomingRequest) Host() string {
    71  	return r.req.Host
    72  }
    73  
    74  // Method returns the HTTP method of the IncomingRequest.
    75  func (r *IncomingRequest) Method() string {
    76  	return r.req.Method
    77  }
    78  
    79  // PostForm parses the form parameters provided in the body of a POST, PATCH or
    80  // PUT request that does not have Content-Type: multipart/form-data. It returns
    81  // the parsed form parameters as a Form object. If a parsing
    82  // error occurs it will return it, together with a nil Form. Unless we expect
    83  // the header Content-Type: multipart/form-data in a POST request, this method
    84  // should  always be used for forms in POST requests.
    85  func (r *IncomingRequest) PostForm() (*Form, error) {
    86  	var err error
    87  	r.postParseOnce.Do(func() {
    88  		if m := r.req.Method; m != MethodPost && m != MethodPatch && m != MethodPut {
    89  			err = fmt.Errorf("got request method %s, want POST/PATCH/PUT", m)
    90  			return
    91  		}
    92  
    93  		if ct := r.req.Header.Get("Content-Type"); ct != "application/x-www-form-urlencoded" {
    94  			err = fmt.Errorf("invalid method called for Content-Type: %s", ct)
    95  			return
    96  		}
    97  
    98  		err = r.req.ParseForm()
    99  	})
   100  	if err != nil {
   101  		return nil, err
   102  	}
   103  	return &Form{values: r.req.PostForm}, nil
   104  }
   105  
   106  // MultipartForm parses the form parameters provided in the body of a POST,
   107  // PATCH or PUT request that has Content-Type set to multipart/form-data. It
   108  // returns a MultipartForm object containing the parsed form parameters and
   109  // file uploads (if any) or the parsing error together with a nil MultipartForm
   110  // otherwise.
   111  //
   112  // If the parsed request body is larger than maxMemory, up to maxMemory bytes
   113  // will be stored in main memory, with the rest stored on disk in temporary
   114  // files.
   115  func (r *IncomingRequest) MultipartForm(maxMemory int64) (*MultipartForm, error) {
   116  	var err error
   117  	r.multipartParseOnce.Do(func() {
   118  		if m := r.req.Method; m != MethodPost && m != MethodPatch && m != MethodPut {
   119  			err = fmt.Errorf("got request method %s, want POST/PATCH/PUT", m)
   120  			return
   121  		}
   122  
   123  		if ct := r.req.Header.Get("Content-Type"); !strings.HasPrefix(ct, "multipart/form-data") {
   124  			err = fmt.Errorf("invalid method called for Content-Type: %s", ct)
   125  			return
   126  		}
   127  		err = r.req.ParseMultipartForm(maxMemory)
   128  	})
   129  	if err != nil {
   130  		return nil, err
   131  	}
   132  	return newMulipartForm(r.req.MultipartForm), nil
   133  }
   134  
   135  // Cookie returns the named cookie provided in the request or
   136  // net/http.ErrNoCookie if not found. If multiple cookies match the given name,
   137  // only one cookie will be returned.
   138  func (r *IncomingRequest) Cookie(name string) (*Cookie, error) {
   139  	c, err := r.req.Cookie(name)
   140  	if err != nil {
   141  		return nil, err
   142  	}
   143  	return &Cookie{wrapped: c}, nil
   144  }
   145  
   146  // Cookies parses and returns the HTTP cookies sent with the request.
   147  func (r *IncomingRequest) Cookies() []*Cookie {
   148  	cl := r.req.Cookies()
   149  	res := make([]*Cookie, 0, len(cl))
   150  	for _, c := range cl {
   151  		res = append(res, &Cookie{wrapped: c})
   152  	}
   153  	return res
   154  }
   155  
   156  // Context returns the context of a safehttp.IncomingRequest. This is always
   157  // non-nil and will default to the background context. The context of a
   158  // safehttp.IncomingRequest is the context of the underlying http.Request.
   159  //
   160  // The context is cancelled when the client's connection
   161  // closes, the request is canceled (with HTTP/2), or when the ServeHTTP method
   162  // returns.
   163  func (r *IncomingRequest) Context() context.Context {
   164  	return r.req.Context()
   165  }
   166  
   167  // WithContext returns a shallow copy of the request with its context changed to
   168  // ctx. The provided ctx must be non-nil.
   169  //
   170  // This is similar to the net/http.Request.WithContext method.
   171  func (r *IncomingRequest) WithContext(ctx context.Context) *IncomingRequest {
   172  	r2 := new(IncomingRequest)
   173  	*r2 = *r
   174  	r2.req = r2.req.WithContext(ctx)
   175  	return r2
   176  }
   177  
   178  // URL specifies the URL that is parsed from the Request-Line. For most requests,
   179  // only URL.Path() will return a non-empty result. (See RFC 7230, Section 5.3)
   180  func (r *IncomingRequest) URL() *URL {
   181  	return &URL{url: r.req.URL}
   182  }
   183  
   184  // WithStrippedURLPrefix returns a shallow copy of the request with its URL
   185  // stripped of a prefix. The prefix has to match exactly (e.g. escaped and
   186  // unescaped characters are considered different).
   187  func (r *IncomingRequest) WithStrippedURLPrefix(prefix string) (*IncomingRequest, error) {
   188  	req := rawRequest(r)
   189  	if !strings.HasPrefix(req.URL.Path, prefix) {
   190  		return nil, fmt.Errorf("Path %q doesn't have prefix %q", req.URL.Path, prefix)
   191  	}
   192  	if req.URL.RawPath != "" && !strings.HasPrefix(req.URL.RawPath, prefix) {
   193  		return nil, fmt.Errorf("RawPath %q doesn't have prefix %q", req.URL.RawPath, prefix)
   194  	}
   195  
   196  	req2 := new(http.Request)
   197  	*req2 = *req
   198  	req2.URL = new(url.URL)
   199  	*req2.URL = *req.URL
   200  	req2.URL.Path = strings.TrimPrefix(req.URL.Path, prefix)
   201  	req2.URL.RawPath = strings.TrimPrefix(req.URL.RawPath, prefix)
   202  
   203  	r2 := new(IncomingRequest)
   204  	*r2 = *r
   205  	r2.req = req2
   206  
   207  	return r2, nil
   208  }
   209  
   210  func rawRequest(r *IncomingRequest) *http.Request {
   211  	return r.req
   212  }