github.com/bir3/gocompiler@v0.9.2202/extra/compress/zstd/framedec.go (about)

     1  // Copyright 2019+ Klaus Post. All rights reserved.
     2  // License information can be found in the LICENSE file.
     3  // Based on work by Yann Collet, released under BSD License.
     4  
     5  package zstd
     6  
     7  import (
     8  	"encoding/binary"
     9  	"encoding/hex"
    10  	"errors"
    11  	"io"
    12  
    13  	"github.com/bir3/gocompiler/extra/compress/zstd/internal/xxhash"
    14  )
    15  
    16  type frameDec struct {
    17  	o   decoderOptions
    18  	crc *xxhash.Digest
    19  
    20  	WindowSize uint64
    21  
    22  	// Frame history passed between blocks
    23  	history history
    24  
    25  	rawInput byteBuffer
    26  
    27  	// Byte buffer that can be reused for small input blocks.
    28  	bBuf byteBuf
    29  
    30  	FrameContentSize uint64
    31  
    32  	DictionaryID  uint32
    33  	HasCheckSum   bool
    34  	SingleSegment bool
    35  }
    36  
    37  const (
    38  	// MinWindowSize is the minimum Window Size, which is 1 KB.
    39  	MinWindowSize = 1 << 10
    40  
    41  	// MaxWindowSize is the maximum encoder window size
    42  	// and the default decoder maximum window size.
    43  	MaxWindowSize = 1 << 29
    44  )
    45  
    46  const (
    47  	frameMagic          = "\x28\xb5\x2f\xfd"
    48  	skippableFrameMagic = "\x2a\x4d\x18"
    49  )
    50  
    51  func newFrameDec(o decoderOptions) *frameDec {
    52  	if o.maxWindowSize > o.maxDecodedSize {
    53  		o.maxWindowSize = o.maxDecodedSize
    54  	}
    55  	d := frameDec{
    56  		o: o,
    57  	}
    58  	return &d
    59  }
    60  
    61  // reset will read the frame header and prepare for block decoding.
    62  // If nothing can be read from the input, io.EOF will be returned.
    63  // Any other error indicated that the stream contained data, but
    64  // there was a problem.
    65  func (d *frameDec) reset(br byteBuffer) error {
    66  	d.HasCheckSum = false
    67  	d.WindowSize = 0
    68  	var signature [4]byte
    69  	for {
    70  		var err error
    71  		// Check if we can read more...
    72  		b, err := br.readSmall(1)
    73  		switch err {
    74  		case io.EOF, io.ErrUnexpectedEOF:
    75  			return io.EOF
    76  		default:
    77  			return err
    78  		case nil:
    79  			signature[0] = b[0]
    80  		}
    81  		// Read the rest, don't allow io.ErrUnexpectedEOF
    82  		b, err = br.readSmall(3)
    83  		switch err {
    84  		case io.EOF:
    85  			return io.EOF
    86  		default:
    87  			return err
    88  		case nil:
    89  			copy(signature[1:], b)
    90  		}
    91  
    92  		if string(signature[1:4]) != skippableFrameMagic || signature[0]&0xf0 != 0x50 {
    93  			if debugDecoder {
    94  				println("Not skippable", hex.EncodeToString(signature[:]), hex.EncodeToString([]byte(skippableFrameMagic)))
    95  			}
    96  			// Break if not skippable frame.
    97  			break
    98  		}
    99  		// Read size to skip
   100  		b, err = br.readSmall(4)
   101  		if err != nil {
   102  			if debugDecoder {
   103  				println("Reading Frame Size", err)
   104  			}
   105  			return err
   106  		}
   107  		n := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
   108  		println("Skipping frame with", n, "bytes.")
   109  		err = br.skipN(int64(n))
   110  		if err != nil {
   111  			if debugDecoder {
   112  				println("Reading discarded frame", err)
   113  			}
   114  			return err
   115  		}
   116  	}
   117  	if string(signature[:]) != frameMagic {
   118  		if debugDecoder {
   119  			println("Got magic numbers: ", signature, "want:", []byte(frameMagic))
   120  		}
   121  		return ErrMagicMismatch
   122  	}
   123  
   124  	// Read Frame_Header_Descriptor
   125  	fhd, err := br.readByte()
   126  	if err != nil {
   127  		if debugDecoder {
   128  			println("Reading Frame_Header_Descriptor", err)
   129  		}
   130  		return err
   131  	}
   132  	d.SingleSegment = fhd&(1<<5) != 0
   133  
   134  	if fhd&(1<<3) != 0 {
   135  		return errors.New("reserved bit set on frame header")
   136  	}
   137  
   138  	// Read Window_Descriptor
   139  	// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor
   140  	d.WindowSize = 0
   141  	if !d.SingleSegment {
   142  		wd, err := br.readByte()
   143  		if err != nil {
   144  			if debugDecoder {
   145  				println("Reading Window_Descriptor", err)
   146  			}
   147  			return err
   148  		}
   149  		printf("raw: %x, mantissa: %d, exponent: %d\n", wd, wd&7, wd>>3)
   150  		windowLog := 10 + (wd >> 3)
   151  		windowBase := uint64(1) << windowLog
   152  		windowAdd := (windowBase / 8) * uint64(wd&0x7)
   153  		d.WindowSize = windowBase + windowAdd
   154  	}
   155  
   156  	// Read Dictionary_ID
   157  	// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#dictionary_id
   158  	d.DictionaryID = 0
   159  	if size := fhd & 3; size != 0 {
   160  		if size == 3 {
   161  			size = 4
   162  		}
   163  
   164  		b, err := br.readSmall(int(size))
   165  		if err != nil {
   166  			println("Reading Dictionary_ID", err)
   167  			return err
   168  		}
   169  		var id uint32
   170  		switch len(b) {
   171  		case 1:
   172  			id = uint32(b[0])
   173  		case 2:
   174  			id = uint32(b[0]) | (uint32(b[1]) << 8)
   175  		case 4:
   176  			id = uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
   177  		}
   178  		if debugDecoder {
   179  			println("Dict size", size, "ID:", id)
   180  		}
   181  		d.DictionaryID = id
   182  	}
   183  
   184  	// Read Frame_Content_Size
   185  	// https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame_content_size
   186  	var fcsSize int
   187  	v := fhd >> 6
   188  	switch v {
   189  	case 0:
   190  		if d.SingleSegment {
   191  			fcsSize = 1
   192  		}
   193  	default:
   194  		fcsSize = 1 << v
   195  	}
   196  	d.FrameContentSize = fcsUnknown
   197  	if fcsSize > 0 {
   198  		b, err := br.readSmall(fcsSize)
   199  		if err != nil {
   200  			println("Reading Frame content", err)
   201  			return err
   202  		}
   203  		switch len(b) {
   204  		case 1:
   205  			d.FrameContentSize = uint64(b[0])
   206  		case 2:
   207  			// When FCS_Field_Size is 2, the offset of 256 is added.
   208  			d.FrameContentSize = uint64(b[0]) | (uint64(b[1]) << 8) + 256
   209  		case 4:
   210  			d.FrameContentSize = uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24)
   211  		case 8:
   212  			d1 := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24)
   213  			d2 := uint32(b[4]) | (uint32(b[5]) << 8) | (uint32(b[6]) << 16) | (uint32(b[7]) << 24)
   214  			d.FrameContentSize = uint64(d1) | (uint64(d2) << 32)
   215  		}
   216  		if debugDecoder {
   217  			println("Read FCS:", d.FrameContentSize)
   218  		}
   219  	}
   220  
   221  	// Move this to shared.
   222  	d.HasCheckSum = fhd&(1<<2) != 0
   223  	if d.HasCheckSum {
   224  		if d.crc == nil {
   225  			d.crc = xxhash.New()
   226  		}
   227  		d.crc.Reset()
   228  	}
   229  
   230  	if d.WindowSize > d.o.maxWindowSize {
   231  		if debugDecoder {
   232  			printf("window size %d > max %d\n", d.WindowSize, d.o.maxWindowSize)
   233  		}
   234  		return ErrWindowSizeExceeded
   235  	}
   236  
   237  	if d.WindowSize == 0 && d.SingleSegment {
   238  		// We may not need window in this case.
   239  		d.WindowSize = d.FrameContentSize
   240  		if d.WindowSize < MinWindowSize {
   241  			d.WindowSize = MinWindowSize
   242  		}
   243  		if d.WindowSize > d.o.maxDecodedSize {
   244  			if debugDecoder {
   245  				printf("window size %d > max %d\n", d.WindowSize, d.o.maxWindowSize)
   246  			}
   247  			return ErrDecoderSizeExceeded
   248  		}
   249  	}
   250  
   251  	// The minimum Window_Size is 1 KB.
   252  	if d.WindowSize < MinWindowSize {
   253  		if debugDecoder {
   254  			println("got window size: ", d.WindowSize)
   255  		}
   256  		return ErrWindowSizeTooSmall
   257  	}
   258  	d.history.windowSize = int(d.WindowSize)
   259  	if !d.o.lowMem || d.history.windowSize < maxBlockSize {
   260  		// Alloc 2x window size if not low-mem, or window size below 2MB.
   261  		d.history.allocFrameBuffer = d.history.windowSize * 2
   262  	} else {
   263  		if d.o.lowMem {
   264  			// Alloc with 1MB extra.
   265  			d.history.allocFrameBuffer = d.history.windowSize + maxBlockSize/2
   266  		} else {
   267  			// Alloc with 2MB extra.
   268  			d.history.allocFrameBuffer = d.history.windowSize + maxBlockSize
   269  		}
   270  	}
   271  
   272  	if debugDecoder {
   273  		println("Frame: Dict:", d.DictionaryID, "FrameContentSize:", d.FrameContentSize, "singleseg:", d.SingleSegment, "window:", d.WindowSize, "crc:", d.HasCheckSum)
   274  	}
   275  
   276  	// history contains input - maybe we do something
   277  	d.rawInput = br
   278  	return nil
   279  }
   280  
   281  // next will start decoding the next block from stream.
   282  func (d *frameDec) next(block *blockDec) error {
   283  	if debugDecoder {
   284  		println("decoding new block")
   285  	}
   286  	err := block.reset(d.rawInput, d.WindowSize)
   287  	if err != nil {
   288  		println("block error:", err)
   289  		// Signal the frame decoder we have a problem.
   290  		block.sendErr(err)
   291  		return err
   292  	}
   293  	return nil
   294  }
   295  
   296  // checkCRC will check the checksum, assuming the frame has one.
   297  // Will return ErrCRCMismatch if crc check failed, otherwise nil.
   298  func (d *frameDec) checkCRC() error {
   299  	// We can overwrite upper tmp now
   300  	buf, err := d.rawInput.readSmall(4)
   301  	if err != nil {
   302  		println("CRC missing?", err)
   303  		return err
   304  	}
   305  
   306  	want := binary.LittleEndian.Uint32(buf[:4])
   307  	got := uint32(d.crc.Sum64())
   308  
   309  	if got != want {
   310  		if debugDecoder {
   311  			printf("CRC check failed: got %08x, want %08x\n", got, want)
   312  		}
   313  		return ErrCRCMismatch
   314  	}
   315  	if debugDecoder {
   316  		printf("CRC ok %08x\n", got)
   317  	}
   318  	return nil
   319  }
   320  
   321  // consumeCRC skips over the checksum, assuming the frame has one.
   322  func (d *frameDec) consumeCRC() error {
   323  	_, err := d.rawInput.readSmall(4)
   324  	if err != nil {
   325  		println("CRC missing?", err)
   326  	}
   327  	return err
   328  }
   329  
   330  // runDecoder will run the decoder for the remainder of the frame.
   331  func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) {
   332  	saved := d.history.b
   333  
   334  	// We use the history for output to avoid copying it.
   335  	d.history.b = dst
   336  	d.history.ignoreBuffer = len(dst)
   337  	// Store input length, so we only check new data.
   338  	crcStart := len(dst)
   339  	d.history.decoders.maxSyncLen = 0
   340  	if d.o.limitToCap {
   341  		d.history.decoders.maxSyncLen = uint64(cap(dst) - len(dst))
   342  	}
   343  	if d.FrameContentSize != fcsUnknown {
   344  		if !d.o.limitToCap || d.FrameContentSize+uint64(len(dst)) < d.history.decoders.maxSyncLen {
   345  			d.history.decoders.maxSyncLen = d.FrameContentSize + uint64(len(dst))
   346  		}
   347  		if d.history.decoders.maxSyncLen > d.o.maxDecodedSize {
   348  			if debugDecoder {
   349  				println("maxSyncLen:", d.history.decoders.maxSyncLen, "> maxDecodedSize:", d.o.maxDecodedSize)
   350  			}
   351  			return dst, ErrDecoderSizeExceeded
   352  		}
   353  		if debugDecoder {
   354  			println("maxSyncLen:", d.history.decoders.maxSyncLen)
   355  		}
   356  		if !d.o.limitToCap && uint64(cap(dst)) < d.history.decoders.maxSyncLen {
   357  			// Alloc for output
   358  			dst2 := make([]byte, len(dst), d.history.decoders.maxSyncLen+compressedBlockOverAlloc)
   359  			copy(dst2, dst)
   360  			dst = dst2
   361  		}
   362  	}
   363  	var err error
   364  	for {
   365  		err = dec.reset(d.rawInput, d.WindowSize)
   366  		if err != nil {
   367  			break
   368  		}
   369  		if debugDecoder {
   370  			println("next block:", dec)
   371  		}
   372  		err = dec.decodeBuf(&d.history)
   373  		if err != nil {
   374  			break
   375  		}
   376  		if uint64(len(d.history.b)-crcStart) > d.o.maxDecodedSize {
   377  			println("runDecoder: maxDecodedSize exceeded", uint64(len(d.history.b)-crcStart), ">", d.o.maxDecodedSize)
   378  			err = ErrDecoderSizeExceeded
   379  			break
   380  		}
   381  		if d.o.limitToCap && len(d.history.b) > cap(dst) {
   382  			println("runDecoder: cap exceeded", uint64(len(d.history.b)), ">", cap(dst))
   383  			err = ErrDecoderSizeExceeded
   384  			break
   385  		}
   386  		if uint64(len(d.history.b)-crcStart) > d.FrameContentSize {
   387  			println("runDecoder: FrameContentSize exceeded", uint64(len(d.history.b)-crcStart), ">", d.FrameContentSize)
   388  			err = ErrFrameSizeExceeded
   389  			break
   390  		}
   391  		if dec.Last {
   392  			break
   393  		}
   394  		if debugDecoder {
   395  			println("runDecoder: FrameContentSize", uint64(len(d.history.b)-crcStart), "<=", d.FrameContentSize)
   396  		}
   397  	}
   398  	dst = d.history.b
   399  	if err == nil {
   400  		if d.FrameContentSize != fcsUnknown && uint64(len(d.history.b)-crcStart) != d.FrameContentSize {
   401  			err = ErrFrameSizeMismatch
   402  		} else if d.HasCheckSum {
   403  			if d.o.ignoreChecksum {
   404  				err = d.consumeCRC()
   405  			} else {
   406  				d.crc.Write(dst[crcStart:])
   407  				err = d.checkCRC()
   408  			}
   409  		}
   410  	}
   411  	d.history.b = saved
   412  	return dst, err
   413  }