trpc.group/trpc-go/trpc-go@v1.0.3/restful/errors.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  package restful
    15  
    16  import (
    17  	"context"
    18  	"net/http"
    19  
    20  	"github.com/valyala/fasthttp"
    21  	trpcpb "trpc.group/trpc/trpc-protocol/pb/go/trpc"
    22  
    23  	"trpc.group/trpc-go/trpc-go/errs"
    24  	"trpc.group/trpc-go/trpc-go/restful/errors"
    25  )
    26  
    27  const (
    28  	// MarshalErrorContent is the content of http response body indicating error marshaling failure.
    29  	MarshalErrorContent = `{"code": 11, "message": "failed to marshal error"}`
    30  )
    31  
    32  // ErrorHandler handles tRPC errors.
    33  type ErrorHandler func(context.Context, http.ResponseWriter, *http.Request, error)
    34  
    35  // FastHTTPErrorHandler handles tRPC errors when fasthttp is used.
    36  type FastHTTPErrorHandler func(context.Context, *fasthttp.RequestCtx, error)
    37  
    38  // WithStatusCode is the error that corresponds to an HTTP status code.
    39  type WithStatusCode struct {
    40  	StatusCode int
    41  	Err        error
    42  }
    43  
    44  // Error implements Go error.
    45  func (w *WithStatusCode) Error() string {
    46  	return w.Err.Error()
    47  }
    48  
    49  // Unwrap returns the wrapped error.
    50  func (w *WithStatusCode) Unwrap() error {
    51  	return w.Err
    52  }
    53  
    54  // tRPC error code => http status code
    55  var httpStatusMap = map[trpcpb.TrpcRetCode]int{
    56  	errs.RetServerDecodeFail:   http.StatusBadRequest,
    57  	errs.RetServerEncodeFail:   http.StatusInternalServerError,
    58  	errs.RetServerNoService:    http.StatusNotFound,
    59  	errs.RetServerNoFunc:       http.StatusNotFound,
    60  	errs.RetServerTimeout:      http.StatusGatewayTimeout,
    61  	errs.RetServerOverload:     http.StatusTooManyRequests,
    62  	errs.RetServerSystemErr:    http.StatusInternalServerError,
    63  	errs.RetServerAuthFail:     http.StatusUnauthorized,
    64  	errs.RetServerValidateFail: http.StatusBadRequest,
    65  	errs.RetUnknown:            http.StatusInternalServerError,
    66  }
    67  
    68  // marshalError marshals an error.
    69  func marshalError(err error, s Serializer) ([]byte, error) {
    70  	// All Serializers for tRPC-Go RESTful are expected to marshal proto messages.
    71  	// So it's better to convert a tRPC error to an *errors.Err.
    72  	terr := &errors.Err{
    73  		Code:    int32(errs.Code(err)),
    74  		Message: errs.Msg(err),
    75  	}
    76  
    77  	return s.Marshal(terr)
    78  }
    79  
    80  // statusCodeFromError returns the status code from the error.
    81  func statusCodeFromError(err error) int {
    82  	statusCode := http.StatusInternalServerError
    83  
    84  	if withStatusCode, ok := err.(*WithStatusCode); ok {
    85  		statusCode = withStatusCode.StatusCode
    86  	} else {
    87  		if statusFromMap, ok := httpStatusMap[errs.Code(err)]; ok {
    88  			statusCode = statusFromMap
    89  		}
    90  	}
    91  
    92  	return statusCode
    93  }
    94  
    95  // DefaultErrorHandler is the default ErrorHandler.
    96  var DefaultErrorHandler = func(ctx context.Context, w http.ResponseWriter, r *http.Request, err error) {
    97  	// get outbound Serializer
    98  	_, s := serializerForTranscoding(r.Header[headerContentType],
    99  		r.Header[headerAccept])
   100  	w.Header().Set(headerContentType, s.ContentType())
   101  
   102  	// marshal error
   103  	buf, merr := marshalError(err, s)
   104  	if merr != nil {
   105  		w.WriteHeader(http.StatusInternalServerError)
   106  		w.Write([]byte(MarshalErrorContent))
   107  		return
   108  	}
   109  	// write response
   110  	w.WriteHeader(statusCodeFromError(err))
   111  	w.Write(buf)
   112  }
   113  
   114  // DefaultFastHTTPErrorHandler is the default FastHTTPErrorHandler.
   115  var DefaultFastHTTPErrorHandler = func(ctx context.Context, requestCtx *fasthttp.RequestCtx, err error) {
   116  	// get outbound Serializer
   117  	_, s := serializerForTranscoding(
   118  		[]string{bytes2str(requestCtx.Request.Header.Peek(headerContentType))},
   119  		[]string{bytes2str(requestCtx.Request.Header.Peek(headerAccept))},
   120  	)
   121  	requestCtx.Response.Header.Set(headerContentType, s.ContentType())
   122  
   123  	// marshal error
   124  	buf, merr := marshalError(err, s)
   125  	if merr != nil {
   126  		requestCtx.Response.SetStatusCode(http.StatusInternalServerError)
   127  		requestCtx.Write([]byte(MarshalErrorContent))
   128  		return
   129  	}
   130  	// write response
   131  	requestCtx.SetStatusCode(statusCodeFromError(err))
   132  	requestCtx.Write(buf)
   133  }