trpc.group/trpc-go/trpc-go@v1.0.3/http/restful_server_transport.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package http
    15  
    16  import (
    17  	"context"
    18  	"crypto/tls"
    19  	"crypto/x509"
    20  	"errors"
    21  	"fmt"
    22  	"net"
    23  	"net/http"
    24  	"os"
    25  	"strconv"
    26  	"time"
    27  
    28  	"github.com/valyala/fasthttp"
    29  	"trpc.group/trpc-go/trpc-go/internal/reuseport"
    30  	trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc"
    31  
    32  	"trpc.group/trpc-go/trpc-go/codec"
    33  	"trpc.group/trpc-go/trpc-go/restful"
    34  	"trpc.group/trpc-go/trpc-go/transport"
    35  )
    36  
    37  var (
    38  	// DefaultRESTServerTransport is the default RESTful ServerTransport.
    39  	DefaultRESTServerTransport = NewRESTServerTransport(false, transport.WithReusePort(true))
    40  
    41  	// DefaultRESTHeaderMatcher is the default REST HeaderMatcher.
    42  	DefaultRESTHeaderMatcher = func(ctx context.Context,
    43  		_ http.ResponseWriter,
    44  		r *http.Request,
    45  		serviceName, methodName string,
    46  	) (context.Context, error) {
    47  		return putRESTMsgInCtx(ctx, r.Header.Get, serviceName, methodName)
    48  	}
    49  
    50  	// DefaultRESTFastHTTPHeaderMatcher is the default REST FastHTTPHeaderMatcher.
    51  	DefaultRESTFastHTTPHeaderMatcher = func(
    52  		ctx context.Context,
    53  		requestCtx *fasthttp.RequestCtx,
    54  		serviceName, methodName string,
    55  	) (context.Context, error) {
    56  		headerGetter := func(k string) string {
    57  			return string(requestCtx.Request.Header.Peek(k))
    58  		}
    59  		return putRESTMsgInCtx(ctx, headerGetter, serviceName, methodName)
    60  	}
    61  
    62  	errReplaceRouter = errors.New("not allow to replace router when is based on fasthttp")
    63  )
    64  
    65  func init() {
    66  	// Compatible with thttp.
    67  	restful.SetCtxForCompatibility(func(ctx context.Context, w http.ResponseWriter,
    68  		r *http.Request) context.Context {
    69  		return WithHeader(ctx, &Header{Response: w, Request: r})
    70  	})
    71  	restful.DefaultHeaderMatcher = DefaultRESTHeaderMatcher
    72  	restful.DefaultFastHTTPHeaderMatcher = DefaultRESTFastHTTPHeaderMatcher
    73  	transport.RegisterServerTransport("restful", DefaultRESTServerTransport)
    74  }
    75  
    76  // putRESTMsgInCtx puts a new codec.Msg, service name and method name in ctx.
    77  // Metadata will be extracted from the request header if the header value exists.
    78  func putRESTMsgInCtx(
    79  	ctx context.Context,
    80  	headerGetter func(string) string,
    81  	service, method string,
    82  ) (context.Context, error) {
    83  	ctx, msg := codec.WithNewMessage(ctx)
    84  	msg.WithCalleeServiceName(service)
    85  	msg.WithServerRPCName(method)
    86  	msg.WithSerializationType(codec.SerializationTypePB)
    87  	if v := headerGetter(TrpcTimeout); v != "" {
    88  		i, _ := strconv.Atoi(v)
    89  		msg.WithRequestTimeout(time.Millisecond * time.Duration(i))
    90  	}
    91  	if v := headerGetter(TrpcCaller); v != "" {
    92  		msg.WithCallerServiceName(v)
    93  	}
    94  	if v := headerGetter(TrpcMessageType); v != "" {
    95  		i, _ := strconv.Atoi(v)
    96  		msg.WithDyeing((int32(i) & int32(trpcpb.TrpcMessageType_TRPC_DYEING_MESSAGE)) != 0)
    97  	}
    98  	if v := headerGetter(TrpcTransInfo); v != "" {
    99  		if _, err := unmarshalTransInfo(msg, v); err != nil {
   100  			return nil, err
   101  		}
   102  	}
   103  	return ctx, nil
   104  }
   105  
   106  // RESTServerTransport is the RESTful ServerTransport.
   107  type RESTServerTransport struct {
   108  	basedOnFastHTTP bool
   109  	opts            *transport.ServerTransportOptions
   110  }
   111  
   112  // NewRESTServerTransport creates a RESTful ServerTransport.
   113  func NewRESTServerTransport(basedOnFastHTTP bool, opt ...transport.ServerTransportOption) transport.ServerTransport {
   114  	opts := &transport.ServerTransportOptions{
   115  		IdleTimeout: time.Minute,
   116  	}
   117  
   118  	for _, o := range opt {
   119  		o(opts)
   120  	}
   121  
   122  	return &RESTServerTransport{
   123  		basedOnFastHTTP: basedOnFastHTTP,
   124  		opts:            opts,
   125  	}
   126  }
   127  
   128  // ListenAndServe implements interface of transport.ServerTransport.
   129  func (st *RESTServerTransport) ListenAndServe(ctx context.Context, opt ...transport.ListenServeOption) error {
   130  	opts := &transport.ListenServeOptions{
   131  		Network: "tcp",
   132  	}
   133  	for _, o := range opt {
   134  		o(opts)
   135  	}
   136  	// Get listener.
   137  	ln := opts.Listener
   138  	if ln == nil {
   139  		var err error
   140  		ln, err = st.getListener(opts)
   141  		if err != nil {
   142  			return fmt.Errorf("restfull server transport get listener err: %w", err)
   143  		}
   144  	}
   145  	// Save listener.
   146  	if err := transport.SaveListener(ln); err != nil {
   147  		return fmt.Errorf("save restful listener error: %w", err)
   148  	}
   149  	// Convert to tcpKeepAliveListener.
   150  	if tcpln, ok := ln.(*net.TCPListener); ok {
   151  		ln = tcpKeepAliveListener{tcpln}
   152  	}
   153  	// Config tls.
   154  	if len(opts.TLSKeyFile) != 0 && len(opts.TLSCertFile) != 0 {
   155  		tlsConf, err := generateTLSConfig(opts)
   156  		if err != nil {
   157  			return err
   158  		}
   159  		ln = tls.NewListener(ln, tlsConf)
   160  	}
   161  
   162  	go func() {
   163  		<-opts.StopListening
   164  		ln.Close()
   165  	}()
   166  
   167  	return st.serve(ctx, ln, opts)
   168  }
   169  
   170  // serve starts service.
   171  func (st *RESTServerTransport) serve(
   172  	ctx context.Context,
   173  	ln net.Listener,
   174  	opts *transport.ListenServeOptions,
   175  ) error {
   176  	// Get router.
   177  	router := restful.GetRouter(opts.ServiceName)
   178  	if router == nil {
   179  		return fmt.Errorf("service %s router not registered", opts.ServiceName)
   180  	}
   181  
   182  	if st.basedOnFastHTTP { // Based on fasthttp.
   183  		r, ok := router.(*restful.Router)
   184  		if !ok {
   185  			return errReplaceRouter
   186  		}
   187  		server := &fasthttp.Server{Handler: r.HandleRequestCtx}
   188  		go func() {
   189  			_ = server.Serve(ln)
   190  		}()
   191  		if st.opts.ReusePort {
   192  			go func() {
   193  				<-ctx.Done()
   194  				_ = server.Shutdown()
   195  			}()
   196  		}
   197  		return nil
   198  	}
   199  	// Based on net/http.
   200  	server := &http.Server{Addr: opts.Address, Handler: router}
   201  	go func() {
   202  		_ = server.Serve(ln)
   203  	}()
   204  	if st.opts.ReusePort {
   205  		go func() {
   206  			<-ctx.Done()
   207  			_ = server.Shutdown(context.TODO())
   208  		}()
   209  	}
   210  	return nil
   211  }
   212  
   213  // getListener gets listener.
   214  func (st *RESTServerTransport) getListener(opts *transport.ListenServeOptions) (net.Listener, error) {
   215  	var err error
   216  	var ln net.Listener
   217  
   218  	v, _ := os.LookupEnv(transport.EnvGraceRestart)
   219  	ok, _ := strconv.ParseBool(v)
   220  	if ok {
   221  		// Find the passed listener.
   222  		pln, err := transport.GetPassedListener(opts.Network, opts.Address)
   223  		if err != nil {
   224  			return nil, err
   225  		}
   226  
   227  		ln, ok = pln.(net.Listener)
   228  		if !ok {
   229  			return nil, errors.New("invalid net.Listener")
   230  		}
   231  
   232  		return ln, nil
   233  	}
   234  
   235  	if st.opts.ReusePort {
   236  		ln, err = reuseport.Listen(opts.Network, opts.Address)
   237  		if err != nil {
   238  			return nil, fmt.Errorf("restful reuseport listen error: %w", err)
   239  		}
   240  	} else {
   241  		ln, err = net.Listen(opts.Network, opts.Address)
   242  		if err != nil {
   243  			return nil, fmt.Errorf("restful listen error: %w", err)
   244  		}
   245  	}
   246  
   247  	return ln, nil
   248  }
   249  
   250  // generateTLSConfig generates config of tls.
   251  func generateTLSConfig(opts *transport.ListenServeOptions) (*tls.Config, error) {
   252  	tlsConf := &tls.Config{}
   253  
   254  	cert, err := tls.LoadX509KeyPair(opts.TLSCertFile, opts.TLSKeyFile)
   255  	if err != nil {
   256  		return nil, err
   257  	}
   258  	tlsConf.Certificates = []tls.Certificate{cert}
   259  
   260  	// Two-way authentication.
   261  	if opts.CACertFile != "" {
   262  		tlsConf.ClientAuth = tls.RequireAndVerifyClientCert
   263  		if opts.CACertFile != "root" {
   264  			ca, err := os.ReadFile(opts.CACertFile)
   265  			if err != nil {
   266  				return nil, err
   267  			}
   268  			pool := x509.NewCertPool()
   269  			ok := pool.AppendCertsFromPEM(ca)
   270  			if !ok {
   271  				return nil, errors.New("failed to append certs from pem")
   272  			}
   273  			tlsConf.ClientCAs = pool
   274  		}
   275  	}
   276  
   277  	return tlsConf, nil
   278  }