github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/compressio/compressio.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package compressio provides parallel compression and decompression, as well 16 // as optional SHA-256 hashing. 17 // 18 // The stream format is defined as follows. 19 // 20 // /------------------------------------------------------\ 21 // | chunk size (4-bytes) | 22 // +------------------------------------------------------+ 23 // | (optional) hash (32-bytes) | 24 // +------------------------------------------------------+ 25 // | compressed data size (4-bytes) | 26 // +------------------------------------------------------+ 27 // | compressed data | 28 // +------------------------------------------------------+ 29 // | (optional) hash (32-bytes) | 30 // +------------------------------------------------------+ 31 // | compressed data size (4-bytes) | 32 // +------------------------------------------------------+ 33 // | ...... | 34 // \------------------------------------------------------/ 35 // 36 // where each subsequent hash is calculated from the following items in order 37 // 38 // compressed data 39 // compressed data size 40 // previous hash 41 // 42 // so the stream integrity cannot be compromised by switching and mixing 43 // compressed chunks. 44 package compressio 45 46 import ( 47 "bytes" 48 "compress/flate" 49 "crypto/hmac" 50 "crypto/sha256" 51 "encoding/binary" 52 "errors" 53 "hash" 54 "io" 55 "runtime" 56 57 "github.com/SagerNet/gvisor/pkg/sync" 58 ) 59 60 var bufPool = sync.Pool{ 61 New: func() interface{} { 62 return bytes.NewBuffer(nil) 63 }, 64 } 65 66 var chunkPool = sync.Pool{ 67 New: func() interface{} { 68 return new(chunk) 69 }, 70 } 71 72 // chunk is a unit of work. 73 type chunk struct { 74 // compressed is compressed data. 75 // 76 // This will always be returned to the bufPool directly when work has 77 // finished (in schedule) and therefore must be allocated. 78 compressed *bytes.Buffer 79 80 // uncompressed is the uncompressed data. 81 // 82 // This is not returned to the bufPool automatically, since it may 83 // correspond to a inline slice (provided directly to Read or Write). 84 uncompressed *bytes.Buffer 85 86 // The current hash object. Only used in compress mode. 87 h hash.Hash 88 89 // The hash from previous chunks. Only used in uncompress mode. 90 lastSum []byte 91 92 // The expected hash after current chunk. Only used in uncompress mode. 93 sum []byte 94 } 95 96 // newChunk allocates a new chunk object (or pulls one from the pool). Buffers 97 // will be allocated if nil is provided for compressed or uncompressed. 98 func newChunk(lastSum []byte, sum []byte, compressed *bytes.Buffer, uncompressed *bytes.Buffer) *chunk { 99 c := chunkPool.Get().(*chunk) 100 c.lastSum = lastSum 101 c.sum = sum 102 if compressed != nil { 103 c.compressed = compressed 104 } else { 105 c.compressed = bufPool.Get().(*bytes.Buffer) 106 } 107 if uncompressed != nil { 108 c.uncompressed = uncompressed 109 } else { 110 c.uncompressed = bufPool.Get().(*bytes.Buffer) 111 } 112 return c 113 } 114 115 // result is the result of some work; it includes the original chunk. 116 type result struct { 117 *chunk 118 err error 119 } 120 121 // worker is a compression/decompression worker. 122 // 123 // The associated worker goroutine reads in uncompressed buffers from input and 124 // writes compressed buffers to its output. Alternatively, the worker reads 125 // compressed buffers from input and writes uncompressed buffers to its output. 126 // 127 // The goroutine will exit when input is closed, and the goroutine will close 128 // output. 129 type worker struct { 130 hashPool *hashPool 131 input chan *chunk 132 output chan result 133 134 // scratch is a temporary buffer used for marshalling. This is declared 135 // unfront here to avoid reallocation. 136 scratch [4]byte 137 } 138 139 // work is the main work routine; see worker. 140 func (w *worker) work(compress bool, level int) { 141 defer close(w.output) 142 143 var h hash.Hash 144 145 for c := range w.input { 146 if h == nil && w.hashPool != nil { 147 h = w.hashPool.getHash() 148 } 149 if compress { 150 mw := io.Writer(c.compressed) 151 if h != nil { 152 mw = io.MultiWriter(mw, h) 153 } 154 155 // Encode this slice. 156 fw, err := flate.NewWriter(mw, level) 157 if err != nil { 158 w.output <- result{c, err} 159 continue 160 } 161 162 // Encode the input. 163 if _, err := io.CopyN(fw, c.uncompressed, int64(c.uncompressed.Len())); err != nil { 164 w.output <- result{c, err} 165 continue 166 } 167 if err := fw.Close(); err != nil { 168 w.output <- result{c, err} 169 continue 170 } 171 172 // Write the hash, if enabled. 173 if h != nil { 174 binary.BigEndian.PutUint32(w.scratch[:], uint32(c.compressed.Len())) 175 h.Write(w.scratch[:4]) 176 c.h = h 177 h = nil 178 } 179 } else { 180 // Check the hash of the compressed contents. 181 if h != nil { 182 h.Write(c.compressed.Bytes()) 183 binary.BigEndian.PutUint32(w.scratch[:], uint32(c.compressed.Len())) 184 h.Write(w.scratch[:4]) 185 io.CopyN(h, bytes.NewReader(c.lastSum), int64(len(c.lastSum))) 186 187 sum := h.Sum(nil) 188 h.Reset() 189 if !hmac.Equal(c.sum, sum) { 190 w.output <- result{c, ErrHashMismatch} 191 continue 192 } 193 } 194 195 // Decode this slice. 196 fr := flate.NewReader(c.compressed) 197 198 // Decode the input. 199 if _, err := io.Copy(c.uncompressed, fr); err != nil { 200 w.output <- result{c, err} 201 continue 202 } 203 } 204 205 // Send the output. 206 w.output <- result{c, nil} 207 } 208 } 209 210 type hashPool struct { 211 // mu protexts the hash list. 212 mu sync.Mutex 213 214 // key is the key used to create hash objects. 215 key []byte 216 217 // hashes is the hash object free list. Note that this cannot be 218 // globally shared across readers or writers, as it is key-specific. 219 hashes []hash.Hash 220 } 221 222 // getHash gets a hash object for the pool. It should only be called when the 223 // pool key is non-nil. 224 func (p *hashPool) getHash() hash.Hash { 225 p.mu.Lock() 226 defer p.mu.Unlock() 227 228 if len(p.hashes) == 0 { 229 return hmac.New(sha256.New, p.key) 230 } 231 232 h := p.hashes[len(p.hashes)-1] 233 p.hashes = p.hashes[:len(p.hashes)-1] 234 return h 235 } 236 237 func (p *hashPool) putHash(h hash.Hash) { 238 h.Reset() 239 240 p.mu.Lock() 241 defer p.mu.Unlock() 242 243 p.hashes = append(p.hashes, h) 244 } 245 246 // pool is common functionality for reader/writers. 247 type pool struct { 248 // workers are the compression/decompression workers. 249 workers []worker 250 251 // chunkSize is the chunk size. This is the first four bytes in the 252 // stream and is shared across both the reader and writer. 253 chunkSize uint32 254 255 // mu protects below; it is generally the responsibility of users to 256 // acquire this mutex before calling any methods on the pool. 257 mu sync.Mutex 258 259 // nextInput is the next worker for input (scheduling). 260 nextInput int 261 262 // nextOutput is the next worker for output (result). 263 nextOutput int 264 265 // buf is the current active buffer; the exact semantics of this buffer 266 // depending on whether this is a reader or a writer. 267 buf *bytes.Buffer 268 269 // lasSum records the hash of the last chunk processed. 270 lastSum []byte 271 272 // hashPool is the hash object pool. It cannot be embedded into pool 273 // itself as worker refers to it and that would stop pool from being 274 // GCed. 275 hashPool *hashPool 276 } 277 278 // init initializes the worker pool. 279 // 280 // This should only be called once. 281 func (p *pool) init(key []byte, workers int, compress bool, level int) { 282 if key != nil { 283 p.hashPool = &hashPool{key: key} 284 } 285 p.workers = make([]worker, workers) 286 for i := 0; i < len(p.workers); i++ { 287 p.workers[i] = worker{ 288 hashPool: p.hashPool, 289 input: make(chan *chunk, 1), 290 output: make(chan result, 1), 291 } 292 go p.workers[i].work(compress, level) // S/R-SAFE: In save path only. 293 } 294 runtime.SetFinalizer(p, (*pool).stop) 295 } 296 297 // stop stops all workers. 298 func (p *pool) stop() { 299 for i := 0; i < len(p.workers); i++ { 300 close(p.workers[i].input) 301 } 302 p.workers = nil 303 p.hashPool = nil 304 } 305 306 // handleResult calls the callback. 307 func handleResult(r result, callback func(*chunk) error) error { 308 defer func() { 309 r.chunk.compressed.Reset() 310 bufPool.Put(r.chunk.compressed) 311 chunkPool.Put(r.chunk) 312 }() 313 if r.err != nil { 314 return r.err 315 } 316 return callback(r.chunk) 317 } 318 319 // schedule schedules the given buffers. 320 // 321 // If c is non-nil, then it will return as soon as the chunk is scheduled. If c 322 // is nil, then it will return only when no more work is left to do. 323 // 324 // If no callback function is provided, then the output channel will be 325 // ignored. You must be sure that the input is schedulable in this case. 326 func (p *pool) schedule(c *chunk, callback func(*chunk) error) error { 327 for { 328 var ( 329 inputChan chan *chunk 330 outputChan chan result 331 ) 332 if c != nil && len(p.workers) != 0 { 333 inputChan = p.workers[(p.nextInput+1)%len(p.workers)].input 334 } 335 if callback != nil && p.nextOutput != p.nextInput && len(p.workers) != 0 { 336 outputChan = p.workers[(p.nextOutput+1)%len(p.workers)].output 337 } 338 if inputChan == nil && outputChan == nil { 339 return nil 340 } 341 342 select { 343 case inputChan <- c: 344 p.nextInput++ 345 return nil 346 case r := <-outputChan: 347 p.nextOutput++ 348 if err := handleResult(r, callback); err != nil { 349 return err 350 } 351 } 352 } 353 } 354 355 // Reader is a compressed reader. 356 type Reader struct { 357 pool 358 359 // in is the source. 360 in io.Reader 361 362 // scratch is a temporary buffer used for marshalling. This is declared 363 // unfront here to avoid reallocation. 364 scratch [4]byte 365 } 366 367 var _ io.Reader = (*Reader)(nil) 368 369 // NewReader returns a new compressed reader. If key is non-nil, the data stream 370 // is assumed to contain expected hash values, which will be compared against 371 // hash values computed from the compressed bytes. See package comments for 372 // details. 373 func NewReader(in io.Reader, key []byte) (*Reader, error) { 374 r := &Reader{ 375 in: in, 376 } 377 378 // Use double buffering for read. 379 r.init(key, 2*runtime.GOMAXPROCS(0), false, 0) 380 381 if _, err := io.ReadFull(in, r.scratch[:4]); err != nil { 382 return nil, err 383 } 384 r.chunkSize = binary.BigEndian.Uint32(r.scratch[:4]) 385 386 if r.hashPool != nil { 387 h := r.hashPool.getHash() 388 binary.BigEndian.PutUint32(r.scratch[:], r.chunkSize) 389 h.Write(r.scratch[:4]) 390 r.lastSum = h.Sum(nil) 391 r.hashPool.putHash(h) 392 sum := make([]byte, len(r.lastSum)) 393 if _, err := io.ReadFull(r.in, sum); err != nil { 394 return nil, err 395 } 396 if !hmac.Equal(r.lastSum, sum) { 397 return nil, ErrHashMismatch 398 } 399 } 400 401 return r, nil 402 } 403 404 // errNewBuffer is returned when a new buffer is completed. 405 var errNewBuffer = errors.New("buffer ready") 406 407 // ErrHashMismatch is returned if the hash does not match. 408 var ErrHashMismatch = errors.New("hash mismatch") 409 410 // ReadByte implements wire.Reader.ReadByte. 411 func (r *Reader) ReadByte() (byte, error) { 412 var p [1]byte 413 n, err := r.Read(p[:]) 414 if n != 1 { 415 return p[0], err 416 } 417 // Suppress EOF. 418 return p[0], nil 419 } 420 421 // Read implements io.Reader.Read. 422 func (r *Reader) Read(p []byte) (int, error) { 423 r.mu.Lock() 424 defer r.mu.Unlock() 425 426 // Total bytes completed; this is declared up front because it must be 427 // adjustable by the callback below. 428 done := 0 429 430 // Total bytes pending in the asynchronous workers for buffers. This is 431 // used to process the proper regions of the input as inline buffers. 432 var ( 433 pendingPre = r.nextInput - r.nextOutput 434 pendingInline = 0 435 ) 436 437 // Define our callback for completed work. 438 callback := func(c *chunk) error { 439 // Check for an inline buffer. 440 if pendingPre == 0 && pendingInline > 0 { 441 pendingInline-- 442 done += c.uncompressed.Len() 443 return nil 444 } 445 446 // Copy the resulting buffer to our intermediate one, and 447 // return errNewBuffer to ensure that we aren't called a second 448 // time. This error code is handled specially below. 449 // 450 // c.buf will be freed and return to the pool when it is done. 451 if pendingPre > 0 { 452 pendingPre-- 453 } 454 r.buf = c.uncompressed 455 return errNewBuffer 456 } 457 458 for done < len(p) { 459 // Do we have buffered data available? 460 if r.buf != nil { 461 n, err := r.buf.Read(p[done:]) 462 done += n 463 if err == io.EOF { 464 // This is the uncompressed buffer, it can be 465 // returned to the pool at this point. 466 r.buf.Reset() 467 bufPool.Put(r.buf) 468 r.buf = nil 469 } else if err != nil { 470 // Should never happen. 471 defer r.stop() 472 return done, err 473 } 474 continue 475 } 476 477 // Read the length of the next chunk and reset the 478 // reader. The length is used to limit the reader. 479 // 480 // See writer.flush. 481 if _, err := io.ReadFull(r.in, r.scratch[:4]); err != nil { 482 // This is generally okay as long as there 483 // are still buffers outstanding. We actually 484 // just wait for completion of those buffers here 485 // and continue our loop. 486 if err := r.schedule(nil, callback); err == nil { 487 // We've actually finished all buffers; this is 488 // the normal EOF exit path. 489 defer r.stop() 490 return done, io.EOF 491 } else if err == errNewBuffer { 492 // A new buffer is now available. 493 continue 494 } else { 495 // Some other error occurred; we cannot 496 // process any further. 497 defer r.stop() 498 return done, err 499 } 500 } 501 l := binary.BigEndian.Uint32(r.scratch[:4]) 502 503 // Read this chunk and schedule decompression. 504 compressed := bufPool.Get().(*bytes.Buffer) 505 if _, err := io.CopyN(compressed, r.in, int64(l)); err != nil { 506 // Some other error occurred; see above. 507 if err == io.EOF { 508 err = io.ErrUnexpectedEOF 509 } 510 return done, err 511 } 512 513 var sum []byte 514 if r.hashPool != nil { 515 sum = make([]byte, len(r.lastSum)) 516 if _, err := io.ReadFull(r.in, sum); err != nil { 517 if err == io.EOF { 518 err = io.ErrUnexpectedEOF 519 } 520 return done, err 521 } 522 } 523 524 // Are we doing inline decoding? 525 // 526 // Note that we need to check the length here against 527 // bytes.MinRead, since the bytes library will choose to grow 528 // the slice if the available capacity is not at least 529 // bytes.MinRead. This limits inline decoding to chunkSizes 530 // that are at least bytes.MinRead (which is not unreasonable). 531 var c *chunk 532 start := done + ((pendingPre + pendingInline) * int(r.chunkSize)) 533 if len(p) >= start+int(r.chunkSize) && len(p) >= start+bytes.MinRead { 534 c = newChunk(r.lastSum, sum, compressed, bytes.NewBuffer(p[start:start])) 535 pendingInline++ 536 } else { 537 c = newChunk(r.lastSum, sum, compressed, nil) 538 } 539 r.lastSum = sum 540 if err := r.schedule(c, callback); err == errNewBuffer { 541 // A new buffer was completed while we were reading. 542 // That's great, but we need to force schedule the 543 // current buffer so that it does not get lost. 544 // 545 // It is safe to pass nil as an output function here, 546 // because we know that we just freed up a slot above. 547 r.schedule(c, nil) 548 } else if err != nil { 549 // Some other error occurred; see above. 550 defer r.stop() 551 return done, err 552 } 553 } 554 555 // Make sure that everything has been decoded successfully, otherwise 556 // parts of p may not actually have completed. 557 for pendingInline > 0 { 558 if err := r.schedule(nil, func(c *chunk) error { 559 if err := callback(c); err != nil { 560 return err 561 } 562 // The nil case means that an inline buffer has 563 // completed. The callback will have already removed 564 // the inline buffer from the map, so we just return an 565 // error to check the top of the loop again. 566 return errNewBuffer 567 }); err != errNewBuffer { 568 // Some other error occurred; see above. 569 return done, err 570 } 571 } 572 573 // Need to return done here, since it may have been adjusted by the 574 // callback to compensation for partial reads on some inline buffer. 575 return done, nil 576 } 577 578 // Writer is a compressed writer. 579 type Writer struct { 580 pool 581 582 // out is the underlying writer. 583 out io.Writer 584 585 // closed indicates whether the file has been closed. 586 closed bool 587 588 // scratch is a temporary buffer used for marshalling. This is declared 589 // unfront here to avoid reallocation. 590 scratch [4]byte 591 } 592 593 var _ io.Writer = (*Writer)(nil) 594 595 // NewWriter returns a new compressed writer. If key is non-nil, hash values are 596 // generated and written out for compressed bytes. See package comments for 597 // details. 598 // 599 // The recommended chunkSize is on the order of 1M. Extra memory may be 600 // buffered (in the form of read-ahead, or buffered writes), and is limited to 601 // O(chunkSize * [1+GOMAXPROCS]). 602 func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (*Writer, error) { 603 w := &Writer{ 604 pool: pool{ 605 chunkSize: chunkSize, 606 buf: bufPool.Get().(*bytes.Buffer), 607 }, 608 out: out, 609 } 610 w.init(key, 1+runtime.GOMAXPROCS(0), true, level) 611 612 binary.BigEndian.PutUint32(w.scratch[:], chunkSize) 613 if _, err := w.out.Write(w.scratch[:4]); err != nil { 614 return nil, err 615 } 616 617 if w.hashPool != nil { 618 h := w.hashPool.getHash() 619 binary.BigEndian.PutUint32(w.scratch[:], chunkSize) 620 h.Write(w.scratch[:4]) 621 w.lastSum = h.Sum(nil) 622 w.hashPool.putHash(h) 623 if _, err := io.CopyN(w.out, bytes.NewReader(w.lastSum), int64(len(w.lastSum))); err != nil { 624 return nil, err 625 } 626 } 627 628 return w, nil 629 } 630 631 // flush writes a single buffer. 632 func (w *Writer) flush(c *chunk) error { 633 // Prefix each chunk with a length; this allows the reader to safely 634 // limit reads while buffering. 635 l := uint32(c.compressed.Len()) 636 637 binary.BigEndian.PutUint32(w.scratch[:], l) 638 if _, err := w.out.Write(w.scratch[:4]); err != nil { 639 return err 640 } 641 642 // Write out to the stream. 643 if _, err := io.CopyN(w.out, c.compressed, int64(c.compressed.Len())); err != nil { 644 return err 645 } 646 647 if w.hashPool != nil { 648 io.CopyN(c.h, bytes.NewReader(w.lastSum), int64(len(w.lastSum))) 649 sum := c.h.Sum(nil) 650 w.hashPool.putHash(c.h) 651 c.h = nil 652 if _, err := io.CopyN(w.out, bytes.NewReader(sum), int64(len(sum))); err != nil { 653 return err 654 } 655 w.lastSum = sum 656 } 657 658 return nil 659 } 660 661 // WriteByte implements wire.Writer.WriteByte. 662 // 663 // Note that this implementation is necessary on the object itself, as an 664 // interface-based dispatch cannot tell whether the array backing the slice 665 // escapes, therefore the all bytes written will generate an escape. 666 func (w *Writer) WriteByte(b byte) error { 667 var p [1]byte 668 p[0] = b 669 n, err := w.Write(p[:]) 670 if n != 1 { 671 return err 672 } 673 return nil 674 } 675 676 // Write implements io.Writer.Write. 677 func (w *Writer) Write(p []byte) (int, error) { 678 w.mu.Lock() 679 defer w.mu.Unlock() 680 681 // Did we close already? 682 if w.closed { 683 return 0, io.ErrUnexpectedEOF 684 } 685 686 // See above; we need to track in the same way. 687 var ( 688 pendingPre = w.nextInput - w.nextOutput 689 pendingInline = 0 690 ) 691 callback := func(c *chunk) error { 692 if pendingPre == 0 && pendingInline > 0 { 693 pendingInline-- 694 return w.flush(c) 695 } 696 if pendingPre > 0 { 697 pendingPre-- 698 } 699 err := w.flush(c) 700 c.uncompressed.Reset() 701 bufPool.Put(c.uncompressed) 702 return err 703 } 704 705 for done := 0; done < len(p); { 706 // Construct an inline buffer if we're doing an inline 707 // encoding; see above regarding the bytes.MinRead constraint. 708 if w.buf.Len() == 0 && len(p) >= done+int(w.chunkSize) && len(p) >= done+bytes.MinRead { 709 bufPool.Put(w.buf) // Return to the pool; never scheduled. 710 w.buf = bytes.NewBuffer(p[done : done+int(w.chunkSize)]) 711 done += int(w.chunkSize) 712 pendingInline++ 713 } 714 715 // Do we need to flush w.buf? Note that this case should be hit 716 // immediately following the inline case above. 717 left := int(w.chunkSize) - w.buf.Len() 718 if left == 0 { 719 if err := w.schedule(newChunk(nil, nil, nil, w.buf), callback); err != nil { 720 return done, err 721 } 722 // Reset the buffer, since this has now been scheduled 723 // for compression. Note that this may be trampled 724 // immediately by the bufPool.Put(w.buf) above if the 725 // next buffer happens to be inline, but that's okay. 726 w.buf = bufPool.Get().(*bytes.Buffer) 727 continue 728 } 729 730 // Read from p into w.buf. 731 toWrite := len(p) - done 732 if toWrite > left { 733 toWrite = left 734 } 735 n, err := w.buf.Write(p[done : done+toWrite]) 736 done += n 737 if err != nil { 738 return done, err 739 } 740 } 741 742 // Make sure that everything has been flushed, we can't return until 743 // all the contents from p have been used. 744 for pendingInline > 0 { 745 if err := w.schedule(nil, func(c *chunk) error { 746 if err := callback(c); err != nil { 747 return err 748 } 749 // The flush was successful, return errNewBuffer here 750 // to break from the loop and check the condition 751 // again. 752 return errNewBuffer 753 }); err != errNewBuffer { 754 return len(p), err 755 } 756 } 757 758 return len(p), nil 759 } 760 761 // Close implements io.Closer.Close. 762 func (w *Writer) Close() error { 763 w.mu.Lock() 764 defer w.mu.Unlock() 765 766 // Did we already close? After the call to Close, we always mark as 767 // closed, regardless of whether the flush is successful. 768 if w.closed { 769 return io.ErrUnexpectedEOF 770 } 771 w.closed = true 772 defer w.stop() 773 774 // Schedule any remaining partial buffer; we pass w.flush directly here 775 // because the final buffer is guaranteed to not be an inline buffer. 776 if w.buf.Len() > 0 { 777 if err := w.schedule(newChunk(nil, nil, nil, w.buf), w.flush); err != nil { 778 return err 779 } 780 } 781 782 // Flush all scheduled buffers; see above. 783 if err := w.schedule(nil, w.flush); err != nil { 784 return err 785 } 786 787 // Close the underlying writer (if necessary). 788 if closer, ok := w.out.(io.Closer); ok { 789 return closer.Close() 790 } 791 return nil 792 }