trpc.group/trpc-go/trpc-go@v1.0.3/codec_test.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package trpc_test 15 16 import ( 17 "bytes" 18 "context" 19 "encoding/binary" 20 "errors" 21 "log" 22 "net" 23 "regexp" 24 "testing" 25 "time" 26 27 "github.com/stretchr/testify/assert" 28 "github.com/stretchr/testify/require" 29 "google.golang.org/protobuf/proto" 30 "trpc.group/trpc-go/trpc-go/internal/attachment" 31 trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc" 32 33 trpc "trpc.group/trpc-go/trpc-go" 34 "trpc.group/trpc-go/trpc-go/codec" 35 "trpc.group/trpc-go/trpc-go/errs" 36 "trpc.group/trpc-go/trpc-go/pool/multiplexed" 37 pb "trpc.group/trpc-go/trpc-go/testdata/trpc/helloworld" 38 ) 39 40 func TestFramer_ReadFrame(t *testing.T) { 41 // test magic num mismatch 42 { 43 var err error 44 totalLen := 0 45 buf := new(bytes.Buffer) 46 // MagicNum 0x930, 2bytes 47 assert.Nil(t, binary.Write(buf, binary.BigEndian, uint16(trpcpb.TrpcMagic_TRPC_MAGIC_VALUE+1))) 48 // frame type, 1byte 49 assert.Nil(t, binary.Write(buf, binary.BigEndian, uint8(0))) 50 // stream frame type, 1byte 51 assert.Nil(t, binary.Write(buf, binary.BigEndian, uint8(0))) 52 // total len 53 assert.Nil(t, binary.Write(buf, binary.BigEndian, uint32(totalLen))) 54 // pb header len 55 assert.Nil(t, binary.Write(buf, binary.BigEndian, uint16(0))) 56 // stream ID 57 assert.Nil(t, binary.Write(buf, binary.BigEndian, uint16(0))) 58 // reserved 59 assert.Nil(t, binary.Write(buf, binary.BigEndian, uint32(0))) 60 assert.Nil(t, err) 61 62 fb := &trpc.FramerBuilder{} 63 fr := fb.New(bytes.NewReader(buf.Bytes())) 64 assert.NotNil(t, fr) 65 _, err = fr.ReadFrame() 66 assert.NotNil(t, err) 67 } 68 69 // test total len exceed max error 70 { 71 var err error 72 totalLen := trpc.DefaultMaxFrameSize + 1 73 buf := new(bytes.Buffer) 74 // MagicNum 0x930, 2bytes 75 assert.Nil(t, binary.Write(buf, binary.BigEndian, uint16(trpcpb.TrpcMagic_TRPC_MAGIC_VALUE))) 76 // frame type, 1byte 77 assert.Nil(t, binary.Write(buf, binary.BigEndian, uint8(0))) 78 // stream frame type, 1byte 79 assert.Nil(t, binary.Write(buf, binary.BigEndian, uint8(0))) 80 // total len 81 assert.Nil(t, binary.Write(buf, binary.BigEndian, uint32(totalLen))) 82 assert.Nil(t, binary.Write(buf, binary.BigEndian, uint16(0))) 83 // stream ID 84 assert.Nil(t, binary.Write(buf, binary.BigEndian, uint16(0))) 85 // reserved 86 assert.Nil(t, binary.Write(buf, binary.BigEndian, uint32(0))) 87 assert.Nil(t, err) 88 89 fb := &trpc.FramerBuilder{} 90 fr := fb.New(bytes.NewReader(buf.Bytes())) 91 assert.NotNil(t, fr) 92 _, err = fr.ReadFrame() 93 assert.NotNil(t, err) 94 } 95 } 96 97 func TestClientCodecEnvTransfer(t *testing.T) { 98 envTransfer := []byte("env transfer") 99 cliCodec := &trpc.ClientCodec{} 100 101 // if msg.EnvTransfer() empty, transmitted env info in req.TransInfo should be cleared 102 _, msg := codec.WithNewMessage(context.Background()) 103 msg.WithClientMetaData(map[string][]byte{trpc.EnvTransfer: envTransfer}) 104 msg.WithEnvTransfer("") 105 reqBuf, err := cliCodec.Encode(msg, nil) 106 assert.Nil(t, err) 107 head := &trpcpb.RequestProtocol{} 108 err = proto.Unmarshal(reqBuf[16:], head) 109 assert.Nil(t, err) 110 assert.Equal(t, head.TransInfo[trpc.EnvTransfer], []byte{}) 111 112 // msg.EnvTransfer() not empty 113 _, msg = codec.WithNewMessage(context.Background()) 114 msg.WithEnvTransfer("env transfer") 115 reqBuf, err = cliCodec.Encode(msg, nil) 116 assert.Nil(t, err) 117 head = &trpcpb.RequestProtocol{} 118 err = proto.Unmarshal(reqBuf[16:], head) 119 assert.Nil(t, err) 120 assert.Equal(t, head.TransInfo[trpc.EnvTransfer], envTransfer) 121 } 122 123 func TestClientCodecDyeing(t *testing.T) { 124 dyeingKey := "123456789" 125 cliCodec := &trpc.ClientCodec{} 126 _, msg := codec.WithNewMessage(context.Background()) 127 msg.WithDyeingKey(dyeingKey) 128 reqBuf, err := cliCodec.Encode(msg, nil) 129 assert.Nil(t, err) 130 head := &trpcpb.RequestProtocol{} 131 err = proto.Unmarshal(reqBuf[16:], head) 132 assert.Nil(t, err) 133 assert.Equal(t, head.TransInfo[trpc.DyeingKey], []byte(dyeingKey)) 134 } 135 136 func TestFramerBuilder(t *testing.T) { 137 t.Run("frame build is a SafeFramer", func(t *testing.T) { 138 fb := trpc.FramerBuilder{} 139 frame := fb.New(bytes.NewReader(nil)) 140 require.True(t, frame.(codec.SafeFramer).IsSafe()) 141 }) 142 t.Run("ok, read valid response", func(t *testing.T) { 143 bts := mustEncode(t, []byte("hello-world")) 144 vid, buf, err := (&trpc.FramerBuilder{}).Parse(bytes.NewReader(bts)) 145 require.Nil(t, err) 146 require.Zero(t, vid) 147 require.Equal(t, bts, buf) 148 }) 149 t.Run("garbage data", func(t *testing.T) { 150 _, _, err := (&trpc.FramerBuilder{}).Parse(bytes.NewReader([]byte("hello-world xxxxxxxxxxxx"))) 151 require.Regexp(t, regexp.MustCompile(`magic .+ not match`), err.Error()) 152 }) 153 } 154 155 func mustEncode(t *testing.T, body []byte) (buffer []byte) { 156 t.Helper() 157 158 msgHead := &trpcpb.RequestProtocol{ 159 Version: uint32(trpcpb.TrpcProtoVersion_TRPC_PROTO_V1), 160 Callee: []byte("trpc.test.helloworld.Greetor"), 161 Func: []byte("/trpc.test.helloworld.Greetor/SayHello"), 162 } 163 head, err := proto.Marshal(msgHead) 164 if err != nil { 165 t.Fatal(err) 166 } 167 168 buf := new(bytes.Buffer) 169 // MagicNum 0x930, 2bytes 170 if err := binary.Write(buf, binary.BigEndian, uint16(trpcpb.TrpcMagic_TRPC_MAGIC_VALUE)); err != nil { 171 t.Fatal(err) 172 } 173 // frame type, 1byte 174 if err := binary.Write(buf, binary.BigEndian, uint8(0)); err != nil { 175 t.Fatal(err) 176 } 177 // stream frame type, 1byte 178 if err := binary.Write(buf, binary.BigEndian, uint8(0)); err != nil { 179 t.Fatal(err) 180 } 181 // total len 182 totalLen := 16 + len(head) + len(body) 183 if err := binary.Write(buf, binary.BigEndian, uint32(totalLen)); err != nil { 184 t.Fatal(err) 185 } 186 // pb header len 187 if err := binary.Write(buf, binary.BigEndian, uint16(len(head))); err != nil { 188 t.Fatal(err) 189 } 190 // stream ID 191 if err := binary.Write(buf, binary.BigEndian, uint16(0)); err != nil { 192 t.Fatal(err) 193 } 194 // reserved 195 if err := binary.Write(buf, binary.BigEndian, uint32(0)); err != nil { 196 t.Fatal(err) 197 } 198 // header 199 if err := binary.Write(buf, binary.BigEndian, head); err != nil { 200 t.Fatal(err) 201 } 202 // body 203 if err := binary.Write(buf, binary.BigEndian, body); err != nil { 204 t.Fatal(err) 205 } 206 return buf.Bytes() 207 } 208 209 func TestClientCodec_DecodeHeadOverflowsUint16(t *testing.T) { 210 cc := trpc.ClientCodec{} 211 msg := codec.Message(trpc.BackgroundContext()) 212 213 msg.WithClientMetaData(codec.MetaData{"smallBuffer": make([]byte, 16)}) 214 rspBuf, err := cc.Encode(msg, nil) 215 require.Nil(t, err) 216 require.Contains(t, string(rspBuf), "smallBuffer") 217 218 msg.WithClientMetaData(map[string][]byte{"largeBuffer": make([]byte, 64*1024)}) 219 _, err = cc.Encode(msg, nil) 220 require.NotNil(t, err) 221 } 222 223 func TestServerCodec_DecodeHeadOverflowsUint16(t *testing.T) { 224 cc := trpc.ServerCodec{} 225 msg := codec.Message(trpc.BackgroundContext()) 226 227 msg.WithServerMetaData(map[string][]byte{"smallBuffer": make([]byte, 16)}) 228 rspBuf, err := cc.Encode(msg, nil) 229 require.Nil(t, err) 230 require.Contains(t, string(rspBuf), "smallBuffer") 231 232 msg.WithServerMetaData( 233 map[string][]byte{ 234 "smallBuffer": make([]byte, 16), 235 "largeBuffer": make([]byte, 64*1024), 236 }) 237 rspBuf, err = cc.Encode(msg, nil) 238 require.Nil(t, err) 239 require.Less(t, len(rspBuf), 64*1024) 240 require.NotContains(t, string(rspBuf), "smallBuffer") 241 require.NotContains(t, string(rspBuf), "largeBuffer") 242 } 243 244 func TestClientCodec_CallTypeEncode(t *testing.T) { 245 sc := trpc.ClientCodec{} 246 msg := codec.Message(trpc.BackgroundContext()) 247 msg.WithCallType(codec.SendOnly) 248 reqBuf, err := sc.Encode(msg, nil) 249 assert.Nil(t, err) 250 head := &trpcpb.RequestProtocol{} 251 err = proto.Unmarshal(reqBuf[16:], head) 252 assert.Nil(t, err) 253 assert.Equal(t, head.GetCallType(), uint32(codec.SendOnly)) 254 } 255 256 func TestServerCodec_CallTypeDecode(t *testing.T) { 257 cc := trpc.ClientCodec{} 258 sc := trpc.ServerCodec{} 259 msg := codec.Message(trpc.BackgroundContext()) 260 msg.WithCallType(codec.SendOnly) 261 reqBuf, err := cc.Encode(msg, nil) 262 assert.Nil(t, err) 263 _, err = sc.Decode(msg, reqBuf) 264 assert.Nil(t, err) 265 assert.Equal(t, msg.CallType(), codec.SendOnly) 266 } 267 268 func TestClientCodec_EncodeErr(t *testing.T) { 269 t.Run("head len overflows uint16", func(t *testing.T) { 270 cc := trpc.ClientCodec{} 271 msg := codec.Message(trpc.BackgroundContext()) 272 msg.WithClientMetaData(codec.MetaData{"overHeadLengthU16": make([]byte, 64*1024)}) 273 _, err := cc.Encode(msg, nil) 274 assert.EqualError(t, err, "head len overflows uint16") 275 }) 276 t.Run("frame len is too large", func(t *testing.T) { 277 cc := trpc.ClientCodec{} 278 msg := codec.Message(trpc.BackgroundContext()) 279 _, err := cc.Encode(msg, make([]byte, trpc.DefaultMaxFrameSize)) 280 assert.EqualError(t, err, "frame len is larger than MaxFrameSize(10485760)") 281 }) 282 t.Run("encoding attachment failed", func(t *testing.T) { 283 cc := trpc.ClientCodec{} 284 msg := codec.Message(trpc.BackgroundContext()) 285 msg.WithCommonMeta(codec.CommonMeta{attachment.ClientAttachmentKey{}: &attachment.Attachment{Request: &errorReader{}, Response: attachment.NoopAttachment{}}}) 286 _, err := cc.Encode(msg, nil) 287 assert.EqualError(t, err, "encoding attachment: reading errorReader always returns error") 288 }) 289 290 } 291 292 type errorReader struct{} 293 294 func (*errorReader) Read(p []byte) (n int, err error) { 295 return 0, errors.New("reading errorReader always returns error") 296 } 297 298 func TestServerCodec_EncodeErr(t *testing.T) { 299 t.Run("head len overflows uint16", func(t *testing.T) { 300 msg := codec.Message(trpc.BackgroundContext()) 301 sc := trpc.ServerCodec{} 302 msg.WithServerMetaData(codec.MetaData{"overHeadLengthU16": make([]byte, 64*1024)}) 303 rspBuf, err := sc.Encode(msg, nil) 304 assert.Nil(t, err) 305 306 head := &trpcpb.ResponseProtocol{} 307 err = proto.Unmarshal(rspBuf[16:], head) 308 assert.Nil(t, err) 309 assert.Equal(t, int32(errs.RetServerEncodeFail), head.GetRet()) 310 }) 311 t.Run("frame len is too large", func(t *testing.T) { 312 msg := codec.Message(trpc.BackgroundContext()) 313 sc := trpc.ServerCodec{} 314 rspBuf, err := sc.Encode(msg, make([]byte, trpc.DefaultMaxFrameSize)) 315 assert.Nil(t, err) 316 317 head := &trpcpb.ResponseProtocol{} 318 err = proto.Unmarshal(rspBuf[16:], head) 319 assert.Nil(t, err) 320 assert.Equal(t, int32(errs.RetServerEncodeFail), head.GetRet()) 321 }) 322 t.Run("encoding attachment failed", func(t *testing.T) { 323 msg := codec.Message(trpc.BackgroundContext()) 324 msg.WithCommonMeta(codec.CommonMeta{attachment.ServerAttachmentKey{}: &attachment.Attachment{Request: attachment.NoopAttachment{}, Response: &errorReader{}}}) 325 sc := trpc.ServerCodec{} 326 _, err := sc.Encode(msg, nil) 327 assert.EqualError(t, err, "encoding attachment: reading errorReader always returns error") 328 }) 329 } 330 331 func TestMultiplexFrame(t *testing.T) { 332 buf := mustEncode(t, []byte("helloworld")) 333 vid, frame, err := (&trpc.FramerBuilder{}).Parse(bytes.NewReader(buf)) 334 require.Nil(t, err) 335 require.Equal(t, uint32(0), vid) 336 require.Equal(t, buf, frame) 337 } 338 339 func TestClientCodecNoModifyOriginalFrameHead(t *testing.T) { 340 _, msg := codec.WithNewMessage(context.Background()) 341 fh := &trpc.FrameHead{ 342 StreamID: 101, 343 } 344 msg.WithFrameHead(fh) 345 clientCodec := &trpc.ClientCodec{} 346 _, err := clientCodec.Encode(msg, []byte("helloworld")) 347 require.Nil(t, err) 348 require.Equal(t, uint32(101), fh.StreamID) 349 } 350 351 // GOMAXPROCS=1 go test -bench=ServerCodec_Decode -benchmem 352 // -benchtime=10s -memprofile mem.out -cpuprofile cpu.out codec_test.go 353 func BenchmarkServerCodec_Decode(b *testing.B) { 354 sc := &trpc.ServerCodec{} 355 cc := &trpc.ClientCodec{} 356 _, msg := codec.WithNewMessage(context.Background()) 357 358 reqBody, err := proto.Marshal(&pb.HelloRequest{ 359 Msg: "helloworld", 360 }) 361 assert.Nil(b, err) 362 363 req, err := cc.Encode(msg, reqBody) 364 assert.Nil(b, err) 365 b.ResetTimer() 366 for n := 0; n < b.N; n++ { 367 sc.Decode(msg, req) 368 } 369 } 370 371 // GOMAXPROCS=1 go test -bench=ClientCodec_Encode -benchmem -benchtime=10s 372 // -memprofile mem.out -cpuprofile cpu.out codec_test.go 373 func BenchmarkClientCodec_Encode(b *testing.B) { 374 cc := &trpc.ClientCodec{} 375 376 _, msg := codec.WithNewMessage(context.Background()) 377 reqBody, err := proto.Marshal(&pb.HelloRequest{ 378 Msg: "helloworld", 379 }) 380 assert.Nil(b, err) 381 382 b.ResetTimer() 383 for i := 0; i < b.N; i++ { 384 cc.Encode(msg, reqBody) 385 } 386 } 387 388 func TestUDPParseFail(t *testing.T) { 389 s := &udpServer{} 390 s.start(context.Background()) 391 t.Cleanup(s.stop) 392 393 m := multiplexed.New(multiplexed.WithConnectNumber(1)) 394 test := func(id uint32, buf []byte, wantErr error) { 395 ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) 396 opts := multiplexed.NewGetOptions() 397 opts.WithVID(id) 398 opts.WithFrameParser(&trpc.FramerBuilder{}) 399 mc, err := m.GetMuxConn(ctx, s.conn.LocalAddr().Network(), s.conn.LocalAddr().String(), opts) 400 assert.Nil(t, err) 401 require.Nil(t, mc.Write(buf)) 402 _, err = mc.Read() 403 assert.Equal(t, err, wantErr) 404 cancel() 405 } 406 // fail when parse invalid buf 407 var id uint32 = 1 408 test(id, []byte("invalid buf"), context.DeadlineExceeded) 409 410 // succeed when parse valid buf 411 id = 2 412 msg := codec.Message(context.Background()) 413 msg.WithFrameHead(&trpc.FrameHead{ 414 StreamID: id, 415 }) 416 sc := &trpc.ServerCodec{} 417 buf, _ := sc.Encode(msg, []byte("helloworld")) 418 test(id, buf, nil) 419 } 420 421 type udpServer struct { 422 cancel context.CancelFunc 423 conn net.PacketConn 424 } 425 426 func (s *udpServer) start(ctx context.Context) error { 427 var err error 428 s.conn, err = net.ListenPacket("udp", "127.0.0.1:0") 429 if err != nil { 430 return err 431 } 432 ctx, s.cancel = context.WithCancel(ctx) 433 go func() { 434 buf := make([]byte, 65535) 435 for { 436 select { 437 case <-ctx.Done(): 438 return 439 default: 440 } 441 n, addr, err := s.conn.ReadFrom(buf) 442 if err != nil { 443 log.Println("l.ReadFrom err: ", err) 444 return 445 } 446 s.conn.WriteTo(buf[:n], addr) 447 } 448 }() 449 return nil 450 } 451 452 func (s *udpServer) stop() { 453 s.cancel() 454 s.conn.Close() 455 }