github.com/bir3/gocompiler@v0.9.2202/extra/compress/fse/compress.go (about) 1 // Copyright 2018 Klaus Post. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 // Based on work Copyright (c) 2013, Yann Collet, released under BSD License. 5 6 package fse 7 8 import ( 9 "errors" 10 "fmt" 11 ) 12 13 // Compress the input bytes. Input must be < 2GB. 14 // Provide a Scratch buffer to avoid memory allocations. 15 // Note that the output is also kept in the scratch buffer. 16 // If input is too hard to compress, ErrIncompressible is returned. 17 // If input is a single byte value repeated ErrUseRLE is returned. 18 func Compress(in []byte, s *Scratch) ([]byte, error) { 19 if len(in) <= 1 { 20 return nil, ErrIncompressible 21 } 22 if len(in) > (2<<30)-1 { 23 return nil, errors.New("input too big, must be < 2GB") 24 } 25 s, err := s.prepare(in) 26 if err != nil { 27 return nil, err 28 } 29 30 // Create histogram, if none was provided. 31 maxCount := s.maxCount 32 if maxCount == 0 { 33 maxCount = s.countSimple(in) 34 } 35 // Reset for next run. 36 s.clearCount = true 37 s.maxCount = 0 38 if maxCount == len(in) { 39 // One symbol, use RLE 40 return nil, ErrUseRLE 41 } 42 if maxCount == 1 || maxCount < (len(in)>>7) { 43 // Each symbol present maximum once or too well distributed. 44 return nil, ErrIncompressible 45 } 46 s.optimalTableLog() 47 err = s.normalizeCount() 48 if err != nil { 49 return nil, err 50 } 51 err = s.writeCount() 52 if err != nil { 53 return nil, err 54 } 55 56 if false { 57 err = s.validateNorm() 58 if err != nil { 59 return nil, err 60 } 61 } 62 63 err = s.buildCTable() 64 if err != nil { 65 return nil, err 66 } 67 err = s.compress(in) 68 if err != nil { 69 return nil, err 70 } 71 s.Out = s.bw.out 72 // Check if we compressed. 73 if len(s.Out) >= len(in) { 74 return nil, ErrIncompressible 75 } 76 return s.Out, nil 77 } 78 79 // cState contains the compression state of a stream. 80 type cState struct { 81 bw *bitWriter 82 stateTable []uint16 83 state uint16 84 } 85 86 // init will initialize the compression state to the first symbol of the stream. 87 func (c *cState) init(bw *bitWriter, ct *cTable, tableLog uint8, first symbolTransform) { 88 c.bw = bw 89 c.stateTable = ct.stateTable 90 91 nbBitsOut := (first.deltaNbBits + (1 << 15)) >> 16 92 im := int32((nbBitsOut << 16) - first.deltaNbBits) 93 lu := (im >> nbBitsOut) + first.deltaFindState 94 c.state = c.stateTable[lu] 95 } 96 97 // encode the output symbol provided and write it to the bitstream. 98 func (c *cState) encode(symbolTT symbolTransform) { 99 nbBitsOut := (uint32(c.state) + symbolTT.deltaNbBits) >> 16 100 dstState := int32(c.state>>(nbBitsOut&15)) + symbolTT.deltaFindState 101 c.bw.addBits16NC(c.state, uint8(nbBitsOut)) 102 c.state = c.stateTable[dstState] 103 } 104 105 // encode the output symbol provided and write it to the bitstream. 106 func (c *cState) encodeZero(symbolTT symbolTransform) { 107 nbBitsOut := (uint32(c.state) + symbolTT.deltaNbBits) >> 16 108 dstState := int32(c.state>>(nbBitsOut&15)) + symbolTT.deltaFindState 109 c.bw.addBits16ZeroNC(c.state, uint8(nbBitsOut)) 110 c.state = c.stateTable[dstState] 111 } 112 113 // flush will write the tablelog to the output and flush the remaining full bytes. 114 func (c *cState) flush(tableLog uint8) { 115 c.bw.flush32() 116 c.bw.addBits16NC(c.state, tableLog) 117 c.bw.flush() 118 } 119 120 // compress is the main compression loop that will encode the input from the last byte to the first. 121 func (s *Scratch) compress(src []byte) error { 122 if len(src) <= 2 { 123 return errors.New("compress: src too small") 124 } 125 tt := s.ct.symbolTT[:256] 126 s.bw.reset(s.Out) 127 128 // Our two states each encodes every second byte. 129 // Last byte encoded (first byte decoded) will always be encoded by c1. 130 var c1, c2 cState 131 132 // Encode so remaining size is divisible by 4. 133 ip := len(src) 134 if ip&1 == 1 { 135 c1.init(&s.bw, &s.ct, s.actualTableLog, tt[src[ip-1]]) 136 c2.init(&s.bw, &s.ct, s.actualTableLog, tt[src[ip-2]]) 137 c1.encodeZero(tt[src[ip-3]]) 138 ip -= 3 139 } else { 140 c2.init(&s.bw, &s.ct, s.actualTableLog, tt[src[ip-1]]) 141 c1.init(&s.bw, &s.ct, s.actualTableLog, tt[src[ip-2]]) 142 ip -= 2 143 } 144 if ip&2 != 0 { 145 c2.encodeZero(tt[src[ip-1]]) 146 c1.encodeZero(tt[src[ip-2]]) 147 ip -= 2 148 } 149 src = src[:ip] 150 151 // Main compression loop. 152 switch { 153 case !s.zeroBits && s.actualTableLog <= 8: 154 // We can encode 4 symbols without requiring a flush. 155 // We do not need to check if any output is 0 bits. 156 for ; len(src) >= 4; src = src[:len(src)-4] { 157 s.bw.flush32() 158 v3, v2, v1, v0 := src[len(src)-4], src[len(src)-3], src[len(src)-2], src[len(src)-1] 159 c2.encode(tt[v0]) 160 c1.encode(tt[v1]) 161 c2.encode(tt[v2]) 162 c1.encode(tt[v3]) 163 } 164 case !s.zeroBits: 165 // We do not need to check if any output is 0 bits. 166 for ; len(src) >= 4; src = src[:len(src)-4] { 167 s.bw.flush32() 168 v3, v2, v1, v0 := src[len(src)-4], src[len(src)-3], src[len(src)-2], src[len(src)-1] 169 c2.encode(tt[v0]) 170 c1.encode(tt[v1]) 171 s.bw.flush32() 172 c2.encode(tt[v2]) 173 c1.encode(tt[v3]) 174 } 175 case s.actualTableLog <= 8: 176 // We can encode 4 symbols without requiring a flush 177 for ; len(src) >= 4; src = src[:len(src)-4] { 178 s.bw.flush32() 179 v3, v2, v1, v0 := src[len(src)-4], src[len(src)-3], src[len(src)-2], src[len(src)-1] 180 c2.encodeZero(tt[v0]) 181 c1.encodeZero(tt[v1]) 182 c2.encodeZero(tt[v2]) 183 c1.encodeZero(tt[v3]) 184 } 185 default: 186 for ; len(src) >= 4; src = src[:len(src)-4] { 187 s.bw.flush32() 188 v3, v2, v1, v0 := src[len(src)-4], src[len(src)-3], src[len(src)-2], src[len(src)-1] 189 c2.encodeZero(tt[v0]) 190 c1.encodeZero(tt[v1]) 191 s.bw.flush32() 192 c2.encodeZero(tt[v2]) 193 c1.encodeZero(tt[v3]) 194 } 195 } 196 197 // Flush final state. 198 // Used to initialize state when decoding. 199 c2.flush(s.actualTableLog) 200 c1.flush(s.actualTableLog) 201 202 return s.bw.close() 203 } 204 205 // writeCount will write the normalized histogram count to header. 206 // This is read back by readNCount. 207 func (s *Scratch) writeCount() error { 208 var ( 209 tableLog = s.actualTableLog 210 tableSize = 1 << tableLog 211 previous0 bool 212 charnum uint16 213 214 maxHeaderSize = ((int(s.symbolLen) * int(tableLog)) >> 3) + 3 215 216 // Write Table Size 217 bitStream = uint32(tableLog - minTablelog) 218 bitCount = uint(4) 219 remaining = int16(tableSize + 1) /* +1 for extra accuracy */ 220 threshold = int16(tableSize) 221 nbBits = uint(tableLog + 1) 222 ) 223 if cap(s.Out) < maxHeaderSize { 224 s.Out = make([]byte, 0, s.br.remain()+maxHeaderSize) 225 } 226 outP := uint(0) 227 out := s.Out[:maxHeaderSize] 228 229 // stops at 1 230 for remaining > 1 { 231 if previous0 { 232 start := charnum 233 for s.norm[charnum] == 0 { 234 charnum++ 235 } 236 for charnum >= start+24 { 237 start += 24 238 bitStream += uint32(0xFFFF) << bitCount 239 out[outP] = byte(bitStream) 240 out[outP+1] = byte(bitStream >> 8) 241 outP += 2 242 bitStream >>= 16 243 } 244 for charnum >= start+3 { 245 start += 3 246 bitStream += 3 << bitCount 247 bitCount += 2 248 } 249 bitStream += uint32(charnum-start) << bitCount 250 bitCount += 2 251 if bitCount > 16 { 252 out[outP] = byte(bitStream) 253 out[outP+1] = byte(bitStream >> 8) 254 outP += 2 255 bitStream >>= 16 256 bitCount -= 16 257 } 258 } 259 260 count := s.norm[charnum] 261 charnum++ 262 max := (2*threshold - 1) - remaining 263 if count < 0 { 264 remaining += count 265 } else { 266 remaining -= count 267 } 268 count++ // +1 for extra accuracy 269 if count >= threshold { 270 count += max // [0..max[ [max..threshold[ (...) [threshold+max 2*threshold[ 271 } 272 bitStream += uint32(count) << bitCount 273 bitCount += nbBits 274 if count < max { 275 bitCount-- 276 } 277 278 previous0 = count == 1 279 if remaining < 1 { 280 return errors.New("internal error: remaining<1") 281 } 282 for remaining < threshold { 283 nbBits-- 284 threshold >>= 1 285 } 286 287 if bitCount > 16 { 288 out[outP] = byte(bitStream) 289 out[outP+1] = byte(bitStream >> 8) 290 outP += 2 291 bitStream >>= 16 292 bitCount -= 16 293 } 294 } 295 296 out[outP] = byte(bitStream) 297 out[outP+1] = byte(bitStream >> 8) 298 outP += (bitCount + 7) / 8 299 300 if charnum > s.symbolLen { 301 return errors.New("internal error: charnum > s.symbolLen") 302 } 303 s.Out = out[:outP] 304 return nil 305 } 306 307 // symbolTransform contains the state transform for a symbol. 308 type symbolTransform struct { 309 deltaFindState int32 310 deltaNbBits uint32 311 } 312 313 // String prints values as a human readable string. 314 func (s symbolTransform) String() string { 315 return fmt.Sprintf("dnbits: %08x, fs:%d", s.deltaNbBits, s.deltaFindState) 316 } 317 318 // cTable contains tables used for compression. 319 type cTable struct { 320 tableSymbol []byte 321 stateTable []uint16 322 symbolTT []symbolTransform 323 } 324 325 // allocCtable will allocate tables needed for compression. 326 // If existing tables a re big enough, they are simply re-used. 327 func (s *Scratch) allocCtable() { 328 tableSize := 1 << s.actualTableLog 329 // get tableSymbol that is big enough. 330 if cap(s.ct.tableSymbol) < tableSize { 331 s.ct.tableSymbol = make([]byte, tableSize) 332 } 333 s.ct.tableSymbol = s.ct.tableSymbol[:tableSize] 334 335 ctSize := tableSize 336 if cap(s.ct.stateTable) < ctSize { 337 s.ct.stateTable = make([]uint16, ctSize) 338 } 339 s.ct.stateTable = s.ct.stateTable[:ctSize] 340 341 if cap(s.ct.symbolTT) < 256 { 342 s.ct.symbolTT = make([]symbolTransform, 256) 343 } 344 s.ct.symbolTT = s.ct.symbolTT[:256] 345 } 346 347 // buildCTable will populate the compression table so it is ready to be used. 348 func (s *Scratch) buildCTable() error { 349 tableSize := uint32(1 << s.actualTableLog) 350 highThreshold := tableSize - 1 351 var cumul [maxSymbolValue + 2]int16 352 353 s.allocCtable() 354 tableSymbol := s.ct.tableSymbol[:tableSize] 355 // symbol start positions 356 { 357 cumul[0] = 0 358 for ui, v := range s.norm[:s.symbolLen-1] { 359 u := byte(ui) // one less than reference 360 if v == -1 { 361 // Low proba symbol 362 cumul[u+1] = cumul[u] + 1 363 tableSymbol[highThreshold] = u 364 highThreshold-- 365 } else { 366 cumul[u+1] = cumul[u] + v 367 } 368 } 369 // Encode last symbol separately to avoid overflowing u 370 u := int(s.symbolLen - 1) 371 v := s.norm[s.symbolLen-1] 372 if v == -1 { 373 // Low proba symbol 374 cumul[u+1] = cumul[u] + 1 375 tableSymbol[highThreshold] = byte(u) 376 highThreshold-- 377 } else { 378 cumul[u+1] = cumul[u] + v 379 } 380 if uint32(cumul[s.symbolLen]) != tableSize { 381 return fmt.Errorf("internal error: expected cumul[s.symbolLen] (%d) == tableSize (%d)", cumul[s.symbolLen], tableSize) 382 } 383 cumul[s.symbolLen] = int16(tableSize) + 1 384 } 385 // Spread symbols 386 s.zeroBits = false 387 { 388 step := tableStep(tableSize) 389 tableMask := tableSize - 1 390 var position uint32 391 // if any symbol > largeLimit, we may have 0 bits output. 392 largeLimit := int16(1 << (s.actualTableLog - 1)) 393 for ui, v := range s.norm[:s.symbolLen] { 394 symbol := byte(ui) 395 if v > largeLimit { 396 s.zeroBits = true 397 } 398 for nbOccurrences := int16(0); nbOccurrences < v; nbOccurrences++ { 399 tableSymbol[position] = symbol 400 position = (position + step) & tableMask 401 for position > highThreshold { 402 position = (position + step) & tableMask 403 } /* Low proba area */ 404 } 405 } 406 407 // Check if we have gone through all positions 408 if position != 0 { 409 return errors.New("position!=0") 410 } 411 } 412 413 // Build table 414 table := s.ct.stateTable 415 { 416 tsi := int(tableSize) 417 for u, v := range tableSymbol { 418 // TableU16 : sorted by symbol order; gives next state value 419 table[cumul[v]] = uint16(tsi + u) 420 cumul[v]++ 421 } 422 } 423 424 // Build Symbol Transformation Table 425 { 426 total := int16(0) 427 symbolTT := s.ct.symbolTT[:s.symbolLen] 428 tableLog := s.actualTableLog 429 tl := (uint32(tableLog) << 16) - (1 << tableLog) 430 for i, v := range s.norm[:s.symbolLen] { 431 switch v { 432 case 0: 433 case -1, 1: 434 symbolTT[i].deltaNbBits = tl 435 symbolTT[i].deltaFindState = int32(total - 1) 436 total++ 437 default: 438 maxBitsOut := uint32(tableLog) - highBits(uint32(v-1)) 439 minStatePlus := uint32(v) << maxBitsOut 440 symbolTT[i].deltaNbBits = (maxBitsOut << 16) - minStatePlus 441 symbolTT[i].deltaFindState = int32(total - v) 442 total += v 443 } 444 } 445 if total != int16(tableSize) { 446 return fmt.Errorf("total mismatch %d (got) != %d (want)", total, tableSize) 447 } 448 } 449 return nil 450 } 451 452 // countSimple will create a simple histogram in s.count. 453 // Returns the biggest count. 454 // Does not update s.clearCount. 455 func (s *Scratch) countSimple(in []byte) (max int) { 456 for _, v := range in { 457 s.count[v]++ 458 } 459 m, symlen := uint32(0), s.symbolLen 460 for i, v := range s.count[:] { 461 if v == 0 { 462 continue 463 } 464 if v > m { 465 m = v 466 } 467 symlen = uint16(i) + 1 468 } 469 s.symbolLen = symlen 470 return int(m) 471 } 472 473 // minTableLog provides the minimum logSize to safely represent a distribution. 474 func (s *Scratch) minTableLog() uint8 { 475 minBitsSrc := highBits(uint32(s.br.remain()-1)) + 1 476 minBitsSymbols := highBits(uint32(s.symbolLen-1)) + 2 477 if minBitsSrc < minBitsSymbols { 478 return uint8(minBitsSrc) 479 } 480 return uint8(minBitsSymbols) 481 } 482 483 // optimalTableLog calculates and sets the optimal tableLog in s.actualTableLog 484 func (s *Scratch) optimalTableLog() { 485 tableLog := s.TableLog 486 minBits := s.minTableLog() 487 maxBitsSrc := uint8(highBits(uint32(s.br.remain()-1))) - 2 488 if maxBitsSrc < tableLog { 489 // Accuracy can be reduced 490 tableLog = maxBitsSrc 491 } 492 if minBits > tableLog { 493 tableLog = minBits 494 } 495 // Need a minimum to safely represent all symbol values 496 if tableLog < minTablelog { 497 tableLog = minTablelog 498 } 499 if tableLog > maxTableLog { 500 tableLog = maxTableLog 501 } 502 s.actualTableLog = tableLog 503 } 504 505 var rtbTable = [...]uint32{0, 473195, 504333, 520860, 550000, 700000, 750000, 830000} 506 507 // normalizeCount will normalize the count of the symbols so 508 // the total is equal to the table size. 509 func (s *Scratch) normalizeCount() error { 510 var ( 511 tableLog = s.actualTableLog 512 scale = 62 - uint64(tableLog) 513 step = (1 << 62) / uint64(s.br.remain()) 514 vStep = uint64(1) << (scale - 20) 515 stillToDistribute = int16(1 << tableLog) 516 largest int 517 largestP int16 518 lowThreshold = (uint32)(s.br.remain() >> tableLog) 519 ) 520 521 for i, cnt := range s.count[:s.symbolLen] { 522 // already handled 523 // if (count[s] == s.length) return 0; /* rle special case */ 524 525 if cnt == 0 { 526 s.norm[i] = 0 527 continue 528 } 529 if cnt <= lowThreshold { 530 s.norm[i] = -1 531 stillToDistribute-- 532 } else { 533 proba := (int16)((uint64(cnt) * step) >> scale) 534 if proba < 8 { 535 restToBeat := vStep * uint64(rtbTable[proba]) 536 v := uint64(cnt)*step - (uint64(proba) << scale) 537 if v > restToBeat { 538 proba++ 539 } 540 } 541 if proba > largestP { 542 largestP = proba 543 largest = i 544 } 545 s.norm[i] = proba 546 stillToDistribute -= proba 547 } 548 } 549 550 if -stillToDistribute >= (s.norm[largest] >> 1) { 551 // corner case, need another normalization method 552 return s.normalizeCount2() 553 } 554 s.norm[largest] += stillToDistribute 555 return nil 556 } 557 558 // Secondary normalization method. 559 // To be used when primary method fails. 560 func (s *Scratch) normalizeCount2() error { 561 const notYetAssigned = -2 562 var ( 563 distributed uint32 564 total = uint32(s.br.remain()) 565 tableLog = s.actualTableLog 566 lowThreshold = total >> tableLog 567 lowOne = (total * 3) >> (tableLog + 1) 568 ) 569 for i, cnt := range s.count[:s.symbolLen] { 570 if cnt == 0 { 571 s.norm[i] = 0 572 continue 573 } 574 if cnt <= lowThreshold { 575 s.norm[i] = -1 576 distributed++ 577 total -= cnt 578 continue 579 } 580 if cnt <= lowOne { 581 s.norm[i] = 1 582 distributed++ 583 total -= cnt 584 continue 585 } 586 s.norm[i] = notYetAssigned 587 } 588 toDistribute := (1 << tableLog) - distributed 589 590 if (total / toDistribute) > lowOne { 591 // risk of rounding to zero 592 lowOne = (total * 3) / (toDistribute * 2) 593 for i, cnt := range s.count[:s.symbolLen] { 594 if (s.norm[i] == notYetAssigned) && (cnt <= lowOne) { 595 s.norm[i] = 1 596 distributed++ 597 total -= cnt 598 continue 599 } 600 } 601 toDistribute = (1 << tableLog) - distributed 602 } 603 if distributed == uint32(s.symbolLen)+1 { 604 // all values are pretty poor; 605 // probably incompressible data (should have already been detected); 606 // find max, then give all remaining points to max 607 var maxV int 608 var maxC uint32 609 for i, cnt := range s.count[:s.symbolLen] { 610 if cnt > maxC { 611 maxV = i 612 maxC = cnt 613 } 614 } 615 s.norm[maxV] += int16(toDistribute) 616 return nil 617 } 618 619 if total == 0 { 620 // all of the symbols were low enough for the lowOne or lowThreshold 621 for i := uint32(0); toDistribute > 0; i = (i + 1) % (uint32(s.symbolLen)) { 622 if s.norm[i] > 0 { 623 toDistribute-- 624 s.norm[i]++ 625 } 626 } 627 return nil 628 } 629 630 var ( 631 vStepLog = 62 - uint64(tableLog) 632 mid = uint64((1 << (vStepLog - 1)) - 1) 633 rStep = (((1 << vStepLog) * uint64(toDistribute)) + mid) / uint64(total) // scale on remaining 634 tmpTotal = mid 635 ) 636 for i, cnt := range s.count[:s.symbolLen] { 637 if s.norm[i] == notYetAssigned { 638 var ( 639 end = tmpTotal + uint64(cnt)*rStep 640 sStart = uint32(tmpTotal >> vStepLog) 641 sEnd = uint32(end >> vStepLog) 642 weight = sEnd - sStart 643 ) 644 if weight < 1 { 645 return errors.New("weight < 1") 646 } 647 s.norm[i] = int16(weight) 648 tmpTotal = end 649 } 650 } 651 return nil 652 } 653 654 // validateNorm validates the normalized histogram table. 655 func (s *Scratch) validateNorm() (err error) { 656 var total int 657 for _, v := range s.norm[:s.symbolLen] { 658 if v >= 0 { 659 total += int(v) 660 } else { 661 total -= int(v) 662 } 663 } 664 defer func() { 665 if err == nil { 666 return 667 } 668 fmt.Printf("selected TableLog: %d, Symbol length: %d\n", s.actualTableLog, s.symbolLen) 669 for i, v := range s.norm[:s.symbolLen] { 670 fmt.Printf("%3d: %5d -> %4d \n", i, s.count[i], v) 671 } 672 }() 673 if total != (1 << s.actualTableLog) { 674 return fmt.Errorf("warning: Total == %d != %d", total, 1<<s.actualTableLog) 675 } 676 for i, v := range s.count[s.symbolLen:] { 677 if v != 0 { 678 return fmt.Errorf("warning: Found symbol out of range, %d after cut", i) 679 } 680 } 681 return nil 682 }