trpc.group/trpc-go/trpc-go@v1.0.2/http/restful_server_transport.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package http 15 16 import ( 17 "context" 18 "crypto/tls" 19 "crypto/x509" 20 "errors" 21 "fmt" 22 "net" 23 "net/http" 24 "os" 25 "strconv" 26 "time" 27 28 "github.com/valyala/fasthttp" 29 "trpc.group/trpc-go/trpc-go/internal/reuseport" 30 trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc" 31 32 "trpc.group/trpc-go/trpc-go/codec" 33 "trpc.group/trpc-go/trpc-go/restful" 34 "trpc.group/trpc-go/trpc-go/transport" 35 ) 36 37 var ( 38 // DefaultRESTServerTransport is the default RESTful ServerTransport. 39 DefaultRESTServerTransport = NewRESTServerTransport(false, transport.WithReusePort(true)) 40 41 // DefaultRESTHeaderMatcher is the default REST HeaderMatcher. 42 DefaultRESTHeaderMatcher = func(ctx context.Context, 43 _ http.ResponseWriter, 44 r *http.Request, 45 serviceName, methodName string, 46 ) (context.Context, error) { 47 return putRESTMsgInCtx(ctx, r.Header.Get, serviceName, methodName) 48 } 49 50 // DefaultRESTFastHTTPHeaderMatcher is the default REST FastHTTPHeaderMatcher. 51 DefaultRESTFastHTTPHeaderMatcher = func( 52 ctx context.Context, 53 requestCtx *fasthttp.RequestCtx, 54 serviceName, methodName string, 55 ) (context.Context, error) { 56 headerGetter := func(k string) string { 57 return string(requestCtx.Request.Header.Peek(k)) 58 } 59 return putRESTMsgInCtx(ctx, headerGetter, serviceName, methodName) 60 } 61 62 errReplaceRouter = errors.New("not allow to replace router when is based on fasthttp") 63 ) 64 65 func init() { 66 // Compatible with thttp. 67 restful.SetCtxForCompatibility(func(ctx context.Context, w http.ResponseWriter, 68 r *http.Request) context.Context { 69 return WithHeader(ctx, &Header{Response: w, Request: r}) 70 }) 71 restful.DefaultHeaderMatcher = DefaultRESTHeaderMatcher 72 restful.DefaultFastHTTPHeaderMatcher = DefaultRESTFastHTTPHeaderMatcher 73 transport.RegisterServerTransport("restful", DefaultRESTServerTransport) 74 } 75 76 // putRESTMsgInCtx puts a new codec.Msg, service name and method name in ctx. 77 // Metadata will be extracted from the request header if the header value exists. 78 func putRESTMsgInCtx( 79 ctx context.Context, 80 headerGetter func(string) string, 81 service, method string, 82 ) (context.Context, error) { 83 ctx, msg := codec.WithNewMessage(ctx) 84 msg.WithCalleeServiceName(service) 85 msg.WithServerRPCName(method) 86 msg.WithCalleeMethod(method) 87 msg.WithSerializationType(codec.SerializationTypePB) 88 if v := headerGetter(TrpcTimeout); v != "" { 89 i, _ := strconv.Atoi(v) 90 msg.WithRequestTimeout(time.Millisecond * time.Duration(i)) 91 } 92 if v := headerGetter(TrpcCaller); v != "" { 93 msg.WithCallerServiceName(v) 94 } 95 if v := headerGetter(TrpcMessageType); v != "" { 96 i, _ := strconv.Atoi(v) 97 msg.WithDyeing((int32(i) & int32(trpcpb.TrpcMessageType_TRPC_DYEING_MESSAGE)) != 0) 98 } 99 if v := headerGetter(TrpcTransInfo); v != "" { 100 if _, err := unmarshalTransInfo(msg, v); err != nil { 101 return nil, err 102 } 103 } 104 return ctx, nil 105 } 106 107 // RESTServerTransport is the RESTful ServerTransport. 108 type RESTServerTransport struct { 109 basedOnFastHTTP bool 110 opts *transport.ServerTransportOptions 111 } 112 113 // NewRESTServerTransport creates a RESTful ServerTransport. 114 func NewRESTServerTransport(basedOnFastHTTP bool, opt ...transport.ServerTransportOption) transport.ServerTransport { 115 opts := &transport.ServerTransportOptions{ 116 IdleTimeout: time.Minute, 117 } 118 119 for _, o := range opt { 120 o(opts) 121 } 122 123 return &RESTServerTransport{ 124 basedOnFastHTTP: basedOnFastHTTP, 125 opts: opts, 126 } 127 } 128 129 // ListenAndServe implements interface of transport.ServerTransport. 130 func (st *RESTServerTransport) ListenAndServe(ctx context.Context, opt ...transport.ListenServeOption) error { 131 opts := &transport.ListenServeOptions{ 132 Network: "tcp", 133 } 134 for _, o := range opt { 135 o(opts) 136 } 137 // Get listener. 138 ln := opts.Listener 139 if ln == nil { 140 var err error 141 ln, err = st.getListener(opts) 142 if err != nil { 143 return fmt.Errorf("restfull server transport get listener err: %w", err) 144 } 145 } 146 // Save listener. 147 if err := transport.SaveListener(ln); err != nil { 148 return fmt.Errorf("save restful listener error: %w", err) 149 } 150 // Convert to tcpKeepAliveListener. 151 if tcpln, ok := ln.(*net.TCPListener); ok { 152 ln = tcpKeepAliveListener{tcpln} 153 } 154 // Config tls. 155 if len(opts.TLSKeyFile) != 0 && len(opts.TLSCertFile) != 0 { 156 tlsConf, err := generateTLSConfig(opts) 157 if err != nil { 158 return err 159 } 160 ln = tls.NewListener(ln, tlsConf) 161 } 162 163 return st.serve(ctx, ln, opts) 164 } 165 166 // serve starts service. 167 func (st *RESTServerTransport) serve(ctx context.Context, ln net.Listener, 168 opts *transport.ListenServeOptions) error { 169 // Get router. 170 router := restful.GetRouter(opts.ServiceName) 171 if router == nil { 172 return fmt.Errorf("service %s router not registered", opts.ServiceName) 173 } 174 175 if st.basedOnFastHTTP { // Based on fasthttp. 176 r, ok := router.(*restful.Router) 177 if !ok { 178 return errReplaceRouter 179 } 180 server := &fasthttp.Server{Handler: r.HandleRequestCtx} 181 go func() { 182 _ = server.Serve(ln) 183 }() 184 if st.opts.ReusePort { 185 go func() { 186 <-ctx.Done() 187 _ = server.Shutdown() 188 }() 189 } 190 return nil 191 } 192 // Based on net/http. 193 server := &http.Server{Addr: opts.Address, Handler: router} 194 go func() { 195 _ = server.Serve(ln) 196 }() 197 if st.opts.ReusePort { 198 go func() { 199 <-ctx.Done() 200 _ = server.Shutdown(context.TODO()) 201 }() 202 } 203 return nil 204 } 205 206 // getListener gets listener. 207 func (st *RESTServerTransport) getListener(opts *transport.ListenServeOptions) (net.Listener, error) { 208 var err error 209 var ln net.Listener 210 211 v, _ := os.LookupEnv(transport.EnvGraceRestart) 212 ok, _ := strconv.ParseBool(v) 213 if ok { 214 // Find the passed listener. 215 pln, err := transport.GetPassedListener(opts.Network, opts.Address) 216 if err != nil { 217 return nil, err 218 } 219 220 ln, ok = pln.(net.Listener) 221 if !ok { 222 return nil, errors.New("invalid net.Listener") 223 } 224 225 return ln, nil 226 } 227 228 if st.opts.ReusePort { 229 ln, err = reuseport.Listen(opts.Network, opts.Address) 230 if err != nil { 231 return nil, fmt.Errorf("restful reuseport listen error: %w", err) 232 } 233 } else { 234 ln, err = net.Listen(opts.Network, opts.Address) 235 if err != nil { 236 return nil, fmt.Errorf("restful listen error: %w", err) 237 } 238 } 239 240 return ln, nil 241 } 242 243 // generateTLSConfig generates config of tls. 244 func generateTLSConfig(opts *transport.ListenServeOptions) (*tls.Config, error) { 245 tlsConf := &tls.Config{} 246 247 cert, err := tls.LoadX509KeyPair(opts.TLSCertFile, opts.TLSKeyFile) 248 if err != nil { 249 return nil, err 250 } 251 tlsConf.Certificates = []tls.Certificate{cert} 252 253 // Two-way authentication. 254 if opts.CACertFile != "" { 255 tlsConf.ClientAuth = tls.RequireAndVerifyClientCert 256 if opts.CACertFile != "root" { 257 ca, err := os.ReadFile(opts.CACertFile) 258 if err != nil { 259 return nil, err 260 } 261 pool := x509.NewCertPool() 262 ok := pool.AppendCertsFromPEM(ca) 263 if !ok { 264 return nil, errors.New("failed to append certs from pem") 265 } 266 tlsConf.ClientCAs = pool 267 } 268 } 269 270 return tlsConf, nil 271 }