go.uber.org/yarpc@v1.72.1/transport/tchannel/handler_test.go (about) 1 // Copyright (c) 2022 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 21 package tchannel 22 23 import ( 24 "bytes" 25 "context" 26 "errors" 27 "fmt" 28 "strconv" 29 "strings" 30 "testing" 31 "time" 32 33 "github.com/golang/mock/gomock" 34 "github.com/stretchr/testify/assert" 35 "github.com/stretchr/testify/require" 36 "github.com/uber/tchannel-go" 37 "go.uber.org/yarpc/api/transport" 38 "go.uber.org/yarpc/api/transport/transporttest" 39 "go.uber.org/yarpc/encoding/json" 40 "go.uber.org/yarpc/encoding/raw" 41 "go.uber.org/yarpc/internal/routertest" 42 "go.uber.org/yarpc/internal/testtime" 43 pkgerrors "go.uber.org/yarpc/pkg/errors" 44 "go.uber.org/yarpc/yarpcerrors" 45 "go.uber.org/zap" 46 "go.uber.org/zap/zapcore" 47 "go.uber.org/zap/zaptest/observer" 48 ) 49 50 func TestHandlerErrors(t *testing.T) { 51 mockCtrl := gomock.NewController(t) 52 defer mockCtrl.Finish() 53 54 tests := []struct { 55 desc string 56 format tchannel.Format 57 headers []byte 58 wantHeaders map[string]string 59 newResponseWriter func(inboundCallResponse, tchannel.Format, headerCase) responseWriter 60 recorder recorder 61 wantLogLevel zapcore.Level 62 wantLogMessage string 63 wantErrMessage string 64 }{ 65 { 66 desc: "test tchannel json handler", 67 format: tchannel.JSON, 68 headers: []byte(`{"Rpc-Header-Foo": "bar"}`), 69 wantHeaders: map[string]string{"rpc-header-foo": "bar"}, 70 newResponseWriter: newHandlerWriter, 71 recorder: newResponseRecorder(), 72 }, 73 { 74 desc: "test tchannel thrift handler", 75 format: tchannel.Thrift, 76 headers: []byte{ 77 0x00, 0x01, // 1 header 78 0x00, 0x03, 'F', 'o', 'o', // Foo 79 0x00, 0x03, 'B', 'a', 'r', // Bar 80 }, 81 wantHeaders: map[string]string{"foo": "Bar"}, 82 newResponseWriter: newHandlerWriter, 83 recorder: newResponseRecorder(), 84 }, 85 { 86 desc: "test responseWriter.Close() failure logging", 87 format: tchannel.JSON, 88 headers: []byte(`{"Rpc-Header-Foo": "bar"}`), 89 wantHeaders: map[string]string{"rpc-header-foo": "bar"}, 90 newResponseWriter: newFaultyHandlerWriter, 91 recorder: newResponseRecorder(), 92 wantLogLevel: zapcore.ErrorLevel, 93 wantLogMessage: "responseWriter failed to close", 94 wantErrMessage: "faultyHandlerWriter failed to close", 95 }, 96 { 97 desc: "test SendSystemError() failure logging", 98 format: tchannel.JSON, 99 headers: []byte(`{"Rpc-Header-Foo": "bar"}`), 100 wantHeaders: map[string]string{"rpc-header-foo": "bar"}, 101 newResponseWriter: newFaultyHandlerWriter, 102 recorder: newFaultyResponseRecorder(), 103 wantLogLevel: zapcore.ErrorLevel, 104 wantLogMessage: "SendSystemError failed", 105 wantErrMessage: "SendSystemError failure", 106 }, 107 } 108 109 for _, tt := range tests { 110 core, logs := observer.New(zapcore.ErrorLevel) 111 rpcHandler := transporttest.NewMockUnaryHandler(mockCtrl) 112 router := transporttest.NewMockRouter(mockCtrl) 113 114 spec := transport.NewUnaryHandlerSpec(rpcHandler) 115 116 tchHandler := handler{router: router, logger: zap.New(core).Named("tchannel"), newResponseWriter: tt.newResponseWriter} 117 118 router.EXPECT().Choose(gomock.Any(), routertest.NewMatcher(). 119 WithService("service"). 120 WithProcedure("hello"), 121 ).Return(spec, nil) 122 123 rpcHandler.EXPECT().Handle( 124 transporttest.NewContextMatcher(t), 125 transporttest.NewRequestMatcher(t, 126 &transport.Request{ 127 Caller: "caller", 128 Service: "service", 129 Transport: "tchannel", 130 Headers: transport.HeadersFromMap(tt.wantHeaders), 131 Encoding: transport.Encoding(tt.format), 132 Procedure: "hello", 133 ShardKey: "shard", 134 RoutingKey: "routekey", 135 RoutingDelegate: "routedelegate", 136 Body: bytes.NewReader([]byte("world")), 137 }), 138 gomock.Any(), 139 ).Return(nil) 140 141 respRecorder := tt.recorder 142 143 ctx, cancel := context.WithTimeout(context.Background(), testtime.Second) 144 defer cancel() 145 tchHandler.handle(ctx, &fakeInboundCall{ 146 service: "service", 147 caller: "caller", 148 format: tt.format, 149 method: "hello", 150 shardkey: "shard", 151 routingkey: "routekey", 152 routingdelegate: "routedelegate", 153 arg2: tt.headers, 154 arg3: []byte("world"), 155 resp: respRecorder, 156 }) 157 158 getLog := func() observer.LoggedEntry { 159 entries := logs.TakeAll() 160 return entries[0] 161 } 162 163 if tt.wantLogMessage != "" { 164 log := getLog() 165 logContext := log.ContextMap() 166 assert.Equal(t, tt.wantLogLevel, log.Entry.Level, "Unexpected log level") 167 assert.Equal(t, tt.wantLogMessage, log.Entry.Message, "Unexpected log message written") 168 assert.Equal(t, tt.wantErrMessage, logContext["error"], "Unexpected error message") 169 assert.Equal(t, "tchannel", log.LoggerName, "Unexpected logger name") 170 assert.Error(t, respRecorder.SystemError(), "Error expected with logging") 171 } 172 173 } 174 } 175 176 func TestHandlerFailures(t *testing.T) { 177 tests := []struct { 178 desc string 179 ctx context.Context // context to use in the callm a default one is used otherwise. 180 ctxFunc func() (context.Context, context.CancelFunc) 181 sendCall *fakeInboundCall 182 expectCall func(*transporttest.MockUnaryHandler) 183 wantStatus tchannel.SystemErrCode // expected status 184 newResponseWriter func(inboundCallResponse, tchannel.Format, headerCase) responseWriter 185 recorder recorder 186 wantLogLevel zapcore.Level 187 wantLogMessage string 188 wantErrMessage string 189 }{ 190 { 191 desc: "no timeout on context", 192 ctx: context.Background(), 193 sendCall: &fakeInboundCall{ 194 service: "foo", 195 caller: "bar", 196 method: "hello", 197 format: tchannel.Raw, 198 arg2: []byte{0x00, 0x00}, 199 arg3: []byte{0x00}, 200 }, 201 wantStatus: tchannel.ErrCodeBadRequest, 202 newResponseWriter: newHandlerWriter, 203 recorder: newResponseRecorder(), 204 wantLogLevel: zapcore.ErrorLevel, 205 }, 206 { 207 desc: "arg2 reader error", 208 sendCall: &fakeInboundCall{ 209 service: "foo", 210 caller: "bar", 211 method: "hello", 212 format: tchannel.Raw, 213 arg2: nil, 214 arg3: []byte{0x00}, 215 }, 216 wantStatus: tchannel.ErrCodeBadRequest, 217 newResponseWriter: newHandlerWriter, 218 recorder: newResponseRecorder(), 219 wantLogLevel: zapcore.ErrorLevel, 220 }, 221 { 222 desc: "arg2 parse error", 223 sendCall: &fakeInboundCall{ 224 service: "foo", 225 caller: "bar", 226 method: "hello", 227 format: tchannel.JSON, 228 arg2: []byte("{not valid JSON}"), 229 arg3: []byte{0x00}, 230 }, 231 wantStatus: tchannel.ErrCodeBadRequest, 232 newResponseWriter: newHandlerWriter, 233 recorder: newResponseRecorder(), 234 wantLogLevel: zapcore.ErrorLevel, 235 }, 236 { 237 desc: "arg3 reader error", 238 sendCall: &fakeInboundCall{ 239 service: "foo", 240 caller: "bar", 241 method: "hello", 242 format: tchannel.Raw, 243 arg2: []byte{0x00, 0x00}, 244 arg3: nil, 245 }, 246 wantStatus: tchannel.ErrCodeUnexpected, 247 newResponseWriter: newHandlerWriter, 248 recorder: newResponseRecorder(), 249 wantLogLevel: zapcore.ErrorLevel, 250 }, 251 { 252 desc: "internal error", 253 sendCall: &fakeInboundCall{ 254 service: "foo", 255 caller: "bar", 256 method: "hello", 257 format: tchannel.Raw, 258 arg2: []byte{0x00, 0x00}, 259 arg3: []byte{0x00}, 260 }, 261 expectCall: func(h *transporttest.MockUnaryHandler) { 262 h.EXPECT().Handle( 263 transporttest.NewContextMatcher(t, transporttest.ContextTTL(testtime.Second)), 264 transporttest.NewRequestMatcher( 265 t, &transport.Request{ 266 Caller: "bar", 267 Service: "foo", 268 Transport: "tchannel", 269 Encoding: raw.Encoding, 270 Procedure: "hello", 271 Body: bytes.NewReader([]byte{0x00}), 272 }, 273 ), gomock.Any(), 274 ).Return(fmt.Errorf("great sadness")) 275 }, 276 wantStatus: tchannel.ErrCodeUnexpected, 277 newResponseWriter: newHandlerWriter, 278 recorder: newResponseRecorder(), 279 wantLogLevel: zapcore.ErrorLevel, 280 }, 281 { 282 desc: "arg3 encode error", 283 sendCall: &fakeInboundCall{ 284 service: "foo", 285 caller: "bar", 286 method: "hello", 287 format: tchannel.JSON, 288 arg2: []byte("{}"), 289 arg3: []byte("{}"), 290 }, 291 expectCall: func(h *transporttest.MockUnaryHandler) { 292 req := &transport.Request{ 293 Caller: "bar", 294 Service: "foo", 295 Transport: "tchannel", 296 Encoding: json.Encoding, 297 Procedure: "hello", 298 Body: bytes.NewReader([]byte("{}")), 299 } 300 h.EXPECT().Handle( 301 transporttest.NewContextMatcher(t, transporttest.ContextTTL(testtime.Second)), 302 transporttest.NewRequestMatcher(t, req), 303 gomock.Any(), 304 ).Return( 305 pkgerrors.ResponseBodyEncodeError(req, errors.New( 306 "serialization derp", 307 ))) 308 }, 309 wantStatus: tchannel.ErrCodeBadRequest, 310 newResponseWriter: newHandlerWriter, 311 recorder: newResponseRecorder(), 312 wantLogLevel: zapcore.ErrorLevel, 313 }, 314 { 315 desc: "handler timeout", 316 ctxFunc: func() (context.Context, context.CancelFunc) { 317 return context.WithTimeout(context.Background(), testtime.Millisecond) 318 }, 319 sendCall: &fakeInboundCall{ 320 service: "foo", 321 caller: "bar", 322 method: "waituntiltimeout", 323 format: tchannel.Raw, 324 arg2: []byte{0x00, 0x00}, 325 arg3: []byte{0x00}, 326 }, 327 expectCall: func(h *transporttest.MockUnaryHandler) { 328 req := &transport.Request{ 329 Caller: "bar", 330 Service: "foo", 331 Transport: "tchannel", 332 Encoding: raw.Encoding, 333 Procedure: "waituntiltimeout", 334 Body: bytes.NewReader([]byte{0x00}), 335 } 336 h.EXPECT().Handle( 337 transporttest.NewContextMatcher( 338 t, transporttest.ContextTTL(testtime.Millisecond)), 339 transporttest.NewRequestMatcher(t, req), 340 gomock.Any(), 341 ).Do(func(ctx context.Context, _ *transport.Request, _ transport.ResponseWriter) { 342 <-ctx.Done() 343 }).Return(context.DeadlineExceeded) 344 }, 345 wantStatus: tchannel.ErrCodeTimeout, 346 newResponseWriter: newHandlerWriter, 347 recorder: newResponseRecorder(), 348 wantLogLevel: zapcore.ErrorLevel, 349 }, 350 { 351 desc: "handler panic", 352 sendCall: &fakeInboundCall{ 353 service: "foo", 354 caller: "bar", 355 method: "panic", 356 format: tchannel.Raw, 357 arg2: []byte{0x00, 0x00}, 358 arg3: []byte{0x00}, 359 }, 360 expectCall: func(h *transporttest.MockUnaryHandler) { 361 req := &transport.Request{ 362 Caller: "bar", 363 Service: "foo", 364 Transport: "tchannel", 365 Encoding: raw.Encoding, 366 Procedure: "panic", 367 Body: bytes.NewReader([]byte{0x00}), 368 } 369 h.EXPECT().Handle( 370 transporttest.NewContextMatcher( 371 t, transporttest.ContextTTL(testtime.Second)), 372 transporttest.NewRequestMatcher(t, req), 373 gomock.Any(), 374 ).Do(func(context.Context, *transport.Request, transport.ResponseWriter) { 375 panic("oops I panicked!") 376 }) 377 }, 378 wantStatus: tchannel.ErrCodeUnexpected, 379 newResponseWriter: newHandlerWriter, 380 recorder: newResponseRecorder(), 381 wantLogLevel: zapcore.ErrorLevel, 382 wantLogMessage: "Unary handler panicked", 383 }, 384 { 385 desc: "test SendSystemError() error logging", 386 sendCall: &fakeInboundCall{ 387 service: "foo", 388 caller: "bar", 389 method: "hello", 390 format: tchannel.Raw, 391 arg2: nil, 392 arg3: []byte{0x00}, 393 }, 394 wantStatus: tchannel.ErrCodeBadRequest, 395 newResponseWriter: newHandlerWriter, 396 recorder: newFaultyResponseRecorder(), 397 wantLogLevel: zapcore.ErrorLevel, 398 wantLogMessage: "SendSystemError failed", 399 wantErrMessage: "SendSystemError failure", 400 }, 401 } 402 403 for _, tt := range tests { 404 t.Run(tt.desc, func(t *testing.T) { 405 406 ctx, cancel := context.WithTimeout(context.Background(), testtime.Second) 407 if tt.ctx != nil { 408 ctx = tt.ctx 409 } else if tt.ctxFunc != nil { 410 ctx, cancel = tt.ctxFunc() 411 } 412 defer cancel() 413 414 core, logs := observer.New(zapcore.ErrorLevel) 415 mockCtrl := gomock.NewController(t) 416 defer mockCtrl.Finish() 417 418 thandler := transporttest.NewMockUnaryHandler(mockCtrl) 419 spec := transport.NewUnaryHandlerSpec(thandler) 420 421 if tt.expectCall != nil { 422 tt.expectCall(thandler) 423 } 424 425 resp := tt.recorder 426 tt.sendCall.resp = resp 427 428 router := transporttest.NewMockRouter(mockCtrl) 429 router.EXPECT().Choose(gomock.Any(), routertest.NewMatcher(). 430 WithService(tt.sendCall.service). 431 WithProcedure(tt.sendCall.method), 432 ).Return(spec, nil).AnyTimes() 433 434 handler{router: router, logger: zap.New(core).Named("tchannel"), newResponseWriter: tt.newResponseWriter}.handle(ctx, tt.sendCall) 435 err := resp.SystemError() 436 require.Error(t, err, "expected error for %q", tt.desc) 437 438 systemErr, isSystemErr := err.(tchannel.SystemError) 439 require.True(t, isSystemErr, "expected %v for %q to be a system error", err, tt.desc) 440 assert.Equal(t, tt.wantStatus, systemErr.Code(), tt.desc) 441 442 getLog := func() observer.LoggedEntry { 443 entries := logs.TakeAll() 444 return entries[0] 445 } 446 447 if tt.wantLogMessage != "" { 448 log := getLog() 449 logContext := log.ContextMap() 450 assert.Equal(t, tt.wantLogLevel, log.Entry.Level, "Unexpected log level") 451 assert.Equal(t, tt.wantLogMessage, log.Entry.Message, "Unexpected log message written") 452 assert.Equal(t, "tchannel", log.LoggerName, "Unexpected logger name") 453 if tt.wantErrMessage != "" { 454 assert.Equal(t, tt.wantErrMessage, logContext["error"], "Unexpected error message") 455 } 456 } 457 }) 458 } 459 } 460 461 func TestResponseWriter(t *testing.T) { 462 yErrAborted := yarpcerrors.CodeAborted 463 464 tests := []struct { 465 name string 466 format tchannel.Format 467 apply func(responseWriter) 468 arg2 map[string]string // use map since ordering isn't guaranteed 469 arg3 []byte 470 applicationError bool 471 headerCase headerCase 472 }{ 473 { 474 name: "raw lowercase headers", 475 format: tchannel.Raw, 476 apply: func(w responseWriter) { 477 headers := transport.HeadersFromMap(map[string]string{"foo": "bar"}) 478 w.AddHeaders(headers) 479 _, err := w.Write([]byte("hello ")) 480 require.NoError(t, err) 481 _, err = w.Write([]byte("world")) 482 require.NoError(t, err) 483 }, 484 arg2: map[string]string{"foo": "bar"}, 485 arg3: []byte("hello world"), 486 }, 487 { 488 name: "raw mixed-case headers", 489 format: tchannel.Raw, 490 apply: func(w responseWriter) { 491 headers := transport.HeadersFromMap(map[string]string{"FoO": "bAr"}) 492 w.AddHeaders(headers) 493 _, err := w.Write([]byte("hello ")) 494 require.NoError(t, err) 495 _, err = w.Write([]byte("world")) 496 require.NoError(t, err) 497 }, 498 arg2: map[string]string{"FoO": "bAr"}, 499 arg3: []byte("hello world"), 500 headerCase: originalHeaderCase, 501 }, 502 { 503 name: "raw multiple writes", 504 format: tchannel.Raw, 505 apply: func(w responseWriter) { 506 _, err := w.Write([]byte("foo")) 507 require.NoError(t, err) 508 _, err = w.Write([]byte("bar")) 509 require.NoError(t, err) 510 }, 511 arg2: nil, 512 arg3: []byte("foobar"), 513 }, 514 { 515 name: "json lowercase headers", 516 format: tchannel.JSON, 517 apply: func(w responseWriter) { 518 headers := transport.HeadersFromMap(map[string]string{"foo": "bar"}) 519 w.AddHeaders(headers) 520 521 _, err := w.Write([]byte("{}")) 522 require.NoError(t, err) 523 }, 524 arg2: map[string]string{"foo": "bar"}, 525 arg3: []byte("{}"), 526 }, 527 { 528 name: "json mixed-case headers", 529 format: tchannel.JSON, 530 apply: func(w responseWriter) { 531 headers := transport.HeadersFromMap(map[string]string{"FoO": "bAr"}) 532 w.AddHeaders(headers) 533 534 _, err := w.Write([]byte("{}")) 535 require.NoError(t, err) 536 }, 537 arg2: map[string]string{"FoO": "bAr"}, 538 arg3: []byte("{}"), 539 headerCase: originalHeaderCase, 540 }, 541 { 542 name: "json empty", 543 format: tchannel.JSON, 544 apply: func(w responseWriter) { 545 _, err := w.Write([]byte("{}")) 546 require.NoError(t, err) 547 }, 548 arg2: nil, 549 arg3: []byte("{}"), 550 }, 551 { 552 name: "application error write", 553 format: tchannel.Raw, 554 apply: func(w responseWriter) { 555 w.SetApplicationError() 556 w.SetApplicationErrorMeta( 557 &transport.ApplicationErrorMeta{ 558 Name: "bAz", 559 Code: &yErrAborted, 560 Details: "App Error Details", 561 }, 562 ) 563 _, err := w.Write([]byte("hello")) 564 require.NoError(t, err) 565 }, 566 arg2: map[string]string{ 567 "$rpc$-application-error-code": "10", 568 "$rpc$-application-error-name": "bAz", 569 "$rpc$-application-error-details": "App Error Details", 570 }, 571 arg3: []byte("hello"), 572 applicationError: true, 573 }, 574 } 575 576 for _, tt := range tests { 577 t.Run(tt.name, func(t *testing.T) { 578 579 call := &fakeInboundCall{format: tt.format} 580 resp := newResponseRecorder() 581 call.resp = resp 582 583 w := newHandlerWriter(call.Response(), call.Format(), tt.headerCase) 584 tt.apply(w) 585 assert.NoError(t, w.Close()) 586 587 assert.Nil(t, resp.systemErr) 588 589 // read headers as a map since ordering is not guaranteed 590 gotHeaders, err := readHeaders(tt.format, func() (tchannel.ArgReader, error) { return resp.arg2, nil }) 591 require.NoError(t, err) 592 593 assert.Equal(t, tt.arg2, gotHeaders.OriginalItems(), "headers mismatch") 594 assert.Equal(t, tt.arg3, resp.arg3.Bytes()) 595 596 if tt.applicationError { 597 assert.True(t, resp.applicationError, "expected an application error") 598 } 599 }) 600 } 601 } 602 603 func TestResponseWriterFailure(t *testing.T) { 604 tests := []struct { 605 setupResp func(*responseRecorder) 606 messages []string 607 }{ 608 { 609 setupResp: func(rr *responseRecorder) { 610 rr.arg2 = nil 611 }, 612 messages: []string{"no arg2 provided"}, 613 }, 614 { 615 setupResp: func(rr *responseRecorder) { 616 rr.arg3 = nil 617 }, 618 messages: []string{"no arg3 provided"}, 619 }, 620 } 621 622 for _, tt := range tests { 623 resp := newResponseRecorder() 624 tt.setupResp(resp) 625 626 w := newHandlerWriter(resp, tchannel.Raw, canonicalizedHeaderCase) 627 _, err := w.Write([]byte("foo")) 628 assert.NoError(t, err) 629 _, err = w.Write([]byte("bar")) 630 assert.NoError(t, err) 631 err = w.Close() 632 assert.Error(t, err) 633 for _, msg := range tt.messages { 634 assert.Contains(t, err.Error(), msg) 635 } 636 } 637 } 638 639 func TestResponseWriterEmptyBodyHeaders(t *testing.T) { 640 res := newResponseRecorder() 641 w := newHandlerWriter(res, tchannel.Raw, canonicalizedHeaderCase) 642 643 w.AddHeaders(transport.NewHeaders().With("foo", "bar")) 644 require.NoError(t, w.Close()) 645 646 assert.NotEmpty(t, res.arg2.Bytes(), "headers must not be empty") 647 assert.Empty(t, res.arg3.Bytes(), "body must be empty but was %#v", res.arg3.Bytes()) 648 assert.False(t, res.applicationError, "application error must be false") 649 } 650 651 func TestGetSystemError(t *testing.T) { 652 tests := []struct { 653 giveErr error 654 wantCode tchannel.SystemErrCode 655 }{ 656 { 657 giveErr: yarpcerrors.UnavailableErrorf("test"), 658 wantCode: tchannel.ErrCodeDeclined, 659 }, 660 { 661 giveErr: errors.New("test"), 662 wantCode: tchannel.ErrCodeUnexpected, 663 }, 664 { 665 giveErr: yarpcerrors.InvalidArgumentErrorf("test"), 666 wantCode: tchannel.ErrCodeBadRequest, 667 }, 668 { 669 giveErr: tchannel.NewSystemError(tchannel.ErrCodeBusy, "test"), 670 wantCode: tchannel.ErrCodeBusy, 671 }, 672 { 673 giveErr: yarpcerrors.Newf(yarpcerrors.Code(1235), "test"), 674 wantCode: tchannel.ErrCodeUnexpected, 675 }, 676 } 677 for i, tt := range tests { 678 t.Run(strconv.Itoa(i), func(t *testing.T) { 679 gotErr := getSystemError(tt.giveErr) 680 tchErr, ok := gotErr.(tchannel.SystemError) 681 require.True(t, ok, "did not return tchannel error") 682 assert.Equal(t, tt.wantCode, tchErr.Code()) 683 }) 684 } 685 } 686 687 func TestHandlerSystemErrorLogs(t *testing.T) { 688 mockCtrl := gomock.NewController(t) 689 defer mockCtrl.Finish() 690 691 zapCore, observedLogs := observer.New(zapcore.ErrorLevel) 692 router := transporttest.NewMockRouter(mockCtrl) 693 transportHandler := &testUnaryHandler{} 694 spec := transport.NewUnaryHandlerSpec(transportHandler) 695 696 tchannelHandler := handler{ 697 router: router, 698 logger: zap.New(zapCore), 699 newResponseWriter: newHandlerWriter, 700 } 701 702 router.EXPECT().Choose(gomock.Any(), gomock.Any()).Return(spec, nil).Times(4) 703 704 inboundCall := &fakeInboundCall{ 705 service: "foo-service", 706 caller: "foo-caller", 707 method: "foo-method", 708 format: tchannel.JSON, 709 arg2: []byte{}, 710 arg3: []byte{}, 711 resp: newFaultyResponseRecorder(), 712 } 713 714 t.Run("client awaiting response", func(t *testing.T) { 715 t.Run("handler success", func(t *testing.T) { 716 transportHandler.reset() 717 718 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 719 defer cancel() 720 721 tchannelHandler.handle(ctx, inboundCall) 722 logs := observedLogs.TakeAll() 723 require.Len(t, logs, 2, "unexpected number of logs") 724 725 assert.Equal(t, logs[0].Message, "SendSystemError failed", "unexpected log message") 726 assert.Equal(t, logs[1].Message, "responseWriter failed to close", "unexpected log message") 727 }) 728 729 t.Run("handler error", func(t *testing.T) { 730 transportHandler.reset() 731 transportHandler.err = errors.New("handler error") 732 733 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 734 defer cancel() 735 736 tchannelHandler.handle(ctx, inboundCall) 737 logs := observedLogs.TakeAll() 738 require.Len(t, logs, 1, "unexpected number of logs") 739 740 assert.Equal(t, logs[0].Message, "SendSystemError failed", "unexpected log message") 741 }) 742 }) 743 744 t.Run("client timed out", func(t *testing.T) { 745 t.Run("handler success", func(t *testing.T) { 746 transportHandler.reset() 747 748 ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) 749 defer cancel() 750 751 transportHandler.fn = func() { <-ctx.Done() } // ensure client times out 752 753 tchannelHandler.handle(ctx, inboundCall) 754 require.Empty(t, observedLogs.TakeAll(), "expected no logs") 755 }) 756 757 t.Run("handler err", func(t *testing.T) { 758 transportHandler.reset() 759 transportHandler.err = errors.New("handler error") 760 761 ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) 762 defer cancel() 763 764 transportHandler.fn = func() { <-ctx.Done() } // ensure client times out 765 766 tchannelHandler.handle(ctx, inboundCall) 767 require.Empty(t, observedLogs.TakeAll(), "expected no logs") 768 }) 769 }) 770 } 771 772 func TestTruncatedHeader(t *testing.T) { 773 tests := []struct { 774 name string 775 value string 776 wantTruncate bool 777 }{ 778 { 779 name: "no-op", 780 value: "foo bar", 781 }, 782 { 783 name: "max", 784 value: strings.Repeat("a", _maxAppErrDetailsHeaderLen), 785 }, 786 { 787 name: "truncate", 788 value: strings.Repeat("b", _maxAppErrDetailsHeaderLen*2), 789 wantTruncate: true, 790 }, 791 } 792 793 for _, tt := range tests { 794 t.Run(tt.name, func(t *testing.T) { 795 got := truncateAppErrDetails(tt.value) 796 797 if !tt.wantTruncate { 798 assert.Equal(t, tt.value, got, "expected no-op") 799 return 800 } 801 802 assert.True(t, strings.HasSuffix(got, _truncatedHeaderMessage), "unexpected truncate suffix") 803 assert.Len(t, got, _maxAppErrDetailsHeaderLen, "did not truncate") 804 }) 805 } 806 } 807 808 func TestRpcServiceHeader(t *testing.T) { 809 hw := &handlerWriter{} 810 h := handler{ 811 headerCase: canonicalizedHeaderCase, 812 newResponseWriter: func(inboundCallResponse, tchannel.Format, headerCase) responseWriter { 813 return hw 814 }, 815 } 816 resp := newResponseRecorder() 817 expectedServiceHeader := "foo" 818 call := &fakeInboundCall{ 819 service: expectedServiceHeader, 820 resp: resp, 821 } 822 ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) 823 defer cancel() 824 825 h.handle(ctx, call) 826 assert.Equal(t, expectedServiceHeader, hw.headers.OriginalItems()[ServiceHeaderKey]) 827 828 h.excludeServiceHeaderInResponse = true 829 hw.headers.Del(ServiceHeaderKey) 830 h.handle(ctx, call) 831 assert.Equal(t, "", hw.headers.OriginalItems()[ServiceHeaderKey]) 832 } 833 834 type testUnaryHandler struct { 835 err error 836 fn func() 837 } 838 839 func (h *testUnaryHandler) Handle(context.Context, *transport.Request, transport.ResponseWriter) error { 840 if h.fn != nil { 841 h.fn() 842 } 843 return h.err 844 } 845 846 func (h *testUnaryHandler) reset() { 847 h.err = nil 848 h.fn = nil 849 }