github.com/AndrienkoAleksandr/go@v0.0.19/src/intern/zstd/block.go (about)

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package zstd
     6  
     7  import (
     8  	"io"
     9  )
    10  
    11  // debug can be set in the source to print debug info using println.
    12  const debug = false
    13  
    14  // compressedBlock decompresses a compressed block, storing the decompressed
    15  // data in r.buffer. The blockSize argument is the compressed size.
    16  // RFC 3.1.1.3.
    17  func (r *Reader) compressedBlock(blockSize int) error {
    18  	if len(r.compressedBuf) >= blockSize {
    19  		r.compressedBuf = r.compressedBuf[:blockSize]
    20  	} else {
    21  		// We know that blockSize <= 128K,
    22  		// so this won't allocate an enormous amount.
    23  		need := blockSize - len(r.compressedBuf)
    24  		r.compressedBuf = append(r.compressedBuf, make([]byte, need)...)
    25  	}
    26  
    27  	if _, err := io.ReadFull(r.r, r.compressedBuf); err != nil {
    28  		return r.wrapNonEOFError(0, err)
    29  	}
    30  
    31  	data := block(r.compressedBuf)
    32  	off := 0
    33  	r.buffer = r.buffer[:0]
    34  
    35  	litoff, litbuf, err := r.readLiterals(data, off, r.literals[:0])
    36  	if err != nil {
    37  		return err
    38  	}
    39  	r.literals = litbuf
    40  
    41  	off = litoff
    42  
    43  	seqCount, off, err := r.initSeqs(data, off)
    44  	if err != nil {
    45  		return err
    46  	}
    47  
    48  	if seqCount == 0 {
    49  		// No sequences, just literals.
    50  		if off < len(data) {
    51  			return r.makeError(off, "extraneous data after no sequences")
    52  		}
    53  		if len(litbuf) == 0 {
    54  			return r.makeError(off, "no sequences and no literals")
    55  		}
    56  		r.buffer = append(r.buffer, litbuf...)
    57  		return nil
    58  	}
    59  
    60  	return r.execSeqs(data, off, litbuf, seqCount)
    61  }
    62  
    63  // seqCode is the kind of sequence codes we have to handle.
    64  type seqCode int
    65  
    66  const (
    67  	seqLiteral seqCode = iota
    68  	seqOffset
    69  	seqMatch
    70  )
    71  
    72  // seqCodeInfoData is the information needed to set up seqTables and
    73  // seqTableBits for a particular kind of sequence code.
    74  type seqCodeInfoData struct {
    75  	predefTable     []fseBaselineEntry // predefined FSE
    76  	predefTableBits int                // number of bits in predefTable
    77  	maxSym          int                // max symbol value in FSE
    78  	maxBits         int                // max bits for FSE
    79  
    80  	// toBaseline converts from an FSE table to an FSE baseline table.
    81  	toBaseline func(*Reader, int, []fseEntry, []fseBaselineEntry) error
    82  }
    83  
    84  // seqCodeInfo is the seqCodeInfoData for each kind of sequence code.
    85  var seqCodeInfo = [3]seqCodeInfoData{
    86  	seqLiteral: {
    87  		predefTable:     predefinedLiteralTable[:],
    88  		predefTableBits: 6,
    89  		maxSym:          35,
    90  		maxBits:         9,
    91  		toBaseline:      (*Reader).makeLiteralBaselineFSE,
    92  	},
    93  	seqOffset: {
    94  		predefTable:     predefinedOffsetTable[:],
    95  		predefTableBits: 5,
    96  		maxSym:          31,
    97  		maxBits:         8,
    98  		toBaseline:      (*Reader).makeOffsetBaselineFSE,
    99  	},
   100  	seqMatch: {
   101  		predefTable:     predefinedMatchTable[:],
   102  		predefTableBits: 6,
   103  		maxSym:          52,
   104  		maxBits:         9,
   105  		toBaseline:      (*Reader).makeMatchBaselineFSE,
   106  	},
   107  }
   108  
   109  // initSeqs reads the Sequences_Section_Header and sets up the FSE
   110  // tables used to read the sequence codes. It returns the number of
   111  // sequences and the new offset. RFC 3.1.1.3.2.1.
   112  func (r *Reader) initSeqs(data block, off int) (int, int, error) {
   113  	if off >= len(data) {
   114  		return 0, 0, r.makeEOFError(off)
   115  	}
   116  
   117  	seqHdr := data[off]
   118  	off++
   119  	if seqHdr == 0 {
   120  		return 0, off, nil
   121  	}
   122  
   123  	var seqCount int
   124  	if seqHdr < 128 {
   125  		seqCount = int(seqHdr)
   126  	} else if seqHdr < 255 {
   127  		if off >= len(data) {
   128  			return 0, 0, r.makeEOFError(off)
   129  		}
   130  		seqCount = ((int(seqHdr) - 128) << 8) + int(data[off])
   131  		off++
   132  	} else {
   133  		if off+1 >= len(data) {
   134  			return 0, 0, r.makeEOFError(off)
   135  		}
   136  		seqCount = int(data[off]) + (int(data[off+1]) << 8) + 0x7f00
   137  		off += 2
   138  	}
   139  
   140  	// Read the Symbol_Compression_Modes byte.
   141  
   142  	if off >= len(data) {
   143  		return 0, 0, r.makeEOFError(off)
   144  	}
   145  	symMode := data[off]
   146  	if symMode&3 != 0 {
   147  		return 0, 0, r.makeError(off, "invalid symbol compression mode")
   148  	}
   149  	off++
   150  
   151  	// Set up the FSE tables used to decode the sequence codes.
   152  
   153  	var err error
   154  	off, err = r.setSeqTable(data, off, seqLiteral, (symMode>>6)&3)
   155  	if err != nil {
   156  		return 0, 0, err
   157  	}
   158  
   159  	off, err = r.setSeqTable(data, off, seqOffset, (symMode>>4)&3)
   160  	if err != nil {
   161  		return 0, 0, err
   162  	}
   163  
   164  	off, err = r.setSeqTable(data, off, seqMatch, (symMode>>2)&3)
   165  	if err != nil {
   166  		return 0, 0, err
   167  	}
   168  
   169  	return seqCount, off, nil
   170  }
   171  
   172  // setSeqTable uses the Compression_Mode in mode to set up r.seqTables and
   173  // r.seqTableBits for kind. We store these in the Reader because one of
   174  // the modes simply reuses the value from the last block in the frame.
   175  func (r *Reader) setSeqTable(data block, off int, kind seqCode, mode byte) (int, error) {
   176  	info := &seqCodeInfo[kind]
   177  	switch mode {
   178  	case 0:
   179  		// Predefined_Mode
   180  		r.seqTables[kind] = info.predefTable
   181  		r.seqTableBits[kind] = uint8(info.predefTableBits)
   182  		return off, nil
   183  
   184  	case 1:
   185  		// RLE_Mode
   186  		if off >= len(data) {
   187  			return 0, r.makeEOFError(off)
   188  		}
   189  		rle := data[off]
   190  		off++
   191  
   192  		// Build a simple baseline table that always returns rle.
   193  
   194  		entry := []fseEntry{
   195  			{
   196  				sym:  rle,
   197  				bits: 0,
   198  				base: 0,
   199  			},
   200  		}
   201  		if cap(r.seqTableBuffers[kind]) == 0 {
   202  			r.seqTableBuffers[kind] = make([]fseBaselineEntry, 1<<info.maxBits)
   203  		}
   204  		r.seqTableBuffers[kind] = r.seqTableBuffers[kind][:1]
   205  		if err := info.toBaseline(r, off, entry, r.seqTableBuffers[kind]); err != nil {
   206  			return 0, err
   207  		}
   208  
   209  		r.seqTables[kind] = r.seqTableBuffers[kind]
   210  		r.seqTableBits[kind] = 0
   211  		return off, nil
   212  
   213  	case 2:
   214  		// FSE_Compressed_Mode
   215  		if cap(r.fseScratch) < 1<<info.maxBits {
   216  			r.fseScratch = make([]fseEntry, 1<<info.maxBits)
   217  		}
   218  		r.fseScratch = r.fseScratch[:1<<info.maxBits]
   219  
   220  		tableBits, roff, err := r.readFSE(data, off, info.maxSym, info.maxBits, r.fseScratch)
   221  		if err != nil {
   222  			return 0, err
   223  		}
   224  		r.fseScratch = r.fseScratch[:1<<tableBits]
   225  
   226  		if cap(r.seqTableBuffers[kind]) == 0 {
   227  			r.seqTableBuffers[kind] = make([]fseBaselineEntry, 1<<info.maxBits)
   228  		}
   229  		r.seqTableBuffers[kind] = r.seqTableBuffers[kind][:1<<tableBits]
   230  
   231  		if err := info.toBaseline(r, roff, r.fseScratch, r.seqTableBuffers[kind]); err != nil {
   232  			return 0, err
   233  		}
   234  
   235  		r.seqTables[kind] = r.seqTableBuffers[kind]
   236  		r.seqTableBits[kind] = uint8(tableBits)
   237  		return roff, nil
   238  
   239  	case 3:
   240  		// Repeat_Mode
   241  		if len(r.seqTables[kind]) == 0 {
   242  			return 0, r.makeError(off, "missing repeat sequence FSE table")
   243  		}
   244  		return off, nil
   245  	}
   246  	panic("unreachable")
   247  }
   248  
   249  // execSeqs reads and executes the sequences. RFC 3.1.1.3.2.1.2.
   250  func (r *Reader) execSeqs(data block, off int, litbuf []byte, seqCount int) error {
   251  	// Set up the initial states for the sequence code readers.
   252  
   253  	rbr, err := r.makeReverseBitReader(data, len(data)-1, off)
   254  	if err != nil {
   255  		return err
   256  	}
   257  
   258  	literalState, err := rbr.val(r.seqTableBits[seqLiteral])
   259  	if err != nil {
   260  		return err
   261  	}
   262  
   263  	offsetState, err := rbr.val(r.seqTableBits[seqOffset])
   264  	if err != nil {
   265  		return err
   266  	}
   267  
   268  	matchState, err := rbr.val(r.seqTableBits[seqMatch])
   269  	if err != nil {
   270  		return err
   271  	}
   272  
   273  	// Read and perform all the sequences. RFC 3.1.1.4.
   274  
   275  	seq := 0
   276  	for seq < seqCount {
   277  		if len(r.buffer)+len(litbuf) > 128<<10 {
   278  			return rbr.makeError("uncompressed size too big")
   279  		}
   280  
   281  		ptoffset := &r.seqTables[seqOffset][offsetState]
   282  		ptmatch := &r.seqTables[seqMatch][matchState]
   283  		ptliteral := &r.seqTables[seqLiteral][literalState]
   284  
   285  		add, err := rbr.val(ptoffset.basebits)
   286  		if err != nil {
   287  			return err
   288  		}
   289  		offset := ptoffset.baseline + add
   290  
   291  		add, err = rbr.val(ptmatch.basebits)
   292  		if err != nil {
   293  			return err
   294  		}
   295  		match := ptmatch.baseline + add
   296  
   297  		add, err = rbr.val(ptliteral.basebits)
   298  		if err != nil {
   299  			return err
   300  		}
   301  		literal := ptliteral.baseline + add
   302  
   303  		// Handle repeat offsets. RFC 3.1.1.5.
   304  		// See the comment in makeOffsetBaselineFSE.
   305  		if ptoffset.basebits > 1 {
   306  			r.repeatedOffset3 = r.repeatedOffset2
   307  			r.repeatedOffset2 = r.repeatedOffset1
   308  			r.repeatedOffset1 = offset
   309  		} else {
   310  			if literal == 0 {
   311  				offset++
   312  			}
   313  			switch offset {
   314  			case 1:
   315  				offset = r.repeatedOffset1
   316  			case 2:
   317  				offset = r.repeatedOffset2
   318  				r.repeatedOffset2 = r.repeatedOffset1
   319  				r.repeatedOffset1 = offset
   320  			case 3:
   321  				offset = r.repeatedOffset3
   322  				r.repeatedOffset3 = r.repeatedOffset2
   323  				r.repeatedOffset2 = r.repeatedOffset1
   324  				r.repeatedOffset1 = offset
   325  			case 4:
   326  				offset = r.repeatedOffset1 - 1
   327  				r.repeatedOffset3 = r.repeatedOffset2
   328  				r.repeatedOffset2 = r.repeatedOffset1
   329  				r.repeatedOffset1 = offset
   330  			}
   331  		}
   332  
   333  		seq++
   334  		if seq < seqCount {
   335  			// Update the states.
   336  			add, err = rbr.val(ptliteral.bits)
   337  			if err != nil {
   338  				return err
   339  			}
   340  			literalState = uint32(ptliteral.base) + add
   341  
   342  			add, err = rbr.val(ptmatch.bits)
   343  			if err != nil {
   344  				return err
   345  			}
   346  			matchState = uint32(ptmatch.base) + add
   347  
   348  			add, err = rbr.val(ptoffset.bits)
   349  			if err != nil {
   350  				return err
   351  			}
   352  			offsetState = uint32(ptoffset.base) + add
   353  		}
   354  
   355  		// The next sequence is now in literal, offset, match.
   356  
   357  		if debug {
   358  			println("literal", literal, "offset", offset, "match", match)
   359  		}
   360  
   361  		// Copy literal bytes from litbuf.
   362  		if literal > uint32(len(litbuf)) {
   363  			return rbr.makeError("literal byte overflow")
   364  		}
   365  		if literal > 0 {
   366  			r.buffer = append(r.buffer, litbuf[:literal]...)
   367  			litbuf = litbuf[literal:]
   368  		}
   369  
   370  		if match > 0 {
   371  			if err := r.copyFromWindow(&rbr, offset, match); err != nil {
   372  				return err
   373  			}
   374  		}
   375  	}
   376  
   377  	if len(litbuf) > 0 {
   378  		r.buffer = append(r.buffer, litbuf...)
   379  	}
   380  
   381  	if rbr.cnt != 0 {
   382  		return r.makeError(off, "extraneous data after sequences")
   383  	}
   384  
   385  	return nil
   386  }
   387  
   388  // Copy match bytes from the decoded output, or the window, at offset.
   389  func (r *Reader) copyFromWindow(rbr *reverseBitReader, offset, match uint32) error {
   390  	if offset == 0 {
   391  		return rbr.makeError("invalid zero offset")
   392  	}
   393  
   394  	lenBlock := uint32(len(r.buffer))
   395  	if lenBlock < offset {
   396  		lenWindow := uint32(len(r.window))
   397  		windowOffset := offset - lenBlock
   398  		if windowOffset > lenWindow {
   399  			return rbr.makeError("offset past window")
   400  		}
   401  		from := lenWindow - windowOffset
   402  		if from+match <= lenWindow {
   403  			r.buffer = append(r.buffer, r.window[from:from+match]...)
   404  			return nil
   405  		}
   406  		r.buffer = append(r.buffer, r.window[from:]...)
   407  		copied := lenWindow - from
   408  		offset -= copied
   409  		match -= copied
   410  
   411  		if offset == 0 && match > 0 {
   412  			return rbr.makeError("invalid offset")
   413  		}
   414  	}
   415  
   416  	from := lenBlock - offset
   417  	if offset >= match {
   418  		r.buffer = append(r.buffer, r.buffer[from:from+match]...)
   419  		return nil
   420  	}
   421  
   422  	// We are being asked to copy data that we are adding to the
   423  	// buffer in the same copy.
   424  	for match > 0 {
   425  		var copy uint32
   426  		if offset >= match {
   427  			copy = match
   428  		} else {
   429  			copy = offset
   430  		}
   431  		r.buffer = append(r.buffer, r.buffer[from:from+copy]...)
   432  		match -= copy
   433  		from += copy
   434  	}
   435  	return nil
   436  }