github.com/matrixorigin/matrixone@v1.2.0/pkg/common/morpc/codec.go (about)

     1  // Copyright 2021 - 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package morpc
    16  
    17  import (
    18  	"encoding/hex"
    19  	"io"
    20  	"sync"
    21  
    22  	"github.com/cespare/xxhash/v2"
    23  	"github.com/fagongzi/goetty/v2/buf"
    24  	"github.com/fagongzi/goetty/v2/codec"
    25  	"github.com/fagongzi/goetty/v2/codec/length"
    26  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    27  	"github.com/matrixorigin/matrixone/pkg/common/mpool"
    28  	"github.com/matrixorigin/matrixone/pkg/txn/clock"
    29  	"github.com/pierrec/lz4/v4"
    30  	"go.uber.org/zap"
    31  )
    32  
    33  const (
    34  	flagHashPayload byte = 1 << iota
    35  	flagChecksumEnabled
    36  	flagHasCustomHeader
    37  	flagCompressEnabled
    38  	flagStreamingMessage
    39  	flagPing
    40  	flagPong
    41  )
    42  
    43  var (
    44  	defaultMaxBodyMessageSize = 1024 * 1024 * 100
    45  	checksumFieldBytes        = 8
    46  	totalSizeFieldBytes       = 4
    47  	payloadSizeFieldBytes     = 4
    48  
    49  	approximateHeaderSize = 1024 * 1024 * 10
    50  )
    51  
    52  func GetMessageSize() int {
    53  	return defaultMaxBodyMessageSize
    54  }
    55  
    56  // WithCodecEnableChecksum enable checksum
    57  func WithCodecEnableChecksum() CodecOption {
    58  	return func(c *messageCodec) {
    59  		c.bc.checksumEnabled = true
    60  	}
    61  }
    62  
    63  // WithCodecPayloadCopyBufferSize set payload copy buffer size, if is a PayloadMessage
    64  func WithCodecPayloadCopyBufferSize(value int) CodecOption {
    65  	return func(c *messageCodec) {
    66  		c.bc.payloadBufSize = value
    67  	}
    68  }
    69  
    70  // WithCodecIntegrationHLC intrgration hlc
    71  func WithCodecIntegrationHLC(clock clock.Clock) CodecOption {
    72  	return func(c *messageCodec) {
    73  		c.AddHeaderCodec(&hlcCodec{clock: clock})
    74  	}
    75  }
    76  
    77  // WithCodecMaxBodySize set rpc max body size
    78  func WithCodecMaxBodySize(size int) CodecOption {
    79  	return func(c *messageCodec) {
    80  		if size == 0 {
    81  			size = defaultMaxBodyMessageSize
    82  		}
    83  		c.codec = length.NewWithSize(c.bc, 0, 0, 0, size+approximateHeaderSize)
    84  		c.bc.maxBodySize = size
    85  	}
    86  }
    87  
    88  // WithCodecEnableCompress enable compress body and payload
    89  func WithCodecEnableCompress(pool *mpool.MPool) CodecOption {
    90  	return func(c *messageCodec) {
    91  		c.bc.compressEnabled = true
    92  		c.bc.pool = pool
    93  	}
    94  }
    95  
    96  type messageCodec struct {
    97  	codec codec.Codec
    98  	bc    *baseCodec
    99  }
   100  
   101  // NewMessageCodec create message codec. The message encoding format consists of a message header and a message body.
   102  // Format:
   103  //  1. Size, 4 bytes, required. Inlucde header and body.
   104  //  2. Message header
   105  //     2.1. Flag, 1 byte, required.
   106  //     2.2. Checksum, 8 byte, optional. Set if has a checksun flag
   107  //     2.3. PayloadSize, 4 byte, optional. Set if the message is a morpc.PayloadMessage.
   108  //     2.4. Streaming sequence, 4 byte, optional. Set if the message is in a streaming.
   109  //     2.5. Custom headers, optional. Set if has custom header codecs
   110  //  3. Message body
   111  //     3.1. message body, required.
   112  //     3.2. payload, optional. Set if has paylad flag.
   113  func NewMessageCodec(messageFactory func() Message, options ...CodecOption) Codec {
   114  	bc := &baseCodec{
   115  		messageFactory: messageFactory,
   116  		maxBodySize:    defaultMaxBodyMessageSize,
   117  	}
   118  	c := &messageCodec{
   119  		codec: length.NewWithSize(bc, 0, 0, 0, defaultMaxBodyMessageSize+approximateHeaderSize),
   120  		bc:    bc,
   121  	}
   122  	c.AddHeaderCodec(&deadlineContextCodec{})
   123  	c.AddHeaderCodec(&traceCodec{})
   124  
   125  	for _, opt := range options {
   126  		opt(c)
   127  	}
   128  	return c
   129  }
   130  
   131  func (c *messageCodec) Decode(in *buf.ByteBuf) (any, bool, error) {
   132  	return c.codec.Decode(in)
   133  }
   134  
   135  func (c *messageCodec) Encode(data interface{}, out *buf.ByteBuf, conn io.Writer) error {
   136  	return c.bc.Encode(data, out, conn)
   137  }
   138  
   139  func (c *messageCodec) Valid(msg Message) error {
   140  	n := msg.Size()
   141  	if n >= c.bc.maxBodySize {
   142  		return moerr.NewInternalErrorNoCtx("message body %d is too large, max is %d",
   143  			n,
   144  			c.bc.maxBodySize)
   145  	}
   146  	return nil
   147  }
   148  
   149  func (c *messageCodec) AddHeaderCodec(hc HeaderCodec) {
   150  	c.bc.headerCodecs = append(c.bc.headerCodecs, hc)
   151  }
   152  
   153  type baseCodec struct {
   154  	pool            *mpool.MPool
   155  	checksumEnabled bool
   156  	compressEnabled bool
   157  	payloadBufSize  int
   158  	maxBodySize     int
   159  	messageFactory  func() Message
   160  	headerCodecs    []HeaderCodec
   161  }
   162  
   163  func (c *baseCodec) Decode(in *buf.ByteBuf) (any, bool, error) {
   164  	msg := RPCMessage{}
   165  	offset := 0
   166  	data := getDecodeData(in)
   167  
   168  	// 2.1
   169  	flag, n := c.readFlag(&msg, data, offset)
   170  	offset += n
   171  
   172  	// 2.2
   173  	expectChecksum, n := readChecksum(flag, data, offset)
   174  	offset += n
   175  
   176  	// 2.3
   177  	payloadSize, n := readPayloadSize(flag, data, offset)
   178  	offset += n
   179  
   180  	// 2.4
   181  	n, err := c.readCustomHeaders(flag, &msg, data, offset)
   182  	if err != nil {
   183  		return nil, false, err
   184  	}
   185  	offset += n
   186  
   187  	// 2.5
   188  	offset += readStreaming(flag, &msg, data, offset)
   189  
   190  	// 3.1 and 3.2
   191  	if err := c.readMessage(flag, data, offset, expectChecksum, payloadSize, &msg); err != nil {
   192  		return nil, false, err
   193  	}
   194  
   195  	in.SetReadIndex(in.GetMarkIndex())
   196  	in.ClearMark()
   197  	return msg, true, nil
   198  }
   199  
   200  func (c *baseCodec) Encode(data interface{}, out *buf.ByteBuf, conn io.Writer) error {
   201  	msg, ok := data.(RPCMessage)
   202  	if !ok {
   203  		return moerr.NewInternalErrorNoCtx("not support %T %+v", data, data)
   204  	}
   205  
   206  	startWriteOffset := out.GetWriteOffset()
   207  	totalSize := 0
   208  	// The total message size cannot be determined at the beginning and needs to wait until all the
   209  	// dynamic content is determined before the total size can be determined. After the total size is
   210  	// determined, we need to write the total size data in the location of totalSizeAt
   211  	totalSizeAt := skip(totalSizeFieldBytes, out)
   212  
   213  	// 2.1 flag
   214  	flag := c.getFlag(msg)
   215  	out.MustWriteByte(flag)
   216  	totalSize += 1
   217  
   218  	// 2.2 checksum, similar to totalSize, we do not currently know the size of the message body.
   219  	checksumAt := -1
   220  	if flag&flagChecksumEnabled != 0 {
   221  		checksumAt = skip(checksumFieldBytes, out)
   222  		totalSize += checksumFieldBytes
   223  	}
   224  
   225  	// 2.3 payload
   226  	var payloadData []byte
   227  	var compressedPayloadData []byte
   228  	var payloadMsg PayloadMessage
   229  	var hasPayload bool
   230  
   231  	// skip all written data by this message
   232  	discardWritten := func() {
   233  		out.SetWriteIndexByOffset(startWriteOffset)
   234  		if hasPayload {
   235  			payloadMsg.SetPayloadField(payloadData)
   236  		}
   237  	}
   238  
   239  	if payloadMsg, hasPayload = msg.Message.(PayloadMessage); hasPayload {
   240  		// set payload field to nil to avoid payload being written to the out buffer, and write directly
   241  		// to the socket afterwards to reduce one payload.
   242  		payloadData = payloadMsg.GetPayloadField()
   243  		payloadMsg.SetPayloadField(nil)
   244  		compressedPayloadData = payloadData
   245  
   246  		if c.compressEnabled && len(payloadData) > 0 {
   247  			v, err := c.compress(payloadData)
   248  			if err != nil {
   249  				discardWritten()
   250  				return err
   251  			}
   252  			defer c.pool.Free(v)
   253  			compressedPayloadData = v
   254  		}
   255  
   256  		out.WriteInt(len(compressedPayloadData))
   257  		totalSize += payloadSizeFieldBytes + len(compressedPayloadData)
   258  	}
   259  
   260  	// 2.4 Custom header size
   261  	n, err := c.encodeCustomHeaders(&msg, out)
   262  	if err != nil {
   263  		return err
   264  	}
   265  	totalSize += n
   266  
   267  	// 2.5 streaming message
   268  	if msg.stream {
   269  		out.WriteUint32(msg.streamSequence)
   270  		totalSize += 4
   271  	}
   272  
   273  	// 3.1 message body
   274  	body, err := c.writeBody(out, msg.Message)
   275  	if err != nil {
   276  		discardWritten()
   277  		return err
   278  	}
   279  
   280  	// now, header and body are all determined, we need to fill the totalSize and checksum
   281  	// fill total size
   282  	totalSize += len(body)
   283  	writeIntAt(totalSizeAt, out, totalSize)
   284  
   285  	// fill checksum
   286  	if checksumAt != -1 {
   287  		if err := writeChecksum(checksumAt, out, body, compressedPayloadData); err != nil {
   288  			discardWritten()
   289  			return err
   290  		}
   291  	}
   292  
   293  	// 3.2 payload
   294  	if hasPayload {
   295  		// resume payload to payload message
   296  		payloadMsg.SetPayloadField(payloadData)
   297  		if err := writePayload(out, compressedPayloadData, conn, c.payloadBufSize); err != nil {
   298  			return err
   299  		}
   300  	}
   301  
   302  	return nil
   303  }
   304  
   305  func (c *baseCodec) compress(src []byte) ([]byte, error) {
   306  	n := lz4.CompressBlockBound(len(src))
   307  	dst, err := c.pool.Alloc(n)
   308  	if err != nil {
   309  		return nil, err
   310  	}
   311  	dst, err = c.compressTo(src, dst)
   312  	if err != nil {
   313  		c.pool.Free(dst)
   314  		return nil, err
   315  	}
   316  	return dst, nil
   317  }
   318  
   319  func (c *baseCodec) uncompress(src []byte) ([]byte, error) {
   320  	// The lz4 library requires a []byte with a large enough dst when
   321  	// decompressing, otherwise it will return an ErrInvalidSourceShortBuffer, we
   322  	// can't confirm how large a dst we need to give initially, so when we encounter
   323  	// an ErrInvalidSourceShortBuffer, we expand the size and retry.
   324  	n := len(src) * 2
   325  	for {
   326  		dst, err := c.pool.Alloc(n)
   327  		if err != nil {
   328  			return nil, err
   329  		}
   330  		// the following line of code needs to free dst if it returns an error, but dst
   331  		// has been reassigned
   332  		old := dst
   333  		dst, err = uncompress(src, dst)
   334  		if err == nil {
   335  			return dst, nil
   336  		}
   337  
   338  		c.pool.Free(old)
   339  		if err != lz4.ErrInvalidSourceShortBuffer {
   340  			return nil, err
   341  		}
   342  		n *= 2
   343  	}
   344  }
   345  
   346  func (c *baseCodec) compressTo(src, dst []byte) ([]byte, error) {
   347  	dst, err := compress(src, dst)
   348  	if err != nil {
   349  		return nil, err
   350  	}
   351  	return dst, nil
   352  }
   353  
   354  func (c *baseCodec) compressBound(size int) int {
   355  	return lz4.CompressBlockBound(size)
   356  }
   357  
   358  func (c *baseCodec) getFlag(msg RPCMessage) byte {
   359  	flag := byte(0)
   360  	if c.checksumEnabled {
   361  		flag |= flagChecksumEnabled
   362  	}
   363  	if c.compressEnabled {
   364  		flag |= flagCompressEnabled
   365  	}
   366  	if len(c.headerCodecs) > 0 {
   367  		flag |= flagHasCustomHeader
   368  	}
   369  	if _, ok := msg.Message.(PayloadMessage); ok {
   370  		flag |= flagHashPayload
   371  	}
   372  	if msg.stream {
   373  		flag |= flagStreamingMessage
   374  	}
   375  	if msg.internal {
   376  		if m, ok := msg.Message.(*flagOnlyMessage); ok {
   377  			flag |= m.flag
   378  		}
   379  	}
   380  	return flag
   381  }
   382  
   383  func (c *baseCodec) encodeCustomHeaders(msg *RPCMessage, out *buf.ByteBuf) (int, error) {
   384  	if len(c.headerCodecs) == 0 {
   385  		return 0, nil
   386  	}
   387  
   388  	size := 0
   389  	for _, hc := range c.headerCodecs {
   390  		v, err := hc.Encode(msg, out)
   391  		if err != nil {
   392  			return 0, err
   393  		}
   394  		size += v
   395  	}
   396  	return size, nil
   397  }
   398  
   399  func (c *baseCodec) readCustomHeaders(flag byte, msg *RPCMessage, data []byte, offset int) (int, error) {
   400  	if flag&flagHasCustomHeader == 0 {
   401  		return 0, nil
   402  	}
   403  
   404  	readed := 0
   405  	for _, hc := range c.headerCodecs {
   406  		n, err := hc.Decode(msg, data[offset+readed:])
   407  		if err != nil {
   408  			return 0, err
   409  		}
   410  		readed += n
   411  	}
   412  	return readed, nil
   413  }
   414  
   415  func (c *baseCodec) writeBody(
   416  	out *buf.ByteBuf,
   417  	msg Message) ([]byte, error) {
   418  	size := msg.Size()
   419  	if size == 0 {
   420  		return nil, nil
   421  	}
   422  	if !c.compressEnabled {
   423  		index, _ := setWriterIndexAfterGow(out, size)
   424  		data := out.RawSlice(index, index+size)
   425  		if _, err := msg.MarshalTo(data); err != nil {
   426  			return nil, err
   427  		}
   428  		return data, nil
   429  	}
   430  
   431  	// we use mpool to compress body, then write the dst into the buffer
   432  	origin, err := c.pool.Alloc(size)
   433  	if err != nil {
   434  		return nil, err
   435  	}
   436  	defer c.pool.Free(origin)
   437  	if _, err := msg.MarshalTo(origin); err != nil {
   438  		return nil, err
   439  	}
   440  
   441  	n := c.compressBound(len(origin))
   442  	dst, err := c.pool.Alloc(n)
   443  	if err != nil {
   444  		return nil, err
   445  	}
   446  	defer c.pool.Free(dst)
   447  
   448  	dst, err = compress(origin, dst)
   449  	if err != nil {
   450  		return nil, err
   451  	}
   452  
   453  	index := out.GetWriteOffset()
   454  	out.MustWrite(dst)
   455  	return out.RawSlice(out.GetReadIndex()+index, out.GetWriteIndex()), nil
   456  }
   457  
   458  func (c *baseCodec) readMessage(
   459  	flag byte,
   460  	data []byte,
   461  	offset int,
   462  	expectChecksum uint64,
   463  	payloadSize int,
   464  	msg *RPCMessage) error {
   465  	if offset == len(data) {
   466  		return nil
   467  	}
   468  
   469  	// invalid body packet
   470  	if offset >= len(data)-payloadSize {
   471  		getLogger().Warn("invalid body packet",
   472  			zap.Int("offset", offset),
   473  			zap.Int("len", len(data)),
   474  			zap.Int("payloadSize", payloadSize),
   475  			zap.String("data-hex", hex.EncodeToString(data)))
   476  		return moerr.NewInvalidInputNoCtx("invalid body packet, offset %d, len %d, payload size %d",
   477  			offset, len(data), payloadSize)
   478  	}
   479  
   480  	body := data[offset : len(data)-payloadSize]
   481  	payload := data[len(data)-payloadSize:]
   482  	if flag&flagChecksumEnabled != 0 {
   483  		if err := validChecksum(body, payload, expectChecksum); err != nil {
   484  			return err
   485  		}
   486  	}
   487  
   488  	if flag&flagCompressEnabled != 0 {
   489  		dstBody, err := c.uncompress(body)
   490  		if err != nil {
   491  			return err
   492  		}
   493  		defer c.pool.Free(dstBody)
   494  		body = dstBody
   495  
   496  		if payloadSize > 0 {
   497  			dstPayload, err := c.uncompress(payload)
   498  			if err != nil {
   499  				return err
   500  			}
   501  			// we cannot free dstPayload here as it will be set to the result for subsequent
   502  			// modules.
   503  			// TODO: We currently use copy to avoid memory leaks in mpool, but this introduces
   504  			// a memory allocation overhead that will need to be optimized away.
   505  			defer c.pool.Free(dstPayload)
   506  			payload = clone(dstPayload)
   507  		}
   508  	}
   509  
   510  	if err := msg.Message.Unmarshal(body); err != nil {
   511  		return err
   512  	}
   513  
   514  	if payloadSize > 0 {
   515  		msg.Message.(PayloadMessage).SetPayloadField(payload)
   516  	}
   517  	return nil
   518  }
   519  
   520  var (
   521  	checksumPool = sync.Pool{
   522  		New: func() any {
   523  			return xxhash.New()
   524  		},
   525  	}
   526  )
   527  
   528  func acquireChecksum() *xxhash.Digest {
   529  	return checksumPool.Get().(*xxhash.Digest)
   530  }
   531  
   532  func releaseChecksum(checksum *xxhash.Digest) {
   533  	checksum.Reset()
   534  	checksumPool.Put(checksum)
   535  }
   536  
   537  func skip(n int, out *buf.ByteBuf) int {
   538  	_, offset := setWriterIndexAfterGow(out, n)
   539  	return offset
   540  }
   541  
   542  func writeIntAt(offset int, out *buf.ByteBuf, value int) {
   543  	idx := out.GetReadIndex() + offset
   544  	buf.Int2BytesTo(value, out.RawSlice(idx, idx+4))
   545  }
   546  
   547  func writeUint64At(offset int, out *buf.ByteBuf, value uint64) {
   548  	idx := out.GetReadIndex() + offset
   549  	buf.Uint64ToBytesTo(value, out.RawSlice(idx, idx+8))
   550  }
   551  
   552  func writePayload(out *buf.ByteBuf, payload []byte, conn io.Writer, copyBuffer int) error {
   553  	if len(payload) == 0 {
   554  		return nil
   555  	}
   556  
   557  	// reset here to avoid buffer expansion as much as possible
   558  	defer out.Reset()
   559  
   560  	// first, write header and body to socket
   561  	if _, err := out.WriteTo(conn); err != nil {
   562  		return err
   563  	}
   564  
   565  	// write payload to socket
   566  	if err := buf.WriteTo(payload, conn, copyBuffer); err != nil {
   567  		return err
   568  	}
   569  	return nil
   570  }
   571  
   572  func writeChecksum(offset int, out *buf.ByteBuf, body, payload []byte) error {
   573  	checksum := acquireChecksum()
   574  	defer releaseChecksum(checksum)
   575  
   576  	_, err := checksum.Write(body)
   577  	if err != nil {
   578  		return err
   579  	}
   580  	if len(payload) > 0 {
   581  		_, err = checksum.Write(payload)
   582  		if err != nil {
   583  			return err
   584  		}
   585  	}
   586  	writeUint64At(offset, out, checksum.Sum64())
   587  	return nil
   588  }
   589  
   590  func getDecodeData(in *buf.ByteBuf) []byte {
   591  	return in.RawSlice(in.GetReadIndex(), in.GetMarkIndex())
   592  }
   593  
   594  func (c *baseCodec) readFlag(msg *RPCMessage, data []byte, offset int) (byte, int) {
   595  	flag := data[offset]
   596  	if flag&flagPing != 0 {
   597  		msg.Message = &flagOnlyMessage{flag: flagPing}
   598  		msg.internal = true
   599  	} else if flag&flagPong != 0 {
   600  		msg.Message = &flagOnlyMessage{flag: flagPong}
   601  		msg.internal = true
   602  	} else {
   603  		msg.Message = c.messageFactory()
   604  	}
   605  	return flag, 1
   606  }
   607  
   608  func readChecksum(flag byte, data []byte, offset int) (uint64, int) {
   609  	if flag&flagChecksumEnabled == 0 {
   610  		return 0, 0
   611  	}
   612  
   613  	return buf.Byte2Uint64(data[offset:]), checksumFieldBytes
   614  }
   615  
   616  func readPayloadSize(flag byte, data []byte, offset int) (int, int) {
   617  	if flag&flagHashPayload == 0 {
   618  		return 0, 0
   619  	}
   620  
   621  	return buf.Byte2Int(data[offset:]), payloadSizeFieldBytes
   622  }
   623  
   624  func readStreaming(flag byte, msg *RPCMessage, data []byte, offset int) int {
   625  	if flag&flagStreamingMessage == 0 {
   626  		return 0
   627  	}
   628  	msg.stream = true
   629  	msg.streamSequence = buf.Byte2Uint32(data[offset:])
   630  	return 4
   631  }
   632  
   633  func validChecksum(body, payload []byte, expectChecksum uint64) error {
   634  	checksum := acquireChecksum()
   635  	defer releaseChecksum(checksum)
   636  
   637  	_, err := checksum.Write(body)
   638  	if err != nil {
   639  		return err
   640  	}
   641  	if len(payload) > 0 {
   642  		_, err := checksum.Write(payload)
   643  		if err != nil {
   644  			return err
   645  		}
   646  	}
   647  	actulChecksum := checksum.Sum64()
   648  	if actulChecksum != expectChecksum {
   649  		return moerr.NewInternalErrorNoCtx("checksum mismatch, expect %d, got %d",
   650  			expectChecksum,
   651  			actulChecksum)
   652  	}
   653  	return nil
   654  }
   655  
   656  func setWriterIndexAfterGow(out *buf.ByteBuf, n int) (int, int) {
   657  	offset := out.Readable()
   658  	out.Grow(n)
   659  	out.SetWriteIndex(out.GetReadIndex() + offset + n)
   660  	return out.GetReadIndex() + offset, offset
   661  }
   662  
   663  func clone(src []byte) []byte {
   664  	v := make([]byte, len(src))
   665  	copy(v, src)
   666  	return v
   667  }