go.uber.org/yarpc@v1.72.1/transport/tchannel/handler.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package tchannel
    22  
    23  import (
    24  	"bytes"
    25  	"context"
    26  	"fmt"
    27  	"strconv"
    28  	"time"
    29  
    30  	"github.com/opentracing/opentracing-go"
    31  	"github.com/uber/tchannel-go"
    32  	"go.uber.org/multierr"
    33  	"go.uber.org/yarpc/api/transport"
    34  	"go.uber.org/yarpc/internal/bufferpool"
    35  	"go.uber.org/yarpc/pkg/errors"
    36  	"go.uber.org/yarpc/yarpcerrors"
    37  	"go.uber.org/zap"
    38  	ncontext "golang.org/x/net/context"
    39  )
    40  
    41  // inboundCall provides an interface similar tchannel.InboundCall.
    42  //
    43  // We use it instead of *tchannel.InboundCall because tchannel.InboundCall is
    44  // not an interface, so we have little control over its behavior in tests.
    45  type inboundCall interface {
    46  	ServiceName() string
    47  	CallerName() string
    48  	MethodString() string
    49  	ShardKey() string
    50  	RoutingKey() string
    51  	RoutingDelegate() string
    52  
    53  	Format() tchannel.Format
    54  
    55  	Arg2Reader() (tchannel.ArgReader, error)
    56  	Arg3Reader() (tchannel.ArgReader, error)
    57  
    58  	Response() inboundCallResponse
    59  }
    60  
    61  // inboundCallResponse provides an interface similar to
    62  // tchannel.InboundCallResponse.
    63  //
    64  // Its purpose is the same as inboundCall: Make it easier to test functions
    65  // that consume InboundCallResponse without having control of
    66  // InboundCallResponse's behavior.
    67  type inboundCallResponse interface {
    68  	Arg2Writer() (tchannel.ArgWriter, error)
    69  	Arg3Writer() (tchannel.ArgWriter, error)
    70  	Blackhole()
    71  	SendSystemError(err error) error
    72  	SetApplicationError() error
    73  }
    74  
    75  // responseWriter provides an interface similar to handlerWriter.
    76  //
    77  // It allows us to control handlerWriter during testing.
    78  type responseWriter interface {
    79  	AddHeaders(h transport.Headers)
    80  	AddHeader(key string, value string)
    81  	Close() error
    82  	ReleaseBuffer()
    83  	IsApplicationError() bool
    84  	SetApplicationError()
    85  	SetApplicationErrorMeta(meta *transport.ApplicationErrorMeta)
    86  	Write(s []byte) (int, error)
    87  }
    88  
    89  // tchannelCall wraps a TChannel InboundCall into an inboundCall.
    90  //
    91  // We need to do this so that we can change the return type of call.Response()
    92  // to match inboundCall's Response().
    93  type tchannelCall struct{ *tchannel.InboundCall }
    94  
    95  func (c tchannelCall) Response() inboundCallResponse {
    96  	return c.InboundCall.Response()
    97  }
    98  
    99  // handler wraps a transport.UnaryHandler into a TChannel Handler.
   100  type handler struct {
   101  	existing                       map[string]tchannel.Handler
   102  	router                         transport.Router
   103  	tracer                         opentracing.Tracer
   104  	headerCase                     headerCase
   105  	logger                         *zap.Logger
   106  	newResponseWriter              func(inboundCallResponse, tchannel.Format, headerCase) responseWriter
   107  	excludeServiceHeaderInResponse bool
   108  }
   109  
   110  func (h handler) Handle(ctx ncontext.Context, call *tchannel.InboundCall) {
   111  	h.handle(ctx, tchannelCall{call})
   112  }
   113  
   114  func (h handler) handle(ctx context.Context, call inboundCall) {
   115  	// you MUST close the responseWriter no matter what unless you have a tchannel.SystemError
   116  	responseWriter := h.newResponseWriter(call.Response(), call.Format(), h.headerCase)
   117  	defer responseWriter.ReleaseBuffer()
   118  
   119  	if !h.excludeServiceHeaderInResponse {
   120  		// echo accepted rpc-service in response header
   121  		responseWriter.AddHeader(ServiceHeaderKey, call.ServiceName())
   122  	}
   123  
   124  	err := h.callHandler(ctx, call, responseWriter)
   125  
   126  	// black-hole requests on resource exhausted errors
   127  	if yarpcerrors.FromError(err).Code() == yarpcerrors.CodeResourceExhausted {
   128  		// all TChannel clients will time out instead of receiving an error
   129  		call.Response().Blackhole()
   130  		return
   131  	}
   132  
   133  	clientTimedOut := ctx.Err() == context.DeadlineExceeded
   134  
   135  	if err != nil && !responseWriter.IsApplicationError() {
   136  		sendSysErr := call.Response().SendSystemError(getSystemError(err))
   137  		if sendSysErr != nil && !clientTimedOut {
   138  			// only log errors if client is still waiting for our response
   139  			h.logger.Error("SendSystemError failed", zap.Error(sendSysErr))
   140  		}
   141  		return
   142  	}
   143  	if err != nil && responseWriter.IsApplicationError() {
   144  		// we have an error, so we're going to propagate it as a yarpc error,
   145  		// regardless of whether or not it is a system error.
   146  		status := yarpcerrors.FromError(errors.WrapHandlerError(err, call.ServiceName(), call.MethodString()))
   147  		// TODO: what to do with error? we could have a whole complicated scheme to
   148  		// return a SystemError here, might want to do that
   149  		text, _ := status.Code().MarshalText()
   150  		responseWriter.AddHeader(ErrorCodeHeaderKey, string(text))
   151  		if status.Name() != "" {
   152  			responseWriter.AddHeader(ErrorNameHeaderKey, status.Name())
   153  		}
   154  		if status.Message() != "" {
   155  			responseWriter.AddHeader(ErrorMessageHeaderKey, status.Message())
   156  		}
   157  	}
   158  	if reswErr := responseWriter.Close(); reswErr != nil && !clientTimedOut {
   159  		if sendSysErr := call.Response().SendSystemError(getSystemError(reswErr)); sendSysErr != nil {
   160  			h.logger.Error("SendSystemError failed", zap.Error(sendSysErr))
   161  		}
   162  		h.logger.Error("responseWriter failed to close", zap.Error(reswErr))
   163  	}
   164  }
   165  
   166  func (h handler) callHandler(ctx context.Context, call inboundCall, responseWriter responseWriter) error {
   167  	start := time.Now()
   168  	_, ok := ctx.Deadline()
   169  	if !ok {
   170  		return tchannel.ErrTimeoutRequired
   171  	}
   172  
   173  	treq := &transport.Request{
   174  		Caller:          call.CallerName(),
   175  		Service:         call.ServiceName(),
   176  		Encoding:        transport.Encoding(call.Format()),
   177  		Transport:       TransportName,
   178  		Procedure:       call.MethodString(),
   179  		ShardKey:        call.ShardKey(),
   180  		RoutingKey:      call.RoutingKey(),
   181  		RoutingDelegate: call.RoutingDelegate(),
   182  	}
   183  
   184  	ctx, headers, err := readRequestHeaders(ctx, call.Format(), call.Arg2Reader)
   185  	if err != nil {
   186  		return errors.RequestHeadersDecodeError(treq, err)
   187  	}
   188  
   189  	// callerProcedure is a rpc header but recevied in application headers, so moving this header to transprotRequest
   190  	// by updating treq.CallerProcedure.
   191  	treq = headerCallerProcedureToRequest(treq, &headers)
   192  	treq.Headers = headers
   193  
   194  	if tcall, ok := call.(tchannelCall); ok {
   195  		tracer := h.tracer
   196  		ctx = tchannel.ExtractInboundSpan(ctx, tcall.InboundCall, headers.Items(), tracer)
   197  	}
   198  
   199  	buf := bufferpool.Get()
   200  	defer bufferpool.Put(buf)
   201  
   202  	body, err := call.Arg3Reader()
   203  	if err != nil {
   204  		return err
   205  	}
   206  
   207  	if _, err = buf.ReadFrom(body); err != nil {
   208  		return err
   209  	}
   210  	if err = body.Close(); err != nil {
   211  		return err
   212  	}
   213  
   214  	treq.Body = bytes.NewReader(buf.Bytes())
   215  	treq.BodySize = buf.Len()
   216  
   217  	if err := transport.ValidateRequest(treq); err != nil {
   218  		return err
   219  	}
   220  
   221  	spec, err := h.router.Choose(ctx, treq)
   222  	if err != nil {
   223  		if yarpcerrors.FromError(err).Code() != yarpcerrors.CodeUnimplemented {
   224  			return err
   225  		}
   226  		if tcall, ok := call.(tchannelCall); !ok {
   227  			if m, ok := h.existing[call.MethodString()]; ok {
   228  				m.Handle(ctx, tcall.InboundCall)
   229  				return nil
   230  			}
   231  		}
   232  		return err
   233  	}
   234  
   235  	if err := transport.ValidateRequestContext(ctx); err != nil {
   236  		return err
   237  	}
   238  	switch spec.Type() {
   239  	case transport.Unary:
   240  		return transport.InvokeUnaryHandler(transport.UnaryInvokeRequest{
   241  			Context:        ctx,
   242  			StartTime:      start,
   243  			Request:        treq,
   244  			ResponseWriter: responseWriter,
   245  			Handler:        spec.Unary(),
   246  			Logger:         h.logger,
   247  		})
   248  
   249  	default:
   250  		return yarpcerrors.Newf(yarpcerrors.CodeUnimplemented, "transport tchannel does not handle %s handlers", spec.Type().String())
   251  	}
   252  }
   253  
   254  type handlerWriter struct {
   255  	failedWith       error
   256  	format           tchannel.Format
   257  	headers          transport.Headers
   258  	buffer           *bufferpool.Buffer
   259  	response         inboundCallResponse
   260  	applicationError bool
   261  	headerCase       headerCase
   262  }
   263  
   264  func newHandlerWriter(response inboundCallResponse, format tchannel.Format, headerCase headerCase) responseWriter {
   265  	return &handlerWriter{
   266  		response:   response,
   267  		format:     format,
   268  		headerCase: headerCase,
   269  	}
   270  }
   271  
   272  func (hw *handlerWriter) AddHeaders(h transport.Headers) {
   273  	for k, v := range h.OriginalItems() {
   274  		if isReservedHeaderKey(k) {
   275  			hw.failedWith = appendError(hw.failedWith, fmt.Errorf("cannot use reserved header key: %s", k))
   276  			return
   277  		}
   278  		hw.AddHeader(k, v)
   279  	}
   280  }
   281  
   282  func (hw *handlerWriter) AddHeader(key string, value string) {
   283  	hw.headers = hw.headers.With(key, value)
   284  }
   285  
   286  func (hw *handlerWriter) SetApplicationError() {
   287  	hw.applicationError = true
   288  }
   289  
   290  func (hw *handlerWriter) SetApplicationErrorMeta(applicationErrorMeta *transport.ApplicationErrorMeta) {
   291  	if applicationErrorMeta == nil {
   292  		return
   293  	}
   294  	if applicationErrorMeta.Code != nil {
   295  		hw.AddHeader(ApplicationErrorCodeHeaderKey, strconv.Itoa(int(*applicationErrorMeta.Code)))
   296  	}
   297  	if applicationErrorMeta.Name != "" {
   298  		hw.AddHeader(ApplicationErrorNameHeaderKey, applicationErrorMeta.Name)
   299  	}
   300  	if applicationErrorMeta.Details != "" {
   301  		hw.AddHeader(ApplicationErrorDetailsHeaderKey, truncateAppErrDetails(applicationErrorMeta.Details))
   302  	}
   303  }
   304  
   305  func truncateAppErrDetails(val string) string {
   306  	if len(val) <= _maxAppErrDetailsHeaderLen {
   307  		return val
   308  	}
   309  	stripIndex := _maxAppErrDetailsHeaderLen - len(_truncatedHeaderMessage)
   310  	return val[:stripIndex] + _truncatedHeaderMessage
   311  }
   312  
   313  func (hw *handlerWriter) IsApplicationError() bool {
   314  	return hw.applicationError
   315  }
   316  
   317  func (hw *handlerWriter) Write(s []byte) (int, error) {
   318  	if hw.failedWith != nil {
   319  		return 0, hw.failedWith
   320  	}
   321  
   322  	if hw.buffer == nil {
   323  		hw.buffer = bufferpool.Get()
   324  	}
   325  
   326  	n, err := hw.buffer.Write(s)
   327  	if err != nil {
   328  		hw.failedWith = appendError(hw.failedWith, err)
   329  	}
   330  	return n, err
   331  }
   332  
   333  func (hw *handlerWriter) Close() error {
   334  	retErr := hw.failedWith
   335  	if hw.IsApplicationError() {
   336  		if err := hw.response.SetApplicationError(); err != nil {
   337  			retErr = appendError(retErr, fmt.Errorf("SetApplicationError() failed: %v", err))
   338  		}
   339  	}
   340  
   341  	headers := headerMap(hw.headers, hw.headerCase)
   342  	retErr = appendError(retErr, writeHeaders(hw.format, headers, nil, hw.response.Arg2Writer))
   343  
   344  	// Arg3Writer must be opened and closed regardless of if there is data
   345  	// However, if there is a system error, we do not want to do this
   346  	bodyWriter, err := hw.response.Arg3Writer()
   347  	if err != nil {
   348  		return appendError(retErr, err)
   349  	}
   350  	defer func() { retErr = appendError(retErr, bodyWriter.Close()) }()
   351  	if hw.buffer != nil {
   352  		if _, err := hw.buffer.WriteTo(bodyWriter); err != nil {
   353  			return appendError(retErr, err)
   354  		}
   355  	}
   356  
   357  	return retErr
   358  }
   359  
   360  func (hw *handlerWriter) ReleaseBuffer() {
   361  	if hw.buffer != nil {
   362  		bufferpool.Put(hw.buffer)
   363  		hw.buffer = nil
   364  	}
   365  }
   366  
   367  func getSystemError(err error) error {
   368  	if _, ok := err.(tchannel.SystemError); ok {
   369  		return err
   370  	}
   371  	if !yarpcerrors.IsStatus(err) {
   372  		return tchannel.NewSystemError(tchannel.ErrCodeUnexpected, err.Error())
   373  	}
   374  	status := yarpcerrors.FromError(err)
   375  	tchannelCode, ok := _codeToTChannelCode[status.Code()]
   376  	if !ok {
   377  		tchannelCode = tchannel.ErrCodeUnexpected
   378  	}
   379  	return tchannel.NewSystemError(tchannelCode, status.Message())
   380  }
   381  
   382  func appendError(left error, right error) error {
   383  	if _, ok := left.(tchannel.SystemError); ok {
   384  		return left
   385  	}
   386  	if _, ok := right.(tchannel.SystemError); ok {
   387  		return right
   388  	}
   389  	return multierr.Append(left, right)
   390  }