github.com/go-asm/go@v1.21.1-0.20240213172139-40c5ead50c48/zstd/fse.go (about)

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package zstd
     6  
     7  import (
     8  	"math/bits"
     9  )
    10  
    11  // fseEntry is one entry in an FSE table.
    12  type fseEntry struct {
    13  	sym  uint8  // value that this entry records
    14  	bits uint8  // number of bits to read to determine next state
    15  	base uint16 // add those bits to this state to get the next state
    16  }
    17  
    18  // readFSE reads an FSE table from data starting at off.
    19  // maxSym is the maximum symbol value.
    20  // maxBits is the maximum number of bits permitted for symbols in the table.
    21  // The FSE is written into table, which must be at least 1<<maxBits in size.
    22  // This returns the number of bits in the FSE table and the new offset.
    23  // RFC 4.1.1.
    24  func (r *Reader) readFSE(data block, off, maxSym, maxBits int, table []fseEntry) (tableBits, roff int, err error) {
    25  	br := r.makeBitReader(data, off)
    26  	if err := br.moreBits(); err != nil {
    27  		return 0, 0, err
    28  	}
    29  
    30  	accuracyLog := int(br.val(4)) + 5
    31  	if accuracyLog > maxBits {
    32  		return 0, 0, br.makeError("FSE accuracy log too large")
    33  	}
    34  
    35  	// The number of remaining probabilities, plus 1.
    36  	// This determines the number of bits to be read for the next value.
    37  	remaining := (1 << accuracyLog) + 1
    38  
    39  	// The current difference between small and large values,
    40  	// which depends on the number of remaining values.
    41  	// Small values use 1 less bit.
    42  	threshold := 1 << accuracyLog
    43  
    44  	// The number of bits needed to compute threshold.
    45  	bitsNeeded := accuracyLog + 1
    46  
    47  	// The next character value.
    48  	sym := 0
    49  
    50  	// Whether the last count was 0.
    51  	prev0 := false
    52  
    53  	var norm [256]int16
    54  
    55  	for remaining > 1 && sym <= maxSym {
    56  		if err := br.moreBits(); err != nil {
    57  			return 0, 0, err
    58  		}
    59  
    60  		if prev0 {
    61  			// Previous count was 0, so there is a 2-bit
    62  			// repeat flag. If the 2-bit flag is 0b11,
    63  			// it adds 3 and then there is another repeat flag.
    64  			zsym := sym
    65  			for (br.bits & 0xfff) == 0xfff {
    66  				zsym += 3 * 6
    67  				br.bits >>= 12
    68  				br.cnt -= 12
    69  				if err := br.moreBits(); err != nil {
    70  					return 0, 0, err
    71  				}
    72  			}
    73  			for (br.bits & 3) == 3 {
    74  				zsym += 3
    75  				br.bits >>= 2
    76  				br.cnt -= 2
    77  				if err := br.moreBits(); err != nil {
    78  					return 0, 0, err
    79  				}
    80  			}
    81  
    82  			// We have at least 14 bits here,
    83  			// no need to call moreBits
    84  
    85  			zsym += int(br.val(2))
    86  
    87  			if zsym > maxSym {
    88  				return 0, 0, br.makeError("FSE symbol index overflow")
    89  			}
    90  
    91  			for ; sym < zsym; sym++ {
    92  				norm[uint8(sym)] = 0
    93  			}
    94  
    95  			prev0 = false
    96  			continue
    97  		}
    98  
    99  		max := (2*threshold - 1) - remaining
   100  		var count int
   101  		if int(br.bits&uint32(threshold-1)) < max {
   102  			// A small value.
   103  			count = int(br.bits & uint32((threshold - 1)))
   104  			br.bits >>= bitsNeeded - 1
   105  			br.cnt -= uint32(bitsNeeded - 1)
   106  		} else {
   107  			// A large value.
   108  			count = int(br.bits & uint32((2*threshold - 1)))
   109  			if count >= threshold {
   110  				count -= max
   111  			}
   112  			br.bits >>= bitsNeeded
   113  			br.cnt -= uint32(bitsNeeded)
   114  		}
   115  
   116  		count--
   117  		if count >= 0 {
   118  			remaining -= count
   119  		} else {
   120  			remaining--
   121  		}
   122  		if sym >= 256 {
   123  			return 0, 0, br.makeError("FSE sym overflow")
   124  		}
   125  		norm[uint8(sym)] = int16(count)
   126  		sym++
   127  
   128  		prev0 = count == 0
   129  
   130  		for remaining < threshold {
   131  			bitsNeeded--
   132  			threshold >>= 1
   133  		}
   134  	}
   135  
   136  	if remaining != 1 {
   137  		return 0, 0, br.makeError("too many symbols in FSE table")
   138  	}
   139  
   140  	for ; sym <= maxSym; sym++ {
   141  		norm[uint8(sym)] = 0
   142  	}
   143  
   144  	br.backup()
   145  
   146  	if err := r.buildFSE(off, norm[:maxSym+1], table, accuracyLog); err != nil {
   147  		return 0, 0, err
   148  	}
   149  
   150  	return accuracyLog, int(br.off), nil
   151  }
   152  
   153  // buildFSE builds an FSE decoding table from a list of probabilities.
   154  // The probabilities are in norm. next is scratch space. The number of bits
   155  // in the table is tableBits.
   156  func (r *Reader) buildFSE(off int, norm []int16, table []fseEntry, tableBits int) error {
   157  	tableSize := 1 << tableBits
   158  	highThreshold := tableSize - 1
   159  
   160  	var next [256]uint16
   161  
   162  	for i, n := range norm {
   163  		if n >= 0 {
   164  			next[uint8(i)] = uint16(n)
   165  		} else {
   166  			table[highThreshold].sym = uint8(i)
   167  			highThreshold--
   168  			next[uint8(i)] = 1
   169  		}
   170  	}
   171  
   172  	pos := 0
   173  	step := (tableSize >> 1) + (tableSize >> 3) + 3
   174  	mask := tableSize - 1
   175  	for i, n := range norm {
   176  		for j := 0; j < int(n); j++ {
   177  			table[pos].sym = uint8(i)
   178  			pos = (pos + step) & mask
   179  			for pos > highThreshold {
   180  				pos = (pos + step) & mask
   181  			}
   182  		}
   183  	}
   184  	if pos != 0 {
   185  		return r.makeError(off, "FSE count error")
   186  	}
   187  
   188  	for i := 0; i < tableSize; i++ {
   189  		sym := table[i].sym
   190  		nextState := next[sym]
   191  		next[sym]++
   192  
   193  		if nextState == 0 {
   194  			return r.makeError(off, "FSE state error")
   195  		}
   196  
   197  		highBit := 15 - bits.LeadingZeros16(nextState)
   198  
   199  		bits := tableBits - highBit
   200  		table[i].bits = uint8(bits)
   201  		table[i].base = (nextState << bits) - uint16(tableSize)
   202  	}
   203  
   204  	return nil
   205  }
   206  
   207  // fseBaselineEntry is an entry in an FSE baseline table.
   208  // We use these for literal/match/length values.
   209  // Those require mapping the symbol to a baseline value,
   210  // and then reading zero or more bits and adding the value to the baseline.
   211  // Rather than looking these up in separate tables,
   212  // we convert the FSE table to an FSE baseline table.
   213  type fseBaselineEntry struct {
   214  	baseline uint32 // baseline for value that this entry represents
   215  	basebits uint8  // number of bits to read to add to baseline
   216  	bits     uint8  // number of bits to read to determine next state
   217  	base     uint16 // add the bits to this base to get the next state
   218  }
   219  
   220  // Given a literal length code, we need to read a number of bits and
   221  // add that to a baseline. For states 0 to 15 the baseline is the
   222  // state and the number of bits is zero. RFC 3.1.1.3.2.1.1.
   223  
   224  const literalLengthOffset = 16
   225  
   226  var literalLengthBase = []uint32{
   227  	16 | (1 << 24),
   228  	18 | (1 << 24),
   229  	20 | (1 << 24),
   230  	22 | (1 << 24),
   231  	24 | (2 << 24),
   232  	28 | (2 << 24),
   233  	32 | (3 << 24),
   234  	40 | (3 << 24),
   235  	48 | (4 << 24),
   236  	64 | (6 << 24),
   237  	128 | (7 << 24),
   238  	256 | (8 << 24),
   239  	512 | (9 << 24),
   240  	1024 | (10 << 24),
   241  	2048 | (11 << 24),
   242  	4096 | (12 << 24),
   243  	8192 | (13 << 24),
   244  	16384 | (14 << 24),
   245  	32768 | (15 << 24),
   246  	65536 | (16 << 24),
   247  }
   248  
   249  // makeLiteralBaselineFSE converts the literal length fseTable to baselineTable.
   250  func (r *Reader) makeLiteralBaselineFSE(off int, fseTable []fseEntry, baselineTable []fseBaselineEntry) error {
   251  	for i, e := range fseTable {
   252  		be := fseBaselineEntry{
   253  			bits: e.bits,
   254  			base: e.base,
   255  		}
   256  		if e.sym < literalLengthOffset {
   257  			be.baseline = uint32(e.sym)
   258  			be.basebits = 0
   259  		} else {
   260  			if e.sym > 35 {
   261  				return r.makeError(off, "FSE baseline symbol overflow")
   262  			}
   263  			idx := e.sym - literalLengthOffset
   264  			basebits := literalLengthBase[idx]
   265  			be.baseline = basebits & 0xffffff
   266  			be.basebits = uint8(basebits >> 24)
   267  		}
   268  		baselineTable[i] = be
   269  	}
   270  	return nil
   271  }
   272  
   273  // makeOffsetBaselineFSE converts the offset length fseTable to baselineTable.
   274  func (r *Reader) makeOffsetBaselineFSE(off int, fseTable []fseEntry, baselineTable []fseBaselineEntry) error {
   275  	for i, e := range fseTable {
   276  		be := fseBaselineEntry{
   277  			bits: e.bits,
   278  			base: e.base,
   279  		}
   280  		if e.sym > 31 {
   281  			return r.makeError(off, "FSE offset symbol overflow")
   282  		}
   283  
   284  		// The simple way to write this is
   285  		//     be.baseline = 1 << e.sym
   286  		//     be.basebits = e.sym
   287  		// That would give us an offset value that corresponds to
   288  		// the one described in the RFC. However, for offsets > 3
   289  		// we have to subtract 3. And for offset values 1, 2, 3
   290  		// we use a repeated offset.
   291  		//
   292  		// The baseline is always a power of 2, and is never 0,
   293  		// so for those low values we will see one entry that is
   294  		// baseline 1, basebits 0, and one entry that is baseline 2,
   295  		// basebits 1. All other entries will have baseline >= 4
   296  		// basebits >= 2.
   297  		//
   298  		// So we can check for RFC offset <= 3 by checking for
   299  		// basebits <= 1. That means that we can subtract 3 here
   300  		// and not worry about doing it in the hot loop.
   301  
   302  		be.baseline = 1 << e.sym
   303  		if e.sym >= 2 {
   304  			be.baseline -= 3
   305  		}
   306  		be.basebits = e.sym
   307  		baselineTable[i] = be
   308  	}
   309  	return nil
   310  }
   311  
   312  // Given a match length code, we need to read a number of bits and add
   313  // that to a baseline. For states 0 to 31 the baseline is state+3 and
   314  // the number of bits is zero. RFC 3.1.1.3.2.1.1.
   315  
   316  const matchLengthOffset = 32
   317  
   318  var matchLengthBase = []uint32{
   319  	35 | (1 << 24),
   320  	37 | (1 << 24),
   321  	39 | (1 << 24),
   322  	41 | (1 << 24),
   323  	43 | (2 << 24),
   324  	47 | (2 << 24),
   325  	51 | (3 << 24),
   326  	59 | (3 << 24),
   327  	67 | (4 << 24),
   328  	83 | (4 << 24),
   329  	99 | (5 << 24),
   330  	131 | (7 << 24),
   331  	259 | (8 << 24),
   332  	515 | (9 << 24),
   333  	1027 | (10 << 24),
   334  	2051 | (11 << 24),
   335  	4099 | (12 << 24),
   336  	8195 | (13 << 24),
   337  	16387 | (14 << 24),
   338  	32771 | (15 << 24),
   339  	65539 | (16 << 24),
   340  }
   341  
   342  // makeMatchBaselineFSE converts the match length fseTable to baselineTable.
   343  func (r *Reader) makeMatchBaselineFSE(off int, fseTable []fseEntry, baselineTable []fseBaselineEntry) error {
   344  	for i, e := range fseTable {
   345  		be := fseBaselineEntry{
   346  			bits: e.bits,
   347  			base: e.base,
   348  		}
   349  		if e.sym < matchLengthOffset {
   350  			be.baseline = uint32(e.sym) + 3
   351  			be.basebits = 0
   352  		} else {
   353  			if e.sym > 52 {
   354  				return r.makeError(off, "FSE baseline symbol overflow")
   355  			}
   356  			idx := e.sym - matchLengthOffset
   357  			basebits := matchLengthBase[idx]
   358  			be.baseline = basebits & 0xffffff
   359  			be.basebits = uint8(basebits >> 24)
   360  		}
   361  		baselineTable[i] = be
   362  	}
   363  	return nil
   364  }
   365  
   366  // predefinedLiteralTable is the predefined table to use for literal lengths.
   367  // Generated from table in RFC 3.1.1.3.2.2.1.
   368  // Checked by TestPredefinedTables.
   369  var predefinedLiteralTable = [...]fseBaselineEntry{
   370  	{0, 0, 4, 0}, {0, 0, 4, 16}, {1, 0, 5, 32},
   371  	{3, 0, 5, 0}, {4, 0, 5, 0}, {6, 0, 5, 0},
   372  	{7, 0, 5, 0}, {9, 0, 5, 0}, {10, 0, 5, 0},
   373  	{12, 0, 5, 0}, {14, 0, 6, 0}, {16, 1, 5, 0},
   374  	{20, 1, 5, 0}, {22, 1, 5, 0}, {28, 2, 5, 0},
   375  	{32, 3, 5, 0}, {48, 4, 5, 0}, {64, 6, 5, 32},
   376  	{128, 7, 5, 0}, {256, 8, 6, 0}, {1024, 10, 6, 0},
   377  	{4096, 12, 6, 0}, {0, 0, 4, 32}, {1, 0, 4, 0},
   378  	{2, 0, 5, 0}, {4, 0, 5, 32}, {5, 0, 5, 0},
   379  	{7, 0, 5, 32}, {8, 0, 5, 0}, {10, 0, 5, 32},
   380  	{11, 0, 5, 0}, {13, 0, 6, 0}, {16, 1, 5, 32},
   381  	{18, 1, 5, 0}, {22, 1, 5, 32}, {24, 2, 5, 0},
   382  	{32, 3, 5, 32}, {40, 3, 5, 0}, {64, 6, 4, 0},
   383  	{64, 6, 4, 16}, {128, 7, 5, 32}, {512, 9, 6, 0},
   384  	{2048, 11, 6, 0}, {0, 0, 4, 48}, {1, 0, 4, 16},
   385  	{2, 0, 5, 32}, {3, 0, 5, 32}, {5, 0, 5, 32},
   386  	{6, 0, 5, 32}, {8, 0, 5, 32}, {9, 0, 5, 32},
   387  	{11, 0, 5, 32}, {12, 0, 5, 32}, {15, 0, 6, 0},
   388  	{18, 1, 5, 32}, {20, 1, 5, 32}, {24, 2, 5, 32},
   389  	{28, 2, 5, 32}, {40, 3, 5, 32}, {48, 4, 5, 32},
   390  	{65536, 16, 6, 0}, {32768, 15, 6, 0}, {16384, 14, 6, 0},
   391  	{8192, 13, 6, 0},
   392  }
   393  
   394  // predefinedOffsetTable is the predefined table to use for offsets.
   395  // Generated from table in RFC 3.1.1.3.2.2.3.
   396  // Checked by TestPredefinedTables.
   397  var predefinedOffsetTable = [...]fseBaselineEntry{
   398  	{1, 0, 5, 0}, {61, 6, 4, 0}, {509, 9, 5, 0},
   399  	{32765, 15, 5, 0}, {2097149, 21, 5, 0}, {5, 3, 5, 0},
   400  	{125, 7, 4, 0}, {4093, 12, 5, 0}, {262141, 18, 5, 0},
   401  	{8388605, 23, 5, 0}, {29, 5, 5, 0}, {253, 8, 4, 0},
   402  	{16381, 14, 5, 0}, {1048573, 20, 5, 0}, {1, 2, 5, 0},
   403  	{125, 7, 4, 16}, {2045, 11, 5, 0}, {131069, 17, 5, 0},
   404  	{4194301, 22, 5, 0}, {13, 4, 5, 0}, {253, 8, 4, 16},
   405  	{8189, 13, 5, 0}, {524285, 19, 5, 0}, {2, 1, 5, 0},
   406  	{61, 6, 4, 16}, {1021, 10, 5, 0}, {65533, 16, 5, 0},
   407  	{268435453, 28, 5, 0}, {134217725, 27, 5, 0}, {67108861, 26, 5, 0},
   408  	{33554429, 25, 5, 0}, {16777213, 24, 5, 0},
   409  }
   410  
   411  // predefinedMatchTable is the predefined table to use for match lengths.
   412  // Generated from table in RFC 3.1.1.3.2.2.2.
   413  // Checked by TestPredefinedTables.
   414  var predefinedMatchTable = [...]fseBaselineEntry{
   415  	{3, 0, 6, 0}, {4, 0, 4, 0}, {5, 0, 5, 32},
   416  	{6, 0, 5, 0}, {8, 0, 5, 0}, {9, 0, 5, 0},
   417  	{11, 0, 5, 0}, {13, 0, 6, 0}, {16, 0, 6, 0},
   418  	{19, 0, 6, 0}, {22, 0, 6, 0}, {25, 0, 6, 0},
   419  	{28, 0, 6, 0}, {31, 0, 6, 0}, {34, 0, 6, 0},
   420  	{37, 1, 6, 0}, {41, 1, 6, 0}, {47, 2, 6, 0},
   421  	{59, 3, 6, 0}, {83, 4, 6, 0}, {131, 7, 6, 0},
   422  	{515, 9, 6, 0}, {4, 0, 4, 16}, {5, 0, 4, 0},
   423  	{6, 0, 5, 32}, {7, 0, 5, 0}, {9, 0, 5, 32},
   424  	{10, 0, 5, 0}, {12, 0, 6, 0}, {15, 0, 6, 0},
   425  	{18, 0, 6, 0}, {21, 0, 6, 0}, {24, 0, 6, 0},
   426  	{27, 0, 6, 0}, {30, 0, 6, 0}, {33, 0, 6, 0},
   427  	{35, 1, 6, 0}, {39, 1, 6, 0}, {43, 2, 6, 0},
   428  	{51, 3, 6, 0}, {67, 4, 6, 0}, {99, 5, 6, 0},
   429  	{259, 8, 6, 0}, {4, 0, 4, 32}, {4, 0, 4, 48},
   430  	{5, 0, 4, 16}, {7, 0, 5, 32}, {8, 0, 5, 32},
   431  	{10, 0, 5, 32}, {11, 0, 5, 32}, {14, 0, 6, 0},
   432  	{17, 0, 6, 0}, {20, 0, 6, 0}, {23, 0, 6, 0},
   433  	{26, 0, 6, 0}, {29, 0, 6, 0}, {32, 0, 6, 0},
   434  	{65539, 16, 6, 0}, {32771, 15, 6, 0}, {16387, 14, 6, 0},
   435  	{8195, 13, 6, 0}, {4099, 12, 6, 0}, {2051, 11, 6, 0},
   436  	{1027, 10, 6, 0},
   437  }