trpc.group/trpc-go/trpc-go@v1.0.3/http/restful_server_transport.go (about) 1 // 2 // 3 // Tencent is pleased to support the open source community by making tRPC available. 4 // 5 // Copyright (C) 2023 THL A29 Limited, a Tencent company. 6 // All rights reserved. 7 // 8 // If you have downloaded a copy of the tRPC source code from Tencent, 9 // please note that tRPC source code is licensed under the Apache 2.0 License, 10 // A copy of the Apache 2.0 License is included in this file. 11 // 12 // 13 14 package http 15 16 import ( 17 "context" 18 "crypto/tls" 19 "crypto/x509" 20 "errors" 21 "fmt" 22 "net" 23 "net/http" 24 "os" 25 "strconv" 26 "time" 27 28 "github.com/valyala/fasthttp" 29 "trpc.group/trpc-go/trpc-go/internal/reuseport" 30 trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc" 31 32 "trpc.group/trpc-go/trpc-go/codec" 33 "trpc.group/trpc-go/trpc-go/restful" 34 "trpc.group/trpc-go/trpc-go/transport" 35 ) 36 37 var ( 38 // DefaultRESTServerTransport is the default RESTful ServerTransport. 39 DefaultRESTServerTransport = NewRESTServerTransport(false, transport.WithReusePort(true)) 40 41 // DefaultRESTHeaderMatcher is the default REST HeaderMatcher. 42 DefaultRESTHeaderMatcher = func(ctx context.Context, 43 _ http.ResponseWriter, 44 r *http.Request, 45 serviceName, methodName string, 46 ) (context.Context, error) { 47 return putRESTMsgInCtx(ctx, r.Header.Get, serviceName, methodName) 48 } 49 50 // DefaultRESTFastHTTPHeaderMatcher is the default REST FastHTTPHeaderMatcher. 51 DefaultRESTFastHTTPHeaderMatcher = func( 52 ctx context.Context, 53 requestCtx *fasthttp.RequestCtx, 54 serviceName, methodName string, 55 ) (context.Context, error) { 56 headerGetter := func(k string) string { 57 return string(requestCtx.Request.Header.Peek(k)) 58 } 59 return putRESTMsgInCtx(ctx, headerGetter, serviceName, methodName) 60 } 61 62 errReplaceRouter = errors.New("not allow to replace router when is based on fasthttp") 63 ) 64 65 func init() { 66 // Compatible with thttp. 67 restful.SetCtxForCompatibility(func(ctx context.Context, w http.ResponseWriter, 68 r *http.Request) context.Context { 69 return WithHeader(ctx, &Header{Response: w, Request: r}) 70 }) 71 restful.DefaultHeaderMatcher = DefaultRESTHeaderMatcher 72 restful.DefaultFastHTTPHeaderMatcher = DefaultRESTFastHTTPHeaderMatcher 73 transport.RegisterServerTransport("restful", DefaultRESTServerTransport) 74 } 75 76 // putRESTMsgInCtx puts a new codec.Msg, service name and method name in ctx. 77 // Metadata will be extracted from the request header if the header value exists. 78 func putRESTMsgInCtx( 79 ctx context.Context, 80 headerGetter func(string) string, 81 service, method string, 82 ) (context.Context, error) { 83 ctx, msg := codec.WithNewMessage(ctx) 84 msg.WithCalleeServiceName(service) 85 msg.WithServerRPCName(method) 86 msg.WithSerializationType(codec.SerializationTypePB) 87 if v := headerGetter(TrpcTimeout); v != "" { 88 i, _ := strconv.Atoi(v) 89 msg.WithRequestTimeout(time.Millisecond * time.Duration(i)) 90 } 91 if v := headerGetter(TrpcCaller); v != "" { 92 msg.WithCallerServiceName(v) 93 } 94 if v := headerGetter(TrpcMessageType); v != "" { 95 i, _ := strconv.Atoi(v) 96 msg.WithDyeing((int32(i) & int32(trpcpb.TrpcMessageType_TRPC_DYEING_MESSAGE)) != 0) 97 } 98 if v := headerGetter(TrpcTransInfo); v != "" { 99 if _, err := unmarshalTransInfo(msg, v); err != nil { 100 return nil, err 101 } 102 } 103 return ctx, nil 104 } 105 106 // RESTServerTransport is the RESTful ServerTransport. 107 type RESTServerTransport struct { 108 basedOnFastHTTP bool 109 opts *transport.ServerTransportOptions 110 } 111 112 // NewRESTServerTransport creates a RESTful ServerTransport. 113 func NewRESTServerTransport(basedOnFastHTTP bool, opt ...transport.ServerTransportOption) transport.ServerTransport { 114 opts := &transport.ServerTransportOptions{ 115 IdleTimeout: time.Minute, 116 } 117 118 for _, o := range opt { 119 o(opts) 120 } 121 122 return &RESTServerTransport{ 123 basedOnFastHTTP: basedOnFastHTTP, 124 opts: opts, 125 } 126 } 127 128 // ListenAndServe implements interface of transport.ServerTransport. 129 func (st *RESTServerTransport) ListenAndServe(ctx context.Context, opt ...transport.ListenServeOption) error { 130 opts := &transport.ListenServeOptions{ 131 Network: "tcp", 132 } 133 for _, o := range opt { 134 o(opts) 135 } 136 // Get listener. 137 ln := opts.Listener 138 if ln == nil { 139 var err error 140 ln, err = st.getListener(opts) 141 if err != nil { 142 return fmt.Errorf("restfull server transport get listener err: %w", err) 143 } 144 } 145 // Save listener. 146 if err := transport.SaveListener(ln); err != nil { 147 return fmt.Errorf("save restful listener error: %w", err) 148 } 149 // Convert to tcpKeepAliveListener. 150 if tcpln, ok := ln.(*net.TCPListener); ok { 151 ln = tcpKeepAliveListener{tcpln} 152 } 153 // Config tls. 154 if len(opts.TLSKeyFile) != 0 && len(opts.TLSCertFile) != 0 { 155 tlsConf, err := generateTLSConfig(opts) 156 if err != nil { 157 return err 158 } 159 ln = tls.NewListener(ln, tlsConf) 160 } 161 162 go func() { 163 <-opts.StopListening 164 ln.Close() 165 }() 166 167 return st.serve(ctx, ln, opts) 168 } 169 170 // serve starts service. 171 func (st *RESTServerTransport) serve( 172 ctx context.Context, 173 ln net.Listener, 174 opts *transport.ListenServeOptions, 175 ) error { 176 // Get router. 177 router := restful.GetRouter(opts.ServiceName) 178 if router == nil { 179 return fmt.Errorf("service %s router not registered", opts.ServiceName) 180 } 181 182 if st.basedOnFastHTTP { // Based on fasthttp. 183 r, ok := router.(*restful.Router) 184 if !ok { 185 return errReplaceRouter 186 } 187 server := &fasthttp.Server{Handler: r.HandleRequestCtx} 188 go func() { 189 _ = server.Serve(ln) 190 }() 191 if st.opts.ReusePort { 192 go func() { 193 <-ctx.Done() 194 _ = server.Shutdown() 195 }() 196 } 197 return nil 198 } 199 // Based on net/http. 200 server := &http.Server{Addr: opts.Address, Handler: router} 201 go func() { 202 _ = server.Serve(ln) 203 }() 204 if st.opts.ReusePort { 205 go func() { 206 <-ctx.Done() 207 _ = server.Shutdown(context.TODO()) 208 }() 209 } 210 return nil 211 } 212 213 // getListener gets listener. 214 func (st *RESTServerTransport) getListener(opts *transport.ListenServeOptions) (net.Listener, error) { 215 var err error 216 var ln net.Listener 217 218 v, _ := os.LookupEnv(transport.EnvGraceRestart) 219 ok, _ := strconv.ParseBool(v) 220 if ok { 221 // Find the passed listener. 222 pln, err := transport.GetPassedListener(opts.Network, opts.Address) 223 if err != nil { 224 return nil, err 225 } 226 227 ln, ok = pln.(net.Listener) 228 if !ok { 229 return nil, errors.New("invalid net.Listener") 230 } 231 232 return ln, nil 233 } 234 235 if st.opts.ReusePort { 236 ln, err = reuseport.Listen(opts.Network, opts.Address) 237 if err != nil { 238 return nil, fmt.Errorf("restful reuseport listen error: %w", err) 239 } 240 } else { 241 ln, err = net.Listen(opts.Network, opts.Address) 242 if err != nil { 243 return nil, fmt.Errorf("restful listen error: %w", err) 244 } 245 } 246 247 return ln, nil 248 } 249 250 // generateTLSConfig generates config of tls. 251 func generateTLSConfig(opts *transport.ListenServeOptions) (*tls.Config, error) { 252 tlsConf := &tls.Config{} 253 254 cert, err := tls.LoadX509KeyPair(opts.TLSCertFile, opts.TLSKeyFile) 255 if err != nil { 256 return nil, err 257 } 258 tlsConf.Certificates = []tls.Certificate{cert} 259 260 // Two-way authentication. 261 if opts.CACertFile != "" { 262 tlsConf.ClientAuth = tls.RequireAndVerifyClientCert 263 if opts.CACertFile != "root" { 264 ca, err := os.ReadFile(opts.CACertFile) 265 if err != nil { 266 return nil, err 267 } 268 pool := x509.NewCertPool() 269 ok := pool.AppendCertsFromPEM(ca) 270 if !ok { 271 return nil, errors.New("failed to append certs from pem") 272 } 273 tlsConf.ClientCAs = pool 274 } 275 } 276 277 return tlsConf, nil 278 }