go.uber.org/yarpc@v1.72.1/transport/tchannel/handler_test.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package tchannel
    22  
    23  import (
    24  	"bytes"
    25  	"context"
    26  	"errors"
    27  	"fmt"
    28  	"strconv"
    29  	"strings"
    30  	"testing"
    31  	"time"
    32  
    33  	"github.com/golang/mock/gomock"
    34  	"github.com/stretchr/testify/assert"
    35  	"github.com/stretchr/testify/require"
    36  	"github.com/uber/tchannel-go"
    37  	"go.uber.org/yarpc/api/transport"
    38  	"go.uber.org/yarpc/api/transport/transporttest"
    39  	"go.uber.org/yarpc/encoding/json"
    40  	"go.uber.org/yarpc/encoding/raw"
    41  	"go.uber.org/yarpc/internal/routertest"
    42  	"go.uber.org/yarpc/internal/testtime"
    43  	pkgerrors "go.uber.org/yarpc/pkg/errors"
    44  	"go.uber.org/yarpc/yarpcerrors"
    45  	"go.uber.org/zap"
    46  	"go.uber.org/zap/zapcore"
    47  	"go.uber.org/zap/zaptest/observer"
    48  )
    49  
    50  func TestHandlerErrors(t *testing.T) {
    51  	mockCtrl := gomock.NewController(t)
    52  	defer mockCtrl.Finish()
    53  
    54  	tests := []struct {
    55  		desc              string
    56  		format            tchannel.Format
    57  		headers           []byte
    58  		wantHeaders       map[string]string
    59  		newResponseWriter func(inboundCallResponse, tchannel.Format, headerCase) responseWriter
    60  		recorder          recorder
    61  		wantLogLevel      zapcore.Level
    62  		wantLogMessage    string
    63  		wantErrMessage    string
    64  	}{
    65  		{
    66  			desc:              "test tchannel json handler",
    67  			format:            tchannel.JSON,
    68  			headers:           []byte(`{"Rpc-Header-Foo": "bar"}`),
    69  			wantHeaders:       map[string]string{"rpc-header-foo": "bar"},
    70  			newResponseWriter: newHandlerWriter,
    71  			recorder:          newResponseRecorder(),
    72  		},
    73  		{
    74  			desc:   "test tchannel thrift handler",
    75  			format: tchannel.Thrift,
    76  			headers: []byte{
    77  				0x00, 0x01, // 1 header
    78  				0x00, 0x03, 'F', 'o', 'o', // Foo
    79  				0x00, 0x03, 'B', 'a', 'r', // Bar
    80  			},
    81  			wantHeaders:       map[string]string{"foo": "Bar"},
    82  			newResponseWriter: newHandlerWriter,
    83  			recorder:          newResponseRecorder(),
    84  		},
    85  		{
    86  			desc:              "test responseWriter.Close() failure logging",
    87  			format:            tchannel.JSON,
    88  			headers:           []byte(`{"Rpc-Header-Foo": "bar"}`),
    89  			wantHeaders:       map[string]string{"rpc-header-foo": "bar"},
    90  			newResponseWriter: newFaultyHandlerWriter,
    91  			recorder:          newResponseRecorder(),
    92  			wantLogLevel:      zapcore.ErrorLevel,
    93  			wantLogMessage:    "responseWriter failed to close",
    94  			wantErrMessage:    "faultyHandlerWriter failed to close",
    95  		},
    96  		{
    97  			desc:              "test SendSystemError() failure logging",
    98  			format:            tchannel.JSON,
    99  			headers:           []byte(`{"Rpc-Header-Foo": "bar"}`),
   100  			wantHeaders:       map[string]string{"rpc-header-foo": "bar"},
   101  			newResponseWriter: newFaultyHandlerWriter,
   102  			recorder:          newFaultyResponseRecorder(),
   103  			wantLogLevel:      zapcore.ErrorLevel,
   104  			wantLogMessage:    "SendSystemError failed",
   105  			wantErrMessage:    "SendSystemError failure",
   106  		},
   107  	}
   108  
   109  	for _, tt := range tests {
   110  		core, logs := observer.New(zapcore.ErrorLevel)
   111  		rpcHandler := transporttest.NewMockUnaryHandler(mockCtrl)
   112  		router := transporttest.NewMockRouter(mockCtrl)
   113  
   114  		spec := transport.NewUnaryHandlerSpec(rpcHandler)
   115  
   116  		tchHandler := handler{router: router, logger: zap.New(core).Named("tchannel"), newResponseWriter: tt.newResponseWriter}
   117  
   118  		router.EXPECT().Choose(gomock.Any(), routertest.NewMatcher().
   119  			WithService("service").
   120  			WithProcedure("hello"),
   121  		).Return(spec, nil)
   122  
   123  		rpcHandler.EXPECT().Handle(
   124  			transporttest.NewContextMatcher(t),
   125  			transporttest.NewRequestMatcher(t,
   126  				&transport.Request{
   127  					Caller:          "caller",
   128  					Service:         "service",
   129  					Transport:       "tchannel",
   130  					Headers:         transport.HeadersFromMap(tt.wantHeaders),
   131  					Encoding:        transport.Encoding(tt.format),
   132  					Procedure:       "hello",
   133  					ShardKey:        "shard",
   134  					RoutingKey:      "routekey",
   135  					RoutingDelegate: "routedelegate",
   136  					Body:            bytes.NewReader([]byte("world")),
   137  				}),
   138  			gomock.Any(),
   139  		).Return(nil)
   140  
   141  		respRecorder := tt.recorder
   142  
   143  		ctx, cancel := context.WithTimeout(context.Background(), testtime.Second)
   144  		defer cancel()
   145  		tchHandler.handle(ctx, &fakeInboundCall{
   146  			service:         "service",
   147  			caller:          "caller",
   148  			format:          tt.format,
   149  			method:          "hello",
   150  			shardkey:        "shard",
   151  			routingkey:      "routekey",
   152  			routingdelegate: "routedelegate",
   153  			arg2:            tt.headers,
   154  			arg3:            []byte("world"),
   155  			resp:            respRecorder,
   156  		})
   157  
   158  		getLog := func() observer.LoggedEntry {
   159  			entries := logs.TakeAll()
   160  			return entries[0]
   161  		}
   162  
   163  		if tt.wantLogMessage != "" {
   164  			log := getLog()
   165  			logContext := log.ContextMap()
   166  			assert.Equal(t, tt.wantLogLevel, log.Entry.Level, "Unexpected log level")
   167  			assert.Equal(t, tt.wantLogMessage, log.Entry.Message, "Unexpected log message written")
   168  			assert.Equal(t, tt.wantErrMessage, logContext["error"], "Unexpected error message")
   169  			assert.Equal(t, "tchannel", log.LoggerName, "Unexpected logger name")
   170  			assert.Error(t, respRecorder.SystemError(), "Error expected with logging")
   171  		}
   172  
   173  	}
   174  }
   175  
   176  func TestHandlerFailures(t *testing.T) {
   177  	tests := []struct {
   178  		desc              string
   179  		ctx               context.Context // context to use in the callm a default one is used otherwise.
   180  		ctxFunc           func() (context.Context, context.CancelFunc)
   181  		sendCall          *fakeInboundCall
   182  		expectCall        func(*transporttest.MockUnaryHandler)
   183  		wantStatus        tchannel.SystemErrCode // expected status
   184  		newResponseWriter func(inboundCallResponse, tchannel.Format, headerCase) responseWriter
   185  		recorder          recorder
   186  		wantLogLevel      zapcore.Level
   187  		wantLogMessage    string
   188  		wantErrMessage    string
   189  	}{
   190  		{
   191  			desc: "no timeout on context",
   192  			ctx:  context.Background(),
   193  			sendCall: &fakeInboundCall{
   194  				service: "foo",
   195  				caller:  "bar",
   196  				method:  "hello",
   197  				format:  tchannel.Raw,
   198  				arg2:    []byte{0x00, 0x00},
   199  				arg3:    []byte{0x00},
   200  			},
   201  			wantStatus:        tchannel.ErrCodeBadRequest,
   202  			newResponseWriter: newHandlerWriter,
   203  			recorder:          newResponseRecorder(),
   204  			wantLogLevel:      zapcore.ErrorLevel,
   205  		},
   206  		{
   207  			desc: "arg2 reader error",
   208  			sendCall: &fakeInboundCall{
   209  				service: "foo",
   210  				caller:  "bar",
   211  				method:  "hello",
   212  				format:  tchannel.Raw,
   213  				arg2:    nil,
   214  				arg3:    []byte{0x00},
   215  			},
   216  			wantStatus:        tchannel.ErrCodeBadRequest,
   217  			newResponseWriter: newHandlerWriter,
   218  			recorder:          newResponseRecorder(),
   219  			wantLogLevel:      zapcore.ErrorLevel,
   220  		},
   221  		{
   222  			desc: "arg2 parse error",
   223  			sendCall: &fakeInboundCall{
   224  				service: "foo",
   225  				caller:  "bar",
   226  				method:  "hello",
   227  				format:  tchannel.JSON,
   228  				arg2:    []byte("{not valid JSON}"),
   229  				arg3:    []byte{0x00},
   230  			},
   231  			wantStatus:        tchannel.ErrCodeBadRequest,
   232  			newResponseWriter: newHandlerWriter,
   233  			recorder:          newResponseRecorder(),
   234  			wantLogLevel:      zapcore.ErrorLevel,
   235  		},
   236  		{
   237  			desc: "arg3 reader error",
   238  			sendCall: &fakeInboundCall{
   239  				service: "foo",
   240  				caller:  "bar",
   241  				method:  "hello",
   242  				format:  tchannel.Raw,
   243  				arg2:    []byte{0x00, 0x00},
   244  				arg3:    nil,
   245  			},
   246  			wantStatus:        tchannel.ErrCodeUnexpected,
   247  			newResponseWriter: newHandlerWriter,
   248  			recorder:          newResponseRecorder(),
   249  			wantLogLevel:      zapcore.ErrorLevel,
   250  		},
   251  		{
   252  			desc: "internal error",
   253  			sendCall: &fakeInboundCall{
   254  				service: "foo",
   255  				caller:  "bar",
   256  				method:  "hello",
   257  				format:  tchannel.Raw,
   258  				arg2:    []byte{0x00, 0x00},
   259  				arg3:    []byte{0x00},
   260  			},
   261  			expectCall: func(h *transporttest.MockUnaryHandler) {
   262  				h.EXPECT().Handle(
   263  					transporttest.NewContextMatcher(t, transporttest.ContextTTL(testtime.Second)),
   264  					transporttest.NewRequestMatcher(
   265  						t, &transport.Request{
   266  							Caller:    "bar",
   267  							Service:   "foo",
   268  							Transport: "tchannel",
   269  							Encoding:  raw.Encoding,
   270  							Procedure: "hello",
   271  							Body:      bytes.NewReader([]byte{0x00}),
   272  						},
   273  					), gomock.Any(),
   274  				).Return(fmt.Errorf("great sadness"))
   275  			},
   276  			wantStatus:        tchannel.ErrCodeUnexpected,
   277  			newResponseWriter: newHandlerWriter,
   278  			recorder:          newResponseRecorder(),
   279  			wantLogLevel:      zapcore.ErrorLevel,
   280  		},
   281  		{
   282  			desc: "arg3 encode error",
   283  			sendCall: &fakeInboundCall{
   284  				service: "foo",
   285  				caller:  "bar",
   286  				method:  "hello",
   287  				format:  tchannel.JSON,
   288  				arg2:    []byte("{}"),
   289  				arg3:    []byte("{}"),
   290  			},
   291  			expectCall: func(h *transporttest.MockUnaryHandler) {
   292  				req := &transport.Request{
   293  					Caller:    "bar",
   294  					Service:   "foo",
   295  					Transport: "tchannel",
   296  					Encoding:  json.Encoding,
   297  					Procedure: "hello",
   298  					Body:      bytes.NewReader([]byte("{}")),
   299  				}
   300  				h.EXPECT().Handle(
   301  					transporttest.NewContextMatcher(t, transporttest.ContextTTL(testtime.Second)),
   302  					transporttest.NewRequestMatcher(t, req),
   303  					gomock.Any(),
   304  				).Return(
   305  					pkgerrors.ResponseBodyEncodeError(req, errors.New(
   306  						"serialization derp",
   307  					)))
   308  			},
   309  			wantStatus:        tchannel.ErrCodeBadRequest,
   310  			newResponseWriter: newHandlerWriter,
   311  			recorder:          newResponseRecorder(),
   312  			wantLogLevel:      zapcore.ErrorLevel,
   313  		},
   314  		{
   315  			desc: "handler timeout",
   316  			ctxFunc: func() (context.Context, context.CancelFunc) {
   317  				return context.WithTimeout(context.Background(), testtime.Millisecond)
   318  			},
   319  			sendCall: &fakeInboundCall{
   320  				service: "foo",
   321  				caller:  "bar",
   322  				method:  "waituntiltimeout",
   323  				format:  tchannel.Raw,
   324  				arg2:    []byte{0x00, 0x00},
   325  				arg3:    []byte{0x00},
   326  			},
   327  			expectCall: func(h *transporttest.MockUnaryHandler) {
   328  				req := &transport.Request{
   329  					Caller:    "bar",
   330  					Service:   "foo",
   331  					Transport: "tchannel",
   332  					Encoding:  raw.Encoding,
   333  					Procedure: "waituntiltimeout",
   334  					Body:      bytes.NewReader([]byte{0x00}),
   335  				}
   336  				h.EXPECT().Handle(
   337  					transporttest.NewContextMatcher(
   338  						t, transporttest.ContextTTL(testtime.Millisecond)),
   339  					transporttest.NewRequestMatcher(t, req),
   340  					gomock.Any(),
   341  				).Do(func(ctx context.Context, _ *transport.Request, _ transport.ResponseWriter) {
   342  					<-ctx.Done()
   343  				}).Return(context.DeadlineExceeded)
   344  			},
   345  			wantStatus:        tchannel.ErrCodeTimeout,
   346  			newResponseWriter: newHandlerWriter,
   347  			recorder:          newResponseRecorder(),
   348  			wantLogLevel:      zapcore.ErrorLevel,
   349  		},
   350  		{
   351  			desc: "handler panic",
   352  			sendCall: &fakeInboundCall{
   353  				service: "foo",
   354  				caller:  "bar",
   355  				method:  "panic",
   356  				format:  tchannel.Raw,
   357  				arg2:    []byte{0x00, 0x00},
   358  				arg3:    []byte{0x00},
   359  			},
   360  			expectCall: func(h *transporttest.MockUnaryHandler) {
   361  				req := &transport.Request{
   362  					Caller:    "bar",
   363  					Service:   "foo",
   364  					Transport: "tchannel",
   365  					Encoding:  raw.Encoding,
   366  					Procedure: "panic",
   367  					Body:      bytes.NewReader([]byte{0x00}),
   368  				}
   369  				h.EXPECT().Handle(
   370  					transporttest.NewContextMatcher(
   371  						t, transporttest.ContextTTL(testtime.Second)),
   372  					transporttest.NewRequestMatcher(t, req),
   373  					gomock.Any(),
   374  				).Do(func(context.Context, *transport.Request, transport.ResponseWriter) {
   375  					panic("oops I panicked!")
   376  				})
   377  			},
   378  			wantStatus:        tchannel.ErrCodeUnexpected,
   379  			newResponseWriter: newHandlerWriter,
   380  			recorder:          newResponseRecorder(),
   381  			wantLogLevel:      zapcore.ErrorLevel,
   382  			wantLogMessage:    "Unary handler panicked",
   383  		},
   384  		{
   385  			desc: "test SendSystemError() error logging",
   386  			sendCall: &fakeInboundCall{
   387  				service: "foo",
   388  				caller:  "bar",
   389  				method:  "hello",
   390  				format:  tchannel.Raw,
   391  				arg2:    nil,
   392  				arg3:    []byte{0x00},
   393  			},
   394  			wantStatus:        tchannel.ErrCodeBadRequest,
   395  			newResponseWriter: newHandlerWriter,
   396  			recorder:          newFaultyResponseRecorder(),
   397  			wantLogLevel:      zapcore.ErrorLevel,
   398  			wantLogMessage:    "SendSystemError failed",
   399  			wantErrMessage:    "SendSystemError failure",
   400  		},
   401  	}
   402  
   403  	for _, tt := range tests {
   404  		t.Run(tt.desc, func(t *testing.T) {
   405  
   406  			ctx, cancel := context.WithTimeout(context.Background(), testtime.Second)
   407  			if tt.ctx != nil {
   408  				ctx = tt.ctx
   409  			} else if tt.ctxFunc != nil {
   410  				ctx, cancel = tt.ctxFunc()
   411  			}
   412  			defer cancel()
   413  
   414  			core, logs := observer.New(zapcore.ErrorLevel)
   415  			mockCtrl := gomock.NewController(t)
   416  			defer mockCtrl.Finish()
   417  
   418  			thandler := transporttest.NewMockUnaryHandler(mockCtrl)
   419  			spec := transport.NewUnaryHandlerSpec(thandler)
   420  
   421  			if tt.expectCall != nil {
   422  				tt.expectCall(thandler)
   423  			}
   424  
   425  			resp := tt.recorder
   426  			tt.sendCall.resp = resp
   427  
   428  			router := transporttest.NewMockRouter(mockCtrl)
   429  			router.EXPECT().Choose(gomock.Any(), routertest.NewMatcher().
   430  				WithService(tt.sendCall.service).
   431  				WithProcedure(tt.sendCall.method),
   432  			).Return(spec, nil).AnyTimes()
   433  
   434  			handler{router: router, logger: zap.New(core).Named("tchannel"), newResponseWriter: tt.newResponseWriter}.handle(ctx, tt.sendCall)
   435  			err := resp.SystemError()
   436  			require.Error(t, err, "expected error for %q", tt.desc)
   437  
   438  			systemErr, isSystemErr := err.(tchannel.SystemError)
   439  			require.True(t, isSystemErr, "expected %v for %q to be a system error", err, tt.desc)
   440  			assert.Equal(t, tt.wantStatus, systemErr.Code(), tt.desc)
   441  
   442  			getLog := func() observer.LoggedEntry {
   443  				entries := logs.TakeAll()
   444  				return entries[0]
   445  			}
   446  
   447  			if tt.wantLogMessage != "" {
   448  				log := getLog()
   449  				logContext := log.ContextMap()
   450  				assert.Equal(t, tt.wantLogLevel, log.Entry.Level, "Unexpected log level")
   451  				assert.Equal(t, tt.wantLogMessage, log.Entry.Message, "Unexpected log message written")
   452  				assert.Equal(t, "tchannel", log.LoggerName, "Unexpected logger name")
   453  				if tt.wantErrMessage != "" {
   454  					assert.Equal(t, tt.wantErrMessage, logContext["error"], "Unexpected error message")
   455  				}
   456  			}
   457  		})
   458  	}
   459  }
   460  
   461  func TestResponseWriter(t *testing.T) {
   462  	yErrAborted := yarpcerrors.CodeAborted
   463  
   464  	tests := []struct {
   465  		name             string
   466  		format           tchannel.Format
   467  		apply            func(responseWriter)
   468  		arg2             map[string]string // use map since ordering isn't guaranteed
   469  		arg3             []byte
   470  		applicationError bool
   471  		headerCase       headerCase
   472  	}{
   473  		{
   474  			name:   "raw lowercase headers",
   475  			format: tchannel.Raw,
   476  			apply: func(w responseWriter) {
   477  				headers := transport.HeadersFromMap(map[string]string{"foo": "bar"})
   478  				w.AddHeaders(headers)
   479  				_, err := w.Write([]byte("hello "))
   480  				require.NoError(t, err)
   481  				_, err = w.Write([]byte("world"))
   482  				require.NoError(t, err)
   483  			},
   484  			arg2: map[string]string{"foo": "bar"},
   485  			arg3: []byte("hello world"),
   486  		},
   487  		{
   488  			name:   "raw mixed-case headers",
   489  			format: tchannel.Raw,
   490  			apply: func(w responseWriter) {
   491  				headers := transport.HeadersFromMap(map[string]string{"FoO": "bAr"})
   492  				w.AddHeaders(headers)
   493  				_, err := w.Write([]byte("hello "))
   494  				require.NoError(t, err)
   495  				_, err = w.Write([]byte("world"))
   496  				require.NoError(t, err)
   497  			},
   498  			arg2:       map[string]string{"FoO": "bAr"},
   499  			arg3:       []byte("hello world"),
   500  			headerCase: originalHeaderCase,
   501  		},
   502  		{
   503  			name:   "raw multiple writes",
   504  			format: tchannel.Raw,
   505  			apply: func(w responseWriter) {
   506  				_, err := w.Write([]byte("foo"))
   507  				require.NoError(t, err)
   508  				_, err = w.Write([]byte("bar"))
   509  				require.NoError(t, err)
   510  			},
   511  			arg2: nil,
   512  			arg3: []byte("foobar"),
   513  		},
   514  		{
   515  			name:   "json lowercase headers",
   516  			format: tchannel.JSON,
   517  			apply: func(w responseWriter) {
   518  				headers := transport.HeadersFromMap(map[string]string{"foo": "bar"})
   519  				w.AddHeaders(headers)
   520  
   521  				_, err := w.Write([]byte("{}"))
   522  				require.NoError(t, err)
   523  			},
   524  			arg2: map[string]string{"foo": "bar"},
   525  			arg3: []byte("{}"),
   526  		},
   527  		{
   528  			name:   "json mixed-case headers",
   529  			format: tchannel.JSON,
   530  			apply: func(w responseWriter) {
   531  				headers := transport.HeadersFromMap(map[string]string{"FoO": "bAr"})
   532  				w.AddHeaders(headers)
   533  
   534  				_, err := w.Write([]byte("{}"))
   535  				require.NoError(t, err)
   536  			},
   537  			arg2:       map[string]string{"FoO": "bAr"},
   538  			arg3:       []byte("{}"),
   539  			headerCase: originalHeaderCase,
   540  		},
   541  		{
   542  			name:   "json empty",
   543  			format: tchannel.JSON,
   544  			apply: func(w responseWriter) {
   545  				_, err := w.Write([]byte("{}"))
   546  				require.NoError(t, err)
   547  			},
   548  			arg2: nil,
   549  			arg3: []byte("{}"),
   550  		},
   551  		{
   552  			name:   "application error write",
   553  			format: tchannel.Raw,
   554  			apply: func(w responseWriter) {
   555  				w.SetApplicationError()
   556  				w.SetApplicationErrorMeta(
   557  					&transport.ApplicationErrorMeta{
   558  						Name:    "bAz",
   559  						Code:    &yErrAborted,
   560  						Details: "App Error Details",
   561  					},
   562  				)
   563  				_, err := w.Write([]byte("hello"))
   564  				require.NoError(t, err)
   565  			},
   566  			arg2: map[string]string{
   567  				"$rpc$-application-error-code":    "10",
   568  				"$rpc$-application-error-name":    "bAz",
   569  				"$rpc$-application-error-details": "App Error Details",
   570  			},
   571  			arg3:             []byte("hello"),
   572  			applicationError: true,
   573  		},
   574  	}
   575  
   576  	for _, tt := range tests {
   577  		t.Run(tt.name, func(t *testing.T) {
   578  
   579  			call := &fakeInboundCall{format: tt.format}
   580  			resp := newResponseRecorder()
   581  			call.resp = resp
   582  
   583  			w := newHandlerWriter(call.Response(), call.Format(), tt.headerCase)
   584  			tt.apply(w)
   585  			assert.NoError(t, w.Close())
   586  
   587  			assert.Nil(t, resp.systemErr)
   588  
   589  			// read headers as a map since ordering is not guaranteed
   590  			gotHeaders, err := readHeaders(tt.format, func() (tchannel.ArgReader, error) { return resp.arg2, nil })
   591  			require.NoError(t, err)
   592  
   593  			assert.Equal(t, tt.arg2, gotHeaders.OriginalItems(), "headers mismatch")
   594  			assert.Equal(t, tt.arg3, resp.arg3.Bytes())
   595  
   596  			if tt.applicationError {
   597  				assert.True(t, resp.applicationError, "expected an application error")
   598  			}
   599  		})
   600  	}
   601  }
   602  
   603  func TestResponseWriterFailure(t *testing.T) {
   604  	tests := []struct {
   605  		setupResp func(*responseRecorder)
   606  		messages  []string
   607  	}{
   608  		{
   609  			setupResp: func(rr *responseRecorder) {
   610  				rr.arg2 = nil
   611  			},
   612  			messages: []string{"no arg2 provided"},
   613  		},
   614  		{
   615  			setupResp: func(rr *responseRecorder) {
   616  				rr.arg3 = nil
   617  			},
   618  			messages: []string{"no arg3 provided"},
   619  		},
   620  	}
   621  
   622  	for _, tt := range tests {
   623  		resp := newResponseRecorder()
   624  		tt.setupResp(resp)
   625  
   626  		w := newHandlerWriter(resp, tchannel.Raw, canonicalizedHeaderCase)
   627  		_, err := w.Write([]byte("foo"))
   628  		assert.NoError(t, err)
   629  		_, err = w.Write([]byte("bar"))
   630  		assert.NoError(t, err)
   631  		err = w.Close()
   632  		assert.Error(t, err)
   633  		for _, msg := range tt.messages {
   634  			assert.Contains(t, err.Error(), msg)
   635  		}
   636  	}
   637  }
   638  
   639  func TestResponseWriterEmptyBodyHeaders(t *testing.T) {
   640  	res := newResponseRecorder()
   641  	w := newHandlerWriter(res, tchannel.Raw, canonicalizedHeaderCase)
   642  
   643  	w.AddHeaders(transport.NewHeaders().With("foo", "bar"))
   644  	require.NoError(t, w.Close())
   645  
   646  	assert.NotEmpty(t, res.arg2.Bytes(), "headers must not be empty")
   647  	assert.Empty(t, res.arg3.Bytes(), "body must be empty but was %#v", res.arg3.Bytes())
   648  	assert.False(t, res.applicationError, "application error must be false")
   649  }
   650  
   651  func TestGetSystemError(t *testing.T) {
   652  	tests := []struct {
   653  		giveErr  error
   654  		wantCode tchannel.SystemErrCode
   655  	}{
   656  		{
   657  			giveErr:  yarpcerrors.UnavailableErrorf("test"),
   658  			wantCode: tchannel.ErrCodeDeclined,
   659  		},
   660  		{
   661  			giveErr:  errors.New("test"),
   662  			wantCode: tchannel.ErrCodeUnexpected,
   663  		},
   664  		{
   665  			giveErr:  yarpcerrors.InvalidArgumentErrorf("test"),
   666  			wantCode: tchannel.ErrCodeBadRequest,
   667  		},
   668  		{
   669  			giveErr:  tchannel.NewSystemError(tchannel.ErrCodeBusy, "test"),
   670  			wantCode: tchannel.ErrCodeBusy,
   671  		},
   672  		{
   673  			giveErr:  yarpcerrors.Newf(yarpcerrors.Code(1235), "test"),
   674  			wantCode: tchannel.ErrCodeUnexpected,
   675  		},
   676  	}
   677  	for i, tt := range tests {
   678  		t.Run(strconv.Itoa(i), func(t *testing.T) {
   679  			gotErr := getSystemError(tt.giveErr)
   680  			tchErr, ok := gotErr.(tchannel.SystemError)
   681  			require.True(t, ok, "did not return tchannel error")
   682  			assert.Equal(t, tt.wantCode, tchErr.Code())
   683  		})
   684  	}
   685  }
   686  
   687  func TestHandlerSystemErrorLogs(t *testing.T) {
   688  	mockCtrl := gomock.NewController(t)
   689  	defer mockCtrl.Finish()
   690  
   691  	zapCore, observedLogs := observer.New(zapcore.ErrorLevel)
   692  	router := transporttest.NewMockRouter(mockCtrl)
   693  	transportHandler := &testUnaryHandler{}
   694  	spec := transport.NewUnaryHandlerSpec(transportHandler)
   695  
   696  	tchannelHandler := handler{
   697  		router:            router,
   698  		logger:            zap.New(zapCore),
   699  		newResponseWriter: newHandlerWriter,
   700  	}
   701  
   702  	router.EXPECT().Choose(gomock.Any(), gomock.Any()).Return(spec, nil).Times(4)
   703  
   704  	inboundCall := &fakeInboundCall{
   705  		service: "foo-service",
   706  		caller:  "foo-caller",
   707  		method:  "foo-method",
   708  		format:  tchannel.JSON,
   709  		arg2:    []byte{},
   710  		arg3:    []byte{},
   711  		resp:    newFaultyResponseRecorder(),
   712  	}
   713  
   714  	t.Run("client awaiting response", func(t *testing.T) {
   715  		t.Run("handler success", func(t *testing.T) {
   716  			transportHandler.reset()
   717  
   718  			ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   719  			defer cancel()
   720  
   721  			tchannelHandler.handle(ctx, inboundCall)
   722  			logs := observedLogs.TakeAll()
   723  			require.Len(t, logs, 2, "unexpected number of logs")
   724  
   725  			assert.Equal(t, logs[0].Message, "SendSystemError failed", "unexpected log message")
   726  			assert.Equal(t, logs[1].Message, "responseWriter failed to close", "unexpected log message")
   727  		})
   728  
   729  		t.Run("handler error", func(t *testing.T) {
   730  			transportHandler.reset()
   731  			transportHandler.err = errors.New("handler error")
   732  
   733  			ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   734  			defer cancel()
   735  
   736  			tchannelHandler.handle(ctx, inboundCall)
   737  			logs := observedLogs.TakeAll()
   738  			require.Len(t, logs, 1, "unexpected number of logs")
   739  
   740  			assert.Equal(t, logs[0].Message, "SendSystemError failed", "unexpected log message")
   741  		})
   742  	})
   743  
   744  	t.Run("client timed out", func(t *testing.T) {
   745  		t.Run("handler success", func(t *testing.T) {
   746  			transportHandler.reset()
   747  
   748  			ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
   749  			defer cancel()
   750  
   751  			transportHandler.fn = func() { <-ctx.Done() } // ensure client times out
   752  
   753  			tchannelHandler.handle(ctx, inboundCall)
   754  			require.Empty(t, observedLogs.TakeAll(), "expected no logs")
   755  		})
   756  
   757  		t.Run("handler err", func(t *testing.T) {
   758  			transportHandler.reset()
   759  			transportHandler.err = errors.New("handler error")
   760  
   761  			ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
   762  			defer cancel()
   763  
   764  			transportHandler.fn = func() { <-ctx.Done() } // ensure client times out
   765  
   766  			tchannelHandler.handle(ctx, inboundCall)
   767  			require.Empty(t, observedLogs.TakeAll(), "expected no logs")
   768  		})
   769  	})
   770  }
   771  
   772  func TestTruncatedHeader(t *testing.T) {
   773  	tests := []struct {
   774  		name         string
   775  		value        string
   776  		wantTruncate bool
   777  	}{
   778  		{
   779  			name:  "no-op",
   780  			value: "foo bar",
   781  		},
   782  		{
   783  			name:  "max",
   784  			value: strings.Repeat("a", _maxAppErrDetailsHeaderLen),
   785  		},
   786  		{
   787  			name:         "truncate",
   788  			value:        strings.Repeat("b", _maxAppErrDetailsHeaderLen*2),
   789  			wantTruncate: true,
   790  		},
   791  	}
   792  
   793  	for _, tt := range tests {
   794  		t.Run(tt.name, func(t *testing.T) {
   795  			got := truncateAppErrDetails(tt.value)
   796  
   797  			if !tt.wantTruncate {
   798  				assert.Equal(t, tt.value, got, "expected no-op")
   799  				return
   800  			}
   801  
   802  			assert.True(t, strings.HasSuffix(got, _truncatedHeaderMessage), "unexpected truncate suffix")
   803  			assert.Len(t, got, _maxAppErrDetailsHeaderLen, "did not truncate")
   804  		})
   805  	}
   806  }
   807  
   808  func TestRpcServiceHeader(t *testing.T) {
   809  	hw := &handlerWriter{}
   810  	h := handler{
   811  		headerCase: canonicalizedHeaderCase,
   812  		newResponseWriter: func(inboundCallResponse, tchannel.Format, headerCase) responseWriter {
   813  			return hw
   814  		},
   815  	}
   816  	resp := newResponseRecorder()
   817  	expectedServiceHeader := "foo"
   818  	call := &fakeInboundCall{
   819  		service: expectedServiceHeader,
   820  		resp:    resp,
   821  	}
   822  	ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
   823  	defer cancel()
   824  
   825  	h.handle(ctx, call)
   826  	assert.Equal(t, expectedServiceHeader, hw.headers.OriginalItems()[ServiceHeaderKey])
   827  
   828  	h.excludeServiceHeaderInResponse = true
   829  	hw.headers.Del(ServiceHeaderKey)
   830  	h.handle(ctx, call)
   831  	assert.Equal(t, "", hw.headers.OriginalItems()[ServiceHeaderKey])
   832  }
   833  
   834  type testUnaryHandler struct {
   835  	err error
   836  	fn  func()
   837  }
   838  
   839  func (h *testUnaryHandler) Handle(context.Context, *transport.Request, transport.ResponseWriter) error {
   840  	if h.fn != nil {
   841  		h.fn()
   842  	}
   843  	return h.err
   844  }
   845  
   846  func (h *testUnaryHandler) reset() {
   847  	h.err = nil
   848  	h.fn = nil
   849  }