github.com/bir3/gocompiler@v0.9.2202/extra/compress/zstd/encoder.go (about)

     1  // Copyright 2019+ Klaus Post. All rights reserved.
     2  // License information can be found in the LICENSE file.
     3  // Based on work by Yann Collet, released under BSD License.
     4  
     5  package zstd
     6  
     7  import (
     8  	"crypto/rand"
     9  	"fmt"
    10  	"io"
    11  	"math"
    12  	rdebug "runtime/debug"
    13  	"sync"
    14  
    15  	"github.com/bir3/gocompiler/extra/compress/zstd/internal/xxhash"
    16  )
    17  
    18  // Encoder provides encoding to Zstandard.
    19  // An Encoder can be used for either compressing a stream via the
    20  // io.WriteCloser interface supported by the Encoder or as multiple independent
    21  // tasks via the EncodeAll function.
    22  // Smaller encodes are encouraged to use the EncodeAll function.
    23  // Use NewWriter to create a new instance.
    24  type Encoder struct {
    25  	o        encoderOptions
    26  	encoders chan encoder
    27  	state    encoderState
    28  	init     sync.Once
    29  }
    30  
    31  type encoder interface {
    32  	Encode(blk *blockEnc, src []byte)
    33  	EncodeNoHist(blk *blockEnc, src []byte)
    34  	Block() *blockEnc
    35  	CRC() *xxhash.Digest
    36  	AppendCRC([]byte) []byte
    37  	WindowSize(size int64) int32
    38  	UseBlock(*blockEnc)
    39  	Reset(d *dict, singleBlock bool)
    40  }
    41  
    42  type encoderState struct {
    43  	w                io.Writer
    44  	filling          []byte
    45  	current          []byte
    46  	previous         []byte
    47  	encoder          encoder
    48  	writing          *blockEnc
    49  	err              error
    50  	writeErr         error
    51  	nWritten         int64
    52  	nInput           int64
    53  	frameContentSize int64
    54  	headerWritten    bool
    55  	eofWritten       bool
    56  	fullFrameWritten bool
    57  
    58  	// This waitgroup indicates an encode is running.
    59  	wg sync.WaitGroup
    60  	// This waitgroup indicates we have a block encoding/writing.
    61  	wWg sync.WaitGroup
    62  }
    63  
    64  // NewWriter will create a new Zstandard encoder.
    65  // If the encoder will be used for encoding blocks a nil writer can be used.
    66  func NewWriter(w io.Writer, opts ...EOption) (*Encoder, error) {
    67  	initPredefined()
    68  	var e Encoder
    69  	e.o.setDefault()
    70  	for _, o := range opts {
    71  		err := o(&e.o)
    72  		if err != nil {
    73  			return nil, err
    74  		}
    75  	}
    76  	if w != nil {
    77  		e.Reset(w)
    78  	}
    79  	return &e, nil
    80  }
    81  
    82  func (e *Encoder) initialize() {
    83  	if e.o.concurrent == 0 {
    84  		e.o.setDefault()
    85  	}
    86  	e.encoders = make(chan encoder, e.o.concurrent)
    87  	for i := 0; i < e.o.concurrent; i++ {
    88  		enc := e.o.encoder()
    89  		e.encoders <- enc
    90  	}
    91  }
    92  
    93  // Reset will re-initialize the writer and new writes will encode to the supplied writer
    94  // as a new, independent stream.
    95  func (e *Encoder) Reset(w io.Writer) {
    96  	s := &e.state
    97  	s.wg.Wait()
    98  	s.wWg.Wait()
    99  	if cap(s.filling) == 0 {
   100  		s.filling = make([]byte, 0, e.o.blockSize)
   101  	}
   102  	if e.o.concurrent > 1 {
   103  		if cap(s.current) == 0 {
   104  			s.current = make([]byte, 0, e.o.blockSize)
   105  		}
   106  		if cap(s.previous) == 0 {
   107  			s.previous = make([]byte, 0, e.o.blockSize)
   108  		}
   109  		s.current = s.current[:0]
   110  		s.previous = s.previous[:0]
   111  		if s.writing == nil {
   112  			s.writing = &blockEnc{lowMem: e.o.lowMem}
   113  			s.writing.init()
   114  		}
   115  		s.writing.initNewEncode()
   116  	}
   117  	if s.encoder == nil {
   118  		s.encoder = e.o.encoder()
   119  	}
   120  	s.filling = s.filling[:0]
   121  	s.encoder.Reset(e.o.dict, false)
   122  	s.headerWritten = false
   123  	s.eofWritten = false
   124  	s.fullFrameWritten = false
   125  	s.w = w
   126  	s.err = nil
   127  	s.nWritten = 0
   128  	s.nInput = 0
   129  	s.writeErr = nil
   130  	s.frameContentSize = 0
   131  }
   132  
   133  // ResetContentSize will reset and set a content size for the next stream.
   134  // If the bytes written does not match the size given an error will be returned
   135  // when calling Close().
   136  // This is removed when Reset is called.
   137  // Sizes <= 0 results in no content size set.
   138  func (e *Encoder) ResetContentSize(w io.Writer, size int64) {
   139  	e.Reset(w)
   140  	if size >= 0 {
   141  		e.state.frameContentSize = size
   142  	}
   143  }
   144  
   145  // Write data to the encoder.
   146  // Input data will be buffered and as the buffer fills up
   147  // content will be compressed and written to the output.
   148  // When done writing, use Close to flush the remaining output
   149  // and write CRC if requested.
   150  func (e *Encoder) Write(p []byte) (n int, err error) {
   151  	s := &e.state
   152  	for len(p) > 0 {
   153  		if len(p)+len(s.filling) < e.o.blockSize {
   154  			if e.o.crc {
   155  				_, _ = s.encoder.CRC().Write(p)
   156  			}
   157  			s.filling = append(s.filling, p...)
   158  			return n + len(p), nil
   159  		}
   160  		add := p
   161  		if len(p)+len(s.filling) > e.o.blockSize {
   162  			add = add[:e.o.blockSize-len(s.filling)]
   163  		}
   164  		if e.o.crc {
   165  			_, _ = s.encoder.CRC().Write(add)
   166  		}
   167  		s.filling = append(s.filling, add...)
   168  		p = p[len(add):]
   169  		n += len(add)
   170  		if len(s.filling) < e.o.blockSize {
   171  			return n, nil
   172  		}
   173  		err := e.nextBlock(false)
   174  		if err != nil {
   175  			return n, err
   176  		}
   177  		if debugAsserts && len(s.filling) > 0 {
   178  			panic(len(s.filling))
   179  		}
   180  	}
   181  	return n, nil
   182  }
   183  
   184  // nextBlock will synchronize and start compressing input in e.state.filling.
   185  // If an error has occurred during encoding it will be returned.
   186  func (e *Encoder) nextBlock(final bool) error {
   187  	s := &e.state
   188  	// Wait for current block.
   189  	s.wg.Wait()
   190  	if s.err != nil {
   191  		return s.err
   192  	}
   193  	if len(s.filling) > e.o.blockSize {
   194  		return fmt.Errorf("block > maxStoreBlockSize")
   195  	}
   196  	if !s.headerWritten {
   197  		// If we have a single block encode, do a sync compression.
   198  		if final && len(s.filling) == 0 && !e.o.fullZero {
   199  			s.headerWritten = true
   200  			s.fullFrameWritten = true
   201  			s.eofWritten = true
   202  			return nil
   203  		}
   204  		if final && len(s.filling) > 0 {
   205  			s.current = e.EncodeAll(s.filling, s.current[:0])
   206  			var n2 int
   207  			n2, s.err = s.w.Write(s.current)
   208  			if s.err != nil {
   209  				return s.err
   210  			}
   211  			s.nWritten += int64(n2)
   212  			s.nInput += int64(len(s.filling))
   213  			s.current = s.current[:0]
   214  			s.filling = s.filling[:0]
   215  			s.headerWritten = true
   216  			s.fullFrameWritten = true
   217  			s.eofWritten = true
   218  			return nil
   219  		}
   220  
   221  		var tmp [maxHeaderSize]byte
   222  		fh := frameHeader{
   223  			ContentSize:   uint64(s.frameContentSize),
   224  			WindowSize:    uint32(s.encoder.WindowSize(s.frameContentSize)),
   225  			SingleSegment: false,
   226  			Checksum:      e.o.crc,
   227  			DictID:        e.o.dict.ID(),
   228  		}
   229  
   230  		dst, err := fh.appendTo(tmp[:0])
   231  		if err != nil {
   232  			return err
   233  		}
   234  		s.headerWritten = true
   235  		s.wWg.Wait()
   236  		var n2 int
   237  		n2, s.err = s.w.Write(dst)
   238  		if s.err != nil {
   239  			return s.err
   240  		}
   241  		s.nWritten += int64(n2)
   242  	}
   243  	if s.eofWritten {
   244  		// Ensure we only write it once.
   245  		final = false
   246  	}
   247  
   248  	if len(s.filling) == 0 {
   249  		// Final block, but no data.
   250  		if final {
   251  			enc := s.encoder
   252  			blk := enc.Block()
   253  			blk.reset(nil)
   254  			blk.last = true
   255  			blk.encodeRaw(nil)
   256  			s.wWg.Wait()
   257  			_, s.err = s.w.Write(blk.output)
   258  			s.nWritten += int64(len(blk.output))
   259  			s.eofWritten = true
   260  		}
   261  		return s.err
   262  	}
   263  
   264  	// SYNC:
   265  	if e.o.concurrent == 1 {
   266  		src := s.filling
   267  		s.nInput += int64(len(s.filling))
   268  		if debugEncoder {
   269  			println("Adding sync block,", len(src), "bytes, final:", final)
   270  		}
   271  		enc := s.encoder
   272  		blk := enc.Block()
   273  		blk.reset(nil)
   274  		enc.Encode(blk, src)
   275  		blk.last = final
   276  		if final {
   277  			s.eofWritten = true
   278  		}
   279  
   280  		s.err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
   281  		if s.err != nil {
   282  			return s.err
   283  		}
   284  		_, s.err = s.w.Write(blk.output)
   285  		s.nWritten += int64(len(blk.output))
   286  		s.filling = s.filling[:0]
   287  		return s.err
   288  	}
   289  
   290  	// Move blocks forward.
   291  	s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current
   292  	s.nInput += int64(len(s.current))
   293  	s.wg.Add(1)
   294  	go func(src []byte) {
   295  		if debugEncoder {
   296  			println("Adding block,", len(src), "bytes, final:", final)
   297  		}
   298  		defer func() {
   299  			if r := recover(); r != nil {
   300  				s.err = fmt.Errorf("panic while encoding: %v", r)
   301  				rdebug.PrintStack()
   302  			}
   303  			s.wg.Done()
   304  		}()
   305  		enc := s.encoder
   306  		blk := enc.Block()
   307  		enc.Encode(blk, src)
   308  		blk.last = final
   309  		if final {
   310  			s.eofWritten = true
   311  		}
   312  		// Wait for pending writes.
   313  		s.wWg.Wait()
   314  		if s.writeErr != nil {
   315  			s.err = s.writeErr
   316  			return
   317  		}
   318  		// Transfer encoders from previous write block.
   319  		blk.swapEncoders(s.writing)
   320  		// Transfer recent offsets to next.
   321  		enc.UseBlock(s.writing)
   322  		s.writing = blk
   323  		s.wWg.Add(1)
   324  		go func() {
   325  			defer func() {
   326  				if r := recover(); r != nil {
   327  					s.writeErr = fmt.Errorf("panic while encoding/writing: %v", r)
   328  					rdebug.PrintStack()
   329  				}
   330  				s.wWg.Done()
   331  			}()
   332  			s.writeErr = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
   333  			if s.writeErr != nil {
   334  				return
   335  			}
   336  			_, s.writeErr = s.w.Write(blk.output)
   337  			s.nWritten += int64(len(blk.output))
   338  		}()
   339  	}(s.current)
   340  	return nil
   341  }
   342  
   343  // ReadFrom reads data from r until EOF or error.
   344  // The return value n is the number of bytes read.
   345  // Any error except io.EOF encountered during the read is also returned.
   346  //
   347  // The Copy function uses ReaderFrom if available.
   348  func (e *Encoder) ReadFrom(r io.Reader) (n int64, err error) {
   349  	if debugEncoder {
   350  		println("Using ReadFrom")
   351  	}
   352  
   353  	// Flush any current writes.
   354  	if len(e.state.filling) > 0 {
   355  		if err := e.nextBlock(false); err != nil {
   356  			return 0, err
   357  		}
   358  	}
   359  	e.state.filling = e.state.filling[:e.o.blockSize]
   360  	src := e.state.filling
   361  	for {
   362  		n2, err := r.Read(src)
   363  		if e.o.crc {
   364  			_, _ = e.state.encoder.CRC().Write(src[:n2])
   365  		}
   366  		// src is now the unfilled part...
   367  		src = src[n2:]
   368  		n += int64(n2)
   369  		switch err {
   370  		case io.EOF:
   371  			e.state.filling = e.state.filling[:len(e.state.filling)-len(src)]
   372  			if debugEncoder {
   373  				println("ReadFrom: got EOF final block:", len(e.state.filling))
   374  			}
   375  			return n, nil
   376  		case nil:
   377  		default:
   378  			if debugEncoder {
   379  				println("ReadFrom: got error:", err)
   380  			}
   381  			e.state.err = err
   382  			return n, err
   383  		}
   384  		if len(src) > 0 {
   385  			if debugEncoder {
   386  				println("ReadFrom: got space left in source:", len(src))
   387  			}
   388  			continue
   389  		}
   390  		err = e.nextBlock(false)
   391  		if err != nil {
   392  			return n, err
   393  		}
   394  		e.state.filling = e.state.filling[:e.o.blockSize]
   395  		src = e.state.filling
   396  	}
   397  }
   398  
   399  // Flush will send the currently written data to output
   400  // and block until everything has been written.
   401  // This should only be used on rare occasions where pushing the currently queued data is critical.
   402  func (e *Encoder) Flush() error {
   403  	s := &e.state
   404  	if len(s.filling) > 0 {
   405  		err := e.nextBlock(false)
   406  		if err != nil {
   407  			return err
   408  		}
   409  	}
   410  	s.wg.Wait()
   411  	s.wWg.Wait()
   412  	if s.err != nil {
   413  		return s.err
   414  	}
   415  	return s.writeErr
   416  }
   417  
   418  // Close will flush the final output and close the stream.
   419  // The function will block until everything has been written.
   420  // The Encoder can still be re-used after calling this.
   421  func (e *Encoder) Close() error {
   422  	s := &e.state
   423  	if s.encoder == nil {
   424  		return nil
   425  	}
   426  	err := e.nextBlock(true)
   427  	if err != nil {
   428  		return err
   429  	}
   430  	if s.frameContentSize > 0 {
   431  		if s.nInput != s.frameContentSize {
   432  			return fmt.Errorf("frame content size %d given, but %d bytes was written", s.frameContentSize, s.nInput)
   433  		}
   434  	}
   435  	if e.state.fullFrameWritten {
   436  		return s.err
   437  	}
   438  	s.wg.Wait()
   439  	s.wWg.Wait()
   440  
   441  	if s.err != nil {
   442  		return s.err
   443  	}
   444  	if s.writeErr != nil {
   445  		return s.writeErr
   446  	}
   447  
   448  	// Write CRC
   449  	if e.o.crc && s.err == nil {
   450  		// heap alloc.
   451  		var tmp [4]byte
   452  		_, s.err = s.w.Write(s.encoder.AppendCRC(tmp[:0]))
   453  		s.nWritten += 4
   454  	}
   455  
   456  	// Add padding with content from crypto/rand.Reader
   457  	if s.err == nil && e.o.pad > 0 {
   458  		add := calcSkippableFrame(s.nWritten, int64(e.o.pad))
   459  		frame, err := skippableFrame(s.filling[:0], add, rand.Reader)
   460  		if err != nil {
   461  			return err
   462  		}
   463  		_, s.err = s.w.Write(frame)
   464  	}
   465  	return s.err
   466  }
   467  
   468  // EncodeAll will encode all input in src and append it to dst.
   469  // This function can be called concurrently, but each call will only run on a single goroutine.
   470  // If empty input is given, nothing is returned, unless WithZeroFrames is specified.
   471  // Encoded blocks can be concatenated and the result will be the combined input stream.
   472  // Data compressed with EncodeAll can be decoded with the Decoder,
   473  // using either a stream or DecodeAll.
   474  func (e *Encoder) EncodeAll(src, dst []byte) []byte {
   475  	if len(src) == 0 {
   476  		if e.o.fullZero {
   477  			// Add frame header.
   478  			fh := frameHeader{
   479  				ContentSize:   0,
   480  				WindowSize:    MinWindowSize,
   481  				SingleSegment: true,
   482  				// Adding a checksum would be a waste of space.
   483  				Checksum: false,
   484  				DictID:   0,
   485  			}
   486  			dst, _ = fh.appendTo(dst)
   487  
   488  			// Write raw block as last one only.
   489  			var blk blockHeader
   490  			blk.setSize(0)
   491  			blk.setType(blockTypeRaw)
   492  			blk.setLast(true)
   493  			dst = blk.appendTo(dst)
   494  		}
   495  		return dst
   496  	}
   497  	e.init.Do(e.initialize)
   498  	enc := <-e.encoders
   499  	defer func() {
   500  		// Release encoder reference to last block.
   501  		// If a non-single block is needed the encoder will reset again.
   502  		e.encoders <- enc
   503  	}()
   504  	// Use single segments when above minimum window and below window size.
   505  	single := len(src) <= e.o.windowSize && len(src) > MinWindowSize
   506  	if e.o.single != nil {
   507  		single = *e.o.single
   508  	}
   509  	fh := frameHeader{
   510  		ContentSize:   uint64(len(src)),
   511  		WindowSize:    uint32(enc.WindowSize(int64(len(src)))),
   512  		SingleSegment: single,
   513  		Checksum:      e.o.crc,
   514  		DictID:        e.o.dict.ID(),
   515  	}
   516  
   517  	// If less than 1MB, allocate a buffer up front.
   518  	if len(dst) == 0 && cap(dst) == 0 && len(src) < 1<<20 && !e.o.lowMem {
   519  		dst = make([]byte, 0, len(src))
   520  	}
   521  	dst, err := fh.appendTo(dst)
   522  	if err != nil {
   523  		panic(err)
   524  	}
   525  
   526  	// If we can do everything in one block, prefer that.
   527  	if len(src) <= e.o.blockSize {
   528  		enc.Reset(e.o.dict, true)
   529  		// Slightly faster with no history and everything in one block.
   530  		if e.o.crc {
   531  			_, _ = enc.CRC().Write(src)
   532  		}
   533  		blk := enc.Block()
   534  		blk.last = true
   535  		if e.o.dict == nil {
   536  			enc.EncodeNoHist(blk, src)
   537  		} else {
   538  			enc.Encode(blk, src)
   539  		}
   540  
   541  		// If we got the exact same number of literals as input,
   542  		// assume the literals cannot be compressed.
   543  		oldout := blk.output
   544  		// Output directly to dst
   545  		blk.output = dst
   546  
   547  		err := blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy)
   548  		if err != nil {
   549  			panic(err)
   550  		}
   551  		dst = blk.output
   552  		blk.output = oldout
   553  	} else {
   554  		enc.Reset(e.o.dict, false)
   555  		blk := enc.Block()
   556  		for len(src) > 0 {
   557  			todo := src
   558  			if len(todo) > e.o.blockSize {
   559  				todo = todo[:e.o.blockSize]
   560  			}
   561  			src = src[len(todo):]
   562  			if e.o.crc {
   563  				_, _ = enc.CRC().Write(todo)
   564  			}
   565  			blk.pushOffsets()
   566  			enc.Encode(blk, todo)
   567  			if len(src) == 0 {
   568  				blk.last = true
   569  			}
   570  			err := blk.encode(todo, e.o.noEntropy, !e.o.allLitEntropy)
   571  			if err != nil {
   572  				panic(err)
   573  			}
   574  			dst = append(dst, blk.output...)
   575  			blk.reset(nil)
   576  		}
   577  	}
   578  	if e.o.crc {
   579  		dst = enc.AppendCRC(dst)
   580  	}
   581  	// Add padding with content from crypto/rand.Reader
   582  	if e.o.pad > 0 {
   583  		add := calcSkippableFrame(int64(len(dst)), int64(e.o.pad))
   584  		dst, err = skippableFrame(dst, add, rand.Reader)
   585  		if err != nil {
   586  			panic(err)
   587  		}
   588  	}
   589  	return dst
   590  }
   591  
   592  // MaxEncodedSize returns the expected maximum
   593  // size of an encoded block or stream.
   594  func (e *Encoder) MaxEncodedSize(size int) int {
   595  	frameHeader := 4 + 2 // magic + frame header & window descriptor
   596  	if e.o.dict != nil {
   597  		frameHeader += 4
   598  	}
   599  	// Frame content size:
   600  	if size < 256 {
   601  		frameHeader++
   602  	} else if size < 65536+256 {
   603  		frameHeader += 2
   604  	} else if size < math.MaxInt32 {
   605  		frameHeader += 4
   606  	} else {
   607  		frameHeader += 8
   608  	}
   609  	// Final crc
   610  	if e.o.crc {
   611  		frameHeader += 4
   612  	}
   613  
   614  	// Max overhead is 3 bytes/block.
   615  	// There cannot be 0 blocks.
   616  	blocks := (size + e.o.blockSize) / e.o.blockSize
   617  
   618  	// Combine, add padding.
   619  	maxSz := frameHeader + 3*blocks + size
   620  	if e.o.pad > 1 {
   621  		maxSz += calcSkippableFrame(int64(maxSz), int64(e.o.pad))
   622  	}
   623  	return maxSz
   624  }