github.com/MetalBlockchain/metalgo@v1.11.9/utils/compression/compressor_test.go (about) 1 // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. 2 // See the file LICENSE for licensing terms. 3 4 package compression 5 6 import ( 7 "fmt" 8 "math" 9 "runtime" 10 "testing" 11 12 "github.com/stretchr/testify/require" 13 14 _ "embed" 15 16 "github.com/MetalBlockchain/metalgo/utils" 17 "github.com/MetalBlockchain/metalgo/utils/units" 18 ) 19 20 const maxMessageSize = 2 * units.MiB // Max message size. Can't import due to cycle. 21 22 var ( 23 newCompressorFuncs = map[Type]func(maxSize int64) (Compressor, error){ 24 TypeNone: func(int64) (Compressor, error) { //nolint:unparam // an error is needed to be returned to compile 25 return NewNoCompressor(), nil 26 }, 27 TypeZstd: NewZstdCompressor, 28 } 29 30 //go:embed zstd_zip_bomb.bin 31 zstdZipBomb []byte 32 33 zipBombs = map[Type][]byte{ 34 TypeZstd: zstdZipBomb, 35 } 36 ) 37 38 func TestDecompressZipBombs(t *testing.T) { 39 for compressionType, zipBomb := range zipBombs { 40 // Make sure that the hardcoded zip bomb would be a valid message. 41 require.Less(t, len(zipBomb), maxMessageSize) 42 43 newCompressorFunc := newCompressorFuncs[compressionType] 44 45 t.Run(compressionType.String(), func(t *testing.T) { 46 require := require.New(t) 47 48 compressor, err := newCompressorFunc(maxMessageSize) 49 require.NoError(err) 50 51 var ( 52 beforeDecompressionStats runtime.MemStats 53 afterDecompressionStats runtime.MemStats 54 ) 55 runtime.ReadMemStats(&beforeDecompressionStats) 56 _, err = compressor.Decompress(zipBomb) 57 runtime.ReadMemStats(&afterDecompressionStats) 58 59 require.ErrorIs(err, ErrDecompressedMsgTooLarge) 60 61 // Make sure that we didn't allocate significantly more memory than 62 // the max message size. 63 bytesAllocatedDuringDecompression := afterDecompressionStats.TotalAlloc - beforeDecompressionStats.TotalAlloc 64 require.Less(bytesAllocatedDuringDecompression, uint64(10*maxMessageSize)) 65 }) 66 } 67 } 68 69 func TestCompressDecompress(t *testing.T) { 70 for compressionType, newCompressorFunc := range newCompressorFuncs { 71 t.Run(compressionType.String(), func(t *testing.T) { 72 require := require.New(t) 73 74 data := utils.RandomBytes(4096) 75 data2 := utils.RandomBytes(4096) 76 77 compressor, err := newCompressorFunc(maxMessageSize) 78 require.NoError(err) 79 80 dataCompressed, err := compressor.Compress(data) 81 require.NoError(err) 82 83 data2Compressed, err := compressor.Compress(data2) 84 require.NoError(err) 85 86 dataDecompressed, err := compressor.Decompress(dataCompressed) 87 require.NoError(err) 88 require.Equal(data, dataDecompressed) 89 90 data2Decompressed, err := compressor.Decompress(data2Compressed) 91 require.NoError(err) 92 require.Equal(data2, data2Decompressed) 93 94 dataDecompressed, err = compressor.Decompress(dataCompressed) 95 require.NoError(err) 96 require.Equal(data, dataDecompressed) 97 98 maxMessage := utils.RandomBytes(maxMessageSize) 99 maxMessageCompressed, err := compressor.Compress(maxMessage) 100 require.NoError(err) 101 102 maxMessageDecompressed, err := compressor.Decompress(maxMessageCompressed) 103 require.NoError(err) 104 105 require.Equal(maxMessage, maxMessageDecompressed) 106 }) 107 } 108 } 109 110 func TestSizeLimiting(t *testing.T) { 111 for compressionType, compressorFunc := range newCompressorFuncs { 112 if compressionType == TypeNone { 113 continue 114 } 115 t.Run(compressionType.String(), func(t *testing.T) { 116 require := require.New(t) 117 118 compressor, err := compressorFunc(maxMessageSize) 119 require.NoError(err) 120 121 data := make([]byte, maxMessageSize+1) 122 _, err = compressor.Compress(data) // should be too large 123 require.ErrorIs(err, ErrMsgTooLarge) 124 125 compressor2, err := compressorFunc(2 * maxMessageSize) 126 require.NoError(err) 127 128 dataCompressed, err := compressor2.Compress(data) 129 require.NoError(err) 130 131 _, err = compressor.Decompress(dataCompressed) // should be too large 132 require.ErrorIs(err, ErrDecompressedMsgTooLarge) 133 }) 134 } 135 } 136 137 // Attempts to create a compressor with math.MaxInt64 138 // which leads to undefined decompress behavior due to integer overflow 139 // in limit reader creation. 140 func TestNewCompressorWithInvalidLimit(t *testing.T) { 141 for compressionType, compressorFunc := range newCompressorFuncs { 142 if compressionType == TypeNone { 143 continue 144 } 145 t.Run(compressionType.String(), func(t *testing.T) { 146 _, err := compressorFunc(math.MaxInt64) 147 require.ErrorIs(t, err, ErrInvalidMaxSizeCompressor) 148 }) 149 } 150 } 151 152 func FuzzZstdCompressor(f *testing.F) { 153 fuzzHelper(f, TypeZstd) 154 } 155 156 func fuzzHelper(f *testing.F, compressionType Type) { 157 var ( 158 compressor Compressor 159 err error 160 ) 161 switch compressionType { 162 case TypeZstd: 163 compressor, err = NewZstdCompressor(maxMessageSize) 164 require.NoError(f, err) 165 default: 166 require.FailNow(f, "Unknown compression type") 167 } 168 169 f.Fuzz(func(t *testing.T, data []byte) { 170 require := require.New(t) 171 172 if len(data) > maxMessageSize { 173 _, err := compressor.Compress(data) 174 require.ErrorIs(err, ErrMsgTooLarge) 175 } 176 177 compressed, err := compressor.Compress(data) 178 require.NoError(err) 179 180 decompressed, err := compressor.Decompress(compressed) 181 require.NoError(err) 182 183 require.Equal(data, decompressed) 184 }) 185 } 186 187 func BenchmarkCompress(b *testing.B) { 188 sizes := []int{ 189 0, 190 256, 191 units.KiB, 192 units.MiB, 193 maxMessageSize, 194 } 195 for compressionType, newCompressorFunc := range newCompressorFuncs { 196 if compressionType == TypeNone { 197 continue 198 } 199 for _, size := range sizes { 200 b.Run(fmt.Sprintf("%s_%d", compressionType, size), func(b *testing.B) { 201 require := require.New(b) 202 203 bytes := utils.RandomBytes(size) 204 compressor, err := newCompressorFunc(maxMessageSize) 205 require.NoError(err) 206 for n := 0; n < b.N; n++ { 207 _, err := compressor.Compress(bytes) 208 require.NoError(err) 209 } 210 }) 211 } 212 } 213 } 214 215 func BenchmarkDecompress(b *testing.B) { 216 sizes := []int{ 217 0, 218 256, 219 units.KiB, 220 units.MiB, 221 maxMessageSize, 222 } 223 for compressionType, newCompressorFunc := range newCompressorFuncs { 224 if compressionType == TypeNone { 225 continue 226 } 227 for _, size := range sizes { 228 b.Run(fmt.Sprintf("%s_%d", compressionType, size), func(b *testing.B) { 229 require := require.New(b) 230 231 bytes := utils.RandomBytes(size) 232 compressor, err := newCompressorFunc(maxMessageSize) 233 require.NoError(err) 234 235 compressedBytes, err := compressor.Compress(bytes) 236 require.NoError(err) 237 238 for n := 0; n < b.N; n++ { 239 _, err := compressor.Decompress(compressedBytes) 240 require.NoError(err) 241 } 242 }) 243 } 244 } 245 }