trpc.group/trpc-go/trpc-go@v1.0.2/http/restful_server_transport.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package http
    15  
    16  import (
    17  	"context"
    18  	"crypto/tls"
    19  	"crypto/x509"
    20  	"errors"
    21  	"fmt"
    22  	"net"
    23  	"net/http"
    24  	"os"
    25  	"strconv"
    26  	"time"
    27  
    28  	"github.com/valyala/fasthttp"
    29  	"trpc.group/trpc-go/trpc-go/internal/reuseport"
    30  	trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc"
    31  
    32  	"trpc.group/trpc-go/trpc-go/codec"
    33  	"trpc.group/trpc-go/trpc-go/restful"
    34  	"trpc.group/trpc-go/trpc-go/transport"
    35  )
    36  
    37  var (
    38  	// DefaultRESTServerTransport is the default RESTful ServerTransport.
    39  	DefaultRESTServerTransport = NewRESTServerTransport(false, transport.WithReusePort(true))
    40  
    41  	// DefaultRESTHeaderMatcher is the default REST HeaderMatcher.
    42  	DefaultRESTHeaderMatcher = func(ctx context.Context,
    43  		_ http.ResponseWriter,
    44  		r *http.Request,
    45  		serviceName, methodName string,
    46  	) (context.Context, error) {
    47  		return putRESTMsgInCtx(ctx, r.Header.Get, serviceName, methodName)
    48  	}
    49  
    50  	// DefaultRESTFastHTTPHeaderMatcher is the default REST FastHTTPHeaderMatcher.
    51  	DefaultRESTFastHTTPHeaderMatcher = func(
    52  		ctx context.Context,
    53  		requestCtx *fasthttp.RequestCtx,
    54  		serviceName, methodName string,
    55  	) (context.Context, error) {
    56  		headerGetter := func(k string) string {
    57  			return string(requestCtx.Request.Header.Peek(k))
    58  		}
    59  		return putRESTMsgInCtx(ctx, headerGetter, serviceName, methodName)
    60  	}
    61  
    62  	errReplaceRouter = errors.New("not allow to replace router when is based on fasthttp")
    63  )
    64  
    65  func init() {
    66  	// Compatible with thttp.
    67  	restful.SetCtxForCompatibility(func(ctx context.Context, w http.ResponseWriter,
    68  		r *http.Request) context.Context {
    69  		return WithHeader(ctx, &Header{Response: w, Request: r})
    70  	})
    71  	restful.DefaultHeaderMatcher = DefaultRESTHeaderMatcher
    72  	restful.DefaultFastHTTPHeaderMatcher = DefaultRESTFastHTTPHeaderMatcher
    73  	transport.RegisterServerTransport("restful", DefaultRESTServerTransport)
    74  }
    75  
    76  // putRESTMsgInCtx puts a new codec.Msg, service name and method name in ctx.
    77  // Metadata will be extracted from the request header if the header value exists.
    78  func putRESTMsgInCtx(
    79  	ctx context.Context,
    80  	headerGetter func(string) string,
    81  	service, method string,
    82  ) (context.Context, error) {
    83  	ctx, msg := codec.WithNewMessage(ctx)
    84  	msg.WithCalleeServiceName(service)
    85  	msg.WithServerRPCName(method)
    86  	msg.WithCalleeMethod(method)
    87  	msg.WithSerializationType(codec.SerializationTypePB)
    88  	if v := headerGetter(TrpcTimeout); v != "" {
    89  		i, _ := strconv.Atoi(v)
    90  		msg.WithRequestTimeout(time.Millisecond * time.Duration(i))
    91  	}
    92  	if v := headerGetter(TrpcCaller); v != "" {
    93  		msg.WithCallerServiceName(v)
    94  	}
    95  	if v := headerGetter(TrpcMessageType); v != "" {
    96  		i, _ := strconv.Atoi(v)
    97  		msg.WithDyeing((int32(i) & int32(trpcpb.TrpcMessageType_TRPC_DYEING_MESSAGE)) != 0)
    98  	}
    99  	if v := headerGetter(TrpcTransInfo); v != "" {
   100  		if _, err := unmarshalTransInfo(msg, v); err != nil {
   101  			return nil, err
   102  		}
   103  	}
   104  	return ctx, nil
   105  }
   106  
   107  // RESTServerTransport is the RESTful ServerTransport.
   108  type RESTServerTransport struct {
   109  	basedOnFastHTTP bool
   110  	opts            *transport.ServerTransportOptions
   111  }
   112  
   113  // NewRESTServerTransport creates a RESTful ServerTransport.
   114  func NewRESTServerTransport(basedOnFastHTTP bool, opt ...transport.ServerTransportOption) transport.ServerTransport {
   115  	opts := &transport.ServerTransportOptions{
   116  		IdleTimeout: time.Minute,
   117  	}
   118  
   119  	for _, o := range opt {
   120  		o(opts)
   121  	}
   122  
   123  	return &RESTServerTransport{
   124  		basedOnFastHTTP: basedOnFastHTTP,
   125  		opts:            opts,
   126  	}
   127  }
   128  
   129  // ListenAndServe implements interface of transport.ServerTransport.
   130  func (st *RESTServerTransport) ListenAndServe(ctx context.Context, opt ...transport.ListenServeOption) error {
   131  	opts := &transport.ListenServeOptions{
   132  		Network: "tcp",
   133  	}
   134  	for _, o := range opt {
   135  		o(opts)
   136  	}
   137  	// Get listener.
   138  	ln := opts.Listener
   139  	if ln == nil {
   140  		var err error
   141  		ln, err = st.getListener(opts)
   142  		if err != nil {
   143  			return fmt.Errorf("restfull server transport get listener err: %w", err)
   144  		}
   145  	}
   146  	// Save listener.
   147  	if err := transport.SaveListener(ln); err != nil {
   148  		return fmt.Errorf("save restful listener error: %w", err)
   149  	}
   150  	// Convert to tcpKeepAliveListener.
   151  	if tcpln, ok := ln.(*net.TCPListener); ok {
   152  		ln = tcpKeepAliveListener{tcpln}
   153  	}
   154  	// Config tls.
   155  	if len(opts.TLSKeyFile) != 0 && len(opts.TLSCertFile) != 0 {
   156  		tlsConf, err := generateTLSConfig(opts)
   157  		if err != nil {
   158  			return err
   159  		}
   160  		ln = tls.NewListener(ln, tlsConf)
   161  	}
   162  
   163  	return st.serve(ctx, ln, opts)
   164  }
   165  
   166  // serve starts service.
   167  func (st *RESTServerTransport) serve(ctx context.Context, ln net.Listener,
   168  	opts *transport.ListenServeOptions) error {
   169  	// Get router.
   170  	router := restful.GetRouter(opts.ServiceName)
   171  	if router == nil {
   172  		return fmt.Errorf("service %s router not registered", opts.ServiceName)
   173  	}
   174  
   175  	if st.basedOnFastHTTP { // Based on fasthttp.
   176  		r, ok := router.(*restful.Router)
   177  		if !ok {
   178  			return errReplaceRouter
   179  		}
   180  		server := &fasthttp.Server{Handler: r.HandleRequestCtx}
   181  		go func() {
   182  			_ = server.Serve(ln)
   183  		}()
   184  		if st.opts.ReusePort {
   185  			go func() {
   186  				<-ctx.Done()
   187  				_ = server.Shutdown()
   188  			}()
   189  		}
   190  		return nil
   191  	}
   192  	// Based on net/http.
   193  	server := &http.Server{Addr: opts.Address, Handler: router}
   194  	go func() {
   195  		_ = server.Serve(ln)
   196  	}()
   197  	if st.opts.ReusePort {
   198  		go func() {
   199  			<-ctx.Done()
   200  			_ = server.Shutdown(context.TODO())
   201  		}()
   202  	}
   203  	return nil
   204  }
   205  
   206  // getListener gets listener.
   207  func (st *RESTServerTransport) getListener(opts *transport.ListenServeOptions) (net.Listener, error) {
   208  	var err error
   209  	var ln net.Listener
   210  
   211  	v, _ := os.LookupEnv(transport.EnvGraceRestart)
   212  	ok, _ := strconv.ParseBool(v)
   213  	if ok {
   214  		// Find the passed listener.
   215  		pln, err := transport.GetPassedListener(opts.Network, opts.Address)
   216  		if err != nil {
   217  			return nil, err
   218  		}
   219  
   220  		ln, ok = pln.(net.Listener)
   221  		if !ok {
   222  			return nil, errors.New("invalid net.Listener")
   223  		}
   224  
   225  		return ln, nil
   226  	}
   227  
   228  	if st.opts.ReusePort {
   229  		ln, err = reuseport.Listen(opts.Network, opts.Address)
   230  		if err != nil {
   231  			return nil, fmt.Errorf("restful reuseport listen error: %w", err)
   232  		}
   233  	} else {
   234  		ln, err = net.Listen(opts.Network, opts.Address)
   235  		if err != nil {
   236  			return nil, fmt.Errorf("restful listen error: %w", err)
   237  		}
   238  	}
   239  
   240  	return ln, nil
   241  }
   242  
   243  // generateTLSConfig generates config of tls.
   244  func generateTLSConfig(opts *transport.ListenServeOptions) (*tls.Config, error) {
   245  	tlsConf := &tls.Config{}
   246  
   247  	cert, err := tls.LoadX509KeyPair(opts.TLSCertFile, opts.TLSKeyFile)
   248  	if err != nil {
   249  		return nil, err
   250  	}
   251  	tlsConf.Certificates = []tls.Certificate{cert}
   252  
   253  	// Two-way authentication.
   254  	if opts.CACertFile != "" {
   255  		tlsConf.ClientAuth = tls.RequireAndVerifyClientCert
   256  		if opts.CACertFile != "root" {
   257  			ca, err := os.ReadFile(opts.CACertFile)
   258  			if err != nil {
   259  				return nil, err
   260  			}
   261  			pool := x509.NewCertPool()
   262  			ok := pool.AppendCertsFromPEM(ca)
   263  			if !ok {
   264  				return nil, errors.New("failed to append certs from pem")
   265  			}
   266  			tlsConf.ClientCAs = pool
   267  		}
   268  	}
   269  
   270  	return tlsConf, nil
   271  }