github.com/cloudwego/kitex@v0.9.0/pkg/remote/codec/default_codec.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package codec
    18  
    19  import (
    20  	"context"
    21  	"encoding/binary"
    22  	"fmt"
    23  	"sync/atomic"
    24  
    25  	"github.com/cloudwego/kitex/pkg/kerrors"
    26  	"github.com/cloudwego/kitex/pkg/remote"
    27  	"github.com/cloudwego/kitex/pkg/remote/codec/perrors"
    28  	"github.com/cloudwego/kitex/pkg/retry"
    29  	"github.com/cloudwego/kitex/pkg/rpcinfo"
    30  	"github.com/cloudwego/kitex/pkg/serviceinfo"
    31  	"github.com/cloudwego/kitex/transport"
    32  )
    33  
    34  // The byte count of 32 and 16 integer values.
    35  const (
    36  	Size32 = 4
    37  	Size16 = 2
    38  )
    39  
    40  const (
    41  	// ThriftV1Magic is the magic code for thrift.VERSION_1
    42  	ThriftV1Magic = 0x80010000
    43  	// ProtobufV1Magic is the magic code for kitex protobuf
    44  	ProtobufV1Magic = 0x90010000
    45  
    46  	// MagicMask is bit mask for checking version.
    47  	MagicMask = 0xffff0000
    48  )
    49  
    50  var (
    51  	ttHeaderCodec   = ttHeader{}
    52  	meshHeaderCodec = meshHeader{}
    53  
    54  	_ remote.Codec       = (*defaultCodec)(nil)
    55  	_ remote.MetaDecoder = (*defaultCodec)(nil)
    56  )
    57  
    58  // NewDefaultCodec creates the default protocol sniffing codec supporting thrift and protobuf.
    59  func NewDefaultCodec() remote.Codec {
    60  	// No size limit by default
    61  	return &defaultCodec{
    62  		maxSize: 0,
    63  	}
    64  }
    65  
    66  // NewDefaultCodecWithSizeLimit creates the default protocol sniffing codec supporting thrift and protobuf but with size limit.
    67  // maxSize is in bytes
    68  func NewDefaultCodecWithSizeLimit(maxSize int) remote.Codec {
    69  	return &defaultCodec{
    70  		maxSize: maxSize,
    71  	}
    72  }
    73  
    74  type defaultCodec struct {
    75  	// maxSize limits the max size of the payload
    76  	maxSize int
    77  }
    78  
    79  // EncodePayload encode payload
    80  func (c *defaultCodec) EncodePayload(ctx context.Context, message remote.Message, out remote.ByteBuffer) error {
    81  	defer func() {
    82  		// notice: mallocLen() must exec before flush, or it will be reset
    83  		if ri := message.RPCInfo(); ri != nil {
    84  			if ms := rpcinfo.AsMutableRPCStats(ri.Stats()); ms != nil {
    85  				ms.SetSendSize(uint64(out.MallocLen()))
    86  			}
    87  		}
    88  	}()
    89  	var err error
    90  	var framedLenField []byte
    91  	headerLen := out.MallocLen()
    92  	tp := message.ProtocolInfo().TransProto
    93  
    94  	// 1. malloc framed field if needed
    95  	if tp&transport.Framed == transport.Framed {
    96  		if framedLenField, err = out.Malloc(Size32); err != nil {
    97  			return err
    98  		}
    99  		headerLen += Size32
   100  	}
   101  
   102  	// 2. encode payload
   103  	if err = c.encodePayload(ctx, message, out); err != nil {
   104  		return err
   105  	}
   106  
   107  	// 3. fill framed field if needed
   108  	var payloadLen int
   109  	if tp&transport.Framed == transport.Framed {
   110  		if framedLenField == nil {
   111  			return perrors.NewProtocolErrorWithMsg("no buffer allocated for the framed length field")
   112  		}
   113  		payloadLen = out.MallocLen() - headerLen
   114  		binary.BigEndian.PutUint32(framedLenField, uint32(payloadLen))
   115  	} else if message.ProtocolInfo().CodecType == serviceinfo.Protobuf {
   116  		return perrors.NewProtocolErrorWithMsg("protobuf just support 'framed' trans proto")
   117  	}
   118  	if tp&transport.TTHeader == transport.TTHeader {
   119  		payloadLen = out.MallocLen() - Size32
   120  	}
   121  	err = checkPayloadSize(payloadLen, c.maxSize)
   122  	return err
   123  }
   124  
   125  // EncodeMetaAndPayload encode meta and payload
   126  func (c *defaultCodec) EncodeMetaAndPayload(ctx context.Context, message remote.Message, out remote.ByteBuffer, me remote.MetaEncoder) error {
   127  	var err error
   128  	var totalLenField []byte
   129  	tp := message.ProtocolInfo().TransProto
   130  
   131  	// 1. encode header and return totalLenField if needed
   132  	// totalLenField will be filled after payload encoded
   133  	if tp&transport.TTHeader == transport.TTHeader {
   134  		if totalLenField, err = ttHeaderCodec.encode(ctx, message, out); err != nil {
   135  			return err
   136  		}
   137  	}
   138  	// 2. encode payload
   139  	if err = me.EncodePayload(ctx, message, out); err != nil {
   140  		return err
   141  	}
   142  	// 3. fill totalLen field for header if needed
   143  	if tp&transport.TTHeader == transport.TTHeader {
   144  		if totalLenField == nil {
   145  			return perrors.NewProtocolErrorWithMsg("no buffer allocated for the header length field")
   146  		}
   147  		payloadLen := out.MallocLen() - Size32
   148  		binary.BigEndian.PutUint32(totalLenField, uint32(payloadLen))
   149  	}
   150  	return nil
   151  }
   152  
   153  // Encode implements the remote.Codec interface, it does complete message encode include header and payload.
   154  func (c *defaultCodec) Encode(ctx context.Context, message remote.Message, out remote.ByteBuffer) (err error) {
   155  	return c.EncodeMetaAndPayload(ctx, message, out, c)
   156  }
   157  
   158  // DecodeMeta decode header
   159  func (c *defaultCodec) DecodeMeta(ctx context.Context, message remote.Message, in remote.ByteBuffer) (err error) {
   160  	var flagBuf []byte
   161  	if flagBuf, err = in.Peek(2 * Size32); err != nil {
   162  		return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("default codec read failed: %s", err.Error()))
   163  	}
   164  
   165  	if err = checkRPCState(ctx, message); err != nil {
   166  		// there is one call has finished in retry task, it doesn't need to do decode for this call
   167  		return err
   168  	}
   169  	isTTHeader := IsTTHeader(flagBuf)
   170  	// 1. decode header
   171  	if isTTHeader {
   172  		// TTHeader
   173  		if err = ttHeaderCodec.decode(ctx, message, in); err != nil {
   174  			return err
   175  		}
   176  		if flagBuf, err = in.Peek(2 * Size32); err != nil {
   177  			return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("ttheader read payload first 8 byte failed: %s", err.Error()))
   178  		}
   179  	} else if isMeshHeader(flagBuf) {
   180  		message.Tags()[remote.MeshHeader] = true
   181  		// MeshHeader
   182  		if err = meshHeaderCodec.decode(ctx, message, in); err != nil {
   183  			return err
   184  		}
   185  		if flagBuf, err = in.Peek(2 * Size32); err != nil {
   186  			return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("meshHeader read payload first 8 byte failed: %s", err.Error()))
   187  		}
   188  	}
   189  	return checkPayload(flagBuf, message, in, isTTHeader, c.maxSize)
   190  }
   191  
   192  // DecodePayload decode payload
   193  func (c *defaultCodec) DecodePayload(ctx context.Context, message remote.Message, in remote.ByteBuffer) error {
   194  	defer func() {
   195  		if ri := message.RPCInfo(); ri != nil {
   196  			if ms := rpcinfo.AsMutableRPCStats(ri.Stats()); ms != nil {
   197  				ms.SetRecvSize(uint64(in.ReadLen()))
   198  			}
   199  		}
   200  	}()
   201  
   202  	hasRead := in.ReadLen()
   203  	pCodec, err := remote.GetPayloadCodec(message)
   204  	if err != nil {
   205  		return err
   206  	}
   207  	if err = pCodec.Unmarshal(ctx, message, in); err != nil {
   208  		return err
   209  	}
   210  	if message.PayloadLen() == 0 {
   211  		// if protocol is PurePayload, should set payload length after decoded
   212  		message.SetPayloadLen(in.ReadLen() - hasRead)
   213  	}
   214  	return nil
   215  }
   216  
   217  // Decode implements the remote.Codec interface, it does complete message decode include header and payload.
   218  func (c *defaultCodec) Decode(ctx context.Context, message remote.Message, in remote.ByteBuffer) (err error) {
   219  	// 1. decode meta
   220  	if err = c.DecodeMeta(ctx, message, in); err != nil {
   221  		return err
   222  	}
   223  
   224  	// 2. decode payload
   225  	return c.DecodePayload(ctx, message, in)
   226  }
   227  
   228  func (c *defaultCodec) Name() string {
   229  	return "default"
   230  }
   231  
   232  // Select to use thrift or protobuf according to the protocol.
   233  func (c *defaultCodec) encodePayload(ctx context.Context, message remote.Message, out remote.ByteBuffer) error {
   234  	pCodec, err := remote.GetPayloadCodec(message)
   235  	if err != nil {
   236  		return err
   237  	}
   238  	return pCodec.Marshal(ctx, message, out)
   239  }
   240  
   241  /**
   242   * +------------------------------------------------------------+
   243   * |                  4Byte                 |       2Byte       |
   244   * +------------------------------------------------------------+
   245   * |   			     Length			    	|   HEADER MAGIC    |
   246   * +------------------------------------------------------------+
   247   */
   248  func IsTTHeader(flagBuf []byte) bool {
   249  	return binary.BigEndian.Uint32(flagBuf[Size32:])&MagicMask == TTHeaderMagic
   250  }
   251  
   252  /**
   253   * +----------------------------------------+
   254   * |       2Byte        |       2Byte       |
   255   * +----------------------------------------+
   256   * |    HEADER MAGIC    |   HEADER SIZE     |
   257   * +----------------------------------------+
   258   */
   259  func isMeshHeader(flagBuf []byte) bool {
   260  	return binary.BigEndian.Uint32(flagBuf[:Size32])&MagicMask == MeshHeaderMagic
   261  }
   262  
   263  /**
   264   * Kitex protobuf has framed field
   265   * +------------------------------------------------------------+
   266   * |                  4Byte                 |       2Byte       |
   267   * +------------------------------------------------------------+
   268   * |   			     Length			    	|   HEADER MAGIC    |
   269   * +------------------------------------------------------------+
   270   */
   271  func isProtobufKitex(flagBuf []byte) bool {
   272  	return binary.BigEndian.Uint32(flagBuf[Size32:])&MagicMask == ProtobufV1Magic
   273  }
   274  
   275  /**
   276   * +-------------------+
   277   * |       2Byte       |
   278   * +-------------------+
   279   * |   HEADER MAGIC    |
   280   * +-------------------
   281   */
   282  func isThriftBinary(flagBuf []byte) bool {
   283  	return binary.BigEndian.Uint32(flagBuf[:Size32])&MagicMask == ThriftV1Magic
   284  }
   285  
   286  /**
   287   * +------------------------------------------------------------+
   288   * |                  4Byte                 |       2Byte       |
   289   * +------------------------------------------------------------+
   290   * |   			     Length			    	|   HEADER MAGIC    |
   291   * +------------------------------------------------------------+
   292   */
   293  func isThriftFramedBinary(flagBuf []byte) bool {
   294  	return binary.BigEndian.Uint32(flagBuf[Size32:])&MagicMask == ThriftV1Magic
   295  }
   296  
   297  func checkRPCState(ctx context.Context, message remote.Message) error {
   298  	if message.RPCRole() == remote.Server {
   299  		return nil
   300  	}
   301  	if ctx.Err() == context.DeadlineExceeded || ctx.Err() == context.Canceled {
   302  		return kerrors.ErrRPCFinish
   303  	}
   304  	if respOp, ok := ctx.Value(retry.CtxRespOp).(*int32); ok {
   305  		if !atomic.CompareAndSwapInt32(respOp, retry.OpNo, retry.OpDoing) {
   306  			// previous call is being handling or done
   307  			// this flag is used to check request status in retry(backup request) scene
   308  			return kerrors.ErrRPCFinish
   309  		}
   310  	}
   311  	return nil
   312  }
   313  
   314  func checkPayload(flagBuf []byte, message remote.Message, in remote.ByteBuffer, isTTHeader bool, maxPayloadSize int) error {
   315  	var transProto transport.Protocol
   316  	var codecType serviceinfo.PayloadCodec
   317  	if isThriftBinary(flagBuf) {
   318  		codecType = serviceinfo.Thrift
   319  		if isTTHeader {
   320  			transProto = transport.TTHeader
   321  		} else {
   322  			transProto = transport.PurePayload
   323  		}
   324  	} else if isThriftFramedBinary(flagBuf) {
   325  		codecType = serviceinfo.Thrift
   326  		if isTTHeader {
   327  			transProto = transport.TTHeaderFramed
   328  		} else {
   329  			transProto = transport.Framed
   330  		}
   331  		payloadLen := binary.BigEndian.Uint32(flagBuf[:Size32])
   332  		message.SetPayloadLen(int(payloadLen))
   333  		if err := in.Skip(Size32); err != nil {
   334  			return err
   335  		}
   336  	} else if isProtobufKitex(flagBuf) {
   337  		codecType = serviceinfo.Protobuf
   338  		if isTTHeader {
   339  			transProto = transport.TTHeaderFramed
   340  		} else {
   341  			transProto = transport.Framed
   342  		}
   343  		payloadLen := binary.BigEndian.Uint32(flagBuf[:Size32])
   344  		message.SetPayloadLen(int(payloadLen))
   345  		if err := in.Skip(Size32); err != nil {
   346  			return err
   347  		}
   348  	} else {
   349  		first4Bytes := binary.BigEndian.Uint32(flagBuf[:Size32])
   350  		second4Bytes := binary.BigEndian.Uint32(flagBuf[Size32:])
   351  		// 0xfff4fffd is the interrupt message of telnet
   352  		err := perrors.NewProtocolErrorWithMsg(fmt.Sprintf("invalid payload (first4Bytes=%#x, second4Bytes=%#x)", first4Bytes, second4Bytes))
   353  		return err
   354  	}
   355  	if err := checkPayloadSize(message.PayloadLen(), maxPayloadSize); err != nil {
   356  		return err
   357  	}
   358  	message.SetProtocolInfo(remote.NewProtocolInfo(transProto, codecType))
   359  	cfg := rpcinfo.AsMutableRPCConfig(message.RPCInfo().Config())
   360  	if cfg != nil {
   361  		tp := message.ProtocolInfo().TransProto
   362  		cfg.SetTransportProtocol(tp)
   363  	}
   364  	return nil
   365  }
   366  
   367  func checkPayloadSize(payloadLen, maxSize int) error {
   368  	if maxSize > 0 && payloadLen > 0 && payloadLen > maxSize {
   369  		return perrors.NewProtocolErrorWithType(
   370  			perrors.InvalidData,
   371  			fmt.Sprintf("invalid data: payload size(%d) larger than the limit(%d)", payloadLen, maxSize),
   372  		)
   373  	}
   374  	return nil
   375  }