trpc.group/trpc-go/trpc-go@v1.0.3/restful/router_test.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package restful_test
    15  
    16  import (
    17  	"context"
    18  	"encoding/json"
    19  	"errors"
    20  	"fmt"
    21  	"io"
    22  	"net"
    23  	"net/http"
    24  	"os"
    25  	"strconv"
    26  	"testing"
    27  	"time"
    28  
    29  	"github.com/stretchr/testify/require"
    30  
    31  	trpc "trpc.group/trpc-go/trpc-go"
    32  	"trpc.group/trpc-go/trpc-go/errs"
    33  	"trpc.group/trpc-go/trpc-go/filter"
    34  	thttp "trpc.group/trpc-go/trpc-go/http"
    35  	"trpc.group/trpc-go/trpc-go/restful"
    36  	"trpc.group/trpc-go/trpc-go/server"
    37  	"trpc.group/trpc-go/trpc-go/testdata/restful/helloworld"
    38  )
    39  
    40  // ------------------------------------- old stub -----------------------------------------//
    41  
    42  type GreeterService interface {
    43  	SayHello(ctx context.Context, req *helloworld.HelloRequest) (rsp *helloworld.HelloReply, err error)
    44  }
    45  
    46  func GreeterService_SayHello_Handler(svr interface{}, ctx context.Context, f server.FilterFunc) (
    47  	rspBody interface{}, err error) {
    48  	req := &helloworld.HelloRequest{}
    49  	filters, err := f(req)
    50  	if err != nil {
    51  		return nil, err
    52  	}
    53  	handleFunc := func(ctx context.Context, reqbody interface{}) (rspbody interface{}, err error) {
    54  		return svr.(GreeterService).SayHello(ctx, reqbody.(*helloworld.HelloRequest))
    55  	}
    56  
    57  	rsp, err := filters.Filter(ctx, req, handleFunc)
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  
    62  	return rsp, nil
    63  }
    64  
    65  var GreeterServer_ServiceDesc = server.ServiceDesc{
    66  	ServiceName: "trpc.examples.restful.helloworld.Greeter",
    67  	HandlerType: (*GreeterService)(nil),
    68  	Methods: []server.Method{
    69  		{
    70  			Name: "/trpc.examples.restful.helloworld.Greeter/SayHello",
    71  			Func: GreeterService_SayHello_Handler,
    72  			Bindings: []*restful.Binding{
    73  				{
    74  					Name:   "/trpc.examples.restful.helloworld.Greeter/SayHello",
    75  					Input:  func() restful.ProtoMessage { return new(helloworld.HelloRequest) },
    76  					Output: func() restful.ProtoMessage { return new(helloworld.HelloReply) },
    77  					Filter: func(svc interface{}, ctx context.Context, reqBody interface{}) (interface{}, error) {
    78  						return svc.(GreeterService).SayHello(ctx, reqBody.(*helloworld.HelloRequest))
    79  					},
    80  					HTTPMethod:   "GET",
    81  					Pattern:      restful.Enforce("/v2/bar/{name}"),
    82  					Body:         nil,
    83  					ResponseBody: nil,
    84  				},
    85  			},
    86  		},
    87  	},
    88  }
    89  
    90  func RegisterGreeterService(s server.Service, svr GreeterService) {
    91  	if err := s.Register(&GreeterServer_ServiceDesc, svr); err != nil {
    92  		panic(fmt.Sprintf("Greeter register error:%v", err))
    93  	}
    94  }
    95  
    96  // ------------------------------------------------------------------------------------------//
    97  
    98  type greeter struct{}
    99  
   100  func (s *greeter) SayHello(ctx context.Context, req *helloworld.HelloRequest) (*helloworld.HelloReply, error) {
   101  	rsp := &helloworld.HelloReply{}
   102  	rsp.Message = req.Name
   103  	return rsp, nil
   104  }
   105  
   106  func TestPreviousVersionStub(t *testing.T) {
   107  	var serverFilter filter.ServerFilter = func(ctx context.Context, req interface{},
   108  		next filter.ServerHandleFunc) (rsp interface{}, err error) {
   109  		helloReq, ok := req.(*helloworld.HelloRequest)
   110  		if !ok {
   111  			return nil, errors.New("invalid request")
   112  		}
   113  		if helloReq.Name != "world" {
   114  			return nil, errors.New("wrong name")
   115  		}
   116  		resp, err := next(ctx, req)
   117  		if err != nil {
   118  			return nil, err
   119  		}
   120  		helloResp, ok := resp.(*helloworld.HelloReply)
   121  		if !ok {
   122  			return nil, errors.New("invalid response")
   123  		}
   124  		helloResp.Message += "a"
   125  		return helloResp, nil
   126  	}
   127  	filter.Register("restful.oldversion.stub", serverFilter, nil)
   128  
   129  	// service registration
   130  	s := &server.Server{}
   131  	service := server.New(
   132  		server.WithAddress("127.0.0.1:32781"),
   133  		server.WithServiceName("trpc.test.helloworld.GreeterPreviousVersionStub"),
   134  		server.WithNetwork("tcp"),
   135  		server.WithProtocol("restful"),
   136  		server.WithFilter(filter.GetServer("restful.oldversion.stub")),
   137  	)
   138  	s.AddService("trpc.test.helloworld.GreeterPreviousVersionStub", service)
   139  	RegisterGreeterService(s, &greeter{})
   140  
   141  	// start server
   142  	go func() {
   143  		err := s.Serve()
   144  		require.Nil(t, err)
   145  	}()
   146  
   147  	time.Sleep(100 * time.Millisecond)
   148  
   149  	// create restful request
   150  	req, err := http.NewRequest("GET", "http://127.0.0.1:32781/v2/bar/world", nil)
   151  	require.Nil(t, err)
   152  
   153  	// send restful request
   154  	cli := http.Client{}
   155  	resp1, err := cli.Do(req)
   156  	require.Nil(t, err)
   157  	defer resp1.Body.Close()
   158  	require.Equal(t, resp1.StatusCode, http.StatusOK)
   159  	bodyBytes1, err := io.ReadAll(resp1.Body)
   160  	require.Nil(t, err)
   161  	type responseBody struct {
   162  		Message string `json:"message"`
   163  	}
   164  	respBody := &responseBody{}
   165  	json.Unmarshal(bodyBytes1, respBody)
   166  	require.Equal(t, "worlda", respBody.Message)
   167  
   168  	resp2, err := cli.Do(req)
   169  	require.Nil(t, err)
   170  	defer resp2.Body.Close()
   171  	require.Equal(t, resp2.StatusCode, http.StatusOK)
   172  	bodyBytes2, err := io.ReadAll(resp2.Body)
   173  	require.Nil(t, err)
   174  	json.Unmarshal(bodyBytes2, respBody)
   175  	require.Equal(t, "worlda", respBody.Message)
   176  }
   177  
   178  func TestTRPCGlobalMessage(t *testing.T) {
   179  	cfgPath := t.TempDir() + "/cfg.yaml"
   180  	require.Nil(t, os.WriteFile(cfgPath, []byte(`
   181  global:
   182    namespace: development
   183    env_name: environment
   184    container_name: container
   185    enable_set: Y
   186    full_set_name: full.set.name
   187  server:
   188    service:
   189      - name: trpc.test.helloworld.Greeter
   190        protocol: restful
   191  `), 0644))
   192  	trpc.ServerConfigPath = cfgPath
   193  
   194  	l, err := net.Listen("tcp", "127.0.0.1:0")
   195  	require.Nil(t, err)
   196  
   197  	s := trpc.NewServer(server.WithRESTOptions(
   198  		restful.WithFilterFunc(func() filter.ServerChain {
   199  			return []filter.ServerFilter{
   200  				func(ctx context.Context, req interface{}, next filter.ServerHandleFunc) (rsp interface{}, err error) {
   201  					msg := trpc.Message(ctx)
   202  					require.Equal(t, "development", msg.Namespace())
   203  					require.Equal(t, "environment", msg.EnvName())
   204  					require.Equal(t, "container", msg.CalleeContainerName())
   205  					require.Equal(t, "full.set.name", msg.SetName())
   206  					return next(ctx, req)
   207  				},
   208  			}
   209  		})),
   210  		server.WithListener(l))
   211  	RegisterGreeterService(s, &greeter{})
   212  	go func() {
   213  		fmt.Println(s.Serve())
   214  	}()
   215  
   216  	rsp, err := http.Get(fmt.Sprintf("http://%s/v2/bar/world", l.Addr().String()))
   217  	require.Nil(t, err)
   218  	require.Equal(t, http.StatusOK, rsp.StatusCode)
   219  }
   220  
   221  func TestHTTPOkWithDetailedError(t *testing.T) {
   222  	l, err := net.Listen("tcp", "127.0.0.1:0")
   223  	require.Nil(t, err)
   224  	s := server.New(
   225  		server.WithListener(l),
   226  		server.WithServiceName("trpc.test.helloworld.Greeter2"),
   227  		server.WithNetwork("tcp"),
   228  		server.WithProtocol("restful"),
   229  		server.WithRESTOptions(
   230  			restful.WithErrorHandler(func(ctx context.Context, w http.ResponseWriter, r *http.Request, err error) {
   231  				restful.DefaultErrorHandler(ctx, w, r, &restful.WithStatusCode{StatusCode: http.StatusOK, Err: err})
   232  			})),
   233  		server.WithFilter(func(
   234  			ctx context.Context,
   235  			req interface{},
   236  			next filter.ServerHandleFunc,
   237  		) (rsp interface{}, err error) {
   238  			return nil, errs.New(errs.RetServerThrottled, "always throttled")
   239  		}))
   240  	RegisterGreeterService(s, &greeter{})
   241  	go func() {
   242  		fmt.Println(s.Serve())
   243  	}()
   244  
   245  	rsp, err := http.Get(fmt.Sprintf("http://%s/v2/bar/world", l.Addr().String()))
   246  	require.Nil(t, err)
   247  	defer rsp.Body.Close()
   248  	require.Equal(t, http.StatusOK, rsp.StatusCode)
   249  	rspBody, err := io.ReadAll(rsp.Body)
   250  	require.Nil(t, err)
   251  	require.Contains(t, string(rspBody), strconv.Itoa(int(errs.RetServerThrottled)))
   252  	require.NotContains(t, string(rspBody), strconv.Itoa(int(errs.RetUnknown)))
   253  	require.Contains(t, string(rspBody), "always throttled")
   254  }
   255  
   256  func TestNoPanicOnFilterReturnsNil(t *testing.T) {
   257  	l, err := net.Listen("tcp", "127.0.0.1:0")
   258  	require.Nil(t, err)
   259  	s := server.New(
   260  		server.WithListener(l),
   261  		server.WithServiceName("trpc.test.helloworld.Greeter3"),
   262  		server.WithNetwork("tcp"),
   263  		server.WithProtocol("restful"),
   264  		server.WithFilter(func(
   265  			ctx context.Context, req interface{}, next filter.ServerHandleFunc,
   266  		) (rsp interface{}, err error) {
   267  			head := ctx.Value(thttp.ContextKeyHeader).(*thttp.Header)
   268  			head.Response.Header().Add(t.Name(), t.Name())
   269  			return nil, nil
   270  		}))
   271  	RegisterGreeterService(s, &greeter{})
   272  	go func() {
   273  		fmt.Println(s.Serve())
   274  	}()
   275  
   276  	rsp, err := http.Get(fmt.Sprintf("http://%s/v2/bar/world", l.Addr().String()))
   277  	require.Nil(t, err)
   278  	defer rsp.Body.Close()
   279  	require.Equal(t, http.StatusOK, rsp.StatusCode)
   280  	require.Equal(t, t.Name(), rsp.Header.Get(t.Name()))
   281  }
   282  
   283  func TestTimeout(t *testing.T) {
   284  	l, err := net.Listen("tcp", "localhost:")
   285  	require.Nil(t, err)
   286  	s := server.New(
   287  		server.WithListener(l),
   288  		server.WithServiceName(t.Name()),
   289  		server.WithProtocol("restful"),
   290  		server.WithTimeout(time.Millisecond*100))
   291  	RegisterGreeterService(s, &greeterAlwaysTimeout{})
   292  	errCh := make(chan error)
   293  	go func() { errCh <- s.Serve() }()
   294  	select {
   295  	case err := <-errCh:
   296  		require.FailNow(t, "serve failed", err)
   297  	case <-time.After(time.Millisecond * 100):
   298  	}
   299  	defer s.Close(nil)
   300  
   301  	start := time.Now()
   302  	rsp, err := http.Get(fmt.Sprintf("http://%s/v2/bar/world", l.Addr().String()))
   303  	require.Nil(t, err)
   304  	require.Equal(t, http.StatusGatewayTimeout, rsp.StatusCode)
   305  	require.InDelta(t, time.Millisecond*100, time.Since(start), float64(time.Millisecond*30))
   306  }
   307  
   308  type greeterAlwaysTimeout struct{}
   309  
   310  func (*greeterAlwaysTimeout) SayHello(ctx context.Context, req *helloworld.HelloRequest) (*helloworld.HelloReply, error) {
   311  	<-ctx.Done()
   312  	return nil, errs.NewFrameError(errs.RetServerTimeout, "ctx timeout")
   313  }