go.uber.org/yarpc@v1.72.1/transport/http/handler_test.go (about) 1 // Copyright (c) 2022 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 21 package http 22 23 import ( 24 "bytes" 25 "context" 26 "fmt" 27 "io/ioutil" 28 "net/http" 29 "net/http/httptest" 30 "strconv" 31 "strings" 32 "testing" 33 "time" 34 35 "github.com/golang/mock/gomock" 36 "github.com/opentracing/opentracing-go" 37 "github.com/stretchr/testify/assert" 38 "github.com/stretchr/testify/require" 39 yarpc "go.uber.org/yarpc" 40 "go.uber.org/yarpc/api/transport" 41 "go.uber.org/yarpc/api/transport/transporttest" 42 "go.uber.org/yarpc/encoding/raw" 43 "go.uber.org/yarpc/internal/routertest" 44 "go.uber.org/yarpc/yarpcerrors" 45 ) 46 47 func TestHandlerSuccess(t *testing.T) { 48 mockCtrl := gomock.NewController(t) 49 defer mockCtrl.Finish() 50 51 headers := make(http.Header) 52 headers.Set(CallerHeader, "moe") 53 headers.Set(EncodingHeader, "raw") 54 headers.Set(TTLMSHeader, "1000") 55 headers.Set(ProcedureHeader, "nyuck") 56 headers.Set(ServiceHeader, "curly") 57 headers.Set(ShardKeyHeader, "shard") 58 headers.Set(RoutingKeyHeader, "routekey") 59 headers.Set(RoutingDelegateHeader, "routedelegate") 60 headers.Set(CallerProcedureHeader, "callerprocedure") 61 62 router := transporttest.NewMockRouter(mockCtrl) 63 rpcHandler := transporttest.NewMockUnaryHandler(mockCtrl) 64 spec := transport.NewUnaryHandlerSpec(rpcHandler) 65 66 router.EXPECT().Choose(gomock.Any(), routertest.NewMatcher(). 67 WithService("curly"). 68 WithProcedure("nyuck"), 69 ).Return(spec, nil) 70 71 rpcHandler.EXPECT().Handle( 72 transporttest.NewContextMatcher(t, 73 transporttest.ContextTTL(time.Second), 74 ), 75 transporttest.NewRequestMatcher( 76 t, &transport.Request{ 77 Caller: "moe", 78 Service: "curly", 79 Transport: "http", 80 Encoding: raw.Encoding, 81 Procedure: "nyuck", 82 ShardKey: "shard", 83 RoutingKey: "routekey", 84 RoutingDelegate: "routedelegate", 85 CallerProcedure: "callerprocedure", 86 Body: bytes.NewReader([]byte("Nyuck Nyuck")), 87 }, 88 ), 89 gomock.Any(), 90 ).Return(nil) 91 92 httpHandler := handler{router: router, tracer: &opentracing.NoopTracer{}, bothResponseError: true} 93 req := &http.Request{ 94 Method: "POST", 95 Header: headers, 96 Body: ioutil.NopCloser(bytes.NewReader([]byte("Nyuck Nyuck"))), 97 } 98 rw := httptest.NewRecorder() 99 httpHandler.ServeHTTP(rw, req) 100 code := rw.Code 101 assert.Equal(t, code, 200, "expected 200 code") 102 assert.Equal(t, rw.Body.String(), "") 103 } 104 105 func TestHandlerHeaders(t *testing.T) { 106 mockCtrl := gomock.NewController(t) 107 defer mockCtrl.Finish() 108 109 tests := []struct { 110 giveEncoding string 111 giveHeaders http.Header 112 grabHeaders map[string]struct{} 113 114 wantTTL time.Duration 115 wantHeaders map[string]string 116 }{ 117 { 118 giveEncoding: "json", 119 giveHeaders: http.Header{ 120 TTLMSHeader: {"1000"}, 121 "Rpc-Header-Foo": {"bar"}, 122 "X-Baz": {"bat"}, 123 }, 124 grabHeaders: map[string]struct{}{"x-baz": {}}, 125 wantTTL: time.Second, 126 wantHeaders: map[string]string{ 127 "foo": "bar", 128 "x-baz": "bat", 129 }, 130 }, 131 { 132 giveEncoding: "raw", 133 giveHeaders: http.Header{ 134 TTLMSHeader: {"100"}, 135 "Rpc-Foo": {"ignored"}, 136 }, 137 wantTTL: 100 * time.Millisecond, 138 wantHeaders: map[string]string{}, 139 }, 140 { 141 giveEncoding: "thrift", 142 giveHeaders: http.Header{ 143 TTLMSHeader: {"1000"}, 144 }, 145 wantTTL: time.Second, 146 wantHeaders: map[string]string{}, 147 }, 148 { 149 giveEncoding: "proto", 150 giveHeaders: http.Header{ 151 TTLMSHeader: {"1000"}, 152 }, 153 wantTTL: time.Second, 154 wantHeaders: map[string]string{}, 155 }, 156 } 157 158 for _, tt := range tests { 159 router := transporttest.NewMockRouter(mockCtrl) 160 rpcHandler := transporttest.NewMockUnaryHandler(mockCtrl) 161 spec := transport.NewUnaryHandlerSpec(rpcHandler) 162 163 router.EXPECT().Choose(gomock.Any(), routertest.NewMatcher(). 164 WithService("service"). 165 WithProcedure("hello"), 166 ).Return(spec, nil) 167 168 httpHandler := handler{router: router, tracer: &opentracing.NoopTracer{}, grabHeaders: tt.grabHeaders, bothResponseError: true} 169 170 rpcHandler.EXPECT().Handle( 171 transporttest.NewContextMatcher(t, 172 transporttest.ContextTTL(tt.wantTTL), 173 ), 174 transporttest.NewRequestMatcher(t, 175 &transport.Request{ 176 Caller: "caller", 177 Service: "service", 178 Transport: "http", 179 Encoding: transport.Encoding(tt.giveEncoding), 180 Procedure: "hello", 181 Headers: transport.HeadersFromMap(tt.wantHeaders), 182 Body: bytes.NewReader([]byte("world")), 183 }), 184 gomock.Any(), 185 ).Return(nil) 186 187 headers := http.Header{} 188 for k, vs := range tt.giveHeaders { 189 for _, v := range vs { 190 headers.Add(k, v) 191 } 192 } 193 headers.Set(CallerHeader, "caller") 194 headers.Set(ServiceHeader, "service") 195 headers.Set(EncodingHeader, tt.giveEncoding) 196 headers.Set(ProcedureHeader, "hello") 197 198 req := &http.Request{ 199 Method: "POST", 200 Header: headers, 201 Body: ioutil.NopCloser(bytes.NewReader([]byte("world"))), 202 } 203 rw := httptest.NewRecorder() 204 httpHandler.ServeHTTP(rw, req) 205 assert.Equal(t, 200, rw.Code, "expected 200 status code") 206 assert.Equal(t, getContentType(transport.Encoding(tt.giveEncoding)), rw.Header().Get("Content-Type")) 207 } 208 } 209 210 func TestHandlerFailures(t *testing.T) { 211 mockCtrl := gomock.NewController(t) 212 defer mockCtrl.Finish() 213 214 service, procedure := "fake", "hello" 215 216 baseHeaders := make(http.Header) 217 baseHeaders.Set(CallerHeader, "somecaller") 218 baseHeaders.Set(EncodingHeader, "raw") 219 baseHeaders.Set(TTLMSHeader, "1000") 220 baseHeaders.Set(ProcedureHeader, procedure) 221 baseHeaders.Set(ServiceHeader, service) 222 223 headersWithBadTTL := headerCopyWithout(baseHeaders, TTLMSHeader) 224 headersWithBadTTL.Set(TTLMSHeader, "not a number") 225 226 tests := []struct { 227 req *http.Request 228 229 // if we expect an error as a result of the TTL 230 errTTL bool 231 wantCode yarpcerrors.Code 232 }{ 233 { 234 req: &http.Request{Method: "GET"}, 235 wantCode: yarpcerrors.CodeNotFound, 236 }, 237 { 238 req: &http.Request{ 239 Method: "POST", 240 Header: headerCopyWithout(baseHeaders, CallerHeader), 241 }, 242 wantCode: yarpcerrors.CodeInvalidArgument, 243 }, 244 { 245 req: &http.Request{ 246 Method: "POST", 247 Header: headerCopyWithout(baseHeaders, ServiceHeader), 248 }, 249 wantCode: yarpcerrors.CodeInvalidArgument, 250 }, 251 { 252 req: &http.Request{ 253 Method: "POST", 254 Header: headerCopyWithout(baseHeaders, ProcedureHeader), 255 }, 256 wantCode: yarpcerrors.CodeInvalidArgument, 257 }, 258 { 259 req: &http.Request{ 260 Method: "POST", 261 Header: headerCopyWithout(baseHeaders, TTLMSHeader), 262 }, 263 wantCode: yarpcerrors.CodeInvalidArgument, 264 errTTL: true, 265 }, 266 { 267 req: &http.Request{ 268 Method: "POST", 269 }, 270 wantCode: yarpcerrors.CodeInvalidArgument, 271 }, 272 { 273 req: &http.Request{ 274 Method: "POST", 275 Header: headersWithBadTTL, 276 }, 277 wantCode: yarpcerrors.CodeInvalidArgument, 278 errTTL: true, 279 }, 280 } 281 282 for _, tt := range tests { 283 req := tt.req 284 if req.Body == nil { 285 req.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) 286 } 287 288 reg := transporttest.NewMockRouter(mockCtrl) 289 290 if tt.errTTL { 291 // since TTL is checked after we've determined the transport type, if we have an 292 // error with TTL it will be discovered after we read from the router 293 spec := transport.NewUnaryHandlerSpec(panickedHandler{}) 294 reg.EXPECT().Choose(gomock.Any(), routertest.NewMatcher(). 295 WithService(service). 296 WithProcedure(procedure), 297 ).Return(spec, nil) 298 } 299 300 h := handler{router: reg, tracer: &opentracing.NoopTracer{}, bothResponseError: true} 301 302 rw := httptest.NewRecorder() 303 h.ServeHTTP(rw, tt.req) 304 305 httpStatusCode := rw.Code 306 assert.True(t, httpStatusCode >= 400 && httpStatusCode < 500, "expected 400 level code") 307 code := statusCodeToBestCode(httpStatusCode) 308 assert.Equal(t, tt.wantCode, code) 309 assert.Equal(t, "text/plain; charset=utf8", rw.Header().Get("Content-Type")) 310 } 311 } 312 313 func TestHandlerInternalFailure(t *testing.T) { 314 mockCtrl := gomock.NewController(t) 315 defer mockCtrl.Finish() 316 317 headers := make(http.Header) 318 headers.Set(CallerHeader, "somecaller") 319 headers.Set(EncodingHeader, "raw") 320 headers.Set(TTLMSHeader, "1000") 321 headers.Set(ProcedureHeader, "hello") 322 headers.Set(ServiceHeader, "fake") 323 324 request := http.Request{ 325 Method: "POST", 326 Header: headers, 327 Body: ioutil.NopCloser(bytes.NewReader([]byte{})), 328 } 329 330 rpcHandler := transporttest.NewMockUnaryHandler(mockCtrl) 331 rpcHandler.EXPECT().Handle( 332 transporttest.NewContextMatcher(t, transporttest.ContextTTL(time.Second)), 333 transporttest.NewRequestMatcher( 334 t, &transport.Request{ 335 Caller: "somecaller", 336 Service: "fake", 337 Transport: "http", 338 Encoding: raw.Encoding, 339 Procedure: "hello", 340 Body: bytes.NewReader([]byte{}), 341 }, 342 ), 343 gomock.Any(), 344 ).Return(fmt.Errorf("great sadness")) 345 346 router := transporttest.NewMockRouter(mockCtrl) 347 spec := transport.NewUnaryHandlerSpec(rpcHandler) 348 349 router.EXPECT().Choose(gomock.Any(), routertest.NewMatcher(). 350 WithService("fake"). 351 WithProcedure("hello"), 352 ).Return(spec, nil) 353 354 httpHandler := handler{router: router, tracer: &opentracing.NoopTracer{}, bothResponseError: true} 355 httpResponse := httptest.NewRecorder() 356 httpHandler.ServeHTTP(httpResponse, &request) 357 358 code := httpResponse.Code 359 assert.True(t, code >= 500 && code < 600, "expected 500 level response") 360 assert.Equal(t, 361 `error for service "fake" and procedure "hello": great sadness`+"\n", 362 httpResponse.Body.String()) 363 } 364 365 type panickedHandler struct{} 366 367 func (th panickedHandler) Handle(context.Context, *transport.Request, transport.ResponseWriter) error { 368 panic("oops I panicked!") 369 } 370 371 func TestHandlerPanic(t *testing.T) { 372 httpTransport := NewTransport() 373 inbound := httpTransport.NewInbound("localhost:0") 374 serverDispatcher := yarpc.NewDispatcher(yarpc.Config{ 375 Name: "yarpc-test", 376 Inbounds: yarpc.Inbounds{inbound}, 377 }) 378 serverDispatcher.Register([]transport.Procedure{ 379 { 380 Name: "panic", 381 HandlerSpec: transport.NewUnaryHandlerSpec(panickedHandler{}), 382 }, 383 }) 384 385 require.NoError(t, serverDispatcher.Start()) 386 defer serverDispatcher.Stop() 387 388 clientDispatcher := yarpc.NewDispatcher(yarpc.Config{ 389 Name: "yarpc-test-client", 390 Outbounds: yarpc.Outbounds{ 391 "yarpc-test": { 392 Unary: httpTransport.NewSingleOutbound(fmt.Sprintf("http://%s", inbound.Addr().String())), 393 }, 394 }, 395 }) 396 require.NoError(t, clientDispatcher.Start()) 397 defer clientDispatcher.Stop() 398 399 client := raw.New(clientDispatcher.ClientConfig("yarpc-test")) 400 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 401 defer cancel() 402 _, err := client.Call(ctx, "panic", []byte{}) 403 404 assert.Equal(t, yarpcerrors.CodeUnknown, yarpcerrors.FromError(err).Code()) 405 } 406 407 func headerCopyWithout(headers http.Header, names ...string) http.Header { 408 newHeaders := make(http.Header) 409 for k, vs := range headers { 410 for _, v := range vs { 411 newHeaders.Add(k, v) 412 } 413 } 414 415 for _, k := range names { 416 newHeaders.Del(k) 417 } 418 419 return newHeaders 420 } 421 422 func TestResponseWriter(t *testing.T) { 423 const ( 424 appErrDetails = "thrift ex message" 425 appErrName = "thrift ex name" 426 ) 427 appErrCode := yarpcerrors.CodeAborted 428 429 recorder := httptest.NewRecorder() 430 writer := newResponseWriter(recorder) 431 432 headers := transport.HeadersFromMap(map[string]string{ 433 "foo": "bar", 434 "shard-key": "123", 435 }) 436 writer.AddHeaders(headers) 437 438 writer.SetApplicationErrorMeta(&transport.ApplicationErrorMeta{ 439 Details: appErrDetails, 440 Name: appErrName, 441 Code: &appErrCode, 442 }) 443 444 _, err := writer.Write([]byte("hello")) 445 require.NoError(t, err) 446 writer.Close(http.StatusOK) 447 448 assert.Equal(t, "bar", recorder.Header().Get("rpc-header-foo")) 449 assert.Equal(t, "123", recorder.Header().Get("rpc-header-shard-key")) 450 assert.Equal(t, "hello", recorder.Body.String()) 451 452 assert.Equal(t, appErrDetails, recorder.Header().Get(_applicationErrorDetailsHeader)) 453 assert.Equal(t, appErrName, recorder.Header().Get(_applicationErrorNameHeader)) 454 assert.Equal(t, strconv.Itoa(int(appErrCode)), recorder.Header().Get(_applicationErrorCodeHeader)) 455 } 456 457 func TestTruncatedHeader(t *testing.T) { 458 tests := []struct { 459 name string 460 value string 461 wantTruncate bool 462 }{ 463 { 464 name: "no-op", 465 value: "foo bar", 466 }, 467 { 468 name: "max", 469 value: strings.Repeat("a", _maxAppErrDetailsHeaderLen), 470 }, 471 { 472 name: "truncate", 473 value: strings.Repeat("b", _maxAppErrDetailsHeaderLen*2), 474 wantTruncate: true, 475 }, 476 } 477 478 for _, tt := range tests { 479 t.Run(tt.name, func(t *testing.T) { 480 got := truncateAppErrDetails(tt.value) 481 482 if !tt.wantTruncate { 483 assert.Equal(t, tt.value, got, "expected no-op") 484 return 485 } 486 487 assert.True(t, strings.HasSuffix(got, _truncatedHeaderMessage), "unexpected truncate suffix") 488 assert.Len(t, got, _maxAppErrDetailsHeaderLen, "did not truncate") 489 }) 490 } 491 }