github.com/MetalBlockchain/metalgo@v1.11.9/message/messages.go (about) 1 // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved. 2 // See the file LICENSE for licensing terms. 3 4 package message 5 6 import ( 7 "errors" 8 "fmt" 9 "time" 10 11 "github.com/prometheus/client_golang/prometheus" 12 "google.golang.org/protobuf/proto" 13 14 "github.com/MetalBlockchain/metalgo/ids" 15 "github.com/MetalBlockchain/metalgo/proto/pb/p2p" 16 "github.com/MetalBlockchain/metalgo/utils/compression" 17 "github.com/MetalBlockchain/metalgo/utils/constants" 18 "github.com/MetalBlockchain/metalgo/utils/logging" 19 "github.com/MetalBlockchain/metalgo/utils/timer/mockable" 20 ) 21 22 const ( 23 typeLabel = "type" 24 opLabel = "op" 25 directionLabel = "direction" 26 27 compressionLabel = "compression" 28 decompressionLabel = "decompression" 29 ) 30 31 var ( 32 _ InboundMessage = (*inboundMessage)(nil) 33 _ OutboundMessage = (*outboundMessage)(nil) 34 35 metricLabels = []string{typeLabel, opLabel, directionLabel} 36 37 errUnknownCompressionType = errors.New("message is compressed with an unknown compression type") 38 ) 39 40 // InboundMessage represents a set of fields for an inbound message 41 type InboundMessage interface { 42 fmt.Stringer 43 // NodeID returns the ID of the node that sent this message 44 NodeID() ids.NodeID 45 // Op returns the op that describes this message type 46 Op() Op 47 // Message returns the message that was sent 48 Message() fmt.Stringer 49 // Expiration returns the time that the sender will have already timed out 50 // this request 51 Expiration() time.Time 52 // OnFinishedHandling must be called one time when this message has been 53 // handled by the message handler 54 OnFinishedHandling() 55 // BytesSavedCompression returns the number of bytes that this message saved 56 // due to being compressed 57 BytesSavedCompression() int 58 } 59 60 type inboundMessage struct { 61 nodeID ids.NodeID 62 op Op 63 message fmt.Stringer 64 expiration time.Time 65 onFinishedHandling func() 66 bytesSavedCompression int 67 } 68 69 func (m *inboundMessage) NodeID() ids.NodeID { 70 return m.nodeID 71 } 72 73 func (m *inboundMessage) Op() Op { 74 return m.op 75 } 76 77 func (m *inboundMessage) Message() fmt.Stringer { 78 return m.message 79 } 80 81 func (m *inboundMessage) Expiration() time.Time { 82 return m.expiration 83 } 84 85 func (m *inboundMessage) OnFinishedHandling() { 86 if m.onFinishedHandling != nil { 87 m.onFinishedHandling() 88 } 89 } 90 91 func (m *inboundMessage) BytesSavedCompression() int { 92 return m.bytesSavedCompression 93 } 94 95 func (m *inboundMessage) String() string { 96 return fmt.Sprintf("%s Op: %s Message: %s", 97 m.nodeID, m.op, m.message) 98 } 99 100 // OutboundMessage represents a set of fields for an outbound message that can 101 // be serialized into a byte stream 102 type OutboundMessage interface { 103 // BypassThrottling returns true if we should send this message, regardless 104 // of any outbound message throttling 105 BypassThrottling() bool 106 // Op returns the op that describes this message type 107 Op() Op 108 // Bytes returns the bytes that will be sent 109 Bytes() []byte 110 // BytesSavedCompression returns the number of bytes that this message saved 111 // due to being compressed 112 BytesSavedCompression() int 113 } 114 115 type outboundMessage struct { 116 bypassThrottling bool 117 op Op 118 bytes []byte 119 bytesSavedCompression int 120 } 121 122 func (m *outboundMessage) BypassThrottling() bool { 123 return m.bypassThrottling 124 } 125 126 func (m *outboundMessage) Op() Op { 127 return m.op 128 } 129 130 func (m *outboundMessage) Bytes() []byte { 131 return m.bytes 132 } 133 134 func (m *outboundMessage) BytesSavedCompression() int { 135 return m.bytesSavedCompression 136 } 137 138 // TODO: add other compression algorithms with extended interface 139 type msgBuilder struct { 140 log logging.Logger 141 142 zstdCompressor compression.Compressor 143 count *prometheus.CounterVec // type + op + direction 144 duration *prometheus.GaugeVec // type + op + direction 145 146 maxMessageTimeout time.Duration 147 } 148 149 func newMsgBuilder( 150 log logging.Logger, 151 metrics prometheus.Registerer, 152 maxMessageTimeout time.Duration, 153 ) (*msgBuilder, error) { 154 zstdCompressor, err := compression.NewZstdCompressor(constants.DefaultMaxMessageSize) 155 if err != nil { 156 return nil, err 157 } 158 159 mb := &msgBuilder{ 160 log: log, 161 162 zstdCompressor: zstdCompressor, 163 count: prometheus.NewCounterVec( 164 prometheus.CounterOpts{ 165 Name: "codec_compressed_count", 166 Help: "number of compressed messages", 167 }, 168 metricLabels, 169 ), 170 duration: prometheus.NewGaugeVec( 171 prometheus.GaugeOpts{ 172 Name: "codec_compressed_duration", 173 Help: "time spent handling compressed messages", 174 }, 175 metricLabels, 176 ), 177 178 maxMessageTimeout: maxMessageTimeout, 179 } 180 return mb, errors.Join( 181 metrics.Register(mb.count), 182 metrics.Register(mb.duration), 183 ) 184 } 185 186 func (mb *msgBuilder) marshal( 187 uncompressedMsg *p2p.Message, 188 compressionType compression.Type, 189 ) ([]byte, int, Op, error) { 190 uncompressedMsgBytes, err := proto.Marshal(uncompressedMsg) 191 if err != nil { 192 return nil, 0, 0, err 193 } 194 195 op, err := ToOp(uncompressedMsg) 196 if err != nil { 197 return nil, 0, 0, err 198 } 199 200 // If compression is enabled, we marshal twice: 201 // 1. the original message 202 // 2. the message with compressed bytes 203 // 204 // This recursive packing allows us to avoid an extra compression on/off 205 // field in the message. 206 var ( 207 startTime = time.Now() 208 compressedMsg p2p.Message 209 ) 210 switch compressionType { 211 case compression.TypeNone: 212 return uncompressedMsgBytes, 0, op, nil 213 case compression.TypeZstd: 214 compressedBytes, err := mb.zstdCompressor.Compress(uncompressedMsgBytes) 215 if err != nil { 216 return nil, 0, 0, err 217 } 218 compressedMsg = p2p.Message{ 219 Message: &p2p.Message_CompressedZstd{ 220 CompressedZstd: compressedBytes, 221 }, 222 } 223 default: 224 return nil, 0, 0, errUnknownCompressionType 225 } 226 227 compressedMsgBytes, err := proto.Marshal(&compressedMsg) 228 if err != nil { 229 return nil, 0, 0, err 230 } 231 compressTook := time.Since(startTime) 232 233 labels := prometheus.Labels{ 234 typeLabel: compressionType.String(), 235 opLabel: op.String(), 236 directionLabel: compressionLabel, 237 } 238 mb.count.With(labels).Inc() 239 mb.duration.With(labels).Add(float64(compressTook)) 240 241 bytesSaved := len(uncompressedMsgBytes) - len(compressedMsgBytes) 242 return compressedMsgBytes, bytesSaved, op, nil 243 } 244 245 func (mb *msgBuilder) unmarshal(b []byte) (*p2p.Message, int, Op, error) { 246 m := new(p2p.Message) 247 if err := proto.Unmarshal(b, m); err != nil { 248 return nil, 0, 0, err 249 } 250 251 // Figure out what compression type, if any, was used to compress the message. 252 var ( 253 compressor compression.Compressor 254 compressedBytes []byte 255 zstdCompressed = m.GetCompressedZstd() 256 ) 257 switch { 258 case len(zstdCompressed) > 0: 259 compressor = mb.zstdCompressor 260 compressedBytes = zstdCompressed 261 default: 262 // The message wasn't compressed 263 op, err := ToOp(m) 264 return m, 0, op, err 265 } 266 267 startTime := time.Now() 268 269 decompressed, err := compressor.Decompress(compressedBytes) 270 if err != nil { 271 return nil, 0, 0, err 272 } 273 bytesSavedCompression := len(decompressed) - len(compressedBytes) 274 275 if err := proto.Unmarshal(decompressed, m); err != nil { 276 return nil, 0, 0, err 277 } 278 decompressTook := time.Since(startTime) 279 280 // Record decompression time metric 281 op, err := ToOp(m) 282 if err != nil { 283 return nil, 0, 0, err 284 } 285 286 labels := prometheus.Labels{ 287 typeLabel: compression.TypeZstd.String(), 288 opLabel: op.String(), 289 directionLabel: decompressionLabel, 290 } 291 mb.count.With(labels).Inc() 292 mb.duration.With(labels).Add(float64(decompressTook)) 293 294 return m, bytesSavedCompression, op, nil 295 } 296 297 func (mb *msgBuilder) createOutbound(m *p2p.Message, compressionType compression.Type, bypassThrottling bool) (*outboundMessage, error) { 298 b, saved, op, err := mb.marshal(m, compressionType) 299 if err != nil { 300 return nil, err 301 } 302 303 return &outboundMessage{ 304 bypassThrottling: bypassThrottling, 305 op: op, 306 bytes: b, 307 bytesSavedCompression: saved, 308 }, nil 309 } 310 311 func (mb *msgBuilder) parseInbound( 312 bytes []byte, 313 nodeID ids.NodeID, 314 onFinishedHandling func(), 315 ) (*inboundMessage, error) { 316 m, bytesSavedCompression, op, err := mb.unmarshal(bytes) 317 if err != nil { 318 return nil, err 319 } 320 321 msg, err := Unwrap(m) 322 if err != nil { 323 return nil, err 324 } 325 326 expiration := mockable.MaxTime 327 if deadline, ok := GetDeadline(msg); ok { 328 deadline = min(deadline, mb.maxMessageTimeout) 329 expiration = time.Now().Add(deadline) 330 } 331 332 return &inboundMessage{ 333 nodeID: nodeID, 334 op: op, 335 message: msg, 336 expiration: expiration, 337 onFinishedHandling: onFinishedHandling, 338 bytesSavedCompression: bytesSavedCompression, 339 }, nil 340 }