github.com/cloudwego/hertz@v0.9.3/pkg/protocol/trailer.go (about)

     1  /*
     2   * Copyright 2023 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package protocol
    18  
    19  import (
    20  	"bytes"
    21  
    22  	"github.com/cloudwego/hertz/internal/bytestr"
    23  	errs "github.com/cloudwego/hertz/pkg/common/errors"
    24  	"github.com/cloudwego/hertz/pkg/common/utils"
    25  	"github.com/cloudwego/hertz/pkg/protocol/consts"
    26  )
    27  
    28  type Trailer struct {
    29  	h                  []argsKV
    30  	bufKV              argsKV
    31  	disableNormalizing bool
    32  }
    33  
    34  // Get returns trailer value for the given key.
    35  func (t *Trailer) Get(key string) string {
    36  	return string(t.Peek(key))
    37  }
    38  
    39  // Peek returns trailer value for the given key.
    40  //
    41  // Returned value is valid until the next call to Trailer.
    42  // Do not store references to returned value. Make copies instead.
    43  func (t *Trailer) Peek(key string) []byte {
    44  	k := getHeaderKeyBytes(&t.bufKV, key, t.disableNormalizing)
    45  	return peekArgBytes(t.h, k)
    46  }
    47  
    48  // Del deletes trailer with the given key.
    49  func (t *Trailer) Del(key string) {
    50  	k := getHeaderKeyBytes(&t.bufKV, key, t.disableNormalizing)
    51  	t.h = delAllArgsBytes(t.h, k)
    52  }
    53  
    54  // VisitAll calls f for each header.
    55  func (t *Trailer) VisitAll(f func(key, value []byte)) {
    56  	visitArgs(t.h, f)
    57  }
    58  
    59  // Set sets the given 'key: value' trailer.
    60  //
    61  // If the key is forbidden by RFC 7230, section 4.1.2, Set will return error
    62  func (t *Trailer) Set(key, value string) error {
    63  	initHeaderKV(&t.bufKV, key, value, t.disableNormalizing)
    64  	return t.setArgBytes(t.bufKV.key, t.bufKV.value, ArgsHasValue)
    65  }
    66  
    67  // Add adds the given 'key: value' trailer.
    68  //
    69  // Multiple headers with the same key may be added with this function.
    70  // Use Set for setting a single header for the given key.
    71  //
    72  // If the key is forbidden by RFC 7230, section 4.1.2, Add will return error
    73  func (t *Trailer) Add(key, value string) error {
    74  	initHeaderKV(&t.bufKV, key, value, t.disableNormalizing)
    75  	return t.addArgBytes(t.bufKV.key, t.bufKV.value, ArgsHasValue)
    76  }
    77  
    78  func (t *Trailer) addArgBytes(key, value []byte, noValue bool) error {
    79  	if IsBadTrailer(key) {
    80  		return errs.NewPublicf("forbidden trailer key: %q", key)
    81  	}
    82  	t.h = appendArgBytes(t.h, key, value, noValue)
    83  	return nil
    84  }
    85  
    86  func (t *Trailer) setArgBytes(key, value []byte, noValue bool) error {
    87  	if IsBadTrailer(key) {
    88  		return errs.NewPublicf("forbidden trailer key: %q", key)
    89  	}
    90  	t.h = setArgBytes(t.h, key, value, noValue)
    91  	return nil
    92  }
    93  
    94  func (t *Trailer) UpdateArgBytes(key, value []byte) error {
    95  	if IsBadTrailer(key) {
    96  		return errs.NewPublicf("forbidden trailer key: %q", key)
    97  	}
    98  
    99  	t.h = updateArgBytes(t.h, key, value)
   100  	return nil
   101  }
   102  
   103  func (t *Trailer) GetTrailers() []argsKV {
   104  	return t.h
   105  }
   106  
   107  func (t *Trailer) Empty() bool {
   108  	return len(t.h) == 0
   109  }
   110  
   111  // GetBytes return the 'Trailer' Header which is composed by the Trailer key
   112  func (t *Trailer) GetBytes() []byte {
   113  	var dst []byte
   114  	for i, n := 0, len(t.h); i < n; i++ {
   115  		kv := &t.h[i]
   116  		dst = append(dst, kv.key...)
   117  		if i+1 < n {
   118  			dst = append(dst, bytestr.StrCommaSpace...)
   119  		}
   120  	}
   121  	return dst
   122  }
   123  
   124  func (t *Trailer) ResetSkipNormalize() {
   125  	t.h = t.h[:0]
   126  }
   127  
   128  func (t *Trailer) Reset() {
   129  	t.disableNormalizing = false
   130  	t.ResetSkipNormalize()
   131  }
   132  
   133  func (t *Trailer) DisableNormalizing() {
   134  	t.disableNormalizing = true
   135  }
   136  
   137  func (t *Trailer) IsDisableNormalizing() bool {
   138  	return t.disableNormalizing
   139  }
   140  
   141  // CopyTo copies all the trailer to dst.
   142  func (t *Trailer) CopyTo(dst *Trailer) {
   143  	dst.Reset()
   144  
   145  	dst.disableNormalizing = t.disableNormalizing
   146  	dst.h = copyArgs(dst.h, t.h)
   147  }
   148  
   149  func (t *Trailer) SetTrailers(trailers []byte) (err error) {
   150  	t.ResetSkipNormalize()
   151  	for i := -1; i+1 < len(trailers); {
   152  		trailers = trailers[i+1:]
   153  		i = bytes.IndexByte(trailers, ',')
   154  		if i < 0 {
   155  			i = len(trailers)
   156  		}
   157  		trailerKey := trailers[:i]
   158  		for len(trailerKey) > 0 && trailerKey[0] == ' ' {
   159  			trailerKey = trailerKey[1:]
   160  		}
   161  		for len(trailerKey) > 0 && trailerKey[len(trailerKey)-1] == ' ' {
   162  			trailerKey = trailerKey[:len(trailerKey)-1]
   163  		}
   164  
   165  		utils.NormalizeHeaderKey(trailerKey, t.disableNormalizing)
   166  		err = t.addArgBytes(trailerKey, nilByteSlice, argsNoValue)
   167  	}
   168  	return
   169  }
   170  
   171  func (t *Trailer) Header() []byte {
   172  	t.bufKV.value = t.AppendBytes(t.bufKV.value[:0])
   173  	return t.bufKV.value
   174  }
   175  
   176  func (t *Trailer) AppendBytes(dst []byte) []byte {
   177  	for i, n := 0, len(t.h); i < n; i++ {
   178  		kv := &t.h[i]
   179  		dst = appendHeaderLine(dst, kv.key, kv.value)
   180  	}
   181  
   182  	dst = append(dst, bytestr.StrCRLF...)
   183  	return dst
   184  }
   185  
   186  func IsBadTrailer(key []byte) bool {
   187  	switch key[0] | 0x20 {
   188  	case 'a':
   189  		return utils.CaseInsensitiveCompare(key, bytestr.StrAuthorization)
   190  	case 'c':
   191  		if len(key) >= len(consts.HeaderContentType) && utils.CaseInsensitiveCompare(key[:8], bytestr.StrContentType[:8]) {
   192  			// skip compare prefix 'Content-'
   193  			return utils.CaseInsensitiveCompare(key[8:], bytestr.StrContentEncoding[8:]) ||
   194  				utils.CaseInsensitiveCompare(key[8:], bytestr.StrContentLength[8:]) ||
   195  				utils.CaseInsensitiveCompare(key[8:], bytestr.StrContentType[8:]) ||
   196  				utils.CaseInsensitiveCompare(key[8:], bytestr.StrContentRange[8:])
   197  		}
   198  		return utils.CaseInsensitiveCompare(key, bytestr.StrConnection)
   199  	case 'e':
   200  		return utils.CaseInsensitiveCompare(key, bytestr.StrExpect)
   201  	case 'h':
   202  		return utils.CaseInsensitiveCompare(key, bytestr.StrHost)
   203  	case 'k':
   204  		return utils.CaseInsensitiveCompare(key, bytestr.StrKeepAlive)
   205  	case 'm':
   206  		return utils.CaseInsensitiveCompare(key, bytestr.StrMaxForwards)
   207  	case 'p':
   208  		if len(key) >= len(consts.HeaderProxyConnection) && utils.CaseInsensitiveCompare(key[:6], bytestr.StrProxyConnection[:6]) {
   209  			// skip compare prefix 'Proxy-'
   210  			return utils.CaseInsensitiveCompare(key[6:], bytestr.StrProxyConnection[6:]) ||
   211  				utils.CaseInsensitiveCompare(key[6:], bytestr.StrProxyAuthenticate[6:]) ||
   212  				utils.CaseInsensitiveCompare(key[6:], bytestr.StrProxyAuthorization[6:])
   213  		}
   214  	case 'r':
   215  		return utils.CaseInsensitiveCompare(key, bytestr.StrRange)
   216  	case 't':
   217  		return utils.CaseInsensitiveCompare(key, bytestr.StrTE) ||
   218  			utils.CaseInsensitiveCompare(key, bytestr.StrTrailer) ||
   219  			utils.CaseInsensitiveCompare(key, bytestr.StrTransferEncoding)
   220  	case 'w':
   221  		return utils.CaseInsensitiveCompare(key, bytestr.StrWWWAuthenticate)
   222  	}
   223  	return false
   224  }