trpc.group/trpc-go/trpc-go@v1.0.2/http/transport_test.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package http_test 15 16 // https certificate file generation method: 17 // 1. ca certificate: 18 // openssl genrsa -out ca.key 2048 19 // openssl req -x509 -new -nodes -key ca.key -subj "/CN=*" -days 5000 -out ca.pem 20 // 2. server certificate: 21 // openssl genrsa -out server.key 2048 22 // openssl req -new -key server.key -subj "/CN=*" -out server.csr 23 // openssl x509 -req -in server.csr -CA ca.pem -CAkey ca.key -CAcreateserial -out server.crt -days 5000 <(printf "subjectAltName=DNS:localhost") 24 // 3. client certificate: 25 // openssl genrsa -out client.key 2048 26 // openssl req -new -key client.key -subj "/CN=*" -out client.csr 27 // openssl x509 -req -in client.csr -CA ca.pem -CAkey ca.key -CAcreateserial -out client.crt -days 5000 <(printf "subjectAltName=DNS:localhost") 28 // 4. show certificate content: 29 // openssl x509 -text -in server.crt -noout 30 31 import ( 32 "bytes" 33 "context" 34 "crypto/tls" 35 "errors" 36 "fmt" 37 "io" 38 "mime/multipart" 39 "net" 40 "net/http" 41 "net/http/httptest" 42 "net/url" 43 "os" 44 "path" 45 "path/filepath" 46 "strconv" 47 "strings" 48 "testing" 49 "time" 50 51 "github.com/stretchr/testify/require" 52 "golang.org/x/net/http2" 53 trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc" 54 55 "trpc.group/trpc-go/trpc-go/client" 56 "trpc.group/trpc-go/trpc-go/codec" 57 "trpc.group/trpc-go/trpc-go/errs" 58 "trpc.group/trpc-go/trpc-go/filter" 59 thttp "trpc.group/trpc-go/trpc-go/http" 60 "trpc.group/trpc-go/trpc-go/log" 61 "trpc.group/trpc-go/trpc-go/naming/registry" 62 "trpc.group/trpc-go/trpc-go/server" 63 "trpc.group/trpc-go/trpc-go/testdata/restful/helloworld" 64 "trpc.group/trpc-go/trpc-go/transport" 65 ) 66 67 func newNoopStdHTTPServer() *http.Server { return &http.Server{} } 68 69 func TestStartServer(t *testing.T) { 70 ctx := context.Background() 71 tp := thttp.NewServerTransport(newNoopStdHTTPServer) 72 ln, err := net.Listen("tcp", "127.0.0.1:0") 73 require.Nil(t, err) 74 defer ln.Close() 75 option := transport.WithListener(ln) 76 handler := transport.WithHandler(transport.Handler(&h{})) 77 require.Nil(t, tp.ListenAndServe(ctx, option, handler), "Failed to new client transport") 78 require.NotNil(t, tp.ListenAndServe(ctx, transport.WithListenAddress("127.0.0.1:8888"), handler, transport.WithListenNetwork("tcp1"))) 79 tls := transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", "ca1") 80 require.NotNil(t, tp.ListenAndServe(ctx, option, handler, tls)) 81 } 82 83 func TestH2C(t *testing.T) { 84 ctx := context.Background() 85 ln, err := net.Listen("tcp", "127.0.0.1:0") 86 require.Nil(t, err) 87 defer ln.Close() 88 handler := transport.WithHandler(transport.Handler(&h{})) 89 tp := thttp.NewServerTransport(newNoopStdHTTPServer, thttp.WithReusePort(), thttp.WithEnableH2C()) 90 require.Nil(t, tp.ListenAndServe(ctx, transport.WithListener(ln), handler)) 91 } 92 93 func TestDisableReusePort(t *testing.T) { 94 ctx := context.Background() 95 tp := thttp.NewServerTransport(newNoopStdHTTPServer) 96 ln1, err := net.Listen("tcp", "127.0.0.1:0") 97 require.Nil(t, err) 98 defer ln1.Close() 99 option := transport.WithListener(ln1) 100 handler := transport.WithHandler(transport.Handler(&h{})) 101 require.Nil(t, tp.ListenAndServe(ctx, option, handler), "Failed to new client transport") 102 103 option = transport.WithListenAddress(ln1.Addr().String()) 104 require.NotNil(t, tp.ListenAndServe(ctx, option, handler, transport.WithListenNetwork("tcp1"))) 105 106 ln2, err := net.Listen("tcp", "127.0.0.1:0") 107 require.Nil(t, err) 108 defer ln2.Close() 109 option = transport.WithListener(ln2) 110 tls := transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", "") 111 require.Nil(t, tp.ListenAndServe(ctx, option, handler, tls)) 112 113 ln3, err := net.Listen("tcp", "127.0.0.1:0") 114 require.Nil(t, err) 115 defer ln3.Close() 116 option = transport.WithListener(ln3) 117 tls = transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", "root") 118 require.Nil(t, tp.ListenAndServe(ctx, option, handler, tls)) 119 120 ln4, err := net.Listen("tcp", "127.0.0.1:0") 121 require.Nil(t, err) 122 defer ln4.Close() 123 option = transport.WithListener(ln4) 124 tls = transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", "../testdata/ca.key") 125 require.NotNil(t, tp.ListenAndServe(ctx, option, handler, tls)) 126 } 127 128 func TestStartServerWithNoHandler(t *testing.T) { 129 ctx := context.Background() 130 tp := thttp.NewServerTransport(newNoopStdHTTPServer) 131 ln, err := net.Listen("tcp", "127.0.0.1:0") 132 require.Nil(t, err) 133 defer ln.Close() 134 option := transport.WithListener(ln) 135 require.NotNil(t, tp.ListenAndServe(ctx, option), "http server transport handler empty") 136 } 137 138 func TestErrHandler(t *testing.T) { 139 ctx := context.Background() 140 tp := thttp.NewServerTransport(newNoopStdHTTPServer) 141 ln, err := net.Listen("tcp", "127.0.0.1:0") 142 require.Nil(t, err) 143 defer ln.Close() 144 option := transport.WithListener(ln) 145 h := transport.WithHandler(transport.Handler(&errHandler{})) 146 require.Nil(t, tp.ListenAndServe(ctx, option, h)) 147 148 ct := thttp.NewClientTransport(true) 149 ctx, msg := codec.WithNewMessage(ctx) 150 msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello") 151 msg.WithClientReqHead(&thttp.ClientReqHeader{}) 152 msg.WithClientRspHead(&thttp.ClientRspHeader{}) 153 154 rsp, err := ct.RoundTrip(ctx, []byte("{\"username\":\"xyz\","+ 155 "\"password\":\"xyz\",\"from\":\"xyz\"}"), 156 transport.WithDialAddress(ln.Addr().String()), 157 ) 158 require.Nil(t, rsp, "roundtrip rsp not empty") 159 require.Nil(t, err, "Failed to roundtrip") 160 } 161 162 func TestErrHeaderHandler(t *testing.T) { 163 ctx := context.Background() 164 tp := thttp.NewServerTransport(newNoopStdHTTPServer) 165 ln, err := net.Listen("tcp", "127.0.0.1:0") 166 require.Nil(t, err) 167 defer func() { require.Nil(t, ln.Close()) }() 168 err = tp.ListenAndServe(ctx, 169 transport.WithHandler(transport.Handler(&errHeaderHandler{})), 170 transport.WithListener(ln), 171 ) 172 require.Nil(t, err) 173 174 ct := thttp.NewClientTransport(true) 175 ctx, msg := codec.WithNewMessage(ctx) 176 msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello") 177 msg.WithClientReqHead(&thttp.ClientReqHeader{}) 178 msg.WithClientRspHead(&thttp.ClientRspHeader{}) 179 180 rsp, err := ct.RoundTrip(ctx, []byte("{\"username\":\"xyz\","+ 181 "\"password\":\"xyz\",\"from\":\"xyz\"}"), 182 transport.WithDialAddress(ln.Addr().String()), 183 ) 184 require.Nil(t, rsp, "roundtrip rsp not empty") 185 require.Nil(t, err, "Failed to roundtrip") 186 } 187 188 func TestListenAndServeFailedDueToBadCertificationFile(t *testing.T) { 189 ctx := context.Background() 190 oldLogger := log.DefaultLogger 191 defer func() { 192 log.DefaultLogger = oldLogger 193 }() 194 errorCh := make(chan error) 195 log.DefaultLogger = &testLog{Logger: oldLogger, errorCh: errorCh} 196 197 ln, err := net.Listen("tcp", "127.0.0.1:0") 198 require.Nil(t, err) 199 defer func() { require.Nil(t, ln.Close()) }() 200 const badCertFile = "bad-file.cert" 201 require.Nil( 202 t, 203 thttp.NewServerTransport(newNoopStdHTTPServer).ListenAndServe( 204 ctx, 205 transport.WithListener(ln), 206 transport.WithHandler(transport.Handler(&h{})), 207 transport.WithServeTLS(badCertFile, "../testdata/server.key", ""), 208 ), 209 "failed to new client transport", 210 ) 211 212 select { 213 case <-time.After(time.Second): 214 t.Fatal("listen on a bad cert should log an error") 215 case err := <-errorCh: 216 require.Contains(t, err.Error(), badCertFile) 217 } 218 } 219 220 func TestStartTLSServerAndNoCheckServer(t *testing.T) { 221 ctx := context.Background() 222 ln, err := net.Listen("tcp", "127.0.0.1:0") 223 require.Nil(t, err) 224 defer func() { require.Nil(t, ln.Close()) }() 225 // Only enables https server and do not verify client certificate. 226 require.Nil( 227 t, 228 thttp.NewServerTransport(newNoopStdHTTPServer).ListenAndServe( 229 ctx, 230 transport.WithListener(ln), 231 transport.WithHandler(transport.Handler(&h{})), 232 transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", ""), 233 ), 234 "Failed to new client transport", 235 ) 236 237 ct := thttp.NewClientTransport(false) 238 ctx, msg := codec.WithNewMessage(ctx) 239 msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello") 240 msg.WithClientReqHead(&thttp.ClientReqHeader{}) 241 msg.WithClientRspHead(&thttp.ClientRspHeader{}) 242 243 rsp, err := ct.RoundTrip( 244 ctx, 245 []byte("{\"username\":\"xyz\","+"\"password\":\"xyz\",\"from\":\"xyz\"}"), 246 transport.WithDialAddress(ln.Addr().String()), 247 // Fully trust the https server and do not verify server certificate, 248 // can only be used in test env. 249 transport.WithDialTLS("", "", "none", ""), 250 ) 251 require.Nil(t, rsp, "roundtrip rsp not empty") 252 require.Nil(t, err, "Failed to roundtrip") 253 } 254 255 func TestServerWithListenerOption(t *testing.T) { 256 ln, err := net.Listen("tcp", "localhost:0") 257 require.Nil(t, err) 258 defer ln.Close() 259 service := server.New( 260 server.WithServiceName("trpc.http.server.ListenerTest"), 261 server.WithNetwork("tcp"), 262 server.WithProtocol("http"), 263 server.WithListener(ln), 264 ) 265 thttp.HandleFunc("/index", func(w http.ResponseWriter, r *http.Request) error { 266 fmt.Printf("Protocol: %s\n", r.Proto) 267 w.Write([]byte(r.Proto)) 268 return nil 269 }) 270 thttp.RegisterDefaultService(service) 271 s := &server.Server{} 272 s.AddService("trpc.http.server.ListenerTest", service) 273 go func() { 274 require.Nil(t, s.Serve()) 275 }() 276 defer s.Close(nil) 277 time.Sleep(100 * time.Millisecond) 278 279 resp, err := http.Get(fmt.Sprintf("http://%v/index", ln.Addr())) 280 require.Nil(t, err) 281 defer resp.Body.Close() 282 body, err := io.ReadAll(resp.Body) 283 require.Nil(t, err) 284 require.Equal(t, []byte("HTTP/1.1"), body) 285 286 const invalidAddr = "localhost:910439" 287 resp, err = http.Get(fmt.Sprintf("http://%s/index", invalidAddr)) 288 require.NotNil(t, err) 289 require.Nil(t, resp) 290 } 291 292 func TestStartDisableKeepAlivesServer(t *testing.T) { 293 ln, err := net.Listen("tcp", "localhost:0") 294 require.Nil(t, err) 295 defer ln.Close() 296 s := &server.Server{} 297 service := server.New( 298 server.WithListener(ln), 299 server.WithServiceName("trpc.http.server.ListenerTest"), 300 server.WithNetwork("tcp"), 301 server.WithProtocol("http"), 302 server.WithTransport(thttp.NewServerTransport(newNoopStdHTTPServer)), 303 server.WithDisableKeepAlives(true), 304 ) 305 thttp.HandleFunc("/disable-keepalives", func(w http.ResponseWriter, _ *http.Request) error { 306 w.Header().Set("Connection", "keep-alive") 307 return nil 308 }) 309 thttp.RegisterDefaultService(service) 310 s.AddService("trpc.http.server.ListenerTest", service) 311 go func() { 312 err := s.Serve() 313 require.Nil(t, err) 314 }() 315 defer func() { 316 _ = s.Close(nil) 317 }() 318 319 time.Sleep(100 * time.Millisecond) 320 321 dailCount := 0 322 client := &http.Client{ 323 Transport: &http.Transport{ 324 DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { 325 dailCount++ 326 conn, err := (&net.Dialer{}).DialContext(ctx, network, addr) 327 return conn, err 328 }, 329 }, 330 } 331 num := 3 332 url := fmt.Sprintf("http://%s/disable-keepalives", ln.Addr()) 333 for i := 0; i < num; i++ { 334 resp, err := client.Get(url) 335 require.Nil(t, err) 336 defer resp.Body.Close() 337 _, err = io.Copy(io.Discard, resp.Body) 338 require.Nil(t, err) 339 } 340 require.Equal(t, num, dailCount) 341 } 342 343 func TestStartH2cServer(t *testing.T) { 344 ln, err := net.Listen("tcp", "localhost:0") 345 require.Nil(t, err) 346 defer ln.Close() 347 s := &server.Server{} 348 service := server.New( 349 server.WithListener(ln), 350 server.WithServiceName("trpc.h2c.server.Greeter"), 351 server.WithNetwork("tcp"), 352 server.WithProtocol("http2"), 353 server.WithTransport(thttp.NewServerTransport(newNoopStdHTTPServer, thttp.WithEnableH2C())), 354 ) 355 thttp.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) error { 356 fmt.Printf("Protocol: %s\n", r.Proto) 357 w.Write([]byte(r.Proto)) 358 return nil 359 }) 360 thttp.HandleFunc("/main", func(w http.ResponseWriter, r *http.Request) error { 361 fmt.Printf("Protocol: %s\n", r.Proto) 362 w.Write([]byte(r.Proto)) 363 return nil 364 }) 365 thttp.RegisterDefaultService(service) 366 s.AddService("trpc.h2c.server.Greeter", service) 367 368 go func() { 369 err := s.Serve() 370 require.Nil(t, err) 371 }() 372 373 time.Sleep(100 * time.Millisecond) 374 375 // h2c client 376 h2cClient := http.Client{ 377 Transport: &http2.Transport{ 378 AllowHTTP: true, 379 DialTLS: func(network, addr string, cfg *tls.Config) (net.Conn, error) { 380 return net.Dial(network, addr) 381 }, 382 }, 383 } 384 url := fmt.Sprintf("http://%s/", ln.Addr()) 385 resp, err := h2cClient.Get(url + "main") 386 require.Nil(t, err) 387 defer resp.Body.Close() 388 body, err := io.ReadAll(resp.Body) 389 require.Nil(t, err) 390 require.Equal(t, []byte("HTTP/2.0"), body) 391 392 // http1 client 393 resp2, err := http.Get(url) 394 require.Nil(t, err) 395 defer resp2.Body.Close() 396 body, err = io.ReadAll(resp2.Body) 397 require.Nil(t, err) 398 require.Equal(t, []byte("HTTP/1.1"), body) 399 require.Equal(t, http.StatusOK, resp2.StatusCode) 400 } 401 402 func TestHttp2StartTLSServerAndNoCheckServer(t *testing.T) { 403 ctx := context.Background() 404 ln, err := net.Listen("tcp", "127.0.0.1:0") 405 require.Nil(t, err) 406 defer func() { require.Nil(t, ln.Close()) }() 407 // Only enables https server and do not verify client certificate. 408 require.Nil( 409 t, 410 thttp.NewServerTransport(newNoopStdHTTPServer).ListenAndServe( 411 ctx, 412 transport.WithListener(ln), 413 transport.WithHandler(transport.Handler(&h{})), 414 transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", ""), 415 ), 416 "Failed to new client transport", 417 ) 418 419 ct := thttp.NewClientTransport(true) 420 ctx, msg := codec.WithNewMessage(ctx) 421 msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello") 422 msg.WithClientReqHead(&thttp.ClientReqHeader{}) 423 msg.WithClientRspHead(&thttp.ClientRspHeader{}) 424 425 rsp, err := ct.RoundTrip( 426 ctx, 427 []byte("{\"username\":\"xyz\","+"\"password\":\"xyz\",\"from\":\"xyz\"}"), 428 transport.WithDialAddress(ln.Addr().String()), 429 // Fully trust the https server and do not verify server certificate, 430 // can only be used in test env. 431 transport.WithDialTLS("", "", "none", ""), 432 ) 433 require.Nil(t, rsp, "roundtrip rsp not empty") 434 require.Nil(t, err, "Failed to roundtrip") 435 } 436 437 func TestStartTLSServerAndCheckServer(t *testing.T) { 438 ctx := context.Background() 439 tp := thttp.NewServerTransport(newNoopStdHTTPServer) 440 ln, err := net.Listen("tcp", "127.0.0.1:0") 441 require.Nil(t, err) 442 defer func() { require.Nil(t, ln.Close()) }() 443 err = tp.ListenAndServe(ctx, 444 transport.WithHandler(transport.Handler(&h{})), 445 // Only enables https server and do not verify client certificate. 446 transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", ""), 447 transport.WithListener(ln), 448 ) 449 require.Nil(t, err, "Failed to new client transport") 450 451 ct := thttp.NewClientTransport(false) 452 ctx, msg := codec.WithNewMessage(ctx) 453 msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello") 454 msg.WithClientReqHead(&thttp.ClientReqHeader{}) 455 msg.WithClientRspHead(&thttp.ClientRspHeader{}) 456 457 rsp, err := ct.RoundTrip(ctx, []byte("{\"username\":\"xyz\","+ 458 "\"password\":\"xyz\",\"from\":\"xyz\"}"), 459 transport.WithDialAddress(ln.Addr().String()), 460 // Uses ca public key to verify server certificate. 461 transport.WithDialTLS("", "", "../testdata/ca.pem", "localhost"), 462 ) 463 require.Nil(t, rsp, "roundtrip rsp not empty") 464 require.Nil(t, err, "Failed to roundtrip") 465 } 466 467 func TestStartTLSServerAndCheckClientNoCert(t *testing.T) { 468 ctx := context.Background() 469 tp := thttp.NewServerTransport(newNoopStdHTTPServer) 470 ln, err := net.Listen("tcp", "127.0.0.1:0") 471 require.Nil(t, err) 472 defer func() { require.Nil(t, ln.Close()) }() 473 err = tp.ListenAndServe(ctx, 474 transport.WithHandler(transport.Handler(&h{})), 475 // Enables two-way authentication http server and need to verify client certificate. 476 transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", "../testdata/ca.pem"), 477 transport.WithListener(ln), 478 ) 479 require.Nil(t, err, "Failed to new client transport") 480 481 ct := thttp.NewClientTransport(false) 482 ctx, msg := codec.WithNewMessage(ctx) 483 msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello") 484 msg.WithClientReqHead(&thttp.ClientReqHeader{}) 485 msg.WithClientRspHead(&thttp.ClientRspHeader{}) 486 487 _, err = ct.RoundTrip(ctx, []byte("{\"username\":\"xyz\","+ 488 "\"password\":\"xyz\",\"from\":\"xyz\"}"), 489 transport.WithDialAddress(ln.Addr().String()), 490 // If the client's own certificate is not sent, will return TLS verification failed. 491 transport.WithDialTLS("", "", "../testdata/ca.pem", "localhost"), 492 ) 493 require.NotNil(t, err, "Failed to roundtrip") 494 } 495 496 func TestStartTLSServerAndCheckClient(t *testing.T) { 497 ctx := context.Background() 498 tp := thttp.NewServerTransport(newNoopStdHTTPServer) 499 ln, err := net.Listen("tcp", "127.0.0.1:0") 500 require.Nil(t, err) 501 defer func() { require.Nil(t, ln.Close()) }() 502 // Enables two-way authentication http server and need to verify client certificate. 503 err = tp.ListenAndServe(ctx, 504 transport.WithHandler(transport.Handler(&h{})), 505 // Only enables https server and do not verify client certificate. 506 transport.WithServeTLS("../testdata/server.crt", "../testdata/server.key", "../testdata/ca.pem"), 507 transport.WithListener(ln), 508 ) 509 require.Nil(t, err, "Failed to new client transport") 510 511 ct := thttp.NewClientTransport(false) 512 ctx, msg := codec.WithNewMessage(ctx) 513 msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello") 514 msg.WithClientReqHead(&thttp.ClientReqHeader{}) 515 msg.WithClientRspHead(&thttp.ClientRspHeader{}) 516 517 rsp, err := ct.RoundTrip(ctx, []byte("{\"username\":\"xyz\","+ 518 "\"password\":\"xyz\",\"from\":\"xyz\"}"), 519 transport.WithDialAddress(ln.Addr().String()), 520 // Need to send the client's own certificate to server. 521 transport.WithDialTLS("../testdata/client.crt", "../testdata/client.key", "../testdata/ca.pem", "localhost"), 522 ) 523 require.Nil(t, rsp, "roundtrip rsp not empty") 524 require.Nil(t, err, "Failed to roundtrip") 525 } 526 527 func TestNewClientTransport(t *testing.T) { 528 ct := thttp.NewClientTransport(false) 529 require.NotNil(t, ct, "Failed to new client transport") 530 531 ct2 := thttp.NewClientTransport(true) 532 require.NotNil(t, ct2, "Failed to new http2 client transport") 533 } 534 535 func TestClientRoundTrip(t *testing.T) { 536 ctx := context.Background() 537 ct := thttp.NewClientTransport(false) 538 ctx, msg := codec.WithNewMessage(ctx) 539 msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello") 540 msg.WithClientReqHead(&thttp.ClientReqHeader{}) 541 msg.WithClientRspHead(&thttp.ClientRspHeader{}) 542 ln, err := net.Listen("tcp", "127.0.0.1:0") 543 require.Nil(t, err) 544 defer ln.Close() 545 go http.Serve(ln, nil) 546 rsp, err := ct.RoundTrip(ctx, []byte("{\"username\":\"xyz\","+ 547 "\"password\":\"xyz\",\"from\":\"xyz\"}"), 548 transport.WithDialAddress(ln.Addr().String())) 549 require.Nil(t, rsp, "roundtrip rsp not empty") 550 require.Nil(t, err, "Failed to roundtrip") 551 } 552 553 func TestClientRoundTripWithNoHead(t *testing.T) { 554 ctx := context.Background() 555 ct := thttp.NewClientTransport(false) 556 ctx, msg := codec.WithNewMessage(ctx) 557 msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello") 558 559 rsp, err := ct.RoundTrip(ctx, []byte("{\"username\":\"xyz\","+ 560 "\"password\":\"xyz\",\"from\":\"xyz\"}"), 561 transport.WithDialAddress("127.0.0.1:18080")) 562 require.Nil(t, rsp, "no head roundtrip rsp not empty") 563 require.NotNil(t, err, "no head roundtrip err nil") 564 565 } 566 567 func TestClientWithSelectorNode(t *testing.T) { 568 ctx := context.Background() 569 type testCase struct { 570 target string 571 address string 572 listener net.Listener 573 } 574 var tests []testCase 575 for i := 0; i < 2; i++ { 576 ln, err := net.Listen("tcp", "127.0.0.1:0") 577 require.Nil(t, err) 578 defer ln.Close() 579 addr := ln.Addr().String() 580 tests = append(tests, testCase{"ip://" + addr, addr, ln}) 581 } 582 for _, tt := range tests { 583 tp := thttp.NewServerTransport(newNoopStdHTTPServer) 584 option := transport.WithListener(tt.listener) 585 handler := transport.WithHandler(transport.Handler(&h{})) 586 err := tp.ListenAndServe(ctx, option, handler) 587 require.Nil(t, err, "Failed to new client transport") 588 589 proxy := thttp.NewClientProxy("trpc.test.helloworld.Greeter", 590 client.WithTarget(tt.target), 591 client.WithSerializationType(codec.SerializationTypeNoop), 592 ) 593 594 reqBody := &codec.Body{ 595 Data: []byte("{\"username\":\"xyz\"," + 596 "\"password\":\"xyz\",\"from\":\"xyz\"}"), 597 } 598 rspBody := &codec.Body{} 599 n := ®istry.Node{} 600 require.Nil(t, 601 proxy.Post(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody, client.WithSelectorNode(n)), 602 "Failed to post") 603 require.Equal(t, tt.address, n.Address) 604 } 605 } 606 607 func TestClient(t *testing.T) { 608 ctx := context.Background() 609 old := codec.GetSerializer(codec.SerializationTypeJSON) 610 defer func() { codec.RegisterSerializer(codec.SerializationTypeJSON, old) }() 611 codec.RegisterSerializer(codec.SerializationTypeJSON, &codec.JSONPBSerialization{}) 612 tp := thttp.NewServerTransport(newNoopStdHTTPServer) 613 ln, err := net.Listen("tcp", "127.0.0.1:0") 614 require.Nil(t, err) 615 defer ln.Close() 616 option := transport.WithListener(ln) 617 handler := transport.WithHandler(transport.Handler(&h{})) 618 require.Nil(t, tp.ListenAndServe(ctx, option, handler), "Failed to new client transport") 619 620 header := &thttp.ClientReqHeader{} 621 header.AddHeader("ContentType", "application/json") 622 623 proxy := thttp.NewClientProxy("trpc.test.helloworld.Greeter", 624 client.WithTarget("ip://"+ln.Addr().String()), 625 client.WithSerializationType(codec.SerializationTypeNoop), 626 client.WithReqHead(header), 627 client.WithMetaData("k1", []byte("v1")), 628 ) 629 reqBody := &codec.Body{ 630 Data: []byte("{\"username\":\"xyz\"," + 631 "\"password\":\"xyz\",\"from\":\"xyz\"}"), 632 } 633 rspBody := &codec.Body{} 634 635 require.Nil(t, proxy.Post(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody), "Failed to post") 636 require.Nil(t, proxy.Put(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody), "Failed to put") 637 require.Nil(t, proxy.Delete(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody), "Failed to delete") 638 require.Nil(t, proxy.Get(ctx, "/trpc.test.helloworld.Greeter/SayHello", rspBody), "Failed to get") 639 require.Nil(t, proxy.Patch(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody), "Failed to patch") 640 641 // Test client with options. 642 proxy = thttp.NewClientProxy("trpc.test.helloworld.Greeter") 643 reqBody = &codec.Body{ 644 Data: []byte("{\"username\":\"xyz\"," + 645 "\"password\":\"xyz\",\"from\":\"xyz\"}"), 646 } 647 rspBody = &codec.Body{} 648 require.Nil(t, 649 proxy.Post(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody, 650 client.WithTarget("ip://"+ln.Addr().String()), 651 client.WithSerializationType(codec.SerializationTypeNoop), 652 client.WithReqHead(header), 653 client.WithMetaData("k1", []byte("v1")), 654 ), "Failed to post") 655 656 require.NotNil(t, 657 proxy.Post(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody, 658 client.WithTarget("ip://127.0.0.1:180"), 659 ), "Failed to post") 660 } 661 662 func TestReqHeader(t *testing.T) { 663 ctx := context.Background() 664 // Invalid url. 665 header := &thttp.ClientReqHeader{} 666 header.AddHeader("Content-Type", "application/json") 667 proxy := thttp.NewClientProxy("trpc.test.helloworld.Greeter", 668 client.WithTarget("ip://127.0.0.1:18080:www.baidu.com//"), 669 client.WithSerializationType(codec.SerializationTypeNoop), 670 client.WithReqHead(header), 671 ) 672 reqBody := &codec.Body{} 673 rspBody := &codec.Body{} 674 err := proxy.Post(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody) 675 require.NotNil(t, err) 676 } 677 678 func TestReqHeaderWithContentType(t *testing.T) { 679 ctx := context.Background() 680 ln, err := net.Listen("tcp", "127.0.0.1:0") 681 require.Nil(t, err) 682 defer ln.Close() 683 option := transport.WithListener(ln) 684 handler := transport.WithHandler(transport.Handler(&h{})) 685 tp := thttp.NewServerTransport(newNoopStdHTTPServer) 686 require.Nil(t, tp.ListenAndServe(ctx, option, handler), "Failed to new client transport") 687 var tests = []struct { 688 expected string 689 }{ 690 {"application/json"}, 691 {"application/jsonp"}, 692 {"application/jsonp123"}, 693 {"application/text123"}, 694 } 695 for _, tt := range tests { 696 header := &thttp.ClientReqHeader{} 697 header.AddHeader("Content-Type", tt.expected) 698 proxy := thttp.NewClientProxy("trpc.test.helloworld.Greeter", 699 client.WithTarget("ip://"+ln.Addr().String()), 700 client.WithSerializationType(codec.SerializationTypeForm), 701 client.WithReqHead(header), 702 ) 703 reqBody := &codec.Body{} 704 rspBody := &codec.Body{} 705 err := proxy.Post(ctx, "/trpc.test.helloworld.Greeter/SayHello", reqBody, rspBody) 706 require.Nil(t, err) 707 } 708 } 709 710 func TestHandler(t *testing.T) { 711 var ( 712 handler = func(w http.ResponseWriter, r *http.Request) { 713 return 714 } 715 handlerFunc = func(w http.ResponseWriter, r *http.Request) error { 716 return nil 717 } 718 service = server.New(server.WithProtocol("http")) 719 ) 720 721 thttp.Handle("*", http.HandlerFunc(handler)) 722 thttp.HandleFunc("/path/do/not/equal/to/*", handlerFunc) 723 thttp.RegisterDefaultService(service) 724 725 for _, method := range thttp.ServiceDesc.Methods { 726 method.Func(nil, context.TODO(), func(reqBody interface{}) (filter.ServerChain, error) { 727 return make([]filter.ServerFilter, 0), nil 728 }) 729 730 method.Func(nil, context.TODO(), func(reqBody interface{}) (filter.ServerChain, error) { 731 return make([]filter.ServerFilter, 0), errors.New("invalid filter") 732 }) 733 734 header := &thttp.Header{ 735 Request: &http.Request{}, 736 Response: &httptest.ResponseRecorder{}, 737 } 738 ctx := thttp.WithHeader(context.TODO(), header) 739 _, err := method.Func(nil, ctx, func(reqBody interface{}) (filter.ServerChain, error) { 740 return make([]filter.ServerFilter, 0), nil 741 }) 742 require.Nil(t, err) 743 } 744 } 745 746 func TestMux(t *testing.T) { 747 var handler = func(w http.ResponseWriter, r *http.Request) { 748 return 749 } 750 mux := http.NewServeMux() 751 mux.HandleFunc("/", handler) 752 753 var service = &mockService{} 754 thttp.RegisterServiceMux(service, mux) 755 desc, _ := service.desc.(*server.ServiceDesc) 756 for _, method := range desc.Methods { 757 method.Func(nil, context.TODO(), func(reqBody interface{}) (filter.ServerChain, error) { 758 return make([]filter.ServerFilter, 0), nil 759 }) 760 761 method.Func(nil, context.TODO(), func(reqBody interface{}) (filter.ServerChain, error) { 762 return make([]filter.ServerFilter, 0), errors.New("invalid filter") 763 }) 764 765 req, _ := http.NewRequest("GET", "/", nil) 766 header := &thttp.Header{ 767 Request: req, 768 Response: &httptest.ResponseRecorder{}, 769 } 770 ctx := thttp.WithHeader(context.TODO(), header) 771 _, err := method.Func(nil, ctx, func(reqBody interface{}) (filter.ServerChain, error) { 772 return make([]filter.ServerFilter, 0), nil 773 }) 774 require.Nil(t, err) 775 } 776 } 777 778 // TestCheckRedirect tests set CheckRedirect 779 func TestCheckRedirect(t *testing.T) { 780 ctx := context.Background() 781 ln, err := net.Listen("tcp", "127.0.0.1:0") 782 require.Nil(t, err) 783 defer ln.Close() 784 // server 785 go func() { 786 // real backend 787 h := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 788 w.Write([]byte("real")) 789 }) 790 http.Handle("/real", h) 791 792 // redirect a 793 rha := http.RedirectHandler("/b", http.StatusMovedPermanently) 794 http.Handle("/a", rha) 795 796 // redirect b 797 rhb := http.RedirectHandler("/real", http.StatusMovedPermanently) 798 http.Handle("/b", rhb) 799 800 http.Serve(ln, nil) 801 }() 802 time.Sleep(200 * time.Millisecond) 803 804 // sets CheckRedirect 805 checkRedirect := func(_ *http.Request, via []*http.Request) error { 806 if len(via) > 1 { 807 return errors.New("more than once") 808 } 809 return nil 810 } 811 thttp.DefaultClientTransport.(*thttp.ClientTransport).CheckRedirect = checkRedirect 812 proxy := thttp.NewClientProxy("trpc.test.helloworld.Greeter", 813 client.WithTarget("ip://"+ln.Addr().String()), 814 client.WithSerializationType(codec.SerializationTypeNoop), 815 ) 816 reqBody := &codec.Body{} 817 rspBody := &codec.Body{} 818 // only redirect once form /b 819 require.Nil(t, proxy.Post(ctx, "/b", reqBody, rspBody)) 820 // redirect twice from /a 821 err = proxy.Post(ctx, "/a", reqBody, rspBody) 822 require.NotNil(t, err) 823 require.Equal(t, true, strings.Contains(err.Error(), "more than once")) 824 } 825 826 func TestTransportError(t *testing.T) { 827 http.HandleFunc("/timeout", func(http.ResponseWriter, *http.Request) { 828 time.Sleep(time.Second) 829 }) 830 http.HandleFunc("/cancel", func(http.ResponseWriter, *http.Request) {}) 831 ln, err := net.Listen("tcp", "127.0.0.1:0") 832 require.Nil(t, err) 833 defer ln.Close() 834 go func() { http.Serve(ln, nil) }() 835 time.Sleep(200 * time.Millisecond) 836 837 proxy := thttp.NewClientProxy("trpc.test.helloworld.Greeter", 838 client.WithTarget("ip://"+ln.Addr().String()), 839 client.WithSerializationType(codec.SerializationTypeNoop), 840 client.WithTimeout(time.Millisecond*500), 841 ) 842 rspBody := &codec.Body{} 843 844 err = proxy.Get(context.Background(), "/timeout", rspBody) 845 terr, ok := err.(*errs.Error) 846 require.True(t, ok) 847 require.EqualValues(t, terr.Code, int32(errs.RetClientTimeout)) 848 849 ctx, cancel := context.WithCancel(context.Background()) 850 cancel() 851 err = proxy.Get(ctx, "/cancel", rspBody) 852 terr, ok = err.(*errs.Error) 853 require.True(t, ok) 854 require.EqualValues(t, terr.Code, int32(errs.RetClientCanceled)) 855 } 856 857 func TestClientRoundDyeing(t *testing.T) { 858 ctx := context.Background() 859 ct := thttp.NewClientTransport(false) 860 ctx, msg := codec.WithNewMessage(ctx) 861 msg.WithDyeing(true) 862 dyeingKey := "dyeingkey" 863 msg.WithDyeingKey(dyeingKey) 864 msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello") 865 req := &http.Request{ 866 Header: http.Header{}, 867 } 868 reqHeader := &thttp.ClientReqHeader{ 869 Request: req, 870 } 871 msg.WithClientReqHead(reqHeader) 872 rspHeader := &thttp.ClientRspHeader{} 873 msg.WithClientRspHead(rspHeader) 874 meta := codec.MetaData{ 875 thttp.TrpcDyeingKey: []byte(dyeingKey), 876 } 877 msg.WithClientMetaData(meta) 878 _, err := ct.RoundTrip(ctx, nil) 879 require.NotNil(t, err) 880 require.Equal(t, req.Header.Get(thttp.TrpcMessageType), 881 strconv.Itoa(int(trpcpb.TrpcMessageType_TRPC_DYEING_MESSAGE))) 882 } 883 884 func TestClientRoundEnvTransfer(t *testing.T) { 885 ctx := context.Background() 886 ct := thttp.NewClientTransport(false) 887 ctx, msg := codec.WithNewMessage(ctx) 888 msg.WithEnvTransfer("feat,master") 889 msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello") 890 req := &http.Request{ 891 Header: http.Header{}, 892 } 893 reqHeader := &thttp.ClientReqHeader{ 894 Request: req, 895 } 896 msg.WithClientReqHead(reqHeader) 897 rspHeader := &thttp.ClientRspHeader{} 898 msg.WithClientRspHead(rspHeader) 899 _, err := ct.RoundTrip(ctx, nil) 900 require.NotNil(t, err) 901 require.Contains(t, req.Header.Get(thttp.TrpcTransInfo), thttp.TrpcEnv) 902 } 903 904 func TestDisableBase64EncodeTransInfo(t *testing.T) { 905 ctx := context.Background() 906 ct := thttp.NewClientTransport(false, transport.WithDisableEncodeTransInfoBase64()) 907 ctx, msg := codec.WithNewMessage(ctx) 908 var ( 909 envTrans = "feat,master" 910 metaVal = "value" 911 dyeingKey = "dyeingkey" 912 ) 913 msg.WithEnvTransfer(envTrans) 914 msg.WithClientMetaData(codec.MetaData{"key": []byte(metaVal)}) 915 msg.WithDyeing(true) 916 msg.WithDyeingKey(dyeingKey) 917 msg.WithClientRPCName("/trpc.test.helloworld.Greeter/SayHello") 918 req := &http.Request{ 919 Header: http.Header{}, 920 } 921 reqHeader := &thttp.ClientReqHeader{ 922 Request: req, 923 } 924 msg.WithClientReqHead(reqHeader) 925 rspHeader := &thttp.ClientRspHeader{} 926 msg.WithClientRspHead(rspHeader) 927 _, err := ct.RoundTrip(ctx, nil) 928 require.NotNil(t, err) 929 require.Contains(t, req.Header.Get(thttp.TrpcTransInfo), envTrans) 930 require.Contains(t, req.Header.Get(thttp.TrpcTransInfo), metaVal) 931 require.Contains(t, req.Header.Get(thttp.TrpcTransInfo), dyeingKey) 932 } 933 934 func TestDisableServiceRouterTransInfo(t *testing.T) { 935 ctx := context.Background() 936 a := require.New(t) 937 ct := thttp.NewClientTransport(false) 938 ctx, msg := codec.WithNewMessage(ctx) 939 msg.WithClientMetaData(codec.MetaData{thttp.TrpcEnv: []byte("orienv")}) // this emulate decode trpc protocol client request 940 msg.WithEnvTransfer("feat,master") 941 req := &http.Request{ 942 Header: http.Header{}, 943 } 944 reqHeader := &thttp.ClientReqHeader{ 945 Request: req, 946 } 947 msg.WithClientReqHead(reqHeader) 948 rspHeader := &thttp.ClientRspHeader{} 949 msg.WithClientRspHead(rspHeader) 950 _, err := ct.RoundTrip(ctx, nil) 951 a.NotNil(err) 952 info, err := thttp.UnmarshalTransInfo(msg, req.Header.Get(thttp.TrpcTransInfo)) 953 a.NoError(err) 954 a.Equal(string(info[thttp.TrpcEnv]), "feat,master") 955 956 msg.WithEnvTransfer("") // DisableServiceRouter would clear EnvTransfer 957 _, err = ct.RoundTrip(ctx, nil) 958 a.NotNil(err) 959 info, err = thttp.UnmarshalTransInfo(msg, req.Header.Get(thttp.TrpcTransInfo)) 960 a.NoError(err) 961 a.Equal(string(info[thttp.TrpcEnv]), "") 962 } 963 964 func TestHTTPSUseClientVerify(t *testing.T) { 965 const ( 966 network = "tcp" 967 address = "127.0.0.1:0" 968 ) 969 ln, err := net.Listen(network, address) 970 require.Nil(t, err) 971 defer ln.Close() 972 serviceName := "trpc.app.server.Service" + t.Name() 973 service := server.New( 974 server.WithServiceName(serviceName), 975 server.WithNetwork(network), 976 server.WithProtocol("http_no_protocol"), 977 server.WithListener(ln), 978 server.WithTLS( 979 "../testdata/server.crt", 980 "../testdata/server.key", 981 "../testdata/ca.pem", 982 ), 983 ) 984 pattern := "/" + t.Name() 985 thttp.RegisterNoProtocolServiceMux(service, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 986 w.Write([]byte(t.Name())) 987 })) 988 s := &server.Server{} 989 s.AddService(serviceName, service) 990 go s.Serve() 991 defer s.Close(nil) 992 time.Sleep(100 * time.Millisecond) 993 994 c := thttp.NewClientProxy( 995 serviceName, 996 client.WithTarget("ip://"+ln.Addr().String()), 997 ) 998 req := &codec.Body{} 999 rsp := &codec.Body{} 1000 require.Nil(t, 1001 c.Post(context.Background(), pattern, req, rsp, 1002 client.WithCurrentSerializationType(codec.SerializationTypeNoop), 1003 client.WithSerializationType(codec.SerializationTypeNoop), 1004 client.WithCurrentCompressType(codec.CompressTypeNoop), 1005 client.WithTLS( 1006 "../testdata/client.crt", 1007 "../testdata/client.key", 1008 "../testdata/ca.pem", 1009 "localhost", 1010 ), 1011 )) 1012 require.Equal(t, []byte(t.Name()), rsp.Data) 1013 } 1014 1015 func TestHTTPSSkipClientVerify(t *testing.T) { 1016 const ( 1017 network = "tcp" 1018 address = "127.0.0.1:0" 1019 ) 1020 ln, err := net.Listen(network, address) 1021 require.Nil(t, err) 1022 defer ln.Close() 1023 serviceName := "trpc.app.server.Service" + t.Name() 1024 service := server.New( 1025 server.WithServiceName(serviceName), 1026 server.WithNetwork(network), 1027 server.WithProtocol("http_no_protocol"), 1028 server.WithListener(ln), 1029 server.WithTLS( 1030 "../testdata/server.crt", 1031 "../testdata/server.key", 1032 "", 1033 ), 1034 ) 1035 pattern := "/" + t.Name() 1036 thttp.RegisterNoProtocolServiceMux(service, http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { 1037 w.Write([]byte(t.Name())) 1038 })) 1039 s := &server.Server{} 1040 s.AddService(serviceName, service) 1041 go s.Serve() 1042 defer s.Close(nil) 1043 time.Sleep(100 * time.Millisecond) 1044 1045 c := thttp.NewClientProxy( 1046 serviceName, 1047 client.WithTarget("ip://"+ln.Addr().String()), 1048 ) 1049 req := &codec.Body{} 1050 rsp := &codec.Body{} 1051 require.Nil(t, 1052 c.Post(context.Background(), pattern, req, rsp, 1053 client.WithCurrentSerializationType(codec.SerializationTypeNoop), 1054 client.WithSerializationType(codec.SerializationTypeNoop), 1055 client.WithCurrentCompressType(codec.CompressTypeNoop), 1056 client.WithTLS( 1057 "", "", "none", "", 1058 ), 1059 )) 1060 require.Equal(t, []byte(t.Name()), rsp.Data) 1061 } 1062 1063 func TestListenAndServeHTTPHead(t *testing.T) { 1064 ctx := context.Background() 1065 const ( 1066 network = "tcp" 1067 address = "127.0.0.1:0" 1068 ) 1069 ln, err := net.Listen(network, address) 1070 require.Nil(t, err) 1071 defer ln.Close() 1072 st := thttp.NewServerTransport(newNoopStdHTTPServer) 1073 require.Nil(t, st.ListenAndServe(ctx, 1074 transport.WithHandler(&httpHeadHandler{ 1075 func(ctx context.Context, _ []byte) (rsp []byte, err error) { 1076 head := thttp.Head(ctx) 1077 head.Response.WriteHeader(http.StatusOK) 1078 head.Response.Write([]byte(fmt.Sprintf("%+v", thttp.Head(head.Request.Context()) != nil))) 1079 return 1080 }}), 1081 transport.WithListener(ln), 1082 )) 1083 time.Sleep(200 * time.Millisecond) 1084 rsp, err := http.Get("http://" + ln.Addr().String()) 1085 require.Nil(t, err) 1086 bs, err := io.ReadAll(rsp.Body) 1087 require.Nil(t, err) 1088 require.Equal(t, fmt.Sprintf("%+v", true), string(bs)) 1089 } 1090 1091 type httpHeadHandler struct { 1092 handle func(ctx context.Context, req []byte) (rsp []byte, err error) 1093 } 1094 1095 func (h *httpHeadHandler) Handle(ctx context.Context, req []byte) (rsp []byte, err error) { 1096 return h.handle(ctx, req) 1097 } 1098 1099 func TestHTTPStreamFileUpload(t *testing.T) { 1100 // Start server. 1101 const ( 1102 network = "tcp" 1103 address = "127.0.0.1:0" 1104 ) 1105 ln, err := net.Listen(network, address) 1106 require.Nil(t, err) 1107 defer ln.Close() 1108 go http.Serve(ln, &fileHandler{}) 1109 // Start client. 1110 c := thttp.NewClientProxy( 1111 "trpc.app.server.Service_http", 1112 client.WithTarget("ip://"+ln.Addr().String()), 1113 ) 1114 // Open and read file. 1115 fileDir, err := os.Getwd() 1116 require.Nil(t, err) 1117 fileName := "README.md" 1118 filePath := path.Join(fileDir, fileName) 1119 file, err := os.Open(filePath) 1120 require.Nil(t, err) 1121 defer file.Close() 1122 // Construct multipart form file. 1123 body := &bytes.Buffer{} 1124 writer := multipart.NewWriter(body) 1125 part, err := writer.CreateFormFile("field_name", filepath.Base(file.Name())) 1126 require.Nil(t, err) 1127 io.Copy(part, file) 1128 require.Nil(t, writer.Close()) 1129 // Add multipart form data header. 1130 header := http.Header{} 1131 header.Add("Content-Type", writer.FormDataContentType()) 1132 reqHeader := &thttp.ClientReqHeader{ 1133 Header: header, 1134 ReqBody: body, // Stream send. 1135 } 1136 req := &codec.Body{} 1137 rsp := &codec.Body{} 1138 // Upload file. 1139 require.Nil(t, 1140 c.Post(context.Background(), "/", req, rsp, 1141 client.WithCurrentSerializationType(codec.SerializationTypeNoop), 1142 client.WithSerializationType(codec.SerializationTypeNoop), 1143 client.WithCurrentCompressType(codec.CompressTypeNoop), 1144 client.WithReqHead(reqHeader), 1145 )) 1146 require.Equal(t, []byte(fileName), rsp.Data) 1147 } 1148 1149 type fileHandler struct{} 1150 1151 func (*fileHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 1152 _, h, err := r.FormFile("field_name") 1153 if err != nil { 1154 w.WriteHeader(http.StatusBadRequest) 1155 return 1156 } 1157 w.WriteHeader(http.StatusOK) 1158 // Write back file name. 1159 w.Write([]byte(h.Filename)) 1160 return 1161 } 1162 1163 func TestHTTPStreamRead(t *testing.T) { 1164 // Start server. 1165 const ( 1166 network = "tcp" 1167 address = "127.0.0.1:0" 1168 ) 1169 ln, err := net.Listen(network, address) 1170 require.Nil(t, err) 1171 defer ln.Close() 1172 go http.Serve(ln, &fileServer{}) 1173 1174 // Start client. 1175 c := thttp.NewClientProxy( 1176 "trpc.app.server.Service_http", 1177 client.WithTarget("ip://"+ln.Addr().String()), 1178 ) 1179 1180 // Enable manual body reading in order to 1181 // disable the framework's automatic body reading capability, 1182 // so that users can manually do their own client-side streaming reads. 1183 rspHead := &thttp.ClientRspHeader{ 1184 ManualReadBody: true, 1185 } 1186 req := &codec.Body{} 1187 rsp := &codec.Body{} 1188 require.Nil(t, 1189 c.Post(context.Background(), "/", req, rsp, 1190 client.WithCurrentSerializationType(codec.SerializationTypeNoop), 1191 client.WithSerializationType(codec.SerializationTypeNoop), 1192 client.WithCurrentCompressType(codec.CompressTypeNoop), 1193 client.WithRspHead(rspHead), 1194 )) 1195 require.Nil(t, rsp.Data) 1196 body := rspHead.Response.Body // Do stream reads directly from rspHead.Response.Body. 1197 defer body.Close() // Do remember to close the body. 1198 bs, err := io.ReadAll(body) 1199 require.Nil(t, err) 1200 require.NotNil(t, bs) 1201 } 1202 1203 func TestHTTPSendReceiveChunk(t *testing.T) { 1204 // HTTP chunked example: 1205 // 1. Client sends chunks: Add "chunked" transfer encoding header, and use io.Reader as body. 1206 // 2. Client reads chunks: The Go/net/http automatically handles the chunked reading. 1207 // Users can simply read resp.Body in a loop until io.EOF. 1208 // 3. Server reads chunks: Similar to client reads chunks. 1209 // 4. Server sends chunks: Assert http.ResponseWriter as http.Flusher, call flusher.Flush() after 1210 // writing a part of data, it will automatically trigger "chunked" encoding to send a chunk. 1211 1212 // Start server. 1213 const ( 1214 network = "tcp" 1215 address = "127.0.0.1:0" 1216 ) 1217 ln, err := net.Listen(network, address) 1218 require.Nil(t, err) 1219 defer ln.Close() 1220 go http.Serve(ln, &chunkedServer{}) 1221 1222 // Start client. 1223 c := thttp.NewClientProxy( 1224 "trpc.app.server.Service_http", 1225 client.WithTarget("ip://"+ln.Addr().String()), 1226 ) 1227 1228 // Open and read file. 1229 fileDir, err := os.Getwd() 1230 require.Nil(t, err) 1231 fileName := "README.md" 1232 filePath := path.Join(fileDir, fileName) 1233 file, err := os.Open(filePath) 1234 require.Nil(t, err) 1235 defer file.Close() 1236 1237 // 1. Client sends chunks. 1238 1239 // Add request headers. 1240 header := http.Header{} 1241 header.Add("Content-Type", "text/plain") 1242 // Add chunked transfer encoding header. 1243 header.Add("Transfer-Encoding", "chunked") 1244 reqHead := &thttp.ClientReqHeader{ 1245 Header: header, 1246 ReqBody: file, // Stream send (for chunks). 1247 } 1248 1249 // Enable manual body reading in order to 1250 // disable the framework's automatic body reading capability, 1251 // so that users can manually do their own client-side streaming reads. 1252 rspHead := &thttp.ClientRspHeader{ 1253 ManualReadBody: true, 1254 } 1255 req := &codec.Body{} 1256 rsp := &codec.Body{} 1257 require.Nil(t, 1258 c.Post(context.Background(), "/", req, rsp, 1259 client.WithCurrentSerializationType(codec.SerializationTypeNoop), 1260 client.WithSerializationType(codec.SerializationTypeNoop), 1261 client.WithCurrentCompressType(codec.CompressTypeNoop), 1262 client.WithReqHead(reqHead), 1263 client.WithRspHead(rspHead), 1264 )) 1265 require.Nil(t, rsp.Data) 1266 1267 // 2. Client reads chunks. 1268 1269 // Do stream reads directly from rspHead.Response.Body. 1270 body := rspHead.Response.Body 1271 defer body.Close() // Do remember to close the body. 1272 buf := make([]byte, 4096) 1273 var idx int 1274 for { 1275 n, err := body.Read(buf) 1276 if err == io.EOF { 1277 t.Logf("reached io.EOF\n") 1278 break 1279 } 1280 t.Logf("read chunk %d of length %d: %q\n", idx, n, buf[:n]) 1281 idx++ 1282 } 1283 } 1284 1285 type chunkedServer struct{} 1286 1287 func (*chunkedServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { 1288 // 3. Server reads chunks. 1289 1290 // io.ReadAll will read until io.EOF. 1291 // Go/net/http will automatically handle chunked body reads. 1292 bs, err := io.ReadAll(r.Body) 1293 if err != nil { 1294 w.WriteHeader(http.StatusInternalServerError) 1295 w.Write([]byte(fmt.Sprintf("io.ReadAll err: %+v", err))) 1296 return 1297 } 1298 1299 // 4. Server sends chunks. 1300 1301 // Send HTTP chunks using http.Flusher. 1302 // Reference: https://stackoverflow.com/questions/26769626/send-a-chunked-http-response-from-a-go-server. 1303 // The "Transfer-Encoding" header will be handled by the writer implicitly, so no need to set it. 1304 flusher, ok := w.(http.Flusher) 1305 if !ok { 1306 w.WriteHeader(http.StatusInternalServerError) 1307 w.Write([]byte("expected http.ResponseWriter to be an http.Flusher")) 1308 return 1309 } 1310 chunks := 10 1311 chunkSize := (len(bs) + chunks - 1) / chunks 1312 for i := 0; i < chunks; i++ { 1313 start := i * chunkSize 1314 end := (i + 1) * chunkSize 1315 if end > len(bs) { 1316 end = len(bs) 1317 } 1318 w.Write(bs[start:end]) 1319 flusher.Flush() // Trigger "chunked" encoding and send a chunk. 1320 time.Sleep(500 * time.Millisecond) 1321 } 1322 return 1323 } 1324 1325 type fileServer struct{} 1326 1327 func (*fileServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { 1328 http.ServeFile(w, r, "./README.md") 1329 return 1330 } 1331 1332 func TestHTTPSendAndReceiveSSE(t *testing.T) { 1333 const ( 1334 network = "tcp" 1335 address = "127.0.0.1:0" 1336 ) 1337 ln, err := net.Listen(network, address) 1338 require.Nil(t, err) 1339 defer ln.Close() 1340 serviceName := "trpc.app.server.Service" + t.Name() 1341 service := server.New( 1342 server.WithServiceName(serviceName), 1343 server.WithNetwork(network), 1344 server.WithProtocol("http_no_protocol"), 1345 server.WithListener(ln), 1346 ) 1347 pattern := "/" + t.Name() 1348 thttp.RegisterNoProtocolServiceMux(service, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1349 flusher, ok := w.(http.Flusher) 1350 if !ok { 1351 http.Error(w, "Streaming unsupported!", http.StatusInternalServerError) 1352 return 1353 } 1354 w.Header().Set("Content-Type", "text/event-stream") 1355 w.Header().Set("Cache-Control", "no-cache") 1356 w.Header().Set("Connection", "keep-alive") 1357 w.Header().Set("Access-Control-Allow-Origin", "*") 1358 bs, err := io.ReadAll(r.Body) 1359 if err != nil { 1360 http.Error(w, err.Error(), http.StatusBadRequest) 1361 return 1362 } 1363 msg := string(bs) 1364 for i := 0; i < 3; i++ { 1365 msgBytes := []byte("event: message\n\ndata: " + msg + strconv.Itoa(i) + "\n\n") 1366 _, err = w.Write(msgBytes) 1367 if err != nil { 1368 http.Error(w, err.Error(), http.StatusInternalServerError) 1369 return 1370 } 1371 flusher.Flush() 1372 time.Sleep(500 * time.Millisecond) 1373 } 1374 return 1375 })) 1376 s := &server.Server{} 1377 s.AddService(serviceName, service) 1378 go s.Serve() 1379 defer s.Close(nil) 1380 time.Sleep(100 * time.Millisecond) 1381 1382 c := thttp.NewClientProxy( 1383 serviceName, 1384 client.WithTarget("ip://"+ln.Addr().String()), 1385 ) 1386 header := http.Header{} 1387 header.Set("Cache-Control", "no-cache") 1388 header.Set("Accept", "text/event-stream") 1389 header.Set("Connection", "keep-alive") 1390 reqHeader := &thttp.ClientReqHeader{ 1391 Header: header, 1392 } 1393 // Enable manual body reading in order to 1394 // disable the framework's automatic body reading capability, 1395 // so that users can manually do their own client-side streaming reads. 1396 rspHead := &thttp.ClientRspHeader{ 1397 ManualReadBody: true, 1398 } 1399 req := &codec.Body{Data: []byte("hello")} 1400 rsp := &codec.Body{} 1401 require.Nil(t, 1402 c.Post(context.Background(), pattern, req, rsp, 1403 client.WithCurrentSerializationType(codec.SerializationTypeNoop), 1404 client.WithSerializationType(codec.SerializationTypeNoop), 1405 client.WithCurrentCompressType(codec.CompressTypeNoop), 1406 client.WithReqHead(reqHeader), 1407 client.WithRspHead(rspHead), 1408 )) 1409 body := rspHead.Response.Body // Do stream reads directly from rspHead.Response.Body. 1410 defer body.Close() // Do remember to close the body. 1411 data := make([]byte, 1024) 1412 for { 1413 n, err := body.Read(data) 1414 if err == io.EOF { 1415 break 1416 } 1417 require.Nil(t, err) 1418 t.Logf("Received message: \n%s\n", string(data[:n])) 1419 } 1420 } 1421 1422 func TestHTTPClientReqRspDifferentContentType(t *testing.T) { 1423 const ( 1424 network = "tcp" 1425 address = "127.0.0.1:0" 1426 ) 1427 ln, err := net.Listen(network, address) 1428 require.Nil(t, err) 1429 defer ln.Close() 1430 serviceName := "trpc.app.server.Service" + t.Name() 1431 service := server.New( 1432 server.WithServiceName(serviceName), 1433 server.WithNetwork(network), 1434 server.WithProtocol("http_no_protocol"), 1435 server.WithListener(ln), 1436 ) 1437 const ( 1438 hello = "hello " 1439 key = "key" 1440 ) 1441 pattern := "/" + t.Name() 1442 thttp.RegisterNoProtocolServiceMux(service, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 1443 bs, err := io.ReadAll(r.Body) 1444 if err != nil { 1445 w.WriteHeader(http.StatusBadRequest) 1446 return 1447 } 1448 req, err := url.ParseQuery(string(bs)) 1449 if err != nil { 1450 w.WriteHeader(http.StatusBadRequest) 1451 return 1452 } 1453 rsp := &helloworld.HelloReply{Message: hello + req.Get(key)} 1454 bs, err = codec.Marshal(codec.SerializationTypePB, rsp) 1455 if err != nil { 1456 w.WriteHeader(http.StatusInternalServerError) 1457 return 1458 } 1459 w.Header().Add("Content-Type", "application/protobuf") 1460 w.Write(bs) 1461 return 1462 })) 1463 s := &server.Server{} 1464 s.AddService(serviceName, service) 1465 go s.Serve() 1466 defer s.Close(nil) 1467 time.Sleep(100 * time.Millisecond) 1468 1469 c := thttp.NewClientProxy( 1470 serviceName, 1471 client.WithTarget("ip://"+ln.Addr().String()), 1472 ) 1473 req := make(url.Values) 1474 req.Add(key, t.Name()) 1475 rsp := &helloworld.HelloReply{} 1476 require.Nil(t, 1477 c.Post(context.Background(), pattern, req, rsp, 1478 client.WithSerializationType(codec.SerializationTypeForm), 1479 )) 1480 require.Equal(t, hello+t.Name(), rsp.Message) 1481 } 1482 1483 type h struct{} 1484 1485 func (*h) Handle(ctx context.Context, reqBuf []byte) (rsp []byte, err error) { 1486 fmt.Println("recv http req") 1487 return nil, nil 1488 } 1489 1490 type testLog struct { 1491 log.Logger 1492 errorCh chan error 1493 } 1494 1495 func (ln *testLog) Errorf(format string, args ...interface{}) { 1496 ln.errorCh <- fmt.Errorf(format, args...) 1497 } 1498 1499 // mockService is a mock service. 1500 type mockService struct { 1501 desc interface{} 1502 } 1503 1504 // Register registers route information. 1505 func (m *mockService) Register(serviceDesc interface{}, serviceImpl interface{}) error { 1506 m.desc = serviceDesc 1507 return nil 1508 } 1509 1510 // Serve runs service. 1511 func (m *mockService) Serve() error { 1512 return nil 1513 } 1514 1515 // Close closes service. 1516 func (m *mockService) Close(chan struct{}) error { 1517 return nil 1518 } 1519 1520 type errHandler struct{} 1521 1522 func (*errHandler) Handle(ctx context.Context, reqBuf []byte) (rsp []byte, err error) { 1523 return nil, errors.New("mock error") 1524 } 1525 1526 type errHeaderHandler struct{} 1527 1528 func (*errHeaderHandler) Handle(ctx context.Context, reqBuf []byte) (rsp []byte, err error) { 1529 return nil, thttp.ErrEncodeMissingHeader 1530 }