trpc.group/trpc-go/trpc-go@v1.0.3/trpc_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package trpc_test
    15  
    16  import (
    17  	"bytes"
    18  	"context"
    19  	"os"
    20  	"path/filepath"
    21  	"testing"
    22  
    23  	trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc"
    24  
    25  	trpc "trpc.group/trpc-go/trpc-go"
    26  	"trpc.group/trpc-go/trpc-go/codec"
    27  	"trpc.group/trpc-go/trpc-go/log"
    28  	"trpc.group/trpc-go/trpc-go/plugin"
    29  	"trpc.group/trpc-go/trpc-go/transport"
    30  
    31  	"github.com/stretchr/testify/assert"
    32  	"github.com/stretchr/testify/require"
    33  	"google.golang.org/protobuf/proto"
    34  )
    35  
    36  var ctx = context.Background()
    37  
    38  func init() {
    39  	trpc.LoadGlobalConfig("testdata/trpc_go.yaml")
    40  	trpc.Setup(trpc.GlobalConfig())
    41  }
    42  
    43  // go test -v -coverprofile=cover.out
    44  // go tool cover -func=cover.out
    45  
    46  func TestCodec(t *testing.T) {
    47  
    48  	msg := trpc.Message(ctx)
    49  	assert.NotNil(t, msg)
    50  
    51  	ctx := trpc.BackgroundContext()
    52  	assert.NotNil(t, ctx)
    53  
    54  	val := trpc.GetMetaData(ctx, "no-exist")
    55  	assert.Nil(t, val)
    56  
    57  	trpc.SetMetaData(ctx, "exist1", []byte("value1"))
    58  	val = trpc.GetMetaData(ctx, "exist1")
    59  	assert.NotNil(t, val)
    60  	assert.Equal(t, []byte("value1"), val)
    61  
    62  	trpc.SetMetaData(ctx, "exist2", []byte("value2"))
    63  	val = trpc.GetMetaData(ctx, "exist2")
    64  	assert.NotNil(t, val)
    65  	assert.Equal(t, []byte("value2"), val)
    66  
    67  	serverCodec := codec.GetServer("trpc")
    68  	clientCodec := codec.GetClient("trpc")
    69  	frameBuilder := transport.GetFramerBuilder("trpc")
    70  
    71  	assert.Equal(t, trpc.DefaultServerCodec, serverCodec)
    72  	assert.Equal(t, trpc.DefaultClientCodec, clientCodec)
    73  
    74  	request := trpc.Request(ctx)
    75  	response := trpc.Response(ctx)
    76  	assert.NotNil(t, request)
    77  	assert.NotNil(t, response)
    78  
    79  	msg = trpc.Message(ctx)
    80  	msg.WithServerReqHead(request)
    81  	msg.WithServerRspHead(response)
    82  	request = trpc.Request(ctx)
    83  	response = trpc.Response(ctx)
    84  	assert.NotNil(t, request)
    85  	assert.NotNil(t, response)
    86  
    87  	request.Func = []byte("test")
    88  	data, err := proto.Marshal(request)
    89  	assert.Nil(t, err)
    90  	assert.Equal(t, []byte{0x3a, 0x4, 0x74, 0x65, 0x73, 0x74}, data)
    91  
    92  	response.ErrorMsg = []byte("ok")
    93  	data, err = proto.Marshal(response)
    94  	assert.Nil(t, err)
    95  	assert.Equal(t, []byte{0x32, 0x2, 0x6f, 0x6b}, data)
    96  
    97  	// frame head: 2 bytes magic stx(0x930) + 1 byte stream type(1) + 1 byte stream frame type(2)
    98  	// + 4 bytes total len(23) + 2 bytes pb header len(6) + 4 bytes stream id(0)
    99  	// + 2 bytes reserved(0) + head + body
   100  	in := []byte{0x9, 0x30, 0, 2, 0, 0, 0, 23, 0, 6, 0, 0, 0, 0, 0, 0, 0x3a, 0x4, 0x74, 0x65, 0x73, 0x74, 1}
   101  	reader := bytes.NewReader(in)
   102  	frame := frameBuilder.New(reader)
   103  	data, err = frame.ReadFrame()
   104  	assert.Nil(t, err)
   105  	assert.Equal(t, in, data)
   106  
   107  	// invalid magic num
   108  	in1 := []byte{0x30, 0x9, 1, 2, 0, 0, 0, 23, 0, 6, 0, 0, 0, 0, 0, 0, 0x3a, 0x4, 0x74, 0x65, 0x73, 0x74, 1}
   109  	reader = bytes.NewReader(in1)
   110  	frame = frameBuilder.New(reader)
   111  	_, err = frame.ReadFrame()
   112  	assert.Contains(t, err.Error(), "not match")
   113  
   114  	msg = codec.Message(ctx)
   115  	reqBody, err := serverCodec.Decode(msg, in)
   116  	assert.Nil(t, err)
   117  	assert.Equal(t, []byte{1}, reqBody)
   118  
   119  	// head len invalid
   120  	in2 := []byte{0x30, 0x9, 0, 2, 0, 0, 0, 23, 0, 7, 0, 0, 0, 0, 0, 0, 0x3a, 0x4, 0x74, 0x65, 0x73, 0x74, 1}
   121  	reqBody2, err := serverCodec.Decode(msg, in2)
   122  	assert.NotNil(t, err)
   123  	assert.Nil(t, reqBody2)
   124  
   125  	rspBuf, err := serverCodec.Encode(msg, reqBody)
   126  	assert.Nil(t, err)
   127  	assert.NotNil(t, rspBuf)
   128  
   129  	reqBuf, err := clientCodec.Encode(msg, reqBody)
   130  	assert.Nil(t, err)
   131  	assert.NotNil(t, reqBuf)
   132  
   133  	in3 := []byte{0x9, 0x30, 0, 2, 0, 0, 0, 21, 0, 4, 0, 0, 0, 0, 0, 0, 0x32, 0x2, 0x6f, 0x6b, 1}
   134  	msg.ClientReqHead().(*trpcpb.RequestProtocol).RequestId = 0
   135  	rspBody, err := clientCodec.Decode(msg, in3)
   136  	assert.Nil(t, err)
   137  	assert.Equal(t, []byte{1}, rspBody)
   138  }
   139  
   140  func TestVersion(t *testing.T) {
   141  	version := trpc.Version()
   142  
   143  	assert.NotNil(t, version)
   144  }
   145  
   146  func TestConfig(t *testing.T) {
   147  
   148  	trpc.ServerConfigPath = "./testdata/trpc_go.yaml"
   149  
   150  	conf := trpc.GlobalConfig()
   151  	assert.NotNil(t, conf)
   152  	assert.Equal(t, 3, len(conf.Server.Service))
   153  	assert.Equal(t, "trpc.test.helloworld.Greeter1", conf.Server.Service[0].Name)
   154  	assert.Equal(t, true, *conf.Server.Service[0].ServerAsync)
   155  	assert.Equal(t, 1000, conf.Server.Service[1].MaxRoutines)
   156  	assert.Equal(t, false, *conf.Server.Service[0].Writev)
   157  
   158  	cfg := &trpc.Config{}
   159  	cfg.Server.Network = "tcp"
   160  	cfg.Server.Protocol = "trpc"
   161  	cfg.Client.Network = "tcp"
   162  	cfg.Client.Protocol = "trpc"
   163  	trpc.SetGlobalConfig(cfg)
   164  }
   165  
   166  func TestNewServer(t *testing.T) {
   167  
   168  	trpc.ServerConfigPath = "./testdata/trpc_go.yaml"
   169  
   170  	logger := log.NewZapLog(log.Config{
   171  		{
   172  			Writer: log.OutputFile,
   173  			WriteConfig: log.WriteConfig{
   174  				LogPath:   os.TempDir(),
   175  				Filename:  "trpc.log",
   176  				WriteMode: 1,
   177  			},
   178  			Level: "DEBUG",
   179  		},
   180  	})
   181  	dftLogger := log.DefaultLogger
   182  	log.SetLogger(logger)
   183  	defer log.SetLogger(dftLogger)
   184  
   185  	fp := filepath.Join(os.TempDir(), "trpc.log")
   186  	defer os.Remove(fp)
   187  
   188  	s := trpc.NewServer()
   189  	assert.NotNil(t, s)
   190  	assert.NotNil(t, s.Service("trpc.test.helloworld.Greeter1"))
   191  	assert.NotNil(t, s.Service("trpc.test.helloworld.Greeter2"))
   192  	assert.NotNil(t, s.Service("trpc.test.helloworld.Greeter3"))
   193  	assert.Equal(t, codec.DefaultReaderSize, codec.GetReaderSize())
   194  
   195  	buf, err := os.ReadFile(fp)
   196  	assert.Nil(t, err)
   197  
   198  	// test namingservice not exist
   199  	// registry set for service1、service2
   200  	assert.Contains(t, string(buf), "trpc.test.helloworld.Greeter1 registry not exist")
   201  	assert.Contains(t, string(buf), "trpc.test.helloworld.Greeter2 registry not exist")
   202  	// registry not set for service3
   203  	assert.NotContains(t, string(buf), "trpc.test.helloworld.Greeter3 registry not exist")
   204  }
   205  
   206  func TestProtocol(t *testing.T) {
   207  	request := trpc.Request(ctx)
   208  	response := trpc.Response(ctx)
   209  
   210  	assert.NotNil(t, request.String())
   211  	assert.NotNil(t, response.String())
   212  	assert.Equal(t, uint32(0), request.GetContentType())
   213  	assert.Equal(t, uint32(0), request.GetRequestId())
   214  	assert.Equal(t, uint32(0), request.GetCallType())
   215  	assert.Equal(t, uint32(0), request.GetVersion())
   216  	assert.Equal(t, uint32(0), request.GetMessageType())
   217  	assert.Nil(t, response.GetErrorMsg())
   218  	assert.Nil(t, request.GetCallee())
   219  	assert.Nil(t, request.GetCaller())
   220  }
   221  
   222  func TestGetAdminService(t *testing.T) {
   223  	cfg := t.TempDir() + "trpc_go.yaml"
   224  	require.Nil(t, os.WriteFile(cfg, []byte{}, 0644))
   225  	oldPath := trpc.ServerConfigPath
   226  	trpc.ServerConfigPath = cfg
   227  	defer func() { trpc.ServerConfigPath = oldPath }()
   228  
   229  	_ = trpc.NewServer()
   230  	admin, err := trpc.GetAdminService(trpc.NewServer())
   231  	require.Nil(t, err)
   232  	require.NotNil(t, admin)
   233  
   234  	require.Nil(t, os.WriteFile(cfg, []byte(`
   235  server:
   236    admin:
   237      port: 9528
   238  `), 0644))
   239  
   240  	s := trpc.NewServer()
   241  	adminService, err := trpc.GetAdminService(s)
   242  	require.Nil(t, err)
   243  	require.NotNil(t, adminService)
   244  }
   245  
   246  func TestNewServerWithClosablePlugin(t *testing.T) {
   247  	closed := make(chan struct{})
   248  	plugin.Register("default", &closablePlugin{onClose: func() error {
   249  		close(closed)
   250  		return nil
   251  	}})
   252  
   253  	cfg := t.TempDir() + "trpc_go.yaml"
   254  	require.Nil(t, os.WriteFile(cfg, []byte(`
   255  plugins:
   256    closable_plugin:
   257      default:
   258  `), 0644))
   259  	oldPath := trpc.ServerConfigPath
   260  	trpc.ServerConfigPath = cfg
   261  	defer func() { trpc.ServerConfigPath = oldPath }()
   262  
   263  	s := trpc.NewServer()
   264  	require.Nil(t, s.Close(nil))
   265  	select {
   266  	case <-closed:
   267  	default:
   268  		require.FailNow(t, "plugin is not closed when server close")
   269  	}
   270  }
   271  
   272  type closablePlugin struct {
   273  	onClose func() error
   274  }
   275  
   276  func (*closablePlugin) Type() string {
   277  	return "closable_plugin"
   278  }
   279  
   280  func (*closablePlugin) Setup(string, plugin.Decoder) error {
   281  	return nil
   282  }
   283  
   284  func (p *closablePlugin) Close() error {
   285  	return p.onClose()
   286  }