trpc.group/trpc-go/trpc-go@v1.0.3/http/transport_test.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package http_test 15 16 // https certificate file generation method: 17 // 1. ca certificate: 18 // openssl genrsa -out ca.key 2048 19 // openssl req -x509 -new -nodes -key ca.key -subj "/CN=*" -days 5000 -out ca.pem 20 // 2. server certificate: 21 // openssl genrsa -out server.key 2048 22 // openssl req -new -key server.key -subj "/CN=*" -out server.csr 23 // openssl x509 -req -in server.csr -CA ca.pem -CAkey ca.key -CAcreateserial -out server.crt -days 5000 <(printf "subjectAltName=DNS:localhost") 24 // 3. client certificate: 25 // openssl genrsa -out client.key 2048 26 // openssl req -new -key client.key -subj "/CN=*" -out client.csr 27 // openssl x509 -req -in client.csr -CA ca.pem -CAkey ca.key -CAcreateserial -out client.crt -days 5000 <(printf "subjectAltName=DNS:localhost") 28 // 4. show certificate content: 29 // openssl x509 -text -in server.crt -noout 30 31 import ( 32 "bytes" 33 "context" 34 "crypto/tls" 35 "errors" 36 "fmt" 37 "io" 38 "mime/multipart" 39 "net" 40 "net/http" 41 "net/http/httptest" 42 "net/url" 43 "os" 44 "path" 45 "path/filepath" 46 "strconv" 47 "strings" 48 "testing" 49 "time" 50 51 "github.com/stretchr/testify/require" 52 "golang.org/x/net/http2" 53 trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc" 54 55 "trpc.group/trpc-go/trpc-go/client" 56 "trpc.group/trpc-go/trpc-go/codec" 57 "trpc.group/trpc-go/trpc-go/errs" 58 "trpc.group/trpc-go/trpc-go/filter" 59 thttp "trpc.group/trpc-go/trpc-go/http" 60 "trpc.group/trpc-go/trpc-go/log" 61 "trpc.group/trpc-go/trpc-go/naming/registry" 62 "trpc.group/trpc-go/trpc-go/server" 63 "trpc.group/trpc-go/trpc-go/testdata/restful/helloworld" 64 "trpc.group/trpc-go/trpc-go/transport" 65 ) 66 67 func newNoopStdHTTPServer() *http.Server { return &http.Server{} } 68 69 func TestStartServer(t *testing.T) { 70 ctx := context.Background() 71 tp := thttp.NewServerTransport(newNoopStdHTTPServer) 72 ln, err := net.Listen("tcp", "127.0.0.1:0") 73 require.Nil(t, err) 74 defer ln.Close() 75 option := transport.WithListener(ln) 76 handler := transport.WithHandler(transport.Handler(&h{})) 77 require.Nil(t, tp.ListenAndServe(ctx, option, handler), "Failed to new client transport") 78 require.NotNil(t, tp.ListenAndServe(ctx, transport.WithListenAddress("127.0.0.1:8888"), handler, transport.WithListenNetwork("tcp1"))) 79 tls := transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", "ca1") 80 require.NotNil(t, tp.ListenAndServe(ctx, option, handler, tls)) 81 } 82 83 func TestH2C(t *testing.T) { 84 ctx := context.Background() 85 ln, err := net.Listen("tcp", "127.0.0.1:0") 86 require.Nil(t, err) 87 defer ln.Close() 88 handler := transport.WithHandler(transport.Handler(&h{})) 89 tp := thttp.NewServerTransport(newNoopStdHTTPServer, thttp.WithReusePort(), thttp.WithEnableH2C()) 90 require.Nil(t, tp.ListenAndServe(ctx, transport.WithListener(ln), handler)) 91 } 92 93 func TestDisableReusePort(t *testing.T) { 94 ctx := context.Background() 95 tp := thttp.NewServerTransport(newNoopStdHTTPServer) 96 ln1, err := net.Listen("tcp", "127.0.0.1:0") 97 require.Nil(t, err) 98 defer ln1.Close() 99 option := transport.WithListener(ln1) 100 handler := transport.WithHandler(transport.Handler(&h{})) 101 require.Nil(t, tp.ListenAndServe(ctx, option, handler), "Failed to new client transport") 102 103 option = transport.WithListenAddress(ln1.Addr().String()) 104 require.NotNil(t, tp.ListenAndServe(ctx, option, handler, transport.WithListenNetwork("tcp1"))) 105 106 ln2, err := net.Listen("tcp", "127.0.0.1:0") 107 require.Nil(t, err) 108 defer ln2.Close() 109 option = transport.WithListener(ln2) 110 tls := transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", "") 111 require.Nil(t, tp.ListenAndServe(ctx, option, handler, tls)) 112 113 ln3, err := net.Listen("tcp", "127.0.0.1:0") 114 require.Nil(t, err) 115 defer ln3.Close() 116 option = transport.WithListener(ln3) 117 tls = transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", "root") 118 require.Nil(t, tp.ListenAndServe(ctx, option, handler, tls)) 119 120 ln4, err := net.Listen("tcp", "127.0.0.1:0") 121 require.Nil(t, err) 122 defer ln4.Close() 123 option = transport.WithListener(ln4) 124 tls = transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", "../testdata/ca.key") 125 require.NotNil(t, tp.ListenAndServe(ctx, option, handler, tls)) 126 } 127 128 func TestStartServerWithNoHandler(t *testing.T) { 129 ctx := context.Background() 130 tp := thttp.NewServerTransport(newNoopStdHTTPServer) 131 ln, err := net.Listen("tcp", "127.0.0.1:0") 132 require.Nil(t, err) 133 defer ln.Close() 134 option := transport.WithListener(ln) 135 require.NotNil(t, tp.ListenAndServe(ctx, option), "http server transport handler empty") 136 } 137 138 func TestErrHandler(t *testing.T) { 139 ctx := context.Background() 140 tp := thttp.NewServerTransport(newNoopStdHTTPServer) 141 ln, err := net.Listen("tcp", "127.0.0.1:0") 142 require.Nil(t, err) 143 defer ln.Close() 144 option := transport.WithListener(ln) 145 h := transport.WithHandler(transport.Handler(&errHandler{})) 146 require.Nil(t, tp.ListenAndServe(ctx, option, h)) 147 148 ct := thttp.NewClientTransport(true) 149 ctx, msg := codec.WithNewMessage(ctx) 150 msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello") 151 msg.WithClientReqHead(&thttp.ClientReqHeader{}) 152 msg.WithClientRspHead(&thttp.ClientRspHeader{}) 153 154 rsp, err := ct.RoundTrip(ctx, []byte("{\"username\":\"xyz\","+ 155 "\"password\":\"xyz\",\"from\":\"xyz\"}"), 156 transport.WithDialAddress(ln.Addr().String()), 157 ) 158 require.Nil(t, rsp, "roundtrip rsp not empty") 159 require.Nil(t, err, "Failed to roundtrip") 160 } 161 162 func TestErrHeaderHandler(t *testing.T) { 163 ctx := context.Background() 164 tp := thttp.NewServerTransport(newNoopStdHTTPServer) 165 ln, err := net.Listen("tcp", "127.0.0.1:0") 166 require.Nil(t, err) 167 defer func() { require.Nil(t, ln.Close()) }() 168 err = tp.ListenAndServe(ctx, 169 transport.WithHandler(transport.Handler(&errHeaderHandler{})), 170 transport.WithListener(ln), 171 ) 172 require.Nil(t, err) 173 174 ct := thttp.NewClientTransport(true) 175 ctx, msg := codec.WithNewMessage(ctx) 176 msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello") 177 msg.WithClientReqHead(&thttp.ClientReqHeader{}) 178 msg.WithClientRspHead(&thttp.ClientRspHeader{}) 179 180 rsp, err := ct.RoundTrip(ctx, []byte("{\"username\":\"xyz\","+ 181 "\"password\":\"xyz\",\"from\":\"xyz\"}"), 182 transport.WithDialAddress(ln.Addr().String()), 183 ) 184 require.Nil(t, rsp, "roundtrip rsp not empty") 185 require.Nil(t, err, "Failed to roundtrip") 186 } 187 188 func TestListenAndServeFailedDueToBadCertificationFile(t *testing.T) { 189 ctx := context.Background() 190 oldLogger := log.DefaultLogger 191 defer func() { 192 log.DefaultLogger = oldLogger 193 }() 194 errorCh := make(chan error) 195 log.DefaultLogger = &testLog{Logger: oldLogger, errorCh: errorCh} 196 197 ln, err := net.Listen("tcp", "127.0.0.1:0") 198 require.Nil(t, err) 199 defer func() { require.Nil(t, ln.Close()) }() 200 const badCertFile = "bad-file.cert" 201 require.Nil( 202 t, 203 thttp.NewServerTransport(newNoopStdHTTPServer).ListenAndServe( 204 ctx, 205 transport.WithListener(ln), 206 transport.WithHandler(transport.Handler(&h{})), 207 transport.WithServeTLS(badCertFile, "../testdata/server.key", ""), 208 ), 209 "failed to new client transport", 210 ) 211 212 select { 213 case <-time.After(time.Second): 214 t.Fatal("listen on a bad cert should log an error") 215 case err := <-errorCh: 216 require.Contains(t, err.Error(), badCertFile) 217 } 218 } 219 220 func TestStartTLSServerAndNoCheckServer(t *testing.T) { 221 ctx := context.Background() 222 ln, err := net.Listen("tcp", "127.0.0.1:0") 223 require.Nil(t, err) 224 defer func() { require.Nil(t, ln.Close()) }() 225 // Only enables https server and do not verify client certificate. 226 require.Nil( 227 t, 228 thttp.NewServerTransport(newNoopStdHTTPServer).ListenAndServe( 229 ctx, 230 transport.WithListener(ln), 231 transport.WithHandler(transport.Handler(&h{})), 232 transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", ""), 233 ), 234 "Failed to new client transport", 235 ) 236 237 ct := thttp.NewClientTransport(false) 238 ctx, msg := codec.WithNewMessage(ctx) 239 msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello") 240 msg.WithClientReqHead(&thttp.ClientReqHeader{}) 241 msg.WithClientRspHead(&thttp.ClientRspHeader{}) 242 243 rsp, err := ct.RoundTrip( 244 ctx, 245 []byte("{\"username\":\"xyz\","+"\"password\":\"xyz\",\"from\":\"xyz\"}"), 246 transport.WithDialAddress(ln.Addr().String()), 247 // Fully trust the https server and do not verify server certificate, 248 // can only be used in test env. 249 transport.WithDialTLS("", "", "none", ""), 250 ) 251 require.Nil(t, rsp, "roundtrip rsp not empty") 252 require.Nil(t, err, "Failed to roundtrip") 253 } 254 255 func TestServerWithListenerOption(t *testing.T) { 256 ln, err := net.Listen("tcp", "localhost:0") 257 require.Nil(t, err) 258 defer ln.Close() 259 service := server.New( 260 server.WithServiceName("trpc.http.server.ListenerTest"), 261 server.WithNetwork("tcp"), 262 server.WithProtocol("http"), 263 server.WithListener(ln), 264 ) 265 thttp.HandleFunc("/index", func(w http.ResponseWriter, r *http.Request) error { 266 fmt.Printf("Protocol: %s\n", r.Proto) 267 w.Write([]byte(r.Proto)) 268 return nil 269 }) 270 thttp.RegisterDefaultService(service) 271 s := &server.Server{} 272 s.AddService("trpc.http.server.ListenerTest", service) 273 go func() { 274 require.Nil(t, s.Serve()) 275 }() 276 defer s.Close(nil) 277 time.Sleep(100 * time.Millisecond) 278 279 resp, err := http.Get(fmt.Sprintf("http://%v/index", ln.Addr())) 280 require.Nil(t, err) 281 defer resp.Body.Close() 282 body, err := io.ReadAll(resp.Body) 283 require.Nil(t, err) 284 require.Equal(t, []byte("HTTP/1.1"), body) 285 286 const invalidAddr = "localhost:910439" 287 resp, err = http.Get(fmt.Sprintf("http://%s/index", invalidAddr)) 288 require.NotNil(t, err) 289 require.Nil(t, resp) 290 } 291 292 func TestStartDisableKeepAlivesServer(t *testing.T) { 293 ln, err := net.Listen("tcp", "localhost:0") 294 require.Nil(t, err) 295 defer ln.Close() 296 s := &server.Server{} 297 service := server.New( 298 server.WithListener(ln), 299 server.WithServiceName("trpc.http.server.ListenerTest"), 300 server.WithNetwork("tcp"), 301 server.WithProtocol("http"), 302 server.WithTransport(thttp.NewServerTransport(newNoopStdHTTPServer)), 303 server.WithDisableKeepAlives(true), 304 ) 305 thttp.HandleFunc("/disable-keepalives", func(w http.ResponseWriter, _ *http.Request) error { 306 w.Header().Set("Connection", "keep-alive") 307 return nil 308 }) 309 thttp.RegisterDefaultService(service) 310 s.AddService("trpc.http.server.ListenerTest", service) 311 go func() { 312 err := s.Serve() 313 require.Nil(t, err) 314 }() 315 defer func() { 316 _ = s.Close(nil) 317 }() 318 319 time.Sleep(100 * time.Millisecond) 320 321 dailCount := 0 322 client := &http.Client{ 323 Transport: &http.Transport{ 324 DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { 325 dailCount++ 326 conn, err := (&net.Dialer{}).DialContext(ctx, network, addr) 327 return conn, err 328 }, 329 }, 330 } 331 num := 3 332 url := fmt.Sprintf("http://%s/disable-keepalives", ln.Addr()) 333 for i := 0; i < num; i++ { 334 resp, err := client.Get(url) 335 require.Nil(t, err) 336 defer resp.Body.Close() 337 _, err = io.Copy(io.Discard, resp.Body) 338 require.Nil(t, err) 339 } 340 require.Equal(t, num, dailCount) 341 } 342 343 func TestStartH2cServer(t *testing.T) { 344 ln, err := net.Listen("tcp", "localhost:0") 345 require.Nil(t, err) 346 defer ln.Close() 347 s := &server.Server{} 348 service := server.New( 349 server.WithListener(ln), 350 server.WithServiceName("trpc.h2c.server.Greeter"), 351 server.WithNetwork("tcp"), 352 server.WithProtocol("http2"), 353 server.WithTransport(thttp.NewServerTransport(newNoopStdHTTPServer, thttp.WithEnableH2C())), 354 ) 355 thttp.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) error { 356 fmt.Printf("Protocol: %s\n", r.Proto) 357 w.Write([]byte(r.Proto)) 358 return nil 359 }) 360 thttp.HandleFunc("/main", func(w http.ResponseWriter, r *http.Request) error { 361 fmt.Printf("Protocol: %s\n", r.Proto) 362 w.Write([]byte(r.Proto)) 363 return nil 364 }) 365 thttp.RegisterDefaultService(service) 366 s.AddService("trpc.h2c.server.Greeter", service) 367 368 go func() { 369 err := s.Serve() 370 require.Nil(t, err) 371 }() 372 373 time.Sleep(100 * time.Millisecond) 374 375 // h2c client 376 h2cClient := http.Client{ 377 Transport: &http2.Transport{ 378 AllowHTTP: true, 379 DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { 380 return net.Dial(network, addr) 381 }, 382 }, 383 } 384 url := fmt.Sprintf("http://%s/", ln.Addr()) 385 resp, err := h2cClient.Get(url + "main") 386 require.Nil(t, err) 387 defer resp.Body.Close() 388 body, err := io.ReadAll(resp.Body) 389 require.Nil(t, err) 390 require.Equal(t, []byte("HTTP/2.0"), body) 391 392 // http1 client 393 resp2, err := http.Get(url) 394 require.Nil(t, err) 395 defer resp2.Body.Close() 396 body, err = io.ReadAll(resp2.Body) 397 require.Nil(t, err) 398 require.Equal(t, []byte("HTTP/1.1"), body) 399 require.Equal(t, http.StatusOK, resp2.StatusCode) 400 } 401 402 func TestHttp2StartTLSServerAndNoCheckServer(t *testing.T) { 403 ctx := context.Background() 404 ln, err := net.Listen("tcp", "127.0.0.1:0") 405 require.Nil(t, err) 406 defer func() { require.Nil(t, ln.Close()) }() 407 // Only enables https server and do not verify client certificate. 408 require.Nil( 409 t, 410 thttp.NewServerTransport(newNoopStdHTTPServer).ListenAndServe( 411 ctx, 412 transport.WithListener(ln), 413 transport.WithHandler(transport.Handler(&h{})), 414 transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", ""), 415 ), 416 "Failed to new client transport", 417 ) 418 419 ct := thttp.NewClientTransport(true) 420 ctx, msg := codec.WithNewMessage(ctx) 421 msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello") 422 msg.WithClientReqHead(&thttp.ClientReqHeader{}) 423 msg.WithClientRspHead(&thttp.ClientRspHeader{}) 424 425 rsp, err := ct.RoundTrip( 426 ctx, 427 []byte("{\"username\":\"xyz\","+"\"password\":\"xyz\",\"from\":\"xyz\"}"), 428 transport.WithDialAddress(ln.Addr().String()), 429 // Fully trust the https server and do not verify server certificate, 430 // can only be used in test env. 431 transport.WithDialTLS("", "", "none", ""), 432 ) 433 require.Nil(t, rsp, "roundtrip rsp not empty") 434 require.Nil(t, err, "Failed to roundtrip") 435 } 436 437 func TestStartTLSServerAndCheckServer(t *testing.T) { 438 ctx := context.Background() 439 tp := thttp.NewServerTransport(newNoopStdHTTPServer) 440 ln, err := net.Listen("tcp", "127.0.0.1:0") 441 require.Nil(t, err) 442 defer func() { require.Nil(t, ln.Close()) }() 443 err = tp.ListenAndServe(ctx, 444 transport.WithHandler(transport.Handler(&h{})), 445 // Only enables https server and do not verify client certificate. 446 transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", ""), 447 transport.WithListener(ln), 448 ) 449 require.Nil(t, err, "Failed to new client transport") 450 451 ct := thttp.NewClientTransport(false) 452 ctx, msg := codec.WithNewMessage(ctx) 453 msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello") 454 msg.WithClientReqHead(&thttp.ClientReqHeader{}) 455 msg.WithClientRspHead(&thttp.ClientRspHeader{}) 456 457 rsp, err := ct.RoundTrip(ctx, []byte("{\"username\":\"xyz\","+ 458 "\"password\":\"xyz\",\"from\":\"xyz\"}"), 459 transport.WithDialAddress(ln.Addr().String()), 460 // Uses ca public key to verify server certificate. 461 transport.WithDialTLS("", "", "../testdata/ca.pem", "localhost"), 462 ) 463 require.Nil(t, rsp, "roundtrip rsp not empty") 464 require.Nil(t, err, "Failed to roundtrip") 465 } 466 467 func TestStartTLSServerAndCheckClientNoCert(t *testing.T) { 468 ctx := context.Background() 469 tp := thttp.NewServerTransport(newNoopStdHTTPServer) 470 ln, err := net.Listen("tcp", "127.0.0.1:0") 471 require.Nil(t, err) 472 defer func() { require.Nil(t, ln.Close()) }() 473 err = tp.ListenAndServe(ctx, 474 transport.WithHandler(transport.Handler(&h{})), 475 // Enables two-way authentication http server and need to verify client certificate. 476 transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", "../testdata/ca.pem"), 477 transport.WithListener(ln), 478 ) 479 require.Nil(t, err, "Failed to new client transport") 480 481 ct := thttp.NewClientTransport(false) 482 ctx, msg := codec.WithNewMessage(ctx) 483 msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello") 484 msg.WithClientReqHead(&thttp.ClientReqHeader{}) 485 msg.WithClientRspHead(&thttp.ClientRspHeader{}) 486 487 _, err = ct.RoundTrip(ctx, []byte("{\"username\":\"xyz\","+ 488 "\"password\":\"xyz\",\"from\":\"xyz\"}"), 489 transport.WithDialAddress(ln.Addr().String()), 490 // If the client's own certificate is not sent, will return TLS verification failed. 491 transport.WithDialTLS("", "", "../testdata/ca.pem", "localhost"), 492 ) 493 require.NotNil(t, err, "Failed to roundtrip") 494 } 495 496 func TestStartTLSServerAndCheckClient(t *testing.T) { 497 ctx := context.Background() 498 tp := thttp.NewServerTransport(newNoopStdHTTPServer) 499 ln, err := net.Listen("tcp", "127.0.0.1:0") 500 require.Nil(t, err) 501 defer func() { require.Nil(t, ln.Close()) }() 502 // Enables two-way authentication http server and need to verify client certificate. 503 err = tp.ListenAndServe(ctx, 504 transport.WithHandler(transport.Handler(&h{})), 505 // Only enables https server and do not verify client certificate. 506 transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", "../testdata/ca.pem"), 507 transport.WithListener(ln), 508 ) 509 require.Nil(t, err, "Failed to new client transport") 510 511 ct := thttp.NewClientTransport(false) 512 ctx, msg := codec.WithNewMessage(ctx) 513 msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello") 514 msg.WithClientReqHead(&thttp.ClientReqHeader{}) 515 msg.WithClientRspHead(&thttp.ClientRspHeader{}) 516 517 rsp, err := ct.RoundTrip(ctx, []byte("{\"username\":\"xyz\","+ 518 "\"password\":\"xyz\",\"from\":\"xyz\"}"), 519 transport.WithDialAddress(ln.Addr().String()), 520 // Need to send the client's own certificate to server. 521 transport.WithDialTLS("../testdata/client.crt", "../testdata/client.key", "../testdata/ca.pem", "localhost"), 522 ) 523 require.Nil(t, rsp, "roundtrip rsp not empty") 524 require.Nil(t, err, "Failed to roundtrip") 525 } 526 527 func TestNewClientTransport(t *testing.T) { 528 ct := thttp.NewClientTransport(false) 529 require.NotNil(t, ct, "Failed to new client transport") 530 531 ct2 := thttp.NewClientTransport(true) 532 require.NotNil(t, ct2, "Failed to new http2 client transport") 533 } 534 535 func TestClientRoundTrip(t *testing.T) { 536 ctx := context.Background() 537 ct := thttp.NewClientTransport(false) 538 ctx, msg := codec.WithNewMessage(ctx) 539 msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello") 540 msg.WithClientReqHead(&thttp.ClientReqHeader{}) 541 msg.WithClientRspHead(&thttp.ClientRspHeader{}) 542 ln, err := net.Listen("tcp", "127.0.0.1:0") 543 require.Nil(t, err) 544 defer ln.Close() 545 go http.Serve(ln, nil) 546 rsp, err := ct.RoundTrip(ctx, []byte("{\"username\":\"xyz\","+ 547 "\"password\":\"xyz\",\"from\":\"xyz\"}"), 548 transport.WithDialAddress(ln.Addr().String())) 549 require.Nil(t, rsp, "roundtrip rsp not empty") 550 require.Nil(t, err, "Failed to roundtrip") 551 } 552 553 func TestClientRoundTripWithNoHead(t *testing.T) { 554 ctx := context.Background() 555 ct := thttp.NewClientTransport(false) 556 ctx, msg := codec.WithNewMessage(ctx) 557 msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello") 558 559 rsp, err := ct.RoundTrip(ctx, []byte("{\"username\":\"xyz\","+ 560 "\"password\":\"xyz\",\"from\":\"xyz\"}"), 561 transport.WithDialAddress("127.0.0.1:18080")) 562 require.Nil(t, rsp, "no head roundtrip rsp not empty") 563 require.NotNil(t, err, "no head roundtrip err nil") 564 565 } 566 567 func TestClientWithSelectorNode(t *testing.T) { 568 ctx := context.Background() 569 type testCase struct { 570 target string 571 address string 572 listener net.Listener 573 } 574 var tests []testCase 575 for i := 0; i < 2; i++ { 576 ln, err := net.Listen("tcp", "127.0.0.1:0") 577 require.Nil(t, err) 578 defer ln.Close() 579 addr := ln.Addr().String() 580 tests = append(tests, testCase{"ip://" + addr, addr, ln}) 581 } 582 for _, tt := range tests { 583 tp := thttp.NewServerTransport(newNoopStdHTTPServer) 584 option := transport.WithListener(tt.listener) 585 handler := transport.WithHandler(transport.Handler(&h{})) 586 err := tp.ListenAndServe(ctx, option, handler) 587 require.Nil(t, err, "Failed to new client transport") 588 589 proxy := thttp.NewClientProxy("trpc.test.helloworld.Greeter", 590 client.WithTarget(tt.target), 591 client.WithSerializationType(codec.SerializationTypeNoop), 592 ) 593 594 reqBody := &codec.Body{ 595 Data: []byte("{\"username\":\"xyz\"," + 596 "\"password\":\"xyz\",\"from\":\"xyz\"}"), 597 } 598 rspBody := &codec.Body{} 599 n := ®istry.Node{} 600 require.Nil(t, 601 proxy.Post(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody, client.WithSelectorNode(n)), 602 "Failed to post") 603 require.Equal(t, tt.address, n.Address) 604 } 605 } 606 607 func TestClient(t *testing.T) { 608 ctx := context.Background() 609 old := codec.GetSerializer(codec.SerializationTypeJSON) 610 defer func() { codec.RegisterSerializer(codec.SerializationTypeJSON, old) }() 611 codec.RegisterSerializer(codec.SerializationTypeJSON, &codec.JSONPBSerialization{}) 612 tp := thttp.NewServerTransport(newNoopStdHTTPServer) 613 ln, err := net.Listen("tcp", "127.0.0.1:0") 614 require.Nil(t, err) 615 defer ln.Close() 616 option := transport.WithListener(ln) 617 handler := transport.WithHandler(transport.Handler(&h{})) 618 require.Nil(t, tp.ListenAndServe(ctx, option, handler), "Failed to new client transport") 619 620 header := &thttp.ClientReqHeader{} 621 header.AddHeader("ContentType", "application/json") 622 623 proxy := thttp.NewClientProxy("trpc.test.helloworld.Greeter", 624 client.WithTarget("ip://"+ln.Addr().String()), 625 client.WithSerializationType(codec.SerializationTypeNoop), 626 client.WithReqHead(header), 627 client.WithMetaData("k1", []byte("v1")), 628 ) 629 reqBody := &codec.Body{ 630 Data: []byte("{\"username\":\"xyz\"," + 631 "\"password\":\"xyz\",\"from\":\"xyz\"}"), 632 } 633 rspBody := &codec.Body{} 634 635 require.Nil(t, proxy.Post(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody), "Failed to post") 636 require.Nil(t, proxy.Put(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody), "Failed to put") 637 require.Nil(t, proxy.Delete(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody), "Failed to delete") 638 require.Nil(t, proxy.Get(ctx, "/trpc.test.helloworld.Greeter/SayHello", rspBody), "Failed to get") 639 require.Nil(t, proxy.Patch(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody), "Failed to patch") 640 641 // Test client with options. 642 proxy = thttp.NewClientProxy("trpc.test.helloworld.Greeter") 643 reqBody = &codec.Body{ 644 Data: []byte("{\"username\":\"xyz\"," + 645 "\"password\":\"xyz\",\"from\":\"xyz\"}"), 646 } 647 rspBody = &codec.Body{} 648 require.Nil(t, 649 proxy.Post(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody, 650 client.WithTarget("ip://"+ln.Addr().String()), 651 client.WithSerializationType(codec.SerializationTypeNoop), 652 client.WithReqHead(header), 653 client.WithMetaData("k1", []byte("v1")), 654 ), "Failed to post") 655 656 require.NotNil(t, 657 proxy.Post(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody, 658 client.WithTarget("ip://127.0.0.1:180"), 659 ), "Failed to post") 660 } 661 662 func TestReqHeader(t *testing.T) { 663 ctx := context.Background() 664 // Invalid url. 665 header := &thttp.ClientReqHeader{} 666 header.AddHeader("Content-Type", "application/json") 667 proxy := thttp.NewClientProxy("trpc.test.helloworld.Greeter", 668 client.WithTarget("ip://127.0.0.1:18080:www.baidu.com//"), 669 client.WithSerializationType(codec.SerializationTypeNoop), 670 client.WithReqHead(header), 671 ) 672 reqBody := &codec.Body{} 673 rspBody := &codec.Body{} 674 err := proxy.Post(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody) 675 require.NotNil(t, err) 676 } 677 678 func TestReqHeaderWithContentType(t *testing.T) { 679 ctx := context.Background() 680 ln, err := net.Listen("tcp", "127.0.0.1:0") 681 require.Nil(t, err) 682 defer ln.Close() 683 option := transport.WithListener(ln) 684 handler := transport.WithHandler(transport.Handler(&h{})) 685 tp := thttp.NewServerTransport(newNoopStdHTTPServer) 686 require.Nil(t, tp.ListenAndServe(ctx, option, handler), "Failed to new client transport") 687 var tests = []struct { 688 expected string 689 }{ 690 {"application/json"}, 691 {"application/jsonp"}, 692 {"application/jsonp123"}, 693 {"application/text123"}, 694 } 695 for _, tt := range tests { 696 header := &thttp.ClientReqHeader{} 697 header.AddHeader("Content-Type", tt.expected) 698 proxy := thttp.NewClientProxy("trpc.test.helloworld.Greeter", 699 client.WithTarget("ip://"+ln.Addr().String()), 700 client.WithSerializationType(codec.SerializationTypeForm), 701 client.WithReqHead(header), 702 ) 703 reqBody := &codec.Body{} 704 rspBody := &codec.Body{} 705 err := proxy.Post(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody) 706 require.Nil(t, err) 707 } 708 } 709 710 func TestHandler(t *testing.T) { 711 var ( 712 handler = func(w http.ResponseWriter, r *http.Request) { 713 return 714 } 715 handlerFunc = func(w http.ResponseWriter, r *http.Request) error { 716 return nil 717 } 718 service = server.New(server.WithProtocol("http")) 719 ) 720 721 thttp.Handle("*", http.HandlerFunc(handler)) 722 thttp.HandleFunc("/path/do/not/equal/to/*", handlerFunc) 723 thttp.RegisterDefaultService(service) 724 725 for _, method := range thttp.ServiceDesc.Methods { 726 method.Func(nil, context.TODO(), func(reqBody interface{}) (filter.ServerChain, error) { 727 return make([]filter.ServerFilter, 0), nil 728 }) 729 730 method.Func(nil, context.TODO(), func(reqBody interface{}) (filter.ServerChain, error) { 731 return make([]filter.ServerFilter, 0), errors.New("invalid filter") 732 }) 733 734 header := &thttp.Header{ 735 Request: &http.Request{}, 736 Response: &httptest.ResponseRecorder{}, 737 } 738 ctx := thttp.WithHeader(context.TODO(), header) 739 _, err := method.Func(nil, ctx, func(reqBody interface{}) (filter.ServerChain, error) { 740 return make([]filter.ServerFilter, 0), nil 741 }) 742 require.Nil(t, err) 743 } 744 } 745 746 func TestMux(t *testing.T) { 747 var handler = func(w http.ResponseWriter, r *http.Request) { 748 return 749 } 750 mux := http.NewServeMux() 751 mux.HandleFunc("/", handler) 752 753 var service = &mockService{} 754 thttp.RegisterServiceMux(service, mux) 755 desc, _ := service.desc.(*server.ServiceDesc) 756 for _, method := range desc.Methods { 757 method.Func(nil, context.TODO(), func(reqBody interface{}) (filter.ServerChain, error) { 758 return make([]filter.ServerFilter, 0), nil 759 }) 760 761 method.Func(nil, context.TODO(), func(reqBody interface{}) (filter.ServerChain, error) { 762 return make([]filter.ServerFilter, 0), errors.New("invalid filter") 763 }) 764 765 req, _ := http.NewRequest("GET", "/", nil) 766 header := &thttp.Header{ 767 Request: req, 768 Response: &httptest.ResponseRecorder{}, 769 } 770 ctx := thttp.WithHeader(context.TODO(), header) 771 _, err := method.Func(nil, ctx, func(reqBody interface{}) (filter.ServerChain, error) { 772 return make([]filter.ServerFilter, 0), nil 773 }) 774 require.Nil(t, err) 775 } 776 } 777 778 // TestCheckRedirect tests set CheckRedirect 779 func TestCheckRedirect(t *testing.T) { 780 ctx := context.Background() 781 ln, err := net.Listen("tcp", "127.0.0.1:0") 782 require.Nil(t, err) 783 defer ln.Close() 784 // server 785 go func() { 786 // real backend 787 h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 788 w.Write([]byte("real")) 789 }) 790 http.Handle("/real", h) 791 792 // redirect a 793 rha := http.RedirectHandler("/b", http.StatusMovedPermanently) 794 http.Handle("/a", rha) 795 796 // redirect b 797 rhb := http.RedirectHandler("/real", http.StatusMovedPermanently) 798 http.Handle("/b", rhb) 799 800 http.Serve(ln, nil) 801 }() 802 time.Sleep(200 * time.Millisecond) 803 804 // sets CheckRedirect 805 checkRedirect := func(_ *http.Request, via []*http.Request) error { 806 if len(via) > 1 { 807 return errors.New("more than once") 808 } 809 return nil 810 } 811 thttp.DefaultClientTransport.(*thttp.ClientTransport).CheckRedirect = checkRedirect 812 defer func() { 813 thttp.DefaultClientTransport.(*thttp.ClientTransport).CheckRedirect = nil 814 }() 815 proxy := thttp.NewClientProxy("trpc.test.helloworld.Greeter", 816 client.WithTarget("ip://"+ln.Addr().String()), 817 client.WithSerializationType(codec.SerializationTypeNoop), 818 ) 819 reqBody := &codec.Body{} 820 rspBody := &codec.Body{} 821 // only redirect once form /b 822 require.Nil(t, proxy.Post(ctx, "/b", reqBody, rspBody)) 823 // redirect twice from /a 824 err = proxy.Post(ctx, "/a", reqBody, rspBody) 825 require.NotNil(t, err) 826 require.Equal(t, true, strings.Contains(err.Error(), "more than once")) 827 } 828 829 func TestTransportError(t *testing.T) { 830 http.HandleFunc("/timeout", func(http.ResponseWriter, *http.Request) { 831 time.Sleep(time.Second) 832 }) 833 http.HandleFunc("/cancel", func(http.ResponseWriter, *http.Request) {}) 834 ln, err := net.Listen("tcp", "127.0.0.1:0") 835 require.Nil(t, err) 836 defer ln.Close() 837 go func() { http.Serve(ln, nil) }() 838 time.Sleep(200 * time.Millisecond) 839 840 proxy := thttp.NewClientProxy("trpc.test.helloworld.Greeter", 841 client.WithTarget("ip://"+ln.Addr().String()), 842 client.WithSerializationType(codec.SerializationTypeNoop), 843 client.WithTimeout(time.Millisecond*500), 844 ) 845 rspBody := &codec.Body{} 846 847 err = proxy.Get(context.Background(), "/timeout", rspBody) 848 terr, ok := err.(*errs.Error) 849 require.True(t, ok) 850 require.EqualValues(t, terr.Code, int32(errs.RetClientTimeout)) 851 852 ctx, cancel := context.WithCancel(context.Background()) 853 cancel() 854 err = proxy.Get(ctx, "/cancel", rspBody) 855 terr, ok = err.(*errs.Error) 856 require.True(t, ok) 857 require.EqualValues(t, terr.Code, int32(errs.RetClientCanceled)) 858 } 859 860 func TestClientRoundDyeing(t *testing.T) { 861 ctx := context.Background() 862 ct := thttp.NewClientTransport(false) 863 ctx, msg := codec.WithNewMessage(ctx) 864 msg.WithDyeing(true) 865 dyeingKey := "dyeingkey" 866 msg.WithDyeingKey(dyeingKey) 867 msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello") 868 req := &http.Request{ 869 Header: http.Header{}, 870 } 871 reqHeader := &thttp.ClientReqHeader{ 872 Request: req, 873 } 874 msg.WithClientReqHead(reqHeader) 875 rspHeader := &thttp.ClientRspHeader{} 876 msg.WithClientRspHead(rspHeader) 877 meta := codec.MetaData{ 878 thttp.TrpcDyeingKey: []byte(dyeingKey), 879 } 880 msg.WithClientMetaData(meta) 881 _, err := ct.RoundTrip(ctx, nil) 882 require.NotNil(t, err) 883 require.Equal(t, req.Header.Get(thttp.TrpcMessageType), 884 strconv.Itoa(int(trpcpb.TrpcMessageType_TRPC_DYEING_MESSAGE))) 885 } 886 887 func TestClientRoundEnvTransfer(t *testing.T) { 888 ctx := context.Background() 889 ct := thttp.NewClientTransport(false) 890 ctx, msg := codec.WithNewMessage(ctx) 891 msg.WithEnvTransfer("feat,master") 892 msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello") 893 req := &http.Request{ 894 Header: http.Header{}, 895 } 896 reqHeader := &thttp.ClientReqHeader{ 897 Request: req, 898 } 899 msg.WithClientReqHead(reqHeader) 900 rspHeader := &thttp.ClientRspHeader{} 901 msg.WithClientRspHead(rspHeader) 902 _, err := ct.RoundTrip(ctx, nil) 903 require.NotNil(t, err) 904 require.Contains(t, req.Header.Get(thttp.TrpcTransInfo), thttp.TrpcEnv) 905 } 906 907 func TestDisableBase64EncodeTransInfo(t *testing.T) { 908 ctx := context.Background() 909 ct := thttp.NewClientTransport(false, transport.WithDisableEncodeTransInfoBase64()) 910 ctx, msg := codec.WithNewMessage(ctx) 911 var ( 912 envTrans = "feat,master" 913 metaVal = "value" 914 dyeingKey = "dyeingkey" 915 ) 916 msg.WithEnvTransfer(envTrans) 917 msg.WithClientMetaData(codec.MetaData{"key": []byte(metaVal)}) 918 msg.WithDyeing(true) 919 msg.WithDyeingKey(dyeingKey) 920 msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello") 921 req := &http.Request{ 922 Header: http.Header{}, 923 } 924 reqHeader := &thttp.ClientReqHeader{ 925 Request: req, 926 } 927 msg.WithClientReqHead(reqHeader) 928 rspHeader := &thttp.ClientRspHeader{} 929 msg.WithClientRspHead(rspHeader) 930 _, err := ct.RoundTrip(ctx, nil) 931 require.NotNil(t, err) 932 require.Contains(t, req.Header.Get(thttp.TrpcTransInfo), envTrans) 933 require.Contains(t, req.Header.Get(thttp.TrpcTransInfo), metaVal) 934 require.Contains(t, req.Header.Get(thttp.TrpcTransInfo), dyeingKey) 935 } 936 937 func TestDisableServiceRouterTransInfo(t *testing.T) { 938 ctx := context.Background() 939 a := require.New(t) 940 ct := thttp.NewClientTransport(false) 941 ctx, msg := codec.WithNewMessage(ctx) 942 msg.WithClientMetaData(codec.MetaData{thttp.TrpcEnv: []byte("orienv")}) // this emulate decode trpc protocol client request 943 msg.WithEnvTransfer("feat,master") 944 req := &http.Request{ 945 Header: http.Header{}, 946 } 947 reqHeader := &thttp.ClientReqHeader{ 948 Request: req, 949 } 950 msg.WithClientReqHead(reqHeader) 951 rspHeader := &thttp.ClientRspHeader{} 952 msg.WithClientRspHead(rspHeader) 953 _, err := ct.RoundTrip(ctx, nil) 954 a.NotNil(err) 955 info, err := thttp.UnmarshalTransInfo(msg, req.Header.Get(thttp.TrpcTransInfo)) 956 a.NoError(err) 957 a.Equal(string(info[thttp.TrpcEnv]), "feat,master") 958 959 msg.WithEnvTransfer("") // DisableServiceRouter would clear EnvTransfer 960 _, err = ct.RoundTrip(ctx, nil) 961 a.NotNil(err) 962 info, err = thttp.UnmarshalTransInfo(msg, req.Header.Get(thttp.TrpcTransInfo)) 963 a.NoError(err) 964 a.Equal(string(info[thttp.TrpcEnv]), "") 965 } 966 967 func TestHTTPSUseClientVerify(t *testing.T) { 968 const ( 969 network = "tcp" 970 address = "127.0.0.1:0" 971 ) 972 ln, err := net.Listen(network, address) 973 require.Nil(t, err) 974 defer ln.Close() 975 serviceName := "trpc.app.server.Service" + t.Name() 976 service := server.New( 977 server.WithServiceName(serviceName), 978 server.WithNetwork(network), 979 server.WithProtocol("http_no_protocol"), 980 server.WithListener(ln), 981 server.WithTLS( 982 "../testdata/server.crt", 983 "../testdata/server.key", 984 "../testdata/ca.pem", 985 ), 986 ) 987 pattern := "/" + t.Name() 988 thttp.RegisterNoProtocolServiceMux(service, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 989 w.Write([]byte(t.Name())) 990 })) 991 s := &server.Server{} 992 s.AddService(serviceName, service) 993 go s.Serve() 994 defer s.Close(nil) 995 time.Sleep(100 * time.Millisecond) 996 997 c := thttp.NewClientProxy( 998 serviceName, 999 client.WithTarget("ip://"+ln.Addr().String()), 1000 ) 1001 req := &codec.Body{} 1002 rsp := &codec.Body{} 1003 require.Nil(t, 1004 c.Post(context.Background(), pattern, req, rsp, 1005 client.WithCurrentSerializationType(codec.SerializationTypeNoop), 1006 client.WithSerializationType(codec.SerializationTypeNoop), 1007 client.WithCurrentCompressType(codec.CompressTypeNoop), 1008 client.WithTLS( 1009 "../testdata/client.crt", 1010 "../testdata/client.key", 1011 "../testdata/ca.pem", 1012 "localhost", 1013 ), 1014 )) 1015 require.Equal(t, []byte(t.Name()), rsp.Data) 1016 } 1017 1018 func TestHTTPSSkipClientVerify(t *testing.T) { 1019 const ( 1020 network = "tcp" 1021 address = "127.0.0.1:0" 1022 ) 1023 ln, err := net.Listen(network, address) 1024 require.Nil(t, err) 1025 defer ln.Close() 1026 serviceName := "trpc.app.server.Service" + t.Name() 1027 service := server.New( 1028 server.WithServiceName(serviceName), 1029 server.WithNetwork(network), 1030 server.WithProtocol("http_no_protocol"), 1031 server.WithListener(ln), 1032 server.WithTLS( 1033 "../testdata/server.crt", 1034 "../testdata/server.key", 1035 "", 1036 ), 1037 ) 1038 pattern := "/" + t.Name() 1039 thttp.RegisterNoProtocolServiceMux(service, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 1040 w.Write([]byte(t.Name())) 1041 })) 1042 s := &server.Server{} 1043 s.AddService(serviceName, service) 1044 go s.Serve() 1045 defer s.Close(nil) 1046 time.Sleep(100 * time.Millisecond) 1047 1048 c := thttp.NewClientProxy( 1049 serviceName, 1050 client.WithTarget("ip://"+ln.Addr().String()), 1051 ) 1052 req := &codec.Body{} 1053 rsp := &codec.Body{} 1054 require.Nil(t, 1055 c.Post(context.Background(), pattern, req, rsp, 1056 client.WithCurrentSerializationType(codec.SerializationTypeNoop), 1057 client.WithSerializationType(codec.SerializationTypeNoop), 1058 client.WithCurrentCompressType(codec.CompressTypeNoop), 1059 client.WithTLS( 1060 "", "", "none", "", 1061 ), 1062 )) 1063 require.Equal(t, []byte(t.Name()), rsp.Data) 1064 } 1065 1066 func TestListenAndServeHTTPHead(t *testing.T) { 1067 ctx := context.Background() 1068 const ( 1069 network = "tcp" 1070 address = "127.0.0.1:0" 1071 ) 1072 ln, err := net.Listen(network, address) 1073 require.Nil(t, err) 1074 defer ln.Close() 1075 st := thttp.NewServerTransport(newNoopStdHTTPServer) 1076 require.Nil(t, st.ListenAndServe(ctx, 1077 transport.WithHandler(&httpHeadHandler{ 1078 func(ctx context.Context, _ []byte) (rsp []byte, err error) { 1079 head := thttp.Head(ctx) 1080 head.Response.WriteHeader(http.StatusOK) 1081 head.Response.Write([]byte(fmt.Sprintf("%+v", thttp.Head(head.Request.Context()) != nil))) 1082 return 1083 }}), 1084 transport.WithListener(ln), 1085 )) 1086 time.Sleep(200 * time.Millisecond) 1087 rsp, err := http.Get("http://" + ln.Addr().String()) 1088 require.Nil(t, err) 1089 bs, err := io.ReadAll(rsp.Body) 1090 require.Nil(t, err) 1091 require.Equal(t, fmt.Sprintf("%+v", true), string(bs)) 1092 } 1093 1094 type httpHeadHandler struct { 1095 handle func(ctx context.Context, req []byte) (rsp []byte, err error) 1096 } 1097 1098 func (h *httpHeadHandler) Handle(ctx context.Context, req []byte) (rsp []byte, err error) { 1099 return h.handle(ctx, req) 1100 } 1101 1102 func TestHTTPStreamFileUpload(t *testing.T) { 1103 // Start server. 1104 const ( 1105 network = "tcp" 1106 address = "127.0.0.1:0" 1107 ) 1108 ln, err := net.Listen(network, address) 1109 require.Nil(t, err) 1110 defer ln.Close() 1111 go http.Serve(ln, &fileHandler{}) 1112 // Start client. 1113 c := thttp.NewClientProxy( 1114 "trpc.app.server.Service_http", 1115 client.WithTarget("ip://"+ln.Addr().String()), 1116 ) 1117 // Open and read file. 1118 fileDir, err := os.Getwd() 1119 require.Nil(t, err) 1120 fileName := "README.md" 1121 filePath := path.Join(fileDir, fileName) 1122 file, err := os.Open(filePath) 1123 require.Nil(t, err) 1124 defer file.Close() 1125 // Construct multipart form file. 1126 body := &bytes.Buffer{} 1127 writer := multipart.NewWriter(body) 1128 part, err := writer.CreateFormFile("field_name", filepath.Base(file.Name())) 1129 require.Nil(t, err) 1130 io.Copy(part, file) 1131 require.Nil(t, writer.Close()) 1132 // Add multipart form data header. 1133 header := http.Header{} 1134 header.Add("Content-Type", writer.FormDataContentType()) 1135 reqHeader := &thttp.ClientReqHeader{ 1136 Header: header, 1137 ReqBody: body, // Stream send. 1138 } 1139 req := &codec.Body{} 1140 rsp := &codec.Body{} 1141 // Upload file. 1142 require.Nil(t, 1143 c.Post(context.Background(), "/", req, rsp, 1144 client.WithCurrentSerializationType(codec.SerializationTypeNoop), 1145 client.WithSerializationType(codec.SerializationTypeNoop), 1146 client.WithCurrentCompressType(codec.CompressTypeNoop), 1147 client.WithReqHead(reqHeader), 1148 )) 1149 require.Equal(t, []byte(fileName), rsp.Data) 1150 } 1151 1152 type fileHandler struct{} 1153 1154 func (*fileHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 1155 _, h, err := r.FormFile("field_name") 1156 if err != nil { 1157 w.WriteHeader(http.StatusBadRequest) 1158 return 1159 } 1160 w.WriteHeader(http.StatusOK) 1161 // Write back file name. 1162 w.Write([]byte(h.Filename)) 1163 return 1164 } 1165 1166 func TestHTTPStreamRead(t *testing.T) { 1167 // Start server. 1168 const ( 1169 network = "tcp" 1170 address = "127.0.0.1:0" 1171 ) 1172 ln, err := net.Listen(network, address) 1173 require.Nil(t, err) 1174 defer ln.Close() 1175 go http.Serve(ln, &fileServer{}) 1176 1177 // Start client. 1178 c := thttp.NewClientProxy( 1179 "trpc.app.server.Service_http", 1180 client.WithTarget("ip://"+ln.Addr().String()), 1181 ) 1182 1183 // Enable manual body reading in order to 1184 // disable the framework's automatic body reading capability, 1185 // so that users can manually do their own client-side streaming reads. 1186 rspHead := &thttp.ClientRspHeader{ 1187 ManualReadBody: true, 1188 } 1189 req := &codec.Body{} 1190 rsp := &codec.Body{} 1191 require.Nil(t, 1192 c.Post(context.Background(), "/", req, rsp, 1193 client.WithCurrentSerializationType(codec.SerializationTypeNoop), 1194 client.WithSerializationType(codec.SerializationTypeNoop), 1195 client.WithCurrentCompressType(codec.CompressTypeNoop), 1196 client.WithRspHead(rspHead), 1197 )) 1198 require.Nil(t, rsp.Data) 1199 body := rspHead.Response.Body // Do stream reads directly from rspHead.Response.Body. 1200 defer body.Close() // Do remember to close the body. 1201 bs, err := io.ReadAll(body) 1202 require.Nil(t, err) 1203 require.NotNil(t, bs) 1204 } 1205 1206 func TestHTTPSendReceiveChunk(t *testing.T) { 1207 // HTTP chunked example: 1208 // 1. Client sends chunks: Add "chunked" transfer encoding header, and use io.Reader as body. 1209 // 2. Client reads chunks: The Go/net/http automatically handles the chunked reading. 1210 // Users can simply read resp.Body in a loop until io.EOF. 1211 // 3. Server reads chunks: Similar to client reads chunks. 1212 // 4. Server sends chunks: Assert http.ResponseWriter as http.Flusher, call flusher.Flush() after 1213 // writing a part of data, it will automatically trigger "chunked" encoding to send a chunk. 1214 1215 // Start server. 1216 const ( 1217 network = "tcp" 1218 address = "127.0.0.1:0" 1219 ) 1220 ln, err := net.Listen(network, address) 1221 require.Nil(t, err) 1222 defer ln.Close() 1223 go http.Serve(ln, &chunkedServer{}) 1224 1225 // Start client. 1226 c := thttp.NewClientProxy( 1227 "trpc.app.server.Service_http", 1228 client.WithTarget("ip://"+ln.Addr().String()), 1229 ) 1230 1231 // Open and read file. 1232 fileDir, err := os.Getwd() 1233 require.Nil(t, err) 1234 fileName := "README.md" 1235 filePath := path.Join(fileDir, fileName) 1236 file, err := os.Open(filePath) 1237 require.Nil(t, err) 1238 defer file.Close() 1239 1240 // 1. Client sends chunks. 1241 1242 // Add request headers. 1243 header := http.Header{} 1244 header.Add("Content-Type", "text/plain") 1245 // Add chunked transfer encoding header. 1246 header.Add("Transfer-Encoding", "chunked") 1247 reqHead := &thttp.ClientReqHeader{ 1248 Header: header, 1249 ReqBody: file, // Stream send (for chunks). 1250 } 1251 1252 // Enable manual body reading in order to 1253 // disable the framework's automatic body reading capability, 1254 // so that users can manually do their own client-side streaming reads. 1255 rspHead := &thttp.ClientRspHeader{ 1256 ManualReadBody: true, 1257 } 1258 req := &codec.Body{} 1259 rsp := &codec.Body{} 1260 require.Nil(t, 1261 c.Post(context.Background(), "/", req, rsp, 1262 client.WithCurrentSerializationType(codec.SerializationTypeNoop), 1263 client.WithSerializationType(codec.SerializationTypeNoop), 1264 client.WithCurrentCompressType(codec.CompressTypeNoop), 1265 client.WithReqHead(reqHead), 1266 client.WithRspHead(rspHead), 1267 )) 1268 require.Nil(t, rsp.Data) 1269 1270 // 2. Client reads chunks. 1271 1272 // Do stream reads directly from rspHead.Response.Body. 1273 body := rspHead.Response.Body 1274 defer body.Close() // Do remember to close the body. 1275 buf := make([]byte, 4096) 1276 var idx int 1277 for { 1278 n, err := body.Read(buf) 1279 if err == io.EOF { 1280 t.Logf("reached io.EOF\n") 1281 break 1282 } 1283 t.Logf("read chunk %d of length %d: %q\n", idx, n, buf[:n]) 1284 idx++ 1285 } 1286 } 1287 1288 type chunkedServer struct{} 1289 1290 func (*chunkedServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { 1291 // 3. Server reads chunks. 1292 1293 // io.ReadAll will read until io.EOF. 1294 // Go/net/http will automatically handle chunked body reads. 1295 bs, err := io.ReadAll(r.Body) 1296 if err != nil { 1297 w.WriteHeader(http.StatusInternalServerError) 1298 w.Write([]byte(fmt.Sprintf("io.ReadAll err: %+v", err))) 1299 return 1300 } 1301 1302 // 4. Server sends chunks. 1303 1304 // Send HTTP chunks using http.Flusher. 1305 // Reference: https://stackoverflow.com/questions/26769626/send-a-chunked-http-response-from-a-go-server. 1306 // The "Transfer-Encoding" header will be handled by the writer implicitly, so no need to set it. 1307 flusher, ok := w.(http.Flusher) 1308 if !ok { 1309 w.WriteHeader(http.StatusInternalServerError) 1310 w.Write([]byte("expected http.ResponseWriter to be an http.Flusher")) 1311 return 1312 } 1313 chunks := 10 1314 chunkSize := (len(bs) + chunks - 1) / chunks 1315 for i := 0; i < chunks; i++ { 1316 start := i * chunkSize 1317 end := (i + 1) * chunkSize 1318 if end > len(bs) { 1319 end = len(bs) 1320 } 1321 w.Write(bs[start:end]) 1322 flusher.Flush() // Trigger "chunked" encoding and send a chunk. 1323 time.Sleep(500 * time.Millisecond) 1324 } 1325 return 1326 } 1327 1328 type fileServer struct{} 1329 1330 func (*fileServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { 1331 http.ServeFile(w, r, "./README.md") 1332 return 1333 } 1334 1335 func TestHTTPSendAndReceiveSSE(t *testing.T) { 1336 const ( 1337 network = "tcp" 1338 address = "127.0.0.1:0" 1339 ) 1340 ln, err := net.Listen(network, address) 1341 require.Nil(t, err) 1342 defer ln.Close() 1343 serviceName := "trpc.app.server.Service" + t.Name() 1344 service := server.New( 1345 server.WithServiceName(serviceName), 1346 server.WithNetwork(network), 1347 server.WithProtocol("http_no_protocol"), 1348 server.WithListener(ln), 1349 ) 1350 pattern := "/" + t.Name() 1351 thttp.RegisterNoProtocolServiceMux(service, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1352 flusher, ok := w.(http.Flusher) 1353 if !ok { 1354 http.Error(w, "Streaming unsupported!", http.StatusInternalServerError) 1355 return 1356 } 1357 w.Header().Set("Content-Type", "text/event-stream") 1358 w.Header().Set("Cache-Control", "no-cache") 1359 w.Header().Set("Connection", "keep-alive") 1360 w.Header().Set("Access-Control-Allow-Origin", "*") 1361 bs, err := io.ReadAll(r.Body) 1362 if err != nil { 1363 http.Error(w, err.Error(), http.StatusBadRequest) 1364 return 1365 } 1366 msg := string(bs) 1367 for i := 0; i < 3; i++ { 1368 msgBytes := []byte("event: message\n\ndata: " + msg + strconv.Itoa(i) + "\n\n") 1369 _, err = w.Write(msgBytes) 1370 if err != nil { 1371 http.Error(w, err.Error(), http.StatusInternalServerError) 1372 return 1373 } 1374 flusher.Flush() 1375 time.Sleep(500 * time.Millisecond) 1376 } 1377 return 1378 })) 1379 s := &server.Server{} 1380 s.AddService(serviceName, service) 1381 go s.Serve() 1382 defer s.Close(nil) 1383 time.Sleep(100 * time.Millisecond) 1384 1385 c := thttp.NewClientProxy( 1386 serviceName, 1387 client.WithTarget("ip://"+ln.Addr().String()), 1388 ) 1389 header := http.Header{} 1390 header.Set("Cache-Control", "no-cache") 1391 header.Set("Accept", "text/event-stream") 1392 header.Set("Connection", "keep-alive") 1393 reqHeader := &thttp.ClientReqHeader{ 1394 Header: header, 1395 } 1396 // Enable manual body reading in order to 1397 // disable the framework's automatic body reading capability, 1398 // so that users can manually do their own client-side streaming reads. 1399 rspHead := &thttp.ClientRspHeader{ 1400 ManualReadBody: true, 1401 } 1402 req := &codec.Body{Data: []byte("hello")} 1403 rsp := &codec.Body{} 1404 require.Nil(t, 1405 c.Post(context.Background(), pattern, req, rsp, 1406 client.WithCurrentSerializationType(codec.SerializationTypeNoop), 1407 client.WithSerializationType(codec.SerializationTypeNoop), 1408 client.WithCurrentCompressType(codec.CompressTypeNoop), 1409 client.WithReqHead(reqHeader), 1410 client.WithRspHead(rspHead), 1411 client.WithTimeout(time.Minute), 1412 )) 1413 body := rspHead.Response.Body // Do stream reads directly from rspHead.Response.Body. 1414 defer body.Close() // Do remember to close the body. 1415 data := make([]byte, 1024) 1416 for { 1417 n, err := body.Read(data) 1418 if err == io.EOF { 1419 break 1420 } 1421 require.Nil(t, err) 1422 t.Logf("Received message: \n%s\n", string(data[:n])) 1423 } 1424 } 1425 1426 func TestHTTPClientReqRspDifferentContentType(t *testing.T) { 1427 const ( 1428 network = "tcp" 1429 address = "127.0.0.1:0" 1430 ) 1431 ln, err := net.Listen(network, address) 1432 require.Nil(t, err) 1433 defer ln.Close() 1434 serviceName := "trpc.app.server.Service" + t.Name() 1435 service := server.New( 1436 server.WithServiceName(serviceName), 1437 server.WithNetwork(network), 1438 server.WithProtocol("http_no_protocol"), 1439 server.WithListener(ln), 1440 ) 1441 const ( 1442 hello = "hello " 1443 key = "key" 1444 ) 1445 pattern := "/" + t.Name() 1446 thttp.RegisterNoProtocolServiceMux(service, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1447 bs, err := io.ReadAll(r.Body) 1448 if err != nil { 1449 w.WriteHeader(http.StatusBadRequest) 1450 return 1451 } 1452 req, err := url.ParseQuery(string(bs)) 1453 if err != nil { 1454 w.WriteHeader(http.StatusBadRequest) 1455 return 1456 } 1457 rsp := &helloworld.HelloReply{Message: hello + req.Get(key)} 1458 bs, err = codec.Marshal(codec.SerializationTypePB, rsp) 1459 if err != nil { 1460 w.WriteHeader(http.StatusInternalServerError) 1461 return 1462 } 1463 w.Header().Add("Content-Type", "application/protobuf") 1464 w.Write(bs) 1465 return 1466 })) 1467 s := &server.Server{} 1468 s.AddService(serviceName, service) 1469 go s.Serve() 1470 defer s.Close(nil) 1471 time.Sleep(100 * time.Millisecond) 1472 1473 c := thttp.NewClientProxy( 1474 serviceName, 1475 client.WithTarget("ip://"+ln.Addr().String()), 1476 ) 1477 req := make(url.Values) 1478 req.Add(key, t.Name()) 1479 rsp := &helloworld.HelloReply{} 1480 require.Nil(t, 1481 c.Post(context.Background(), pattern, req, rsp, 1482 client.WithSerializationType(codec.SerializationTypeForm), 1483 )) 1484 require.Equal(t, hello+t.Name(), rsp.Message) 1485 } 1486 1487 func TestHTTPGotConnectionRemoteAddr(t *testing.T) { 1488 ctx := context.Background() 1489 for i := 0; i < 3; i++ { 1490 proxy := thttp.NewClientProxy(t.Name(), client.WithTarget("dns://new.qq.com/")) 1491 rsp := &codec.Body{} 1492 require.Nil(t, proxy.Get(ctx, "/", rsp, 1493 client.WithSerializationType(codec.SerializationTypeNoop), 1494 client.WithFilter( 1495 func(ctx context.Context, req, rsp interface{}, next filter.ClientHandleFunc) error { 1496 err := next(ctx, req, rsp) 1497 msg := codec.Message(ctx) 1498 addr := msg.RemoteAddr() 1499 require.NotNil(t, addr, "expect to get remote addr from msg in connection reuse case") 1500 t.Logf("addr = %+v\n", addr) 1501 return err 1502 }))) 1503 } 1504 } 1505 1506 type h struct{} 1507 1508 func (*h) Handle(ctx context.Context, reqBuf []byte) (rsp []byte, err error) { 1509 fmt.Println("recv http req") 1510 return nil, nil 1511 } 1512 1513 type testLog struct { 1514 log.Logger 1515 errorCh chan error 1516 } 1517 1518 func (ln *testLog) Errorf(format string, args ...interface{}) { 1519 ln.errorCh <- fmt.Errorf(format, args...) 1520 } 1521 1522 // mockService is a mock service. 1523 type mockService struct { 1524 desc interface{} 1525 } 1526 1527 // Register registers route information. 1528 func (m *mockService) Register(serviceDesc interface{}, serviceImpl interface{}) error { 1529 m.desc = serviceDesc 1530 return nil 1531 } 1532 1533 // Serve runs service. 1534 func (m *mockService) Serve() error { 1535 return nil 1536 } 1537 1538 // Close closes service. 1539 func (m *mockService) Close(chan struct{}) error { 1540 return nil 1541 } 1542 1543 type errHandler struct{} 1544 1545 func (*errHandler) Handle(ctx context.Context, reqBuf []byte) (rsp []byte, err error) { 1546 return nil, errors.New("mock error") 1547 } 1548 1549 type errHeaderHandler struct{} 1550 1551 func (*errHeaderHandler) Handle(ctx context.Context, reqBuf []byte) (rsp []byte, err error) { 1552 return nil, thttp.ErrEncodeMissingHeader 1553 }