github.com/cloudwego/hertz@v0.9.3/pkg/protocol/trailer.go (about) 1 /* 2 * Copyright 2023 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package protocol 18 19 import ( 20 "bytes" 21 22 "github.com/cloudwego/hertz/internal/bytestr" 23 errs "github.com/cloudwego/hertz/pkg/common/errors" 24 "github.com/cloudwego/hertz/pkg/common/utils" 25 "github.com/cloudwego/hertz/pkg/protocol/consts" 26 ) 27 28 type Trailer struct { 29 h []argsKV 30 bufKV argsKV 31 disableNormalizing bool 32 } 33 34 // Get returns trailer value for the given key. 35 func (t *Trailer) Get(key string) string { 36 return string(t.Peek(key)) 37 } 38 39 // Peek returns trailer value for the given key. 40 // 41 // Returned value is valid until the next call to Trailer. 42 // Do not store references to returned value. Make copies instead. 43 func (t *Trailer) Peek(key string) []byte { 44 k := getHeaderKeyBytes(&t.bufKV, key, t.disableNormalizing) 45 return peekArgBytes(t.h, k) 46 } 47 48 // Del deletes trailer with the given key. 49 func (t *Trailer) Del(key string) { 50 k := getHeaderKeyBytes(&t.bufKV, key, t.disableNormalizing) 51 t.h = delAllArgsBytes(t.h, k) 52 } 53 54 // VisitAll calls f for each header. 55 func (t *Trailer) VisitAll(f func(key, value []byte)) { 56 visitArgs(t.h, f) 57 } 58 59 // Set sets the given 'key: value' trailer. 60 // 61 // If the key is forbidden by RFC 7230, section 4.1.2, Set will return error 62 func (t *Trailer) Set(key, value string) error { 63 initHeaderKV(&t.bufKV, key, value, t.disableNormalizing) 64 return t.setArgBytes(t.bufKV.key, t.bufKV.value, ArgsHasValue) 65 } 66 67 // Add adds the given 'key: value' trailer. 68 // 69 // Multiple headers with the same key may be added with this function. 70 // Use Set for setting a single header for the given key. 71 // 72 // If the key is forbidden by RFC 7230, section 4.1.2, Add will return error 73 func (t *Trailer) Add(key, value string) error { 74 initHeaderKV(&t.bufKV, key, value, t.disableNormalizing) 75 return t.addArgBytes(t.bufKV.key, t.bufKV.value, ArgsHasValue) 76 } 77 78 func (t *Trailer) addArgBytes(key, value []byte, noValue bool) error { 79 if IsBadTrailer(key) { 80 return errs.NewPublicf("forbidden trailer key: %q", key) 81 } 82 t.h = appendArgBytes(t.h, key, value, noValue) 83 return nil 84 } 85 86 func (t *Trailer) setArgBytes(key, value []byte, noValue bool) error { 87 if IsBadTrailer(key) { 88 return errs.NewPublicf("forbidden trailer key: %q", key) 89 } 90 t.h = setArgBytes(t.h, key, value, noValue) 91 return nil 92 } 93 94 func (t *Trailer) UpdateArgBytes(key, value []byte) error { 95 if IsBadTrailer(key) { 96 return errs.NewPublicf("forbidden trailer key: %q", key) 97 } 98 99 t.h = updateArgBytes(t.h, key, value) 100 return nil 101 } 102 103 func (t *Trailer) GetTrailers() []argsKV { 104 return t.h 105 } 106 107 func (t *Trailer) Empty() bool { 108 return len(t.h) == 0 109 } 110 111 // GetBytes return the 'Trailer' Header which is composed by the Trailer key 112 func (t *Trailer) GetBytes() []byte { 113 var dst []byte 114 for i, n := 0, len(t.h); i < n; i++ { 115 kv := &t.h[i] 116 dst = append(dst, kv.key...) 117 if i+1 < n { 118 dst = append(dst, bytestr.StrCommaSpace...) 119 } 120 } 121 return dst 122 } 123 124 func (t *Trailer) ResetSkipNormalize() { 125 t.h = t.h[:0] 126 } 127 128 func (t *Trailer) Reset() { 129 t.disableNormalizing = false 130 t.ResetSkipNormalize() 131 } 132 133 func (t *Trailer) DisableNormalizing() { 134 t.disableNormalizing = true 135 } 136 137 func (t *Trailer) IsDisableNormalizing() bool { 138 return t.disableNormalizing 139 } 140 141 // CopyTo copies all the trailer to dst. 142 func (t *Trailer) CopyTo(dst *Trailer) { 143 dst.Reset() 144 145 dst.disableNormalizing = t.disableNormalizing 146 dst.h = copyArgs(dst.h, t.h) 147 } 148 149 func (t *Trailer) SetTrailers(trailers []byte) (err error) { 150 t.ResetSkipNormalize() 151 for i := -1; i+1 < len(trailers); { 152 trailers = trailers[i+1:] 153 i = bytes.IndexByte(trailers, ',') 154 if i < 0 { 155 i = len(trailers) 156 } 157 trailerKey := trailers[:i] 158 for len(trailerKey) > 0 && trailerKey[0] == ' ' { 159 trailerKey = trailerKey[1:] 160 } 161 for len(trailerKey) > 0 && trailerKey[len(trailerKey)-1] == ' ' { 162 trailerKey = trailerKey[:len(trailerKey)-1] 163 } 164 165 utils.NormalizeHeaderKey(trailerKey, t.disableNormalizing) 166 err = t.addArgBytes(trailerKey, nilByteSlice, argsNoValue) 167 } 168 return 169 } 170 171 func (t *Trailer) Header() []byte { 172 t.bufKV.value = t.AppendBytes(t.bufKV.value[:0]) 173 return t.bufKV.value 174 } 175 176 func (t *Trailer) AppendBytes(dst []byte) []byte { 177 for i, n := 0, len(t.h); i < n; i++ { 178 kv := &t.h[i] 179 dst = appendHeaderLine(dst, kv.key, kv.value) 180 } 181 182 dst = append(dst, bytestr.StrCRLF...) 183 return dst 184 } 185 186 func IsBadTrailer(key []byte) bool { 187 switch key[0] | 0x20 { 188 case 'a': 189 return utils.CaseInsensitiveCompare(key, bytestr.StrAuthorization) 190 case 'c': 191 if len(key) >= len(consts.HeaderContentType) && utils.CaseInsensitiveCompare(key[:8], bytestr.StrContentType[:8]) { 192 // skip compare prefix 'Content-' 193 return utils.CaseInsensitiveCompare(key[8:], bytestr.StrContentEncoding[8:]) || 194 utils.CaseInsensitiveCompare(key[8:], bytestr.StrContentLength[8:]) || 195 utils.CaseInsensitiveCompare(key[8:], bytestr.StrContentType[8:]) || 196 utils.CaseInsensitiveCompare(key[8:], bytestr.StrContentRange[8:]) 197 } 198 return utils.CaseInsensitiveCompare(key, bytestr.StrConnection) 199 case 'e': 200 return utils.CaseInsensitiveCompare(key, bytestr.StrExpect) 201 case 'h': 202 return utils.CaseInsensitiveCompare(key, bytestr.StrHost) 203 case 'k': 204 return utils.CaseInsensitiveCompare(key, bytestr.StrKeepAlive) 205 case 'm': 206 return utils.CaseInsensitiveCompare(key, bytestr.StrMaxForwards) 207 case 'p': 208 if len(key) >= len(consts.HeaderProxyConnection) && utils.CaseInsensitiveCompare(key[:6], bytestr.StrProxyConnection[:6]) { 209 // skip compare prefix 'Proxy-' 210 return utils.CaseInsensitiveCompare(key[6:], bytestr.StrProxyConnection[6:]) || 211 utils.CaseInsensitiveCompare(key[6:], bytestr.StrProxyAuthenticate[6:]) || 212 utils.CaseInsensitiveCompare(key[6:], bytestr.StrProxyAuthorization[6:]) 213 } 214 case 'r': 215 return utils.CaseInsensitiveCompare(key, bytestr.StrRange) 216 case 't': 217 return utils.CaseInsensitiveCompare(key, bytestr.StrTE) || 218 utils.CaseInsensitiveCompare(key, bytestr.StrTrailer) || 219 utils.CaseInsensitiveCompare(key, bytestr.StrTransferEncoding) 220 case 'w': 221 return utils.CaseInsensitiveCompare(key, bytestr.StrWWWAuthenticate) 222 } 223 return false 224 }