github.com/metacubex/gvisor@v0.0.0-20240320004321-933faba989ec/pkg/compressio/compressio.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package compressio provides parallel compression and decompression, as well
    16  // as optional SHA-256 hashing. It also provides another storage variant
    17  // (nocompressio) that does not compress data but tracks its integrity.
    18  //
    19  // The stream format is defined as follows.
    20  //
    21  // /------------------------------------------------------\
    22  // |                 chunk size (4-bytes)                 |
    23  // +------------------------------------------------------+
    24  // |              (optional) hash (32-bytes)              |
    25  // +------------------------------------------------------+
    26  // |           compressed data size (4-bytes)             |
    27  // +------------------------------------------------------+
    28  // |                   compressed data                    |
    29  // +------------------------------------------------------+
    30  // |              (optional) hash (32-bytes)              |
    31  // +------------------------------------------------------+
    32  // |           compressed data size (4-bytes)             |
    33  // +------------------------------------------------------+
    34  // |                       ......                         |
    35  // \------------------------------------------------------/
    36  //
    37  // where each subsequent hash is calculated from the following items in order
    38  //
    39  //	compressed data
    40  //	compressed data size
    41  //	previous hash
    42  //
    43  // so the stream integrity cannot be compromised by switching and mixing
    44  // compressed chunks.
    45  package compressio
    46  
    47  import (
    48  	"bytes"
    49  	"compress/flate"
    50  	"crypto/hmac"
    51  	"crypto/sha256"
    52  	"encoding/binary"
    53  	"errors"
    54  	"hash"
    55  	"io"
    56  	"runtime"
    57  
    58  	"github.com/metacubex/gvisor/pkg/sync"
    59  )
    60  
    61  var bufPool = sync.Pool{
    62  	New: func() any {
    63  		return bytes.NewBuffer(nil)
    64  	},
    65  }
    66  
    67  var chunkPool = sync.Pool{
    68  	New: func() any {
    69  		return new(chunk)
    70  	},
    71  }
    72  
    73  // chunk is a unit of work.
    74  type chunk struct {
    75  	// compressed is compressed data.
    76  	//
    77  	// This will always be returned to the bufPool directly when work has
    78  	// finished (in schedule) and therefore must be allocated.
    79  	compressed *bytes.Buffer
    80  
    81  	// uncompressed is the uncompressed data.
    82  	//
    83  	// This is not returned to the bufPool automatically, since it may
    84  	// correspond to a inline slice (provided directly to Read or Write).
    85  	uncompressed *bytes.Buffer
    86  
    87  	// The current hash object. Only used in compress mode.
    88  	h hash.Hash
    89  
    90  	// The hash from previous chunks. Only used in uncompress mode.
    91  	lastSum []byte
    92  
    93  	// The expected hash after current chunk. Only used in uncompress mode.
    94  	sum []byte
    95  }
    96  
    97  // newChunk allocates a new chunk object (or pulls one from the pool). Buffers
    98  // will be allocated if nil is provided for compressed or uncompressed.
    99  func newChunk(lastSum []byte, sum []byte, compressed *bytes.Buffer, uncompressed *bytes.Buffer) *chunk {
   100  	c := chunkPool.Get().(*chunk)
   101  	c.lastSum = lastSum
   102  	c.sum = sum
   103  	if compressed != nil {
   104  		c.compressed = compressed
   105  	} else {
   106  		c.compressed = bufPool.Get().(*bytes.Buffer)
   107  	}
   108  	if uncompressed != nil {
   109  		c.uncompressed = uncompressed
   110  	} else {
   111  		c.uncompressed = bufPool.Get().(*bytes.Buffer)
   112  	}
   113  	return c
   114  }
   115  
   116  // result is the result of some work; it includes the original chunk.
   117  type result struct {
   118  	*chunk
   119  	err error
   120  }
   121  
   122  // worker is a compression/decompression worker.
   123  //
   124  // The associated worker goroutine reads in uncompressed buffers from input and
   125  // writes compressed buffers to its output. Alternatively, the worker reads
   126  // compressed buffers from input and writes uncompressed buffers to its output.
   127  //
   128  // The goroutine will exit when input is closed, and the goroutine will close
   129  // output.
   130  type worker struct {
   131  	hashPool *hashPool
   132  	input    chan *chunk
   133  	output   chan result
   134  
   135  	// scratch is a temporary buffer used for marshalling. This is declared
   136  	// unfront here to avoid reallocation.
   137  	scratch [4]byte
   138  }
   139  
   140  // work is the main work routine; see worker.
   141  func (w *worker) work(compress bool, level int) {
   142  	defer close(w.output)
   143  
   144  	var h hash.Hash
   145  
   146  	for c := range w.input {
   147  		if h == nil && w.hashPool != nil {
   148  			h = w.hashPool.getHash()
   149  		}
   150  		if compress {
   151  			mw := io.Writer(c.compressed)
   152  			if h != nil {
   153  				mw = io.MultiWriter(mw, h)
   154  			}
   155  
   156  			// Encode this slice.
   157  			fw, err := flate.NewWriter(mw, level)
   158  			if err != nil {
   159  				w.output <- result{c, err}
   160  				continue
   161  			}
   162  
   163  			// Encode the input.
   164  			if _, err := io.CopyN(fw, c.uncompressed, int64(c.uncompressed.Len())); err != nil {
   165  				w.output <- result{c, err}
   166  				continue
   167  			}
   168  			if err := fw.Close(); err != nil {
   169  				w.output <- result{c, err}
   170  				continue
   171  			}
   172  
   173  			// Write the hash, if enabled.
   174  			if h != nil {
   175  				binary.BigEndian.PutUint32(w.scratch[:], uint32(c.compressed.Len()))
   176  				h.Write(w.scratch[:4])
   177  				c.h = h
   178  				h = nil
   179  			}
   180  		} else {
   181  			// Check the hash of the compressed contents.
   182  			if h != nil {
   183  				h.Write(c.compressed.Bytes())
   184  				binary.BigEndian.PutUint32(w.scratch[:], uint32(c.compressed.Len()))
   185  				h.Write(w.scratch[:4])
   186  				io.CopyN(h, bytes.NewReader(c.lastSum), int64(len(c.lastSum)))
   187  
   188  				sum := h.Sum(nil)
   189  				h.Reset()
   190  				if !hmac.Equal(c.sum, sum) {
   191  					w.output <- result{c, ErrHashMismatch}
   192  					continue
   193  				}
   194  			}
   195  
   196  			// Decode this slice.
   197  			fr := flate.NewReader(c.compressed)
   198  
   199  			// Decode the input.
   200  			if _, err := io.Copy(c.uncompressed, fr); err != nil {
   201  				w.output <- result{c, err}
   202  				continue
   203  			}
   204  		}
   205  
   206  		// Send the output.
   207  		w.output <- result{c, nil}
   208  	}
   209  }
   210  
   211  type hashPool struct {
   212  	// mu protexts the hash list.
   213  	mu sync.Mutex
   214  
   215  	// key is the key used to create hash objects.
   216  	key []byte
   217  
   218  	// hashes is the hash object free list. Note that this cannot be
   219  	// globally shared across readers or writers, as it is key-specific.
   220  	hashes []hash.Hash
   221  }
   222  
   223  // getHash gets a hash object for the pool. It should only be called when the
   224  // pool key is non-nil.
   225  func (p *hashPool) getHash() hash.Hash {
   226  	p.mu.Lock()
   227  	defer p.mu.Unlock()
   228  
   229  	if len(p.hashes) == 0 {
   230  		return hmac.New(sha256.New, p.key)
   231  	}
   232  
   233  	h := p.hashes[len(p.hashes)-1]
   234  	p.hashes = p.hashes[:len(p.hashes)-1]
   235  	return h
   236  }
   237  
   238  func (p *hashPool) putHash(h hash.Hash) {
   239  	h.Reset()
   240  
   241  	p.mu.Lock()
   242  	defer p.mu.Unlock()
   243  
   244  	p.hashes = append(p.hashes, h)
   245  }
   246  
   247  // pool is common functionality for reader/writers.
   248  type pool struct {
   249  	// workers are the compression/decompression workers.
   250  	workers []worker
   251  
   252  	// chunkSize is the chunk size. This is the first four bytes in the
   253  	// stream and is shared across both the reader and writer.
   254  	chunkSize uint32
   255  
   256  	// mu protects below; it is generally the responsibility of users to
   257  	// acquire this mutex before calling any methods on the pool.
   258  	mu sync.Mutex
   259  
   260  	// nextInput is the next worker for input (scheduling).
   261  	nextInput int
   262  
   263  	// nextOutput is the next worker for output (result).
   264  	nextOutput int
   265  
   266  	// buf is the current active buffer; the exact semantics of this buffer
   267  	// depending on whether this is a reader or a writer.
   268  	buf *bytes.Buffer
   269  
   270  	// lasSum records the hash of the last chunk processed.
   271  	lastSum []byte
   272  
   273  	// hashPool is the hash object pool. It cannot be embedded into pool
   274  	// itself as worker refers to it and that would stop pool from being
   275  	// GCed.
   276  	hashPool *hashPool
   277  }
   278  
   279  // init initializes the worker pool.
   280  //
   281  // This should only be called once.
   282  func (p *pool) init(key []byte, workers int, compress bool, level int) {
   283  	if key != nil {
   284  		p.hashPool = &hashPool{key: key}
   285  	}
   286  	p.workers = make([]worker, workers)
   287  	for i := 0; i < len(p.workers); i++ {
   288  		p.workers[i] = worker{
   289  			hashPool: p.hashPool,
   290  			input:    make(chan *chunk, 1),
   291  			output:   make(chan result, 1),
   292  		}
   293  		go p.workers[i].work(compress, level) // S/R-SAFE: In save path only.
   294  	}
   295  	runtime.SetFinalizer(p, (*pool).stop)
   296  }
   297  
   298  // stop stops all workers.
   299  func (p *pool) stop() {
   300  	for i := 0; i < len(p.workers); i++ {
   301  		close(p.workers[i].input)
   302  	}
   303  	p.workers = nil
   304  	p.hashPool = nil
   305  }
   306  
   307  // handleResult calls the callback.
   308  func handleResult(r result, callback func(*chunk) error) error {
   309  	defer func() {
   310  		r.chunk.compressed.Reset()
   311  		bufPool.Put(r.chunk.compressed)
   312  		chunkPool.Put(r.chunk)
   313  	}()
   314  	if r.err != nil {
   315  		return r.err
   316  	}
   317  	return callback(r.chunk)
   318  }
   319  
   320  // schedule schedules the given buffers.
   321  //
   322  // If c is non-nil, then it will return as soon as the chunk is scheduled. If c
   323  // is nil, then it will return only when no more work is left to do.
   324  //
   325  // If no callback function is provided, then the output channel will be
   326  // ignored.  You must be sure that the input is schedulable in this case.
   327  func (p *pool) schedule(c *chunk, callback func(*chunk) error) error {
   328  	for {
   329  		var (
   330  			inputChan  chan *chunk
   331  			outputChan chan result
   332  		)
   333  		if c != nil && len(p.workers) != 0 {
   334  			inputChan = p.workers[(p.nextInput+1)%len(p.workers)].input
   335  		}
   336  		if callback != nil && p.nextOutput != p.nextInput && len(p.workers) != 0 {
   337  			outputChan = p.workers[(p.nextOutput+1)%len(p.workers)].output
   338  		}
   339  		if inputChan == nil && outputChan == nil {
   340  			return nil
   341  		}
   342  
   343  		select {
   344  		case inputChan <- c:
   345  			p.nextInput++
   346  			return nil
   347  		case r := <-outputChan:
   348  			p.nextOutput++
   349  			if err := handleResult(r, callback); err != nil {
   350  				return err
   351  			}
   352  		}
   353  	}
   354  }
   355  
   356  // Reader is a compressed reader.
   357  type Reader struct {
   358  	pool
   359  
   360  	// in is the source.
   361  	in io.Reader
   362  
   363  	// scratch is a temporary buffer used for marshalling. This is declared
   364  	// unfront here to avoid reallocation.
   365  	scratch [4]byte
   366  }
   367  
   368  var _ io.Reader = (*Reader)(nil)
   369  
   370  // NewReader returns a new compressed reader. If key is non-nil, the data stream
   371  // is assumed to contain expected hash values, which will be compared against
   372  // hash values computed from the compressed bytes. See package comments for
   373  // details.
   374  func NewReader(in io.Reader, key []byte) (*Reader, error) {
   375  	r := &Reader{
   376  		in: in,
   377  	}
   378  
   379  	// Use double buffering for read.
   380  	r.init(key, 2*runtime.GOMAXPROCS(0), false, 0)
   381  
   382  	if _, err := io.ReadFull(in, r.scratch[:4]); err != nil {
   383  		return nil, err
   384  	}
   385  	r.chunkSize = binary.BigEndian.Uint32(r.scratch[:4])
   386  
   387  	if r.hashPool != nil {
   388  		h := r.hashPool.getHash()
   389  		binary.BigEndian.PutUint32(r.scratch[:], r.chunkSize)
   390  		h.Write(r.scratch[:4])
   391  		r.lastSum = h.Sum(nil)
   392  		r.hashPool.putHash(h)
   393  		sum := make([]byte, len(r.lastSum))
   394  		if _, err := io.ReadFull(r.in, sum); err != nil {
   395  			return nil, err
   396  		}
   397  		if !hmac.Equal(r.lastSum, sum) {
   398  			return nil, ErrHashMismatch
   399  		}
   400  	}
   401  
   402  	return r, nil
   403  }
   404  
   405  // errNewBuffer is returned when a new buffer is completed.
   406  var errNewBuffer = errors.New("buffer ready")
   407  
   408  // ErrHashMismatch is returned if the hash does not match.
   409  var ErrHashMismatch = errors.New("hash mismatch")
   410  
   411  // ReadByte implements wire.Reader.ReadByte.
   412  func (r *Reader) ReadByte() (byte, error) {
   413  	var p [1]byte
   414  	n, err := r.Read(p[:])
   415  	if n != 1 {
   416  		return p[0], err
   417  	}
   418  	// Suppress EOF.
   419  	return p[0], nil
   420  }
   421  
   422  // Read implements io.Reader.Read.
   423  func (r *Reader) Read(p []byte) (int, error) {
   424  	r.mu.Lock()
   425  	defer r.mu.Unlock()
   426  
   427  	// Total bytes completed; this is declared up front because it must be
   428  	// adjustable by the callback below.
   429  	done := 0
   430  
   431  	// Total bytes pending in the asynchronous workers for buffers. This is
   432  	// used to process the proper regions of the input as inline buffers.
   433  	var (
   434  		pendingPre    = r.nextInput - r.nextOutput
   435  		pendingInline = 0
   436  	)
   437  
   438  	// Define our callback for completed work.
   439  	callback := func(c *chunk) error {
   440  		// Check for an inline buffer.
   441  		if pendingPre == 0 && pendingInline > 0 {
   442  			pendingInline--
   443  			done += c.uncompressed.Len()
   444  			return nil
   445  		}
   446  
   447  		// Copy the resulting buffer to our intermediate one, and
   448  		// return errNewBuffer to ensure that we aren't called a second
   449  		// time. This error code is handled specially below.
   450  		//
   451  		// c.buf will be freed and return to the pool when it is done.
   452  		if pendingPre > 0 {
   453  			pendingPre--
   454  		}
   455  		r.buf = c.uncompressed
   456  		return errNewBuffer
   457  	}
   458  
   459  	for done < len(p) {
   460  		// Do we have buffered data available?
   461  		if r.buf != nil {
   462  			n, err := r.buf.Read(p[done:])
   463  			done += n
   464  			if err == io.EOF {
   465  				// This is the uncompressed buffer, it can be
   466  				// returned to the pool at this point.
   467  				r.buf.Reset()
   468  				bufPool.Put(r.buf)
   469  				r.buf = nil
   470  			} else if err != nil {
   471  				// Should never happen.
   472  				defer r.stop()
   473  				return done, err
   474  			}
   475  			continue
   476  		}
   477  
   478  		// Read the length of the next chunk and reset the
   479  		// reader. The length is used to limit the reader.
   480  		//
   481  		// See writer.flush.
   482  		if _, err := io.ReadFull(r.in, r.scratch[:4]); err != nil {
   483  			// This is generally okay as long as there
   484  			// are still buffers outstanding. We actually
   485  			// just wait for completion of those buffers here
   486  			// and continue our loop.
   487  			if err := r.schedule(nil, callback); err == nil {
   488  				// We've actually finished all buffers; this is
   489  				// the normal EOF exit path.
   490  				defer r.stop()
   491  				return done, io.EOF
   492  			} else if err == errNewBuffer {
   493  				// A new buffer is now available.
   494  				continue
   495  			} else {
   496  				// Some other error occurred; we cannot
   497  				// process any further.
   498  				defer r.stop()
   499  				return done, err
   500  			}
   501  		}
   502  		l := binary.BigEndian.Uint32(r.scratch[:4])
   503  
   504  		// Read this chunk and schedule decompression.
   505  		compressed := bufPool.Get().(*bytes.Buffer)
   506  		if _, err := io.CopyN(compressed, r.in, int64(l)); err != nil {
   507  			// Some other error occurred; see above.
   508  			if err == io.EOF {
   509  				err = io.ErrUnexpectedEOF
   510  			}
   511  			return done, err
   512  		}
   513  
   514  		var sum []byte
   515  		if r.hashPool != nil {
   516  			sum = make([]byte, len(r.lastSum))
   517  			if _, err := io.ReadFull(r.in, sum); err != nil {
   518  				if err == io.EOF {
   519  					err = io.ErrUnexpectedEOF
   520  				}
   521  				return done, err
   522  			}
   523  		}
   524  
   525  		// Are we doing inline decoding?
   526  		//
   527  		// Note that we need to check the length here against
   528  		// bytes.MinRead, since the bytes library will choose to grow
   529  		// the slice if the available capacity is not at least
   530  		// bytes.MinRead. This limits inline decoding to chunkSizes
   531  		// that are at least bytes.MinRead (which is not unreasonable).
   532  		var c *chunk
   533  		start := done + ((pendingPre + pendingInline) * int(r.chunkSize))
   534  		if len(p) >= start+int(r.chunkSize) && len(p) >= start+bytes.MinRead {
   535  			c = newChunk(r.lastSum, sum, compressed, bytes.NewBuffer(p[start:start]))
   536  			pendingInline++
   537  		} else {
   538  			c = newChunk(r.lastSum, sum, compressed, nil)
   539  		}
   540  		r.lastSum = sum
   541  		if err := r.schedule(c, callback); err == errNewBuffer {
   542  			// A new buffer was completed while we were reading.
   543  			// That's great, but we need to force schedule the
   544  			// current buffer so that it does not get lost.
   545  			//
   546  			// It is safe to pass nil as an output function here,
   547  			// because we know that we just freed up a slot above.
   548  			r.schedule(c, nil)
   549  		} else if err != nil {
   550  			// Some other error occurred; see above.
   551  			defer r.stop()
   552  			return done, err
   553  		}
   554  	}
   555  
   556  	// Make sure that everything has been decoded successfully, otherwise
   557  	// parts of p may not actually have completed.
   558  	for pendingInline > 0 {
   559  		if err := r.schedule(nil, func(c *chunk) error {
   560  			if err := callback(c); err != nil {
   561  				return err
   562  			}
   563  			// The nil case means that an inline buffer has
   564  			// completed. The callback will have already removed
   565  			// the inline buffer from the map, so we just return an
   566  			// error to check the top of the loop again.
   567  			return errNewBuffer
   568  		}); err != errNewBuffer {
   569  			// Some other error occurred; see above.
   570  			return done, err
   571  		}
   572  	}
   573  
   574  	// Need to return done here, since it may have been adjusted by the
   575  	// callback to compensation for partial reads on some inline buffer.
   576  	return done, nil
   577  }
   578  
   579  // Writer is a compressed writer.
   580  type Writer struct {
   581  	pool
   582  
   583  	// out is the underlying writer.
   584  	out io.Writer
   585  
   586  	// closed indicates whether the file has been closed.
   587  	closed bool
   588  
   589  	// scratch is a temporary buffer used for marshalling. This is declared
   590  	// unfront here to avoid reallocation.
   591  	scratch [4]byte
   592  }
   593  
   594  var _ io.Writer = (*Writer)(nil)
   595  
   596  // NewWriter returns a new compressed writer. If key is non-nil, hash values are
   597  // generated and written out for compressed bytes. See package comments for
   598  // details.
   599  //
   600  // The recommended chunkSize is on the order of 1M. Extra memory may be
   601  // buffered (in the form of read-ahead, or buffered writes), and is limited to
   602  // O(chunkSize * [1+GOMAXPROCS]).
   603  func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (*Writer, error) {
   604  	w := &Writer{
   605  		pool: pool{
   606  			chunkSize: chunkSize,
   607  			buf:       bufPool.Get().(*bytes.Buffer),
   608  		},
   609  		out: out,
   610  	}
   611  	w.init(key, 1+runtime.GOMAXPROCS(0), true, level)
   612  
   613  	binary.BigEndian.PutUint32(w.scratch[:], chunkSize)
   614  	if _, err := w.out.Write(w.scratch[:4]); err != nil {
   615  		return nil, err
   616  	}
   617  
   618  	if w.hashPool != nil {
   619  		h := w.hashPool.getHash()
   620  		binary.BigEndian.PutUint32(w.scratch[:], chunkSize)
   621  		h.Write(w.scratch[:4])
   622  		w.lastSum = h.Sum(nil)
   623  		w.hashPool.putHash(h)
   624  		if _, err := io.CopyN(w.out, bytes.NewReader(w.lastSum), int64(len(w.lastSum))); err != nil {
   625  			return nil, err
   626  		}
   627  	}
   628  
   629  	return w, nil
   630  }
   631  
   632  // flush writes a single buffer.
   633  func (w *Writer) flush(c *chunk) error {
   634  	// Prefix each chunk with a length; this allows the reader to safely
   635  	// limit reads while buffering.
   636  	l := uint32(c.compressed.Len())
   637  
   638  	binary.BigEndian.PutUint32(w.scratch[:], l)
   639  	if _, err := w.out.Write(w.scratch[:4]); err != nil {
   640  		return err
   641  	}
   642  
   643  	// Write out to the stream.
   644  	if _, err := io.CopyN(w.out, c.compressed, int64(c.compressed.Len())); err != nil {
   645  		return err
   646  	}
   647  
   648  	if w.hashPool != nil {
   649  		io.CopyN(c.h, bytes.NewReader(w.lastSum), int64(len(w.lastSum)))
   650  		sum := c.h.Sum(nil)
   651  		w.hashPool.putHash(c.h)
   652  		c.h = nil
   653  		if _, err := io.CopyN(w.out, bytes.NewReader(sum), int64(len(sum))); err != nil {
   654  			return err
   655  		}
   656  		w.lastSum = sum
   657  	}
   658  
   659  	return nil
   660  }
   661  
   662  // WriteByte implements wire.Writer.WriteByte.
   663  //
   664  // Note that this implementation is necessary on the object itself, as an
   665  // interface-based dispatch cannot tell whether the array backing the slice
   666  // escapes, therefore the all bytes written will generate an escape.
   667  func (w *Writer) WriteByte(b byte) error {
   668  	var p [1]byte
   669  	p[0] = b
   670  	n, err := w.Write(p[:])
   671  	if n != 1 {
   672  		return err
   673  	}
   674  	return nil
   675  }
   676  
   677  // Write implements io.Writer.Write.
   678  func (w *Writer) Write(p []byte) (int, error) {
   679  	w.mu.Lock()
   680  	defer w.mu.Unlock()
   681  
   682  	// Did we close already?
   683  	if w.closed {
   684  		return 0, io.ErrUnexpectedEOF
   685  	}
   686  
   687  	// See above; we need to track in the same way.
   688  	var (
   689  		pendingPre    = w.nextInput - w.nextOutput
   690  		pendingInline = 0
   691  	)
   692  	callback := func(c *chunk) error {
   693  		if pendingPre > 0 {
   694  			pendingPre--
   695  			err := w.flush(c)
   696  			c.uncompressed.Reset()
   697  			bufPool.Put(c.uncompressed)
   698  			return err
   699  		}
   700  		if pendingInline > 0 {
   701  			pendingInline--
   702  			return w.flush(c)
   703  		}
   704  		panic("both pendingPre and pendingInline exhausted")
   705  	}
   706  
   707  	for done := 0; done < len(p); {
   708  		// Construct an inline buffer if we're doing an inline
   709  		// encoding; see above regarding the bytes.MinRead constraint.
   710  		inline := false
   711  		if w.buf.Len() == 0 && len(p) >= done+int(w.chunkSize) && len(p) >= done+bytes.MinRead {
   712  			bufPool.Put(w.buf) // Return to the pool; never scheduled.
   713  			w.buf = bytes.NewBuffer(p[done : done+int(w.chunkSize)])
   714  			done += int(w.chunkSize)
   715  			pendingInline++
   716  			inline = true
   717  		}
   718  
   719  		// Do we need to flush w.buf? Note that this case should be hit
   720  		// immediately following the inline case above.
   721  		left := int(w.chunkSize) - w.buf.Len()
   722  		if left == 0 {
   723  			if err := w.schedule(newChunk(nil, nil, nil, w.buf), callback); err != nil {
   724  				return done, err
   725  			}
   726  			if !inline {
   727  				pendingPre++
   728  			}
   729  			// Reset the buffer, since this has now been scheduled
   730  			// for compression. Note that this may be trampled
   731  			// immediately by the bufPool.Put(w.buf) above if the
   732  			// next buffer happens to be inline, but that's okay.
   733  			w.buf = bufPool.Get().(*bytes.Buffer)
   734  			continue
   735  		}
   736  
   737  		// Read from p into w.buf.
   738  		toWrite := len(p) - done
   739  		if toWrite > left {
   740  			toWrite = left
   741  		}
   742  		n, err := w.buf.Write(p[done : done+toWrite])
   743  		done += n
   744  		if err != nil {
   745  			return done, err
   746  		}
   747  	}
   748  
   749  	// Make sure that everything has been flushed, we can't return until
   750  	// all the contents from p have been used.
   751  	for pendingInline > 0 {
   752  		if err := w.schedule(nil, func(c *chunk) error {
   753  			if err := callback(c); err != nil {
   754  				return err
   755  			}
   756  			// The flush was successful, return errNewBuffer here
   757  			// to break from the loop and check the condition
   758  			// again.
   759  			return errNewBuffer
   760  		}); err != errNewBuffer {
   761  			return len(p), err
   762  		}
   763  	}
   764  
   765  	return len(p), nil
   766  }
   767  
   768  // Close implements io.Closer.Close.
   769  func (w *Writer) Close() error {
   770  	w.mu.Lock()
   771  	defer w.mu.Unlock()
   772  
   773  	// Did we already close? After the call to Close, we always mark as
   774  	// closed, regardless of whether the flush is successful.
   775  	if w.closed {
   776  		return io.ErrUnexpectedEOF
   777  	}
   778  	w.closed = true
   779  	defer w.stop()
   780  
   781  	// Schedule any remaining partial buffer; we pass w.flush directly here
   782  	// because the final buffer is guaranteed to not be an inline buffer.
   783  	if w.buf.Len() > 0 {
   784  		if err := w.schedule(newChunk(nil, nil, nil, w.buf), w.flush); err != nil {
   785  			return err
   786  		}
   787  	}
   788  
   789  	// Flush all scheduled buffers; see above.
   790  	if err := w.schedule(nil, w.flush); err != nil {
   791  		return err
   792  	}
   793  
   794  	// Close the underlying writer (if necessary).
   795  	if closer, ok := w.out.(io.Closer); ok {
   796  		return closer.Close()
   797  	}
   798  	return nil
   799  }