github.com/twelsh-aw/go/src@v0.0.0-20230516233729-a56fe86a7c81/internal/zstd/zstd.go (about)

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package zstd provides a decompressor for zstd streams,
     6  // described in RFC 8878. It does not support dictionaries.
     7  package zstd
     8  
     9  import (
    10  	"encoding/binary"
    11  	"errors"
    12  	"fmt"
    13  	"io"
    14  )
    15  
    16  // fuzzing is a fuzzer hook set to true when fuzzing.
    17  // This is used to reject cases where we don't match zstd.
    18  var fuzzing = false
    19  
    20  // Reader implements [io.Reader] to read a zstd compressed stream.
    21  type Reader struct {
    22  	// The underlying Reader.
    23  	r io.Reader
    24  
    25  	// Whether we have read the frame header.
    26  	// This is of interest when buffer is empty.
    27  	// If true we expect to see a new block.
    28  	sawFrameHeader bool
    29  
    30  	// Whether the current frame expects a checksum.
    31  	hasChecksum bool
    32  
    33  	// Whether we have read at least one frame.
    34  	readOneFrame bool
    35  
    36  	// True if the frame size is not known.
    37  	frameSizeUnknown bool
    38  
    39  	// The number of uncompressed bytes remaining in the current frame.
    40  	// If frameSizeUnknown is true, this is not valid.
    41  	remainingFrameSize uint64
    42  
    43  	// The number of bytes read from r up to the start of the current
    44  	// block, for error reporting.
    45  	blockOffset int64
    46  
    47  	// Buffered decompressed data.
    48  	buffer []byte
    49  	// Current read offset in buffer.
    50  	off int
    51  
    52  	// The current repeated offsets.
    53  	repeatedOffset1 uint32
    54  	repeatedOffset2 uint32
    55  	repeatedOffset3 uint32
    56  
    57  	// The current Huffman tree used for compressing literals.
    58  	huffmanTable     []uint16
    59  	huffmanTableBits int
    60  
    61  	// The window for back references.
    62  	windowSize int    // maximum required window size
    63  	window     []byte // window data
    64  
    65  	// A buffer available to hold a compressed block.
    66  	compressedBuf []byte
    67  
    68  	// A buffer for literals.
    69  	literals []byte
    70  
    71  	// Sequence decode FSE tables.
    72  	seqTables    [3][]fseBaselineEntry
    73  	seqTableBits [3]uint8
    74  
    75  	// Buffers for sequence decode FSE tables.
    76  	seqTableBuffers [3][]fseBaselineEntry
    77  
    78  	// Scratch space used for small reads, to avoid allocation.
    79  	scratch [16]byte
    80  
    81  	// A scratch table for reading an FSE. Only temporarily valid.
    82  	fseScratch []fseEntry
    83  
    84  	// For checksum computation.
    85  	checksum xxhash64
    86  }
    87  
    88  // NewReader creates a new Reader that decompresses data from the given reader.
    89  func NewReader(input io.Reader) *Reader {
    90  	r := new(Reader)
    91  	r.Reset(input)
    92  	return r
    93  }
    94  
    95  // Reset discards the current state and starts reading a new stream from r.
    96  // This permits reusing a Reader rather than allocating a new one.
    97  func (r *Reader) Reset(input io.Reader) {
    98  	r.r = input
    99  
   100  	// Several fields are preserved to avoid allocation.
   101  	// Others are always set before they are used.
   102  	r.sawFrameHeader = false
   103  	r.hasChecksum = false
   104  	r.readOneFrame = false
   105  	r.frameSizeUnknown = false
   106  	r.remainingFrameSize = 0
   107  	r.blockOffset = 0
   108  	// buffer
   109  	r.off = 0
   110  	// repeatedOffset1
   111  	// repeatedOffset2
   112  	// repeatedOffset3
   113  	// huffmanTable
   114  	// huffmanTableBits
   115  	// windowSize
   116  	// window
   117  	// compressedBuf
   118  	// literals
   119  	// seqTables
   120  	// seqTableBits
   121  	// seqTableBuffers
   122  	// scratch
   123  	// fseScratch
   124  }
   125  
   126  // Read implements [io.Reader].
   127  func (r *Reader) Read(p []byte) (int, error) {
   128  	if err := r.refillIfNeeded(); err != nil {
   129  		return 0, err
   130  	}
   131  	n := copy(p, r.buffer[r.off:])
   132  	r.off += n
   133  	return n, nil
   134  }
   135  
   136  // ReadByte implements [io.ByteReader].
   137  func (r *Reader) ReadByte() (byte, error) {
   138  	if err := r.refillIfNeeded(); err != nil {
   139  		return 0, err
   140  	}
   141  	ret := r.buffer[r.off]
   142  	r.off++
   143  	return ret, nil
   144  }
   145  
   146  // refillIfNeeded reads the next block if necessary.
   147  func (r *Reader) refillIfNeeded() error {
   148  	for r.off >= len(r.buffer) {
   149  		if err := r.refill(); err != nil {
   150  			return err
   151  		}
   152  		r.off = 0
   153  	}
   154  	return nil
   155  }
   156  
   157  // refill reads and decompresses the next block.
   158  func (r *Reader) refill() error {
   159  	if !r.sawFrameHeader {
   160  		if err := r.readFrameHeader(); err != nil {
   161  			return err
   162  		}
   163  	}
   164  	return r.readBlock()
   165  }
   166  
   167  // readFrameHeader reads the frame header and prepares to read a block.
   168  func (r *Reader) readFrameHeader() error {
   169  retry:
   170  	relativeOffset := 0
   171  
   172  	// Read magic number. RFC 3.1.1.
   173  	if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
   174  		// We require that the stream contain at least one frame.
   175  		if err == io.EOF && !r.readOneFrame {
   176  			err = io.ErrUnexpectedEOF
   177  		}
   178  		return r.wrapError(relativeOffset, err)
   179  	}
   180  
   181  	if magic := binary.LittleEndian.Uint32(r.scratch[:4]); magic != 0xfd2fb528 {
   182  		if magic >= 0x184d2a50 && magic <= 0x184d2a5f {
   183  			// This is a skippable frame.
   184  			r.blockOffset += int64(relativeOffset) + 4
   185  			if err := r.skipFrame(); err != nil {
   186  				return err
   187  			}
   188  			goto retry
   189  		}
   190  
   191  		return r.makeError(relativeOffset, "invalid magic number")
   192  	}
   193  
   194  	relativeOffset += 4
   195  
   196  	// Read Frame_Header_Descriptor. RFC 3.1.1.1.1.
   197  	if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil {
   198  		return r.wrapNonEOFError(relativeOffset, err)
   199  	}
   200  	descriptor := r.scratch[0]
   201  
   202  	singleSegment := descriptor&(1<<5) != 0
   203  
   204  	fcsFieldSize := 1 << (descriptor >> 6)
   205  	if fcsFieldSize == 1 && !singleSegment {
   206  		fcsFieldSize = 0
   207  	}
   208  
   209  	var windowDescriptorSize int
   210  	if singleSegment {
   211  		windowDescriptorSize = 0
   212  	} else {
   213  		windowDescriptorSize = 1
   214  	}
   215  
   216  	if descriptor&(1<<3) != 0 {
   217  		return r.makeError(relativeOffset, "reserved bit set in frame header descriptor")
   218  	}
   219  
   220  	r.hasChecksum = descriptor&(1<<2) != 0
   221  	if r.hasChecksum {
   222  		r.checksum.reset()
   223  	}
   224  
   225  	if descriptor&3 != 0 {
   226  		return r.makeError(relativeOffset, "dictionaries are not supported")
   227  	}
   228  
   229  	relativeOffset++
   230  
   231  	headerSize := windowDescriptorSize + fcsFieldSize
   232  
   233  	if _, err := io.ReadFull(r.r, r.scratch[:headerSize]); err != nil {
   234  		return r.wrapNonEOFError(relativeOffset, err)
   235  	}
   236  
   237  	// Figure out the maximum amount of data we need to retain
   238  	// for backreferences.
   239  
   240  	if singleSegment {
   241  		// No window required, as all the data is in a single buffer.
   242  		r.windowSize = 0
   243  	} else {
   244  		// Window descriptor. RFC 3.1.1.1.2.
   245  		windowDescriptor := r.scratch[0]
   246  		exponent := uint64(windowDescriptor >> 3)
   247  		mantissa := uint64(windowDescriptor & 7)
   248  		windowLog := exponent + 10
   249  		windowBase := uint64(1) << windowLog
   250  		windowAdd := (windowBase / 8) * mantissa
   251  		windowSize := windowBase + windowAdd
   252  
   253  		// Default zstd sets limits on the window size.
   254  		if fuzzing && (windowLog > 31 || windowSize > 1<<27) {
   255  			return r.makeError(relativeOffset, "windowSize too large")
   256  		}
   257  
   258  		// RFC 8878 permits us to set an 8M max on window size.
   259  		if windowSize > 8<<20 {
   260  			windowSize = 8 << 20
   261  		}
   262  
   263  		r.windowSize = int(windowSize)
   264  	}
   265  
   266  	// Frame_Content_Size. RFC 3.1.1.4.
   267  	r.frameSizeUnknown = false
   268  	r.remainingFrameSize = 0
   269  	fb := r.scratch[windowDescriptorSize:]
   270  	switch fcsFieldSize {
   271  	case 0:
   272  		r.frameSizeUnknown = true
   273  	case 1:
   274  		r.remainingFrameSize = uint64(fb[0])
   275  	case 2:
   276  		r.remainingFrameSize = 256 + uint64(binary.LittleEndian.Uint16(fb))
   277  	case 4:
   278  		r.remainingFrameSize = uint64(binary.LittleEndian.Uint32(fb))
   279  	case 8:
   280  		r.remainingFrameSize = binary.LittleEndian.Uint64(fb)
   281  	default:
   282  		panic("unreachable")
   283  	}
   284  
   285  	relativeOffset += headerSize
   286  
   287  	r.sawFrameHeader = true
   288  	r.readOneFrame = true
   289  	r.blockOffset += int64(relativeOffset)
   290  
   291  	// Prepare to read blocks from the frame.
   292  	r.repeatedOffset1 = 1
   293  	r.repeatedOffset2 = 4
   294  	r.repeatedOffset3 = 8
   295  	r.huffmanTableBits = 0
   296  	r.window = r.window[:0]
   297  	r.seqTables[0] = nil
   298  	r.seqTables[1] = nil
   299  	r.seqTables[2] = nil
   300  
   301  	return nil
   302  }
   303  
   304  // skipFrame skips a skippable frame. RFC 3.1.2.
   305  func (r *Reader) skipFrame() error {
   306  	relativeOffset := 0
   307  
   308  	if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
   309  		return r.wrapNonEOFError(relativeOffset, err)
   310  	}
   311  
   312  	relativeOffset += 4
   313  
   314  	size := binary.LittleEndian.Uint32(r.scratch[:4])
   315  
   316  	if seeker, ok := r.r.(io.Seeker); ok {
   317  		if _, err := seeker.Seek(int64(size), io.SeekCurrent); err != nil {
   318  			return err
   319  		}
   320  		r.blockOffset += int64(relativeOffset) + int64(size)
   321  		return nil
   322  	}
   323  
   324  	var skip []byte
   325  	const chunk = 1 << 20 // 1M
   326  	for size >= chunk {
   327  		if len(skip) == 0 {
   328  			skip = make([]byte, chunk)
   329  		}
   330  		if _, err := io.ReadFull(r.r, skip); err != nil {
   331  			return r.wrapNonEOFError(relativeOffset, err)
   332  		}
   333  		relativeOffset += chunk
   334  		size -= chunk
   335  	}
   336  	if size > 0 {
   337  		if len(skip) == 0 {
   338  			skip = make([]byte, size)
   339  		}
   340  		if _, err := io.ReadFull(r.r, skip); err != nil {
   341  			return r.wrapNonEOFError(relativeOffset, err)
   342  		}
   343  		relativeOffset += int(size)
   344  	}
   345  
   346  	r.blockOffset += int64(relativeOffset)
   347  
   348  	return nil
   349  }
   350  
   351  // readBlock reads the next block from a frame.
   352  func (r *Reader) readBlock() error {
   353  	relativeOffset := 0
   354  
   355  	// Read Block_Header. RFC 3.1.1.2.
   356  	if _, err := io.ReadFull(r.r, r.scratch[:3]); err != nil {
   357  		return r.wrapNonEOFError(relativeOffset, err)
   358  	}
   359  
   360  	relativeOffset += 3
   361  
   362  	header := uint32(r.scratch[0]) | (uint32(r.scratch[1]) << 8) | (uint32(r.scratch[2]) << 16)
   363  
   364  	lastBlock := header&1 != 0
   365  	blockType := (header >> 1) & 3
   366  	blockSize := int(header >> 3)
   367  
   368  	// Maximum block size is smaller of window size and 128K.
   369  	// We don't record the window size for a single segment frame,
   370  	// so just use 128K. RFC 3.1.1.2.3, 3.1.1.2.4.
   371  	if blockSize > 128<<10 || (r.windowSize > 0 && blockSize > r.windowSize) {
   372  		return r.makeError(relativeOffset, "block size too large")
   373  	}
   374  
   375  	// Handle different block types. RFC 3.1.1.2.2.
   376  	switch blockType {
   377  	case 0:
   378  		r.setBufferSize(blockSize)
   379  		if _, err := io.ReadFull(r.r, r.buffer); err != nil {
   380  			return r.wrapNonEOFError(relativeOffset, err)
   381  		}
   382  		relativeOffset += blockSize
   383  		r.blockOffset += int64(relativeOffset)
   384  	case 1:
   385  		r.setBufferSize(blockSize)
   386  		if _, err := io.ReadFull(r.r, r.scratch[:1]); err != nil {
   387  			return r.wrapNonEOFError(relativeOffset, err)
   388  		}
   389  		relativeOffset++
   390  		v := r.scratch[0]
   391  		for i := range r.buffer {
   392  			r.buffer[i] = v
   393  		}
   394  		r.blockOffset += int64(relativeOffset)
   395  	case 2:
   396  		r.blockOffset += int64(relativeOffset)
   397  		if err := r.compressedBlock(blockSize); err != nil {
   398  			return err
   399  		}
   400  		r.blockOffset += int64(blockSize)
   401  	case 3:
   402  		return r.makeError(relativeOffset, "invalid block type")
   403  	}
   404  
   405  	if !r.frameSizeUnknown {
   406  		if uint64(len(r.buffer)) > r.remainingFrameSize {
   407  			return r.makeError(relativeOffset, "too many uncompressed bytes in frame")
   408  		}
   409  		r.remainingFrameSize -= uint64(len(r.buffer))
   410  	}
   411  
   412  	if r.hasChecksum {
   413  		r.checksum.update(r.buffer)
   414  	}
   415  
   416  	if !lastBlock {
   417  		r.saveWindow(r.buffer)
   418  	} else {
   419  		if !r.frameSizeUnknown && r.remainingFrameSize != 0 {
   420  			return r.makeError(relativeOffset, "not enough uncompressed bytes for frame")
   421  		}
   422  		// Check for checksum at end of frame. RFC 3.1.1.
   423  		if r.hasChecksum {
   424  			if _, err := io.ReadFull(r.r, r.scratch[:4]); err != nil {
   425  				return r.wrapNonEOFError(0, err)
   426  			}
   427  
   428  			inputChecksum := binary.LittleEndian.Uint32(r.scratch[:4])
   429  			dataChecksum := uint32(r.checksum.digest())
   430  			if inputChecksum != dataChecksum {
   431  				return r.wrapError(0, fmt.Errorf("invalid checksum: got %#x want %#x", dataChecksum, inputChecksum))
   432  			}
   433  
   434  			r.blockOffset += 4
   435  		}
   436  		r.sawFrameHeader = false
   437  	}
   438  
   439  	return nil
   440  }
   441  
   442  // setBufferSize sets the decompressed buffer size.
   443  // When this is called the buffer is empty.
   444  func (r *Reader) setBufferSize(size int) {
   445  	if cap(r.buffer) < size {
   446  		need := size - cap(r.buffer)
   447  		r.buffer = append(r.buffer[:cap(r.buffer)], make([]byte, need)...)
   448  	}
   449  	r.buffer = r.buffer[:size]
   450  }
   451  
   452  // saveWindow saves bytes in the backreference window.
   453  // TODO: use a circular buffer for less data movement.
   454  func (r *Reader) saveWindow(buf []byte) {
   455  	if r.windowSize == 0 {
   456  		return
   457  	}
   458  
   459  	if len(buf) >= r.windowSize {
   460  		from := len(buf) - r.windowSize
   461  		r.window = append(r.window[:0], buf[from:]...)
   462  		return
   463  	}
   464  
   465  	keep := r.windowSize - len(buf) // must be positive
   466  	if keep < len(r.window) {
   467  		remove := len(r.window) - keep
   468  		copy(r.window[:], r.window[remove:])
   469  	}
   470  
   471  	r.window = append(r.window, buf...)
   472  }
   473  
   474  // zstdError is an error while decompressing.
   475  type zstdError struct {
   476  	offset int64
   477  	err    error
   478  }
   479  
   480  func (ze *zstdError) Error() string {
   481  	return fmt.Sprintf("zstd decompression error at %d: %v", ze.offset, ze.err)
   482  }
   483  
   484  func (ze *zstdError) Unwrap() error {
   485  	return ze.err
   486  }
   487  
   488  func (r *Reader) makeEOFError(off int) error {
   489  	return r.wrapError(off, io.ErrUnexpectedEOF)
   490  }
   491  
   492  func (r *Reader) wrapNonEOFError(off int, err error) error {
   493  	if err == io.EOF {
   494  		err = io.ErrUnexpectedEOF
   495  	}
   496  	return r.wrapError(off, err)
   497  }
   498  
   499  func (r *Reader) makeError(off int, msg string) error {
   500  	return r.wrapError(off, errors.New(msg))
   501  }
   502  
   503  func (r *Reader) wrapError(off int, err error) error {
   504  	if err == io.EOF {
   505  		return err
   506  	}
   507  	return &zstdError{r.blockOffset + int64(off), err}
   508  }