trpc.group/trpc-go/trpc-go@v1.0.3/http/restful_server_transport_test.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package http_test 15 16 import ( 17 "bytes" 18 "context" 19 "crypto/tls" 20 "crypto/x509" 21 "encoding/base64" 22 "encoding/json" 23 "errors" 24 "io" 25 "net" 26 "net/http" 27 "os" 28 "strings" 29 "testing" 30 "time" 31 32 "github.com/stretchr/testify/require" 33 "github.com/valyala/fasthttp" 34 35 trpc "trpc.group/trpc-go/trpc-go" 36 "trpc.group/trpc-go/trpc-go/codec" 37 thttp "trpc.group/trpc-go/trpc-go/http" 38 itls "trpc.group/trpc-go/trpc-go/internal/tls" 39 "trpc.group/trpc-go/trpc-go/restful" 40 "trpc.group/trpc-go/trpc-go/server" 41 "trpc.group/trpc-go/trpc-go/testdata/restful/helloworld" 42 "trpc.group/trpc-go/trpc-go/transport" 43 ) 44 45 func TestCompatibility(t *testing.T) { 46 // Registers service. 47 serviceName := "trpc.test.server.Greeter" + t.Name() 48 ln, err := net.Listen("tcp", "127.0.0.1:0") 49 require.Nil(t, err) 50 defer ln.Close() 51 url := "http://" + ln.Addr().String() 52 s := &server.Server{} 53 service := server.New( 54 server.WithListener(ln), 55 server.WithServiceName(serviceName), 56 server.WithProtocol("restful"), 57 ) 58 s.AddService(serviceName, service) 59 helloworld.RegisterGreeterService(s, &greeterServerImpl{}) 60 61 go func() { require.Nil(t, s.Serve()) }() 62 defer s.Close(nil) 63 64 time.Sleep(100 * time.Millisecond) 65 66 // Removes compatibility setting. 67 restful.SetCtxForCompatibility( 68 func(ctx context.Context, w http.ResponseWriter, r *http.Request) context.Context { 69 return ctx 70 }, 71 ) 72 73 // Sends restful request. 74 req1, err := http.NewRequest("POST", url+"/v1/foobar", 75 bytes.NewBuffer([]byte(`{"name": "xyz"}`))) 76 require.Nil(t, err) 77 cli := http.Client{} 78 resp1, err := cli.Do(req1) 79 require.Nil(t, err) 80 defer resp1.Body.Close() 81 require.Equal(t, resp1.StatusCode, http.StatusInternalServerError) 82 83 // Adds compatibility setting. 84 restful.SetCtxForCompatibility(func(ctx context.Context, w http.ResponseWriter, 85 r *http.Request) context.Context { 86 return thttp.WithHeader(ctx, &thttp.Header{Response: w, Request: r}) 87 }) 88 89 // Sends restful request. 90 req2, err := http.NewRequest("POST", url+"/v1/foobar", 91 bytes.NewBuffer([]byte(`{"name": "xyz"}`))) 92 require.Nil(t, err) 93 resp2, err := cli.Do(req2) 94 require.Nil(t, err) 95 defer resp2.Body.Close() 96 require.Equal(t, resp2.StatusCode, http.StatusOK) 97 } 98 99 func TestEnableTLS(t *testing.T) { 100 // Registers service. 101 s := &server.Server{} 102 conf, err := itls.GetServerConfig("../testdata/ca.pem", "../testdata/server.crt", "../testdata/server.key") 103 require.Nil(t, err, "%+v", err) 104 ln, err := tls.Listen("tcp", "127.0.0.1:0", conf) 105 require.Nil(t, err) 106 defer ln.Close() 107 addr := strings.Split(ln.Addr().String(), ":") 108 require.Equal(t, 2, len(addr)) 109 port := addr[1] 110 // Must use localhost to replace 127.0.0.1, or else the following error will occur: 111 // tls: failed to verify certificate: x509: cannot validate certificate for 127.0.0.1 because it doesn't contain any IP SANs. 112 url := "https://localhost:" + port 113 service := server.New( 114 server.WithListener(ln), 115 server.WithServiceName("trpc.test.helloworld.Greeter"), 116 server.WithProtocol("restful"), 117 ) 118 s.AddService("trpc.test.helloworld.Greeter", service) 119 helloworld.RegisterGreeterService(s, &greeterServerImpl{}) 120 121 go func() { require.Nil(t, s.Serve()) }() 122 defer s.Close(nil) 123 124 time.Sleep(100 * time.Millisecond) 125 126 // Sends https request. 127 pool := x509.NewCertPool() 128 ca, err := os.ReadFile("../testdata/ca.pem") 129 require.Nil(t, err) 130 pool.AppendCertsFromPEM(ca) 131 cert, err := tls.LoadX509KeyPair("../testdata/client.crt", "../testdata/client.key") 132 require.Nil(t, err) 133 134 cli := &http.Client{ 135 Transport: &http.Transport{ 136 TLSClientConfig: &tls.Config{ 137 RootCAs: pool, 138 Certificates: []tls.Certificate{cert}, 139 }, 140 }, 141 } 142 143 req, err := http.NewRequest("POST", url+"/v1/foobar", 144 bytes.NewBuffer([]byte(`{"name": "xyz"}`))) 145 require.Nil(t, err) 146 147 resp, err := cli.Do(req) 148 require.Nil(t, err, "%+v", err) 149 defer resp.Body.Close() 150 require.Equal(t, resp.StatusCode, http.StatusOK) 151 152 bodyBytes, err := io.ReadAll(resp.Body) 153 require.Nil(t, err) 154 type responseBody struct { 155 Message string `json:"message"` 156 } 157 respBody := &responseBody{} 158 json.Unmarshal(bodyBytes, respBody) 159 require.Equal(t, respBody.Message, "test restful server transport") 160 } 161 162 func TestReplaceRouter(t *testing.T) { 163 st := thttp.NewRESTServerTransport(true, transport.WithReusePort(true)) 164 restful.RegisterRouter("replacing", http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})) 165 restful.RegisterRouter("no_replacing", restful.NewRouter()) 166 err := st.ListenAndServe(context.Background(), transport.WithServiceName("replacing")) 167 require.NotNil(t, err) 168 err = st.ListenAndServe(context.Background(), transport.WithServiceName("no_replacing")) 169 require.Nil(t, err) 170 } 171 172 var ( 173 headerMatcherTransInfo, _ = json.Marshal(map[string]string{ 174 "kfuin": base64.StdEncoding.EncodeToString([]byte("3009025887")), 175 }) 176 ) 177 178 func TestDefaultRESTHeaderMatcher(t *testing.T) { 179 bgctx := trpc.BackgroundContext() 180 req := http.Request{Header: make(http.Header)} 181 req.Header.Set(thttp.TrpcCaller, "TestDefaultHeaderMatcher") 182 req.Header.Set(thttp.TrpcTransInfo, string(headerMatcherTransInfo)) 183 req.Header.Set(thttp.TrpcTimeout, "2000") 184 req.Header.Set(thttp.TrpcMessageType, "1") 185 ctx, err := thttp.DefaultRESTHeaderMatcher(bgctx, nil, &req, "UTService", "UTMethod") 186 require.Nil(t, err) 187 msg := codec.Message(ctx) 188 require.Equal(t, "UTService", msg.CalleeServiceName()) 189 require.Equal(t, "UTMethod", msg.ServerRPCName()) 190 require.Equal(t, "TestDefaultHeaderMatcher", msg.CallerServiceName()) 191 require.Equal(t, time.Duration(2000*time.Millisecond), msg.RequestTimeout()) 192 require.Equal(t, "3009025887", string(trpc.GetMetaData(ctx, "kfuin"))) 193 require.Equal(t, true, msg.Dyeing()) 194 195 req.Header.Set(thttp.TrpcTransInfo, "") 196 req.Header.Set(thttp.TrpcMessageType, "0") 197 ctx, err = thttp.DefaultRESTHeaderMatcher(bgctx, nil, &req, "UTService", "UTMethod") 198 require.Nil(t, err) 199 msg = codec.Message(ctx) 200 require.Equal(t, "", string(trpc.GetMetaData(ctx, "kfuin"))) 201 require.Equal(t, false, msg.Dyeing()) 202 } 203 204 func TestDefaultRESTFastHTTPHeaderMatcher(t *testing.T) { 205 bgctx := trpc.BackgroundContext() 206 req := fasthttp.RequestCtx{} 207 req.Request.Header.Set(thttp.TrpcCaller, "TestDefaultHeaderMatcher") 208 req.Request.Header.Set(thttp.TrpcTransInfo, string(headerMatcherTransInfo)) 209 req.Request.Header.Set(thttp.TrpcTimeout, "2000") 210 req.Request.Header.Set(thttp.TrpcMessageType, "1") 211 ctx, err := thttp.DefaultRESTFastHTTPHeaderMatcher(bgctx, &req, "UTService", "UTMethod") 212 require.Nil(t, err) 213 msg := codec.Message(ctx) 214 require.Equal(t, "UTService", msg.CalleeServiceName()) 215 require.Equal(t, "UTMethod", msg.ServerRPCName()) 216 require.Equal(t, "TestDefaultHeaderMatcher", msg.CallerServiceName()) 217 require.Equal(t, time.Duration(2000*time.Millisecond), msg.RequestTimeout()) 218 require.Equal(t, "3009025887", string(trpc.GetMetaData(ctx, "kfuin"))) 219 require.Equal(t, true, msg.Dyeing()) 220 221 req = fasthttp.RequestCtx{} 222 req.Request.Header.Set(thttp.TrpcTransInfo, "xyz") 223 _, err = thttp.DefaultRESTFastHTTPHeaderMatcher(bgctx, &req, "UTService", "UTMethod") 224 require.NotNil(t, err) 225 } 226 227 func TestPassListenerUseTLS(t *testing.T) { 228 // Registers service. 229 serviceName := "trpc.test.helloworld.Greeter" + t.Name() 230 ln, err := net.Listen("tcp", "127.0.0.1:0") 231 require.Nil(t, err) 232 addr := strings.Split(ln.Addr().String(), ":") 233 require.Equal(t, 2, len(addr)) 234 port := addr[1] 235 // Must use localhost to replace 127.0.0.1, or else the following error will occur: 236 // tls: failed to verify certificate: x509: cannot validate certificate for 127.0.0.1 because it doesn't contain any IP SANs. 237 url := "https://localhost:" + port 238 s := &server.Server{} 239 service := server.New( 240 server.WithListener(ln), 241 server.WithServiceName(serviceName), 242 server.WithProtocol("restful"), 243 server.WithTLS("../testdata/server.crt", "../testdata/server.key", "../testdata/ca.pem"), 244 ) 245 s.AddService(serviceName, service) 246 helloworld.RegisterGreeterService(s, &greeterServerImpl{}) 247 248 go func() { 249 err := s.Serve() 250 require.Nil(t, err) 251 }() 252 defer s.Close(nil) 253 254 time.Sleep(100 * time.Millisecond) 255 256 // Sends https request. 257 pool := x509.NewCertPool() 258 ca, err := os.ReadFile("../testdata/ca.pem") 259 require.Nil(t, err) 260 pool.AppendCertsFromPEM(ca) 261 cert, err := tls.LoadX509KeyPair("../testdata/client.crt", "../testdata/client.key") 262 require.Nil(t, err) 263 264 cli := &http.Client{ 265 Transport: &http.Transport{ 266 TLSClientConfig: &tls.Config{ 267 RootCAs: pool, 268 Certificates: []tls.Certificate{cert}, 269 }, 270 }, 271 } 272 273 req, err := http.NewRequest("POST", url+"/v1/foobar", 274 bytes.NewBuffer([]byte(`{"name": "xyz"}`))) 275 require.Nil(t, err) 276 277 resp, err := cli.Do(req) 278 require.Nil(t, err, "err: %+v", err) 279 defer resp.Body.Close() 280 require.Equal(t, resp.StatusCode, http.StatusOK) 281 282 bodyBytes, err := io.ReadAll(resp.Body) 283 require.Nil(t, err) 284 type responseBody struct { 285 Message string `json:"message"` 286 } 287 respBody := &responseBody{} 288 json.Unmarshal(bodyBytes, respBody) 289 require.Equal(t, respBody.Message, "test restful server transport") 290 } 291 292 func TestListenAndServeInvalidAddrErr(t *testing.T) { 293 serviceName := "trpc.test.helloworld.Greeter" + t.Name() 294 s := &server.Server{} 295 invalidAddr := "888.888.888.888:88888" 296 service := server.New( 297 server.WithAddress(invalidAddr), 298 server.WithServiceName(serviceName), 299 server.WithProtocol("restful"), 300 ) 301 s.AddService(serviceName, service) 302 require.NotNil(t, s.Serve()) 303 } 304 305 type greeterServerImpl struct{} 306 307 func (s *greeterServerImpl) SayHello(ctx context.Context, req *helloworld.HelloRequest) (*helloworld.HelloReply, error) { 308 rsp := &helloworld.HelloReply{} 309 if thttp.Head(ctx) == nil { 310 return nil, errors.New("test error") 311 } 312 rsp.Message = "test restful server transport" 313 return rsp, nil 314 }