gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/compressio/compressio.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package compressio provides parallel compression and decompression, as well 16 // as optional SHA-256 hashing. It also provides another storage variant 17 // (nocompressio) that does not compress data but tracks its integrity. 18 // 19 // The stream format is defined as follows. 20 // 21 // /------------------------------------------------------\ 22 // | chunk size (4-bytes) | 23 // +------------------------------------------------------+ 24 // | (optional) hash (32-bytes) | 25 // +------------------------------------------------------+ 26 // | compressed data size (4-bytes) | 27 // +------------------------------------------------------+ 28 // | compressed data | 29 // +------------------------------------------------------+ 30 // | (optional) hash (32-bytes) | 31 // +------------------------------------------------------+ 32 // | compressed data size (4-bytes) | 33 // +------------------------------------------------------+ 34 // | ...... | 35 // \------------------------------------------------------/ 36 // 37 // where each subsequent hash is calculated from the following items in order 38 // 39 // compressed data 40 // compressed data size 41 // previous hash 42 // 43 // so the stream integrity cannot be compromised by switching and mixing 44 // compressed chunks. 45 package compressio 46 47 import ( 48 "bytes" 49 "compress/flate" 50 "crypto/hmac" 51 "crypto/sha256" 52 "encoding/binary" 53 "errors" 54 "hash" 55 "io" 56 "runtime" 57 58 "gvisor.dev/gvisor/pkg/sync" 59 ) 60 61 var bufPool = sync.Pool{ 62 New: func() any { 63 return bytes.NewBuffer(nil) 64 }, 65 } 66 67 var chunkPool = sync.Pool{ 68 New: func() any { 69 return new(chunk) 70 }, 71 } 72 73 // chunk is a unit of work. 74 type chunk struct { 75 // compressed is compressed data. 76 // 77 // This will always be returned to the bufPool directly when work has 78 // finished (in schedule) and therefore must be allocated. 79 compressed *bytes.Buffer 80 81 // uncompressed is the uncompressed data. 82 // 83 // This is not returned to the bufPool automatically, since it may 84 // correspond to a inline slice (provided directly to Read or Write). 85 uncompressed *bytes.Buffer 86 87 // The current hash object. Only used in compress mode. 88 h hash.Hash 89 90 // The hash from previous chunks. Only used in uncompress mode. 91 lastSum []byte 92 93 // The expected hash after current chunk. Only used in uncompress mode. 94 sum []byte 95 } 96 97 // newChunk allocates a new chunk object (or pulls one from the pool). Buffers 98 // will be allocated if nil is provided for compressed or uncompressed. 99 func newChunk(lastSum []byte, sum []byte, compressed *bytes.Buffer, uncompressed *bytes.Buffer) *chunk { 100 c := chunkPool.Get().(*chunk) 101 c.lastSum = lastSum 102 c.sum = sum 103 if compressed != nil { 104 c.compressed = compressed 105 } else { 106 c.compressed = bufPool.Get().(*bytes.Buffer) 107 } 108 if uncompressed != nil { 109 c.uncompressed = uncompressed 110 } else { 111 c.uncompressed = bufPool.Get().(*bytes.Buffer) 112 } 113 return c 114 } 115 116 // result is the result of some work; it includes the original chunk. 117 type result struct { 118 *chunk 119 err error 120 } 121 122 // worker is a compression/decompression worker. 123 // 124 // The associated worker goroutine reads in uncompressed buffers from input and 125 // writes compressed buffers to its output. Alternatively, the worker reads 126 // compressed buffers from input and writes uncompressed buffers to its output. 127 // 128 // The goroutine will exit when input is closed, and the goroutine will close 129 // output. 130 type worker struct { 131 hashPool *hashPool 132 input chan *chunk 133 output chan result 134 135 // scratch is a temporary buffer used for marshalling. This is declared 136 // unfront here to avoid reallocation. 137 scratch [4]byte 138 } 139 140 // work is the main work routine; see worker. 141 func (w *worker) work(compress bool, level int) { 142 defer close(w.output) 143 144 var h hash.Hash 145 146 for c := range w.input { 147 if h == nil && w.hashPool != nil { 148 h = w.hashPool.getHash() 149 } 150 if compress { 151 mw := io.Writer(c.compressed) 152 if h != nil { 153 mw = io.MultiWriter(mw, h) 154 } 155 156 // Encode this slice. 157 fw, err := flate.NewWriter(mw, level) 158 if err != nil { 159 w.output <- result{c, err} 160 continue 161 } 162 163 // Encode the input. 164 if _, err := io.CopyN(fw, c.uncompressed, int64(c.uncompressed.Len())); err != nil { 165 w.output <- result{c, err} 166 continue 167 } 168 if err := fw.Close(); err != nil { 169 w.output <- result{c, err} 170 continue 171 } 172 173 // Write the hash, if enabled. 174 if h != nil { 175 binary.BigEndian.PutUint32(w.scratch[:], uint32(c.compressed.Len())) 176 h.Write(w.scratch[:4]) 177 c.h = h 178 h = nil 179 } 180 } else { 181 // Check the hash of the compressed contents. 182 if h != nil { 183 h.Write(c.compressed.Bytes()) 184 binary.BigEndian.PutUint32(w.scratch[:], uint32(c.compressed.Len())) 185 h.Write(w.scratch[:4]) 186 io.CopyN(h, bytes.NewReader(c.lastSum), int64(len(c.lastSum))) 187 188 sum := h.Sum(nil) 189 h.Reset() 190 if !hmac.Equal(c.sum, sum) { 191 w.output <- result{c, ErrHashMismatch} 192 continue 193 } 194 } 195 196 // Decode this slice. 197 fr := flate.NewReader(c.compressed) 198 199 // Decode the input. 200 if _, err := io.Copy(c.uncompressed, fr); err != nil { 201 w.output <- result{c, err} 202 continue 203 } 204 } 205 206 // Send the output. 207 w.output <- result{c, nil} 208 } 209 } 210 211 type hashPool struct { 212 // mu protects the hash list. 213 mu sync.Mutex 214 215 // key is the key used to create hash objects. 216 key []byte 217 218 // hashes is the hash object free list. Note that this cannot be 219 // globally shared across readers or writers, as it is key-specific. 220 hashes []hash.Hash 221 } 222 223 // getHash gets a hash object for the pool. It should only be called when the 224 // pool key is non-nil. 225 func (p *hashPool) getHash() hash.Hash { 226 p.mu.Lock() 227 defer p.mu.Unlock() 228 229 if len(p.hashes) == 0 { 230 return hmac.New(sha256.New, p.key) 231 } 232 233 h := p.hashes[len(p.hashes)-1] 234 p.hashes = p.hashes[:len(p.hashes)-1] 235 return h 236 } 237 238 func (p *hashPool) putHash(h hash.Hash) { 239 h.Reset() 240 241 p.mu.Lock() 242 defer p.mu.Unlock() 243 244 p.hashes = append(p.hashes, h) 245 } 246 247 // pool is common functionality for reader/writers. 248 type pool struct { 249 // workers are the compression/decompression workers. 250 workers []worker 251 252 // chunkSize is the chunk size. This is the first four bytes in the 253 // stream and is shared across both the reader and writer. 254 chunkSize uint32 255 256 // mu protects below; it is generally the responsibility of users to 257 // acquire this mutex before calling any methods on the pool. 258 mu sync.Mutex 259 260 // nextInput is the next worker for input (scheduling). 261 nextInput int 262 263 // nextOutput is the next worker for output (result). 264 nextOutput int 265 266 // buf is the current active buffer; the exact semantics of this buffer 267 // depending on whether this is a reader or a writer. 268 buf *bytes.Buffer 269 270 // lasSum records the hash of the last chunk processed. 271 lastSum []byte 272 273 // hashPool is the hash object pool. It cannot be embedded into pool 274 // itself as worker refers to it and that would stop pool from being 275 // GCed. 276 hashPool *hashPool 277 } 278 279 // init initializes the worker pool. 280 // 281 // This should only be called once. 282 func (p *pool) init(key []byte, workers int, compress bool, level int) { 283 if key != nil { 284 p.hashPool = &hashPool{key: key} 285 } 286 p.workers = make([]worker, workers) 287 for i := 0; i < len(p.workers); i++ { 288 p.workers[i] = worker{ 289 hashPool: p.hashPool, 290 input: make(chan *chunk, 1), 291 output: make(chan result, 1), 292 } 293 go p.workers[i].work(compress, level) // S/R-SAFE: In save path only. 294 } 295 runtime.SetFinalizer(p, (*pool).stop) 296 } 297 298 // stop stops all workers. 299 func (p *pool) stop() { 300 for i := 0; i < len(p.workers); i++ { 301 close(p.workers[i].input) 302 } 303 p.workers = nil 304 p.hashPool = nil 305 } 306 307 // handleResult calls the callback. 308 func handleResult(r result, callback func(*chunk) error) error { 309 defer func() { 310 r.chunk.compressed.Reset() 311 bufPool.Put(r.chunk.compressed) 312 chunkPool.Put(r.chunk) 313 }() 314 if r.err != nil { 315 return r.err 316 } 317 return callback(r.chunk) 318 } 319 320 // schedule schedules the given buffers. 321 // 322 // If c is non-nil, then it will return as soon as the chunk is scheduled. If c 323 // is nil, then it will return only when no more work is left to do. 324 // 325 // If no callback function is provided, then the output channel will be 326 // ignored. You must be sure that the input is schedulable in this case. 327 func (p *pool) schedule(c *chunk, callback func(*chunk) error) error { 328 for { 329 var ( 330 inputChan chan *chunk 331 outputChan chan result 332 ) 333 if c != nil && len(p.workers) != 0 { 334 inputChan = p.workers[(p.nextInput+1)%len(p.workers)].input 335 } 336 if callback != nil && p.nextOutput != p.nextInput && len(p.workers) != 0 { 337 outputChan = p.workers[(p.nextOutput+1)%len(p.workers)].output 338 } 339 if inputChan == nil && outputChan == nil { 340 return nil 341 } 342 343 select { 344 case inputChan <- c: 345 p.nextInput++ 346 return nil 347 case r := <-outputChan: 348 p.nextOutput++ 349 if err := handleResult(r, callback); err != nil { 350 return err 351 } 352 } 353 } 354 } 355 356 // Reader is a compressed reader. 357 type Reader struct { 358 pool 359 360 // in is the source. 361 in io.Reader 362 363 // scratch is a temporary buffer used for marshalling. This is declared 364 // unfront here to avoid reallocation. 365 scratch [4]byte 366 } 367 368 var _ io.Reader = (*Reader)(nil) 369 370 // NewReader returns a new compressed reader. If key is non-nil, the data stream 371 // is assumed to contain expected hash values, which will be compared against 372 // hash values computed from the compressed bytes. See package comments for 373 // details. 374 func NewReader(in io.Reader, key []byte) (*Reader, error) { 375 r := &Reader{ 376 in: in, 377 } 378 379 // Use double buffering for read. 380 r.init(key, 2*runtime.GOMAXPROCS(0), false, 0) 381 382 if _, err := io.ReadFull(in, r.scratch[:4]); err != nil { 383 return nil, err 384 } 385 r.chunkSize = binary.BigEndian.Uint32(r.scratch[:4]) 386 387 if r.hashPool != nil { 388 h := r.hashPool.getHash() 389 binary.BigEndian.PutUint32(r.scratch[:], r.chunkSize) 390 h.Write(r.scratch[:4]) 391 r.lastSum = h.Sum(nil) 392 r.hashPool.putHash(h) 393 sum := make([]byte, len(r.lastSum)) 394 if _, err := io.ReadFull(r.in, sum); err != nil { 395 return nil, err 396 } 397 if !hmac.Equal(r.lastSum, sum) { 398 return nil, ErrHashMismatch 399 } 400 } 401 402 return r, nil 403 } 404 405 // errNewBuffer is returned when a new buffer is completed. 406 var errNewBuffer = errors.New("buffer ready") 407 408 // ErrHashMismatch is returned if the hash does not match. 409 var ErrHashMismatch = errors.New("hash mismatch") 410 411 // Read implements io.Reader.Read. 412 func (r *Reader) Read(p []byte) (int, error) { 413 r.mu.Lock() 414 defer r.mu.Unlock() 415 416 // Total bytes completed; this is declared up front because it must be 417 // adjustable by the callback below. 418 done := 0 419 420 // Total bytes pending in the asynchronous workers for buffers. This is 421 // used to process the proper regions of the input as inline buffers. 422 var ( 423 pendingPre = r.nextInput - r.nextOutput 424 pendingInline = 0 425 ) 426 427 // Define our callback for completed work. 428 callback := func(c *chunk) error { 429 // Check for an inline buffer. 430 if pendingPre == 0 && pendingInline > 0 { 431 pendingInline-- 432 done += c.uncompressed.Len() 433 return nil 434 } 435 436 // Copy the resulting buffer to our intermediate one, and 437 // return errNewBuffer to ensure that we aren't called a second 438 // time. This error code is handled specially below. 439 // 440 // c.buf will be freed and return to the pool when it is done. 441 if pendingPre > 0 { 442 pendingPre-- 443 } 444 r.buf = c.uncompressed 445 return errNewBuffer 446 } 447 448 for done < len(p) { 449 // Do we have buffered data available? 450 if r.buf != nil { 451 n, err := r.buf.Read(p[done:]) 452 done += n 453 if err == io.EOF { 454 // This is the uncompressed buffer, it can be 455 // returned to the pool at this point. 456 r.buf.Reset() 457 bufPool.Put(r.buf) 458 r.buf = nil 459 } else if err != nil { 460 // Should never happen. 461 defer r.stop() 462 return done, err 463 } 464 continue 465 } 466 467 // Read the length of the next chunk and reset the 468 // reader. The length is used to limit the reader. 469 // 470 // See writer.flush. 471 if _, err := io.ReadFull(r.in, r.scratch[:4]); err != nil { 472 // This is generally okay as long as there 473 // are still buffers outstanding. We actually 474 // just wait for completion of those buffers here 475 // and continue our loop. 476 if err := r.schedule(nil, callback); err == nil { 477 // We've actually finished all buffers; this is 478 // the normal EOF exit path. 479 defer r.stop() 480 return done, io.EOF 481 } else if err == errNewBuffer { 482 // A new buffer is now available. 483 continue 484 } else { 485 // Some other error occurred; we cannot 486 // process any further. 487 defer r.stop() 488 return done, err 489 } 490 } 491 l := binary.BigEndian.Uint32(r.scratch[:4]) 492 493 // Read this chunk and schedule decompression. 494 compressed := bufPool.Get().(*bytes.Buffer) 495 if _, err := io.CopyN(compressed, r.in, int64(l)); err != nil { 496 // Some other error occurred; see above. 497 if err == io.EOF { 498 err = io.ErrUnexpectedEOF 499 } 500 return done, err 501 } 502 503 var sum []byte 504 if r.hashPool != nil { 505 sum = make([]byte, len(r.lastSum)) 506 if _, err := io.ReadFull(r.in, sum); err != nil { 507 if err == io.EOF { 508 err = io.ErrUnexpectedEOF 509 } 510 return done, err 511 } 512 } 513 514 // Are we doing inline decoding? 515 // 516 // Note that we need to check the length here against 517 // bytes.MinRead, since the bytes library will choose to grow 518 // the slice if the available capacity is not at least 519 // bytes.MinRead. This limits inline decoding to chunkSizes 520 // that are at least bytes.MinRead (which is not unreasonable). 521 var c *chunk 522 start := done + ((pendingPre + pendingInline) * int(r.chunkSize)) 523 if len(p) >= start+int(r.chunkSize) && len(p) >= start+bytes.MinRead { 524 c = newChunk(r.lastSum, sum, compressed, bytes.NewBuffer(p[start:start])) 525 pendingInline++ 526 } else { 527 c = newChunk(r.lastSum, sum, compressed, nil) 528 } 529 r.lastSum = sum 530 if err := r.schedule(c, callback); err == errNewBuffer { 531 // A new buffer was completed while we were reading. 532 // That's great, but we need to force schedule the 533 // current buffer so that it does not get lost. 534 // 535 // It is safe to pass nil as an output function here, 536 // because we know that we just freed up a slot above. 537 r.schedule(c, nil) 538 } else if err != nil { 539 // Some other error occurred; see above. 540 defer r.stop() 541 return done, err 542 } 543 } 544 545 // Make sure that everything has been decoded successfully, otherwise 546 // parts of p may not actually have completed. 547 for pendingInline > 0 { 548 if err := r.schedule(nil, func(c *chunk) error { 549 if err := callback(c); err != nil { 550 return err 551 } 552 // The nil case means that an inline buffer has 553 // completed. The callback will have already removed 554 // the inline buffer from the map, so we just return an 555 // error to check the top of the loop again. 556 return errNewBuffer 557 }); err != errNewBuffer { 558 // Some other error occurred; see above. 559 return done, err 560 } 561 } 562 563 // Need to return done here, since it may have been adjusted by the 564 // callback to compensation for partial reads on some inline buffer. 565 return done, nil 566 } 567 568 // Writer is a compressed writer. 569 type Writer struct { 570 pool 571 572 // out is the underlying writer. 573 out io.Writer 574 575 // closed indicates whether the file has been closed. 576 closed bool 577 578 // scratch is a temporary buffer used for marshalling. This is declared 579 // unfront here to avoid reallocation. 580 scratch [4]byte 581 } 582 583 var _ io.Writer = (*Writer)(nil) 584 585 // NewWriter returns a new compressed writer. If key is non-nil, hash values are 586 // generated and written out for compressed bytes. See package comments for 587 // details. 588 // 589 // The recommended chunkSize is on the order of 1M. Extra memory may be 590 // buffered (in the form of read-ahead, or buffered writes), and is limited to 591 // O(chunkSize * [1+GOMAXPROCS]). 592 func NewWriter(out io.Writer, key []byte, chunkSize uint32, level int) (*Writer, error) { 593 w := &Writer{ 594 pool: pool{ 595 chunkSize: chunkSize, 596 buf: bufPool.Get().(*bytes.Buffer), 597 }, 598 out: out, 599 } 600 w.init(key, 1+runtime.GOMAXPROCS(0), true, level) 601 602 binary.BigEndian.PutUint32(w.scratch[:], chunkSize) 603 if _, err := w.out.Write(w.scratch[:4]); err != nil { 604 return nil, err 605 } 606 607 if w.hashPool != nil { 608 h := w.hashPool.getHash() 609 binary.BigEndian.PutUint32(w.scratch[:], chunkSize) 610 h.Write(w.scratch[:4]) 611 w.lastSum = h.Sum(nil) 612 w.hashPool.putHash(h) 613 if _, err := io.CopyN(w.out, bytes.NewReader(w.lastSum), int64(len(w.lastSum))); err != nil { 614 return nil, err 615 } 616 } 617 618 return w, nil 619 } 620 621 // flush writes a single buffer. 622 func (w *Writer) flush(c *chunk) error { 623 // Prefix each chunk with a length; this allows the reader to safely 624 // limit reads while buffering. 625 l := uint32(c.compressed.Len()) 626 627 binary.BigEndian.PutUint32(w.scratch[:], l) 628 if _, err := w.out.Write(w.scratch[:4]); err != nil { 629 return err 630 } 631 632 // Write out to the stream. 633 if _, err := io.CopyN(w.out, c.compressed, int64(c.compressed.Len())); err != nil { 634 return err 635 } 636 637 if w.hashPool != nil { 638 io.CopyN(c.h, bytes.NewReader(w.lastSum), int64(len(w.lastSum))) 639 sum := c.h.Sum(nil) 640 w.hashPool.putHash(c.h) 641 c.h = nil 642 if _, err := io.CopyN(w.out, bytes.NewReader(sum), int64(len(sum))); err != nil { 643 return err 644 } 645 w.lastSum = sum 646 } 647 648 return nil 649 } 650 651 // Write implements io.Writer.Write. 652 func (w *Writer) Write(p []byte) (int, error) { 653 w.mu.Lock() 654 defer w.mu.Unlock() 655 656 // Did we close already? 657 if w.closed { 658 return 0, io.ErrUnexpectedEOF 659 } 660 661 // See above; we need to track in the same way. 662 var ( 663 pendingPre = w.nextInput - w.nextOutput 664 pendingInline = 0 665 ) 666 callback := func(c *chunk) error { 667 if pendingPre > 0 { 668 pendingPre-- 669 err := w.flush(c) 670 c.uncompressed.Reset() 671 bufPool.Put(c.uncompressed) 672 return err 673 } 674 if pendingInline > 0 { 675 pendingInline-- 676 return w.flush(c) 677 } 678 panic("both pendingPre and pendingInline exhausted") 679 } 680 681 for done := 0; done < len(p); { 682 // Construct an inline buffer if we're doing an inline 683 // encoding; see above regarding the bytes.MinRead constraint. 684 inline := false 685 if w.buf.Len() == 0 && len(p) >= done+int(w.chunkSize) && len(p) >= done+bytes.MinRead { 686 bufPool.Put(w.buf) // Return to the pool; never scheduled. 687 w.buf = bytes.NewBuffer(p[done : done+int(w.chunkSize)]) 688 done += int(w.chunkSize) 689 pendingInline++ 690 inline = true 691 } 692 693 // Do we need to flush w.buf? Note that this case should be hit 694 // immediately following the inline case above. 695 left := int(w.chunkSize) - w.buf.Len() 696 if left == 0 { 697 if err := w.schedule(newChunk(nil, nil, nil, w.buf), callback); err != nil { 698 return done, err 699 } 700 if !inline { 701 pendingPre++ 702 } 703 // Reset the buffer, since this has now been scheduled 704 // for compression. Note that this may be trampled 705 // immediately by the bufPool.Put(w.buf) above if the 706 // next buffer happens to be inline, but that's okay. 707 w.buf = bufPool.Get().(*bytes.Buffer) 708 continue 709 } 710 711 // Read from p into w.buf. 712 toWrite := len(p) - done 713 if toWrite > left { 714 toWrite = left 715 } 716 n, err := w.buf.Write(p[done : done+toWrite]) 717 done += n 718 if err != nil { 719 return done, err 720 } 721 } 722 723 // Make sure that everything has been flushed, we can't return until 724 // all the contents from p have been used. 725 for pendingInline > 0 { 726 if err := w.schedule(nil, func(c *chunk) error { 727 if err := callback(c); err != nil { 728 return err 729 } 730 // The flush was successful, return errNewBuffer here 731 // to break from the loop and check the condition 732 // again. 733 return errNewBuffer 734 }); err != errNewBuffer { 735 return len(p), err 736 } 737 } 738 739 return len(p), nil 740 } 741 742 // Close implements io.Closer.Close. 743 func (w *Writer) Close() error { 744 w.mu.Lock() 745 defer w.mu.Unlock() 746 747 // Did we already close? After the call to Close, we always mark as 748 // closed, regardless of whether the flush is successful. 749 if w.closed { 750 return io.ErrUnexpectedEOF 751 } 752 w.closed = true 753 defer w.stop() 754 755 // Schedule any remaining partial buffer; we pass w.flush directly here 756 // because the final buffer is guaranteed to not be an inline buffer. 757 if w.buf.Len() > 0 { 758 if err := w.schedule(newChunk(nil, nil, nil, w.buf), w.flush); err != nil { 759 return err 760 } 761 } 762 763 // Flush all scheduled buffers; see above. 764 if err := w.schedule(nil, w.flush); err != nil { 765 return err 766 } 767 768 // Close the underlying writer (if necessary). 769 if closer, ok := w.out.(io.Closer); ok { 770 return closer.Close() 771 } 772 return nil 773 }