github.com/bir3/gocompiler@v0.9.2202/extra/compress/zstd/fse_encoder.go (about) 1 // Copyright 2019+ Klaus Post. All rights reserved. 2 // License information can be found in the LICENSE file. 3 // Based on work by Yann Collet, released under BSD License. 4 5 package zstd 6 7 import ( 8 "errors" 9 "fmt" 10 "math" 11 ) 12 13 const ( 14 // For encoding we only support up to 15 maxEncTableLog = 8 16 maxEncTablesize = 1 << maxTableLog 17 maxEncTableMask = (1 << maxTableLog) - 1 18 minEncTablelog = 5 19 maxEncSymbolValue = maxMatchLengthSymbol 20 ) 21 22 // Scratch provides temporary storage for compression and decompression. 23 type fseEncoder struct { 24 symbolLen uint16 // Length of active part of the symbol table. 25 actualTableLog uint8 // Selected tablelog. 26 ct cTable // Compression tables. 27 maxCount int // count of the most probable symbol 28 zeroBits bool // no bits has prob > 50%. 29 clearCount bool // clear count 30 useRLE bool // This encoder is for RLE 31 preDefined bool // This encoder is predefined. 32 reUsed bool // Set to know when the encoder has been reused. 33 rleVal uint8 // RLE Symbol 34 maxBits uint8 // Maximum output bits after transform. 35 36 // TODO: Technically zstd should be fine with 64 bytes. 37 count [256]uint32 38 norm [256]int16 39 } 40 41 // cTable contains tables used for compression. 42 type cTable struct { 43 tableSymbol []byte 44 stateTable []uint16 45 symbolTT []symbolTransform 46 } 47 48 // symbolTransform contains the state transform for a symbol. 49 type symbolTransform struct { 50 deltaNbBits uint32 51 deltaFindState int16 52 outBits uint8 53 } 54 55 // String prints values as a human readable string. 56 func (s symbolTransform) String() string { 57 return fmt.Sprintf("{deltabits: %08x, findstate:%d outbits:%d}", s.deltaNbBits, s.deltaFindState, s.outBits) 58 } 59 60 // Histogram allows to populate the histogram and skip that step in the compression, 61 // It otherwise allows to inspect the histogram when compression is done. 62 // To indicate that you have populated the histogram call HistogramFinished 63 // with the value of the highest populated symbol, as well as the number of entries 64 // in the most populated entry. These are accepted at face value. 65 func (s *fseEncoder) Histogram() *[256]uint32 { 66 return &s.count 67 } 68 69 // HistogramFinished can be called to indicate that the histogram has been populated. 70 // maxSymbol is the index of the highest set symbol of the next data segment. 71 // maxCount is the number of entries in the most populated entry. 72 // These are accepted at face value. 73 func (s *fseEncoder) HistogramFinished(maxSymbol uint8, maxCount int) { 74 s.maxCount = maxCount 75 s.symbolLen = uint16(maxSymbol) + 1 76 s.clearCount = maxCount != 0 77 } 78 79 // allocCtable will allocate tables needed for compression. 80 // If existing tables a re big enough, they are simply re-used. 81 func (s *fseEncoder) allocCtable() { 82 tableSize := 1 << s.actualTableLog 83 // get tableSymbol that is big enough. 84 if cap(s.ct.tableSymbol) < tableSize { 85 s.ct.tableSymbol = make([]byte, tableSize) 86 } 87 s.ct.tableSymbol = s.ct.tableSymbol[:tableSize] 88 89 ctSize := tableSize 90 if cap(s.ct.stateTable) < ctSize { 91 s.ct.stateTable = make([]uint16, ctSize) 92 } 93 s.ct.stateTable = s.ct.stateTable[:ctSize] 94 95 if cap(s.ct.symbolTT) < 256 { 96 s.ct.symbolTT = make([]symbolTransform, 256) 97 } 98 s.ct.symbolTT = s.ct.symbolTT[:256] 99 } 100 101 // buildCTable will populate the compression table so it is ready to be used. 102 func (s *fseEncoder) buildCTable() error { 103 tableSize := uint32(1 << s.actualTableLog) 104 highThreshold := tableSize - 1 105 var cumul [256]int16 106 107 s.allocCtable() 108 tableSymbol := s.ct.tableSymbol[:tableSize] 109 // symbol start positions 110 { 111 cumul[0] = 0 112 for ui, v := range s.norm[:s.symbolLen-1] { 113 u := byte(ui) // one less than reference 114 if v == -1 { 115 // Low proba symbol 116 cumul[u+1] = cumul[u] + 1 117 tableSymbol[highThreshold] = u 118 highThreshold-- 119 } else { 120 cumul[u+1] = cumul[u] + v 121 } 122 } 123 // Encode last symbol separately to avoid overflowing u 124 u := int(s.symbolLen - 1) 125 v := s.norm[s.symbolLen-1] 126 if v == -1 { 127 // Low proba symbol 128 cumul[u+1] = cumul[u] + 1 129 tableSymbol[highThreshold] = byte(u) 130 highThreshold-- 131 } else { 132 cumul[u+1] = cumul[u] + v 133 } 134 if uint32(cumul[s.symbolLen]) != tableSize { 135 return fmt.Errorf("internal error: expected cumul[s.symbolLen] (%d) == tableSize (%d)", cumul[s.symbolLen], tableSize) 136 } 137 cumul[s.symbolLen] = int16(tableSize) + 1 138 } 139 // Spread symbols 140 s.zeroBits = false 141 { 142 step := tableStep(tableSize) 143 tableMask := tableSize - 1 144 var position uint32 145 // if any symbol > largeLimit, we may have 0 bits output. 146 largeLimit := int16(1 << (s.actualTableLog - 1)) 147 for ui, v := range s.norm[:s.symbolLen] { 148 symbol := byte(ui) 149 if v > largeLimit { 150 s.zeroBits = true 151 } 152 for nbOccurrences := int16(0); nbOccurrences < v; nbOccurrences++ { 153 tableSymbol[position] = symbol 154 position = (position + step) & tableMask 155 for position > highThreshold { 156 position = (position + step) & tableMask 157 } /* Low proba area */ 158 } 159 } 160 161 // Check if we have gone through all positions 162 if position != 0 { 163 return errors.New("position!=0") 164 } 165 } 166 167 // Build table 168 table := s.ct.stateTable 169 { 170 tsi := int(tableSize) 171 for u, v := range tableSymbol { 172 // TableU16 : sorted by symbol order; gives next state value 173 table[cumul[v]] = uint16(tsi + u) 174 cumul[v]++ 175 } 176 } 177 178 // Build Symbol Transformation Table 179 { 180 total := int16(0) 181 symbolTT := s.ct.symbolTT[:s.symbolLen] 182 tableLog := s.actualTableLog 183 tl := (uint32(tableLog) << 16) - (1 << tableLog) 184 for i, v := range s.norm[:s.symbolLen] { 185 switch v { 186 case 0: 187 case -1, 1: 188 symbolTT[i].deltaNbBits = tl 189 symbolTT[i].deltaFindState = total - 1 190 total++ 191 default: 192 maxBitsOut := uint32(tableLog) - highBit(uint32(v-1)) 193 minStatePlus := uint32(v) << maxBitsOut 194 symbolTT[i].deltaNbBits = (maxBitsOut << 16) - minStatePlus 195 symbolTT[i].deltaFindState = total - v 196 total += v 197 } 198 } 199 if total != int16(tableSize) { 200 return fmt.Errorf("total mismatch %d (got) != %d (want)", total, tableSize) 201 } 202 } 203 return nil 204 } 205 206 var rtbTable = [...]uint32{0, 473195, 504333, 520860, 550000, 700000, 750000, 830000} 207 208 func (s *fseEncoder) setRLE(val byte) { 209 s.allocCtable() 210 s.actualTableLog = 0 211 s.ct.stateTable = s.ct.stateTable[:1] 212 s.ct.symbolTT[val] = symbolTransform{ 213 deltaFindState: 0, 214 deltaNbBits: 0, 215 } 216 if debugEncoder { 217 println("setRLE: val", val, "symbolTT", s.ct.symbolTT[val]) 218 } 219 s.rleVal = val 220 s.useRLE = true 221 } 222 223 // setBits will set output bits for the transform. 224 // if nil is provided, the number of bits is equal to the index. 225 func (s *fseEncoder) setBits(transform []byte) { 226 if s.reUsed || s.preDefined { 227 return 228 } 229 if s.useRLE { 230 if transform == nil { 231 s.ct.symbolTT[s.rleVal].outBits = s.rleVal 232 s.maxBits = s.rleVal 233 return 234 } 235 s.maxBits = transform[s.rleVal] 236 s.ct.symbolTT[s.rleVal].outBits = s.maxBits 237 return 238 } 239 if transform == nil { 240 for i := range s.ct.symbolTT[:s.symbolLen] { 241 s.ct.symbolTT[i].outBits = uint8(i) 242 } 243 s.maxBits = uint8(s.symbolLen - 1) 244 return 245 } 246 s.maxBits = 0 247 for i, v := range transform[:s.symbolLen] { 248 s.ct.symbolTT[i].outBits = v 249 if v > s.maxBits { 250 // We could assume bits always going up, but we play safe. 251 s.maxBits = v 252 } 253 } 254 } 255 256 // normalizeCount will normalize the count of the symbols so 257 // the total is equal to the table size. 258 // If successful, compression tables will also be made ready. 259 func (s *fseEncoder) normalizeCount(length int) error { 260 if s.reUsed { 261 return nil 262 } 263 s.optimalTableLog(length) 264 var ( 265 tableLog = s.actualTableLog 266 scale = 62 - uint64(tableLog) 267 step = (1 << 62) / uint64(length) 268 vStep = uint64(1) << (scale - 20) 269 stillToDistribute = int16(1 << tableLog) 270 largest int 271 largestP int16 272 lowThreshold = (uint32)(length >> tableLog) 273 ) 274 if s.maxCount == length { 275 s.useRLE = true 276 return nil 277 } 278 s.useRLE = false 279 for i, cnt := range s.count[:s.symbolLen] { 280 // already handled 281 // if (count[s] == s.length) return 0; /* rle special case */ 282 283 if cnt == 0 { 284 s.norm[i] = 0 285 continue 286 } 287 if cnt <= lowThreshold { 288 s.norm[i] = -1 289 stillToDistribute-- 290 } else { 291 proba := (int16)((uint64(cnt) * step) >> scale) 292 if proba < 8 { 293 restToBeat := vStep * uint64(rtbTable[proba]) 294 v := uint64(cnt)*step - (uint64(proba) << scale) 295 if v > restToBeat { 296 proba++ 297 } 298 } 299 if proba > largestP { 300 largestP = proba 301 largest = i 302 } 303 s.norm[i] = proba 304 stillToDistribute -= proba 305 } 306 } 307 308 if -stillToDistribute >= (s.norm[largest] >> 1) { 309 // corner case, need another normalization method 310 err := s.normalizeCount2(length) 311 if err != nil { 312 return err 313 } 314 if debugAsserts { 315 err = s.validateNorm() 316 if err != nil { 317 return err 318 } 319 } 320 return s.buildCTable() 321 } 322 s.norm[largest] += stillToDistribute 323 if debugAsserts { 324 err := s.validateNorm() 325 if err != nil { 326 return err 327 } 328 } 329 return s.buildCTable() 330 } 331 332 // Secondary normalization method. 333 // To be used when primary method fails. 334 func (s *fseEncoder) normalizeCount2(length int) error { 335 const notYetAssigned = -2 336 var ( 337 distributed uint32 338 total = uint32(length) 339 tableLog = s.actualTableLog 340 lowThreshold = total >> tableLog 341 lowOne = (total * 3) >> (tableLog + 1) 342 ) 343 for i, cnt := range s.count[:s.symbolLen] { 344 if cnt == 0 { 345 s.norm[i] = 0 346 continue 347 } 348 if cnt <= lowThreshold { 349 s.norm[i] = -1 350 distributed++ 351 total -= cnt 352 continue 353 } 354 if cnt <= lowOne { 355 s.norm[i] = 1 356 distributed++ 357 total -= cnt 358 continue 359 } 360 s.norm[i] = notYetAssigned 361 } 362 toDistribute := (1 << tableLog) - distributed 363 364 if (total / toDistribute) > lowOne { 365 // risk of rounding to zero 366 lowOne = (total * 3) / (toDistribute * 2) 367 for i, cnt := range s.count[:s.symbolLen] { 368 if (s.norm[i] == notYetAssigned) && (cnt <= lowOne) { 369 s.norm[i] = 1 370 distributed++ 371 total -= cnt 372 continue 373 } 374 } 375 toDistribute = (1 << tableLog) - distributed 376 } 377 if distributed == uint32(s.symbolLen)+1 { 378 // all values are pretty poor; 379 // probably incompressible data (should have already been detected); 380 // find max, then give all remaining points to max 381 var maxV int 382 var maxC uint32 383 for i, cnt := range s.count[:s.symbolLen] { 384 if cnt > maxC { 385 maxV = i 386 maxC = cnt 387 } 388 } 389 s.norm[maxV] += int16(toDistribute) 390 return nil 391 } 392 393 if total == 0 { 394 // all of the symbols were low enough for the lowOne or lowThreshold 395 for i := uint32(0); toDistribute > 0; i = (i + 1) % (uint32(s.symbolLen)) { 396 if s.norm[i] > 0 { 397 toDistribute-- 398 s.norm[i]++ 399 } 400 } 401 return nil 402 } 403 404 var ( 405 vStepLog = 62 - uint64(tableLog) 406 mid = uint64((1 << (vStepLog - 1)) - 1) 407 rStep = (((1 << vStepLog) * uint64(toDistribute)) + mid) / uint64(total) // scale on remaining 408 tmpTotal = mid 409 ) 410 for i, cnt := range s.count[:s.symbolLen] { 411 if s.norm[i] == notYetAssigned { 412 var ( 413 end = tmpTotal + uint64(cnt)*rStep 414 sStart = uint32(tmpTotal >> vStepLog) 415 sEnd = uint32(end >> vStepLog) 416 weight = sEnd - sStart 417 ) 418 if weight < 1 { 419 return errors.New("weight < 1") 420 } 421 s.norm[i] = int16(weight) 422 tmpTotal = end 423 } 424 } 425 return nil 426 } 427 428 // optimalTableLog calculates and sets the optimal tableLog in s.actualTableLog 429 func (s *fseEncoder) optimalTableLog(length int) { 430 tableLog := uint8(maxEncTableLog) 431 minBitsSrc := highBit(uint32(length)) + 1 432 minBitsSymbols := highBit(uint32(s.symbolLen-1)) + 2 433 minBits := uint8(minBitsSymbols) 434 if minBitsSrc < minBitsSymbols { 435 minBits = uint8(minBitsSrc) 436 } 437 438 maxBitsSrc := uint8(highBit(uint32(length-1))) - 2 439 if maxBitsSrc < tableLog { 440 // Accuracy can be reduced 441 tableLog = maxBitsSrc 442 } 443 if minBits > tableLog { 444 tableLog = minBits 445 } 446 // Need a minimum to safely represent all symbol values 447 if tableLog < minEncTablelog { 448 tableLog = minEncTablelog 449 } 450 if tableLog > maxEncTableLog { 451 tableLog = maxEncTableLog 452 } 453 s.actualTableLog = tableLog 454 } 455 456 // validateNorm validates the normalized histogram table. 457 func (s *fseEncoder) validateNorm() (err error) { 458 var total int 459 for _, v := range s.norm[:s.symbolLen] { 460 if v >= 0 { 461 total += int(v) 462 } else { 463 total -= int(v) 464 } 465 } 466 defer func() { 467 if err == nil { 468 return 469 } 470 fmt.Printf("selected TableLog: %d, Symbol length: %d\n", s.actualTableLog, s.symbolLen) 471 for i, v := range s.norm[:s.symbolLen] { 472 fmt.Printf("%3d: %5d -> %4d \n", i, s.count[i], v) 473 } 474 }() 475 if total != (1 << s.actualTableLog) { 476 return fmt.Errorf("warning: Total == %d != %d", total, 1<<s.actualTableLog) 477 } 478 for i, v := range s.count[s.symbolLen:] { 479 if v != 0 { 480 return fmt.Errorf("warning: Found symbol out of range, %d after cut", i) 481 } 482 } 483 return nil 484 } 485 486 // writeCount will write the normalized histogram count to header. 487 // This is read back by readNCount. 488 func (s *fseEncoder) writeCount(out []byte) ([]byte, error) { 489 if s.useRLE { 490 return append(out, s.rleVal), nil 491 } 492 if s.preDefined || s.reUsed { 493 // Never write predefined. 494 return out, nil 495 } 496 497 var ( 498 tableLog = s.actualTableLog 499 tableSize = 1 << tableLog 500 previous0 bool 501 charnum uint16 502 503 // maximum header size plus 2 extra bytes for final output if bitCount == 0. 504 maxHeaderSize = ((int(s.symbolLen) * int(tableLog)) >> 3) + 3 + 2 505 506 // Write Table Size 507 bitStream = uint32(tableLog - minEncTablelog) 508 bitCount = uint(4) 509 remaining = int16(tableSize + 1) /* +1 for extra accuracy */ 510 threshold = int16(tableSize) 511 nbBits = uint(tableLog + 1) 512 outP = len(out) 513 ) 514 if cap(out) < outP+maxHeaderSize { 515 out = append(out, make([]byte, maxHeaderSize*3)...) 516 out = out[:len(out)-maxHeaderSize*3] 517 } 518 out = out[:outP+maxHeaderSize] 519 520 // stops at 1 521 for remaining > 1 { 522 if previous0 { 523 start := charnum 524 for s.norm[charnum] == 0 { 525 charnum++ 526 } 527 for charnum >= start+24 { 528 start += 24 529 bitStream += uint32(0xFFFF) << bitCount 530 out[outP] = byte(bitStream) 531 out[outP+1] = byte(bitStream >> 8) 532 outP += 2 533 bitStream >>= 16 534 } 535 for charnum >= start+3 { 536 start += 3 537 bitStream += 3 << bitCount 538 bitCount += 2 539 } 540 bitStream += uint32(charnum-start) << bitCount 541 bitCount += 2 542 if bitCount > 16 { 543 out[outP] = byte(bitStream) 544 out[outP+1] = byte(bitStream >> 8) 545 outP += 2 546 bitStream >>= 16 547 bitCount -= 16 548 } 549 } 550 551 count := s.norm[charnum] 552 charnum++ 553 max := (2*threshold - 1) - remaining 554 if count < 0 { 555 remaining += count 556 } else { 557 remaining -= count 558 } 559 count++ // +1 for extra accuracy 560 if count >= threshold { 561 count += max // [0..max[ [max..threshold[ (...) [threshold+max 2*threshold[ 562 } 563 bitStream += uint32(count) << bitCount 564 bitCount += nbBits 565 if count < max { 566 bitCount-- 567 } 568 569 previous0 = count == 1 570 if remaining < 1 { 571 return nil, errors.New("internal error: remaining < 1") 572 } 573 for remaining < threshold { 574 nbBits-- 575 threshold >>= 1 576 } 577 578 if bitCount > 16 { 579 out[outP] = byte(bitStream) 580 out[outP+1] = byte(bitStream >> 8) 581 outP += 2 582 bitStream >>= 16 583 bitCount -= 16 584 } 585 } 586 587 if outP+2 > len(out) { 588 return nil, fmt.Errorf("internal error: %d > %d, maxheader: %d, sl: %d, tl: %d, normcount: %v", outP+2, len(out), maxHeaderSize, s.symbolLen, int(tableLog), s.norm[:s.symbolLen]) 589 } 590 out[outP] = byte(bitStream) 591 out[outP+1] = byte(bitStream >> 8) 592 outP += int((bitCount + 7) / 8) 593 594 if charnum > s.symbolLen { 595 return nil, errors.New("internal error: charnum > s.symbolLen") 596 } 597 return out[:outP], nil 598 } 599 600 // Approximate symbol cost, as fractional value, using fixed-point format (accuracyLog fractional bits) 601 // note 1 : assume symbolValue is valid (<= maxSymbolValue) 602 // note 2 : if freq[symbolValue]==0, @return a fake cost of tableLog+1 bits * 603 func (s *fseEncoder) bitCost(symbolValue uint8, accuracyLog uint32) uint32 { 604 minNbBits := s.ct.symbolTT[symbolValue].deltaNbBits >> 16 605 threshold := (minNbBits + 1) << 16 606 if debugAsserts { 607 if !(s.actualTableLog < 16) { 608 panic("!s.actualTableLog < 16") 609 } 610 // ensure enough room for renormalization double shift 611 if !(uint8(accuracyLog) < 31-s.actualTableLog) { 612 panic("!uint8(accuracyLog) < 31-s.actualTableLog") 613 } 614 } 615 tableSize := uint32(1) << s.actualTableLog 616 deltaFromThreshold := threshold - (s.ct.symbolTT[symbolValue].deltaNbBits + tableSize) 617 // linear interpolation (very approximate) 618 normalizedDeltaFromThreshold := (deltaFromThreshold << accuracyLog) >> s.actualTableLog 619 bitMultiplier := uint32(1) << accuracyLog 620 if debugAsserts { 621 if s.ct.symbolTT[symbolValue].deltaNbBits+tableSize > threshold { 622 panic("s.ct.symbolTT[symbolValue].deltaNbBits+tableSize > threshold") 623 } 624 if normalizedDeltaFromThreshold > bitMultiplier { 625 panic("normalizedDeltaFromThreshold > bitMultiplier") 626 } 627 } 628 return (minNbBits+1)*bitMultiplier - normalizedDeltaFromThreshold 629 } 630 631 // Returns the cost in bits of encoding the distribution in count using ctable. 632 // Histogram should only be up to the last non-zero symbol. 633 // Returns an -1 if ctable cannot represent all the symbols in count. 634 func (s *fseEncoder) approxSize(hist []uint32) uint32 { 635 if int(s.symbolLen) < len(hist) { 636 // More symbols than we have. 637 return math.MaxUint32 638 } 639 if s.useRLE { 640 // We will never reuse RLE encoders. 641 return math.MaxUint32 642 } 643 const kAccuracyLog = 8 644 badCost := (uint32(s.actualTableLog) + 1) << kAccuracyLog 645 var cost uint32 646 for i, v := range hist { 647 if v == 0 { 648 continue 649 } 650 if s.norm[i] == 0 { 651 return math.MaxUint32 652 } 653 bitCost := s.bitCost(uint8(i), kAccuracyLog) 654 if bitCost > badCost { 655 return math.MaxUint32 656 } 657 cost += v * bitCost 658 } 659 return cost >> kAccuracyLog 660 } 661 662 // maxHeaderSize returns the maximum header size in bits. 663 // This is not exact size, but we want a penalty for new tables anyway. 664 func (s *fseEncoder) maxHeaderSize() uint32 { 665 if s.preDefined { 666 return 0 667 } 668 if s.useRLE { 669 return 8 670 } 671 return (((uint32(s.symbolLen) * uint32(s.actualTableLog)) >> 3) + 3) * 8 672 } 673 674 // cState contains the compression state of a stream. 675 type cState struct { 676 bw *bitWriter 677 stateTable []uint16 678 state uint16 679 } 680 681 // init will initialize the compression state to the first symbol of the stream. 682 func (c *cState) init(bw *bitWriter, ct *cTable, first symbolTransform) { 683 c.bw = bw 684 c.stateTable = ct.stateTable 685 if len(c.stateTable) == 1 { 686 // RLE 687 c.stateTable[0] = uint16(0) 688 c.state = 0 689 return 690 } 691 nbBitsOut := (first.deltaNbBits + (1 << 15)) >> 16 692 im := int32((nbBitsOut << 16) - first.deltaNbBits) 693 lu := (im >> nbBitsOut) + int32(first.deltaFindState) 694 c.state = c.stateTable[lu] 695 } 696 697 // flush will write the tablelog to the output and flush the remaining full bytes. 698 func (c *cState) flush(tableLog uint8) { 699 c.bw.flush32() 700 c.bw.addBits16NC(c.state, tableLog) 701 }