trpc.group/trpc-go/trpc-go@v1.0.3/http/transport.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 // Package http provides support for http protocol by default, 15 // provides rpc server with http protocol, and provides rpc client 16 // for calling http protocol. 17 package http 18 19 import ( 20 "bytes" 21 "context" 22 "crypto/tls" 23 "encoding/base64" 24 "errors" 25 "fmt" 26 "io" 27 "net" 28 "net/http" 29 stdhttp "net/http" 30 "net/http/httptrace" 31 "net/url" 32 "os" 33 "strconv" 34 "strings" 35 "sync" 36 "time" 37 38 "golang.org/x/net/http2" 39 "golang.org/x/net/http2/h2c" 40 icontext "trpc.group/trpc-go/trpc-go/internal/context" 41 "trpc.group/trpc-go/trpc-go/internal/reuseport" 42 trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc" 43 44 "trpc.group/trpc-go/trpc-go/codec" 45 "trpc.group/trpc-go/trpc-go/errs" 46 icodec "trpc.group/trpc-go/trpc-go/internal/codec" 47 itls "trpc.group/trpc-go/trpc-go/internal/tls" 48 "trpc.group/trpc-go/trpc-go/log" 49 "trpc.group/trpc-go/trpc-go/rpcz" 50 "trpc.group/trpc-go/trpc-go/transport" 51 ) 52 53 func init() { 54 st := NewServerTransport(func() *stdhttp.Server { return &stdhttp.Server{} }) 55 DefaultServerTransport = st 56 DefaultHTTP2ServerTransport = st 57 // Server transport (protocol file service). 58 transport.RegisterServerTransport("http", st) 59 transport.RegisterServerTransport("http2", st) 60 // Server transport (no protocol file service). 61 transport.RegisterServerTransport("http_no_protocol", st) 62 transport.RegisterServerTransport("http2_no_protocol", st) 63 // Client transport. 64 transport.RegisterClientTransport("http", DefaultClientTransport) 65 transport.RegisterClientTransport("http2", DefaultHTTP2ClientTransport) 66 } 67 68 // DefaultServerTransport is the default server http transport. 69 var DefaultServerTransport transport.ServerTransport 70 71 // DefaultHTTP2ServerTransport is the default server http2 transport. 72 var DefaultHTTP2ServerTransport transport.ServerTransport 73 74 // ServerTransport is the http transport layer. 75 type ServerTransport struct { 76 newServer func() *stdhttp.Server 77 reusePort bool 78 enableH2C bool 79 } 80 81 // NewServerTransport creates a new ServerTransport which implement transport.ServerTransport. 82 // The parameter newStdHttpServer is used to create the underlying stdhttp.Server when ListenAndServe, and that server 83 // is modified by opts of this function and ListenAndServe. 84 func NewServerTransport( 85 newStdHttpServer func() *stdhttp.Server, 86 opts ...OptServerTransport, 87 ) *ServerTransport { 88 st := ServerTransport{newServer: newStdHttpServer} 89 for _, opt := range opts { 90 opt(&st) 91 } 92 return &st 93 } 94 95 // ListenAndServe handles configuration. 96 func (t *ServerTransport) ListenAndServe(ctx context.Context, opt ...transport.ListenServeOption) error { 97 opts := &transport.ListenServeOptions{ 98 Network: "tcp", 99 } 100 for _, o := range opt { 101 o(opts) 102 } 103 if opts.Handler == nil { 104 return errors.New("http server transport handler empty") 105 } 106 return t.listenAndServeHTTP(ctx, opts) 107 } 108 109 var emptyBuf []byte 110 111 func (t *ServerTransport) listenAndServeHTTP(ctx context.Context, opts *transport.ListenServeOptions) error { 112 // All trpc-go http server transport only register this http.Handler. 113 serveFunc := func(w stdhttp.ResponseWriter, r *stdhttp.Request) { 114 h := &Header{Request: r, Response: w} 115 ctx := WithHeader(r.Context(), h) 116 117 // Generates new empty general message structure data and save it to ctx. 118 ctx, msg := codec.WithNewMessage(ctx) 119 defer codec.PutBackMessage(msg) 120 // The old request must be replaced to ensure that the context is embedded. 121 h.Request = r.WithContext(ctx) 122 defer func() { 123 // Fix issues/778 124 if r.MultipartForm == nil { 125 r.MultipartForm = h.Request.MultipartForm 126 } 127 }() 128 129 span, ender, ctx := rpcz.NewSpanContext(ctx, "http-server") 130 defer ender.End() 131 span.SetAttribute(rpcz.HTTPAttributeURL, r.URL) 132 span.SetAttribute(rpcz.HTTPAttributeRequestContentLength, r.ContentLength) 133 134 // Records LocalAddr and RemoteAddr to Context. 135 localAddr, ok := h.Request.Context().Value(stdhttp.LocalAddrContextKey).(net.Addr) 136 if ok { 137 msg.WithLocalAddr(localAddr) 138 } 139 raddr, _ := net.ResolveTCPAddr("tcp", h.Request.RemoteAddr) 140 msg.WithRemoteAddr(raddr) 141 _, err := opts.Handler.Handle(ctx, emptyBuf) 142 if err != nil { 143 span.SetAttribute(rpcz.TRPCAttributeError, err) 144 log.Errorf("http server transport handle fail:%v", err) 145 if err == ErrEncodeMissingHeader || errors.Is(err, errs.ErrServerNoResponse) { 146 w.WriteHeader(http.StatusInternalServerError) 147 _, _ = w.Write([]byte(fmt.Sprintf("http server handle error: %+v", err))) 148 } 149 return 150 } 151 } 152 153 s, err := t.newHTTPServer(serveFunc, opts) 154 if err != nil { 155 return err 156 } 157 158 if err := t.serve(ctx, s, opts); err != nil { 159 return err 160 } 161 return nil 162 } 163 164 func (t *ServerTransport) serve(ctx context.Context, s *stdhttp.Server, opts *transport.ListenServeOptions) error { 165 ln := opts.Listener 166 if ln == nil { 167 var err error 168 ln, err = t.getListener(opts.Network, s.Addr) 169 if err != nil { 170 return fmt.Errorf("http server transport get listener err: %w", err) 171 } 172 } 173 174 if err := transport.SaveListener(ln); err != nil { 175 return fmt.Errorf("save http listener error: %w", err) 176 } 177 178 if len(opts.TLSKeyFile) != 0 && len(opts.TLSCertFile) != 0 { 179 go func() { 180 if err := s.ServeTLS( 181 tcpKeepAliveListener{ln.(*net.TCPListener)}, 182 opts.TLSCertFile, 183 opts.TLSKeyFile, 184 ); err != stdhttp.ErrServerClosed { 185 log.Errorf("serve TLS failed: %w", err) 186 } 187 }() 188 } else { 189 go func() { 190 _ = s.Serve(tcpKeepAliveListener{ln.(*net.TCPListener)}) 191 }() 192 } 193 194 // Reuse ports: Kernel distributes IO ReadReady events to multiple cores and threads to accelerate IO efficiency. 195 if t.reusePort { 196 go func() { 197 <-ctx.Done() 198 _ = s.Shutdown(context.TODO()) 199 }() 200 } 201 go func() { 202 <-opts.StopListening 203 ln.Close() 204 }() 205 return nil 206 } 207 208 func (t *ServerTransport) getListener(network, addr string) (net.Listener, error) { 209 var ln net.Listener 210 v, _ := os.LookupEnv(transport.EnvGraceRestart) 211 ok, _ := strconv.ParseBool(v) 212 if ok { 213 // Find the passed listener. 214 pln, err := transport.GetPassedListener(network, addr) 215 if err != nil { 216 return nil, err 217 } 218 ln, ok = pln.(net.Listener) 219 if !ok { 220 return nil, fmt.Errorf("invalid listener type, want net.Listener, got %T", pln) 221 } 222 return ln, nil 223 } 224 225 if t.reusePort { 226 ln, err := reuseport.Listen(network, addr) 227 if err != nil { 228 return nil, fmt.Errorf("http reuseport listen error:%v", err) 229 } 230 return ln, nil 231 } 232 233 ln, err := net.Listen(network, addr) 234 if err != nil { 235 return nil, fmt.Errorf("http listen error:%v", err) 236 } 237 return ln, nil 238 } 239 240 // newHTTPServer creates http server. 241 func (t *ServerTransport) newHTTPServer( 242 serveFunc func(w stdhttp.ResponseWriter, r *stdhttp.Request), 243 opts *transport.ListenServeOptions, 244 ) (*stdhttp.Server, error) { 245 s := t.newServer() 246 s.Addr = opts.Address 247 s.Handler = stdhttp.HandlerFunc(serveFunc) 248 if t.enableH2C { 249 h2s := &http2.Server{} 250 s.Handler = h2c.NewHandler(stdhttp.HandlerFunc(serveFunc), h2s) 251 return s, nil 252 } 253 if len(opts.CACertFile) != 0 { // Enable two-way authentication to verify client certificate. 254 s.TLSConfig = &tls.Config{ 255 ClientAuth: tls.RequireAndVerifyClientCert, 256 } 257 certPool, err := itls.GetCertPool(opts.CACertFile) 258 if err != nil { 259 return nil, fmt.Errorf("http server get ca cert file error:%v", err) 260 } 261 s.TLSConfig.ClientCAs = certPool 262 } 263 if opts.DisableKeepAlives { 264 s.SetKeepAlivesEnabled(false) 265 } 266 if opts.IdleTimeout > 0 { 267 s.IdleTimeout = opts.IdleTimeout 268 } 269 return s, nil 270 } 271 272 // tcpKeepAliveListener sets TCP keep-alive timeouts on accepted 273 // connections. It's used by ListenAndServe and ListenAndServeTLS so 274 // dead TCP connections (e.g. closing laptop mid-download) eventually 275 // go away. 276 type tcpKeepAliveListener struct { 277 *net.TCPListener 278 } 279 280 // Accept accepts new request. 281 func (ln tcpKeepAliveListener) Accept() (net.Conn, error) { 282 tc, err := ln.AcceptTCP() 283 if err != nil { 284 return nil, err 285 } 286 _ = tc.SetKeepAlive(true) 287 _ = tc.SetKeepAlivePeriod(3 * time.Minute) 288 return tc, nil 289 } 290 291 // ClientTransport client side http transport. 292 type ClientTransport struct { 293 stdhttp.Client // http client, exposed variables, allow user to customize settings. 294 opts *transport.ClientTransportOptions 295 tlsClients map[string]*stdhttp.Client // Different certificate file use different TLS client. 296 tlsLock sync.RWMutex 297 http2Only bool 298 } 299 300 // DefaultClientTransport default client http transport. 301 var DefaultClientTransport = NewClientTransport(false) 302 303 // DefaultHTTP2ClientTransport default client http2 transport. 304 var DefaultHTTP2ClientTransport = NewClientTransport(true) 305 306 // NewClientTransport creates http transport. 307 func NewClientTransport(http2Only bool, opt ...transport.ClientTransportOption) transport.ClientTransport { 308 opts := &transport.ClientTransportOptions{} 309 310 // Write func options to field opts. 311 for _, o := range opt { 312 o(opts) 313 } 314 return &ClientTransport{ 315 opts: opts, 316 Client: stdhttp.Client{ 317 Transport: NewRoundTripper(StdHTTPTransport), 318 }, 319 tlsClients: make(map[string]*stdhttp.Client), 320 http2Only: http2Only, 321 } 322 } 323 324 func (ct *ClientTransport) getRequest(reqHeader *ClientReqHeader, 325 reqBody []byte, msg codec.Msg, opts *transport.RoundTripOptions) (*stdhttp.Request, error) { 326 req, err := ct.newRequest(reqHeader, reqBody, msg, opts) 327 if err != nil { 328 return nil, err 329 } 330 331 if reqHeader.Header != nil { 332 req.Header = make(stdhttp.Header) 333 for h, val := range reqHeader.Header { 334 req.Header[h] = val 335 } 336 } 337 if len(reqHeader.Host) != 0 { 338 req.Host = reqHeader.Host 339 } 340 req.Header.Set(TrpcCaller, msg.CallerServiceName()) 341 req.Header.Set(TrpcCallee, msg.CalleeServiceName()) 342 req.Header.Set(TrpcTimeout, strconv.Itoa(int(msg.RequestTimeout()/time.Millisecond))) 343 if opts.DisableConnectionPool { 344 req.Header.Set(Connection, "close") 345 req.Close = true 346 } 347 if t := msg.CompressType(); icodec.IsValidCompressType(t) && t != codec.CompressTypeNoop { 348 req.Header.Set("Content-Encoding", compressTypeContentEncoding[t]) 349 } 350 if msg.SerializationType() != codec.SerializationTypeNoop { 351 if len(req.Header.Get("Content-Type")) == 0 { 352 req.Header.Set("Content-Type", 353 serializationTypeContentType[msg.SerializationType()]) 354 } 355 } 356 if err := ct.setTransInfo(msg, req); err != nil { 357 return nil, err 358 } 359 if len(opts.TLSServerName) == 0 { 360 opts.TLSServerName = req.Host 361 } 362 return req, nil 363 } 364 365 func (ct *ClientTransport) setTransInfo(msg codec.Msg, req *stdhttp.Request) error { 366 var m map[string]string 367 if md := msg.ClientMetaData(); len(md) > 0 { 368 m = make(map[string]string, len(md)) 369 for k, v := range md { 370 m[k] = ct.encodeBytes(v) 371 } 372 } 373 374 // Set dyeing information. 375 if msg.Dyeing() { 376 if m == nil { 377 m = make(map[string]string) 378 } 379 m[TrpcDyeingKey] = ct.encodeString(msg.DyeingKey()) 380 req.Header.Set(TrpcMessageType, strconv.Itoa(int(trpcpb.TrpcMessageType_TRPC_DYEING_MESSAGE))) 381 } 382 383 if msg.EnvTransfer() != "" { 384 if m == nil { 385 m = make(map[string]string) 386 } 387 m[TrpcEnv] = ct.encodeString(msg.EnvTransfer()) 388 } else { 389 // If msg.EnvTransfer() empty, transmitted env info in req.TransInfo should be cleared 390 if _, ok := m[TrpcEnv]; ok { 391 m[TrpcEnv] = "" 392 } 393 } 394 395 if len(m) > 0 { 396 val, err := codec.Marshal(codec.SerializationTypeJSON, m) 397 if err != nil { 398 return errs.NewFrameError(errs.RetClientValidateFail, "http client json marshal metadata fail: "+err.Error()) 399 } 400 req.Header.Set(TrpcTransInfo, string(val)) 401 } 402 403 return nil 404 } 405 406 func (ct *ClientTransport) newRequest(reqHeader *ClientReqHeader, 407 reqBody []byte, msg codec.Msg, opts *transport.RoundTripOptions) (*stdhttp.Request, error) { 408 if reqHeader.Request != nil { 409 return reqHeader.Request, nil 410 } 411 scheme := reqHeader.Schema 412 if scheme == "" { 413 if len(opts.CACertFile) > 0 || strings.HasSuffix(opts.Address, ":443") { 414 scheme = "https" 415 } else { 416 scheme = "http" 417 } 418 } 419 420 body := reqHeader.ReqBody 421 if body == nil { 422 body = bytes.NewReader(reqBody) 423 } 424 425 request, err := stdhttp.NewRequest( 426 reqHeader.Method, 427 fmt.Sprintf("%s://%s%s", scheme, opts.Address, msg.ClientRPCName()), 428 body) 429 if err != nil { 430 return nil, errs.NewFrameError(errs.RetClientNetErr, 431 "http client transport NewRequest: "+err.Error()) 432 } 433 return request, nil 434 } 435 436 func (ct *ClientTransport) encodeBytes(in []byte) string { 437 if ct.opts.DisableHTTPEncodeTransInfoBase64 { 438 return string(in) 439 } 440 return base64.StdEncoding.EncodeToString(in) 441 } 442 443 func (ct *ClientTransport) encodeString(in string) string { 444 if ct.opts.DisableHTTPEncodeTransInfoBase64 { 445 return in 446 } 447 return base64.StdEncoding.EncodeToString([]byte(in)) 448 } 449 450 // RoundTrip sends and receives http packets, put http response into ctx, 451 // no need to return rspBuf here. 452 func (ct *ClientTransport) RoundTrip( 453 ctx context.Context, 454 reqBody []byte, 455 callOpts ...transport.RoundTripOption, 456 ) (rspBody []byte, err error) { 457 msg := codec.Message(ctx) 458 reqHeader, ok := msg.ClientReqHead().(*ClientReqHeader) 459 if !ok { 460 return nil, errs.NewFrameError(errs.RetClientEncodeFail, 461 "http client transport: ReqHead should be type of *http.ClientReqHeader") 462 } 463 rspHeader, ok := msg.ClientRspHead().(*ClientRspHeader) 464 if !ok { 465 return nil, errs.NewFrameError(errs.RetClientEncodeFail, 466 "http client transport: RspHead should be type of *http.ClientRspHeader") 467 } 468 469 var opts transport.RoundTripOptions 470 for _, o := range callOpts { 471 o(&opts) 472 } 473 474 // Sets reqHeader. 475 req, err := ct.getRequest(reqHeader, reqBody, msg, &opts) 476 if err != nil { 477 return nil, err 478 } 479 trace := &httptrace.ClientTrace{ 480 GotConn: func(info httptrace.GotConnInfo) { 481 msg.WithRemoteAddr(info.Conn.RemoteAddr()) 482 }, 483 } 484 reqCtx := ctx 485 cancel := context.CancelFunc(func() {}) 486 if rspHeader.ManualReadBody { 487 // In the scenario of Manual Read body, the lifecycle of rsp body is different 488 // from that of invoke ctx, and is independently controlled by body.Close(). 489 // Therefore, the timeout/cancel function in the original context needs to be replaced. 490 controlCtx := context.Background() 491 if deadline, ok := ctx.Deadline(); ok { 492 controlCtx, cancel = context.WithDeadline(context.Background(), deadline) 493 } 494 reqCtx = icontext.NewContextWithValues(controlCtx, ctx) 495 } 496 defer func() { 497 if err != nil { 498 cancel() 499 } 500 }() 501 request := req.WithContext(httptrace.WithClientTrace(reqCtx, trace)) 502 503 client, err := ct.getStdHTTPClient(opts.CACertFile, opts.TLSCertFile, 504 opts.TLSKeyFile, opts.TLSServerName) 505 if err != nil { 506 return nil, err 507 } 508 509 rspHeader.Response, err = client.Do(request) 510 if err != nil { 511 if e, ok := err.(*url.Error); ok { 512 if e.Timeout() { 513 return nil, errs.NewFrameError(errs.RetClientTimeout, 514 "http client transport RoundTrip timeout: "+err.Error()) 515 } 516 } 517 if ctx.Err() == context.Canceled { 518 return nil, errs.NewFrameError(errs.RetClientCanceled, 519 "http client transport RoundTrip canceled: "+err.Error()) 520 } 521 return nil, errs.NewFrameError(errs.RetClientNetErr, 522 "http client transport RoundTrip: "+err.Error()) 523 } 524 decorateWithCancel(rspHeader, cancel) 525 return emptyBuf, nil 526 } 527 528 func decorateWithCancel(rspHeader *ClientRspHeader, cancel context.CancelFunc) { 529 // Quoted from: https://github.com/golang/go/blob/go1.21.4/src/net/http/response.go#L69 530 // 531 // "As of Go 1.12, the Body will also implement io.Writer on a successful "101 Switching Protocols" response, 532 // as used by WebSockets and HTTP/2's "h2c" mode." 533 // 534 // Therefore, we require an extra check to ensure io.Writer's conformity, 535 // which will then expose the corresponding method. 536 // 537 // It's important to note that an embedded body may not be capable of exposing all the attached interfaces. 538 // Consequently, we perform an explicit interface assertion here. 539 if body, ok := rspHeader.Response.Body.(io.ReadWriteCloser); ok { 540 rspHeader.Response.Body = &writableResponseBodyWithCancel{ReadWriteCloser: body, cancel: cancel} 541 } else { 542 rspHeader.Response.Body = &responseBodyWithCancel{ReadCloser: rspHeader.Response.Body, cancel: cancel} 543 } 544 } 545 546 // writableResponseBodyWithCancel implements io.ReadWriteCloser. 547 // It wraps response body and cancel function. 548 type writableResponseBodyWithCancel struct { 549 io.ReadWriteCloser 550 cancel context.CancelFunc 551 } 552 553 func (b *writableResponseBodyWithCancel) Close() error { 554 b.cancel() 555 return b.ReadWriteCloser.Close() 556 } 557 558 // responseBodyWithCancel implements io.ReadCloser. 559 // It wraps response body and cancel function. 560 type responseBodyWithCancel struct { 561 io.ReadCloser 562 cancel context.CancelFunc 563 } 564 565 func (b *responseBodyWithCancel) Close() error { 566 b.cancel() 567 return b.ReadCloser.Close() 568 } 569 570 func (ct *ClientTransport) getStdHTTPClient(caFile, certFile, 571 keyFile, serverName string) (*stdhttp.Client, error) { 572 if len(caFile) == 0 { // HTTP requests share one client. 573 return &ct.Client, nil 574 } 575 576 cacheKey := fmt.Sprintf("%s-%s-%s", caFile, certFile, serverName) 577 ct.tlsLock.RLock() 578 cli, ok := ct.tlsClients[cacheKey] 579 ct.tlsLock.RUnlock() 580 if ok { 581 return cli, nil 582 } 583 584 ct.tlsLock.Lock() 585 defer ct.tlsLock.Unlock() 586 cli, ok = ct.tlsClients[cacheKey] 587 if ok { 588 return cli, nil 589 } 590 591 conf, err := itls.GetClientConfig(serverName, caFile, certFile, keyFile) 592 if err != nil { 593 return nil, err 594 } 595 client := &stdhttp.Client{ 596 CheckRedirect: ct.Client.CheckRedirect, 597 Timeout: ct.Client.Timeout, 598 } 599 if ct.http2Only { 600 client.Transport = &http2.Transport{ 601 TLSClientConfig: conf, 602 } 603 } else { 604 tr := StdHTTPTransport.Clone() 605 tr.TLSClientConfig = conf 606 client.Transport = NewRoundTripper(tr) 607 } 608 ct.tlsClients[cacheKey] = client 609 return client, nil 610 } 611 612 // StdHTTPTransport all RoundTripper object used by http and https. 613 var StdHTTPTransport = &stdhttp.Transport{ 614 Proxy: stdhttp.ProxyFromEnvironment, 615 DialContext: (&net.Dialer{ 616 Timeout: 30 * time.Second, 617 KeepAlive: 30 * time.Second, 618 DualStack: true, 619 }).DialContext, 620 ForceAttemptHTTP2: true, 621 IdleConnTimeout: 50 * time.Second, 622 TLSHandshakeTimeout: 10 * time.Second, 623 MaxIdleConnsPerHost: 100, 624 DisableCompression: true, 625 ExpectContinueTimeout: time.Second, 626 } 627 628 // NewRoundTripper creates new NewRoundTripper and can be replaced. 629 var NewRoundTripper = newValueDetachedTransport