trpc.group/trpc-go/trpc-go@v1.0.2/http/transport.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  // Package http provides support for http protocol by default,
    15  // provides rpc server with http protocol, and provides rpc client
    16  // for calling http protocol.
    17  package http
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"crypto/tls"
    23  	"encoding/base64"
    24  	"errors"
    25  	"fmt"
    26  	"net"
    27  	stdhttp "net/http"
    28  	"net/http/httptrace"
    29  	"net/url"
    30  	"os"
    31  	"strconv"
    32  	"strings"
    33  	"sync"
    34  	"time"
    35  
    36  	"golang.org/x/net/http2"
    37  	"golang.org/x/net/http2/h2c"
    38  	"trpc.group/trpc-go/trpc-go/internal/reuseport"
    39  	trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc"
    40  
    41  	"trpc.group/trpc-go/trpc-go/codec"
    42  	"trpc.group/trpc-go/trpc-go/errs"
    43  	icodec "trpc.group/trpc-go/trpc-go/internal/codec"
    44  	itls "trpc.group/trpc-go/trpc-go/internal/tls"
    45  	"trpc.group/trpc-go/trpc-go/log"
    46  	"trpc.group/trpc-go/trpc-go/rpcz"
    47  	"trpc.group/trpc-go/trpc-go/transport"
    48  )
    49  
    50  func init() {
    51  	st := NewServerTransport(func() *stdhttp.Server { return &stdhttp.Server{} })
    52  	DefaultServerTransport = st
    53  	DefaultHTTP2ServerTransport = st
    54  	// Server transport (protocol file service).
    55  	transport.RegisterServerTransport("http", st)
    56  	transport.RegisterServerTransport("http2", st)
    57  	// Server transport (no protocol file service).
    58  	transport.RegisterServerTransport("http_no_protocol", st)
    59  	transport.RegisterServerTransport("http2_no_protocol", st)
    60  	// Client transport.
    61  	transport.RegisterClientTransport("http", DefaultClientTransport)
    62  	transport.RegisterClientTransport("http2", DefaultHTTP2ClientTransport)
    63  }
    64  
    65  // DefaultServerTransport is the default server http transport.
    66  var DefaultServerTransport transport.ServerTransport
    67  
    68  // DefaultHTTP2ServerTransport is the default server http2 transport.
    69  var DefaultHTTP2ServerTransport transport.ServerTransport
    70  
    71  // ServerTransport is the http transport layer.
    72  type ServerTransport struct {
    73  	newServer func() *stdhttp.Server
    74  	reusePort bool
    75  	enableH2C bool
    76  }
    77  
    78  // NewServerTransport creates a new ServerTransport which implement transport.ServerTransport.
    79  // The parameter newStdHttpServer is used to create the underlying stdhttp.Server when ListenAndServe, and that server
    80  // is modified by opts of this function and ListenAndServe.
    81  func NewServerTransport(
    82  	newStdHttpServer func() *stdhttp.Server,
    83  	opts ...OptServerTransport,
    84  ) *ServerTransport {
    85  	st := ServerTransport{newServer: newStdHttpServer}
    86  	for _, opt := range opts {
    87  		opt(&st)
    88  	}
    89  	return &st
    90  }
    91  
    92  // ListenAndServe handles configuration.
    93  func (t *ServerTransport) ListenAndServe(ctx context.Context, opt ...transport.ListenServeOption) error {
    94  	opts := &transport.ListenServeOptions{
    95  		Network: "tcp",
    96  	}
    97  	for _, o := range opt {
    98  		o(opts)
    99  	}
   100  	if opts.Handler == nil {
   101  		return errors.New("http server transport handler empty")
   102  	}
   103  	return t.listenAndServeHTTP(ctx, opts)
   104  }
   105  
   106  var emptyBuf []byte
   107  
   108  func (t *ServerTransport) listenAndServeHTTP(ctx context.Context, opts *transport.ListenServeOptions) error {
   109  	// All trpc-go http server transport only register this http.Handler.
   110  	serveFunc := func(w stdhttp.ResponseWriter, r *stdhttp.Request) {
   111  		h := &Header{Request: r, Response: w}
   112  		ctx := WithHeader(r.Context(), h)
   113  
   114  		// Generates new empty general message structure data and save it to ctx.
   115  		ctx, msg := codec.WithNewMessage(ctx)
   116  		defer codec.PutBackMessage(msg)
   117  		// The old request must be replaced to ensure that the context is embedded.
   118  		h.Request = r.WithContext(ctx)
   119  		defer func() {
   120  			// Fix issues/778
   121  			if r.MultipartForm == nil {
   122  				r.MultipartForm = h.Request.MultipartForm
   123  			}
   124  		}()
   125  
   126  		span, ender, ctx := rpcz.NewSpanContext(ctx, "http-server")
   127  		defer ender.End()
   128  		span.SetAttribute(rpcz.HTTPAttributeURL, r.URL)
   129  		span.SetAttribute(rpcz.HTTPAttributeRequestContentLength, r.ContentLength)
   130  
   131  		// Records LocalAddr and RemoteAddr to Context.
   132  		localAddr, ok := h.Request.Context().Value(stdhttp.LocalAddrContextKey).(net.Addr)
   133  		if ok {
   134  			msg.WithLocalAddr(localAddr)
   135  		}
   136  		raddr, _ := net.ResolveTCPAddr("tcp", h.Request.RemoteAddr)
   137  		msg.WithRemoteAddr(raddr)
   138  		_, err := opts.Handler.Handle(ctx, emptyBuf)
   139  		if err != nil {
   140  			span.SetAttribute(rpcz.TRPCAttributeError, err)
   141  			log.Errorf("http server transport handle fail:%v", err)
   142  			if err == ErrEncodeMissingHeader {
   143  				w.WriteHeader(500)
   144  			}
   145  			return
   146  		}
   147  	}
   148  
   149  	s, err := t.newHTTPServer(serveFunc, opts)
   150  	if err != nil {
   151  		return err
   152  	}
   153  
   154  	if err := t.serve(ctx, s, opts); err != nil {
   155  		return err
   156  	}
   157  	return nil
   158  }
   159  
   160  func (t *ServerTransport) serve(ctx context.Context, s *stdhttp.Server, opts *transport.ListenServeOptions) error {
   161  	ln := opts.Listener
   162  	if ln == nil {
   163  		var err error
   164  		ln, err = t.getListener(opts.Network, s.Addr)
   165  		if err != nil {
   166  			return fmt.Errorf("http server transport get listener err: %w", err)
   167  		}
   168  	}
   169  
   170  	if err := transport.SaveListener(ln); err != nil {
   171  		return fmt.Errorf("save http listener error: %w", err)
   172  	}
   173  
   174  	if len(opts.TLSKeyFile) != 0 && len(opts.TLSCertFile) != 0 {
   175  		go func() {
   176  			if err := s.ServeTLS(
   177  				tcpKeepAliveListener{ln.(*net.TCPListener)},
   178  				opts.TLSCertFile,
   179  				opts.TLSKeyFile,
   180  			); err != stdhttp.ErrServerClosed {
   181  				log.Errorf("serve TLS failed: %w", err)
   182  			}
   183  		}()
   184  	} else {
   185  		go func() {
   186  			_ = s.Serve(tcpKeepAliveListener{ln.(*net.TCPListener)})
   187  		}()
   188  	}
   189  
   190  	// Reuse ports: Kernel distributes IO ReadReady events to multiple cores and threads to accelerate IO efficiency.
   191  	if t.reusePort {
   192  		go func() {
   193  			<-ctx.Done()
   194  			_ = s.Shutdown(context.TODO())
   195  		}()
   196  	}
   197  	return nil
   198  }
   199  
   200  func (t *ServerTransport) getListener(network, addr string) (net.Listener, error) {
   201  	var ln net.Listener
   202  	v, _ := os.LookupEnv(transport.EnvGraceRestart)
   203  	ok, _ := strconv.ParseBool(v)
   204  	if ok {
   205  		// Find the passed listener.
   206  		pln, err := transport.GetPassedListener(network, addr)
   207  		if err != nil {
   208  			return nil, err
   209  		}
   210  		ln, ok = pln.(net.Listener)
   211  		if !ok {
   212  			return nil, fmt.Errorf("invalid listener type, want net.Listener, got %T", pln)
   213  		}
   214  		return ln, nil
   215  	}
   216  
   217  	if t.reusePort {
   218  		ln, err := reuseport.Listen(network, addr)
   219  		if err != nil {
   220  			return nil, fmt.Errorf("http reuseport listen error:%v", err)
   221  		}
   222  		return ln, nil
   223  	}
   224  
   225  	ln, err := net.Listen(network, addr)
   226  	if err != nil {
   227  		return nil, fmt.Errorf("http listen error:%v", err)
   228  	}
   229  	return ln, nil
   230  }
   231  
   232  // newHTTPServer creates http server.
   233  func (t *ServerTransport) newHTTPServer(
   234  	serveFunc func(w stdhttp.ResponseWriter, r *stdhttp.Request),
   235  	opts *transport.ListenServeOptions,
   236  ) (*stdhttp.Server, error) {
   237  	s := t.newServer()
   238  	s.Addr = opts.Address
   239  	s.Handler = stdhttp.HandlerFunc(serveFunc)
   240  	if t.enableH2C {
   241  		h2s := &http2.Server{}
   242  		s.Handler = h2c.NewHandler(stdhttp.HandlerFunc(serveFunc), h2s)
   243  		return s, nil
   244  	}
   245  	if len(opts.CACertFile) != 0 { // Enable two-way authentication to verify client certificate.
   246  		s.TLSConfig = &tls.Config{
   247  			ClientAuth: tls.RequireAndVerifyClientCert,
   248  		}
   249  		certPool, err := itls.GetCertPool(opts.CACertFile)
   250  		if err != nil {
   251  			return nil, fmt.Errorf("http server get ca cert file error:%v", err)
   252  		}
   253  		s.TLSConfig.ClientCAs = certPool
   254  	}
   255  	if opts.DisableKeepAlives {
   256  		s.SetKeepAlivesEnabled(false)
   257  	}
   258  	if opts.IdleTimeout > 0 {
   259  		s.IdleTimeout = opts.IdleTimeout
   260  	}
   261  	return s, nil
   262  }
   263  
   264  // tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
   265  // connections. It's used by ListenAndServe and ListenAndServeTLS so
   266  // dead TCP connections (e.g. closing laptop mid-download) eventually
   267  // go away.
   268  type tcpKeepAliveListener struct {
   269  	*net.TCPListener
   270  }
   271  
   272  // Accept accepts new request.
   273  func (ln tcpKeepAliveListener) Accept() (net.Conn, error) {
   274  	tc, err := ln.AcceptTCP()
   275  	if err != nil {
   276  		return nil, err
   277  	}
   278  	_ = tc.SetKeepAlive(true)
   279  	_ = tc.SetKeepAlivePeriod(3 * time.Minute)
   280  	return tc, nil
   281  }
   282  
   283  // ClientTransport client side http transport.
   284  type ClientTransport struct {
   285  	stdhttp.Client // http client, exposed variables, allow user to customize settings.
   286  	opts           *transport.ClientTransportOptions
   287  	tlsClients     map[string]*stdhttp.Client // Different certificate file use different TLS client.
   288  	tlsLock        sync.RWMutex
   289  	http2Only      bool
   290  }
   291  
   292  // DefaultClientTransport default client http transport.
   293  var DefaultClientTransport = NewClientTransport(false)
   294  
   295  // DefaultHTTP2ClientTransport default client http2 transport.
   296  var DefaultHTTP2ClientTransport = NewClientTransport(true)
   297  
   298  // NewClientTransport creates http transport.
   299  func NewClientTransport(http2Only bool, opt ...transport.ClientTransportOption) transport.ClientTransport {
   300  	opts := &transport.ClientTransportOptions{}
   301  
   302  	// Write func options to field opts.
   303  	for _, o := range opt {
   304  		o(opts)
   305  	}
   306  	return &ClientTransport{
   307  		opts: opts,
   308  		Client: stdhttp.Client{
   309  			Transport: NewRoundTripper(StdHTTPTransport),
   310  		},
   311  		tlsClients: make(map[string]*stdhttp.Client),
   312  		http2Only:  http2Only,
   313  	}
   314  }
   315  
   316  func (ct *ClientTransport) getRequest(reqHeader *ClientReqHeader,
   317  	reqBody []byte, msg codec.Msg, opts *transport.RoundTripOptions) (*stdhttp.Request, error) {
   318  	req, err := ct.newRequest(reqHeader, reqBody, msg, opts)
   319  	if err != nil {
   320  		return nil, err
   321  	}
   322  
   323  	if reqHeader.Header != nil {
   324  		req.Header = make(stdhttp.Header)
   325  		for h, val := range reqHeader.Header {
   326  			req.Header[h] = val
   327  		}
   328  	}
   329  	if len(reqHeader.Host) != 0 {
   330  		req.Host = reqHeader.Host
   331  	}
   332  	req.Header.Set(TrpcCaller, msg.CallerServiceName())
   333  	req.Header.Set(TrpcCallee, msg.CalleeServiceName())
   334  	req.Header.Set(TrpcTimeout, strconv.Itoa(int(msg.RequestTimeout()/time.Millisecond)))
   335  	if opts.DisableConnectionPool {
   336  		req.Header.Set(Connection, "close")
   337  		req.Close = true
   338  	}
   339  	if t := msg.CompressType(); icodec.IsValidCompressType(t) && t != codec.CompressTypeNoop {
   340  		req.Header.Set("Content-Encoding", compressTypeContentEncoding[t])
   341  	}
   342  	if msg.SerializationType() != codec.SerializationTypeNoop {
   343  		if len(req.Header.Get("Content-Type")) == 0 {
   344  			req.Header.Set("Content-Type",
   345  				serializationTypeContentType[msg.SerializationType()])
   346  		}
   347  	}
   348  	if err := ct.setTransInfo(msg, req); err != nil {
   349  		return nil, err
   350  	}
   351  	if len(opts.TLSServerName) == 0 {
   352  		opts.TLSServerName = req.Host
   353  	}
   354  	return req, nil
   355  }
   356  
   357  func (ct *ClientTransport) setTransInfo(msg codec.Msg, req *stdhttp.Request) error {
   358  	var m map[string]string
   359  	if md := msg.ClientMetaData(); len(md) > 0 {
   360  		m = make(map[string]string, len(md))
   361  		for k, v := range md {
   362  			m[k] = ct.encodeBytes(v)
   363  		}
   364  	}
   365  
   366  	// Set dyeing information.
   367  	if msg.Dyeing() {
   368  		if m == nil {
   369  			m = make(map[string]string)
   370  		}
   371  		m[TrpcDyeingKey] = ct.encodeString(msg.DyeingKey())
   372  		req.Header.Set(TrpcMessageType, strconv.Itoa(int(trpcpb.TrpcMessageType_TRPC_DYEING_MESSAGE)))
   373  	}
   374  
   375  	if msg.EnvTransfer() != "" {
   376  		if m == nil {
   377  			m = make(map[string]string)
   378  		}
   379  		m[TrpcEnv] = ct.encodeString(msg.EnvTransfer())
   380  	} else {
   381  		// If msg.EnvTransfer() empty, transmitted env info in req.TransInfo should be cleared
   382  		if _, ok := m[TrpcEnv]; ok {
   383  			m[TrpcEnv] = ""
   384  		}
   385  	}
   386  
   387  	if len(m) > 0 {
   388  		val, err := codec.Marshal(codec.SerializationTypeJSON, m)
   389  		if err != nil {
   390  			return errs.NewFrameError(errs.RetClientValidateFail, "http client json marshal metadata fail: "+err.Error())
   391  		}
   392  		req.Header.Set(TrpcTransInfo, string(val))
   393  	}
   394  
   395  	return nil
   396  }
   397  
   398  func (ct *ClientTransport) newRequest(reqHeader *ClientReqHeader,
   399  	reqBody []byte, msg codec.Msg, opts *transport.RoundTripOptions) (*stdhttp.Request, error) {
   400  	if reqHeader.Request != nil {
   401  		return reqHeader.Request, nil
   402  	}
   403  	scheme := reqHeader.Schema
   404  	if scheme == "" {
   405  		if len(opts.CACertFile) > 0 || strings.HasSuffix(opts.Address, ":443") {
   406  			scheme = "https"
   407  		} else {
   408  			scheme = "http"
   409  		}
   410  	}
   411  
   412  	body := reqHeader.ReqBody
   413  	if body == nil {
   414  		body = bytes.NewReader(reqBody)
   415  	}
   416  
   417  	request, err := stdhttp.NewRequest(
   418  		reqHeader.Method,
   419  		fmt.Sprintf("%s://%s%s", scheme, opts.Address, msg.ClientRPCName()),
   420  		body)
   421  	if err != nil {
   422  		return nil, errs.NewFrameError(errs.RetClientNetErr,
   423  			"http client transport NewRequest: "+err.Error())
   424  	}
   425  	return request, nil
   426  }
   427  
   428  func (ct *ClientTransport) encodeBytes(in []byte) string {
   429  	if ct.opts.DisableHTTPEncodeTransInfoBase64 {
   430  		return string(in)
   431  	}
   432  	return base64.StdEncoding.EncodeToString(in)
   433  }
   434  
   435  func (ct *ClientTransport) encodeString(in string) string {
   436  	if ct.opts.DisableHTTPEncodeTransInfoBase64 {
   437  		return in
   438  	}
   439  	return base64.StdEncoding.EncodeToString([]byte(in))
   440  }
   441  
   442  // RoundTrip sends and receives http packets, put http response into ctx,
   443  // no need to return rspBuf here.
   444  func (ct *ClientTransport) RoundTrip(
   445  	ctx context.Context,
   446  	reqBody []byte,
   447  	callOpts ...transport.RoundTripOption,
   448  ) (rspBody []byte, err error) {
   449  	msg := codec.Message(ctx)
   450  	reqHeader, ok := msg.ClientReqHead().(*ClientReqHeader)
   451  	if !ok {
   452  		return nil, errs.NewFrameError(errs.RetClientEncodeFail,
   453  			"http client transport: ReqHead should be type of *http.ClientReqHeader")
   454  	}
   455  	rspHeader, ok := msg.ClientRspHead().(*ClientRspHeader)
   456  	if !ok {
   457  		return nil, errs.NewFrameError(errs.RetClientEncodeFail,
   458  			"http client transport: RspHead should be type of *http.ClientRspHeader")
   459  	}
   460  
   461  	var opts transport.RoundTripOptions
   462  	for _, o := range callOpts {
   463  		o(&opts)
   464  	}
   465  
   466  	// Sets reqHeader.
   467  	req, err := ct.getRequest(reqHeader, reqBody, msg, &opts)
   468  	if err != nil {
   469  		return nil, err
   470  	}
   471  	trace := &httptrace.ClientTrace{
   472  		ConnectStart: func(network, addr string) {
   473  			tcpAddr, _ := net.ResolveTCPAddr(network, addr)
   474  			msg.WithRemoteAddr(tcpAddr)
   475  		},
   476  	}
   477  	request := req.WithContext(httptrace.WithClientTrace(ctx, trace))
   478  
   479  	client, err := ct.getStdHTTPClient(opts.CACertFile, opts.TLSCertFile,
   480  		opts.TLSKeyFile, opts.TLSServerName)
   481  	if err != nil {
   482  		return nil, err
   483  	}
   484  
   485  	rspHeader.Response, err = client.Do(request)
   486  	if err != nil {
   487  		if e, ok := err.(*url.Error); ok {
   488  			if e.Timeout() {
   489  				return nil, errs.NewFrameError(errs.RetClientTimeout,
   490  					"http client transport RoundTrip timeout: "+err.Error())
   491  			}
   492  		}
   493  		if ctx.Err() == context.Canceled {
   494  			return nil, errs.NewFrameError(errs.RetClientCanceled,
   495  				"http client transport RoundTrip canceled: "+err.Error())
   496  		}
   497  		return nil, errs.NewFrameError(errs.RetClientNetErr,
   498  			"http client transport RoundTrip: "+err.Error())
   499  	}
   500  	return emptyBuf, nil
   501  }
   502  
   503  func (ct *ClientTransport) getStdHTTPClient(caFile, certFile,
   504  	keyFile, serverName string) (*stdhttp.Client, error) {
   505  	if len(caFile) == 0 { // HTTP requests share one client.
   506  		return &ct.Client, nil
   507  	}
   508  
   509  	cacheKey := fmt.Sprintf("%s-%s-%s", caFile, certFile, serverName)
   510  	ct.tlsLock.RLock()
   511  	cli, ok := ct.tlsClients[cacheKey]
   512  	ct.tlsLock.RUnlock()
   513  	if ok {
   514  		return cli, nil
   515  	}
   516  
   517  	ct.tlsLock.Lock()
   518  	defer ct.tlsLock.Unlock()
   519  	cli, ok = ct.tlsClients[cacheKey]
   520  	if ok {
   521  		return cli, nil
   522  	}
   523  
   524  	conf, err := itls.GetClientConfig(serverName, caFile, certFile, keyFile)
   525  	if err != nil {
   526  		return nil, err
   527  	}
   528  	client := &stdhttp.Client{
   529  		CheckRedirect: ct.Client.CheckRedirect,
   530  		Timeout:       ct.Client.Timeout,
   531  	}
   532  	if ct.http2Only {
   533  		client.Transport = &http2.Transport{
   534  			TLSClientConfig: conf,
   535  		}
   536  	} else {
   537  		tr := StdHTTPTransport.Clone()
   538  		tr.TLSClientConfig = conf
   539  		client.Transport = NewRoundTripper(tr)
   540  	}
   541  	ct.tlsClients[cacheKey] = client
   542  	return client, nil
   543  }
   544  
   545  // StdHTTPTransport all RoundTripper object used by http and https.
   546  var StdHTTPTransport = &stdhttp.Transport{
   547  	Proxy: stdhttp.ProxyFromEnvironment,
   548  	DialContext: (&net.Dialer{
   549  		Timeout:   30 * time.Second,
   550  		KeepAlive: 30 * time.Second,
   551  		DualStack: true,
   552  	}).DialContext,
   553  	ForceAttemptHTTP2:     true,
   554  	IdleConnTimeout:       50 * time.Second,
   555  	TLSHandshakeTimeout:   10 * time.Second,
   556  	MaxIdleConnsPerHost:   100,
   557  	DisableCompression:    true,
   558  	ExpectContinueTimeout: time.Second,
   559  }
   560  
   561  // NewRoundTripper creates new NewRoundTripper and can be replaced.
   562  var NewRoundTripper = newValueDetachedTransport