github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/net/http2/hpack/huffman.go (about) 1 // Copyright 2014 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package hpack 6 7 import ( 8 "bytes" 9 "errors" 10 "io" 11 "sync" 12 ) 13 14 var bufPool = sync.Pool{ 15 New: func() interface{} { return new(bytes.Buffer) }, 16 } 17 18 // HuffmanDecode decodes the string in v and writes the expanded 19 // result to w, returning the number of bytes written to w and the 20 // Write call's return value. At most one Write call is made. 21 func HuffmanDecode(w io.Writer, v []byte) (int, error) { 22 buf := bufPool.Get().(*bytes.Buffer) 23 buf.Reset() 24 defer bufPool.Put(buf) 25 if err := huffmanDecode(buf, 0, v); err != nil { 26 return 0, err 27 } 28 return w.Write(buf.Bytes()) 29 } 30 31 // HuffmanDecodeToString decodes the string in v. 32 func HuffmanDecodeToString(v []byte) (string, error) { 33 buf := bufPool.Get().(*bytes.Buffer) 34 buf.Reset() 35 defer bufPool.Put(buf) 36 if err := huffmanDecode(buf, 0, v); err != nil { 37 return "", err 38 } 39 return buf.String(), nil 40 } 41 42 // ErrInvalidHuffman is returned for errors found decoding 43 // Huffman-encoded strings. 44 var ErrInvalidHuffman = errors.New("hpack: invalid Huffman-encoded data") 45 46 // huffmanDecode decodes v to buf. 47 // If maxLen is greater than 0, attempts to write more to buf than 48 // maxLen bytes will return ErrStringLength. 49 func huffmanDecode(buf *bytes.Buffer, maxLen int, v []byte) error { 50 rootHuffmanNode := getRootHuffmanNode() 51 n := rootHuffmanNode 52 // cur is the bit buffer that has not been fed into n. 53 // cbits is the number of low order bits in cur that are valid. 54 // sbits is the number of bits of the symbol prefix being decoded. 55 cur, cbits, sbits := uint(0), uint8(0), uint8(0) 56 for _, b := range v { 57 cur = cur<<8 | uint(b) 58 cbits += 8 59 sbits += 8 60 for cbits >= 8 { 61 idx := byte(cur >> (cbits - 8)) 62 n = n.children[idx] 63 if n == nil { 64 return ErrInvalidHuffman 65 } 66 if n.children == nil { 67 if maxLen != 0 && buf.Len() == maxLen { 68 return ErrStringLength 69 } 70 buf.WriteByte(n.sym) 71 cbits -= n.codeLen 72 n = rootHuffmanNode 73 sbits = cbits 74 } else { 75 cbits -= 8 76 } 77 } 78 } 79 for cbits > 0 { 80 n = n.children[byte(cur<<(8-cbits))] 81 if n == nil { 82 return ErrInvalidHuffman 83 } 84 if n.children != nil || n.codeLen > cbits { 85 break 86 } 87 if maxLen != 0 && buf.Len() == maxLen { 88 return ErrStringLength 89 } 90 buf.WriteByte(n.sym) 91 cbits -= n.codeLen 92 n = rootHuffmanNode 93 sbits = cbits 94 } 95 if sbits > 7 { 96 // Either there was an incomplete symbol, or overlong padding. 97 // Both are decoding errors per RFC 7541 section 5.2. 98 return ErrInvalidHuffman 99 } 100 if mask := uint(1<<cbits - 1); cur&mask != mask { 101 // Trailing bits must be a prefix of EOS per RFC 7541 section 5.2. 102 return ErrInvalidHuffman 103 } 104 105 return nil 106 } 107 108 // incomparable is a zero-width, non-comparable type. Adding it to a struct 109 // makes that struct also non-comparable, and generally doesn't add 110 // any size (as long as it's first). 111 type incomparable [0]func() 112 113 type node struct { 114 _ incomparable 115 116 // children is non-nil for internal nodes 117 children *[256]*node 118 119 // The following are only valid if children is nil: 120 codeLen uint8 // number of bits that led to the output of sym 121 sym byte // output symbol 122 } 123 124 func newInternalNode() *node { 125 return &node{children: new([256]*node)} 126 } 127 128 var ( 129 buildRootOnce sync.Once 130 lazyRootHuffmanNode *node 131 ) 132 133 func getRootHuffmanNode() *node { 134 buildRootOnce.Do(buildRootHuffmanNode) 135 return lazyRootHuffmanNode 136 } 137 138 func buildRootHuffmanNode() { 139 if len(huffmanCodes) != 256 { 140 panic("unexpected size") 141 } 142 lazyRootHuffmanNode = newInternalNode() 143 // allocate a leaf node for each of the 256 symbols 144 leaves := new([256]node) 145 146 for sym, code := range huffmanCodes { 147 codeLen := huffmanCodeLen[sym] 148 149 cur := lazyRootHuffmanNode 150 for codeLen > 8 { 151 codeLen -= 8 152 i := uint8(code >> codeLen) 153 if cur.children[i] == nil { 154 cur.children[i] = newInternalNode() 155 } 156 cur = cur.children[i] 157 } 158 shift := 8 - codeLen 159 start, end := int(uint8(code<<shift)), int(1<<shift) 160 161 leaves[sym].sym = byte(sym) 162 leaves[sym].codeLen = codeLen 163 for i := start; i < start+end; i++ { 164 cur.children[i] = &leaves[sym] 165 } 166 } 167 } 168 169 // AppendHuffmanString appends s, as encoded in Huffman codes, to dst 170 // and returns the extended buffer. 171 func AppendHuffmanString(dst []byte, s string) []byte { 172 rembits := uint8(8) 173 174 for i := 0; i < len(s); i++ { 175 if rembits == 8 { 176 dst = append(dst, 0) 177 } 178 dst, rembits = appendByteToHuffmanCode(dst, rembits, s[i]) 179 } 180 181 if rembits < 8 { 182 // special EOS symbol 183 code := uint32(0x3fffffff) 184 nbits := uint8(30) 185 186 t := uint8(code >> (nbits - rembits)) 187 dst[len(dst)-1] |= t 188 } 189 190 return dst 191 } 192 193 // HuffmanEncodeLength returns the number of bytes required to encode 194 // s in Huffman codes. The result is round up to byte boundary. 195 func HuffmanEncodeLength(s string) uint64 { 196 n := uint64(0) 197 for i := 0; i < len(s); i++ { 198 n += uint64(huffmanCodeLen[s[i]]) 199 } 200 return (n + 7) / 8 201 } 202 203 // appendByteToHuffmanCode appends Huffman code for c to dst and 204 // returns the extended buffer and the remaining bits in the last 205 // element. The appending is not byte aligned and the remaining bits 206 // in the last element of dst is given in rembits. 207 func appendByteToHuffmanCode(dst []byte, rembits uint8, c byte) ([]byte, uint8) { 208 code := huffmanCodes[c] 209 nbits := huffmanCodeLen[c] 210 211 for { 212 if rembits > nbits { 213 t := uint8(code << (rembits - nbits)) 214 dst[len(dst)-1] |= t 215 rembits -= nbits 216 break 217 } 218 219 t := uint8(code >> (nbits - rembits)) 220 dst[len(dst)-1] |= t 221 222 nbits -= rembits 223 rembits = 8 224 225 if nbits == 0 { 226 break 227 } 228 229 dst = append(dst, 0) 230 } 231 232 return dst, rembits 233 }