go.uber.org/yarpc@v1.72.1/transport/http/handler.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package http
    22  
    23  import (
    24  	"bytes"
    25  	"context"
    26  	"fmt"
    27  	"net/http"
    28  	"strconv"
    29  	"time"
    30  
    31  	"github.com/opentracing/opentracing-go"
    32  	"github.com/opentracing/opentracing-go/ext"
    33  	opentracinglog "github.com/opentracing/opentracing-go/log"
    34  	"go.uber.org/yarpc"
    35  	"go.uber.org/yarpc/api/transport"
    36  	"go.uber.org/yarpc/internal/bufferpool"
    37  	"go.uber.org/yarpc/internal/iopool"
    38  	"go.uber.org/yarpc/pkg/errors"
    39  	"go.uber.org/yarpc/yarpcerrors"
    40  	"go.uber.org/zap"
    41  )
    42  
    43  func popHeader(h http.Header, n string) string {
    44  	v := h.Get(n)
    45  	h.Del(n)
    46  	return v
    47  }
    48  
    49  // handler adapts a transport.Handler into a handler for net/http.
    50  type handler struct {
    51  	router            transport.Router
    52  	tracer            opentracing.Tracer
    53  	grabHeaders       map[string]struct{}
    54  	bothResponseError bool
    55  	logger            *zap.Logger
    56  }
    57  
    58  func (h handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
    59  	responseWriter := newResponseWriter(w)
    60  	service := popHeader(req.Header, ServiceHeader)
    61  	procedure := popHeader(req.Header, ProcedureHeader)
    62  	bothResponseError := popHeader(req.Header, AcceptsBothResponseErrorHeader) == AcceptTrue
    63  	// add response header to echo accepted rpc-service
    64  	responseWriter.AddSystemHeader(ServiceHeader, service)
    65  	status := yarpcerrors.FromError(errors.WrapHandlerError(h.callHandler(responseWriter, req, service, procedure), service, procedure))
    66  	if status == nil {
    67  		responseWriter.Close(http.StatusOK)
    68  		return
    69  	}
    70  	if statusCodeText, marshalErr := status.Code().MarshalText(); marshalErr != nil {
    71  		status = yarpcerrors.Newf(yarpcerrors.CodeInternal, "error %s had code %v which is unknown", status.Error(), status.Code())
    72  		responseWriter.AddSystemHeader(ErrorCodeHeader, "internal")
    73  	} else {
    74  		responseWriter.AddSystemHeader(ErrorCodeHeader, string(statusCodeText))
    75  	}
    76  	if status.Name() != "" {
    77  		responseWriter.AddSystemHeader(ErrorNameHeader, status.Name())
    78  	}
    79  	if bothResponseError && h.bothResponseError {
    80  		responseWriter.AddSystemHeader(BothResponseErrorHeader, AcceptTrue)
    81  		responseWriter.AddSystemHeader(ErrorMessageHeader, status.Message())
    82  		if details := status.Details(); details != nil {
    83  			responseWriter.AddSystemHeader(ErrorDetailsHeader, string(details))
    84  			responseWriter.ResetBuffer()
    85  			_, _ = responseWriter.Write(details)
    86  		}
    87  	} else {
    88  		responseWriter.ResetBuffer()
    89  		_, _ = fmt.Fprintln(responseWriter, status.Message())
    90  		responseWriter.AddSystemHeader("Content-Type", "text/plain; charset=utf8")
    91  	}
    92  	httpStatusCode, ok := _codeToStatusCode[status.Code()]
    93  	if !ok {
    94  		httpStatusCode = http.StatusInternalServerError
    95  	}
    96  	responseWriter.Close(httpStatusCode)
    97  }
    98  
    99  func (h handler) callHandler(responseWriter *responseWriter, req *http.Request, service string, procedure string) (retErr error) {
   100  	start := time.Now()
   101  	defer req.Body.Close()
   102  	if req.Method != http.MethodPost {
   103  		return yarpcerrors.Newf(yarpcerrors.CodeNotFound, "request method was %s but only %s is allowed", req.Method, http.MethodPost)
   104  	}
   105  
   106  	treq := &transport.Request{
   107  		Caller:          popHeader(req.Header, CallerHeader),
   108  		Service:         service,
   109  		Procedure:       procedure,
   110  		Encoding:        transport.Encoding(popHeader(req.Header, EncodingHeader)),
   111  		Transport:       TransportName,
   112  		ShardKey:        popHeader(req.Header, ShardKeyHeader),
   113  		RoutingKey:      popHeader(req.Header, RoutingKeyHeader),
   114  		RoutingDelegate: popHeader(req.Header, RoutingDelegateHeader),
   115  		CallerProcedure: popHeader(req.Header, CallerProcedureHeader),
   116  		Headers:         applicationHeaders.FromHTTPHeaders(req.Header, transport.Headers{}),
   117  		Body:            req.Body,
   118  		BodySize:        int(req.ContentLength),
   119  	}
   120  	for header := range h.grabHeaders {
   121  		if value := req.Header.Get(header); value != "" {
   122  			treq.Headers = treq.Headers.With(header, value)
   123  		}
   124  	}
   125  	if err := transport.ValidateRequest(treq); err != nil {
   126  		return err
   127  	}
   128  	defer func() {
   129  		if retErr == nil {
   130  			if contentType := getContentType(treq.Encoding); contentType != "" {
   131  				responseWriter.AddSystemHeader("Content-Type", contentType)
   132  			}
   133  		}
   134  	}()
   135  
   136  	ctx := req.Context()
   137  	ctx, cancel, parseTTLErr := parseTTL(ctx, treq, popHeader(req.Header, TTLMSHeader))
   138  	// parseTTLErr != nil is a problem only if the request is unary.
   139  	defer cancel()
   140  	ctx, span := h.createSpan(ctx, req, treq, start)
   141  
   142  	spec, err := h.router.Choose(ctx, treq)
   143  	if err != nil {
   144  		updateSpanWithErr(span, err)
   145  		return err
   146  	}
   147  
   148  	if parseTTLErr != nil {
   149  		return parseTTLErr
   150  	}
   151  	if err := transport.ValidateRequestContext(ctx); err != nil {
   152  		return err
   153  	}
   154  	switch spec.Type() {
   155  	case transport.Unary:
   156  		defer span.Finish()
   157  
   158  		err = transport.InvokeUnaryHandler(transport.UnaryInvokeRequest{
   159  			Context:        ctx,
   160  			StartTime:      start,
   161  			Request:        treq,
   162  			Handler:        spec.Unary(),
   163  			ResponseWriter: responseWriter,
   164  			Logger:         h.logger,
   165  		})
   166  
   167  	case transport.Oneway:
   168  		err = handleOnewayRequest(span, treq, spec.Oneway(), h.logger)
   169  
   170  	default:
   171  		err = yarpcerrors.Newf(yarpcerrors.CodeUnimplemented, "transport http does not handle %s handlers", spec.Type().String())
   172  	}
   173  
   174  	updateSpanWithErr(span, err)
   175  	return err
   176  }
   177  
   178  func handleOnewayRequest(
   179  	span opentracing.Span,
   180  	treq *transport.Request,
   181  	onewayHandler transport.OnewayHandler,
   182  	logger *zap.Logger,
   183  ) error {
   184  	// we will lose access to the body unless we read all the bytes before
   185  	// returning from the request
   186  	var buff bytes.Buffer
   187  	if _, err := iopool.Copy(&buff, treq.Body); err != nil {
   188  		return err
   189  	}
   190  	treq.Body = &buff
   191  
   192  	// create a new context for oneway requests since the HTTP handler cancels
   193  	// http.Request's context when ServeHTTP returns
   194  	ctx := opentracing.ContextWithSpan(context.Background(), span)
   195  
   196  	go func() {
   197  		// ensure the span lasts for length of the handler in case of errors
   198  		defer span.Finish()
   199  
   200  		err := transport.InvokeOnewayHandler(transport.OnewayInvokeRequest{
   201  			Context: ctx,
   202  			Request: treq,
   203  			Handler: onewayHandler,
   204  			Logger:  logger,
   205  		})
   206  		updateSpanWithErr(span, err)
   207  	}()
   208  	return nil
   209  }
   210  
   211  func updateSpanWithErr(span opentracing.Span, err error) {
   212  	if err != nil {
   213  		span.SetTag("error", true)
   214  		span.LogFields(opentracinglog.String("event", err.Error()))
   215  	}
   216  }
   217  
   218  func (h handler) createSpan(ctx context.Context, req *http.Request, treq *transport.Request, start time.Time) (context.Context, opentracing.Span) {
   219  	// Extract opentracing etc baggage from headers
   220  	// Annotate the inbound context with a trace span
   221  	tracer := h.tracer
   222  	carrier := opentracing.HTTPHeadersCarrier(req.Header)
   223  	parentSpanCtx, _ := tracer.Extract(opentracing.HTTPHeaders, carrier)
   224  	// parentSpanCtx may be nil, ext.RPCServerOption handles a nil parent
   225  	// gracefully.
   226  	tags := opentracing.Tags{
   227  		"rpc.caller":    treq.Caller,
   228  		"rpc.service":   treq.Service,
   229  		"rpc.encoding":  treq.Encoding,
   230  		"rpc.transport": "http",
   231  	}
   232  	for k, v := range yarpc.OpentracingTags {
   233  		tags[k] = v
   234  	}
   235  	span := tracer.StartSpan(
   236  		treq.Procedure,
   237  		opentracing.StartTime(start),
   238  		ext.RPCServerOption(parentSpanCtx), // implies ChildOf
   239  		tags,
   240  	)
   241  	ext.PeerService.Set(span, treq.Caller)
   242  	ctx = opentracing.ContextWithSpan(ctx, span)
   243  	return ctx, span
   244  }
   245  
   246  var (
   247  	_ transport.ResponseWriter             = (*responseWriter)(nil)
   248  	_ transport.ApplicationErrorMetaSetter = (*responseWriter)(nil)
   249  )
   250  
   251  // responseWriter adapts a http.ResponseWriter into a transport.ResponseWriter.
   252  type responseWriter struct {
   253  	w      http.ResponseWriter
   254  	buffer *bufferpool.Buffer
   255  }
   256  
   257  func newResponseWriter(w http.ResponseWriter) *responseWriter {
   258  	w.Header().Set(ApplicationStatusHeader, ApplicationSuccessStatus)
   259  	return &responseWriter{w: w}
   260  }
   261  
   262  func (rw *responseWriter) Write(s []byte) (int, error) {
   263  	if rw.buffer == nil {
   264  		rw.buffer = bufferpool.Get()
   265  	}
   266  	return rw.buffer.Write(s)
   267  }
   268  
   269  func (rw *responseWriter) AddHeaders(h transport.Headers) {
   270  	applicationHeaders.ToHTTPHeaders(h, rw.w.Header())
   271  }
   272  
   273  func (rw *responseWriter) SetApplicationError() {
   274  	rw.w.Header().Set(ApplicationStatusHeader, ApplicationErrorStatus)
   275  }
   276  
   277  func (rw *responseWriter) SetApplicationErrorMeta(meta *transport.ApplicationErrorMeta) {
   278  	if meta == nil {
   279  		return
   280  	}
   281  	if meta.Code != nil {
   282  		rw.w.Header().Set(_applicationErrorCodeHeader, strconv.Itoa(int(*meta.Code)))
   283  	}
   284  	if meta.Name != "" {
   285  		rw.w.Header().Set(_applicationErrorNameHeader, meta.Name)
   286  	}
   287  	if meta.Details != "" {
   288  		rw.w.Header().Set(_applicationErrorDetailsHeader, truncateAppErrDetails(meta.Details))
   289  	}
   290  }
   291  
   292  func truncateAppErrDetails(val string) string {
   293  	if len(val) <= _maxAppErrDetailsHeaderLen {
   294  		return val
   295  	}
   296  	stripIndex := _maxAppErrDetailsHeaderLen - len(_truncatedHeaderMessage)
   297  	return val[:stripIndex] + _truncatedHeaderMessage
   298  }
   299  
   300  func (rw *responseWriter) AddSystemHeader(key string, value string) {
   301  	rw.w.Header().Set(key, value)
   302  }
   303  
   304  func (rw *responseWriter) ResetBuffer() {
   305  	if rw.buffer != nil {
   306  		rw.buffer.Reset()
   307  	}
   308  }
   309  
   310  func (rw *responseWriter) Close(httpStatusCode int) {
   311  	rw.w.WriteHeader(httpStatusCode)
   312  	if rw.buffer != nil {
   313  		// TODO: what to do with error?
   314  		_, _ = rw.buffer.WriteTo(rw.w)
   315  		bufferpool.Put(rw.buffer)
   316  	}
   317  }
   318  
   319  func getContentType(encoding transport.Encoding) string {
   320  	switch encoding {
   321  	case "json":
   322  		return "application/json"
   323  	case "raw":
   324  		return "application/octet-stream"
   325  	case "thrift":
   326  		return "application/vnd.apache.thrift.binary"
   327  	case "proto":
   328  		return "application/x-protobuf"
   329  	default:
   330  		return ""
   331  	}
   332  }