github.com/prysmaticlabs/prysm@v1.4.4/beacon-chain/p2p/encoder/ssz_test.go (about) 1 package encoder_test 2 3 import ( 4 "bytes" 5 "encoding/binary" 6 "fmt" 7 "io" 8 "math" 9 "testing" 10 11 gogo "github.com/gogo/protobuf/proto" 12 "github.com/prysmaticlabs/prysm/beacon-chain/p2p/encoder" 13 pb "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1" 14 "github.com/prysmaticlabs/prysm/shared/params" 15 "github.com/prysmaticlabs/prysm/shared/testutil" 16 "github.com/prysmaticlabs/prysm/shared/testutil/assert" 17 "github.com/prysmaticlabs/prysm/shared/testutil/require" 18 "google.golang.org/protobuf/proto" 19 ) 20 21 func TestSszNetworkEncoder_RoundTrip(t *testing.T) { 22 e := &encoder.SszNetworkEncoder{} 23 testRoundTripWithLength(t, e) 24 testRoundTripWithGossip(t, e) 25 } 26 27 func TestSszNetworkEncoder_FailsSnappyLength(t *testing.T) { 28 e := &encoder.SszNetworkEncoder{} 29 att := &pb.Fork{} 30 data := make([]byte, 32) 31 binary.PutUvarint(data, encoder.MaxGossipSize+32) 32 err := e.DecodeGossip(data, att) 33 require.ErrorContains(t, "snappy message exceeds max size", err) 34 } 35 36 func testRoundTripWithLength(t *testing.T, e *encoder.SszNetworkEncoder) { 37 buf := new(bytes.Buffer) 38 msg := &pb.Fork{ 39 PreviousVersion: []byte("fooo"), 40 CurrentVersion: []byte("barr"), 41 Epoch: 9001, 42 } 43 _, err := e.EncodeWithMaxLength(buf, msg) 44 require.NoError(t, err) 45 decoded := &pb.Fork{} 46 require.NoError(t, e.DecodeWithMaxLength(buf, decoded)) 47 if !proto.Equal(decoded, msg) { 48 t.Logf("decoded=%+v\n", decoded) 49 t.Error("Decoded message is not the same as original") 50 } 51 } 52 53 func testRoundTripWithGossip(t *testing.T, e *encoder.SszNetworkEncoder) { 54 buf := new(bytes.Buffer) 55 msg := &pb.Fork{ 56 PreviousVersion: []byte("fooo"), 57 CurrentVersion: []byte("barr"), 58 Epoch: 9001, 59 } 60 _, err := e.EncodeGossip(buf, msg) 61 require.NoError(t, err) 62 decoded := &pb.Fork{} 63 require.NoError(t, e.DecodeGossip(buf.Bytes(), decoded)) 64 if !proto.Equal(decoded, msg) { 65 t.Logf("decoded=%+v\n", decoded) 66 t.Error("Decoded message is not the same as original") 67 } 68 } 69 70 func TestSszNetworkEncoder_EncodeWithMaxLength(t *testing.T) { 71 buf := new(bytes.Buffer) 72 msg := &pb.Fork{ 73 PreviousVersion: []byte("fooo"), 74 CurrentVersion: []byte("barr"), 75 Epoch: 9001, 76 } 77 e := &encoder.SszNetworkEncoder{} 78 params.SetupTestConfigCleanup(t) 79 c := params.BeaconNetworkConfig() 80 c.MaxChunkSize = uint64(5) 81 params.OverrideBeaconNetworkConfig(c) 82 _, err := e.EncodeWithMaxLength(buf, msg) 83 wanted := fmt.Sprintf("which is larger than the provided max limit of %d", params.BeaconNetworkConfig().MaxChunkSize) 84 assert.ErrorContains(t, wanted, err) 85 } 86 87 func TestSszNetworkEncoder_DecodeWithMaxLength(t *testing.T) { 88 buf := new(bytes.Buffer) 89 msg := &pb.Fork{ 90 PreviousVersion: []byte("fooo"), 91 CurrentVersion: []byte("barr"), 92 Epoch: 4242, 93 } 94 e := &encoder.SszNetworkEncoder{} 95 params.SetupTestConfigCleanup(t) 96 c := params.BeaconNetworkConfig() 97 maxChunkSize := uint64(5) 98 c.MaxChunkSize = maxChunkSize 99 params.OverrideBeaconNetworkConfig(c) 100 _, err := e.EncodeGossip(buf, msg) 101 require.NoError(t, err) 102 decoded := &pb.Fork{} 103 err = e.DecodeWithMaxLength(buf, decoded) 104 wanted := fmt.Sprintf("goes over the provided max limit of %d", maxChunkSize) 105 assert.ErrorContains(t, wanted, err) 106 } 107 108 func TestSszNetworkEncoder_DecodeWithMultipleFrames(t *testing.T) { 109 buf := new(bytes.Buffer) 110 st, _ := testutil.DeterministicGenesisState(t, 100) 111 e := &encoder.SszNetworkEncoder{} 112 params.SetupTestConfigCleanup(t) 113 c := params.BeaconNetworkConfig() 114 // 4 * 1 Mib 115 maxChunkSize := uint64(1 << 22) 116 c.MaxChunkSize = maxChunkSize 117 params.OverrideBeaconNetworkConfig(c) 118 _, err := e.EncodeWithMaxLength(buf, st.InnerStateUnsafe()) 119 require.NoError(t, err) 120 // Max snappy block size 121 if buf.Len() <= 76490 { 122 t.Errorf("buffer smaller than expected, wanted > %d but got %d", 76490, buf.Len()) 123 } 124 decoded := new(pb.BeaconState) 125 err = e.DecodeWithMaxLength(buf, decoded) 126 assert.NoError(t, err) 127 } 128 func TestSszNetworkEncoder_NegativeMaxLength(t *testing.T) { 129 e := &encoder.SszNetworkEncoder{} 130 length, err := e.MaxLength(0xfffffffffff) 131 132 assert.Equal(t, 0, length, "Received non zero length on bad message length") 133 assert.ErrorContains(t, "max encoded length is negative", err) 134 } 135 136 func TestSszNetworkEncoder_MaxInt64(t *testing.T) { 137 e := &encoder.SszNetworkEncoder{} 138 length, err := e.MaxLength(math.MaxInt64 + 1) 139 140 assert.Equal(t, 0, length, "Received non zero length on bad message length") 141 assert.ErrorContains(t, "invalid length provided", err) 142 } 143 144 func TestSszNetworkEncoder_DecodeWithBadSnappyStream(t *testing.T) { 145 st := newBadSnappyStream() 146 e := &encoder.SszNetworkEncoder{} 147 decoded := new(pb.Fork) 148 err := e.DecodeWithMaxLength(st, decoded) 149 assert.ErrorContains(t, io.EOF.Error(), err) 150 } 151 152 type badSnappyStream struct { 153 varint []byte 154 header []byte 155 repeat []byte 156 i int 157 // count how many times it was read 158 counter int 159 // count bytes read so far 160 total int 161 } 162 163 func newBadSnappyStream() *badSnappyStream { 164 const ( 165 magicBody = "sNaPpY" 166 magicChunk = "\xff\x06\x00\x00" + magicBody 167 ) 168 169 header := make([]byte, len(magicChunk)) 170 // magicChunk == chunkTypeStreamIdentifier byte ++ 3 byte little endian len(magic body) ++ 6 byte magic body 171 172 // header is a special chunk type, with small fixed length, to add some magic to claim it's really snappy. 173 copy(header, magicChunk) // snappy library constants help us construct the common header chunk easily. 174 175 payload := make([]byte, 4) 176 177 // byte 0 is chunk type 178 // Exploit any fancy ignored chunk type 179 // Section 4.4 Padding (chunk type 0xfe). 180 // Section 4.6. Reserved skippable chunks (chunk types 0x80-0xfd). 181 payload[0] = 0xfe 182 183 // byte 1,2,3 are chunk length (little endian) 184 payload[1] = 0 185 payload[2] = 0 186 payload[3] = 0 187 188 return &badSnappyStream{ 189 varint: gogo.EncodeVarint(1000), 190 header: header, 191 repeat: payload, 192 i: 0, 193 counter: 0, 194 total: 0, 195 } 196 } 197 198 func (b *badSnappyStream) Read(p []byte) (n int, err error) { 199 // Stream out varint bytes first to make test happy. 200 if len(b.varint) > 0 { 201 copy(p, b.varint[:1]) 202 b.varint = b.varint[1:] 203 return 1, nil 204 } 205 defer func() { 206 b.counter += 1 207 b.total += n 208 }() 209 if len(b.repeat) == 0 { 210 panic("no bytes to repeat") 211 } 212 if len(b.header) > 0 { 213 n = copy(p, b.header) 214 b.header = b.header[n:] 215 return 216 } 217 for n < len(p) { 218 n += copy(p[n:], b.repeat[b.i:]) 219 b.i = (b.i + n) % len(b.repeat) 220 } 221 return 222 }