github.com/matrixorigin/matrixone@v0.7.0/pkg/common/morpc/codec.go (about) 1 // Copyright 2021 - 2022 Matrix Origin 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package morpc 16 17 import ( 18 "io" 19 "sync" 20 21 "github.com/cespare/xxhash/v2" 22 "github.com/fagongzi/goetty/v2/buf" 23 "github.com/fagongzi/goetty/v2/codec" 24 "github.com/fagongzi/goetty/v2/codec/length" 25 "github.com/matrixorigin/matrixone/pkg/common/moerr" 26 "github.com/matrixorigin/matrixone/pkg/common/mpool" 27 "github.com/matrixorigin/matrixone/pkg/txn/clock" 28 "github.com/pierrec/lz4/v4" 29 ) 30 31 const ( 32 flagHashPayload byte = 1 << iota 33 flagChecksumEnabled 34 flagHasCustomHeader 35 flagCompressEnabled 36 flagStreamingMessage 37 flagPing 38 flagPong 39 ) 40 41 var ( 42 defaultMaxBodyMessageSize = 1024 * 1024 * 100 43 checksumFieldBytes = 8 44 totalSizeFieldBytes = 4 45 payloadSizeFieldBytes = 4 46 47 approximateHeaderSize = 1024 * 1024 * 10 48 ) 49 50 func GetMessageSize() int { 51 return defaultMaxBodyMessageSize 52 } 53 54 // WithCodecEnableChecksum enable checksum 55 func WithCodecEnableChecksum() CodecOption { 56 return func(c *messageCodec) { 57 c.bc.checksumEnabled = true 58 } 59 } 60 61 // WithCodecPayloadCopyBufferSize set payload copy buffer size, if is a PayloadMessage 62 func WithCodecPayloadCopyBufferSize(value int) CodecOption { 63 return func(c *messageCodec) { 64 c.bc.payloadBufSize = value 65 } 66 } 67 68 // WithCodecIntegrationHLC intrgration hlc 69 func WithCodecIntegrationHLC(clock clock.Clock) CodecOption { 70 return func(c *messageCodec) { 71 c.AddHeaderCodec(&hlcCodec{clock: clock}) 72 } 73 } 74 75 // WithCodecMaxBodySize set rpc max body size 76 func WithCodecMaxBodySize(size int) CodecOption { 77 return func(c *messageCodec) { 78 if size == 0 { 79 size = defaultMaxBodyMessageSize 80 } 81 c.codec = length.NewWithSize(c.bc, 0, 0, 0, size+approximateHeaderSize) 82 c.bc.maxBodySize = size 83 } 84 } 85 86 // WithCodecEnableCompress enable compress body and payload 87 func WithCodecEnableCompress(pool *mpool.MPool) CodecOption { 88 return func(c *messageCodec) { 89 c.bc.compressEnabled = true 90 c.bc.pool = pool 91 } 92 } 93 94 type messageCodec struct { 95 codec codec.Codec 96 bc *baseCodec 97 } 98 99 // NewMessageCodec create message codec. The message encoding format consists of a message header and a message body. 100 // Format: 101 // 1. Size, 4 bytes, required. Inlucde header and body. 102 // 2. Message header 103 // 2.1. Flag, 1 byte, required. 104 // 2.2. Checksum, 8 byte, optional. Set if has a checksun flag 105 // 2.3. PayloadSize, 4 byte, optional. Set if the message is a morpc.PayloadMessage. 106 // 2.4. Streaming sequence, 4 byte, optional. Set if the message is in a streaming. 107 // 2.5. Custom headers, optional. Set if has custom header codecs 108 // 3. Message body 109 // 3.1. message body, required. 110 // 3.2. payload, optional. Set if has paylad flag. 111 func NewMessageCodec(messageFactory func() Message, options ...CodecOption) Codec { 112 bc := &baseCodec{ 113 messageFactory: messageFactory, 114 maxBodySize: defaultMaxBodyMessageSize, 115 } 116 c := &messageCodec{ 117 codec: length.NewWithSize(bc, 0, 0, 0, defaultMaxBodyMessageSize+approximateHeaderSize), 118 bc: bc, 119 } 120 c.AddHeaderCodec(&deadlineContextCodec{}) 121 c.AddHeaderCodec(&traceCodec{}) 122 123 for _, opt := range options { 124 opt(c) 125 } 126 return c 127 } 128 129 func (c *messageCodec) Decode(in *buf.ByteBuf) (any, bool, error) { 130 return c.codec.Decode(in) 131 } 132 133 func (c *messageCodec) Encode(data interface{}, out *buf.ByteBuf, conn io.Writer) error { 134 return c.bc.Encode(data, out, conn) 135 } 136 137 func (c *messageCodec) Valid(msg Message) error { 138 n := msg.Size() 139 if n >= c.bc.maxBodySize { 140 return moerr.NewInternalErrorNoCtx("message body %d is too large, max is %d", 141 n, 142 c.bc.maxBodySize) 143 } 144 return nil 145 } 146 147 func (c *messageCodec) AddHeaderCodec(hc HeaderCodec) { 148 c.bc.headerCodecs = append(c.bc.headerCodecs, hc) 149 } 150 151 type baseCodec struct { 152 pool *mpool.MPool 153 checksumEnabled bool 154 compressEnabled bool 155 payloadBufSize int 156 maxBodySize int 157 messageFactory func() Message 158 headerCodecs []HeaderCodec 159 } 160 161 func (c *baseCodec) Decode(in *buf.ByteBuf) (any, bool, error) { 162 msg := RPCMessage{} 163 offset := 0 164 data := getDecodeData(in) 165 166 // 2.1 167 flag, n := c.readFlag(&msg, data, offset) 168 offset += n 169 170 // 2.2 171 expectChecksum, n := readChecksum(flag, data, offset) 172 offset += n 173 174 // 2.3 175 payloadSize, n := readPayloadSize(flag, data, offset) 176 offset += n 177 178 // 2.4 179 n, err := c.readCustomHeaders(flag, &msg, data, offset) 180 if err != nil { 181 return nil, false, err 182 } 183 offset += n 184 185 // 2.5 186 offset += readStreaming(flag, &msg, data, offset) 187 188 // 3.1 and 3.2 189 if err := c.readMessage(flag, data, offset, expectChecksum, payloadSize, &msg); err != nil { 190 return nil, false, err 191 } 192 193 in.SetReadIndex(in.GetMarkIndex()) 194 in.ClearMark() 195 return msg, true, nil 196 } 197 198 func (c *baseCodec) Encode(data interface{}, out *buf.ByteBuf, conn io.Writer) error { 199 msg, ok := data.(RPCMessage) 200 if !ok { 201 return moerr.NewInternalErrorNoCtx("not support %T %+v", data, data) 202 } 203 204 startWriteOffset := out.GetWriteOffset() 205 totalSize := 0 206 // The total message size cannot be determined at the beginning and needs to wait until all the 207 // dynamic content is determined before the total size can be determined. After the total size is 208 // determined, we need to write the total size data in the location of totalSizeAt 209 totalSizeAt := skip(totalSizeFieldBytes, out) 210 211 // 2.1 flag 212 flag := c.getFlag(msg) 213 out.MustWriteByte(flag) 214 totalSize += 1 215 216 // 2.2 checksum, similar to totalSize, we do not currently know the size of the message body. 217 checksumAt := -1 218 if flag&flagChecksumEnabled != 0 { 219 checksumAt = skip(checksumFieldBytes, out) 220 totalSize += checksumFieldBytes 221 } 222 223 // 2.3 payload 224 var payloadData []byte 225 var compressedPayloadData []byte 226 var payloadMsg PayloadMessage 227 var hasPayload bool 228 229 // skip all written data by this message 230 discardWritten := func() { 231 out.SetWriteIndexByOffset(startWriteOffset) 232 if hasPayload { 233 payloadMsg.SetPayloadField(payloadData) 234 } 235 } 236 237 if payloadMsg, hasPayload = msg.Message.(PayloadMessage); hasPayload { 238 // set payload filed to nil to avoid payload being written to the out buffer, and write directly 239 // to the socket afterwards to reduce one payload. 240 payloadData = payloadMsg.GetPayloadField() 241 payloadMsg.SetPayloadField(nil) 242 compressedPayloadData = payloadData 243 244 if c.compressEnabled && len(payloadData) > 0 { 245 v, err := c.compress(payloadData) 246 if err != nil { 247 discardWritten() 248 return err 249 } 250 defer c.pool.Free(v) 251 compressedPayloadData = v 252 } 253 254 out.WriteInt(len(compressedPayloadData)) 255 totalSize += payloadSizeFieldBytes + len(compressedPayloadData) 256 } 257 258 // 2.4 Custom header size 259 n, err := c.encodeCustomHeaders(&msg, out) 260 if err != nil { 261 return err 262 } 263 totalSize += n 264 265 // 2.5 streaming message 266 if msg.stream { 267 out.WriteUint32(msg.streamSequence) 268 totalSize += 4 269 } 270 271 // 3.1 message body 272 body, err := c.writeBody(out, msg.Message) 273 if err != nil { 274 discardWritten() 275 return err 276 } 277 278 // now, header and body are all determined, we need to fill the totalSize and checksum 279 // fill total size 280 totalSize += len(body) 281 writeIntAt(totalSizeAt, out, totalSize) 282 283 // fill checksum 284 if checksumAt != -1 { 285 if err := writeChecksum(checksumAt, out, body, compressedPayloadData); err != nil { 286 discardWritten() 287 return err 288 } 289 } 290 291 // 3.2 payload 292 if hasPayload { 293 // resume payload to payload message 294 payloadMsg.SetPayloadField(payloadData) 295 if err := writePayload(out, compressedPayloadData, conn, c.payloadBufSize); err != nil { 296 return err 297 } 298 } 299 300 return nil 301 } 302 303 func (c *baseCodec) compress(src []byte) ([]byte, error) { 304 n := lz4.CompressBlockBound(len(src)) 305 dst, err := c.pool.Alloc(n) 306 if err != nil { 307 return nil, err 308 } 309 dst, err = c.compressTo(src, dst) 310 if err != nil { 311 c.pool.Free(dst) 312 return nil, err 313 } 314 return dst, nil 315 } 316 317 func (c *baseCodec) uncompress(src []byte) ([]byte, error) { 318 // The lz4 library requires a []byte with a large enough dst when 319 // decompressing, otherwise it will return an ErrInvalidSourceShortBuffer, we 320 // can't confirm how large a dst we need to give initially, so when we encounter 321 // an ErrInvalidSourceShortBuffer, we expand the size and retry. 322 n := len(src) * 2 323 for { 324 dst, err := c.pool.Alloc(n) 325 if err != nil { 326 return nil, err 327 } 328 dst, err = uncompress(src, dst) 329 if err == nil { 330 return dst, nil 331 } 332 333 c.pool.Free(dst) 334 if err != lz4.ErrInvalidSourceShortBuffer { 335 return nil, err 336 } 337 n *= 2 338 } 339 } 340 341 func (c *baseCodec) compressTo(src, dst []byte) ([]byte, error) { 342 dst, err := compress(src, dst) 343 if err != nil { 344 return nil, err 345 } 346 return dst, nil 347 } 348 349 func (c *baseCodec) compressBound(size int) int { 350 return lz4.CompressBlockBound(size) 351 } 352 353 func (c *baseCodec) getFlag(msg RPCMessage) byte { 354 flag := byte(0) 355 if c.checksumEnabled { 356 flag |= flagChecksumEnabled 357 } 358 if c.compressEnabled { 359 flag |= flagCompressEnabled 360 } 361 if len(c.headerCodecs) > 0 { 362 flag |= flagHasCustomHeader 363 } 364 if _, ok := msg.Message.(PayloadMessage); ok { 365 flag |= flagHashPayload 366 } 367 if msg.stream { 368 flag |= flagStreamingMessage 369 } 370 if msg.internal { 371 if m, ok := msg.Message.(*flagOnlyMessage); ok { 372 flag |= m.flag 373 } 374 } 375 return flag 376 } 377 378 func (c *baseCodec) encodeCustomHeaders(msg *RPCMessage, out *buf.ByteBuf) (int, error) { 379 if len(c.headerCodecs) == 0 { 380 return 0, nil 381 } 382 383 size := 0 384 for _, hc := range c.headerCodecs { 385 v, err := hc.Encode(msg, out) 386 if err != nil { 387 return 0, err 388 } 389 size += v 390 } 391 return size, nil 392 } 393 394 func (c *baseCodec) readCustomHeaders(flag byte, msg *RPCMessage, data []byte, offset int) (int, error) { 395 if flag&flagHasCustomHeader == 0 { 396 return 0, nil 397 } 398 399 readed := 0 400 for _, hc := range c.headerCodecs { 401 n, err := hc.Decode(msg, data[offset+readed:]) 402 if err != nil { 403 return 0, err 404 } 405 readed += n 406 } 407 return readed, nil 408 } 409 410 func (c *baseCodec) writeBody( 411 out *buf.ByteBuf, 412 msg Message) ([]byte, error) { 413 size := msg.Size() 414 if size == 0 { 415 return nil, nil 416 } 417 if !c.compressEnabled { 418 index, _ := setWriterIndexAfterGow(out, size) 419 data := out.RawSlice(index, index+size) 420 if _, err := msg.MarshalTo(data); err != nil { 421 return nil, err 422 } 423 return data, nil 424 } 425 426 // we use mpool to compress body, then write the dst into the buffer 427 origin, err := c.pool.Alloc(size) 428 if err != nil { 429 return nil, err 430 } 431 defer c.pool.Free(origin) 432 if _, err := msg.MarshalTo(origin); err != nil { 433 return nil, err 434 } 435 436 n := c.compressBound(len(origin)) 437 dst, err := c.pool.Alloc(n) 438 if err != nil { 439 return nil, err 440 } 441 defer c.pool.Free(dst) 442 443 dst, err = compress(origin, dst) 444 if err != nil { 445 return nil, err 446 } 447 448 index := out.GetWriteOffset() 449 out.MustWrite(dst) 450 return out.RawSlice(out.GetReadIndex()+index, out.GetWriteIndex()), nil 451 } 452 453 func (c *baseCodec) readMessage(flag byte, data []byte, offset int, expectChecksum uint64, payloadSize int, msg *RPCMessage) error { 454 if offset == len(data) { 455 return nil 456 } 457 458 body := data[offset : len(data)-payloadSize] 459 payload := data[len(data)-payloadSize:] 460 if flag&flagChecksumEnabled != 0 { 461 if err := validChecksum(body, payload, expectChecksum); err != nil { 462 return err 463 } 464 } 465 466 if flag&flagCompressEnabled != 0 { 467 dstBody, err := c.uncompress(body) 468 if err != nil { 469 return err 470 } 471 defer c.pool.Free(dstBody) 472 body = dstBody 473 474 if payloadSize > 0 { 475 dstPayload, err := c.uncompress(payload) 476 if err != nil { 477 return err 478 } 479 defer c.pool.Free(dstPayload) 480 payload = dstPayload 481 } 482 } 483 484 if err := msg.Message.Unmarshal(body); err != nil { 485 return err 486 } 487 488 if payloadSize > 0 { 489 msg.Message.(PayloadMessage).SetPayloadField(payload) 490 } 491 return nil 492 } 493 494 var ( 495 checksumPool = sync.Pool{ 496 New: func() any { 497 return xxhash.New() 498 }, 499 } 500 ) 501 502 func acquireChecksum() *xxhash.Digest { 503 return checksumPool.Get().(*xxhash.Digest) 504 } 505 506 func releaseChecksum(checksum *xxhash.Digest) { 507 checksum.Reset() 508 checksumPool.Put(checksum) 509 } 510 511 func skip(n int, out *buf.ByteBuf) int { 512 _, offset := setWriterIndexAfterGow(out, n) 513 return offset 514 } 515 516 func writeIntAt(offset int, out *buf.ByteBuf, value int) { 517 idx := out.GetReadIndex() + offset 518 buf.Int2BytesTo(value, out.RawSlice(idx, idx+4)) 519 } 520 521 func writeUint64At(offset int, out *buf.ByteBuf, value uint64) { 522 idx := out.GetReadIndex() + offset 523 buf.Uint64ToBytesTo(value, out.RawSlice(idx, idx+8)) 524 } 525 526 func writePayload(out *buf.ByteBuf, payload []byte, conn io.Writer, copyBuffer int) error { 527 if len(payload) == 0 { 528 return nil 529 } 530 531 // reset here to avoid buffer expansion as much as possible 532 defer out.Reset() 533 534 // first, write header and body to socket 535 if _, err := out.WriteTo(conn); err != nil { 536 return err 537 } 538 539 // write payload to socket 540 if err := buf.WriteTo(payload, conn, copyBuffer); err != nil { 541 return err 542 } 543 return nil 544 } 545 546 func writeChecksum(offset int, out *buf.ByteBuf, body, payload []byte) error { 547 checksum := acquireChecksum() 548 defer releaseChecksum(checksum) 549 550 _, err := checksum.Write(body) 551 if err != nil { 552 return err 553 } 554 if len(payload) > 0 { 555 _, err = checksum.Write(payload) 556 if err != nil { 557 return err 558 } 559 } 560 writeUint64At(offset, out, checksum.Sum64()) 561 return nil 562 } 563 564 func getDecodeData(in *buf.ByteBuf) []byte { 565 return in.RawSlice(in.GetReadIndex(), in.GetMarkIndex()) 566 } 567 568 func (c *baseCodec) readFlag(msg *RPCMessage, data []byte, offset int) (byte, int) { 569 flag := data[offset] 570 if flag&flagPing != 0 { 571 msg.Message = &flagOnlyMessage{flag: flagPing} 572 msg.internal = true 573 } else if flag&flagPong != 0 { 574 msg.Message = &flagOnlyMessage{flag: flagPong} 575 msg.internal = true 576 } else { 577 msg.Message = c.messageFactory() 578 } 579 return flag, 1 580 } 581 582 func readChecksum(flag byte, data []byte, offset int) (uint64, int) { 583 if flag&flagChecksumEnabled == 0 { 584 return 0, 0 585 } 586 587 return buf.Byte2Uint64(data[offset:]), checksumFieldBytes 588 } 589 590 func readPayloadSize(flag byte, data []byte, offset int) (int, int) { 591 if flag&flagHashPayload == 0 { 592 return 0, 0 593 } 594 595 return buf.Byte2Int(data[offset:]), payloadSizeFieldBytes 596 } 597 598 func readStreaming(flag byte, msg *RPCMessage, data []byte, offset int) int { 599 if flag&flagStreamingMessage == 0 { 600 return 0 601 } 602 msg.stream = true 603 msg.streamSequence = buf.Byte2Uint32(data[offset:]) 604 return 4 605 } 606 607 func validChecksum(body, payload []byte, expectChecksum uint64) error { 608 checksum := acquireChecksum() 609 defer releaseChecksum(checksum) 610 611 _, err := checksum.Write(body) 612 if err != nil { 613 return err 614 } 615 if len(payload) > 0 { 616 _, err := checksum.Write(payload) 617 if err != nil { 618 return err 619 } 620 } 621 actulChecksum := checksum.Sum64() 622 if actulChecksum != expectChecksum { 623 return moerr.NewInternalErrorNoCtx("checksum mismatch, expect %d, got %d", 624 expectChecksum, 625 actulChecksum) 626 } 627 return nil 628 } 629 630 func setWriterIndexAfterGow(out *buf.ByteBuf, n int) (int, int) { 631 offset := out.Readable() 632 out.Grow(n) 633 out.SetWriteIndex(out.GetReadIndex() + offset + n) 634 return out.GetReadIndex() + offset, offset 635 }