github.com/cloudwego/kitex@v0.9.0/pkg/remote/codec/header_codec.go (about) 1 /* 2 * Copyright 2021 CloudWeGo Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package codec 18 19 import ( 20 "context" 21 "encoding/binary" 22 "fmt" 23 "io" 24 25 "github.com/cloudwego/kitex/pkg/klog" 26 "github.com/cloudwego/kitex/pkg/remote" 27 "github.com/cloudwego/kitex/pkg/remote/codec/perrors" 28 "github.com/cloudwego/kitex/pkg/remote/transmeta" 29 "github.com/cloudwego/kitex/pkg/rpcinfo" 30 "github.com/cloudwego/kitex/pkg/rpcinfo/remoteinfo" 31 "github.com/cloudwego/kitex/pkg/serviceinfo" 32 "github.com/cloudwego/kitex/pkg/utils" 33 ) 34 35 /** 36 * TTHeader Protocol 37 * +-------------2Byte--------------|-------------2Byte-------------+ 38 * +----------------------------------------------------------------+ 39 * | 0| LENGTH | 40 * +----------------------------------------------------------------+ 41 * | 0| HEADER MAGIC | FLAGS | 42 * +----------------------------------------------------------------+ 43 * | SEQUENCE NUMBER | 44 * +----------------------------------------------------------------+ 45 * | 0| Header Size(/32) | ... 46 * +--------------------------------- 47 * 48 * Header is of variable size: 49 * (and starts at offset 14) 50 * 51 * +----------------------------------------------------------------+ 52 * | PROTOCOL ID |NUM TRANSFORMS . |TRANSFORM 0 ID (uint8)| 53 * +----------------------------------------------------------------+ 54 * | TRANSFORM 0 DATA ... 55 * +----------------------------------------------------------------+ 56 * | ... ... | 57 * +----------------------------------------------------------------+ 58 * | INFO 0 ID (uint8) | INFO 0 DATA ... 59 * +----------------------------------------------------------------+ 60 * | ... ... | 61 * +----------------------------------------------------------------+ 62 * | | 63 * | PAYLOAD | 64 * | | 65 * +----------------------------------------------------------------+ 66 */ 67 68 // Header keys 69 const ( 70 // Header Magics 71 // 0 and 16th bits must be 0 to differentiate from framed & unframed 72 TTHeaderMagic uint32 = 0x10000000 73 MeshHeaderMagic uint32 = 0xFFAF0000 74 MeshHeaderLenMask uint32 = 0x0000FFFF 75 76 // HeaderMask uint32 = 0xFFFF0000 77 FlagsMask uint32 = 0x0000FFFF 78 MethodMask uint32 = 0x41000000 // method first byte [A-Za-z_] 79 MaxFrameSize uint32 = 0x3FFFFFFF 80 MaxHeaderSize uint32 = 65536 81 ) 82 83 type HeaderFlags uint16 84 85 const ( 86 HeaderFlagsKey string = "HeaderFlags" 87 HeaderFlagSupportOutOfOrder HeaderFlags = 0x01 88 HeaderFlagDuplexReverse HeaderFlags = 0x08 89 HeaderFlagSASL HeaderFlags = 0x10 90 ) 91 92 const ( 93 TTHeaderMetaSize = 14 94 ) 95 96 // ProtocolID is the wrapped protocol id used in THeader. 97 type ProtocolID uint8 98 99 // Supported ProtocolID values. 100 const ( 101 ProtocolIDThriftBinary ProtocolID = 0x00 102 ProtocolIDThriftCompact ProtocolID = 0x02 // Kitex not support 103 ProtocolIDThriftCompactV2 ProtocolID = 0x03 // Kitex not support 104 ProtocolIDKitexProtobuf ProtocolID = 0x04 105 ProtocolIDDefault = ProtocolIDThriftBinary 106 ) 107 108 type InfoIDType uint8 // uint8 109 110 const ( 111 InfoIDPadding InfoIDType = 0 112 InfoIDKeyValue InfoIDType = 0x01 113 InfoIDIntKeyValue InfoIDType = 0x10 114 InfoIDACLToken InfoIDType = 0x11 115 ) 116 117 type ttHeader struct{} 118 119 func (t ttHeader) encode(ctx context.Context, message remote.Message, out remote.ByteBuffer) (totalLenField []byte, err error) { 120 // 1. header meta 121 var headerMeta []byte 122 headerMeta, err = out.Malloc(TTHeaderMetaSize) 123 if err != nil { 124 return nil, perrors.NewProtocolErrorWithMsg(fmt.Sprintf("ttHeader malloc header meta failed, %s", err.Error())) 125 } 126 127 totalLenField = headerMeta[0:4] 128 headerInfoSizeField := headerMeta[12:14] 129 binary.BigEndian.PutUint32(headerMeta[4:8], TTHeaderMagic+uint32(getFlags(message))) 130 binary.BigEndian.PutUint32(headerMeta[8:12], uint32(message.RPCInfo().Invocation().SeqID())) 131 132 var transformIDs []uint8 // transformIDs not support TODO compress 133 // 2. header info, malloc and write 134 if err = WriteByte(byte(getProtocolID(message.ProtocolInfo())), out); err != nil { 135 return nil, perrors.NewProtocolErrorWithMsg(fmt.Sprintf("ttHeader write protocol id failed, %s", err.Error())) 136 } 137 if err = WriteByte(byte(len(transformIDs)), out); err != nil { 138 return nil, perrors.NewProtocolErrorWithMsg(fmt.Sprintf("ttHeader write transformIDs length failed, %s", err.Error())) 139 } 140 for tid := range transformIDs { 141 if err = WriteByte(byte(tid), out); err != nil { 142 return nil, perrors.NewProtocolErrorWithMsg(fmt.Sprintf("ttHeader write transformIDs failed, %s", err.Error())) 143 } 144 } 145 // PROTOCOL ID(u8) + NUM TRANSFORMS(always 0)(u8) + TRANSFORM IDs([]u8) 146 headerInfoSize := 1 + 1 + len(transformIDs) 147 headerInfoSize, err = writeKVInfo(headerInfoSize, message, out) 148 if err != nil { 149 return nil, perrors.NewProtocolErrorWithMsg(fmt.Sprintf("ttHeader write kv info failed, %s", err.Error())) 150 } 151 152 if uint32(headerInfoSize) > MaxHeaderSize { 153 return nil, perrors.NewProtocolErrorWithMsg(fmt.Sprintf("invalid header length[%d]", headerInfoSize)) 154 } 155 binary.BigEndian.PutUint16(headerInfoSizeField, uint16(headerInfoSize/4)) 156 return totalLenField, err 157 } 158 159 func (t ttHeader) decode(ctx context.Context, message remote.Message, in remote.ByteBuffer) error { 160 headerMeta, err := in.Next(TTHeaderMetaSize) 161 if err != nil { 162 return perrors.NewProtocolError(err) 163 } 164 if !IsTTHeader(headerMeta) { 165 return perrors.NewProtocolErrorWithMsg("not TTHeader protocol") 166 } 167 totalLen := Bytes2Uint32NoCheck(headerMeta[:Size32]) 168 169 flags := Bytes2Uint16NoCheck(headerMeta[Size16*3:]) 170 setFlags(flags, message) 171 172 seqID := Bytes2Uint32NoCheck(headerMeta[Size32*2 : Size32*3]) 173 if err = SetOrCheckSeqID(int32(seqID), message); err != nil { 174 klog.Warnf("the seqID in TTHeader check failed, error=%s", err.Error()) 175 // some framework doesn't write correct seqID in TTheader, to ignore err only check it in payload 176 // print log to push the downstream framework to refine it. 177 } 178 headerInfoSize := Bytes2Uint16NoCheck(headerMeta[Size32*3:TTHeaderMetaSize]) * 4 179 if uint32(headerInfoSize) > MaxHeaderSize || headerInfoSize < 2 { 180 return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("invalid header length[%d]", headerInfoSize)) 181 } 182 183 var headerInfo []byte 184 if headerInfo, err = in.Next(int(headerInfoSize)); err != nil { 185 return perrors.NewProtocolError(err) 186 } 187 if err = checkProtocolID(headerInfo[0], message); err != nil { 188 return err 189 } 190 hdIdx := 2 191 transformIDNum := int(headerInfo[1]) 192 if int(headerInfoSize)-hdIdx < transformIDNum { 193 return perrors.NewProtocolErrorWithType(perrors.InvalidData, fmt.Sprintf("need read %d transformIDs, but not enough", transformIDNum)) 194 } 195 transformIDs := make([]uint8, transformIDNum) 196 for i := 0; i < transformIDNum; i++ { 197 transformIDs[i] = headerInfo[hdIdx] 198 hdIdx++ 199 } 200 201 if err := readKVInfo(hdIdx, headerInfo, message); err != nil { 202 return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("ttHeader read kv info failed, %s, headerInfo=%#x", err.Error(), headerInfo)) 203 } 204 fillBasicInfoOfTTHeader(message) 205 206 message.SetPayloadLen(int(totalLen - uint32(headerInfoSize) + Size32 - TTHeaderMetaSize)) 207 return err 208 } 209 210 func writeKVInfo(writtenSize int, message remote.Message, out remote.ByteBuffer) (writeSize int, err error) { 211 writeSize = writtenSize 212 tm := message.TransInfo() 213 // str kv info 214 strKVMap := tm.TransStrInfo() 215 strKVSize := len(strKVMap) 216 // write gdpr token into InfoIDACLToken 217 // supplementary doc: https://www.cloudwego.io/docs/kitex/reference/transport_protocol_ttheader/ 218 if gdprToken, ok := strKVMap[transmeta.GDPRToken]; ok { 219 strKVSize-- 220 // INFO ID TYPE(u8) 221 if err = WriteByte(byte(InfoIDACLToken), out); err != nil { 222 return writeSize, err 223 } 224 writeSize += 1 225 226 wLen, err := WriteString2BLen(gdprToken, out) 227 if err != nil { 228 return writeSize, err 229 } 230 writeSize += wLen 231 } 232 233 if strKVSize > 0 { 234 // INFO ID TYPE(u8) + NUM HEADERS(u16) 235 if err = WriteByte(byte(InfoIDKeyValue), out); err != nil { 236 return writeSize, err 237 } 238 if err = WriteUint16(uint16(strKVSize), out); err != nil { 239 return writeSize, err 240 } 241 writeSize += 3 242 for key, val := range strKVMap { 243 if key == transmeta.GDPRToken { 244 continue 245 } 246 keyWLen, err := WriteString2BLen(key, out) 247 if err != nil { 248 return writeSize, err 249 } 250 valWLen, err := WriteString2BLen(val, out) 251 if err != nil { 252 return writeSize, err 253 } 254 writeSize = writeSize + keyWLen + valWLen 255 } 256 } 257 258 // int kv info 259 intKVSize := len(tm.TransIntInfo()) 260 if intKVSize > 0 { 261 // INFO ID TYPE(u8) + NUM HEADERS(u16) 262 if err = WriteByte(byte(InfoIDIntKeyValue), out); err != nil { 263 return writeSize, err 264 } 265 if err = WriteUint16(uint16(intKVSize), out); err != nil { 266 return writeSize, err 267 } 268 writeSize += 3 269 for key, val := range tm.TransIntInfo() { 270 if err = WriteUint16(key, out); err != nil { 271 return writeSize, err 272 } 273 valWLen, err := WriteString2BLen(val, out) 274 if err != nil { 275 return writeSize, err 276 } 277 writeSize = writeSize + 2 + valWLen 278 } 279 } 280 281 // padding = (4 - headerInfoSize%4) % 4 282 padding := (4 - writeSize%4) % 4 283 paddingBuf, err := out.Malloc(padding) 284 if err != nil { 285 return writeSize, err 286 } 287 for i := 0; i < len(paddingBuf); i++ { 288 paddingBuf[i] = byte(0) 289 } 290 writeSize += padding 291 return 292 } 293 294 func readKVInfo(idx int, buf []byte, message remote.Message) error { 295 intInfo := message.TransInfo().TransIntInfo() 296 strInfo := message.TransInfo().TransStrInfo() 297 for { 298 infoID, err := Bytes2Uint8(buf, idx) 299 idx++ 300 if err != nil { 301 // this is the last field, read until there is no more padding 302 if err == io.EOF { 303 break 304 } else { 305 return err 306 } 307 } 308 switch InfoIDType(infoID) { 309 case InfoIDPadding: 310 continue 311 case InfoIDKeyValue: 312 _, err := readStrKVInfo(&idx, buf, strInfo) 313 if err != nil { 314 return err 315 } 316 case InfoIDIntKeyValue: 317 _, err := readIntKVInfo(&idx, buf, intInfo) 318 if err != nil { 319 return err 320 } 321 case InfoIDACLToken: 322 if err := readACLToken(&idx, buf, strInfo); err != nil { 323 return err 324 } 325 default: 326 return fmt.Errorf("invalid infoIDType[%#x]", infoID) 327 } 328 } 329 return nil 330 } 331 332 func readIntKVInfo(idx *int, buf []byte, info map[uint16]string) (has bool, err error) { 333 kvSize, err := Bytes2Uint16(buf, *idx) 334 *idx += 2 335 if err != nil { 336 return false, fmt.Errorf("error reading int kv info size: %s", err.Error()) 337 } 338 if kvSize <= 0 { 339 return false, nil 340 } 341 for i := uint16(0); i < kvSize; i++ { 342 key, err := Bytes2Uint16(buf, *idx) 343 *idx += 2 344 if err != nil { 345 return false, fmt.Errorf("error reading int kv info: %s", err.Error()) 346 } 347 val, n, err := ReadString2BLen(buf, *idx) 348 *idx += n 349 if err != nil { 350 return false, fmt.Errorf("error reading int kv info: %s", err.Error()) 351 } 352 info[key] = val 353 } 354 return true, nil 355 } 356 357 func readStrKVInfo(idx *int, buf []byte, info map[string]string) (has bool, err error) { 358 kvSize, err := Bytes2Uint16(buf, *idx) 359 *idx += 2 360 if err != nil { 361 return false, fmt.Errorf("error reading str kv info size: %s", err.Error()) 362 } 363 if kvSize <= 0 { 364 return false, nil 365 } 366 for i := uint16(0); i < kvSize; i++ { 367 key, n, err := ReadString2BLen(buf, *idx) 368 *idx += n 369 if err != nil { 370 return false, fmt.Errorf("error reading str kv info: %s", err.Error()) 371 } 372 val, n, err := ReadString2BLen(buf, *idx) 373 *idx += n 374 if err != nil { 375 return false, fmt.Errorf("error reading str kv info: %s", err.Error()) 376 } 377 info[key] = val 378 } 379 return true, nil 380 } 381 382 // readACLToken reads acl token 383 func readACLToken(idx *int, buf []byte, info map[string]string) error { 384 val, n, err := ReadString2BLen(buf, *idx) 385 *idx += n 386 if err != nil { 387 return fmt.Errorf("error reading acl token: %s", err.Error()) 388 } 389 info[transmeta.GDPRToken] = val 390 return nil 391 } 392 393 func getFlags(message remote.Message) HeaderFlags { 394 var headerFlags HeaderFlags 395 if message.Tags() != nil && message.Tags()[HeaderFlagsKey] != nil { 396 if hfs, ok := message.Tags()[HeaderFlagsKey].(HeaderFlags); ok { 397 headerFlags = hfs 398 } else { 399 klog.Warnf("KITEX: the type of headerFlags is invalid, %T", message.Tags()[HeaderFlagsKey]) 400 } 401 } 402 return headerFlags 403 } 404 405 func setFlags(flags uint16, message remote.Message) { 406 if message.MessageType() == remote.Call { 407 message.Tags()[HeaderFlagsKey] = HeaderFlags(flags) 408 } 409 } 410 411 // protoID just for ttheader 412 func getProtocolID(pi remote.ProtocolInfo) ProtocolID { 413 switch pi.CodecType { 414 case serviceinfo.Protobuf: 415 // ProtocolIDKitexProtobuf is 0x03 at old version(<=v1.9.1) , but it conflicts with ThriftCompactV2. 416 // Change the ProtocolIDKitexProtobuf to 0x04 from v1.9.2. But notice! that it is an incompatible change of protocol. 417 // For keeping compatible, Kitex use ProtocolIDDefault send ttheader+KitexProtobuf request to ignore the old version 418 // check failed if use 0x04. It doesn't make sense, but it won't affect the correctness of RPC call because the actual 419 // protocol check at checkPayload func which check payload with HEADER MAGIC bytes of payload. 420 return ProtocolIDDefault 421 } 422 return ProtocolIDDefault 423 } 424 425 // protoID just for ttheader 426 func checkProtocolID(protoID uint8, message remote.Message) error { 427 switch protoID { 428 case uint8(ProtocolIDThriftBinary): 429 case uint8(ProtocolIDKitexProtobuf): 430 case uint8(ProtocolIDThriftCompactV2): 431 // just for compatibility 432 default: 433 return fmt.Errorf("unsupported ProtocolID[%d]", protoID) 434 } 435 return nil 436 } 437 438 /** 439 * +-------------2Byte-------------|-------------2Byte--------------+ 440 * +----------------------------------------------------------------+ 441 * | HEADER MAGIC | HEADER SIZE | 442 * +----------------------------------------------------------------+ 443 * | HEADER MAP SIZE | HEADER MAP... | 444 * +----------------------------------------------------------------+ 445 * | | 446 * | PAYLOAD | 447 * | | 448 * +----------------------------------------------------------------+ 449 */ 450 type meshHeader struct{} 451 452 //lint:ignore U1000 until encode is used 453 func (m meshHeader) encode(ctx context.Context, message remote.Message, payloadBuf, out remote.ByteBuffer) error { 454 // do nothing, kitex just support decode meshHeader, encode protocol depend on the payload 455 return nil 456 } 457 458 func (m meshHeader) decode(ctx context.Context, message remote.Message, in remote.ByteBuffer) error { 459 headerMeta, err := in.Next(Size32) 460 if err != nil { 461 return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("meshHeader read header meta failed, %s", err.Error())) 462 } 463 if !isMeshHeader(headerMeta) { 464 return perrors.NewProtocolErrorWithMsg("not MeshHeader protocol") 465 } 466 headerLen := Bytes2Uint16NoCheck(headerMeta[Size16:]) 467 var headerInfo []byte 468 if headerInfo, err = in.Next(int(headerLen)); err != nil { 469 return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("meshHeader read header buf failed, %s", err.Error())) 470 } 471 mapInfo := message.TransInfo().TransStrInfo() 472 idx := 0 473 if _, err = readStrKVInfo(&idx, headerInfo, mapInfo); err != nil { 474 return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("meshHeader read kv info failed, %s", err.Error())) 475 } 476 fillBasicInfoOfTTHeader(message) 477 return nil 478 } 479 480 // Fill basic from_info(from service, from address) which carried by ttheader to rpcinfo. 481 // It is better to fill rpcinfo in matahandlers in terms of design, 482 // but metahandlers are executed after payloadDecode, we don't know from_info when error happen in payloadDecode. 483 // So 'fillBasicInfoOfTTHeader' is just for getting more info to output log when decode error happen. 484 func fillBasicInfoOfTTHeader(msg remote.Message) { 485 if msg.RPCRole() == remote.Server { 486 fi := rpcinfo.AsMutableEndpointInfo(msg.RPCInfo().From()) 487 if fi != nil { 488 if v := msg.TransInfo().TransStrInfo()[transmeta.HeaderTransRemoteAddr]; v != "" { 489 fi.SetAddress(utils.NewNetAddr("tcp", v)) 490 } 491 if v := msg.TransInfo().TransIntInfo()[transmeta.FromService]; v != "" { 492 fi.SetServiceName(v) 493 } 494 } 495 if ink, ok := msg.RPCInfo().Invocation().(rpcinfo.InvocationSetter); ok { 496 if svcName, ok := msg.TransInfo().TransStrInfo()[transmeta.HeaderIDLServiceName]; ok { 497 ink.SetServiceName(svcName) 498 } 499 } 500 } else { 501 ti := remoteinfo.AsRemoteInfo(msg.RPCInfo().To()) 502 if ti != nil { 503 if v := msg.TransInfo().TransStrInfo()[transmeta.HeaderTransRemoteAddr]; v != "" { 504 ti.SetRemoteAddr(utils.NewNetAddr("tcp", v)) 505 } 506 } 507 } 508 }