github.com/MetalBlockchain/metalgo@v1.11.9/utils/compression/zstd_compressor.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package compression
     5  
     6  import (
     7  	"bytes"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  	"math"
    12  
    13  	"github.com/DataDog/zstd"
    14  )
    15  
    16  var (
    17  	_ Compressor = (*zstdCompressor)(nil)
    18  
    19  	ErrInvalidMaxSizeCompressor = errors.New("invalid compressor max size")
    20  	ErrDecompressedMsgTooLarge  = errors.New("decompressed msg too large")
    21  	ErrMsgTooLarge              = errors.New("msg too large to be compressed")
    22  )
    23  
    24  func NewZstdCompressor(maxSize int64) (Compressor, error) {
    25  	if maxSize == math.MaxInt64 {
    26  		// "Decompress" creates "io.LimitReader" with max size + 1:
    27  		// if the max size + 1 overflows, "io.LimitReader" reads nothing
    28  		// returning 0 byte for the decompress call
    29  		// require max size < math.MaxInt64 to prevent int64 overflows
    30  		return nil, ErrInvalidMaxSizeCompressor
    31  	}
    32  
    33  	return &zstdCompressor{
    34  		maxSize: maxSize,
    35  	}, nil
    36  }
    37  
    38  type zstdCompressor struct {
    39  	maxSize int64
    40  }
    41  
    42  func (z *zstdCompressor) Compress(msg []byte) ([]byte, error) {
    43  	if int64(len(msg)) > z.maxSize {
    44  		return nil, fmt.Errorf("%w: (%d) > (%d)", ErrMsgTooLarge, len(msg), z.maxSize)
    45  	}
    46  	return zstd.Compress(nil, msg)
    47  }
    48  
    49  func (z *zstdCompressor) Decompress(msg []byte) ([]byte, error) {
    50  	reader := zstd.NewReader(bytes.NewReader(msg))
    51  	defer reader.Close()
    52  
    53  	// We allow [io.LimitReader] to read up to [z.maxSize + 1] bytes, so that if
    54  	// the decompressed payload is greater than the maximum size, this function
    55  	// will return the appropriate error instead of an incomplete byte slice.
    56  	limitReader := io.LimitReader(reader, z.maxSize+1)
    57  	decompressed, err := io.ReadAll(limitReader)
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  	if int64(len(decompressed)) > z.maxSize {
    62  		return nil, fmt.Errorf("%w: (%d) > (%d)", ErrDecompressedMsgTooLarge, len(decompressed), z.maxSize)
    63  	}
    64  	return decompressed, nil
    65  }