github.com/cloudwego/hertz@v0.9.3/pkg/protocol/http1/req/header.go (about)

     1  /*
     2   * Copyright 2022 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   *
    16   * The MIT License (MIT)
    17   *
    18   * Copyright (c) 2015-present Aliaksandr Valialkin, VertaMedia, Kirill Danshin, Erik Dubbelboer, FastHTTP Authors
    19   *
    20   * Permission is hereby granted, free of charge, to any person obtaining a copy
    21   * of this software and associated documentation files (the "Software"), to deal
    22   * in the Software without restriction, including without limitation the rights
    23   * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    24   * copies of the Software, and to permit persons to whom the Software is
    25   * furnished to do so, subject to the following conditions:
    26   *
    27   * The above copyright notice and this permission notice shall be included in
    28   * all copies or substantial portions of the Software.
    29   *
    30   * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    31   * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    32   * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    33   * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    34   * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    35   * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    36   * THE SOFTWARE.
    37   *
    38   * This file may have been modified by CloudWeGo authors. All CloudWeGo
    39   * Modifications are Copyright 2022 CloudWeGo Authors.
    40   */
    41  
    42  package req
    43  
    44  import (
    45  	"bytes"
    46  	"errors"
    47  	"fmt"
    48  	"io"
    49  
    50  	"github.com/cloudwego/hertz/internal/bytesconv"
    51  	"github.com/cloudwego/hertz/internal/bytestr"
    52  	errs "github.com/cloudwego/hertz/pkg/common/errors"
    53  	"github.com/cloudwego/hertz/pkg/common/utils"
    54  	"github.com/cloudwego/hertz/pkg/network"
    55  	"github.com/cloudwego/hertz/pkg/protocol"
    56  	"github.com/cloudwego/hertz/pkg/protocol/consts"
    57  	"github.com/cloudwego/hertz/pkg/protocol/http1/ext"
    58  )
    59  
    60  var errEOFReadHeader = errs.NewPublic("error when reading request headers: EOF")
    61  
    62  // Write writes request header to w.
    63  func WriteHeader(h *protocol.RequestHeader, w network.Writer) error {
    64  	header := h.Header()
    65  	_, err := w.WriteBinary(header)
    66  	return err
    67  }
    68  
    69  func ReadHeader(h *protocol.RequestHeader, r network.Reader) error {
    70  	n := 1
    71  	for {
    72  		err := tryRead(h, r, n)
    73  		if err == nil {
    74  			return nil
    75  		}
    76  		if !errors.Is(err, errs.ErrNeedMore) {
    77  			h.ResetSkipNormalize()
    78  			return err
    79  		}
    80  
    81  		// No more data available on the wire, try block peek
    82  		if n == r.Len() {
    83  			n++
    84  			continue
    85  		}
    86  		n = r.Len()
    87  	}
    88  }
    89  
    90  func tryRead(h *protocol.RequestHeader, r network.Reader, n int) error {
    91  	h.ResetSkipNormalize()
    92  	b, err := r.Peek(n)
    93  	if len(b) == 0 {
    94  		if err != io.EOF {
    95  			return err
    96  		}
    97  
    98  		// n == 1 on the first read for the request.
    99  		if n == 1 {
   100  			// We didn't read a single byte.
   101  			return errs.New(errs.ErrNothingRead, errs.ErrorTypePrivate, err)
   102  		}
   103  
   104  		return errEOFReadHeader
   105  	}
   106  	b = ext.MustPeekBuffered(r)
   107  	headersLen, errParse := parse(h, b)
   108  	if errParse != nil {
   109  		return ext.HeaderError("request", err, errParse, b)
   110  	}
   111  	ext.MustDiscard(r, headersLen)
   112  	return nil
   113  }
   114  
   115  func parse(h *protocol.RequestHeader, buf []byte) (int, error) {
   116  	m, err := parseFirstLine(h, buf)
   117  	if err != nil {
   118  		return 0, err
   119  	}
   120  
   121  	rawHeaders, _, err := ext.ReadRawHeaders(h.RawHeaders()[:0], buf[m:])
   122  	h.SetRawHeaders(rawHeaders)
   123  	if err != nil {
   124  		return 0, err
   125  	}
   126  	var n int
   127  	n, err = parseHeaders(h, buf[m:])
   128  	if err != nil {
   129  		return 0, err
   130  	}
   131  	return m + n, nil
   132  }
   133  
   134  func parseFirstLine(h *protocol.RequestHeader, buf []byte) (int, error) {
   135  	bNext := buf
   136  	var b []byte
   137  	var err error
   138  	for len(b) == 0 {
   139  		if b, bNext, err = utils.NextLine(bNext); err != nil {
   140  			return 0, err
   141  		}
   142  	}
   143  
   144  	// parse method
   145  	n := bytes.IndexByte(b, ' ')
   146  	if n <= 0 {
   147  		return 0, fmt.Errorf("cannot find http request method in %q", ext.BufferSnippet(buf))
   148  	}
   149  	h.SetMethodBytes(b[:n])
   150  	b = b[n+1:]
   151  
   152  	// Set default protocol
   153  	h.SetProtocol(consts.HTTP11)
   154  	// parse requestURI
   155  	n = bytes.LastIndexByte(b, ' ')
   156  	if n < 0 {
   157  		h.SetProtocol(consts.HTTP10)
   158  		n = len(b)
   159  	} else if n == 0 {
   160  		return 0, fmt.Errorf("requestURI cannot be empty in %q", buf)
   161  	} else if !bytes.Equal(b[n+1:], bytestr.StrHTTP11) {
   162  		h.SetProtocol(consts.HTTP10)
   163  	}
   164  	h.SetRequestURIBytes(b[:n])
   165  
   166  	return len(buf) - len(bNext), nil
   167  }
   168  
   169  // validHeaderFieldValue is equal to httpguts.ValidHeaderFieldValue(shares the same context)
   170  func validHeaderFieldValue(val []byte) bool {
   171  	for _, v := range val {
   172  		if bytesconv.ValidHeaderFieldValueTable[v] == 0 {
   173  			return false
   174  		}
   175  	}
   176  	return true
   177  }
   178  
   179  func parseHeaders(h *protocol.RequestHeader, buf []byte) (int, error) {
   180  	h.InitContentLengthWithValue(-2)
   181  
   182  	var s ext.HeaderScanner
   183  	s.B = buf
   184  	s.DisableNormalizing = h.IsDisableNormalizing()
   185  	var err error
   186  	for s.Next() {
   187  		if len(s.Key) > 0 {
   188  			// Spaces between the header key and colon are not allowed.
   189  			// See RFC 7230, Section 3.2.4.
   190  			if bytes.IndexByte(s.Key, ' ') != -1 || bytes.IndexByte(s.Key, '\t') != -1 {
   191  				err = fmt.Errorf("invalid header key %q", s.Key)
   192  				return 0, err
   193  			}
   194  
   195  			// Check the invalid chars in header value
   196  			if !validHeaderFieldValue(s.Value) {
   197  				err = fmt.Errorf("invalid header value %q", s.Value)
   198  				return 0, err
   199  			}
   200  
   201  			switch s.Key[0] | 0x20 {
   202  			case 'h':
   203  				if utils.CaseInsensitiveCompare(s.Key, bytestr.StrHost) {
   204  					h.SetHostBytes(s.Value)
   205  					continue
   206  				}
   207  			case 'u':
   208  				if utils.CaseInsensitiveCompare(s.Key, bytestr.StrUserAgent) {
   209  					h.SetUserAgentBytes(s.Value)
   210  					continue
   211  				}
   212  			case 'c':
   213  				if utils.CaseInsensitiveCompare(s.Key, bytestr.StrContentType) {
   214  					h.SetContentTypeBytes(s.Value)
   215  					continue
   216  				}
   217  				if utils.CaseInsensitiveCompare(s.Key, bytestr.StrContentLength) {
   218  					if h.ContentLength() != -1 {
   219  						var nerr error
   220  						var contentLength int
   221  						if contentLength, nerr = protocol.ParseContentLength(s.Value); nerr != nil {
   222  							if err == nil {
   223  								err = nerr
   224  							}
   225  							h.InitContentLengthWithValue(-2)
   226  						} else {
   227  							h.InitContentLengthWithValue(contentLength)
   228  							h.SetContentLengthBytes(s.Value)
   229  						}
   230  					}
   231  					continue
   232  				}
   233  				if utils.CaseInsensitiveCompare(s.Key, bytestr.StrConnection) {
   234  					if bytes.Equal(s.Value, bytestr.StrClose) {
   235  						h.SetConnectionClose(true)
   236  					} else {
   237  						h.SetConnectionClose(false)
   238  						h.AddArgBytes(s.Key, s.Value, protocol.ArgsHasValue)
   239  					}
   240  					continue
   241  				}
   242  			case 't':
   243  				if utils.CaseInsensitiveCompare(s.Key, bytestr.StrTransferEncoding) {
   244  					if !bytes.Equal(s.Value, bytestr.StrIdentity) {
   245  						h.InitContentLengthWithValue(-1)
   246  						h.SetArgBytes(bytestr.StrTransferEncoding, bytestr.StrChunked, protocol.ArgsHasValue)
   247  					}
   248  					continue
   249  				}
   250  				if utils.CaseInsensitiveCompare(s.Key, bytestr.StrTrailer) {
   251  					if nerr := h.Trailer().SetTrailers(s.Value); nerr != nil {
   252  						if err == nil {
   253  							err = nerr
   254  						}
   255  					}
   256  					continue
   257  				}
   258  			}
   259  		}
   260  		h.AddArgBytes(s.Key, s.Value, protocol.ArgsHasValue)
   261  	}
   262  
   263  	if s.Err != nil && err == nil {
   264  		err = s.Err
   265  	}
   266  	if err != nil {
   267  		h.SetConnectionClose(true)
   268  		return 0, err
   269  	}
   270  
   271  	if h.ContentLength() < 0 {
   272  		h.SetContentLengthBytes(h.ContentLengthBytes()[:0])
   273  	}
   274  	if !h.IsHTTP11() && !h.ConnectionClose() {
   275  		// close connection for non-http/1.1 request unless 'Connection: keep-alive' is set.
   276  		v := h.PeekArgBytes(bytestr.StrConnection)
   277  		h.SetConnectionClose(!ext.HasHeaderValue(v, bytestr.StrKeepAlive))
   278  	}
   279  	return s.HLen, nil
   280  }