github.com/Serizao/go-winio@v0.0.0-20230906082528-f02f7f4ad6e8/wim/lzx/lzx.go (about) 1 // Package lzx implements a decompressor for the the WIM variant of the 2 // LZX compression algorithm. 3 // 4 // The LZX algorithm is an earlier variant of LZX DELTA, which is documented 5 // at https://msdn.microsoft.com/en-us/library/cc483133(v=exchg.80).aspx. 6 package lzx 7 8 import ( 9 "bytes" 10 "encoding/binary" 11 "errors" 12 "io" 13 ) 14 15 const ( 16 maincodecount = 496 17 maincodesplit = 256 18 lencodecount = 249 19 lenshift = 9 20 codemask = 0x1ff 21 tablebits = 9 22 tablesize = 1 << tablebits 23 24 maxBlockSize = 32768 25 windowSize = 32768 26 27 maxTreePathLen = 16 28 29 e8filesize = 12000000 30 maxe8offset = 0x3fffffff 31 32 verbatimBlock = 1 33 alignedOffsetBlock = 2 34 uncompressedBlock = 3 35 ) 36 37 var footerBits = [...]byte{ 38 0, 0, 0, 0, 1, 1, 2, 2, 39 3, 3, 4, 4, 5, 5, 6, 6, 40 7, 7, 8, 8, 9, 9, 10, 10, 41 11, 11, 12, 12, 13, 13, 14, 42 } 43 44 var basePosition = [...]uint16{ 45 0, 1, 2, 3, 4, 6, 8, 12, 46 16, 24, 32, 48, 64, 96, 128, 192, 47 256, 384, 512, 768, 1024, 1536, 2048, 3072, 48 4096, 6144, 8192, 12288, 16384, 24576, 32768, 49 } 50 51 var ( 52 errCorrupt = errors.New("LZX data corrupt") 53 ) 54 55 // Reader is an interface used by the decompressor to access 56 // the input stream. If the provided io.Reader does not implement 57 // Reader, then a bufio.Reader is used. 58 type Reader interface { 59 io.Reader 60 io.ByteReader 61 } 62 63 type decompressor struct { 64 r io.Reader 65 err error 66 unaligned bool 67 nbits byte 68 c uint32 69 lru [3]uint16 70 uncompressed int 71 windowReader *bytes.Reader 72 mainlens [maincodecount]byte 73 lenlens [lencodecount]byte 74 window [windowSize]byte 75 b []byte 76 bv int 77 bo int 78 } 79 80 //go:noinline 81 func (f *decompressor) fail(err error) { 82 if f.err == nil { 83 f.err = err 84 } 85 f.bo = 0 86 f.bv = 0 87 } 88 89 func (f *decompressor) ensureAtLeast(n int) error { 90 if f.bv-f.bo >= n { 91 return nil 92 } 93 94 if f.err != nil { 95 return f.err 96 } 97 98 if f.bv != f.bo { 99 copy(f.b[:f.bv-f.bo], f.b[f.bo:f.bv]) 100 } 101 n, err := io.ReadAtLeast(f.r, f.b[f.bv-f.bo:], n) 102 if err != nil { 103 if err == io.EOF { //nolint:errorlint 104 err = io.ErrUnexpectedEOF 105 } else { 106 f.fail(err) 107 } 108 return err 109 } 110 f.bv = f.bv - f.bo + n 111 f.bo = 0 112 return nil 113 } 114 115 // feed retrieves another 16-bit word from the stream and consumes 116 // it into f.c. It returns false if there are no more bytes available. 117 // Otherwise, on error, it sets f.err. 118 func (f *decompressor) feed() bool { 119 err := f.ensureAtLeast(2) 120 if err == io.ErrUnexpectedEOF { //nolint:errorlint // returns io.ErrUnexpectedEOF by contract 121 return false 122 } 123 f.c |= (uint32(f.b[f.bo+1])<<8 | uint32(f.b[f.bo])) << (16 - f.nbits) 124 f.nbits += 16 125 f.bo += 2 126 return true 127 } 128 129 // getBits retrieves the next n bits from the byte stream. n 130 // must be <= 16. It sets f.err on error. 131 func (f *decompressor) getBits(n byte) uint16 { 132 if f.nbits < n { 133 if !f.feed() { 134 f.fail(io.ErrUnexpectedEOF) 135 } 136 } 137 c := uint16(f.c >> (32 - n)) 138 f.c <<= n 139 f.nbits -= n 140 return c 141 } 142 143 type huffman struct { 144 extra [][]uint16 145 maxbits byte 146 table [tablesize]uint16 147 } 148 149 // buildTable builds a huffman decoding table from a slice of code lengths, 150 // one per code, in order. Each code length must be <= maxTreePathLen. 151 // See https://en.wikipedia.org/wiki/Canonical_Huffman_code. 152 func buildTable(codelens []byte) *huffman { 153 // Determine the number of codes of each length, and the 154 // maximum length. 155 var count [maxTreePathLen + 1]uint 156 var max byte 157 for _, cl := range codelens { 158 count[cl]++ 159 if max < cl { 160 max = cl 161 } 162 } 163 164 if max == 0 { 165 return &huffman{} 166 } 167 168 // Determine the first code of each length. 169 var first [maxTreePathLen + 1]uint 170 code := uint(0) 171 for i := byte(1); i <= max; i++ { 172 code <<= 1 173 first[i] = code 174 code += count[i] 175 } 176 177 if code != 1<<max { 178 return nil 179 } 180 181 // Build a table for code lookup. For code sizes < max, 182 // put all possible suffixes for the code into the table, too. 183 // For max > tablebits, split long codes into additional tables 184 // of suffixes of max-tablebits length. 185 h := &huffman{maxbits: max} 186 if max > tablebits { 187 core := first[tablebits+1] / 2 // Number of codes that fit without extra tables 188 nextra := 1<<tablebits - core // Number of extra entries 189 h.extra = make([][]uint16, nextra) 190 for code := core; code < 1<<tablebits; code++ { 191 h.table[code] = uint16(code - core) 192 h.extra[code-core] = make([]uint16, 1<<(max-tablebits)) 193 } 194 } 195 196 for i, cl := range codelens { 197 if cl != 0 { 198 code := first[cl] 199 first[cl]++ 200 v := uint16(cl)<<lenshift | uint16(i) 201 if cl <= tablebits { 202 extendedCode := code << (tablebits - cl) 203 for j := uint(0); j < 1<<(tablebits-cl); j++ { 204 h.table[extendedCode+j] = v 205 } 206 } else { 207 prefix := code >> (cl - tablebits) 208 suffix := code & (1<<(cl-tablebits) - 1) 209 extendedCode := suffix << (max - cl) 210 for j := uint(0); j < 1<<(max-cl); j++ { 211 h.extra[h.table[prefix]][extendedCode+j] = v 212 } 213 } 214 } 215 } 216 217 return h 218 } 219 220 // getCode retrieves the next code using the provided 221 // huffman tree. It sets f.err on error. 222 func (f *decompressor) getCode(h *huffman) uint16 { 223 if h.maxbits > 0 { 224 if f.nbits < maxTreePathLen { 225 f.feed() 226 } 227 228 // For codes with length < tablebits, it doesn't matter 229 // what the remainder of the bits used for table lookup 230 // are, since entries with all possible suffixes were 231 // added to the table. 232 c := h.table[f.c>>(32-tablebits)] 233 if !(c >= 1<<lenshift) { 234 // The code is not in c. 235 c = h.extra[c][f.c<<tablebits>>(32-(h.maxbits-tablebits))] 236 } 237 238 n := byte(c >> lenshift) 239 if f.nbits >= n { 240 // Only consume the length of the code, not the maximum 241 // code length. 242 f.c <<= n 243 f.nbits -= n 244 return c & codemask 245 } 246 247 f.fail(io.ErrUnexpectedEOF) 248 return 0 249 } 250 251 // This is an empty tree. It should not be used. 252 f.fail(errCorrupt) 253 return 0 254 } 255 256 // readTree updates the huffman tree path lengths in lens by 257 // reading and decoding lengths from the byte stream. lens 258 // should be prepopulated with the previous block's tree's path 259 // lengths. For the first block, lens should be zero. 260 func (f *decompressor) readTree(lens []byte) error { 261 // Get the pre-tree for the main tree. 262 var pretreeLen [20]byte 263 for i := range pretreeLen { 264 pretreeLen[i] = byte(f.getBits(4)) 265 } 266 if f.err != nil { 267 return f.err 268 } 269 h := buildTable(pretreeLen[:]) 270 271 // The lengths are encoded as a series of huffman codes 272 // encoded by the pre-tree. 273 for i := 0; i < len(lens); { 274 c := byte(f.getCode(h)) 275 if f.err != nil { 276 return f.err 277 } 278 switch { 279 case c <= 16: // length is delta from previous length 280 lens[i] = (lens[i] + 17 - c) % 17 281 i++ 282 case c == 17: // next n + 4 lengths are zero 283 zeroes := int(f.getBits(4)) + 4 284 if i+zeroes > len(lens) { 285 return errCorrupt 286 } 287 for j := 0; j < zeroes; j++ { 288 lens[i+j] = 0 289 } 290 i += zeroes 291 case c == 18: // next n + 20 lengths are zero 292 zeroes := int(f.getBits(5)) + 20 293 if i+zeroes > len(lens) { 294 return errCorrupt 295 } 296 for j := 0; j < zeroes; j++ { 297 lens[i+j] = 0 298 } 299 i += zeroes 300 case c == 19: // next n + 4 lengths all have the same value 301 same := int(f.getBits(1)) + 4 302 if i+same > len(lens) { 303 return errCorrupt 304 } 305 c = byte(f.getCode(h)) 306 if c > 16 { 307 return errCorrupt 308 } 309 l := (lens[i] + 17 - c) % 17 310 for j := 0; j < same; j++ { 311 lens[i+j] = l 312 } 313 i += same 314 default: 315 return errCorrupt 316 } 317 } 318 319 if f.err != nil { 320 return f.err 321 } 322 return nil 323 } 324 325 func (f *decompressor) readBlockHeader() (byte, uint16, error) { 326 // If the previous block was an unaligned uncompressed block, restore 327 // 2-byte alignment. 328 if f.unaligned { 329 err := f.ensureAtLeast(1) 330 if err != nil { 331 return 0, 0, err 332 } 333 f.bo++ 334 f.unaligned = false 335 } 336 337 blockType := f.getBits(3) 338 full := f.getBits(1) 339 var blockSize uint16 340 if full != 0 { 341 blockSize = maxBlockSize 342 } else { 343 blockSize = f.getBits(16) 344 if blockSize > maxBlockSize { 345 return 0, 0, errCorrupt 346 } 347 } 348 349 if f.err != nil { 350 return 0, 0, f.err 351 } 352 353 switch blockType { 354 case verbatimBlock, alignedOffsetBlock: 355 // The caller will read the huffman trees. 356 case uncompressedBlock: 357 if f.nbits > 16 { 358 panic("impossible: more than one 16-bit word remains") 359 } 360 361 // Drop the remaining bits in the current 16-bit word 362 // If there are no bits left, discard a full 16-bit word. 363 n := f.nbits 364 if n == 0 { 365 n = 16 366 } 367 368 f.getBits(n) 369 370 // Read the LRU values for the next block. 371 err := f.ensureAtLeast(12) 372 if err != nil { 373 return 0, 0, err 374 } 375 376 f.lru[0] = uint16(binary.LittleEndian.Uint32(f.b[f.bo : f.bo+4])) 377 f.lru[1] = uint16(binary.LittleEndian.Uint32(f.b[f.bo+4 : f.bo+8])) 378 f.lru[2] = uint16(binary.LittleEndian.Uint32(f.b[f.bo+8 : f.bo+12])) 379 f.bo += 12 380 381 default: 382 return 0, 0, errCorrupt 383 } 384 385 return byte(blockType), blockSize, nil 386 } 387 388 // readTrees reads the two or three huffman trees for the current block. 389 // readAligned specifies whether to read the aligned offset tree. 390 func (f *decompressor) readTrees(readAligned bool) (main *huffman, length *huffman, aligned *huffman, err error) { 391 // Aligned offset blocks start with a small aligned offset tree. 392 if readAligned { 393 var alignedLen [8]byte 394 for i := range alignedLen { 395 alignedLen[i] = byte(f.getBits(3)) 396 } 397 aligned = buildTable(alignedLen[:]) 398 if aligned == nil { 399 return main, length, aligned, errors.New("corrupt") 400 } 401 } 402 403 // The main tree is encoded in two parts. 404 err = f.readTree(f.mainlens[:maincodesplit]) 405 if err != nil { 406 return main, length, aligned, err 407 } 408 err = f.readTree(f.mainlens[maincodesplit:]) 409 if err != nil { 410 return main, length, aligned, err 411 } 412 413 main = buildTable(f.mainlens[:]) 414 if main == nil { 415 return main, length, aligned, errors.New("corrupt") 416 } 417 418 // The length tree is encoding in a single part. 419 err = f.readTree(f.lenlens[:]) 420 if err != nil { 421 return main, length, aligned, err 422 } 423 424 length = buildTable(f.lenlens[:]) 425 if length == nil { 426 return main, length, aligned, errors.New("corrupt") 427 } 428 429 return main, length, aligned, f.err 430 } 431 432 // readCompressedBlock decodes a compressed block, writing into the window 433 // starting at start and ending at end, and using the provided huffman trees. 434 func (f *decompressor) readCompressedBlock(start, end uint16, hmain, hlength, haligned *huffman) (int, error) { 435 i := start 436 for i < end { 437 main := f.getCode(hmain) 438 if f.err != nil { 439 break 440 } 441 if main < 256 { 442 // Literal byte. 443 f.window[i] = byte(main) 444 i++ 445 continue 446 } 447 448 // This is a match backward in the window. Determine 449 // the offset and dlength. 450 matchlen := (main - 256) % 8 451 slot := (main - 256) / 8 452 453 // The length is either the low bits of the code, 454 // or if this is 7, is encoded with the length tree. 455 if matchlen == 7 { 456 matchlen += f.getCode(hlength) 457 } 458 matchlen += 2 459 460 var matchoffset uint16 461 if slot < 3 { //nolint:nestif // todo: simplify nested complexity 462 // The offset is one of the LRU values. 463 matchoffset = f.lru[slot] 464 f.lru[slot] = f.lru[0] 465 f.lru[0] = matchoffset 466 } else { 467 // The offset is encoded as a combination of the 468 // slot and more bits from the bit stream. 469 offsetbits := footerBits[slot] 470 var verbatimbits, alignedbits uint16 471 if offsetbits > 0 { 472 if haligned != nil && offsetbits >= 3 { 473 // This is an aligned offset block. Combine 474 // the bits written verbatim with the aligned 475 // offset tree code. 476 verbatimbits = f.getBits(offsetbits-3) * 8 477 alignedbits = f.getCode(haligned) 478 } else { 479 // There are no aligned offset bits to read, 480 // only verbatim bits. 481 verbatimbits = f.getBits(offsetbits) 482 alignedbits = 0 483 } 484 } 485 matchoffset = basePosition[slot] + verbatimbits + alignedbits - 2 486 // Update the LRU cache. 487 f.lru[2] = f.lru[1] 488 f.lru[1] = f.lru[0] 489 f.lru[0] = matchoffset 490 } 491 492 if !(matchoffset <= i && matchlen <= end-i) { 493 f.fail(errCorrupt) 494 break 495 } 496 copyend := i + matchlen 497 for ; i < copyend; i++ { 498 f.window[i] = f.window[i-matchoffset] 499 } 500 } 501 return int(i - start), f.err 502 } 503 504 // readBlock decodes the current block and returns the number of uncompressed bytes. 505 func (f *decompressor) readBlock(start uint16) (int, error) { 506 blockType, size, err := f.readBlockHeader() 507 if err != nil { 508 return 0, err 509 } 510 511 if blockType == uncompressedBlock { 512 if size%2 == 1 { 513 // Remember to realign the byte stream at the next block. 514 f.unaligned = true 515 } 516 copied := 0 517 if f.bo < f.bv { 518 copied = int(size) 519 s := int(start) 520 if copied > f.bv-f.bo { 521 copied = f.bv - f.bo 522 } 523 copy(f.window[s:s+copied], f.b[f.bo:f.bo+copied]) 524 f.bo += copied 525 } 526 n, err := io.ReadFull(f.r, f.window[start+uint16(copied):start+size]) 527 return copied + n, err 528 } 529 530 hmain, hlength, haligned, err := f.readTrees(blockType == alignedOffsetBlock) 531 if err != nil { 532 return 0, err 533 } 534 535 return f.readCompressedBlock(start, start+size, hmain, hlength, haligned) 536 } 537 538 // decodeE8 reverses the 0xe8 x86 instruction encoding that was performed 539 // to the uncompressed data before it was compressed. 540 func decodeE8(b []byte, off int64) { 541 if off > maxe8offset || len(b) < 10 { 542 return 543 } 544 for i := 0; i < len(b)-10; i++ { 545 if b[i] == 0xe8 { 546 currentPtr := int32(off) + int32(i) 547 abs := int32(binary.LittleEndian.Uint32(b[i+1 : i+5])) 548 if abs >= -currentPtr && abs < e8filesize { 549 var rel int32 550 if abs >= 0 { 551 rel = abs - currentPtr 552 } else { 553 rel = abs + e8filesize 554 } 555 binary.LittleEndian.PutUint32(b[i+1:i+5], uint32(rel)) 556 } 557 i += 4 558 } 559 } 560 } 561 562 func (f *decompressor) Read(b []byte) (int, error) { 563 // Read and uncompress everything. 564 if f.windowReader == nil { 565 n := 0 566 for n < f.uncompressed { 567 k, err := f.readBlock(uint16(n)) 568 if err != nil { 569 return 0, err 570 } 571 n += k 572 } 573 decodeE8(f.window[:f.uncompressed], 0) 574 f.windowReader = bytes.NewReader(f.window[:f.uncompressed]) 575 } 576 577 // Just read directly from the window. 578 return f.windowReader.Read(b) 579 } 580 581 func (*decompressor) Close() error { 582 return nil 583 } 584 585 // NewReader returns a new io.ReadCloser that decompresses a 586 // WIM LZX stream until uncompressedSize bytes have been returned. 587 func NewReader(r io.Reader, uncompressedSize int) (io.ReadCloser, error) { 588 if uncompressedSize > windowSize { 589 return nil, errors.New("uncompressed size is limited to 32KB") 590 } 591 f := &decompressor{ 592 lru: [3]uint16{1, 1, 1}, 593 uncompressed: uncompressedSize, 594 b: make([]byte, 4096), 595 r: r, 596 } 597 return f, nil 598 }