github.com/MetalBlockchain/metalgo@v1.11.9/utils/compression/compressor_test.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package compression
     5  
     6  import (
     7  	"fmt"
     8  	"math"
     9  	"runtime"
    10  	"testing"
    11  
    12  	"github.com/stretchr/testify/require"
    13  
    14  	_ "embed"
    15  
    16  	"github.com/MetalBlockchain/metalgo/utils"
    17  	"github.com/MetalBlockchain/metalgo/utils/units"
    18  )
    19  
    20  const maxMessageSize = 2 * units.MiB // Max message size. Can't import due to cycle.
    21  
    22  var (
    23  	newCompressorFuncs = map[Type]func(maxSize int64) (Compressor, error){
    24  		TypeNone: func(int64) (Compressor, error) { //nolint:unparam // an error is needed to be returned to compile
    25  			return NewNoCompressor(), nil
    26  		},
    27  		TypeZstd: NewZstdCompressor,
    28  	}
    29  
    30  	//go:embed zstd_zip_bomb.bin
    31  	zstdZipBomb []byte
    32  
    33  	zipBombs = map[Type][]byte{
    34  		TypeZstd: zstdZipBomb,
    35  	}
    36  )
    37  
    38  func TestDecompressZipBombs(t *testing.T) {
    39  	for compressionType, zipBomb := range zipBombs {
    40  		// Make sure that the hardcoded zip bomb would be a valid message.
    41  		require.Less(t, len(zipBomb), maxMessageSize)
    42  
    43  		newCompressorFunc := newCompressorFuncs[compressionType]
    44  
    45  		t.Run(compressionType.String(), func(t *testing.T) {
    46  			require := require.New(t)
    47  
    48  			compressor, err := newCompressorFunc(maxMessageSize)
    49  			require.NoError(err)
    50  
    51  			var (
    52  				beforeDecompressionStats runtime.MemStats
    53  				afterDecompressionStats  runtime.MemStats
    54  			)
    55  			runtime.ReadMemStats(&beforeDecompressionStats)
    56  			_, err = compressor.Decompress(zipBomb)
    57  			runtime.ReadMemStats(&afterDecompressionStats)
    58  
    59  			require.ErrorIs(err, ErrDecompressedMsgTooLarge)
    60  
    61  			// Make sure that we didn't allocate significantly more memory than
    62  			// the max message size.
    63  			bytesAllocatedDuringDecompression := afterDecompressionStats.TotalAlloc - beforeDecompressionStats.TotalAlloc
    64  			require.Less(bytesAllocatedDuringDecompression, uint64(10*maxMessageSize))
    65  		})
    66  	}
    67  }
    68  
    69  func TestCompressDecompress(t *testing.T) {
    70  	for compressionType, newCompressorFunc := range newCompressorFuncs {
    71  		t.Run(compressionType.String(), func(t *testing.T) {
    72  			require := require.New(t)
    73  
    74  			data := utils.RandomBytes(4096)
    75  			data2 := utils.RandomBytes(4096)
    76  
    77  			compressor, err := newCompressorFunc(maxMessageSize)
    78  			require.NoError(err)
    79  
    80  			dataCompressed, err := compressor.Compress(data)
    81  			require.NoError(err)
    82  
    83  			data2Compressed, err := compressor.Compress(data2)
    84  			require.NoError(err)
    85  
    86  			dataDecompressed, err := compressor.Decompress(dataCompressed)
    87  			require.NoError(err)
    88  			require.Equal(data, dataDecompressed)
    89  
    90  			data2Decompressed, err := compressor.Decompress(data2Compressed)
    91  			require.NoError(err)
    92  			require.Equal(data2, data2Decompressed)
    93  
    94  			dataDecompressed, err = compressor.Decompress(dataCompressed)
    95  			require.NoError(err)
    96  			require.Equal(data, dataDecompressed)
    97  
    98  			maxMessage := utils.RandomBytes(maxMessageSize)
    99  			maxMessageCompressed, err := compressor.Compress(maxMessage)
   100  			require.NoError(err)
   101  
   102  			maxMessageDecompressed, err := compressor.Decompress(maxMessageCompressed)
   103  			require.NoError(err)
   104  
   105  			require.Equal(maxMessage, maxMessageDecompressed)
   106  		})
   107  	}
   108  }
   109  
   110  func TestSizeLimiting(t *testing.T) {
   111  	for compressionType, compressorFunc := range newCompressorFuncs {
   112  		if compressionType == TypeNone {
   113  			continue
   114  		}
   115  		t.Run(compressionType.String(), func(t *testing.T) {
   116  			require := require.New(t)
   117  
   118  			compressor, err := compressorFunc(maxMessageSize)
   119  			require.NoError(err)
   120  
   121  			data := make([]byte, maxMessageSize+1)
   122  			_, err = compressor.Compress(data) // should be too large
   123  			require.ErrorIs(err, ErrMsgTooLarge)
   124  
   125  			compressor2, err := compressorFunc(2 * maxMessageSize)
   126  			require.NoError(err)
   127  
   128  			dataCompressed, err := compressor2.Compress(data)
   129  			require.NoError(err)
   130  
   131  			_, err = compressor.Decompress(dataCompressed) // should be too large
   132  			require.ErrorIs(err, ErrDecompressedMsgTooLarge)
   133  		})
   134  	}
   135  }
   136  
   137  // Attempts to create a compressor with math.MaxInt64
   138  // which leads to undefined decompress behavior due to integer overflow
   139  // in limit reader creation.
   140  func TestNewCompressorWithInvalidLimit(t *testing.T) {
   141  	for compressionType, compressorFunc := range newCompressorFuncs {
   142  		if compressionType == TypeNone {
   143  			continue
   144  		}
   145  		t.Run(compressionType.String(), func(t *testing.T) {
   146  			_, err := compressorFunc(math.MaxInt64)
   147  			require.ErrorIs(t, err, ErrInvalidMaxSizeCompressor)
   148  		})
   149  	}
   150  }
   151  
   152  func FuzzZstdCompressor(f *testing.F) {
   153  	fuzzHelper(f, TypeZstd)
   154  }
   155  
   156  func fuzzHelper(f *testing.F, compressionType Type) {
   157  	var (
   158  		compressor Compressor
   159  		err        error
   160  	)
   161  	switch compressionType {
   162  	case TypeZstd:
   163  		compressor, err = NewZstdCompressor(maxMessageSize)
   164  		require.NoError(f, err)
   165  	default:
   166  		require.FailNow(f, "Unknown compression type")
   167  	}
   168  
   169  	f.Fuzz(func(t *testing.T, data []byte) {
   170  		require := require.New(t)
   171  
   172  		if len(data) > maxMessageSize {
   173  			_, err := compressor.Compress(data)
   174  			require.ErrorIs(err, ErrMsgTooLarge)
   175  		}
   176  
   177  		compressed, err := compressor.Compress(data)
   178  		require.NoError(err)
   179  
   180  		decompressed, err := compressor.Decompress(compressed)
   181  		require.NoError(err)
   182  
   183  		require.Equal(data, decompressed)
   184  	})
   185  }
   186  
   187  func BenchmarkCompress(b *testing.B) {
   188  	sizes := []int{
   189  		0,
   190  		256,
   191  		units.KiB,
   192  		units.MiB,
   193  		maxMessageSize,
   194  	}
   195  	for compressionType, newCompressorFunc := range newCompressorFuncs {
   196  		if compressionType == TypeNone {
   197  			continue
   198  		}
   199  		for _, size := range sizes {
   200  			b.Run(fmt.Sprintf("%s_%d", compressionType, size), func(b *testing.B) {
   201  				require := require.New(b)
   202  
   203  				bytes := utils.RandomBytes(size)
   204  				compressor, err := newCompressorFunc(maxMessageSize)
   205  				require.NoError(err)
   206  				for n := 0; n < b.N; n++ {
   207  					_, err := compressor.Compress(bytes)
   208  					require.NoError(err)
   209  				}
   210  			})
   211  		}
   212  	}
   213  }
   214  
   215  func BenchmarkDecompress(b *testing.B) {
   216  	sizes := []int{
   217  		0,
   218  		256,
   219  		units.KiB,
   220  		units.MiB,
   221  		maxMessageSize,
   222  	}
   223  	for compressionType, newCompressorFunc := range newCompressorFuncs {
   224  		if compressionType == TypeNone {
   225  			continue
   226  		}
   227  		for _, size := range sizes {
   228  			b.Run(fmt.Sprintf("%s_%d", compressionType, size), func(b *testing.B) {
   229  				require := require.New(b)
   230  
   231  				bytes := utils.RandomBytes(size)
   232  				compressor, err := newCompressorFunc(maxMessageSize)
   233  				require.NoError(err)
   234  
   235  				compressedBytes, err := compressor.Compress(bytes)
   236  				require.NoError(err)
   237  
   238  				for n := 0; n < b.N; n++ {
   239  					_, err := compressor.Decompress(compressedBytes)
   240  					require.NoError(err)
   241  				}
   242  			})
   243  		}
   244  	}
   245  }