go.uber.org/yarpc@v1.72.1/transport/tchannel/handler.go (about) 1 // Copyright (c) 2022 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 21 package tchannel 22 23 import ( 24 "bytes" 25 "context" 26 "fmt" 27 "strconv" 28 "time" 29 30 "github.com/opentracing/opentracing-go" 31 "github.com/uber/tchannel-go" 32 "go.uber.org/multierr" 33 "go.uber.org/yarpc/api/transport" 34 "go.uber.org/yarpc/internal/bufferpool" 35 "go.uber.org/yarpc/pkg/errors" 36 "go.uber.org/yarpc/yarpcerrors" 37 "go.uber.org/zap" 38 ncontext "golang.org/x/net/context" 39 ) 40 41 // inboundCall provides an interface similar tchannel.InboundCall. 42 // 43 // We use it instead of *tchannel.InboundCall because tchannel.InboundCall is 44 // not an interface, so we have little control over its behavior in tests. 45 type inboundCall interface { 46 ServiceName() string 47 CallerName() string 48 MethodString() string 49 ShardKey() string 50 RoutingKey() string 51 RoutingDelegate() string 52 53 Format() tchannel.Format 54 55 Arg2Reader() (tchannel.ArgReader, error) 56 Arg3Reader() (tchannel.ArgReader, error) 57 58 Response() inboundCallResponse 59 } 60 61 // inboundCallResponse provides an interface similar to 62 // tchannel.InboundCallResponse. 63 // 64 // Its purpose is the same as inboundCall: Make it easier to test functions 65 // that consume InboundCallResponse without having control of 66 // InboundCallResponse's behavior. 67 type inboundCallResponse interface { 68 Arg2Writer() (tchannel.ArgWriter, error) 69 Arg3Writer() (tchannel.ArgWriter, error) 70 Blackhole() 71 SendSystemError(err error) error 72 SetApplicationError() error 73 } 74 75 // responseWriter provides an interface similar to handlerWriter. 76 // 77 // It allows us to control handlerWriter during testing. 78 type responseWriter interface { 79 AddHeaders(h transport.Headers) 80 AddHeader(key string, value string) 81 Close() error 82 ReleaseBuffer() 83 IsApplicationError() bool 84 SetApplicationError() 85 SetApplicationErrorMeta(meta *transport.ApplicationErrorMeta) 86 Write(s []byte) (int, error) 87 } 88 89 // tchannelCall wraps a TChannel InboundCall into an inboundCall. 90 // 91 // We need to do this so that we can change the return type of call.Response() 92 // to match inboundCall's Response(). 93 type tchannelCall struct{ *tchannel.InboundCall } 94 95 func (c tchannelCall) Response() inboundCallResponse { 96 return c.InboundCall.Response() 97 } 98 99 // handler wraps a transport.UnaryHandler into a TChannel Handler. 100 type handler struct { 101 existing map[string]tchannel.Handler 102 router transport.Router 103 tracer opentracing.Tracer 104 headerCase headerCase 105 logger *zap.Logger 106 newResponseWriter func(inboundCallResponse, tchannel.Format, headerCase) responseWriter 107 excludeServiceHeaderInResponse bool 108 } 109 110 func (h handler) Handle(ctx ncontext.Context, call *tchannel.InboundCall) { 111 h.handle(ctx, tchannelCall{call}) 112 } 113 114 func (h handler) handle(ctx context.Context, call inboundCall) { 115 // you MUST close the responseWriter no matter what unless you have a tchannel.SystemError 116 responseWriter := h.newResponseWriter(call.Response(), call.Format(), h.headerCase) 117 defer responseWriter.ReleaseBuffer() 118 119 if !h.excludeServiceHeaderInResponse { 120 // echo accepted rpc-service in response header 121 responseWriter.AddHeader(ServiceHeaderKey, call.ServiceName()) 122 } 123 124 err := h.callHandler(ctx, call, responseWriter) 125 126 // black-hole requests on resource exhausted errors 127 if yarpcerrors.FromError(err).Code() == yarpcerrors.CodeResourceExhausted { 128 // all TChannel clients will time out instead of receiving an error 129 call.Response().Blackhole() 130 return 131 } 132 133 clientTimedOut := ctx.Err() == context.DeadlineExceeded 134 135 if err != nil && !responseWriter.IsApplicationError() { 136 sendSysErr := call.Response().SendSystemError(getSystemError(err)) 137 if sendSysErr != nil && !clientTimedOut { 138 // only log errors if client is still waiting for our response 139 h.logger.Error("SendSystemError failed", zap.Error(sendSysErr)) 140 } 141 return 142 } 143 if err != nil && responseWriter.IsApplicationError() { 144 // we have an error, so we're going to propagate it as a yarpc error, 145 // regardless of whether or not it is a system error. 146 status := yarpcerrors.FromError(errors.WrapHandlerError(err, call.ServiceName(), call.MethodString())) 147 // TODO: what to do with error? we could have a whole complicated scheme to 148 // return a SystemError here, might want to do that 149 text, _ := status.Code().MarshalText() 150 responseWriter.AddHeader(ErrorCodeHeaderKey, string(text)) 151 if status.Name() != "" { 152 responseWriter.AddHeader(ErrorNameHeaderKey, status.Name()) 153 } 154 if status.Message() != "" { 155 responseWriter.AddHeader(ErrorMessageHeaderKey, status.Message()) 156 } 157 } 158 if reswErr := responseWriter.Close(); reswErr != nil && !clientTimedOut { 159 if sendSysErr := call.Response().SendSystemError(getSystemError(reswErr)); sendSysErr != nil { 160 h.logger.Error("SendSystemError failed", zap.Error(sendSysErr)) 161 } 162 h.logger.Error("responseWriter failed to close", zap.Error(reswErr)) 163 } 164 } 165 166 func (h handler) callHandler(ctx context.Context, call inboundCall, responseWriter responseWriter) error { 167 start := time.Now() 168 _, ok := ctx.Deadline() 169 if !ok { 170 return tchannel.ErrTimeoutRequired 171 } 172 173 treq := &transport.Request{ 174 Caller: call.CallerName(), 175 Service: call.ServiceName(), 176 Encoding: transport.Encoding(call.Format()), 177 Transport: TransportName, 178 Procedure: call.MethodString(), 179 ShardKey: call.ShardKey(), 180 RoutingKey: call.RoutingKey(), 181 RoutingDelegate: call.RoutingDelegate(), 182 } 183 184 ctx, headers, err := readRequestHeaders(ctx, call.Format(), call.Arg2Reader) 185 if err != nil { 186 return errors.RequestHeadersDecodeError(treq, err) 187 } 188 189 // callerProcedure is a rpc header but recevied in application headers, so moving this header to transprotRequest 190 // by updating treq.CallerProcedure. 191 treq = headerCallerProcedureToRequest(treq, &headers) 192 treq.Headers = headers 193 194 if tcall, ok := call.(tchannelCall); ok { 195 tracer := h.tracer 196 ctx = tchannel.ExtractInboundSpan(ctx, tcall.InboundCall, headers.Items(), tracer) 197 } 198 199 buf := bufferpool.Get() 200 defer bufferpool.Put(buf) 201 202 body, err := call.Arg3Reader() 203 if err != nil { 204 return err 205 } 206 207 if _, err = buf.ReadFrom(body); err != nil { 208 return err 209 } 210 if err = body.Close(); err != nil { 211 return err 212 } 213 214 treq.Body = bytes.NewReader(buf.Bytes()) 215 treq.BodySize = buf.Len() 216 217 if err := transport.ValidateRequest(treq); err != nil { 218 return err 219 } 220 221 spec, err := h.router.Choose(ctx, treq) 222 if err != nil { 223 if yarpcerrors.FromError(err).Code() != yarpcerrors.CodeUnimplemented { 224 return err 225 } 226 if tcall, ok := call.(tchannelCall); !ok { 227 if m, ok := h.existing[call.MethodString()]; ok { 228 m.Handle(ctx, tcall.InboundCall) 229 return nil 230 } 231 } 232 return err 233 } 234 235 if err := transport.ValidateRequestContext(ctx); err != nil { 236 return err 237 } 238 switch spec.Type() { 239 case transport.Unary: 240 return transport.InvokeUnaryHandler(transport.UnaryInvokeRequest{ 241 Context: ctx, 242 StartTime: start, 243 Request: treq, 244 ResponseWriter: responseWriter, 245 Handler: spec.Unary(), 246 Logger: h.logger, 247 }) 248 249 default: 250 return yarpcerrors.Newf(yarpcerrors.CodeUnimplemented, "transport tchannel does not handle %s handlers", spec.Type().String()) 251 } 252 } 253 254 type handlerWriter struct { 255 failedWith error 256 format tchannel.Format 257 headers transport.Headers 258 buffer *bufferpool.Buffer 259 response inboundCallResponse 260 applicationError bool 261 headerCase headerCase 262 } 263 264 func newHandlerWriter(response inboundCallResponse, format tchannel.Format, headerCase headerCase) responseWriter { 265 return &handlerWriter{ 266 response: response, 267 format: format, 268 headerCase: headerCase, 269 } 270 } 271 272 func (hw *handlerWriter) AddHeaders(h transport.Headers) { 273 for k, v := range h.OriginalItems() { 274 if isReservedHeaderKey(k) { 275 hw.failedWith = appendError(hw.failedWith, fmt.Errorf("cannot use reserved header key: %s", k)) 276 return 277 } 278 hw.AddHeader(k, v) 279 } 280 } 281 282 func (hw *handlerWriter) AddHeader(key string, value string) { 283 hw.headers = hw.headers.With(key, value) 284 } 285 286 func (hw *handlerWriter) SetApplicationError() { 287 hw.applicationError = true 288 } 289 290 func (hw *handlerWriter) SetApplicationErrorMeta(applicationErrorMeta *transport.ApplicationErrorMeta) { 291 if applicationErrorMeta == nil { 292 return 293 } 294 if applicationErrorMeta.Code != nil { 295 hw.AddHeader(ApplicationErrorCodeHeaderKey, strconv.Itoa(int(*applicationErrorMeta.Code))) 296 } 297 if applicationErrorMeta.Name != "" { 298 hw.AddHeader(ApplicationErrorNameHeaderKey, applicationErrorMeta.Name) 299 } 300 if applicationErrorMeta.Details != "" { 301 hw.AddHeader(ApplicationErrorDetailsHeaderKey, truncateAppErrDetails(applicationErrorMeta.Details)) 302 } 303 } 304 305 func truncateAppErrDetails(val string) string { 306 if len(val) <= _maxAppErrDetailsHeaderLen { 307 return val 308 } 309 stripIndex := _maxAppErrDetailsHeaderLen - len(_truncatedHeaderMessage) 310 return val[:stripIndex] + _truncatedHeaderMessage 311 } 312 313 func (hw *handlerWriter) IsApplicationError() bool { 314 return hw.applicationError 315 } 316 317 func (hw *handlerWriter) Write(s []byte) (int, error) { 318 if hw.failedWith != nil { 319 return 0, hw.failedWith 320 } 321 322 if hw.buffer == nil { 323 hw.buffer = bufferpool.Get() 324 } 325 326 n, err := hw.buffer.Write(s) 327 if err != nil { 328 hw.failedWith = appendError(hw.failedWith, err) 329 } 330 return n, err 331 } 332 333 func (hw *handlerWriter) Close() error { 334 retErr := hw.failedWith 335 if hw.IsApplicationError() { 336 if err := hw.response.SetApplicationError(); err != nil { 337 retErr = appendError(retErr, fmt.Errorf("SetApplicationError() failed: %v", err)) 338 } 339 } 340 341 headers := headerMap(hw.headers, hw.headerCase) 342 retErr = appendError(retErr, writeHeaders(hw.format, headers, nil, hw.response.Arg2Writer)) 343 344 // Arg3Writer must be opened and closed regardless of if there is data 345 // However, if there is a system error, we do not want to do this 346 bodyWriter, err := hw.response.Arg3Writer() 347 if err != nil { 348 return appendError(retErr, err) 349 } 350 defer func() { retErr = appendError(retErr, bodyWriter.Close()) }() 351 if hw.buffer != nil { 352 if _, err := hw.buffer.WriteTo(bodyWriter); err != nil { 353 return appendError(retErr, err) 354 } 355 } 356 357 return retErr 358 } 359 360 func (hw *handlerWriter) ReleaseBuffer() { 361 if hw.buffer != nil { 362 bufferpool.Put(hw.buffer) 363 hw.buffer = nil 364 } 365 } 366 367 func getSystemError(err error) error { 368 if _, ok := err.(tchannel.SystemError); ok { 369 return err 370 } 371 if !yarpcerrors.IsStatus(err) { 372 return tchannel.NewSystemError(tchannel.ErrCodeUnexpected, err.Error()) 373 } 374 status := yarpcerrors.FromError(err) 375 tchannelCode, ok := _codeToTChannelCode[status.Code()] 376 if !ok { 377 tchannelCode = tchannel.ErrCodeUnexpected 378 } 379 return tchannel.NewSystemError(tchannelCode, status.Message()) 380 } 381 382 func appendError(left error, right error) error { 383 if _, ok := left.(tchannel.SystemError); ok { 384 return left 385 } 386 if _, ok := right.(tchannel.SystemError); ok { 387 return right 388 } 389 return multierr.Append(left, right) 390 }