trpc.group/trpc-go/trpc-go@v1.0.3/restful/router_test.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package restful_test 15 16 import ( 17 "context" 18 "encoding/json" 19 "errors" 20 "fmt" 21 "io" 22 "net" 23 "net/http" 24 "os" 25 "strconv" 26 "testing" 27 "time" 28 29 "github.com/stretchr/testify/require" 30 31 trpc "trpc.group/trpc-go/trpc-go" 32 "trpc.group/trpc-go/trpc-go/errs" 33 "trpc.group/trpc-go/trpc-go/filter" 34 thttp "trpc.group/trpc-go/trpc-go/http" 35 "trpc.group/trpc-go/trpc-go/restful" 36 "trpc.group/trpc-go/trpc-go/server" 37 "trpc.group/trpc-go/trpc-go/testdata/restful/helloworld" 38 ) 39 40 // ------------------------------------- old stub -----------------------------------------// 41 42 type GreeterService interface { 43 SayHello(ctx context.Context, req *helloworld.HelloRequest) (rsp *helloworld.HelloReply, err error) 44 } 45 46 func GreeterService_SayHello_Handler(svr interface{}, ctx context.Context, f server.FilterFunc) ( 47 rspBody interface{}, err error) { 48 req := &helloworld.HelloRequest{} 49 filters, err := f(req) 50 if err != nil { 51 return nil, err 52 } 53 handleFunc := func(ctx context.Context, reqbody interface{}) (rspbody interface{}, err error) { 54 return svr.(GreeterService).SayHello(ctx, reqbody.(*helloworld.HelloRequest)) 55 } 56 57 rsp, err := filters.Filter(ctx, req, handleFunc) 58 if err != nil { 59 return nil, err 60 } 61 62 return rsp, nil 63 } 64 65 var GreeterServer_ServiceDesc = server.ServiceDesc{ 66 ServiceName: "trpc.examples.restful.helloworld.Greeter", 67 HandlerType: (*GreeterService)(nil), 68 Methods: []server.Method{ 69 { 70 Name: "/trpc.examples.restful.helloworld.Greeter/SayHello", 71 Func: GreeterService_SayHello_Handler, 72 Bindings: []*restful.Binding{ 73 { 74 Name: "/trpc.examples.restful.helloworld.Greeter/SayHello", 75 Input: func() restful.ProtoMessage { return new(helloworld.HelloRequest) }, 76 Output: func() restful.ProtoMessage { return new(helloworld.HelloReply) }, 77 Filter: func(svc interface{}, ctx context.Context, reqBody interface{}) (interface{}, error) { 78 return svc.(GreeterService).SayHello(ctx, reqBody.(*helloworld.HelloRequest)) 79 }, 80 HTTPMethod: "GET", 81 Pattern: restful.Enforce("/v2/bar/{name}"), 82 Body: nil, 83 ResponseBody: nil, 84 }, 85 }, 86 }, 87 }, 88 } 89 90 func RegisterGreeterService(s server.Service, svr GreeterService) { 91 if err := s.Register(&GreeterServer_ServiceDesc, svr); err != nil { 92 panic(fmt.Sprintf("Greeter register error:%v", err)) 93 } 94 } 95 96 // ------------------------------------------------------------------------------------------// 97 98 type greeter struct{} 99 100 func (s *greeter) SayHello(ctx context.Context, req *helloworld.HelloRequest) (*helloworld.HelloReply, error) { 101 rsp := &helloworld.HelloReply{} 102 rsp.Message = req.Name 103 return rsp, nil 104 } 105 106 func TestPreviousVersionStub(t *testing.T) { 107 var serverFilter filter.ServerFilter = func(ctx context.Context, req interface{}, 108 next filter.ServerHandleFunc) (rsp interface{}, err error) { 109 helloReq, ok := req.(*helloworld.HelloRequest) 110 if !ok { 111 return nil, errors.New("invalid request") 112 } 113 if helloReq.Name != "world" { 114 return nil, errors.New("wrong name") 115 } 116 resp, err := next(ctx, req) 117 if err != nil { 118 return nil, err 119 } 120 helloResp, ok := resp.(*helloworld.HelloReply) 121 if !ok { 122 return nil, errors.New("invalid response") 123 } 124 helloResp.Message += "a" 125 return helloResp, nil 126 } 127 filter.Register("restful.oldversion.stub", serverFilter, nil) 128 129 // service registration 130 s := &server.Server{} 131 service := server.New( 132 server.WithAddress("127.0.0.1:32781"), 133 server.WithServiceName("trpc.test.helloworld.GreeterPreviousVersionStub"), 134 server.WithNetwork("tcp"), 135 server.WithProtocol("restful"), 136 server.WithFilter(filter.GetServer("restful.oldversion.stub")), 137 ) 138 s.AddService("trpc.test.helloworld.GreeterPreviousVersionStub", service) 139 RegisterGreeterService(s, &greeter{}) 140 141 // start server 142 go func() { 143 err := s.Serve() 144 require.Nil(t, err) 145 }() 146 147 time.Sleep(100 * time.Millisecond) 148 149 // create restful request 150 req, err := http.NewRequest("GET", "http://127.0.0.1:32781/v2/bar/world", nil) 151 require.Nil(t, err) 152 153 // send restful request 154 cli := http.Client{} 155 resp1, err := cli.Do(req) 156 require.Nil(t, err) 157 defer resp1.Body.Close() 158 require.Equal(t, resp1.StatusCode, http.StatusOK) 159 bodyBytes1, err := io.ReadAll(resp1.Body) 160 require.Nil(t, err) 161 type responseBody struct { 162 Message string `json:"message"` 163 } 164 respBody := &responseBody{} 165 json.Unmarshal(bodyBytes1, respBody) 166 require.Equal(t, "worlda", respBody.Message) 167 168 resp2, err := cli.Do(req) 169 require.Nil(t, err) 170 defer resp2.Body.Close() 171 require.Equal(t, resp2.StatusCode, http.StatusOK) 172 bodyBytes2, err := io.ReadAll(resp2.Body) 173 require.Nil(t, err) 174 json.Unmarshal(bodyBytes2, respBody) 175 require.Equal(t, "worlda", respBody.Message) 176 } 177 178 func TestTRPCGlobalMessage(t *testing.T) { 179 cfgPath := t.TempDir() + "/cfg.yaml" 180 require.Nil(t, os.WriteFile(cfgPath, []byte(` 181 global: 182 namespace: development 183 env_name: environment 184 container_name: container 185 enable_set: Y 186 full_set_name: full.set.name 187 server: 188 service: 189 - name: trpc.test.helloworld.Greeter 190 protocol: restful 191 `), 0644)) 192 trpc.ServerConfigPath = cfgPath 193 194 l, err := net.Listen("tcp", "127.0.0.1:0") 195 require.Nil(t, err) 196 197 s := trpc.NewServer(server.WithRESTOptions( 198 restful.WithFilterFunc(func() filter.ServerChain { 199 return []filter.ServerFilter{ 200 func(ctx context.Context, req interface{}, next filter.ServerHandleFunc) (rsp interface{}, err error) { 201 msg := trpc.Message(ctx) 202 require.Equal(t, "development", msg.Namespace()) 203 require.Equal(t, "environment", msg.EnvName()) 204 require.Equal(t, "container", msg.CalleeContainerName()) 205 require.Equal(t, "full.set.name", msg.SetName()) 206 return next(ctx, req) 207 }, 208 } 209 })), 210 server.WithListener(l)) 211 RegisterGreeterService(s, &greeter{}) 212 go func() { 213 fmt.Println(s.Serve()) 214 }() 215 216 rsp, err := http.Get(fmt.Sprintf("http://%s/v2/bar/world", l.Addr().String())) 217 require.Nil(t, err) 218 require.Equal(t, http.StatusOK, rsp.StatusCode) 219 } 220 221 func TestHTTPOkWithDetailedError(t *testing.T) { 222 l, err := net.Listen("tcp", "127.0.0.1:0") 223 require.Nil(t, err) 224 s := server.New( 225 server.WithListener(l), 226 server.WithServiceName("trpc.test.helloworld.Greeter2"), 227 server.WithNetwork("tcp"), 228 server.WithProtocol("restful"), 229 server.WithRESTOptions( 230 restful.WithErrorHandler(func(ctx context.Context, w http.ResponseWriter, r *http.Request, err error) { 231 restful.DefaultErrorHandler(ctx, w, r, &restful.WithStatusCode{StatusCode: http.StatusOK, Err: err}) 232 })), 233 server.WithFilter(func( 234 ctx context.Context, 235 req interface{}, 236 next filter.ServerHandleFunc, 237 ) (rsp interface{}, err error) { 238 return nil, errs.New(errs.RetServerThrottled, "always throttled") 239 })) 240 RegisterGreeterService(s, &greeter{}) 241 go func() { 242 fmt.Println(s.Serve()) 243 }() 244 245 rsp, err := http.Get(fmt.Sprintf("http://%s/v2/bar/world", l.Addr().String())) 246 require.Nil(t, err) 247 defer rsp.Body.Close() 248 require.Equal(t, http.StatusOK, rsp.StatusCode) 249 rspBody, err := io.ReadAll(rsp.Body) 250 require.Nil(t, err) 251 require.Contains(t, string(rspBody), strconv.Itoa(int(errs.RetServerThrottled))) 252 require.NotContains(t, string(rspBody), strconv.Itoa(int(errs.RetUnknown))) 253 require.Contains(t, string(rspBody), "always throttled") 254 } 255 256 func TestNoPanicOnFilterReturnsNil(t *testing.T) { 257 l, err := net.Listen("tcp", "127.0.0.1:0") 258 require.Nil(t, err) 259 s := server.New( 260 server.WithListener(l), 261 server.WithServiceName("trpc.test.helloworld.Greeter3"), 262 server.WithNetwork("tcp"), 263 server.WithProtocol("restful"), 264 server.WithFilter(func( 265 ctx context.Context, req interface{}, next filter.ServerHandleFunc, 266 ) (rsp interface{}, err error) { 267 head := ctx.Value(thttp.ContextKeyHeader).(*thttp.Header) 268 head.Response.Header().Add(t.Name(), t.Name()) 269 return nil, nil 270 })) 271 RegisterGreeterService(s, &greeter{}) 272 go func() { 273 fmt.Println(s.Serve()) 274 }() 275 276 rsp, err := http.Get(fmt.Sprintf("http://%s/v2/bar/world", l.Addr().String())) 277 require.Nil(t, err) 278 defer rsp.Body.Close() 279 require.Equal(t, http.StatusOK, rsp.StatusCode) 280 require.Equal(t, t.Name(), rsp.Header.Get(t.Name())) 281 } 282 283 func TestTimeout(t *testing.T) { 284 l, err := net.Listen("tcp", "localhost:") 285 require.Nil(t, err) 286 s := server.New( 287 server.WithListener(l), 288 server.WithServiceName(t.Name()), 289 server.WithProtocol("restful"), 290 server.WithTimeout(time.Millisecond*100)) 291 RegisterGreeterService(s, &greeterAlwaysTimeout{}) 292 errCh := make(chan error) 293 go func() { errCh <- s.Serve() }() 294 select { 295 case err := <-errCh: 296 require.FailNow(t, "serve failed", err) 297 case <-time.After(time.Millisecond * 100): 298 } 299 defer s.Close(nil) 300 301 start := time.Now() 302 rsp, err := http.Get(fmt.Sprintf("http://%s/v2/bar/world", l.Addr().String())) 303 require.Nil(t, err) 304 require.Equal(t, http.StatusGatewayTimeout, rsp.StatusCode) 305 require.InDelta(t, time.Millisecond*100, time.Since(start), float64(time.Millisecond*30)) 306 } 307 308 type greeterAlwaysTimeout struct{} 309 310 func (*greeterAlwaysTimeout) SayHello(ctx context.Context, req *helloworld.HelloRequest) (*helloworld.HelloReply, error) { 311 <-ctx.Done() 312 return nil, errs.NewFrameError(errs.RetServerTimeout, "ctx timeout") 313 }