github.com/bir3/gocompiler@v0.9.2202/extra/compress/fse/decompress.go (about) 1 package fse 2 3 import ( 4 "errors" 5 "fmt" 6 ) 7 8 const ( 9 tablelogAbsoluteMax = 15 10 ) 11 12 // Decompress a block of data. 13 // You can provide a scratch buffer to avoid allocations. 14 // If nil is provided a temporary one will be allocated. 15 // It is possible, but by no way guaranteed that corrupt data will 16 // return an error. 17 // It is up to the caller to verify integrity of the returned data. 18 // Use a predefined Scrach to set maximum acceptable output size. 19 func Decompress(b []byte, s *Scratch) ([]byte, error) { 20 s, err := s.prepare(b) 21 if err != nil { 22 return nil, err 23 } 24 s.Out = s.Out[:0] 25 err = s.readNCount() 26 if err != nil { 27 return nil, err 28 } 29 err = s.buildDtable() 30 if err != nil { 31 return nil, err 32 } 33 err = s.decompress() 34 if err != nil { 35 return nil, err 36 } 37 38 return s.Out, nil 39 } 40 41 // readNCount will read the symbol distribution so decoding tables can be constructed. 42 func (s *Scratch) readNCount() error { 43 var ( 44 charnum uint16 45 previous0 bool 46 b = &s.br 47 ) 48 iend := b.remain() 49 if iend < 4 { 50 return errors.New("input too small") 51 } 52 bitStream := b.Uint32() 53 nbBits := uint((bitStream & 0xF) + minTablelog) // extract tableLog 54 if nbBits > tablelogAbsoluteMax { 55 return errors.New("tableLog too large") 56 } 57 bitStream >>= 4 58 bitCount := uint(4) 59 60 s.actualTableLog = uint8(nbBits) 61 remaining := int32((1 << nbBits) + 1) 62 threshold := int32(1 << nbBits) 63 gotTotal := int32(0) 64 nbBits++ 65 66 for remaining > 1 { 67 if previous0 { 68 n0 := charnum 69 for (bitStream & 0xFFFF) == 0xFFFF { 70 n0 += 24 71 if b.off < iend-5 { 72 b.advance(2) 73 bitStream = b.Uint32() >> bitCount 74 } else { 75 bitStream >>= 16 76 bitCount += 16 77 } 78 } 79 for (bitStream & 3) == 3 { 80 n0 += 3 81 bitStream >>= 2 82 bitCount += 2 83 } 84 n0 += uint16(bitStream & 3) 85 bitCount += 2 86 if n0 > maxSymbolValue { 87 return errors.New("maxSymbolValue too small") 88 } 89 for charnum < n0 { 90 s.norm[charnum&0xff] = 0 91 charnum++ 92 } 93 94 if b.off <= iend-7 || b.off+int(bitCount>>3) <= iend-4 { 95 b.advance(bitCount >> 3) 96 bitCount &= 7 97 bitStream = b.Uint32() >> bitCount 98 } else { 99 bitStream >>= 2 100 } 101 } 102 103 max := (2*(threshold) - 1) - (remaining) 104 var count int32 105 106 if (int32(bitStream) & (threshold - 1)) < max { 107 count = int32(bitStream) & (threshold - 1) 108 bitCount += nbBits - 1 109 } else { 110 count = int32(bitStream) & (2*threshold - 1) 111 if count >= threshold { 112 count -= max 113 } 114 bitCount += nbBits 115 } 116 117 count-- // extra accuracy 118 if count < 0 { 119 // -1 means +1 120 remaining += count 121 gotTotal -= count 122 } else { 123 remaining -= count 124 gotTotal += count 125 } 126 s.norm[charnum&0xff] = int16(count) 127 charnum++ 128 previous0 = count == 0 129 for remaining < threshold { 130 nbBits-- 131 threshold >>= 1 132 } 133 if b.off <= iend-7 || b.off+int(bitCount>>3) <= iend-4 { 134 b.advance(bitCount >> 3) 135 bitCount &= 7 136 } else { 137 bitCount -= (uint)(8 * (len(b.b) - 4 - b.off)) 138 b.off = len(b.b) - 4 139 } 140 bitStream = b.Uint32() >> (bitCount & 31) 141 } 142 s.symbolLen = charnum 143 144 if s.symbolLen <= 1 { 145 return fmt.Errorf("symbolLen (%d) too small", s.symbolLen) 146 } 147 if s.symbolLen > maxSymbolValue+1 { 148 return fmt.Errorf("symbolLen (%d) too big", s.symbolLen) 149 } 150 if remaining != 1 { 151 return fmt.Errorf("corruption detected (remaining %d != 1)", remaining) 152 } 153 if bitCount > 32 { 154 return fmt.Errorf("corruption detected (bitCount %d > 32)", bitCount) 155 } 156 if gotTotal != 1<<s.actualTableLog { 157 return fmt.Errorf("corruption detected (total %d != %d)", gotTotal, 1<<s.actualTableLog) 158 } 159 b.advance((bitCount + 7) >> 3) 160 return nil 161 } 162 163 // decSymbol contains information about a state entry, 164 // Including the state offset base, the output symbol and 165 // the number of bits to read for the low part of the destination state. 166 type decSymbol struct { 167 newState uint16 168 symbol uint8 169 nbBits uint8 170 } 171 172 // allocDtable will allocate decoding tables if they are not big enough. 173 func (s *Scratch) allocDtable() { 174 tableSize := 1 << s.actualTableLog 175 if cap(s.decTable) < tableSize { 176 s.decTable = make([]decSymbol, tableSize) 177 } 178 s.decTable = s.decTable[:tableSize] 179 180 if cap(s.ct.tableSymbol) < 256 { 181 s.ct.tableSymbol = make([]byte, 256) 182 } 183 s.ct.tableSymbol = s.ct.tableSymbol[:256] 184 185 if cap(s.ct.stateTable) < 256 { 186 s.ct.stateTable = make([]uint16, 256) 187 } 188 s.ct.stateTable = s.ct.stateTable[:256] 189 } 190 191 // buildDtable will build the decoding table. 192 func (s *Scratch) buildDtable() error { 193 tableSize := uint32(1 << s.actualTableLog) 194 highThreshold := tableSize - 1 195 s.allocDtable() 196 symbolNext := s.ct.stateTable[:256] 197 198 // Init, lay down lowprob symbols 199 s.zeroBits = false 200 { 201 largeLimit := int16(1 << (s.actualTableLog - 1)) 202 for i, v := range s.norm[:s.symbolLen] { 203 if v == -1 { 204 s.decTable[highThreshold].symbol = uint8(i) 205 highThreshold-- 206 symbolNext[i] = 1 207 } else { 208 if v >= largeLimit { 209 s.zeroBits = true 210 } 211 symbolNext[i] = uint16(v) 212 } 213 } 214 } 215 // Spread symbols 216 { 217 tableMask := tableSize - 1 218 step := tableStep(tableSize) 219 position := uint32(0) 220 for ss, v := range s.norm[:s.symbolLen] { 221 for i := 0; i < int(v); i++ { 222 s.decTable[position].symbol = uint8(ss) 223 position = (position + step) & tableMask 224 for position > highThreshold { 225 // lowprob area 226 position = (position + step) & tableMask 227 } 228 } 229 } 230 if position != 0 { 231 // position must reach all cells once, otherwise normalizedCounter is incorrect 232 return errors.New("corrupted input (position != 0)") 233 } 234 } 235 236 // Build Decoding table 237 { 238 tableSize := uint16(1 << s.actualTableLog) 239 for u, v := range s.decTable { 240 symbol := v.symbol 241 nextState := symbolNext[symbol] 242 symbolNext[symbol] = nextState + 1 243 nBits := s.actualTableLog - byte(highBits(uint32(nextState))) 244 s.decTable[u].nbBits = nBits 245 newState := (nextState << nBits) - tableSize 246 if newState >= tableSize { 247 return fmt.Errorf("newState (%d) outside table size (%d)", newState, tableSize) 248 } 249 if newState == uint16(u) && nBits == 0 { 250 // Seems weird that this is possible with nbits > 0. 251 return fmt.Errorf("newState (%d) == oldState (%d) and no bits", newState, u) 252 } 253 s.decTable[u].newState = newState 254 } 255 } 256 return nil 257 } 258 259 // decompress will decompress the bitstream. 260 // If the buffer is over-read an error is returned. 261 func (s *Scratch) decompress() error { 262 br := &s.bits 263 if err := br.init(s.br.unread()); err != nil { 264 return err 265 } 266 267 var s1, s2 decoder 268 // Initialize and decode first state and symbol. 269 s1.init(br, s.decTable, s.actualTableLog) 270 s2.init(br, s.decTable, s.actualTableLog) 271 272 // Use temp table to avoid bound checks/append penalty. 273 var tmp = s.ct.tableSymbol[:256] 274 var off uint8 275 276 // Main part 277 if !s.zeroBits { 278 for br.off >= 8 { 279 br.fillFast() 280 tmp[off+0] = s1.nextFast() 281 tmp[off+1] = s2.nextFast() 282 br.fillFast() 283 tmp[off+2] = s1.nextFast() 284 tmp[off+3] = s2.nextFast() 285 off += 4 286 // When off is 0, we have overflowed and should write. 287 if off == 0 { 288 s.Out = append(s.Out, tmp...) 289 if len(s.Out) >= s.DecompressLimit { 290 return fmt.Errorf("output size (%d) > DecompressLimit (%d)", len(s.Out), s.DecompressLimit) 291 } 292 } 293 } 294 } else { 295 for br.off >= 8 { 296 br.fillFast() 297 tmp[off+0] = s1.next() 298 tmp[off+1] = s2.next() 299 br.fillFast() 300 tmp[off+2] = s1.next() 301 tmp[off+3] = s2.next() 302 off += 4 303 if off == 0 { 304 s.Out = append(s.Out, tmp...) 305 // When off is 0, we have overflowed and should write. 306 if len(s.Out) >= s.DecompressLimit { 307 return fmt.Errorf("output size (%d) > DecompressLimit (%d)", len(s.Out), s.DecompressLimit) 308 } 309 } 310 } 311 } 312 s.Out = append(s.Out, tmp[:off]...) 313 314 // Final bits, a bit more expensive check 315 for { 316 if s1.finished() { 317 s.Out = append(s.Out, s1.final(), s2.final()) 318 break 319 } 320 br.fill() 321 s.Out = append(s.Out, s1.next()) 322 if s2.finished() { 323 s.Out = append(s.Out, s2.final(), s1.final()) 324 break 325 } 326 s.Out = append(s.Out, s2.next()) 327 if len(s.Out) >= s.DecompressLimit { 328 return fmt.Errorf("output size (%d) > DecompressLimit (%d)", len(s.Out), s.DecompressLimit) 329 } 330 } 331 return br.close() 332 } 333 334 // decoder keeps track of the current state and updates it from the bitstream. 335 type decoder struct { 336 state uint16 337 br *bitReader 338 dt []decSymbol 339 } 340 341 // init will initialize the decoder and read the first state from the stream. 342 func (d *decoder) init(in *bitReader, dt []decSymbol, tableLog uint8) { 343 d.dt = dt 344 d.br = in 345 d.state = in.getBits(tableLog) 346 } 347 348 // next returns the next symbol and sets the next state. 349 // At least tablelog bits must be available in the bit reader. 350 func (d *decoder) next() uint8 { 351 n := &d.dt[d.state] 352 lowBits := d.br.getBits(n.nbBits) 353 d.state = n.newState + lowBits 354 return n.symbol 355 } 356 357 // finished returns true if all bits have been read from the bitstream 358 // and the next state would require reading bits from the input. 359 func (d *decoder) finished() bool { 360 return d.br.finished() && d.dt[d.state].nbBits > 0 361 } 362 363 // final returns the current state symbol without decoding the next. 364 func (d *decoder) final() uint8 { 365 return d.dt[d.state].symbol 366 } 367 368 // nextFast returns the next symbol and sets the next state. 369 // This can only be used if no symbols are 0 bits. 370 // At least tablelog bits must be available in the bit reader. 371 func (d *decoder) nextFast() uint8 { 372 n := d.dt[d.state] 373 lowBits := d.br.getBitsFast(n.nbBits) 374 d.state = n.newState + lowBits 375 return n.symbol 376 }