github.com/go-asm/go@v1.21.1-0.20240213172139-40c5ead50c48/zstd/huff.go (about) 1 // Copyright 2023 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package zstd 6 7 import ( 8 "io" 9 "math/bits" 10 ) 11 12 // maxHuffmanBits is the largest possible Huffman table bits. 13 const maxHuffmanBits = 11 14 15 // readHuff reads Huffman table from data starting at off into table. 16 // Each entry in a Huffman table is a pair of bytes. 17 // The high byte is the encoded value. The low byte is the number 18 // of bits used to encode that value. We index into the table 19 // with a value of size tableBits. A value that requires fewer bits 20 // appear in the table multiple times. 21 // This returns the number of bits in the Huffman table and the new offset. 22 // RFC 4.2.1. 23 func (r *Reader) readHuff(data block, off int, table []uint16) (tableBits, roff int, err error) { 24 if off >= len(data) { 25 return 0, 0, r.makeEOFError(off) 26 } 27 28 hdr := data[off] 29 off++ 30 31 var weights [256]uint8 32 var count int 33 if hdr < 128 { 34 // The table is compressed using an FSE. RFC 4.2.1.2. 35 if len(r.fseScratch) < 1<<6 { 36 r.fseScratch = make([]fseEntry, 1<<6) 37 } 38 fseBits, noff, err := r.readFSE(data, off, 255, 6, r.fseScratch) 39 if err != nil { 40 return 0, 0, err 41 } 42 fseTable := r.fseScratch 43 44 if off+int(hdr) > len(data) { 45 return 0, 0, r.makeEOFError(off) 46 } 47 48 rbr, err := r.makeReverseBitReader(data, off+int(hdr)-1, noff) 49 if err != nil { 50 return 0, 0, err 51 } 52 53 state1, err := rbr.val(uint8(fseBits)) 54 if err != nil { 55 return 0, 0, err 56 } 57 58 state2, err := rbr.val(uint8(fseBits)) 59 if err != nil { 60 return 0, 0, err 61 } 62 63 // There are two independent FSE streams, tracked by 64 // state1 and state2. We decode them alternately. 65 66 for { 67 pt := &fseTable[state1] 68 if !rbr.fetch(pt.bits) { 69 if count >= 254 { 70 return 0, 0, rbr.makeError("Huffman count overflow") 71 } 72 weights[count] = pt.sym 73 weights[count+1] = fseTable[state2].sym 74 count += 2 75 break 76 } 77 78 v, err := rbr.val(pt.bits) 79 if err != nil { 80 return 0, 0, err 81 } 82 state1 = uint32(pt.base) + v 83 84 if count >= 255 { 85 return 0, 0, rbr.makeError("Huffman count overflow") 86 } 87 88 weights[count] = pt.sym 89 count++ 90 91 pt = &fseTable[state2] 92 93 if !rbr.fetch(pt.bits) { 94 if count >= 254 { 95 return 0, 0, rbr.makeError("Huffman count overflow") 96 } 97 weights[count] = pt.sym 98 weights[count+1] = fseTable[state1].sym 99 count += 2 100 break 101 } 102 103 v, err = rbr.val(pt.bits) 104 if err != nil { 105 return 0, 0, err 106 } 107 state2 = uint32(pt.base) + v 108 109 if count >= 255 { 110 return 0, 0, rbr.makeError("Huffman count overflow") 111 } 112 113 weights[count] = pt.sym 114 count++ 115 } 116 117 off += int(hdr) 118 } else { 119 // The table is not compressed. Each weight is 4 bits. 120 121 count = int(hdr) - 127 122 if off+((count+1)/2) >= len(data) { 123 return 0, 0, io.ErrUnexpectedEOF 124 } 125 for i := 0; i < count; i += 2 { 126 b := data[off] 127 off++ 128 weights[i] = b >> 4 129 weights[i+1] = b & 0xf 130 } 131 } 132 133 // RFC 4.2.1.3. 134 135 var weightMark [13]uint32 136 weightMask := uint32(0) 137 for _, w := range weights[:count] { 138 if w > 12 { 139 return 0, 0, r.makeError(off, "Huffman weight overflow") 140 } 141 weightMark[w]++ 142 if w > 0 { 143 weightMask += 1 << (w - 1) 144 } 145 } 146 if weightMask == 0 { 147 return 0, 0, r.makeError(off, "bad Huffman weights") 148 } 149 150 tableBits = 32 - bits.LeadingZeros32(weightMask) 151 if tableBits > maxHuffmanBits { 152 return 0, 0, r.makeError(off, "bad Huffman weights") 153 } 154 155 if len(table) < 1<<tableBits { 156 return 0, 0, r.makeError(off, "Huffman table too small") 157 } 158 159 // Work out the last weight value, which is omitted because 160 // the weights must sum to a power of two. 161 left := (uint32(1) << tableBits) - weightMask 162 if left == 0 { 163 return 0, 0, r.makeError(off, "bad Huffman weights") 164 } 165 highBit := 31 - bits.LeadingZeros32(left) 166 if uint32(1)<<highBit != left { 167 return 0, 0, r.makeError(off, "bad Huffman weights") 168 } 169 if count >= 256 { 170 return 0, 0, r.makeError(off, "Huffman weight overflow") 171 } 172 weights[count] = uint8(highBit + 1) 173 count++ 174 weightMark[highBit+1]++ 175 176 if weightMark[1] < 2 || weightMark[1]&1 != 0 { 177 return 0, 0, r.makeError(off, "bad Huffman weights") 178 } 179 180 // Change weightMark from a count of weights to the index of 181 // the first symbol for that weight. We shift the indexes to 182 // also store how many we have seen so far, 183 next := uint32(0) 184 for i := 0; i < tableBits; i++ { 185 cur := next 186 next += weightMark[i+1] << i 187 weightMark[i+1] = cur 188 } 189 190 for i, w := range weights[:count] { 191 if w == 0 { 192 continue 193 } 194 length := uint32(1) << (w - 1) 195 tval := uint16(i)<<8 | (uint16(tableBits) + 1 - uint16(w)) 196 start := weightMark[w] 197 for j := uint32(0); j < length; j++ { 198 table[start+j] = tval 199 } 200 weightMark[w] += length 201 } 202 203 return tableBits, off, nil 204 }