github.com/matrixorigin/matrixone@v0.7.0/pkg/common/morpc/codec.go (about)

     1  // Copyright 2021 - 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package morpc
    16  
    17  import (
    18  	"io"
    19  	"sync"
    20  
    21  	"github.com/cespare/xxhash/v2"
    22  	"github.com/fagongzi/goetty/v2/buf"
    23  	"github.com/fagongzi/goetty/v2/codec"
    24  	"github.com/fagongzi/goetty/v2/codec/length"
    25  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    26  	"github.com/matrixorigin/matrixone/pkg/common/mpool"
    27  	"github.com/matrixorigin/matrixone/pkg/txn/clock"
    28  	"github.com/pierrec/lz4/v4"
    29  )
    30  
    31  const (
    32  	flagHashPayload byte = 1 << iota
    33  	flagChecksumEnabled
    34  	flagHasCustomHeader
    35  	flagCompressEnabled
    36  	flagStreamingMessage
    37  	flagPing
    38  	flagPong
    39  )
    40  
    41  var (
    42  	defaultMaxBodyMessageSize = 1024 * 1024 * 100
    43  	checksumFieldBytes        = 8
    44  	totalSizeFieldBytes       = 4
    45  	payloadSizeFieldBytes     = 4
    46  
    47  	approximateHeaderSize = 1024 * 1024 * 10
    48  )
    49  
    50  func GetMessageSize() int {
    51  	return defaultMaxBodyMessageSize
    52  }
    53  
    54  // WithCodecEnableChecksum enable checksum
    55  func WithCodecEnableChecksum() CodecOption {
    56  	return func(c *messageCodec) {
    57  		c.bc.checksumEnabled = true
    58  	}
    59  }
    60  
    61  // WithCodecPayloadCopyBufferSize set payload copy buffer size, if is a PayloadMessage
    62  func WithCodecPayloadCopyBufferSize(value int) CodecOption {
    63  	return func(c *messageCodec) {
    64  		c.bc.payloadBufSize = value
    65  	}
    66  }
    67  
    68  // WithCodecIntegrationHLC intrgration hlc
    69  func WithCodecIntegrationHLC(clock clock.Clock) CodecOption {
    70  	return func(c *messageCodec) {
    71  		c.AddHeaderCodec(&hlcCodec{clock: clock})
    72  	}
    73  }
    74  
    75  // WithCodecMaxBodySize set rpc max body size
    76  func WithCodecMaxBodySize(size int) CodecOption {
    77  	return func(c *messageCodec) {
    78  		if size == 0 {
    79  			size = defaultMaxBodyMessageSize
    80  		}
    81  		c.codec = length.NewWithSize(c.bc, 0, 0, 0, size+approximateHeaderSize)
    82  		c.bc.maxBodySize = size
    83  	}
    84  }
    85  
    86  // WithCodecEnableCompress enable compress body and payload
    87  func WithCodecEnableCompress(pool *mpool.MPool) CodecOption {
    88  	return func(c *messageCodec) {
    89  		c.bc.compressEnabled = true
    90  		c.bc.pool = pool
    91  	}
    92  }
    93  
    94  type messageCodec struct {
    95  	codec codec.Codec
    96  	bc    *baseCodec
    97  }
    98  
    99  // NewMessageCodec create message codec. The message encoding format consists of a message header and a message body.
   100  // Format:
   101  //  1. Size, 4 bytes, required. Inlucde header and body.
   102  //  2. Message header
   103  //     2.1. Flag, 1 byte, required.
   104  //     2.2. Checksum, 8 byte, optional. Set if has a checksun flag
   105  //     2.3. PayloadSize, 4 byte, optional. Set if the message is a morpc.PayloadMessage.
   106  //     2.4. Streaming sequence, 4 byte, optional. Set if the message is in a streaming.
   107  //     2.5. Custom headers, optional. Set if has custom header codecs
   108  //  3. Message body
   109  //     3.1. message body, required.
   110  //     3.2. payload, optional. Set if has paylad flag.
   111  func NewMessageCodec(messageFactory func() Message, options ...CodecOption) Codec {
   112  	bc := &baseCodec{
   113  		messageFactory: messageFactory,
   114  		maxBodySize:    defaultMaxBodyMessageSize,
   115  	}
   116  	c := &messageCodec{
   117  		codec: length.NewWithSize(bc, 0, 0, 0, defaultMaxBodyMessageSize+approximateHeaderSize),
   118  		bc:    bc,
   119  	}
   120  	c.AddHeaderCodec(&deadlineContextCodec{})
   121  	c.AddHeaderCodec(&traceCodec{})
   122  
   123  	for _, opt := range options {
   124  		opt(c)
   125  	}
   126  	return c
   127  }
   128  
   129  func (c *messageCodec) Decode(in *buf.ByteBuf) (any, bool, error) {
   130  	return c.codec.Decode(in)
   131  }
   132  
   133  func (c *messageCodec) Encode(data interface{}, out *buf.ByteBuf, conn io.Writer) error {
   134  	return c.bc.Encode(data, out, conn)
   135  }
   136  
   137  func (c *messageCodec) Valid(msg Message) error {
   138  	n := msg.Size()
   139  	if n >= c.bc.maxBodySize {
   140  		return moerr.NewInternalErrorNoCtx("message body %d is too large, max is %d",
   141  			n,
   142  			c.bc.maxBodySize)
   143  	}
   144  	return nil
   145  }
   146  
   147  func (c *messageCodec) AddHeaderCodec(hc HeaderCodec) {
   148  	c.bc.headerCodecs = append(c.bc.headerCodecs, hc)
   149  }
   150  
   151  type baseCodec struct {
   152  	pool            *mpool.MPool
   153  	checksumEnabled bool
   154  	compressEnabled bool
   155  	payloadBufSize  int
   156  	maxBodySize     int
   157  	messageFactory  func() Message
   158  	headerCodecs    []HeaderCodec
   159  }
   160  
   161  func (c *baseCodec) Decode(in *buf.ByteBuf) (any, bool, error) {
   162  	msg := RPCMessage{}
   163  	offset := 0
   164  	data := getDecodeData(in)
   165  
   166  	// 2.1
   167  	flag, n := c.readFlag(&msg, data, offset)
   168  	offset += n
   169  
   170  	// 2.2
   171  	expectChecksum, n := readChecksum(flag, data, offset)
   172  	offset += n
   173  
   174  	// 2.3
   175  	payloadSize, n := readPayloadSize(flag, data, offset)
   176  	offset += n
   177  
   178  	// 2.4
   179  	n, err := c.readCustomHeaders(flag, &msg, data, offset)
   180  	if err != nil {
   181  		return nil, false, err
   182  	}
   183  	offset += n
   184  
   185  	// 2.5
   186  	offset += readStreaming(flag, &msg, data, offset)
   187  
   188  	// 3.1 and 3.2
   189  	if err := c.readMessage(flag, data, offset, expectChecksum, payloadSize, &msg); err != nil {
   190  		return nil, false, err
   191  	}
   192  
   193  	in.SetReadIndex(in.GetMarkIndex())
   194  	in.ClearMark()
   195  	return msg, true, nil
   196  }
   197  
   198  func (c *baseCodec) Encode(data interface{}, out *buf.ByteBuf, conn io.Writer) error {
   199  	msg, ok := data.(RPCMessage)
   200  	if !ok {
   201  		return moerr.NewInternalErrorNoCtx("not support %T %+v", data, data)
   202  	}
   203  
   204  	startWriteOffset := out.GetWriteOffset()
   205  	totalSize := 0
   206  	// The total message size cannot be determined at the beginning and needs to wait until all the
   207  	// dynamic content is determined before the total size can be determined. After the total size is
   208  	// determined, we need to write the total size data in the location of totalSizeAt
   209  	totalSizeAt := skip(totalSizeFieldBytes, out)
   210  
   211  	// 2.1 flag
   212  	flag := c.getFlag(msg)
   213  	out.MustWriteByte(flag)
   214  	totalSize += 1
   215  
   216  	// 2.2 checksum, similar to totalSize, we do not currently know the size of the message body.
   217  	checksumAt := -1
   218  	if flag&flagChecksumEnabled != 0 {
   219  		checksumAt = skip(checksumFieldBytes, out)
   220  		totalSize += checksumFieldBytes
   221  	}
   222  
   223  	// 2.3 payload
   224  	var payloadData []byte
   225  	var compressedPayloadData []byte
   226  	var payloadMsg PayloadMessage
   227  	var hasPayload bool
   228  
   229  	// skip all written data by this message
   230  	discardWritten := func() {
   231  		out.SetWriteIndexByOffset(startWriteOffset)
   232  		if hasPayload {
   233  			payloadMsg.SetPayloadField(payloadData)
   234  		}
   235  	}
   236  
   237  	if payloadMsg, hasPayload = msg.Message.(PayloadMessage); hasPayload {
   238  		// set payload filed to nil to avoid payload being written to the out buffer, and write directly
   239  		// to the socket afterwards to reduce one payload.
   240  		payloadData = payloadMsg.GetPayloadField()
   241  		payloadMsg.SetPayloadField(nil)
   242  		compressedPayloadData = payloadData
   243  
   244  		if c.compressEnabled && len(payloadData) > 0 {
   245  			v, err := c.compress(payloadData)
   246  			if err != nil {
   247  				discardWritten()
   248  				return err
   249  			}
   250  			defer c.pool.Free(v)
   251  			compressedPayloadData = v
   252  		}
   253  
   254  		out.WriteInt(len(compressedPayloadData))
   255  		totalSize += payloadSizeFieldBytes + len(compressedPayloadData)
   256  	}
   257  
   258  	// 2.4 Custom header size
   259  	n, err := c.encodeCustomHeaders(&msg, out)
   260  	if err != nil {
   261  		return err
   262  	}
   263  	totalSize += n
   264  
   265  	// 2.5 streaming message
   266  	if msg.stream {
   267  		out.WriteUint32(msg.streamSequence)
   268  		totalSize += 4
   269  	}
   270  
   271  	// 3.1 message body
   272  	body, err := c.writeBody(out, msg.Message)
   273  	if err != nil {
   274  		discardWritten()
   275  		return err
   276  	}
   277  
   278  	// now, header and body are all determined, we need to fill the totalSize and checksum
   279  	// fill total size
   280  	totalSize += len(body)
   281  	writeIntAt(totalSizeAt, out, totalSize)
   282  
   283  	// fill checksum
   284  	if checksumAt != -1 {
   285  		if err := writeChecksum(checksumAt, out, body, compressedPayloadData); err != nil {
   286  			discardWritten()
   287  			return err
   288  		}
   289  	}
   290  
   291  	// 3.2 payload
   292  	if hasPayload {
   293  		// resume payload to payload message
   294  		payloadMsg.SetPayloadField(payloadData)
   295  		if err := writePayload(out, compressedPayloadData, conn, c.payloadBufSize); err != nil {
   296  			return err
   297  		}
   298  	}
   299  
   300  	return nil
   301  }
   302  
   303  func (c *baseCodec) compress(src []byte) ([]byte, error) {
   304  	n := lz4.CompressBlockBound(len(src))
   305  	dst, err := c.pool.Alloc(n)
   306  	if err != nil {
   307  		return nil, err
   308  	}
   309  	dst, err = c.compressTo(src, dst)
   310  	if err != nil {
   311  		c.pool.Free(dst)
   312  		return nil, err
   313  	}
   314  	return dst, nil
   315  }
   316  
   317  func (c *baseCodec) uncompress(src []byte) ([]byte, error) {
   318  	// The lz4 library requires a []byte with a large enough dst when
   319  	// decompressing, otherwise it will return an ErrInvalidSourceShortBuffer, we
   320  	// can't confirm how large a dst we need to give initially, so when we encounter
   321  	// an ErrInvalidSourceShortBuffer, we expand the size and retry.
   322  	n := len(src) * 2
   323  	for {
   324  		dst, err := c.pool.Alloc(n)
   325  		if err != nil {
   326  			return nil, err
   327  		}
   328  		dst, err = uncompress(src, dst)
   329  		if err == nil {
   330  			return dst, nil
   331  		}
   332  
   333  		c.pool.Free(dst)
   334  		if err != lz4.ErrInvalidSourceShortBuffer {
   335  			return nil, err
   336  		}
   337  		n *= 2
   338  	}
   339  }
   340  
   341  func (c *baseCodec) compressTo(src, dst []byte) ([]byte, error) {
   342  	dst, err := compress(src, dst)
   343  	if err != nil {
   344  		return nil, err
   345  	}
   346  	return dst, nil
   347  }
   348  
   349  func (c *baseCodec) compressBound(size int) int {
   350  	return lz4.CompressBlockBound(size)
   351  }
   352  
   353  func (c *baseCodec) getFlag(msg RPCMessage) byte {
   354  	flag := byte(0)
   355  	if c.checksumEnabled {
   356  		flag |= flagChecksumEnabled
   357  	}
   358  	if c.compressEnabled {
   359  		flag |= flagCompressEnabled
   360  	}
   361  	if len(c.headerCodecs) > 0 {
   362  		flag |= flagHasCustomHeader
   363  	}
   364  	if _, ok := msg.Message.(PayloadMessage); ok {
   365  		flag |= flagHashPayload
   366  	}
   367  	if msg.stream {
   368  		flag |= flagStreamingMessage
   369  	}
   370  	if msg.internal {
   371  		if m, ok := msg.Message.(*flagOnlyMessage); ok {
   372  			flag |= m.flag
   373  		}
   374  	}
   375  	return flag
   376  }
   377  
   378  func (c *baseCodec) encodeCustomHeaders(msg *RPCMessage, out *buf.ByteBuf) (int, error) {
   379  	if len(c.headerCodecs) == 0 {
   380  		return 0, nil
   381  	}
   382  
   383  	size := 0
   384  	for _, hc := range c.headerCodecs {
   385  		v, err := hc.Encode(msg, out)
   386  		if err != nil {
   387  			return 0, err
   388  		}
   389  		size += v
   390  	}
   391  	return size, nil
   392  }
   393  
   394  func (c *baseCodec) readCustomHeaders(flag byte, msg *RPCMessage, data []byte, offset int) (int, error) {
   395  	if flag&flagHasCustomHeader == 0 {
   396  		return 0, nil
   397  	}
   398  
   399  	readed := 0
   400  	for _, hc := range c.headerCodecs {
   401  		n, err := hc.Decode(msg, data[offset+readed:])
   402  		if err != nil {
   403  			return 0, err
   404  		}
   405  		readed += n
   406  	}
   407  	return readed, nil
   408  }
   409  
   410  func (c *baseCodec) writeBody(
   411  	out *buf.ByteBuf,
   412  	msg Message) ([]byte, error) {
   413  	size := msg.Size()
   414  	if size == 0 {
   415  		return nil, nil
   416  	}
   417  	if !c.compressEnabled {
   418  		index, _ := setWriterIndexAfterGow(out, size)
   419  		data := out.RawSlice(index, index+size)
   420  		if _, err := msg.MarshalTo(data); err != nil {
   421  			return nil, err
   422  		}
   423  		return data, nil
   424  	}
   425  
   426  	// we use mpool to compress body, then write the dst into the buffer
   427  	origin, err := c.pool.Alloc(size)
   428  	if err != nil {
   429  		return nil, err
   430  	}
   431  	defer c.pool.Free(origin)
   432  	if _, err := msg.MarshalTo(origin); err != nil {
   433  		return nil, err
   434  	}
   435  
   436  	n := c.compressBound(len(origin))
   437  	dst, err := c.pool.Alloc(n)
   438  	if err != nil {
   439  		return nil, err
   440  	}
   441  	defer c.pool.Free(dst)
   442  
   443  	dst, err = compress(origin, dst)
   444  	if err != nil {
   445  		return nil, err
   446  	}
   447  
   448  	index := out.GetWriteOffset()
   449  	out.MustWrite(dst)
   450  	return out.RawSlice(out.GetReadIndex()+index, out.GetWriteIndex()), nil
   451  }
   452  
   453  func (c *baseCodec) readMessage(flag byte, data []byte, offset int, expectChecksum uint64, payloadSize int, msg *RPCMessage) error {
   454  	if offset == len(data) {
   455  		return nil
   456  	}
   457  
   458  	body := data[offset : len(data)-payloadSize]
   459  	payload := data[len(data)-payloadSize:]
   460  	if flag&flagChecksumEnabled != 0 {
   461  		if err := validChecksum(body, payload, expectChecksum); err != nil {
   462  			return err
   463  		}
   464  	}
   465  
   466  	if flag&flagCompressEnabled != 0 {
   467  		dstBody, err := c.uncompress(body)
   468  		if err != nil {
   469  			return err
   470  		}
   471  		defer c.pool.Free(dstBody)
   472  		body = dstBody
   473  
   474  		if payloadSize > 0 {
   475  			dstPayload, err := c.uncompress(payload)
   476  			if err != nil {
   477  				return err
   478  			}
   479  			defer c.pool.Free(dstPayload)
   480  			payload = dstPayload
   481  		}
   482  	}
   483  
   484  	if err := msg.Message.Unmarshal(body); err != nil {
   485  		return err
   486  	}
   487  
   488  	if payloadSize > 0 {
   489  		msg.Message.(PayloadMessage).SetPayloadField(payload)
   490  	}
   491  	return nil
   492  }
   493  
   494  var (
   495  	checksumPool = sync.Pool{
   496  		New: func() any {
   497  			return xxhash.New()
   498  		},
   499  	}
   500  )
   501  
   502  func acquireChecksum() *xxhash.Digest {
   503  	return checksumPool.Get().(*xxhash.Digest)
   504  }
   505  
   506  func releaseChecksum(checksum *xxhash.Digest) {
   507  	checksum.Reset()
   508  	checksumPool.Put(checksum)
   509  }
   510  
   511  func skip(n int, out *buf.ByteBuf) int {
   512  	_, offset := setWriterIndexAfterGow(out, n)
   513  	return offset
   514  }
   515  
   516  func writeIntAt(offset int, out *buf.ByteBuf, value int) {
   517  	idx := out.GetReadIndex() + offset
   518  	buf.Int2BytesTo(value, out.RawSlice(idx, idx+4))
   519  }
   520  
   521  func writeUint64At(offset int, out *buf.ByteBuf, value uint64) {
   522  	idx := out.GetReadIndex() + offset
   523  	buf.Uint64ToBytesTo(value, out.RawSlice(idx, idx+8))
   524  }
   525  
   526  func writePayload(out *buf.ByteBuf, payload []byte, conn io.Writer, copyBuffer int) error {
   527  	if len(payload) == 0 {
   528  		return nil
   529  	}
   530  
   531  	// reset here to avoid buffer expansion as much as possible
   532  	defer out.Reset()
   533  
   534  	// first, write header and body to socket
   535  	if _, err := out.WriteTo(conn); err != nil {
   536  		return err
   537  	}
   538  
   539  	// write payload to socket
   540  	if err := buf.WriteTo(payload, conn, copyBuffer); err != nil {
   541  		return err
   542  	}
   543  	return nil
   544  }
   545  
   546  func writeChecksum(offset int, out *buf.ByteBuf, body, payload []byte) error {
   547  	checksum := acquireChecksum()
   548  	defer releaseChecksum(checksum)
   549  
   550  	_, err := checksum.Write(body)
   551  	if err != nil {
   552  		return err
   553  	}
   554  	if len(payload) > 0 {
   555  		_, err = checksum.Write(payload)
   556  		if err != nil {
   557  			return err
   558  		}
   559  	}
   560  	writeUint64At(offset, out, checksum.Sum64())
   561  	return nil
   562  }
   563  
   564  func getDecodeData(in *buf.ByteBuf) []byte {
   565  	return in.RawSlice(in.GetReadIndex(), in.GetMarkIndex())
   566  }
   567  
   568  func (c *baseCodec) readFlag(msg *RPCMessage, data []byte, offset int) (byte, int) {
   569  	flag := data[offset]
   570  	if flag&flagPing != 0 {
   571  		msg.Message = &flagOnlyMessage{flag: flagPing}
   572  		msg.internal = true
   573  	} else if flag&flagPong != 0 {
   574  		msg.Message = &flagOnlyMessage{flag: flagPong}
   575  		msg.internal = true
   576  	} else {
   577  		msg.Message = c.messageFactory()
   578  	}
   579  	return flag, 1
   580  }
   581  
   582  func readChecksum(flag byte, data []byte, offset int) (uint64, int) {
   583  	if flag&flagChecksumEnabled == 0 {
   584  		return 0, 0
   585  	}
   586  
   587  	return buf.Byte2Uint64(data[offset:]), checksumFieldBytes
   588  }
   589  
   590  func readPayloadSize(flag byte, data []byte, offset int) (int, int) {
   591  	if flag&flagHashPayload == 0 {
   592  		return 0, 0
   593  	}
   594  
   595  	return buf.Byte2Int(data[offset:]), payloadSizeFieldBytes
   596  }
   597  
   598  func readStreaming(flag byte, msg *RPCMessage, data []byte, offset int) int {
   599  	if flag&flagStreamingMessage == 0 {
   600  		return 0
   601  	}
   602  	msg.stream = true
   603  	msg.streamSequence = buf.Byte2Uint32(data[offset:])
   604  	return 4
   605  }
   606  
   607  func validChecksum(body, payload []byte, expectChecksum uint64) error {
   608  	checksum := acquireChecksum()
   609  	defer releaseChecksum(checksum)
   610  
   611  	_, err := checksum.Write(body)
   612  	if err != nil {
   613  		return err
   614  	}
   615  	if len(payload) > 0 {
   616  		_, err := checksum.Write(payload)
   617  		if err != nil {
   618  			return err
   619  		}
   620  	}
   621  	actulChecksum := checksum.Sum64()
   622  	if actulChecksum != expectChecksum {
   623  		return moerr.NewInternalErrorNoCtx("checksum mismatch, expect %d, got %d",
   624  			expectChecksum,
   625  			actulChecksum)
   626  	}
   627  	return nil
   628  }
   629  
   630  func setWriterIndexAfterGow(out *buf.ByteBuf, n int) (int, int) {
   631  	offset := out.Readable()
   632  	out.Grow(n)
   633  	out.SetWriteIndex(out.GetReadIndex() + offset + n)
   634  	return out.GetReadIndex() + offset, offset
   635  }