github.com/matrixorigin/matrixone@v1.2.0/pkg/common/morpc/codec.go (about) 1 // Copyright 2021 - 2022 Matrix Origin 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package morpc 16 17 import ( 18 "encoding/hex" 19 "io" 20 "sync" 21 22 "github.com/cespare/xxhash/v2" 23 "github.com/fagongzi/goetty/v2/buf" 24 "github.com/fagongzi/goetty/v2/codec" 25 "github.com/fagongzi/goetty/v2/codec/length" 26 "github.com/matrixorigin/matrixone/pkg/common/moerr" 27 "github.com/matrixorigin/matrixone/pkg/common/mpool" 28 "github.com/matrixorigin/matrixone/pkg/txn/clock" 29 "github.com/pierrec/lz4/v4" 30 "go.uber.org/zap" 31 ) 32 33 const ( 34 flagHashPayload byte = 1 << iota 35 flagChecksumEnabled 36 flagHasCustomHeader 37 flagCompressEnabled 38 flagStreamingMessage 39 flagPing 40 flagPong 41 ) 42 43 var ( 44 defaultMaxBodyMessageSize = 1024 * 1024 * 100 45 checksumFieldBytes = 8 46 totalSizeFieldBytes = 4 47 payloadSizeFieldBytes = 4 48 49 approximateHeaderSize = 1024 * 1024 * 10 50 ) 51 52 func GetMessageSize() int { 53 return defaultMaxBodyMessageSize 54 } 55 56 // WithCodecEnableChecksum enable checksum 57 func WithCodecEnableChecksum() CodecOption { 58 return func(c *messageCodec) { 59 c.bc.checksumEnabled = true 60 } 61 } 62 63 // WithCodecPayloadCopyBufferSize set payload copy buffer size, if is a PayloadMessage 64 func WithCodecPayloadCopyBufferSize(value int) CodecOption { 65 return func(c *messageCodec) { 66 c.bc.payloadBufSize = value 67 } 68 } 69 70 // WithCodecIntegrationHLC intrgration hlc 71 func WithCodecIntegrationHLC(clock clock.Clock) CodecOption { 72 return func(c *messageCodec) { 73 c.AddHeaderCodec(&hlcCodec{clock: clock}) 74 } 75 } 76 77 // WithCodecMaxBodySize set rpc max body size 78 func WithCodecMaxBodySize(size int) CodecOption { 79 return func(c *messageCodec) { 80 if size == 0 { 81 size = defaultMaxBodyMessageSize 82 } 83 c.codec = length.NewWithSize(c.bc, 0, 0, 0, size+approximateHeaderSize) 84 c.bc.maxBodySize = size 85 } 86 } 87 88 // WithCodecEnableCompress enable compress body and payload 89 func WithCodecEnableCompress(pool *mpool.MPool) CodecOption { 90 return func(c *messageCodec) { 91 c.bc.compressEnabled = true 92 c.bc.pool = pool 93 } 94 } 95 96 type messageCodec struct { 97 codec codec.Codec 98 bc *baseCodec 99 } 100 101 // NewMessageCodec create message codec. The message encoding format consists of a message header and a message body. 102 // Format: 103 // 1. Size, 4 bytes, required. Inlucde header and body. 104 // 2. Message header 105 // 2.1. Flag, 1 byte, required. 106 // 2.2. Checksum, 8 byte, optional. Set if has a checksun flag 107 // 2.3. PayloadSize, 4 byte, optional. Set if the message is a morpc.PayloadMessage. 108 // 2.4. Streaming sequence, 4 byte, optional. Set if the message is in a streaming. 109 // 2.5. Custom headers, optional. Set if has custom header codecs 110 // 3. Message body 111 // 3.1. message body, required. 112 // 3.2. payload, optional. Set if has paylad flag. 113 func NewMessageCodec(messageFactory func() Message, options ...CodecOption) Codec { 114 bc := &baseCodec{ 115 messageFactory: messageFactory, 116 maxBodySize: defaultMaxBodyMessageSize, 117 } 118 c := &messageCodec{ 119 codec: length.NewWithSize(bc, 0, 0, 0, defaultMaxBodyMessageSize+approximateHeaderSize), 120 bc: bc, 121 } 122 c.AddHeaderCodec(&deadlineContextCodec{}) 123 c.AddHeaderCodec(&traceCodec{}) 124 125 for _, opt := range options { 126 opt(c) 127 } 128 return c 129 } 130 131 func (c *messageCodec) Decode(in *buf.ByteBuf) (any, bool, error) { 132 return c.codec.Decode(in) 133 } 134 135 func (c *messageCodec) Encode(data interface{}, out *buf.ByteBuf, conn io.Writer) error { 136 return c.bc.Encode(data, out, conn) 137 } 138 139 func (c *messageCodec) Valid(msg Message) error { 140 n := msg.Size() 141 if n >= c.bc.maxBodySize { 142 return moerr.NewInternalErrorNoCtx("message body %d is too large, max is %d", 143 n, 144 c.bc.maxBodySize) 145 } 146 return nil 147 } 148 149 func (c *messageCodec) AddHeaderCodec(hc HeaderCodec) { 150 c.bc.headerCodecs = append(c.bc.headerCodecs, hc) 151 } 152 153 type baseCodec struct { 154 pool *mpool.MPool 155 checksumEnabled bool 156 compressEnabled bool 157 payloadBufSize int 158 maxBodySize int 159 messageFactory func() Message 160 headerCodecs []HeaderCodec 161 } 162 163 func (c *baseCodec) Decode(in *buf.ByteBuf) (any, bool, error) { 164 msg := RPCMessage{} 165 offset := 0 166 data := getDecodeData(in) 167 168 // 2.1 169 flag, n := c.readFlag(&msg, data, offset) 170 offset += n 171 172 // 2.2 173 expectChecksum, n := readChecksum(flag, data, offset) 174 offset += n 175 176 // 2.3 177 payloadSize, n := readPayloadSize(flag, data, offset) 178 offset += n 179 180 // 2.4 181 n, err := c.readCustomHeaders(flag, &msg, data, offset) 182 if err != nil { 183 return nil, false, err 184 } 185 offset += n 186 187 // 2.5 188 offset += readStreaming(flag, &msg, data, offset) 189 190 // 3.1 and 3.2 191 if err := c.readMessage(flag, data, offset, expectChecksum, payloadSize, &msg); err != nil { 192 return nil, false, err 193 } 194 195 in.SetReadIndex(in.GetMarkIndex()) 196 in.ClearMark() 197 return msg, true, nil 198 } 199 200 func (c *baseCodec) Encode(data interface{}, out *buf.ByteBuf, conn io.Writer) error { 201 msg, ok := data.(RPCMessage) 202 if !ok { 203 return moerr.NewInternalErrorNoCtx("not support %T %+v", data, data) 204 } 205 206 startWriteOffset := out.GetWriteOffset() 207 totalSize := 0 208 // The total message size cannot be determined at the beginning and needs to wait until all the 209 // dynamic content is determined before the total size can be determined. After the total size is 210 // determined, we need to write the total size data in the location of totalSizeAt 211 totalSizeAt := skip(totalSizeFieldBytes, out) 212 213 // 2.1 flag 214 flag := c.getFlag(msg) 215 out.MustWriteByte(flag) 216 totalSize += 1 217 218 // 2.2 checksum, similar to totalSize, we do not currently know the size of the message body. 219 checksumAt := -1 220 if flag&flagChecksumEnabled != 0 { 221 checksumAt = skip(checksumFieldBytes, out) 222 totalSize += checksumFieldBytes 223 } 224 225 // 2.3 payload 226 var payloadData []byte 227 var compressedPayloadData []byte 228 var payloadMsg PayloadMessage 229 var hasPayload bool 230 231 // skip all written data by this message 232 discardWritten := func() { 233 out.SetWriteIndexByOffset(startWriteOffset) 234 if hasPayload { 235 payloadMsg.SetPayloadField(payloadData) 236 } 237 } 238 239 if payloadMsg, hasPayload = msg.Message.(PayloadMessage); hasPayload { 240 // set payload field to nil to avoid payload being written to the out buffer, and write directly 241 // to the socket afterwards to reduce one payload. 242 payloadData = payloadMsg.GetPayloadField() 243 payloadMsg.SetPayloadField(nil) 244 compressedPayloadData = payloadData 245 246 if c.compressEnabled && len(payloadData) > 0 { 247 v, err := c.compress(payloadData) 248 if err != nil { 249 discardWritten() 250 return err 251 } 252 defer c.pool.Free(v) 253 compressedPayloadData = v 254 } 255 256 out.WriteInt(len(compressedPayloadData)) 257 totalSize += payloadSizeFieldBytes + len(compressedPayloadData) 258 } 259 260 // 2.4 Custom header size 261 n, err := c.encodeCustomHeaders(&msg, out) 262 if err != nil { 263 return err 264 } 265 totalSize += n 266 267 // 2.5 streaming message 268 if msg.stream { 269 out.WriteUint32(msg.streamSequence) 270 totalSize += 4 271 } 272 273 // 3.1 message body 274 body, err := c.writeBody(out, msg.Message) 275 if err != nil { 276 discardWritten() 277 return err 278 } 279 280 // now, header and body are all determined, we need to fill the totalSize and checksum 281 // fill total size 282 totalSize += len(body) 283 writeIntAt(totalSizeAt, out, totalSize) 284 285 // fill checksum 286 if checksumAt != -1 { 287 if err := writeChecksum(checksumAt, out, body, compressedPayloadData); err != nil { 288 discardWritten() 289 return err 290 } 291 } 292 293 // 3.2 payload 294 if hasPayload { 295 // resume payload to payload message 296 payloadMsg.SetPayloadField(payloadData) 297 if err := writePayload(out, compressedPayloadData, conn, c.payloadBufSize); err != nil { 298 return err 299 } 300 } 301 302 return nil 303 } 304 305 func (c *baseCodec) compress(src []byte) ([]byte, error) { 306 n := lz4.CompressBlockBound(len(src)) 307 dst, err := c.pool.Alloc(n) 308 if err != nil { 309 return nil, err 310 } 311 dst, err = c.compressTo(src, dst) 312 if err != nil { 313 c.pool.Free(dst) 314 return nil, err 315 } 316 return dst, nil 317 } 318 319 func (c *baseCodec) uncompress(src []byte) ([]byte, error) { 320 // The lz4 library requires a []byte with a large enough dst when 321 // decompressing, otherwise it will return an ErrInvalidSourceShortBuffer, we 322 // can't confirm how large a dst we need to give initially, so when we encounter 323 // an ErrInvalidSourceShortBuffer, we expand the size and retry. 324 n := len(src) * 2 325 for { 326 dst, err := c.pool.Alloc(n) 327 if err != nil { 328 return nil, err 329 } 330 // the following line of code needs to free dst if it returns an error, but dst 331 // has been reassigned 332 old := dst 333 dst, err = uncompress(src, dst) 334 if err == nil { 335 return dst, nil 336 } 337 338 c.pool.Free(old) 339 if err != lz4.ErrInvalidSourceShortBuffer { 340 return nil, err 341 } 342 n *= 2 343 } 344 } 345 346 func (c *baseCodec) compressTo(src, dst []byte) ([]byte, error) { 347 dst, err := compress(src, dst) 348 if err != nil { 349 return nil, err 350 } 351 return dst, nil 352 } 353 354 func (c *baseCodec) compressBound(size int) int { 355 return lz4.CompressBlockBound(size) 356 } 357 358 func (c *baseCodec) getFlag(msg RPCMessage) byte { 359 flag := byte(0) 360 if c.checksumEnabled { 361 flag |= flagChecksumEnabled 362 } 363 if c.compressEnabled { 364 flag |= flagCompressEnabled 365 } 366 if len(c.headerCodecs) > 0 { 367 flag |= flagHasCustomHeader 368 } 369 if _, ok := msg.Message.(PayloadMessage); ok { 370 flag |= flagHashPayload 371 } 372 if msg.stream { 373 flag |= flagStreamingMessage 374 } 375 if msg.internal { 376 if m, ok := msg.Message.(*flagOnlyMessage); ok { 377 flag |= m.flag 378 } 379 } 380 return flag 381 } 382 383 func (c *baseCodec) encodeCustomHeaders(msg *RPCMessage, out *buf.ByteBuf) (int, error) { 384 if len(c.headerCodecs) == 0 { 385 return 0, nil 386 } 387 388 size := 0 389 for _, hc := range c.headerCodecs { 390 v, err := hc.Encode(msg, out) 391 if err != nil { 392 return 0, err 393 } 394 size += v 395 } 396 return size, nil 397 } 398 399 func (c *baseCodec) readCustomHeaders(flag byte, msg *RPCMessage, data []byte, offset int) (int, error) { 400 if flag&flagHasCustomHeader == 0 { 401 return 0, nil 402 } 403 404 readed := 0 405 for _, hc := range c.headerCodecs { 406 n, err := hc.Decode(msg, data[offset+readed:]) 407 if err != nil { 408 return 0, err 409 } 410 readed += n 411 } 412 return readed, nil 413 } 414 415 func (c *baseCodec) writeBody( 416 out *buf.ByteBuf, 417 msg Message) ([]byte, error) { 418 size := msg.Size() 419 if size == 0 { 420 return nil, nil 421 } 422 if !c.compressEnabled { 423 index, _ := setWriterIndexAfterGow(out, size) 424 data := out.RawSlice(index, index+size) 425 if _, err := msg.MarshalTo(data); err != nil { 426 return nil, err 427 } 428 return data, nil 429 } 430 431 // we use mpool to compress body, then write the dst into the buffer 432 origin, err := c.pool.Alloc(size) 433 if err != nil { 434 return nil, err 435 } 436 defer c.pool.Free(origin) 437 if _, err := msg.MarshalTo(origin); err != nil { 438 return nil, err 439 } 440 441 n := c.compressBound(len(origin)) 442 dst, err := c.pool.Alloc(n) 443 if err != nil { 444 return nil, err 445 } 446 defer c.pool.Free(dst) 447 448 dst, err = compress(origin, dst) 449 if err != nil { 450 return nil, err 451 } 452 453 index := out.GetWriteOffset() 454 out.MustWrite(dst) 455 return out.RawSlice(out.GetReadIndex()+index, out.GetWriteIndex()), nil 456 } 457 458 func (c *baseCodec) readMessage( 459 flag byte, 460 data []byte, 461 offset int, 462 expectChecksum uint64, 463 payloadSize int, 464 msg *RPCMessage) error { 465 if offset == len(data) { 466 return nil 467 } 468 469 // invalid body packet 470 if offset >= len(data)-payloadSize { 471 getLogger().Warn("invalid body packet", 472 zap.Int("offset", offset), 473 zap.Int("len", len(data)), 474 zap.Int("payloadSize", payloadSize), 475 zap.String("data-hex", hex.EncodeToString(data))) 476 return moerr.NewInvalidInputNoCtx("invalid body packet, offset %d, len %d, payload size %d", 477 offset, len(data), payloadSize) 478 } 479 480 body := data[offset : len(data)-payloadSize] 481 payload := data[len(data)-payloadSize:] 482 if flag&flagChecksumEnabled != 0 { 483 if err := validChecksum(body, payload, expectChecksum); err != nil { 484 return err 485 } 486 } 487 488 if flag&flagCompressEnabled != 0 { 489 dstBody, err := c.uncompress(body) 490 if err != nil { 491 return err 492 } 493 defer c.pool.Free(dstBody) 494 body = dstBody 495 496 if payloadSize > 0 { 497 dstPayload, err := c.uncompress(payload) 498 if err != nil { 499 return err 500 } 501 // we cannot free dstPayload here as it will be set to the result for subsequent 502 // modules. 503 // TODO: We currently use copy to avoid memory leaks in mpool, but this introduces 504 // a memory allocation overhead that will need to be optimized away. 505 defer c.pool.Free(dstPayload) 506 payload = clone(dstPayload) 507 } 508 } 509 510 if err := msg.Message.Unmarshal(body); err != nil { 511 return err 512 } 513 514 if payloadSize > 0 { 515 msg.Message.(PayloadMessage).SetPayloadField(payload) 516 } 517 return nil 518 } 519 520 var ( 521 checksumPool = sync.Pool{ 522 New: func() any { 523 return xxhash.New() 524 }, 525 } 526 ) 527 528 func acquireChecksum() *xxhash.Digest { 529 return checksumPool.Get().(*xxhash.Digest) 530 } 531 532 func releaseChecksum(checksum *xxhash.Digest) { 533 checksum.Reset() 534 checksumPool.Put(checksum) 535 } 536 537 func skip(n int, out *buf.ByteBuf) int { 538 _, offset := setWriterIndexAfterGow(out, n) 539 return offset 540 } 541 542 func writeIntAt(offset int, out *buf.ByteBuf, value int) { 543 idx := out.GetReadIndex() + offset 544 buf.Int2BytesTo(value, out.RawSlice(idx, idx+4)) 545 } 546 547 func writeUint64At(offset int, out *buf.ByteBuf, value uint64) { 548 idx := out.GetReadIndex() + offset 549 buf.Uint64ToBytesTo(value, out.RawSlice(idx, idx+8)) 550 } 551 552 func writePayload(out *buf.ByteBuf, payload []byte, conn io.Writer, copyBuffer int) error { 553 if len(payload) == 0 { 554 return nil 555 } 556 557 // reset here to avoid buffer expansion as much as possible 558 defer out.Reset() 559 560 // first, write header and body to socket 561 if _, err := out.WriteTo(conn); err != nil { 562 return err 563 } 564 565 // write payload to socket 566 if err := buf.WriteTo(payload, conn, copyBuffer); err != nil { 567 return err 568 } 569 return nil 570 } 571 572 func writeChecksum(offset int, out *buf.ByteBuf, body, payload []byte) error { 573 checksum := acquireChecksum() 574 defer releaseChecksum(checksum) 575 576 _, err := checksum.Write(body) 577 if err != nil { 578 return err 579 } 580 if len(payload) > 0 { 581 _, err = checksum.Write(payload) 582 if err != nil { 583 return err 584 } 585 } 586 writeUint64At(offset, out, checksum.Sum64()) 587 return nil 588 } 589 590 func getDecodeData(in *buf.ByteBuf) []byte { 591 return in.RawSlice(in.GetReadIndex(), in.GetMarkIndex()) 592 } 593 594 func (c *baseCodec) readFlag(msg *RPCMessage, data []byte, offset int) (byte, int) { 595 flag := data[offset] 596 if flag&flagPing != 0 { 597 msg.Message = &flagOnlyMessage{flag: flagPing} 598 msg.internal = true 599 } else if flag&flagPong != 0 { 600 msg.Message = &flagOnlyMessage{flag: flagPong} 601 msg.internal = true 602 } else { 603 msg.Message = c.messageFactory() 604 } 605 return flag, 1 606 } 607 608 func readChecksum(flag byte, data []byte, offset int) (uint64, int) { 609 if flag&flagChecksumEnabled == 0 { 610 return 0, 0 611 } 612 613 return buf.Byte2Uint64(data[offset:]), checksumFieldBytes 614 } 615 616 func readPayloadSize(flag byte, data []byte, offset int) (int, int) { 617 if flag&flagHashPayload == 0 { 618 return 0, 0 619 } 620 621 return buf.Byte2Int(data[offset:]), payloadSizeFieldBytes 622 } 623 624 func readStreaming(flag byte, msg *RPCMessage, data []byte, offset int) int { 625 if flag&flagStreamingMessage == 0 { 626 return 0 627 } 628 msg.stream = true 629 msg.streamSequence = buf.Byte2Uint32(data[offset:]) 630 return 4 631 } 632 633 func validChecksum(body, payload []byte, expectChecksum uint64) error { 634 checksum := acquireChecksum() 635 defer releaseChecksum(checksum) 636 637 _, err := checksum.Write(body) 638 if err != nil { 639 return err 640 } 641 if len(payload) > 0 { 642 _, err := checksum.Write(payload) 643 if err != nil { 644 return err 645 } 646 } 647 actulChecksum := checksum.Sum64() 648 if actulChecksum != expectChecksum { 649 return moerr.NewInternalErrorNoCtx("checksum mismatch, expect %d, got %d", 650 expectChecksum, 651 actulChecksum) 652 } 653 return nil 654 } 655 656 func setWriterIndexAfterGow(out *buf.ByteBuf, n int) (int, int) { 657 offset := out.Readable() 658 out.Grow(n) 659 out.SetWriteIndex(out.GetReadIndex() + offset + n) 660 return out.GetReadIndex() + offset, offset 661 } 662 663 func clone(src []byte) []byte { 664 v := make([]byte, len(src)) 665 copy(v, src) 666 return v 667 }