trpc.group/trpc-go/trpc-go@v1.0.3/codec_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package trpc_test
    15  
    16  import (
    17  	"bytes"
    18  	"context"
    19  	"encoding/binary"
    20  	"errors"
    21  	"log"
    22  	"net"
    23  	"regexp"
    24  	"testing"
    25  	"time"
    26  
    27  	"github.com/stretchr/testify/assert"
    28  	"github.com/stretchr/testify/require"
    29  	"google.golang.org/protobuf/proto"
    30  	"trpc.group/trpc-go/trpc-go/internal/attachment"
    31  	trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc"
    32  
    33  	trpc "trpc.group/trpc-go/trpc-go"
    34  	"trpc.group/trpc-go/trpc-go/codec"
    35  	"trpc.group/trpc-go/trpc-go/errs"
    36  	"trpc.group/trpc-go/trpc-go/pool/multiplexed"
    37  	pb "trpc.group/trpc-go/trpc-go/testdata/trpc/helloworld"
    38  )
    39  
    40  func TestFramer_ReadFrame(t *testing.T) {
    41  	// test magic num mismatch
    42  	{
    43  		var err error
    44  		totalLen := 0
    45  		buf := new(bytes.Buffer)
    46  		// MagicNum 0x930, 2bytes
    47  		assert.Nil(t, binary.Write(buf, binary.BigEndian, uint16(trpcpb.TrpcMagic_TRPC_MAGIC_VALUE+1)))
    48  		// frame type, 1byte
    49  		assert.Nil(t, binary.Write(buf, binary.BigEndian, uint8(0)))
    50  		// stream frame type, 1byte
    51  		assert.Nil(t, binary.Write(buf, binary.BigEndian, uint8(0)))
    52  		// total len
    53  		assert.Nil(t, binary.Write(buf, binary.BigEndian, uint32(totalLen)))
    54  		// pb header len
    55  		assert.Nil(t, binary.Write(buf, binary.BigEndian, uint16(0)))
    56  		// stream ID
    57  		assert.Nil(t, binary.Write(buf, binary.BigEndian, uint16(0)))
    58  		// reserved
    59  		assert.Nil(t, binary.Write(buf, binary.BigEndian, uint32(0)))
    60  		assert.Nil(t, err)
    61  
    62  		fb := &trpc.FramerBuilder{}
    63  		fr := fb.New(bytes.NewReader(buf.Bytes()))
    64  		assert.NotNil(t, fr)
    65  		_, err = fr.ReadFrame()
    66  		assert.NotNil(t, err)
    67  	}
    68  
    69  	// test total len exceed max error
    70  	{
    71  		var err error
    72  		totalLen := trpc.DefaultMaxFrameSize + 1
    73  		buf := new(bytes.Buffer)
    74  		// MagicNum 0x930, 2bytes
    75  		assert.Nil(t, binary.Write(buf, binary.BigEndian, uint16(trpcpb.TrpcMagic_TRPC_MAGIC_VALUE)))
    76  		// frame type, 1byte
    77  		assert.Nil(t, binary.Write(buf, binary.BigEndian, uint8(0)))
    78  		// stream frame type, 1byte
    79  		assert.Nil(t, binary.Write(buf, binary.BigEndian, uint8(0)))
    80  		// total len
    81  		assert.Nil(t, binary.Write(buf, binary.BigEndian, uint32(totalLen)))
    82  		assert.Nil(t, binary.Write(buf, binary.BigEndian, uint16(0)))
    83  		// stream ID
    84  		assert.Nil(t, binary.Write(buf, binary.BigEndian, uint16(0)))
    85  		// reserved
    86  		assert.Nil(t, binary.Write(buf, binary.BigEndian, uint32(0)))
    87  		assert.Nil(t, err)
    88  
    89  		fb := &trpc.FramerBuilder{}
    90  		fr := fb.New(bytes.NewReader(buf.Bytes()))
    91  		assert.NotNil(t, fr)
    92  		_, err = fr.ReadFrame()
    93  		assert.NotNil(t, err)
    94  	}
    95  }
    96  
    97  func TestClientCodecEnvTransfer(t *testing.T) {
    98  	envTransfer := []byte("env transfer")
    99  	cliCodec := &trpc.ClientCodec{}
   100  
   101  	// if msg.EnvTransfer() empty, transmitted env info in req.TransInfo should be cleared
   102  	_, msg := codec.WithNewMessage(context.Background())
   103  	msg.WithClientMetaData(map[string][]byte{trpc.EnvTransfer: envTransfer})
   104  	msg.WithEnvTransfer("")
   105  	reqBuf, err := cliCodec.Encode(msg, nil)
   106  	assert.Nil(t, err)
   107  	head := &trpcpb.RequestProtocol{}
   108  	err = proto.Unmarshal(reqBuf[16:], head)
   109  	assert.Nil(t, err)
   110  	assert.Equal(t, head.TransInfo[trpc.EnvTransfer], []byte{})
   111  
   112  	// msg.EnvTransfer() not empty
   113  	_, msg = codec.WithNewMessage(context.Background())
   114  	msg.WithEnvTransfer("env transfer")
   115  	reqBuf, err = cliCodec.Encode(msg, nil)
   116  	assert.Nil(t, err)
   117  	head = &trpcpb.RequestProtocol{}
   118  	err = proto.Unmarshal(reqBuf[16:], head)
   119  	assert.Nil(t, err)
   120  	assert.Equal(t, head.TransInfo[trpc.EnvTransfer], envTransfer)
   121  }
   122  
   123  func TestClientCodecDyeing(t *testing.T) {
   124  	dyeingKey := "123456789"
   125  	cliCodec := &trpc.ClientCodec{}
   126  	_, msg := codec.WithNewMessage(context.Background())
   127  	msg.WithDyeingKey(dyeingKey)
   128  	reqBuf, err := cliCodec.Encode(msg, nil)
   129  	assert.Nil(t, err)
   130  	head := &trpcpb.RequestProtocol{}
   131  	err = proto.Unmarshal(reqBuf[16:], head)
   132  	assert.Nil(t, err)
   133  	assert.Equal(t, head.TransInfo[trpc.DyeingKey], []byte(dyeingKey))
   134  }
   135  
   136  func TestFramerBuilder(t *testing.T) {
   137  	t.Run("frame build is a SafeFramer", func(t *testing.T) {
   138  		fb := trpc.FramerBuilder{}
   139  		frame := fb.New(bytes.NewReader(nil))
   140  		require.True(t, frame.(codec.SafeFramer).IsSafe())
   141  	})
   142  	t.Run("ok, read valid response", func(t *testing.T) {
   143  		bts := mustEncode(t, []byte("hello-world"))
   144  		vid, buf, err := (&trpc.FramerBuilder{}).Parse(bytes.NewReader(bts))
   145  		require.Nil(t, err)
   146  		require.Zero(t, vid)
   147  		require.Equal(t, bts, buf)
   148  	})
   149  	t.Run("garbage data", func(t *testing.T) {
   150  		_, _, err := (&trpc.FramerBuilder{}).Parse(bytes.NewReader([]byte("hello-world xxxxxxxxxxxx")))
   151  		require.Regexp(t, regexp.MustCompile(`magic .+ not match`), err.Error())
   152  	})
   153  }
   154  
   155  func mustEncode(t *testing.T, body []byte) (buffer []byte) {
   156  	t.Helper()
   157  
   158  	msgHead := &trpcpb.RequestProtocol{
   159  		Version: uint32(trpcpb.TrpcProtoVersion_TRPC_PROTO_V1),
   160  		Callee:  []byte("trpc.test.helloworld.Greetor"),
   161  		Func:    []byte("/trpc.test.helloworld.Greetor/SayHello"),
   162  	}
   163  	head, err := proto.Marshal(msgHead)
   164  	if err != nil {
   165  		t.Fatal(err)
   166  	}
   167  
   168  	buf := new(bytes.Buffer)
   169  	// MagicNum 0x930, 2bytes
   170  	if err := binary.Write(buf, binary.BigEndian, uint16(trpcpb.TrpcMagic_TRPC_MAGIC_VALUE)); err != nil {
   171  		t.Fatal(err)
   172  	}
   173  	// frame type, 1byte
   174  	if err := binary.Write(buf, binary.BigEndian, uint8(0)); err != nil {
   175  		t.Fatal(err)
   176  	}
   177  	// stream frame type, 1byte
   178  	if err := binary.Write(buf, binary.BigEndian, uint8(0)); err != nil {
   179  		t.Fatal(err)
   180  	}
   181  	// total len
   182  	totalLen := 16 + len(head) + len(body)
   183  	if err := binary.Write(buf, binary.BigEndian, uint32(totalLen)); err != nil {
   184  		t.Fatal(err)
   185  	}
   186  	// pb header len
   187  	if err := binary.Write(buf, binary.BigEndian, uint16(len(head))); err != nil {
   188  		t.Fatal(err)
   189  	}
   190  	// stream ID
   191  	if err := binary.Write(buf, binary.BigEndian, uint16(0)); err != nil {
   192  		t.Fatal(err)
   193  	}
   194  	// reserved
   195  	if err := binary.Write(buf, binary.BigEndian, uint32(0)); err != nil {
   196  		t.Fatal(err)
   197  	}
   198  	// header
   199  	if err := binary.Write(buf, binary.BigEndian, head); err != nil {
   200  		t.Fatal(err)
   201  	}
   202  	// body
   203  	if err := binary.Write(buf, binary.BigEndian, body); err != nil {
   204  		t.Fatal(err)
   205  	}
   206  	return buf.Bytes()
   207  }
   208  
   209  func TestClientCodec_DecodeHeadOverflowsUint16(t *testing.T) {
   210  	cc := trpc.ClientCodec{}
   211  	msg := codec.Message(trpc.BackgroundContext())
   212  
   213  	msg.WithClientMetaData(codec.MetaData{"smallBuffer": make([]byte, 16)})
   214  	rspBuf, err := cc.Encode(msg, nil)
   215  	require.Nil(t, err)
   216  	require.Contains(t, string(rspBuf), "smallBuffer")
   217  
   218  	msg.WithClientMetaData(map[string][]byte{"largeBuffer": make([]byte, 64*1024)})
   219  	_, err = cc.Encode(msg, nil)
   220  	require.NotNil(t, err)
   221  }
   222  
   223  func TestServerCodec_DecodeHeadOverflowsUint16(t *testing.T) {
   224  	cc := trpc.ServerCodec{}
   225  	msg := codec.Message(trpc.BackgroundContext())
   226  
   227  	msg.WithServerMetaData(map[string][]byte{"smallBuffer": make([]byte, 16)})
   228  	rspBuf, err := cc.Encode(msg, nil)
   229  	require.Nil(t, err)
   230  	require.Contains(t, string(rspBuf), "smallBuffer")
   231  
   232  	msg.WithServerMetaData(
   233  		map[string][]byte{
   234  			"smallBuffer": make([]byte, 16),
   235  			"largeBuffer": make([]byte, 64*1024),
   236  		})
   237  	rspBuf, err = cc.Encode(msg, nil)
   238  	require.Nil(t, err)
   239  	require.Less(t, len(rspBuf), 64*1024)
   240  	require.NotContains(t, string(rspBuf), "smallBuffer")
   241  	require.NotContains(t, string(rspBuf), "largeBuffer")
   242  }
   243  
   244  func TestClientCodec_CallTypeEncode(t *testing.T) {
   245  	sc := trpc.ClientCodec{}
   246  	msg := codec.Message(trpc.BackgroundContext())
   247  	msg.WithCallType(codec.SendOnly)
   248  	reqBuf, err := sc.Encode(msg, nil)
   249  	assert.Nil(t, err)
   250  	head := &trpcpb.RequestProtocol{}
   251  	err = proto.Unmarshal(reqBuf[16:], head)
   252  	assert.Nil(t, err)
   253  	assert.Equal(t, head.GetCallType(), uint32(codec.SendOnly))
   254  }
   255  
   256  func TestServerCodec_CallTypeDecode(t *testing.T) {
   257  	cc := trpc.ClientCodec{}
   258  	sc := trpc.ServerCodec{}
   259  	msg := codec.Message(trpc.BackgroundContext())
   260  	msg.WithCallType(codec.SendOnly)
   261  	reqBuf, err := cc.Encode(msg, nil)
   262  	assert.Nil(t, err)
   263  	_, err = sc.Decode(msg, reqBuf)
   264  	assert.Nil(t, err)
   265  	assert.Equal(t, msg.CallType(), codec.SendOnly)
   266  }
   267  
   268  func TestClientCodec_EncodeErr(t *testing.T) {
   269  	t.Run("head len overflows uint16", func(t *testing.T) {
   270  		cc := trpc.ClientCodec{}
   271  		msg := codec.Message(trpc.BackgroundContext())
   272  		msg.WithClientMetaData(codec.MetaData{"overHeadLengthU16": make([]byte, 64*1024)})
   273  		_, err := cc.Encode(msg, nil)
   274  		assert.EqualError(t, err, "head len overflows uint16")
   275  	})
   276  	t.Run("frame len is too large", func(t *testing.T) {
   277  		cc := trpc.ClientCodec{}
   278  		msg := codec.Message(trpc.BackgroundContext())
   279  		_, err := cc.Encode(msg, make([]byte, trpc.DefaultMaxFrameSize))
   280  		assert.EqualError(t, err, "frame len is larger than MaxFrameSize(10485760)")
   281  	})
   282  	t.Run("encoding attachment failed", func(t *testing.T) {
   283  		cc := trpc.ClientCodec{}
   284  		msg := codec.Message(trpc.BackgroundContext())
   285  		msg.WithCommonMeta(codec.CommonMeta{attachment.ClientAttachmentKey{}: &attachment.Attachment{Request: &errorReader{}, Response: attachment.NoopAttachment{}}})
   286  		_, err := cc.Encode(msg, nil)
   287  		assert.EqualError(t, err, "encoding attachment: reading errorReader always returns error")
   288  	})
   289  
   290  }
   291  
   292  type errorReader struct{}
   293  
   294  func (*errorReader) Read(p []byte) (n int, err error) {
   295  	return 0, errors.New("reading errorReader always returns error")
   296  }
   297  
   298  func TestServerCodec_EncodeErr(t *testing.T) {
   299  	t.Run("head len overflows uint16", func(t *testing.T) {
   300  		msg := codec.Message(trpc.BackgroundContext())
   301  		sc := trpc.ServerCodec{}
   302  		msg.WithServerMetaData(codec.MetaData{"overHeadLengthU16": make([]byte, 64*1024)})
   303  		rspBuf, err := sc.Encode(msg, nil)
   304  		assert.Nil(t, err)
   305  
   306  		head := &trpcpb.ResponseProtocol{}
   307  		err = proto.Unmarshal(rspBuf[16:], head)
   308  		assert.Nil(t, err)
   309  		assert.Equal(t, int32(errs.RetServerEncodeFail), head.GetRet())
   310  	})
   311  	t.Run("frame len is too large", func(t *testing.T) {
   312  		msg := codec.Message(trpc.BackgroundContext())
   313  		sc := trpc.ServerCodec{}
   314  		rspBuf, err := sc.Encode(msg, make([]byte, trpc.DefaultMaxFrameSize))
   315  		assert.Nil(t, err)
   316  
   317  		head := &trpcpb.ResponseProtocol{}
   318  		err = proto.Unmarshal(rspBuf[16:], head)
   319  		assert.Nil(t, err)
   320  		assert.Equal(t, int32(errs.RetServerEncodeFail), head.GetRet())
   321  	})
   322  	t.Run("encoding attachment failed", func(t *testing.T) {
   323  		msg := codec.Message(trpc.BackgroundContext())
   324  		msg.WithCommonMeta(codec.CommonMeta{attachment.ServerAttachmentKey{}: &attachment.Attachment{Request: attachment.NoopAttachment{}, Response: &errorReader{}}})
   325  		sc := trpc.ServerCodec{}
   326  		_, err := sc.Encode(msg, nil)
   327  		assert.EqualError(t, err, "encoding attachment: reading errorReader always returns error")
   328  	})
   329  }
   330  
   331  func TestMultiplexFrame(t *testing.T) {
   332  	buf := mustEncode(t, []byte("helloworld"))
   333  	vid, frame, err := (&trpc.FramerBuilder{}).Parse(bytes.NewReader(buf))
   334  	require.Nil(t, err)
   335  	require.Equal(t, uint32(0), vid)
   336  	require.Equal(t, buf, frame)
   337  }
   338  
   339  func TestClientCodecNoModifyOriginalFrameHead(t *testing.T) {
   340  	_, msg := codec.WithNewMessage(context.Background())
   341  	fh := &trpc.FrameHead{
   342  		StreamID: 101,
   343  	}
   344  	msg.WithFrameHead(fh)
   345  	clientCodec := &trpc.ClientCodec{}
   346  	_, err := clientCodec.Encode(msg, []byte("helloworld"))
   347  	require.Nil(t, err)
   348  	require.Equal(t, uint32(101), fh.StreamID)
   349  }
   350  
   351  // GOMAXPROCS=1 go test -bench=ServerCodec_Decode -benchmem
   352  // -benchtime=10s -memprofile mem.out -cpuprofile cpu.out codec_test.go
   353  func BenchmarkServerCodec_Decode(b *testing.B) {
   354  	sc := &trpc.ServerCodec{}
   355  	cc := &trpc.ClientCodec{}
   356  	_, msg := codec.WithNewMessage(context.Background())
   357  
   358  	reqBody, err := proto.Marshal(&pb.HelloRequest{
   359  		Msg: "helloworld",
   360  	})
   361  	assert.Nil(b, err)
   362  
   363  	req, err := cc.Encode(msg, reqBody)
   364  	assert.Nil(b, err)
   365  	b.ResetTimer()
   366  	for n := 0; n < b.N; n++ {
   367  		sc.Decode(msg, req)
   368  	}
   369  }
   370  
   371  // GOMAXPROCS=1 go test -bench=ClientCodec_Encode -benchmem -benchtime=10s
   372  // -memprofile mem.out -cpuprofile cpu.out codec_test.go
   373  func BenchmarkClientCodec_Encode(b *testing.B) {
   374  	cc := &trpc.ClientCodec{}
   375  
   376  	_, msg := codec.WithNewMessage(context.Background())
   377  	reqBody, err := proto.Marshal(&pb.HelloRequest{
   378  		Msg: "helloworld",
   379  	})
   380  	assert.Nil(b, err)
   381  
   382  	b.ResetTimer()
   383  	for i := 0; i < b.N; i++ {
   384  		cc.Encode(msg, reqBody)
   385  	}
   386  }
   387  
   388  func TestUDPParseFail(t *testing.T) {
   389  	s := &udpServer{}
   390  	s.start(context.Background())
   391  	t.Cleanup(s.stop)
   392  
   393  	m := multiplexed.New(multiplexed.WithConnectNumber(1))
   394  	test := func(id uint32, buf []byte, wantErr error) {
   395  		ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
   396  		opts := multiplexed.NewGetOptions()
   397  		opts.WithVID(id)
   398  		opts.WithFrameParser(&trpc.FramerBuilder{})
   399  		mc, err := m.GetMuxConn(ctx, s.conn.LocalAddr().Network(), s.conn.LocalAddr().String(), opts)
   400  		assert.Nil(t, err)
   401  		require.Nil(t, mc.Write(buf))
   402  		_, err = mc.Read()
   403  		assert.Equal(t, err, wantErr)
   404  		cancel()
   405  	}
   406  	// fail when parse invalid buf
   407  	var id uint32 = 1
   408  	test(id, []byte("invalid buf"), context.DeadlineExceeded)
   409  
   410  	// succeed when parse valid buf
   411  	id = 2
   412  	msg := codec.Message(context.Background())
   413  	msg.WithFrameHead(&trpc.FrameHead{
   414  		StreamID: id,
   415  	})
   416  	sc := &trpc.ServerCodec{}
   417  	buf, _ := sc.Encode(msg, []byte("helloworld"))
   418  	test(id, buf, nil)
   419  }
   420  
   421  type udpServer struct {
   422  	cancel context.CancelFunc
   423  	conn   net.PacketConn
   424  }
   425  
   426  func (s *udpServer) start(ctx context.Context) error {
   427  	var err error
   428  	s.conn, err = net.ListenPacket("udp", "127.0.0.1:0")
   429  	if err != nil {
   430  		return err
   431  	}
   432  	ctx, s.cancel = context.WithCancel(ctx)
   433  	go func() {
   434  		buf := make([]byte, 65535)
   435  		for {
   436  			select {
   437  			case <-ctx.Done():
   438  				return
   439  			default:
   440  			}
   441  			n, addr, err := s.conn.ReadFrom(buf)
   442  			if err != nil {
   443  				log.Println("l.ReadFrom err: ", err)
   444  				return
   445  			}
   446  			s.conn.WriteTo(buf[:n], addr)
   447  		}
   448  	}()
   449  	return nil
   450  }
   451  
   452  func (s *udpServer) stop() {
   453  	s.cancel()
   454  	s.conn.Close()
   455  }