github.com/MetalBlockchain/metalgo@v1.11.9/message/messages.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package message
     5  
     6  import (
     7  	"errors"
     8  	"fmt"
     9  	"time"
    10  
    11  	"github.com/prometheus/client_golang/prometheus"
    12  	"google.golang.org/protobuf/proto"
    13  
    14  	"github.com/MetalBlockchain/metalgo/ids"
    15  	"github.com/MetalBlockchain/metalgo/proto/pb/p2p"
    16  	"github.com/MetalBlockchain/metalgo/utils/compression"
    17  	"github.com/MetalBlockchain/metalgo/utils/constants"
    18  	"github.com/MetalBlockchain/metalgo/utils/logging"
    19  	"github.com/MetalBlockchain/metalgo/utils/timer/mockable"
    20  )
    21  
    22  const (
    23  	typeLabel      = "type"
    24  	opLabel        = "op"
    25  	directionLabel = "direction"
    26  
    27  	compressionLabel   = "compression"
    28  	decompressionLabel = "decompression"
    29  )
    30  
    31  var (
    32  	_ InboundMessage  = (*inboundMessage)(nil)
    33  	_ OutboundMessage = (*outboundMessage)(nil)
    34  
    35  	metricLabels = []string{typeLabel, opLabel, directionLabel}
    36  
    37  	errUnknownCompressionType = errors.New("message is compressed with an unknown compression type")
    38  )
    39  
    40  // InboundMessage represents a set of fields for an inbound message
    41  type InboundMessage interface {
    42  	fmt.Stringer
    43  	// NodeID returns the ID of the node that sent this message
    44  	NodeID() ids.NodeID
    45  	// Op returns the op that describes this message type
    46  	Op() Op
    47  	// Message returns the message that was sent
    48  	Message() fmt.Stringer
    49  	// Expiration returns the time that the sender will have already timed out
    50  	// this request
    51  	Expiration() time.Time
    52  	// OnFinishedHandling must be called one time when this message has been
    53  	// handled by the message handler
    54  	OnFinishedHandling()
    55  	// BytesSavedCompression returns the number of bytes that this message saved
    56  	// due to being compressed
    57  	BytesSavedCompression() int
    58  }
    59  
    60  type inboundMessage struct {
    61  	nodeID                ids.NodeID
    62  	op                    Op
    63  	message               fmt.Stringer
    64  	expiration            time.Time
    65  	onFinishedHandling    func()
    66  	bytesSavedCompression int
    67  }
    68  
    69  func (m *inboundMessage) NodeID() ids.NodeID {
    70  	return m.nodeID
    71  }
    72  
    73  func (m *inboundMessage) Op() Op {
    74  	return m.op
    75  }
    76  
    77  func (m *inboundMessage) Message() fmt.Stringer {
    78  	return m.message
    79  }
    80  
    81  func (m *inboundMessage) Expiration() time.Time {
    82  	return m.expiration
    83  }
    84  
    85  func (m *inboundMessage) OnFinishedHandling() {
    86  	if m.onFinishedHandling != nil {
    87  		m.onFinishedHandling()
    88  	}
    89  }
    90  
    91  func (m *inboundMessage) BytesSavedCompression() int {
    92  	return m.bytesSavedCompression
    93  }
    94  
    95  func (m *inboundMessage) String() string {
    96  	return fmt.Sprintf("%s Op: %s Message: %s",
    97  		m.nodeID, m.op, m.message)
    98  }
    99  
   100  // OutboundMessage represents a set of fields for an outbound message that can
   101  // be serialized into a byte stream
   102  type OutboundMessage interface {
   103  	// BypassThrottling returns true if we should send this message, regardless
   104  	// of any outbound message throttling
   105  	BypassThrottling() bool
   106  	// Op returns the op that describes this message type
   107  	Op() Op
   108  	// Bytes returns the bytes that will be sent
   109  	Bytes() []byte
   110  	// BytesSavedCompression returns the number of bytes that this message saved
   111  	// due to being compressed
   112  	BytesSavedCompression() int
   113  }
   114  
   115  type outboundMessage struct {
   116  	bypassThrottling      bool
   117  	op                    Op
   118  	bytes                 []byte
   119  	bytesSavedCompression int
   120  }
   121  
   122  func (m *outboundMessage) BypassThrottling() bool {
   123  	return m.bypassThrottling
   124  }
   125  
   126  func (m *outboundMessage) Op() Op {
   127  	return m.op
   128  }
   129  
   130  func (m *outboundMessage) Bytes() []byte {
   131  	return m.bytes
   132  }
   133  
   134  func (m *outboundMessage) BytesSavedCompression() int {
   135  	return m.bytesSavedCompression
   136  }
   137  
   138  // TODO: add other compression algorithms with extended interface
   139  type msgBuilder struct {
   140  	log logging.Logger
   141  
   142  	zstdCompressor compression.Compressor
   143  	count          *prometheus.CounterVec // type + op + direction
   144  	duration       *prometheus.GaugeVec   // type + op + direction
   145  
   146  	maxMessageTimeout time.Duration
   147  }
   148  
   149  func newMsgBuilder(
   150  	log logging.Logger,
   151  	metrics prometheus.Registerer,
   152  	maxMessageTimeout time.Duration,
   153  ) (*msgBuilder, error) {
   154  	zstdCompressor, err := compression.NewZstdCompressor(constants.DefaultMaxMessageSize)
   155  	if err != nil {
   156  		return nil, err
   157  	}
   158  
   159  	mb := &msgBuilder{
   160  		log: log,
   161  
   162  		zstdCompressor: zstdCompressor,
   163  		count: prometheus.NewCounterVec(
   164  			prometheus.CounterOpts{
   165  				Name: "codec_compressed_count",
   166  				Help: "number of compressed messages",
   167  			},
   168  			metricLabels,
   169  		),
   170  		duration: prometheus.NewGaugeVec(
   171  			prometheus.GaugeOpts{
   172  				Name: "codec_compressed_duration",
   173  				Help: "time spent handling compressed messages",
   174  			},
   175  			metricLabels,
   176  		),
   177  
   178  		maxMessageTimeout: maxMessageTimeout,
   179  	}
   180  	return mb, errors.Join(
   181  		metrics.Register(mb.count),
   182  		metrics.Register(mb.duration),
   183  	)
   184  }
   185  
   186  func (mb *msgBuilder) marshal(
   187  	uncompressedMsg *p2p.Message,
   188  	compressionType compression.Type,
   189  ) ([]byte, int, Op, error) {
   190  	uncompressedMsgBytes, err := proto.Marshal(uncompressedMsg)
   191  	if err != nil {
   192  		return nil, 0, 0, err
   193  	}
   194  
   195  	op, err := ToOp(uncompressedMsg)
   196  	if err != nil {
   197  		return nil, 0, 0, err
   198  	}
   199  
   200  	// If compression is enabled, we marshal twice:
   201  	// 1. the original message
   202  	// 2. the message with compressed bytes
   203  	//
   204  	// This recursive packing allows us to avoid an extra compression on/off
   205  	// field in the message.
   206  	var (
   207  		startTime     = time.Now()
   208  		compressedMsg p2p.Message
   209  	)
   210  	switch compressionType {
   211  	case compression.TypeNone:
   212  		return uncompressedMsgBytes, 0, op, nil
   213  	case compression.TypeZstd:
   214  		compressedBytes, err := mb.zstdCompressor.Compress(uncompressedMsgBytes)
   215  		if err != nil {
   216  			return nil, 0, 0, err
   217  		}
   218  		compressedMsg = p2p.Message{
   219  			Message: &p2p.Message_CompressedZstd{
   220  				CompressedZstd: compressedBytes,
   221  			},
   222  		}
   223  	default:
   224  		return nil, 0, 0, errUnknownCompressionType
   225  	}
   226  
   227  	compressedMsgBytes, err := proto.Marshal(&compressedMsg)
   228  	if err != nil {
   229  		return nil, 0, 0, err
   230  	}
   231  	compressTook := time.Since(startTime)
   232  
   233  	labels := prometheus.Labels{
   234  		typeLabel:      compressionType.String(),
   235  		opLabel:        op.String(),
   236  		directionLabel: compressionLabel,
   237  	}
   238  	mb.count.With(labels).Inc()
   239  	mb.duration.With(labels).Add(float64(compressTook))
   240  
   241  	bytesSaved := len(uncompressedMsgBytes) - len(compressedMsgBytes)
   242  	return compressedMsgBytes, bytesSaved, op, nil
   243  }
   244  
   245  func (mb *msgBuilder) unmarshal(b []byte) (*p2p.Message, int, Op, error) {
   246  	m := new(p2p.Message)
   247  	if err := proto.Unmarshal(b, m); err != nil {
   248  		return nil, 0, 0, err
   249  	}
   250  
   251  	// Figure out what compression type, if any, was used to compress the message.
   252  	var (
   253  		compressor      compression.Compressor
   254  		compressedBytes []byte
   255  		zstdCompressed  = m.GetCompressedZstd()
   256  	)
   257  	switch {
   258  	case len(zstdCompressed) > 0:
   259  		compressor = mb.zstdCompressor
   260  		compressedBytes = zstdCompressed
   261  	default:
   262  		// The message wasn't compressed
   263  		op, err := ToOp(m)
   264  		return m, 0, op, err
   265  	}
   266  
   267  	startTime := time.Now()
   268  
   269  	decompressed, err := compressor.Decompress(compressedBytes)
   270  	if err != nil {
   271  		return nil, 0, 0, err
   272  	}
   273  	bytesSavedCompression := len(decompressed) - len(compressedBytes)
   274  
   275  	if err := proto.Unmarshal(decompressed, m); err != nil {
   276  		return nil, 0, 0, err
   277  	}
   278  	decompressTook := time.Since(startTime)
   279  
   280  	// Record decompression time metric
   281  	op, err := ToOp(m)
   282  	if err != nil {
   283  		return nil, 0, 0, err
   284  	}
   285  
   286  	labels := prometheus.Labels{
   287  		typeLabel:      compression.TypeZstd.String(),
   288  		opLabel:        op.String(),
   289  		directionLabel: decompressionLabel,
   290  	}
   291  	mb.count.With(labels).Inc()
   292  	mb.duration.With(labels).Add(float64(decompressTook))
   293  
   294  	return m, bytesSavedCompression, op, nil
   295  }
   296  
   297  func (mb *msgBuilder) createOutbound(m *p2p.Message, compressionType compression.Type, bypassThrottling bool) (*outboundMessage, error) {
   298  	b, saved, op, err := mb.marshal(m, compressionType)
   299  	if err != nil {
   300  		return nil, err
   301  	}
   302  
   303  	return &outboundMessage{
   304  		bypassThrottling:      bypassThrottling,
   305  		op:                    op,
   306  		bytes:                 b,
   307  		bytesSavedCompression: saved,
   308  	}, nil
   309  }
   310  
   311  func (mb *msgBuilder) parseInbound(
   312  	bytes []byte,
   313  	nodeID ids.NodeID,
   314  	onFinishedHandling func(),
   315  ) (*inboundMessage, error) {
   316  	m, bytesSavedCompression, op, err := mb.unmarshal(bytes)
   317  	if err != nil {
   318  		return nil, err
   319  	}
   320  
   321  	msg, err := Unwrap(m)
   322  	if err != nil {
   323  		return nil, err
   324  	}
   325  
   326  	expiration := mockable.MaxTime
   327  	if deadline, ok := GetDeadline(msg); ok {
   328  		deadline = min(deadline, mb.maxMessageTimeout)
   329  		expiration = time.Now().Add(deadline)
   330  	}
   331  
   332  	return &inboundMessage{
   333  		nodeID:                nodeID,
   334  		op:                    op,
   335  		message:               msg,
   336  		expiration:            expiration,
   337  		onFinishedHandling:    onFinishedHandling,
   338  		bytesSavedCompression: bytesSavedCompression,
   339  	}, nil
   340  }