github.com/amazechain/amc@v0.1.3/internal/p2p/encoder/ssz.go (about)

     1  package encoder
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"sync"
     7  
     8  	"github.com/golang/snappy"
     9  	"github.com/pkg/errors"
    10  	fastssz "github.com/prysmaticlabs/fastssz"
    11  )
    12  
    13  var _ NetworkEncoding = (*SszNetworkEncoder)(nil)
    14  
    15  // MaxGossipSize allowed for gossip messages.
    16  var MaxGossipSize = uint64(1 << 20) // 1 MiB
    17  var MaxChunkSize = uint64(1 << 20)  // 1 Mib.
    18  
    19  // This pool defines the sync pool for our buffered snappy writers, so that they
    20  // can be constantly reused.
    21  var bufWriterPool = new(sync.Pool)
    22  
    23  // This pool defines the sync pool for our buffered snappy readers, so that they
    24  // can be constantly reused.
    25  var bufReaderPool = new(sync.Pool)
    26  
    27  // SszNetworkEncoder supports p2p networking encoding using SimpleSerialize
    28  // with snappy compression (if enabled).
    29  type SszNetworkEncoder struct{}
    30  
    31  // ProtocolSuffixSSZSnappy is the last part of the topic string to identify the encoding protocol.
    32  const ProtocolSuffixSSZSnappy = "ssz_snappy"
    33  
    34  // EncodeGossip the proto gossip message to the io.Writer.
    35  func (_ SszNetworkEncoder) EncodeGossip(w io.Writer, msg fastssz.Marshaler) (int, error) {
    36  	if msg == nil {
    37  		return 0, nil
    38  	}
    39  	b, err := msg.MarshalSSZ()
    40  	if err != nil {
    41  		return 0, err
    42  	}
    43  	if uint64(len(b)) > MaxGossipSize {
    44  		return 0, errors.Errorf("gossip message exceeds max gossip size: %d bytes > %d bytes", len(b), MaxGossipSize)
    45  	}
    46  	b = snappy.Encode(nil /*dst*/, b)
    47  	return w.Write(b)
    48  }
    49  
    50  // EncodeWithMaxLength the proto message to the io.Writer. This encoding prefixes the byte slice with a protobuf varint
    51  // to indicate the size of the message. This checks that the encoded message isn't larger than the provided max limit.
    52  func (_ SszNetworkEncoder) EncodeWithMaxLength(w io.Writer, msg fastssz.Marshaler) (int, error) {
    53  	if msg == nil {
    54  		return 0, nil
    55  	}
    56  	b, err := msg.MarshalSSZ()
    57  	if err != nil {
    58  		return 0, err
    59  	}
    60  	if uint64(len(b)) > MaxChunkSize {
    61  		return 0, fmt.Errorf(
    62  			"size of encoded message is %d which is larger than the provided max limit of %d",
    63  			len(b),
    64  			MaxChunkSize,
    65  		)
    66  	}
    67  	// write varint first
    68  	_, err = w.Write(EncodeVarint(uint64(len(b))))
    69  	if err != nil {
    70  		return 0, err
    71  	}
    72  	return writeSnappyBuffer(w, b)
    73  }
    74  
    75  func doDecode(b []byte, to fastssz.Unmarshaler) error {
    76  	return to.UnmarshalSSZ(b)
    77  }
    78  
    79  // DecodeGossip decodes the bytes to the protobuf gossip message provided.
    80  func (_ SszNetworkEncoder) DecodeGossip(b []byte, to fastssz.Unmarshaler) error {
    81  	b, err := DecodeSnappy(b, MaxGossipSize)
    82  	if err != nil {
    83  		return err
    84  	}
    85  	return doDecode(b, to)
    86  }
    87  
    88  // DecodeSnappy decodes a snappy compressed message.
    89  func DecodeSnappy(msg []byte, maxSize uint64) ([]byte, error) {
    90  	size, err := snappy.DecodedLen(msg)
    91  	if err != nil {
    92  		return nil, err
    93  	}
    94  	if uint64(size) > maxSize {
    95  		return nil, errors.Errorf("snappy message exceeds max size: %d bytes > %d bytes", size, maxSize)
    96  	}
    97  	msg, err = snappy.Decode(nil /*dst*/, msg)
    98  	if err != nil {
    99  		return nil, err
   100  	}
   101  	return msg, nil
   102  }
   103  
   104  // DecodeWithMaxLength the bytes from io.Reader to the protobuf message provided.
   105  // This checks that the decoded message isn't larger than the provided max limit.
   106  func (e SszNetworkEncoder) DecodeWithMaxLength(r io.Reader, to fastssz.Unmarshaler) error {
   107  	msgLen, err := readVarint(r)
   108  	if err != nil {
   109  		return err
   110  	}
   111  	if msgLen > MaxChunkSize {
   112  		return fmt.Errorf(
   113  			"remaining bytes %d goes over the provided max limit of %d",
   114  			msgLen,
   115  			MaxChunkSize,
   116  		)
   117  	}
   118  	msgMax, err := e.MaxLength(msgLen)
   119  	if err != nil {
   120  		return err
   121  	}
   122  	limitedRdr := io.LimitReader(r, int64(msgMax))
   123  	r = newBufferedReader(limitedRdr)
   124  	defer bufReaderPool.Put(r)
   125  
   126  	buf := make([]byte, msgLen)
   127  	// Returns an error if less than msgLen bytes
   128  	// are read. This ensures we read exactly the
   129  	// required amount.
   130  	_, err = io.ReadFull(r, buf)
   131  	if err != nil {
   132  		return err
   133  	}
   134  	return doDecode(buf, to)
   135  }
   136  
   137  // ProtocolSuffix returns the appropriate suffix for protocol IDs.
   138  func (_ SszNetworkEncoder) ProtocolSuffix() string {
   139  	return "/" + ProtocolSuffixSSZSnappy
   140  }
   141  
   142  // MaxLength specifies the maximum possible length of an encoded
   143  // chunk of data.
   144  func (_ SszNetworkEncoder) MaxLength(length uint64) (int, error) {
   145  	//il, err := math.Int(length)
   146  	//if err != nil {
   147  	//	return 0, errors.Wrap(err, "invalid length provided")
   148  	//}
   149  	//todo uint64 to int
   150  	maxLen := snappy.MaxEncodedLen(int(length))
   151  	if maxLen < 0 {
   152  		return 0, errors.Errorf("max encoded length is negative: %d", maxLen)
   153  	}
   154  	return maxLen, nil
   155  }
   156  
   157  // Writes a bytes value through a snappy buffered writer.
   158  func writeSnappyBuffer(w io.Writer, b []byte) (int, error) {
   159  	bufWriter := newBufferedWriter(w)
   160  	defer bufWriterPool.Put(bufWriter)
   161  	num, err := bufWriter.Write(b)
   162  	if err != nil {
   163  		// Close buf writer in the event of an error.
   164  		if err := bufWriter.Close(); err != nil {
   165  			return 0, err
   166  		}
   167  		return 0, err
   168  	}
   169  	return num, bufWriter.Close()
   170  }
   171  
   172  // Instantiates a new instance of the snappy buffered reader
   173  // using our sync pool.
   174  func newBufferedReader(r io.Reader) *snappy.Reader {
   175  	rawReader := bufReaderPool.Get()
   176  	if rawReader == nil {
   177  		return snappy.NewReader(r)
   178  	}
   179  	bufR, ok := rawReader.(*snappy.Reader)
   180  	if !ok {
   181  		return snappy.NewReader(r)
   182  	}
   183  	bufR.Reset(r)
   184  	return bufR
   185  }
   186  
   187  // Instantiates a new instance of the snappy buffered writer
   188  // using our sync pool.
   189  func newBufferedWriter(w io.Writer) *snappy.Writer {
   190  	rawBufWriter := bufWriterPool.Get()
   191  	if rawBufWriter == nil {
   192  		return snappy.NewBufferedWriter(w)
   193  	}
   194  	bufW, ok := rawBufWriter.(*snappy.Writer)
   195  	if !ok {
   196  		return snappy.NewBufferedWriter(w)
   197  	}
   198  	bufW.Reset(w)
   199  	return bufW
   200  }