github.com/cloudwego/kitex@v0.9.0/pkg/remote/codec/header_codec.go (about)

     1  /*
     2   * Copyright 2021 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package codec
    18  
    19  import (
    20  	"context"
    21  	"encoding/binary"
    22  	"fmt"
    23  	"io"
    24  
    25  	"github.com/cloudwego/kitex/pkg/klog"
    26  	"github.com/cloudwego/kitex/pkg/remote"
    27  	"github.com/cloudwego/kitex/pkg/remote/codec/perrors"
    28  	"github.com/cloudwego/kitex/pkg/remote/transmeta"
    29  	"github.com/cloudwego/kitex/pkg/rpcinfo"
    30  	"github.com/cloudwego/kitex/pkg/rpcinfo/remoteinfo"
    31  	"github.com/cloudwego/kitex/pkg/serviceinfo"
    32  	"github.com/cloudwego/kitex/pkg/utils"
    33  )
    34  
    35  /**
    36   *  TTHeader Protocol
    37   *  +-------------2Byte--------------|-------------2Byte-------------+
    38   *	+----------------------------------------------------------------+
    39   *	| 0|                          LENGTH                             |
    40   *	+----------------------------------------------------------------+
    41   *	| 0|       HEADER MAGIC          |            FLAGS              |
    42   *	+----------------------------------------------------------------+
    43   *	|                         SEQUENCE NUMBER                        |
    44   *	+----------------------------------------------------------------+
    45   *	| 0|     Header Size(/32)        | ...
    46   *	+---------------------------------
    47   *
    48   *	Header is of variable size:
    49   *	(and starts at offset 14)
    50   *
    51   *	+----------------------------------------------------------------+
    52   *	| PROTOCOL ID  |NUM TRANSFORMS . |TRANSFORM 0 ID (uint8)|
    53   *	+----------------------------------------------------------------+
    54   *	|  TRANSFORM 0 DATA ...
    55   *	+----------------------------------------------------------------+
    56   *	|         ...                              ...                   |
    57   *	+----------------------------------------------------------------+
    58   *	|        INFO 0 ID (uint8)      |       INFO 0  DATA ...
    59   *	+----------------------------------------------------------------+
    60   *	|         ...                              ...                   |
    61   *	+----------------------------------------------------------------+
    62   *	|                                                                |
    63   *	|                              PAYLOAD                           |
    64   *	|                                                                |
    65   *	+----------------------------------------------------------------+
    66   */
    67  
    68  // Header keys
    69  const (
    70  	// Header Magics
    71  	// 0 and 16th bits must be 0 to differentiate from framed & unframed
    72  	TTHeaderMagic     uint32 = 0x10000000
    73  	MeshHeaderMagic   uint32 = 0xFFAF0000
    74  	MeshHeaderLenMask uint32 = 0x0000FFFF
    75  
    76  	// HeaderMask        uint32 = 0xFFFF0000
    77  	FlagsMask     uint32 = 0x0000FFFF
    78  	MethodMask    uint32 = 0x41000000 // method first byte [A-Za-z_]
    79  	MaxFrameSize  uint32 = 0x3FFFFFFF
    80  	MaxHeaderSize uint32 = 65536
    81  )
    82  
    83  type HeaderFlags uint16
    84  
    85  const (
    86  	HeaderFlagsKey              string      = "HeaderFlags"
    87  	HeaderFlagSupportOutOfOrder HeaderFlags = 0x01
    88  	HeaderFlagDuplexReverse     HeaderFlags = 0x08
    89  	HeaderFlagSASL              HeaderFlags = 0x10
    90  )
    91  
    92  const (
    93  	TTHeaderMetaSize = 14
    94  )
    95  
    96  // ProtocolID is the wrapped protocol id used in THeader.
    97  type ProtocolID uint8
    98  
    99  // Supported ProtocolID values.
   100  const (
   101  	ProtocolIDThriftBinary    ProtocolID = 0x00
   102  	ProtocolIDThriftCompact   ProtocolID = 0x02 // Kitex not support
   103  	ProtocolIDThriftCompactV2 ProtocolID = 0x03 // Kitex not support
   104  	ProtocolIDKitexProtobuf   ProtocolID = 0x04
   105  	ProtocolIDDefault                    = ProtocolIDThriftBinary
   106  )
   107  
   108  type InfoIDType uint8 // uint8
   109  
   110  const (
   111  	InfoIDPadding     InfoIDType = 0
   112  	InfoIDKeyValue    InfoIDType = 0x01
   113  	InfoIDIntKeyValue InfoIDType = 0x10
   114  	InfoIDACLToken    InfoIDType = 0x11
   115  )
   116  
   117  type ttHeader struct{}
   118  
   119  func (t ttHeader) encode(ctx context.Context, message remote.Message, out remote.ByteBuffer) (totalLenField []byte, err error) {
   120  	// 1. header meta
   121  	var headerMeta []byte
   122  	headerMeta, err = out.Malloc(TTHeaderMetaSize)
   123  	if err != nil {
   124  		return nil, perrors.NewProtocolErrorWithMsg(fmt.Sprintf("ttHeader malloc header meta failed, %s", err.Error()))
   125  	}
   126  
   127  	totalLenField = headerMeta[0:4]
   128  	headerInfoSizeField := headerMeta[12:14]
   129  	binary.BigEndian.PutUint32(headerMeta[4:8], TTHeaderMagic+uint32(getFlags(message)))
   130  	binary.BigEndian.PutUint32(headerMeta[8:12], uint32(message.RPCInfo().Invocation().SeqID()))
   131  
   132  	var transformIDs []uint8 // transformIDs not support TODO compress
   133  	// 2.  header info, malloc and write
   134  	if err = WriteByte(byte(getProtocolID(message.ProtocolInfo())), out); err != nil {
   135  		return nil, perrors.NewProtocolErrorWithMsg(fmt.Sprintf("ttHeader write protocol id failed, %s", err.Error()))
   136  	}
   137  	if err = WriteByte(byte(len(transformIDs)), out); err != nil {
   138  		return nil, perrors.NewProtocolErrorWithMsg(fmt.Sprintf("ttHeader write transformIDs length failed, %s", err.Error()))
   139  	}
   140  	for tid := range transformIDs {
   141  		if err = WriteByte(byte(tid), out); err != nil {
   142  			return nil, perrors.NewProtocolErrorWithMsg(fmt.Sprintf("ttHeader write transformIDs failed, %s", err.Error()))
   143  		}
   144  	}
   145  	// PROTOCOL ID(u8) + NUM TRANSFORMS(always 0)(u8) + TRANSFORM IDs([]u8)
   146  	headerInfoSize := 1 + 1 + len(transformIDs)
   147  	headerInfoSize, err = writeKVInfo(headerInfoSize, message, out)
   148  	if err != nil {
   149  		return nil, perrors.NewProtocolErrorWithMsg(fmt.Sprintf("ttHeader write kv info failed, %s", err.Error()))
   150  	}
   151  
   152  	if uint32(headerInfoSize) > MaxHeaderSize {
   153  		return nil, perrors.NewProtocolErrorWithMsg(fmt.Sprintf("invalid header length[%d]", headerInfoSize))
   154  	}
   155  	binary.BigEndian.PutUint16(headerInfoSizeField, uint16(headerInfoSize/4))
   156  	return totalLenField, err
   157  }
   158  
   159  func (t ttHeader) decode(ctx context.Context, message remote.Message, in remote.ByteBuffer) error {
   160  	headerMeta, err := in.Next(TTHeaderMetaSize)
   161  	if err != nil {
   162  		return perrors.NewProtocolError(err)
   163  	}
   164  	if !IsTTHeader(headerMeta) {
   165  		return perrors.NewProtocolErrorWithMsg("not TTHeader protocol")
   166  	}
   167  	totalLen := Bytes2Uint32NoCheck(headerMeta[:Size32])
   168  
   169  	flags := Bytes2Uint16NoCheck(headerMeta[Size16*3:])
   170  	setFlags(flags, message)
   171  
   172  	seqID := Bytes2Uint32NoCheck(headerMeta[Size32*2 : Size32*3])
   173  	if err = SetOrCheckSeqID(int32(seqID), message); err != nil {
   174  		klog.Warnf("the seqID in TTHeader check failed, error=%s", err.Error())
   175  		// some framework doesn't write correct seqID in TTheader, to ignore err only check it in payload
   176  		// print log to push the downstream framework to refine it.
   177  	}
   178  	headerInfoSize := Bytes2Uint16NoCheck(headerMeta[Size32*3:TTHeaderMetaSize]) * 4
   179  	if uint32(headerInfoSize) > MaxHeaderSize || headerInfoSize < 2 {
   180  		return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("invalid header length[%d]", headerInfoSize))
   181  	}
   182  
   183  	var headerInfo []byte
   184  	if headerInfo, err = in.Next(int(headerInfoSize)); err != nil {
   185  		return perrors.NewProtocolError(err)
   186  	}
   187  	if err = checkProtocolID(headerInfo[0], message); err != nil {
   188  		return err
   189  	}
   190  	hdIdx := 2
   191  	transformIDNum := int(headerInfo[1])
   192  	if int(headerInfoSize)-hdIdx < transformIDNum {
   193  		return perrors.NewProtocolErrorWithType(perrors.InvalidData, fmt.Sprintf("need read %d transformIDs, but not enough", transformIDNum))
   194  	}
   195  	transformIDs := make([]uint8, transformIDNum)
   196  	for i := 0; i < transformIDNum; i++ {
   197  		transformIDs[i] = headerInfo[hdIdx]
   198  		hdIdx++
   199  	}
   200  
   201  	if err := readKVInfo(hdIdx, headerInfo, message); err != nil {
   202  		return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("ttHeader read kv info failed, %s, headerInfo=%#x", err.Error(), headerInfo))
   203  	}
   204  	fillBasicInfoOfTTHeader(message)
   205  
   206  	message.SetPayloadLen(int(totalLen - uint32(headerInfoSize) + Size32 - TTHeaderMetaSize))
   207  	return err
   208  }
   209  
   210  func writeKVInfo(writtenSize int, message remote.Message, out remote.ByteBuffer) (writeSize int, err error) {
   211  	writeSize = writtenSize
   212  	tm := message.TransInfo()
   213  	// str kv info
   214  	strKVMap := tm.TransStrInfo()
   215  	strKVSize := len(strKVMap)
   216  	// write gdpr token into InfoIDACLToken
   217  	// supplementary doc: https://www.cloudwego.io/docs/kitex/reference/transport_protocol_ttheader/
   218  	if gdprToken, ok := strKVMap[transmeta.GDPRToken]; ok {
   219  		strKVSize--
   220  		// INFO ID TYPE(u8)
   221  		if err = WriteByte(byte(InfoIDACLToken), out); err != nil {
   222  			return writeSize, err
   223  		}
   224  		writeSize += 1
   225  
   226  		wLen, err := WriteString2BLen(gdprToken, out)
   227  		if err != nil {
   228  			return writeSize, err
   229  		}
   230  		writeSize += wLen
   231  	}
   232  
   233  	if strKVSize > 0 {
   234  		// INFO ID TYPE(u8) + NUM HEADERS(u16)
   235  		if err = WriteByte(byte(InfoIDKeyValue), out); err != nil {
   236  			return writeSize, err
   237  		}
   238  		if err = WriteUint16(uint16(strKVSize), out); err != nil {
   239  			return writeSize, err
   240  		}
   241  		writeSize += 3
   242  		for key, val := range strKVMap {
   243  			if key == transmeta.GDPRToken {
   244  				continue
   245  			}
   246  			keyWLen, err := WriteString2BLen(key, out)
   247  			if err != nil {
   248  				return writeSize, err
   249  			}
   250  			valWLen, err := WriteString2BLen(val, out)
   251  			if err != nil {
   252  				return writeSize, err
   253  			}
   254  			writeSize = writeSize + keyWLen + valWLen
   255  		}
   256  	}
   257  
   258  	// int kv info
   259  	intKVSize := len(tm.TransIntInfo())
   260  	if intKVSize > 0 {
   261  		// INFO ID TYPE(u8) + NUM HEADERS(u16)
   262  		if err = WriteByte(byte(InfoIDIntKeyValue), out); err != nil {
   263  			return writeSize, err
   264  		}
   265  		if err = WriteUint16(uint16(intKVSize), out); err != nil {
   266  			return writeSize, err
   267  		}
   268  		writeSize += 3
   269  		for key, val := range tm.TransIntInfo() {
   270  			if err = WriteUint16(key, out); err != nil {
   271  				return writeSize, err
   272  			}
   273  			valWLen, err := WriteString2BLen(val, out)
   274  			if err != nil {
   275  				return writeSize, err
   276  			}
   277  			writeSize = writeSize + 2 + valWLen
   278  		}
   279  	}
   280  
   281  	// padding = (4 - headerInfoSize%4) % 4
   282  	padding := (4 - writeSize%4) % 4
   283  	paddingBuf, err := out.Malloc(padding)
   284  	if err != nil {
   285  		return writeSize, err
   286  	}
   287  	for i := 0; i < len(paddingBuf); i++ {
   288  		paddingBuf[i] = byte(0)
   289  	}
   290  	writeSize += padding
   291  	return
   292  }
   293  
   294  func readKVInfo(idx int, buf []byte, message remote.Message) error {
   295  	intInfo := message.TransInfo().TransIntInfo()
   296  	strInfo := message.TransInfo().TransStrInfo()
   297  	for {
   298  		infoID, err := Bytes2Uint8(buf, idx)
   299  		idx++
   300  		if err != nil {
   301  			// this is the last field, read until there is no more padding
   302  			if err == io.EOF {
   303  				break
   304  			} else {
   305  				return err
   306  			}
   307  		}
   308  		switch InfoIDType(infoID) {
   309  		case InfoIDPadding:
   310  			continue
   311  		case InfoIDKeyValue:
   312  			_, err := readStrKVInfo(&idx, buf, strInfo)
   313  			if err != nil {
   314  				return err
   315  			}
   316  		case InfoIDIntKeyValue:
   317  			_, err := readIntKVInfo(&idx, buf, intInfo)
   318  			if err != nil {
   319  				return err
   320  			}
   321  		case InfoIDACLToken:
   322  			if err := readACLToken(&idx, buf, strInfo); err != nil {
   323  				return err
   324  			}
   325  		default:
   326  			return fmt.Errorf("invalid infoIDType[%#x]", infoID)
   327  		}
   328  	}
   329  	return nil
   330  }
   331  
   332  func readIntKVInfo(idx *int, buf []byte, info map[uint16]string) (has bool, err error) {
   333  	kvSize, err := Bytes2Uint16(buf, *idx)
   334  	*idx += 2
   335  	if err != nil {
   336  		return false, fmt.Errorf("error reading int kv info size: %s", err.Error())
   337  	}
   338  	if kvSize <= 0 {
   339  		return false, nil
   340  	}
   341  	for i := uint16(0); i < kvSize; i++ {
   342  		key, err := Bytes2Uint16(buf, *idx)
   343  		*idx += 2
   344  		if err != nil {
   345  			return false, fmt.Errorf("error reading int kv info: %s", err.Error())
   346  		}
   347  		val, n, err := ReadString2BLen(buf, *idx)
   348  		*idx += n
   349  		if err != nil {
   350  			return false, fmt.Errorf("error reading int kv info: %s", err.Error())
   351  		}
   352  		info[key] = val
   353  	}
   354  	return true, nil
   355  }
   356  
   357  func readStrKVInfo(idx *int, buf []byte, info map[string]string) (has bool, err error) {
   358  	kvSize, err := Bytes2Uint16(buf, *idx)
   359  	*idx += 2
   360  	if err != nil {
   361  		return false, fmt.Errorf("error reading str kv info size: %s", err.Error())
   362  	}
   363  	if kvSize <= 0 {
   364  		return false, nil
   365  	}
   366  	for i := uint16(0); i < kvSize; i++ {
   367  		key, n, err := ReadString2BLen(buf, *idx)
   368  		*idx += n
   369  		if err != nil {
   370  			return false, fmt.Errorf("error reading str kv info: %s", err.Error())
   371  		}
   372  		val, n, err := ReadString2BLen(buf, *idx)
   373  		*idx += n
   374  		if err != nil {
   375  			return false, fmt.Errorf("error reading str kv info: %s", err.Error())
   376  		}
   377  		info[key] = val
   378  	}
   379  	return true, nil
   380  }
   381  
   382  // readACLToken reads acl token
   383  func readACLToken(idx *int, buf []byte, info map[string]string) error {
   384  	val, n, err := ReadString2BLen(buf, *idx)
   385  	*idx += n
   386  	if err != nil {
   387  		return fmt.Errorf("error reading acl token: %s", err.Error())
   388  	}
   389  	info[transmeta.GDPRToken] = val
   390  	return nil
   391  }
   392  
   393  func getFlags(message remote.Message) HeaderFlags {
   394  	var headerFlags HeaderFlags
   395  	if message.Tags() != nil && message.Tags()[HeaderFlagsKey] != nil {
   396  		if hfs, ok := message.Tags()[HeaderFlagsKey].(HeaderFlags); ok {
   397  			headerFlags = hfs
   398  		} else {
   399  			klog.Warnf("KITEX: the type of headerFlags is invalid, %T", message.Tags()[HeaderFlagsKey])
   400  		}
   401  	}
   402  	return headerFlags
   403  }
   404  
   405  func setFlags(flags uint16, message remote.Message) {
   406  	if message.MessageType() == remote.Call {
   407  		message.Tags()[HeaderFlagsKey] = HeaderFlags(flags)
   408  	}
   409  }
   410  
   411  // protoID just for ttheader
   412  func getProtocolID(pi remote.ProtocolInfo) ProtocolID {
   413  	switch pi.CodecType {
   414  	case serviceinfo.Protobuf:
   415  		// ProtocolIDKitexProtobuf is 0x03 at old version(<=v1.9.1) , but it conflicts with ThriftCompactV2.
   416  		// Change the ProtocolIDKitexProtobuf to 0x04 from v1.9.2. But notice! that it is an incompatible change of protocol.
   417  		// For keeping compatible, Kitex use ProtocolIDDefault send ttheader+KitexProtobuf request to ignore the old version
   418  		// check failed if use 0x04. It doesn't make sense, but it won't affect the correctness of RPC call because the actual
   419  		// protocol check at checkPayload func which check payload with HEADER MAGIC bytes of payload.
   420  		return ProtocolIDDefault
   421  	}
   422  	return ProtocolIDDefault
   423  }
   424  
   425  // protoID just for ttheader
   426  func checkProtocolID(protoID uint8, message remote.Message) error {
   427  	switch protoID {
   428  	case uint8(ProtocolIDThriftBinary):
   429  	case uint8(ProtocolIDKitexProtobuf):
   430  	case uint8(ProtocolIDThriftCompactV2):
   431  		// just for compatibility
   432  	default:
   433  		return fmt.Errorf("unsupported ProtocolID[%d]", protoID)
   434  	}
   435  	return nil
   436  }
   437  
   438  /**
   439   * +-------------2Byte-------------|-------------2Byte--------------+
   440   * +----------------------------------------------------------------+
   441   * |       HEADER MAGIC            |      HEADER SIZE               |
   442   * +----------------------------------------------------------------+
   443   * |       HEADER MAP SIZE         |    HEADER MAP...               |
   444   * +----------------------------------------------------------------+
   445   * |                                                                |
   446   * |                            PAYLOAD                             |
   447   * |                                                                |
   448   * +----------------------------------------------------------------+
   449   */
   450  type meshHeader struct{}
   451  
   452  //lint:ignore U1000 until encode is used
   453  func (m meshHeader) encode(ctx context.Context, message remote.Message, payloadBuf, out remote.ByteBuffer) error {
   454  	// do nothing, kitex just support decode meshHeader, encode protocol depend on the payload
   455  	return nil
   456  }
   457  
   458  func (m meshHeader) decode(ctx context.Context, message remote.Message, in remote.ByteBuffer) error {
   459  	headerMeta, err := in.Next(Size32)
   460  	if err != nil {
   461  		return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("meshHeader read header meta failed, %s", err.Error()))
   462  	}
   463  	if !isMeshHeader(headerMeta) {
   464  		return perrors.NewProtocolErrorWithMsg("not MeshHeader protocol")
   465  	}
   466  	headerLen := Bytes2Uint16NoCheck(headerMeta[Size16:])
   467  	var headerInfo []byte
   468  	if headerInfo, err = in.Next(int(headerLen)); err != nil {
   469  		return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("meshHeader read header buf failed, %s", err.Error()))
   470  	}
   471  	mapInfo := message.TransInfo().TransStrInfo()
   472  	idx := 0
   473  	if _, err = readStrKVInfo(&idx, headerInfo, mapInfo); err != nil {
   474  		return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("meshHeader read kv info failed, %s", err.Error()))
   475  	}
   476  	fillBasicInfoOfTTHeader(message)
   477  	return nil
   478  }
   479  
   480  // Fill basic from_info(from service, from address) which carried by ttheader to rpcinfo.
   481  // It is better to fill rpcinfo in matahandlers in terms of design,
   482  // but metahandlers are executed after payloadDecode, we don't know from_info when error happen in payloadDecode.
   483  // So 'fillBasicInfoOfTTHeader' is just for getting more info to output log when decode error happen.
   484  func fillBasicInfoOfTTHeader(msg remote.Message) {
   485  	if msg.RPCRole() == remote.Server {
   486  		fi := rpcinfo.AsMutableEndpointInfo(msg.RPCInfo().From())
   487  		if fi != nil {
   488  			if v := msg.TransInfo().TransStrInfo()[transmeta.HeaderTransRemoteAddr]; v != "" {
   489  				fi.SetAddress(utils.NewNetAddr("tcp", v))
   490  			}
   491  			if v := msg.TransInfo().TransIntInfo()[transmeta.FromService]; v != "" {
   492  				fi.SetServiceName(v)
   493  			}
   494  		}
   495  		if ink, ok := msg.RPCInfo().Invocation().(rpcinfo.InvocationSetter); ok {
   496  			if svcName, ok := msg.TransInfo().TransStrInfo()[transmeta.HeaderIDLServiceName]; ok {
   497  				ink.SetServiceName(svcName)
   498  			}
   499  		}
   500  	} else {
   501  		ti := remoteinfo.AsRemoteInfo(msg.RPCInfo().To())
   502  		if ti != nil {
   503  			if v := msg.TransInfo().TransStrInfo()[transmeta.HeaderTransRemoteAddr]; v != "" {
   504  				ti.SetRemoteAddr(utils.NewNetAddr("tcp", v))
   505  			}
   506  		}
   507  	}
   508  }