github.com/prysmaticlabs/prysm@v1.4.4/beacon-chain/p2p/encoder/ssz_test.go (about)

     1  package encoder_test
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"io"
     8  	"math"
     9  	"testing"
    10  
    11  	gogo "github.com/gogo/protobuf/proto"
    12  	"github.com/prysmaticlabs/prysm/beacon-chain/p2p/encoder"
    13  	pb "github.com/prysmaticlabs/prysm/proto/beacon/p2p/v1"
    14  	"github.com/prysmaticlabs/prysm/shared/params"
    15  	"github.com/prysmaticlabs/prysm/shared/testutil"
    16  	"github.com/prysmaticlabs/prysm/shared/testutil/assert"
    17  	"github.com/prysmaticlabs/prysm/shared/testutil/require"
    18  	"google.golang.org/protobuf/proto"
    19  )
    20  
    21  func TestSszNetworkEncoder_RoundTrip(t *testing.T) {
    22  	e := &encoder.SszNetworkEncoder{}
    23  	testRoundTripWithLength(t, e)
    24  	testRoundTripWithGossip(t, e)
    25  }
    26  
    27  func TestSszNetworkEncoder_FailsSnappyLength(t *testing.T) {
    28  	e := &encoder.SszNetworkEncoder{}
    29  	att := &pb.Fork{}
    30  	data := make([]byte, 32)
    31  	binary.PutUvarint(data, encoder.MaxGossipSize+32)
    32  	err := e.DecodeGossip(data, att)
    33  	require.ErrorContains(t, "snappy message exceeds max size", err)
    34  }
    35  
    36  func testRoundTripWithLength(t *testing.T, e *encoder.SszNetworkEncoder) {
    37  	buf := new(bytes.Buffer)
    38  	msg := &pb.Fork{
    39  		PreviousVersion: []byte("fooo"),
    40  		CurrentVersion:  []byte("barr"),
    41  		Epoch:           9001,
    42  	}
    43  	_, err := e.EncodeWithMaxLength(buf, msg)
    44  	require.NoError(t, err)
    45  	decoded := &pb.Fork{}
    46  	require.NoError(t, e.DecodeWithMaxLength(buf, decoded))
    47  	if !proto.Equal(decoded, msg) {
    48  		t.Logf("decoded=%+v\n", decoded)
    49  		t.Error("Decoded message is not the same as original")
    50  	}
    51  }
    52  
    53  func testRoundTripWithGossip(t *testing.T, e *encoder.SszNetworkEncoder) {
    54  	buf := new(bytes.Buffer)
    55  	msg := &pb.Fork{
    56  		PreviousVersion: []byte("fooo"),
    57  		CurrentVersion:  []byte("barr"),
    58  		Epoch:           9001,
    59  	}
    60  	_, err := e.EncodeGossip(buf, msg)
    61  	require.NoError(t, err)
    62  	decoded := &pb.Fork{}
    63  	require.NoError(t, e.DecodeGossip(buf.Bytes(), decoded))
    64  	if !proto.Equal(decoded, msg) {
    65  		t.Logf("decoded=%+v\n", decoded)
    66  		t.Error("Decoded message is not the same as original")
    67  	}
    68  }
    69  
    70  func TestSszNetworkEncoder_EncodeWithMaxLength(t *testing.T) {
    71  	buf := new(bytes.Buffer)
    72  	msg := &pb.Fork{
    73  		PreviousVersion: []byte("fooo"),
    74  		CurrentVersion:  []byte("barr"),
    75  		Epoch:           9001,
    76  	}
    77  	e := &encoder.SszNetworkEncoder{}
    78  	params.SetupTestConfigCleanup(t)
    79  	c := params.BeaconNetworkConfig()
    80  	c.MaxChunkSize = uint64(5)
    81  	params.OverrideBeaconNetworkConfig(c)
    82  	_, err := e.EncodeWithMaxLength(buf, msg)
    83  	wanted := fmt.Sprintf("which is larger than the provided max limit of %d", params.BeaconNetworkConfig().MaxChunkSize)
    84  	assert.ErrorContains(t, wanted, err)
    85  }
    86  
    87  func TestSszNetworkEncoder_DecodeWithMaxLength(t *testing.T) {
    88  	buf := new(bytes.Buffer)
    89  	msg := &pb.Fork{
    90  		PreviousVersion: []byte("fooo"),
    91  		CurrentVersion:  []byte("barr"),
    92  		Epoch:           4242,
    93  	}
    94  	e := &encoder.SszNetworkEncoder{}
    95  	params.SetupTestConfigCleanup(t)
    96  	c := params.BeaconNetworkConfig()
    97  	maxChunkSize := uint64(5)
    98  	c.MaxChunkSize = maxChunkSize
    99  	params.OverrideBeaconNetworkConfig(c)
   100  	_, err := e.EncodeGossip(buf, msg)
   101  	require.NoError(t, err)
   102  	decoded := &pb.Fork{}
   103  	err = e.DecodeWithMaxLength(buf, decoded)
   104  	wanted := fmt.Sprintf("goes over the provided max limit of %d", maxChunkSize)
   105  	assert.ErrorContains(t, wanted, err)
   106  }
   107  
   108  func TestSszNetworkEncoder_DecodeWithMultipleFrames(t *testing.T) {
   109  	buf := new(bytes.Buffer)
   110  	st, _ := testutil.DeterministicGenesisState(t, 100)
   111  	e := &encoder.SszNetworkEncoder{}
   112  	params.SetupTestConfigCleanup(t)
   113  	c := params.BeaconNetworkConfig()
   114  	// 4 * 1 Mib
   115  	maxChunkSize := uint64(1 << 22)
   116  	c.MaxChunkSize = maxChunkSize
   117  	params.OverrideBeaconNetworkConfig(c)
   118  	_, err := e.EncodeWithMaxLength(buf, st.InnerStateUnsafe())
   119  	require.NoError(t, err)
   120  	// Max snappy block size
   121  	if buf.Len() <= 76490 {
   122  		t.Errorf("buffer smaller than expected, wanted > %d but got %d", 76490, buf.Len())
   123  	}
   124  	decoded := new(pb.BeaconState)
   125  	err = e.DecodeWithMaxLength(buf, decoded)
   126  	assert.NoError(t, err)
   127  }
   128  func TestSszNetworkEncoder_NegativeMaxLength(t *testing.T) {
   129  	e := &encoder.SszNetworkEncoder{}
   130  	length, err := e.MaxLength(0xfffffffffff)
   131  
   132  	assert.Equal(t, 0, length, "Received non zero length on bad message length")
   133  	assert.ErrorContains(t, "max encoded length is negative", err)
   134  }
   135  
   136  func TestSszNetworkEncoder_MaxInt64(t *testing.T) {
   137  	e := &encoder.SszNetworkEncoder{}
   138  	length, err := e.MaxLength(math.MaxInt64 + 1)
   139  
   140  	assert.Equal(t, 0, length, "Received non zero length on bad message length")
   141  	assert.ErrorContains(t, "invalid length provided", err)
   142  }
   143  
   144  func TestSszNetworkEncoder_DecodeWithBadSnappyStream(t *testing.T) {
   145  	st := newBadSnappyStream()
   146  	e := &encoder.SszNetworkEncoder{}
   147  	decoded := new(pb.Fork)
   148  	err := e.DecodeWithMaxLength(st, decoded)
   149  	assert.ErrorContains(t, io.EOF.Error(), err)
   150  }
   151  
   152  type badSnappyStream struct {
   153  	varint []byte
   154  	header []byte
   155  	repeat []byte
   156  	i      int
   157  	// count how many times it was read
   158  	counter int
   159  	// count bytes read so far
   160  	total int
   161  }
   162  
   163  func newBadSnappyStream() *badSnappyStream {
   164  	const (
   165  		magicBody  = "sNaPpY"
   166  		magicChunk = "\xff\x06\x00\x00" + magicBody
   167  	)
   168  
   169  	header := make([]byte, len(magicChunk))
   170  	// magicChunk == chunkTypeStreamIdentifier byte ++ 3 byte little endian len(magic body) ++ 6 byte magic body
   171  
   172  	// header is a special chunk type, with small fixed length, to add some magic to claim it's really snappy.
   173  	copy(header, magicChunk) // snappy library constants help us construct the common header chunk easily.
   174  
   175  	payload := make([]byte, 4)
   176  
   177  	// byte 0 is chunk type
   178  	// Exploit any fancy ignored chunk type
   179  	//   Section 4.4 Padding (chunk type 0xfe).
   180  	//   Section 4.6. Reserved skippable chunks (chunk types 0x80-0xfd).
   181  	payload[0] = 0xfe
   182  
   183  	// byte 1,2,3 are chunk length (little endian)
   184  	payload[1] = 0
   185  	payload[2] = 0
   186  	payload[3] = 0
   187  
   188  	return &badSnappyStream{
   189  		varint:  gogo.EncodeVarint(1000),
   190  		header:  header,
   191  		repeat:  payload,
   192  		i:       0,
   193  		counter: 0,
   194  		total:   0,
   195  	}
   196  }
   197  
   198  func (b *badSnappyStream) Read(p []byte) (n int, err error) {
   199  	// Stream out varint bytes first to make test happy.
   200  	if len(b.varint) > 0 {
   201  		copy(p, b.varint[:1])
   202  		b.varint = b.varint[1:]
   203  		return 1, nil
   204  	}
   205  	defer func() {
   206  		b.counter += 1
   207  		b.total += n
   208  	}()
   209  	if len(b.repeat) == 0 {
   210  		panic("no bytes to repeat")
   211  	}
   212  	if len(b.header) > 0 {
   213  		n = copy(p, b.header)
   214  		b.header = b.header[n:]
   215  		return
   216  	}
   217  	for n < len(p) {
   218  		n += copy(p[n:], b.repeat[b.i:])
   219  		b.i = (b.i + n) % len(b.repeat)
   220  	}
   221  	return
   222  }