trpc.group/trpc-go/trpc-go@v1.0.2/restful/router.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package restful
    15  
    16  import (
    17  	"context"
    18  	"fmt"
    19  	"io"
    20  	"net/http"
    21  	"strconv"
    22  	"strings"
    23  	"sync"
    24  
    25  	"google.golang.org/protobuf/proto"
    26  	"google.golang.org/protobuf/types/known/emptypb"
    27  
    28  	"trpc.group/trpc-go/trpc-go/codec"
    29  	"trpc.group/trpc-go/trpc-go/errs"
    30  	"trpc.group/trpc-go/trpc-go/filter"
    31  	"trpc.group/trpc-go/trpc-go/internal/dat"
    32  )
    33  
    34  // Router is restful router.
    35  type Router struct {
    36  	opts        *Options
    37  	transcoders map[string][]*transcoder
    38  }
    39  
    40  // NewRouter creates a Router.
    41  func NewRouter(opts ...Option) *Router {
    42  	o := Options{
    43  		ErrorHandler:          DefaultErrorHandler,
    44  		HeaderMatcher:         DefaultHeaderMatcher,
    45  		ResponseHandler:       DefaultResponseHandler,
    46  		FastHTTPErrHandler:    DefaultFastHTTPErrorHandler,
    47  		FastHTTPHeaderMatcher: DefaultFastHTTPHeaderMatcher,
    48  		FastHTTPRespHandler:   DefaultFastHTTPRespHandler,
    49  	}
    50  	for _, opt := range opts {
    51  		opt(&o)
    52  	}
    53  	o.rebuildHeaderMatcher()
    54  
    55  	return &Router{
    56  		opts:        &o,
    57  		transcoders: make(map[string][]*transcoder),
    58  	}
    59  }
    60  
    61  var (
    62  	routers    = make(map[string]http.Handler) // tRPC service name -> Router
    63  	routerLock sync.RWMutex
    64  )
    65  
    66  // RegisterRouter registers a Router which corresponds to a tRPC Service.
    67  func RegisterRouter(name string, router http.Handler) {
    68  	routerLock.Lock()
    69  	routers[name] = router
    70  	routerLock.Unlock()
    71  }
    72  
    73  // GetRouter returns a Router which corresponds to a tRPC Service.
    74  func GetRouter(name string) http.Handler {
    75  	routerLock.RLock()
    76  	router := routers[name]
    77  	routerLock.RUnlock()
    78  	return router
    79  }
    80  
    81  // ProtoMessage is alias of proto.Message.
    82  type ProtoMessage proto.Message
    83  
    84  // Initializer initializes a ProtoMessage.
    85  type Initializer func() ProtoMessage
    86  
    87  // BodyLocator locates which fields of the proto message would be
    88  // populated according to HttpRule body.
    89  type BodyLocator interface {
    90  	Body() string
    91  	Locate(ProtoMessage) interface{}
    92  }
    93  
    94  // ResponseBodyLocator locates which fields of the proto message would be marshaled
    95  // according to HttpRule response_body.
    96  type ResponseBodyLocator interface {
    97  	ResponseBody() string
    98  	Locate(ProtoMessage) interface{}
    99  }
   100  
   101  // HandleFunc is tRPC method handle function.
   102  type HandleFunc func(svc interface{}, ctx context.Context, reqBody interface{}) (interface{}, error)
   103  
   104  // ExtractFilterFunc extracts tRPC service filter chain.
   105  type ExtractFilterFunc func() filter.ServerChain
   106  
   107  // Binding is the binding of tRPC method and HttpRule.
   108  type Binding struct {
   109  	Name         string
   110  	Input        Initializer
   111  	Output       Initializer
   112  	Filter       HandleFunc
   113  	HTTPMethod   string
   114  	Pattern      *Pattern
   115  	Body         BodyLocator
   116  	ResponseBody ResponseBodyLocator
   117  }
   118  
   119  // AddImplBinding creates a new binding with a specified service implementation.
   120  func (r *Router) AddImplBinding(binding *Binding, serviceImpl interface{}) error {
   121  	tr, err := r.newTranscoder(binding, serviceImpl)
   122  	if err != nil {
   123  		return fmt.Errorf("new transcoder during add impl binding: %w", err)
   124  	}
   125  	// add transcoder
   126  	r.transcoders[binding.HTTPMethod] = append(r.transcoders[binding.HTTPMethod], tr)
   127  	return nil
   128  }
   129  
   130  func (r *Router) newTranscoder(binding *Binding, serviceImpl interface{}) (*transcoder, error) {
   131  	if binding.Output == nil {
   132  		binding.Output = func() ProtoMessage { return &emptypb.Empty{} }
   133  	}
   134  
   135  	// create a transcoder
   136  	tr := &transcoder{
   137  		name:                 binding.Name,
   138  		input:                binding.Input,
   139  		output:               binding.Output,
   140  		handler:              binding.Filter,
   141  		httpMethod:           binding.HTTPMethod,
   142  		pat:                  binding.Pattern,
   143  		body:                 binding.Body,
   144  		respBody:             binding.ResponseBody,
   145  		router:               r,
   146  		discardUnknownParams: r.opts.DiscardUnknownParams,
   147  		serviceImpl:          serviceImpl,
   148  	}
   149  
   150  	// create a dat, filter all fields specified in HttpRule
   151  	var fps [][]string
   152  	if fromPat := binding.Pattern.FieldPaths(); fromPat != nil {
   153  		fps = append(fps, fromPat...)
   154  	}
   155  	if binding.Body != nil {
   156  		if fromBody := binding.Body.Body(); fromBody != "" && fromBody != "*" {
   157  			fps = append(fps, strings.Split(fromBody, "."))
   158  		}
   159  	}
   160  	if len(fps) > 0 {
   161  		doubleArrayTrie, err := dat.Build(fps)
   162  		if err != nil {
   163  			return nil, fmt.Errorf("failed to build dat: %w", err)
   164  		}
   165  		tr.dat = doubleArrayTrie
   166  	}
   167  	return tr, nil
   168  }
   169  
   170  // ctxForCompatibility is used only for compatibility with thttp.
   171  var ctxForCompatibility func(context.Context, http.ResponseWriter, *http.Request) context.Context
   172  
   173  // SetCtxForCompatibility is used only for compatibility with thttp.
   174  func SetCtxForCompatibility(f func(context.Context, http.ResponseWriter, *http.Request) context.Context) {
   175  	ctxForCompatibility = f
   176  }
   177  
   178  // HeaderMatcher matches http request header to tRPC Stub Context.
   179  type HeaderMatcher func(
   180  	ctx context.Context,
   181  	w http.ResponseWriter,
   182  	r *http.Request,
   183  	serviceName, methodName string,
   184  ) (context.Context, error)
   185  
   186  // DefaultHeaderMatcher is the default HeaderMatcher.
   187  var DefaultHeaderMatcher = func(
   188  	ctx context.Context,
   189  	w http.ResponseWriter,
   190  	req *http.Request,
   191  	serviceName, methodName string,
   192  ) (context.Context, error) {
   193  	// Noted: it's better to do the same thing as withNewMessage.
   194  	return withNewMessage(ctx, serviceName, methodName), nil
   195  }
   196  
   197  // withNewMessage create a new codec.Msg, put it into ctx,
   198  // and set target service name and method name.
   199  func withNewMessage(ctx context.Context, serviceName, methodName string) context.Context {
   200  	ctx, msg := codec.WithNewMessage(ctx)
   201  	msg.WithServerRPCName(methodName)
   202  	msg.WithCalleeMethod(methodName)
   203  	msg.WithCalleeServiceName(serviceName)
   204  	msg.WithSerializationType(codec.SerializationTypePB)
   205  	return ctx
   206  }
   207  
   208  // CustomResponseHandler is the custom response handler.
   209  type CustomResponseHandler func(
   210  	ctx context.Context,
   211  	w http.ResponseWriter,
   212  	r *http.Request,
   213  	resp proto.Message,
   214  	body []byte,
   215  ) error
   216  
   217  var httpStatusKey = "t-http-status"
   218  
   219  // SetStatusCodeOnSucceed sets status code on succeed, should be 2XX.
   220  // It's not supposed to call this function but use WithStatusCode in restful/errors.go
   221  // to set status code on error.
   222  func SetStatusCodeOnSucceed(ctx context.Context, code int) {
   223  	msg := codec.Message(ctx)
   224  	metadata := msg.ServerMetaData()
   225  	if metadata == nil {
   226  		metadata = codec.MetaData{}
   227  	}
   228  	metadata[httpStatusKey] = []byte(strconv.Itoa(code))
   229  	msg.WithServerMetaData(metadata)
   230  }
   231  
   232  // GetStatusCodeOnSucceed returns status code on succeed.
   233  // SetStatusCodeOnSucceed must be called first in tRPC method.
   234  func GetStatusCodeOnSucceed(ctx context.Context) int {
   235  	if metadata := codec.Message(ctx).ServerMetaData(); metadata != nil {
   236  		if buf, ok := metadata[httpStatusKey]; ok {
   237  			if code, err := strconv.Atoi(bytes2str(buf)); err == nil {
   238  				return code
   239  			}
   240  		}
   241  	}
   242  	return http.StatusOK
   243  }
   244  
   245  // DefaultResponseHandler is the default CustomResponseHandler.
   246  var DefaultResponseHandler = func(
   247  	ctx context.Context,
   248  	w http.ResponseWriter,
   249  	r *http.Request,
   250  	resp proto.Message,
   251  	body []byte,
   252  ) error {
   253  	// compress
   254  	var writer io.Writer = w
   255  	_, c := compressorForTranscoding(r.Header[headerContentEncoding],
   256  		r.Header[headerAcceptEncoding])
   257  	if c != nil {
   258  		writeCloser, err := c.Compress(w)
   259  		if err != nil {
   260  			return fmt.Errorf("failed to compress resp body: %w", err)
   261  		}
   262  		defer writeCloser.Close()
   263  		w.Header().Set(headerContentEncoding, c.ContentEncoding())
   264  		writer = writeCloser
   265  	}
   266  
   267  	// set response content-type
   268  	_, s := serializerForTranscoding(r.Header[headerContentType],
   269  		r.Header[headerAccept])
   270  	w.Header().Set(headerContentType, s.ContentType())
   271  
   272  	// set status code
   273  	statusCode := GetStatusCodeOnSucceed(ctx)
   274  	w.WriteHeader(statusCode)
   275  
   276  	// response body
   277  	if statusCode != http.StatusNoContent && statusCode != http.StatusNotModified {
   278  		writer.Write(body)
   279  	}
   280  
   281  	return nil
   282  }
   283  
   284  // putBackCtxMessage calls codec.PutBackMessage to put a codec.Msg back to pool,
   285  // if the codec.Msg has been put into ctx.
   286  func putBackCtxMessage(ctx context.Context) {
   287  	if msg, ok := ctx.Value(codec.ContextKeyMessage).(codec.Msg); ok {
   288  		codec.PutBackMessage(msg)
   289  	}
   290  }
   291  
   292  // ServeHTTP implements http.Handler.
   293  // TODO: better routing handling.
   294  func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
   295  	ctx := ctxForCompatibility(req.Context(), w, req)
   296  	for _, tr := range r.transcoders[req.Method] {
   297  		fieldValues, err := tr.pat.Match(req.URL.Path)
   298  		if err == nil {
   299  			r.handle(ctx, w, req, tr, fieldValues)
   300  			return
   301  		}
   302  	}
   303  	r.opts.ErrorHandler(ctx, w, req, errs.New(errs.RetServerNoFunc, "failed to match any pattern"))
   304  }
   305  
   306  func (r *Router) handle(
   307  	ctx context.Context,
   308  	w http.ResponseWriter,
   309  	req *http.Request,
   310  	tr *transcoder,
   311  	fieldValues map[string]string,
   312  ) {
   313  	modifiedCtx, err := r.opts.HeaderMatcher(ctx, w, req, r.opts.ServiceName, tr.name)
   314  	if err != nil {
   315  		r.opts.ErrorHandler(ctx, w, req, errs.New(errs.RetServerDecodeFail, err.Error()))
   316  		return
   317  	}
   318  	ctx = modifiedCtx
   319  	defer putBackCtxMessage(ctx)
   320  
   321  	timeout := r.opts.Timeout
   322  	requestTimeout := codec.Message(ctx).RequestTimeout()
   323  	if requestTimeout > 0 && (requestTimeout < timeout || timeout == 0) {
   324  		timeout = requestTimeout
   325  	}
   326  	if timeout > 0 {
   327  		var cancel context.CancelFunc
   328  		ctx, cancel = context.WithTimeout(ctx, timeout)
   329  		defer cancel()
   330  	}
   331  
   332  	// get inbound/outbound Compressor and Serializer
   333  	reqCompressor, respCompressor := compressorForTranscoding(req.Header[headerContentEncoding],
   334  		req.Header[headerAcceptEncoding])
   335  	reqSerializer, respSerializer := serializerForTranscoding(req.Header[headerContentType],
   336  		req.Header[headerAccept])
   337  
   338  	// set transcoder params
   339  	params, _ := paramsPool.Get().(*transcodeParams)
   340  	params.reqCompressor = reqCompressor
   341  	params.respCompressor = respCompressor
   342  	params.reqSerializer = reqSerializer
   343  	params.respSerializer = respSerializer
   344  	params.body = req.Body
   345  	params.fieldValues = fieldValues
   346  	params.form = req.URL.Query()
   347  	defer putBackParams(params)
   348  
   349  	// transcode
   350  	resp, body, err := tr.transcode(ctx, params)
   351  	if err != nil {
   352  		r.opts.ErrorHandler(ctx, w, req, err)
   353  		return
   354  	}
   355  
   356  	// custom response handling
   357  	if err := r.opts.ResponseHandler(ctx, w, req, resp, body); err != nil {
   358  		r.opts.ErrorHandler(ctx, w, req, errs.New(errs.RetServerEncodeFail, err.Error()))
   359  	}
   360  }