trpc.group/trpc-go/trpc-go@v1.0.3/restful/router.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package restful
    15  
    16  import (
    17  	"context"
    18  	"fmt"
    19  	"io"
    20  	"net/http"
    21  	"strconv"
    22  	"strings"
    23  	"sync"
    24  
    25  	"google.golang.org/protobuf/proto"
    26  	"google.golang.org/protobuf/types/known/emptypb"
    27  
    28  	"trpc.group/trpc-go/trpc-go/codec"
    29  	"trpc.group/trpc-go/trpc-go/errs"
    30  	"trpc.group/trpc-go/trpc-go/filter"
    31  	"trpc.group/trpc-go/trpc-go/internal/dat"
    32  )
    33  
    34  // Router is restful router.
    35  type Router struct {
    36  	opts        *Options
    37  	transcoders map[string][]*transcoder
    38  }
    39  
    40  // NewRouter creates a Router.
    41  func NewRouter(opts ...Option) *Router {
    42  	o := Options{
    43  		ErrorHandler:          DefaultErrorHandler,
    44  		HeaderMatcher:         DefaultHeaderMatcher,
    45  		ResponseHandler:       DefaultResponseHandler,
    46  		FastHTTPErrHandler:    DefaultFastHTTPErrorHandler,
    47  		FastHTTPHeaderMatcher: DefaultFastHTTPHeaderMatcher,
    48  		FastHTTPRespHandler:   DefaultFastHTTPRespHandler,
    49  	}
    50  	for _, opt := range opts {
    51  		opt(&o)
    52  	}
    53  	o.rebuildHeaderMatcher()
    54  
    55  	return &Router{
    56  		opts:        &o,
    57  		transcoders: make(map[string][]*transcoder),
    58  	}
    59  }
    60  
    61  var (
    62  	routers    = make(map[string]http.Handler) // tRPC service name -> Router
    63  	routerLock sync.RWMutex
    64  )
    65  
    66  // RegisterRouter registers a Router which corresponds to a tRPC Service.
    67  func RegisterRouter(name string, router http.Handler) {
    68  	routerLock.Lock()
    69  	routers[name] = router
    70  	routerLock.Unlock()
    71  }
    72  
    73  // GetRouter returns a Router which corresponds to a tRPC Service.
    74  func GetRouter(name string) http.Handler {
    75  	routerLock.RLock()
    76  	router := routers[name]
    77  	routerLock.RUnlock()
    78  	return router
    79  }
    80  
    81  // ProtoMessage is alias of proto.Message.
    82  type ProtoMessage proto.Message
    83  
    84  // Initializer initializes a ProtoMessage.
    85  type Initializer func() ProtoMessage
    86  
    87  // BodyLocator locates which fields of the proto message would be
    88  // populated according to HttpRule body.
    89  type BodyLocator interface {
    90  	Body() string
    91  	Locate(ProtoMessage) interface{}
    92  }
    93  
    94  // ResponseBodyLocator locates which fields of the proto message would be marshaled
    95  // according to HttpRule response_body.
    96  type ResponseBodyLocator interface {
    97  	ResponseBody() string
    98  	Locate(ProtoMessage) interface{}
    99  }
   100  
   101  // HandleFunc is tRPC method handle function.
   102  type HandleFunc func(svc interface{}, ctx context.Context, reqBody interface{}) (interface{}, error)
   103  
   104  // ExtractFilterFunc extracts tRPC service filter chain.
   105  type ExtractFilterFunc func() filter.ServerChain
   106  
   107  // Binding is the binding of tRPC method and HttpRule.
   108  type Binding struct {
   109  	Name         string
   110  	Input        Initializer
   111  	Output       Initializer
   112  	Filter       HandleFunc
   113  	HTTPMethod   string
   114  	Pattern      *Pattern
   115  	Body         BodyLocator
   116  	ResponseBody ResponseBodyLocator
   117  }
   118  
   119  // AddImplBinding creates a new binding with a specified service implementation.
   120  func (r *Router) AddImplBinding(binding *Binding, serviceImpl interface{}) error {
   121  	tr, err := r.newTranscoder(binding, serviceImpl)
   122  	if err != nil {
   123  		return fmt.Errorf("new transcoder during add impl binding: %w", err)
   124  	}
   125  	// add transcoder
   126  	r.transcoders[binding.HTTPMethod] = append(r.transcoders[binding.HTTPMethod], tr)
   127  	return nil
   128  }
   129  
   130  func (r *Router) newTranscoder(binding *Binding, serviceImpl interface{}) (*transcoder, error) {
   131  	if binding.Output == nil {
   132  		binding.Output = func() ProtoMessage { return &emptypb.Empty{} }
   133  	}
   134  
   135  	// create a transcoder
   136  	tr := &transcoder{
   137  		name:                 binding.Name,
   138  		input:                binding.Input,
   139  		output:               binding.Output,
   140  		handler:              binding.Filter,
   141  		httpMethod:           binding.HTTPMethod,
   142  		pat:                  binding.Pattern,
   143  		body:                 binding.Body,
   144  		respBody:             binding.ResponseBody,
   145  		router:               r,
   146  		discardUnknownParams: r.opts.DiscardUnknownParams,
   147  		serviceImpl:          serviceImpl,
   148  	}
   149  
   150  	// create a dat, filter all fields specified in HttpRule
   151  	var fps [][]string
   152  	if fromPat := binding.Pattern.FieldPaths(); fromPat != nil {
   153  		fps = append(fps, fromPat...)
   154  	}
   155  	if binding.Body != nil {
   156  		if fromBody := binding.Body.Body(); fromBody != "" && fromBody != "*" {
   157  			fps = append(fps, strings.Split(fromBody, "."))
   158  		}
   159  	}
   160  	if len(fps) > 0 {
   161  		doubleArrayTrie, err := dat.Build(fps)
   162  		if err != nil {
   163  			return nil, fmt.Errorf("failed to build dat: %w", err)
   164  		}
   165  		tr.dat = doubleArrayTrie
   166  	}
   167  	return tr, nil
   168  }
   169  
   170  // ctxForCompatibility is used only for compatibility with thttp.
   171  var ctxForCompatibility func(context.Context, http.ResponseWriter, *http.Request) context.Context
   172  
   173  // SetCtxForCompatibility is used only for compatibility with thttp.
   174  func SetCtxForCompatibility(f func(context.Context, http.ResponseWriter, *http.Request) context.Context) {
   175  	ctxForCompatibility = f
   176  }
   177  
   178  // HeaderMatcher matches http request header to tRPC Stub Context.
   179  type HeaderMatcher func(
   180  	ctx context.Context,
   181  	w http.ResponseWriter,
   182  	r *http.Request,
   183  	serviceName, methodName string,
   184  ) (context.Context, error)
   185  
   186  // DefaultHeaderMatcher is the default HeaderMatcher.
   187  var DefaultHeaderMatcher = func(
   188  	ctx context.Context,
   189  	w http.ResponseWriter,
   190  	req *http.Request,
   191  	serviceName, methodName string,
   192  ) (context.Context, error) {
   193  	// Noted: it's better to do the same thing as withNewMessage.
   194  	return withNewMessage(ctx, serviceName, methodName), nil
   195  }
   196  
   197  // withNewMessage create a new codec.Msg, put it into ctx,
   198  // and set target service name and method name.
   199  func withNewMessage(ctx context.Context, serviceName, methodName string) context.Context {
   200  	ctx, msg := codec.WithNewMessage(ctx)
   201  	msg.WithServerRPCName(methodName)
   202  	msg.WithCalleeServiceName(serviceName)
   203  	msg.WithSerializationType(codec.SerializationTypePB)
   204  	return ctx
   205  }
   206  
   207  // CustomResponseHandler is the custom response handler.
   208  type CustomResponseHandler func(
   209  	ctx context.Context,
   210  	w http.ResponseWriter,
   211  	r *http.Request,
   212  	resp proto.Message,
   213  	body []byte,
   214  ) error
   215  
   216  var httpStatusKey = "t-http-status"
   217  
   218  // SetStatusCodeOnSucceed sets status code on succeed, should be 2XX.
   219  // It's not supposed to call this function but use WithStatusCode in restful/errors.go
   220  // to set status code on error.
   221  func SetStatusCodeOnSucceed(ctx context.Context, code int) {
   222  	msg := codec.Message(ctx)
   223  	metadata := msg.ServerMetaData()
   224  	if metadata == nil {
   225  		metadata = codec.MetaData{}
   226  	}
   227  	metadata[httpStatusKey] = []byte(strconv.Itoa(code))
   228  	msg.WithServerMetaData(metadata)
   229  }
   230  
   231  // GetStatusCodeOnSucceed returns status code on succeed.
   232  // SetStatusCodeOnSucceed must be called first in tRPC method.
   233  func GetStatusCodeOnSucceed(ctx context.Context) int {
   234  	if metadata := codec.Message(ctx).ServerMetaData(); metadata != nil {
   235  		if buf, ok := metadata[httpStatusKey]; ok {
   236  			if code, err := strconv.Atoi(bytes2str(buf)); err == nil {
   237  				return code
   238  			}
   239  		}
   240  	}
   241  	return http.StatusOK
   242  }
   243  
   244  // DefaultResponseHandler is the default CustomResponseHandler.
   245  var DefaultResponseHandler = func(
   246  	ctx context.Context,
   247  	w http.ResponseWriter,
   248  	r *http.Request,
   249  	resp proto.Message,
   250  	body []byte,
   251  ) error {
   252  	// compress
   253  	var writer io.Writer = w
   254  	_, c := compressorForTranscoding(r.Header[headerContentEncoding],
   255  		r.Header[headerAcceptEncoding])
   256  	if c != nil {
   257  		writeCloser, err := c.Compress(w)
   258  		if err != nil {
   259  			return fmt.Errorf("failed to compress resp body: %w", err)
   260  		}
   261  		defer writeCloser.Close()
   262  		w.Header().Set(headerContentEncoding, c.ContentEncoding())
   263  		writer = writeCloser
   264  	}
   265  
   266  	// set response content-type
   267  	_, s := serializerForTranscoding(r.Header[headerContentType],
   268  		r.Header[headerAccept])
   269  	w.Header().Set(headerContentType, s.ContentType())
   270  
   271  	// set status code
   272  	statusCode := GetStatusCodeOnSucceed(ctx)
   273  	w.WriteHeader(statusCode)
   274  
   275  	// response body
   276  	if statusCode != http.StatusNoContent && statusCode != http.StatusNotModified {
   277  		writer.Write(body)
   278  	}
   279  
   280  	return nil
   281  }
   282  
   283  // putBackCtxMessage calls codec.PutBackMessage to put a codec.Msg back to pool,
   284  // if the codec.Msg has been put into ctx.
   285  func putBackCtxMessage(ctx context.Context) {
   286  	if msg, ok := ctx.Value(codec.ContextKeyMessage).(codec.Msg); ok {
   287  		codec.PutBackMessage(msg)
   288  	}
   289  }
   290  
   291  // ServeHTTP implements http.Handler.
   292  // TODO: better routing handling.
   293  func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
   294  	ctx := ctxForCompatibility(req.Context(), w, req)
   295  	for _, tr := range r.transcoders[req.Method] {
   296  		fieldValues, err := tr.pat.Match(req.URL.Path)
   297  		if err == nil {
   298  			r.handle(ctx, w, req, tr, fieldValues)
   299  			return
   300  		}
   301  	}
   302  	r.opts.ErrorHandler(ctx, w, req, errs.New(errs.RetServerNoFunc, "failed to match any pattern"))
   303  }
   304  
   305  func (r *Router) handle(
   306  	ctx context.Context,
   307  	w http.ResponseWriter,
   308  	req *http.Request,
   309  	tr *transcoder,
   310  	fieldValues map[string]string,
   311  ) {
   312  	modifiedCtx, err := r.opts.HeaderMatcher(ctx, w, req, r.opts.ServiceName, tr.name)
   313  	if err != nil {
   314  		r.opts.ErrorHandler(ctx, w, req, errs.New(errs.RetServerDecodeFail, err.Error()))
   315  		return
   316  	}
   317  	ctx = modifiedCtx
   318  	defer putBackCtxMessage(ctx)
   319  
   320  	timeout := r.opts.Timeout
   321  	requestTimeout := codec.Message(ctx).RequestTimeout()
   322  	if requestTimeout > 0 && (requestTimeout < timeout || timeout == 0) {
   323  		timeout = requestTimeout
   324  	}
   325  	if timeout > 0 {
   326  		var cancel context.CancelFunc
   327  		ctx, cancel = context.WithTimeout(ctx, timeout)
   328  		defer cancel()
   329  	}
   330  
   331  	// get inbound/outbound Compressor and Serializer
   332  	reqCompressor, respCompressor := compressorForTranscoding(req.Header[headerContentEncoding],
   333  		req.Header[headerAcceptEncoding])
   334  	reqSerializer, respSerializer := serializerForTranscoding(req.Header[headerContentType],
   335  		req.Header[headerAccept])
   336  
   337  	// set transcoder params
   338  	params, _ := paramsPool.Get().(*transcodeParams)
   339  	params.reqCompressor = reqCompressor
   340  	params.respCompressor = respCompressor
   341  	params.reqSerializer = reqSerializer
   342  	params.respSerializer = respSerializer
   343  	params.body = req.Body
   344  	params.fieldValues = fieldValues
   345  	params.form = req.URL.Query()
   346  	defer putBackParams(params)
   347  
   348  	// transcode
   349  	resp, body, err := tr.transcode(ctx, params)
   350  	if err != nil {
   351  		r.opts.ErrorHandler(ctx, w, req, err)
   352  		return
   353  	}
   354  
   355  	// custom response handling
   356  	if err := r.opts.ResponseHandler(ctx, w, req, resp, body); err != nil {
   357  		r.opts.ErrorHandler(ctx, w, req, errs.New(errs.RetServerEncodeFail, err.Error()))
   358  	}
   359  }