github.com/prysmaticlabs/prysm@v1.4.4/beacon-chain/p2p/encoder/ssz.go (about)

     1  package encoder
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"math"
     7  	"sync"
     8  
     9  	fastssz "github.com/ferranbt/fastssz"
    10  	"github.com/gogo/protobuf/proto"
    11  	"github.com/golang/snappy"
    12  	"github.com/pkg/errors"
    13  	"github.com/prysmaticlabs/prysm/shared/params"
    14  )
    15  
    16  var _ NetworkEncoding = (*SszNetworkEncoder)(nil)
    17  
    18  // MaxGossipSize allowed for gossip messages.
    19  var MaxGossipSize = params.BeaconNetworkConfig().GossipMaxSize // 1 Mib
    20  
    21  // This pool defines the sync pool for our buffered snappy writers, so that they
    22  // can be constantly reused.
    23  var bufWriterPool = new(sync.Pool)
    24  
    25  // This pool defines the sync pool for our buffered snappy readers, so that they
    26  // can be constantly reused.
    27  var bufReaderPool = new(sync.Pool)
    28  
    29  // SszNetworkEncoder supports p2p networking encoding using SimpleSerialize
    30  // with snappy compression (if enabled).
    31  type SszNetworkEncoder struct{}
    32  
    33  // ProtocolSuffixSSZSnappy is the last part of the topic string to identify the encoding protocol.
    34  const ProtocolSuffixSSZSnappy = "ssz_snappy"
    35  
    36  func (e SszNetworkEncoder) doEncode(msg interface{}) ([]byte, error) {
    37  	if v, ok := msg.(fastssz.Marshaler); ok {
    38  		return v.MarshalSSZ()
    39  	}
    40  	return nil, errors.Errorf("non-supported type: %T", msg)
    41  }
    42  
    43  // EncodeGossip the proto gossip message to the io.Writer.
    44  func (e SszNetworkEncoder) EncodeGossip(w io.Writer, msg interface{}) (int, error) {
    45  	if msg == nil {
    46  		return 0, nil
    47  	}
    48  	b, err := e.doEncode(msg)
    49  	if err != nil {
    50  		return 0, err
    51  	}
    52  	if uint64(len(b)) > MaxGossipSize {
    53  		return 0, errors.Errorf("gossip message exceeds max gossip size: %d bytes > %d bytes", len(b), MaxGossipSize)
    54  	}
    55  	b = snappy.Encode(nil /*dst*/, b)
    56  	return w.Write(b)
    57  }
    58  
    59  // EncodeWithMaxLength the proto message to the io.Writer. This encoding prefixes the byte slice with a protobuf varint
    60  // to indicate the size of the message. This checks that the encoded message isn't larger than the provided max limit.
    61  func (e SszNetworkEncoder) EncodeWithMaxLength(w io.Writer, msg interface{}) (int, error) {
    62  	if msg == nil {
    63  		return 0, nil
    64  	}
    65  	b, err := e.doEncode(msg)
    66  	if err != nil {
    67  		return 0, err
    68  	}
    69  	if uint64(len(b)) > params.BeaconNetworkConfig().MaxChunkSize {
    70  		return 0, fmt.Errorf(
    71  			"size of encoded message is %d which is larger than the provided max limit of %d",
    72  			len(b),
    73  			params.BeaconNetworkConfig().MaxChunkSize,
    74  		)
    75  	}
    76  	// write varint first
    77  	_, err = w.Write(proto.EncodeVarint(uint64(len(b))))
    78  	if err != nil {
    79  		return 0, err
    80  	}
    81  	return writeSnappyBuffer(w, b)
    82  }
    83  
    84  func (e SszNetworkEncoder) doDecode(b []byte, to interface{}) error {
    85  	if v, ok := to.(fastssz.Unmarshaler); ok {
    86  		return v.UnmarshalSSZ(b)
    87  	}
    88  	return errors.Errorf("non-supported type: %T", to)
    89  }
    90  
    91  // DecodeGossip decodes the bytes to the protobuf gossip message provided.
    92  func (e SszNetworkEncoder) DecodeGossip(b []byte, to interface{}) error {
    93  	b, err := DecodeSnappy(b, MaxGossipSize)
    94  	if err != nil {
    95  		return err
    96  	}
    97  	return e.doDecode(b, to)
    98  }
    99  
   100  // DecodeSnappy decodes a snappy compressed message.
   101  func DecodeSnappy(msg []byte, maxSize uint64) ([]byte, error) {
   102  	size, err := snappy.DecodedLen(msg)
   103  	if err != nil {
   104  		return nil, err
   105  	}
   106  	if uint64(size) > maxSize {
   107  		return nil, errors.Errorf("snappy message exceeds max size: %d bytes > %d bytes", size, maxSize)
   108  	}
   109  	msg, err = snappy.Decode(nil /*dst*/, msg)
   110  	if err != nil {
   111  		return nil, err
   112  	}
   113  	return msg, nil
   114  }
   115  
   116  // DecodeWithMaxLength the bytes from io.Reader to the protobuf message provided.
   117  // This checks that the decoded message isn't larger than the provided max limit.
   118  func (e SszNetworkEncoder) DecodeWithMaxLength(r io.Reader, to interface{}) error {
   119  	msgLen, err := readVarint(r)
   120  	if err != nil {
   121  		return err
   122  	}
   123  	if msgLen > params.BeaconNetworkConfig().MaxChunkSize {
   124  		return fmt.Errorf(
   125  			"remaining bytes %d goes over the provided max limit of %d",
   126  			msgLen,
   127  			params.BeaconNetworkConfig().MaxChunkSize,
   128  		)
   129  	}
   130  	msgMax, err := e.MaxLength(msgLen)
   131  	if err != nil {
   132  		return err
   133  	}
   134  	limitedRdr := io.LimitReader(r, int64(msgMax))
   135  	r = newBufferedReader(limitedRdr)
   136  	defer bufReaderPool.Put(r)
   137  
   138  	buf := make([]byte, msgLen)
   139  	// Returns an error if less than msgLen bytes
   140  	// are read. This ensures we read exactly the
   141  	// required amount.
   142  	_, err = io.ReadFull(r, buf)
   143  	if err != nil {
   144  		return err
   145  	}
   146  	return e.doDecode(buf, to)
   147  }
   148  
   149  // ProtocolSuffix returns the appropriate suffix for protocol IDs.
   150  func (e SszNetworkEncoder) ProtocolSuffix() string {
   151  	return "/" + ProtocolSuffixSSZSnappy
   152  }
   153  
   154  // MaxLength specifies the maximum possible length of an encoded
   155  // chunk of data.
   156  func (e SszNetworkEncoder) MaxLength(length uint64) (int, error) {
   157  	// Defensive check to prevent potential issues when casting to int64.
   158  	if length > math.MaxInt64 {
   159  		return 0, errors.Errorf("invalid length provided: %d", length)
   160  	}
   161  	maxLen := snappy.MaxEncodedLen(int(length))
   162  	if maxLen < 0 {
   163  		return 0, errors.Errorf("max encoded length is negative: %d", maxLen)
   164  	}
   165  	return maxLen, nil
   166  }
   167  
   168  // Writes a bytes value through a snappy buffered writer.
   169  func writeSnappyBuffer(w io.Writer, b []byte) (int, error) {
   170  	bufWriter := newBufferedWriter(w)
   171  	defer bufWriterPool.Put(bufWriter)
   172  	num, err := bufWriter.Write(b)
   173  	if err != nil {
   174  		// Close buf writer in the event of an error.
   175  		if err := bufWriter.Close(); err != nil {
   176  			return 0, err
   177  		}
   178  		return 0, err
   179  	}
   180  	return num, bufWriter.Close()
   181  }
   182  
   183  // Instantiates a new instance of the snappy buffered reader
   184  // using our sync pool.
   185  func newBufferedReader(r io.Reader) *snappy.Reader {
   186  	rawReader := bufReaderPool.Get()
   187  	if rawReader == nil {
   188  		return snappy.NewReader(r)
   189  	}
   190  	bufR, ok := rawReader.(*snappy.Reader)
   191  	if !ok {
   192  		return snappy.NewReader(r)
   193  	}
   194  	bufR.Reset(r)
   195  	return bufR
   196  }
   197  
   198  // Instantiates a new instance of the snappy buffered writer
   199  // using our sync pool.
   200  func newBufferedWriter(w io.Writer) *snappy.Writer {
   201  	rawBufWriter := bufWriterPool.Get()
   202  	if rawBufWriter == nil {
   203  		return snappy.NewBufferedWriter(w)
   204  	}
   205  	bufW, ok := rawBufWriter.(*snappy.Writer)
   206  	if !ok {
   207  		return snappy.NewBufferedWriter(w)
   208  	}
   209  	bufW.Reset(w)
   210  	return bufW
   211  }