go.uber.org/yarpc@v1.72.1/transport/grpc/integration_test.go (about) 1 // Copyright (c) 2022 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 21 package grpc 22 23 import ( 24 "bytes" 25 "compress/gzip" 26 "context" 27 "errors" 28 "io" 29 "math" 30 "net" 31 "strings" 32 "sync/atomic" 33 "testing" 34 "time" 35 36 "github.com/gogo/protobuf/proto" 37 gogostatus "github.com/gogo/status" 38 "github.com/opentracing/opentracing-go" 39 "github.com/stretchr/testify/assert" 40 "github.com/stretchr/testify/require" 41 "go.uber.org/goleak" 42 "go.uber.org/multierr" 43 "go.uber.org/yarpc" 44 "go.uber.org/yarpc/api/transport" 45 yarpctls "go.uber.org/yarpc/api/transport/tls" 46 "go.uber.org/yarpc/encoding/protobuf" 47 "go.uber.org/yarpc/internal/clientconfig" 48 "go.uber.org/yarpc/internal/grpcctx" 49 "go.uber.org/yarpc/internal/prototest/example" 50 "go.uber.org/yarpc/internal/prototest/examplepb" 51 "go.uber.org/yarpc/internal/testtime" 52 intyarpcerrors "go.uber.org/yarpc/internal/yarpcerrors" 53 "go.uber.org/yarpc/peer" 54 "go.uber.org/yarpc/peer/hostport" 55 "go.uber.org/yarpc/pkg/procedure" 56 "go.uber.org/yarpc/transport/internal/tls/testscenario" 57 "go.uber.org/yarpc/yarpcerrors" 58 "go.uber.org/zap/zaptest" 59 "google.golang.org/grpc" 60 "google.golang.org/grpc/codes" 61 "google.golang.org/grpc/credentials" 62 "google.golang.org/grpc/status" 63 ) 64 65 func TestYARPCBasic(t *testing.T) { 66 t.Parallel() 67 te := testEnvOptions{ 68 TransportOptions: []TransportOption{ 69 Tracer(opentracing.NoopTracer{}), 70 }, 71 } 72 te.do(t, func(t *testing.T, e *testEnv) { 73 _, err := e.GetValueYARPC(context.Background(), "foo") 74 assert.Equal(t, yarpcerrors.Newf(yarpcerrors.CodeNotFound, "foo"), err) 75 assert.NoError(t, e.SetValueYARPC(context.Background(), "foo", "bar")) 76 value, err := e.GetValueYARPC(context.Background(), "foo") 77 assert.NoError(t, err) 78 assert.Equal(t, "bar", value) 79 }) 80 } 81 82 func TestGRPCBasic(t *testing.T) { 83 t.Parallel() 84 te := testEnvOptions{} 85 te.do(t, func(t *testing.T, e *testEnv) { 86 _, err := e.GetValueGRPC(context.Background(), "foo") 87 assert.Equal(t, status.Error(codes.NotFound, "foo"), err) 88 assert.NoError(t, e.SetValueGRPC(context.Background(), "foo", "bar")) 89 value, err := e.GetValueGRPC(context.Background(), "foo") 90 assert.NoError(t, err) 91 assert.Equal(t, "bar", value) 92 }) 93 } 94 95 func TestYARPCWellKnownError(t *testing.T) { 96 t.Parallel() 97 te := testEnvOptions{} 98 te.do(t, func(t *testing.T, e *testEnv) { 99 e.KeyValueYARPCServer.SetNextError(status.Error(codes.FailedPrecondition, "bar 1")) 100 err := e.SetValueYARPC(context.Background(), "foo", "bar") 101 assert.Equal(t, yarpcerrors.Newf(yarpcerrors.CodeFailedPrecondition, "bar 1"), err) 102 }) 103 } 104 105 func TestYARPCNamedError(t *testing.T) { 106 t.Parallel() 107 te := testEnvOptions{} 108 te.do(t, func(t *testing.T, e *testEnv) { 109 e.KeyValueYARPCServer.SetNextError(intyarpcerrors.NewWithNamef(yarpcerrors.CodeUnknown, "bar", "baz 1")) 110 err := e.SetValueYARPC(context.Background(), "foo", "bar") 111 assert.Equal(t, intyarpcerrors.NewWithNamef(yarpcerrors.CodeUnknown, "bar", "baz 1"), err) 112 }) 113 } 114 115 func TestYARPCNamedErrorNoMessage(t *testing.T) { 116 t.Parallel() 117 te := testEnvOptions{} 118 te.do(t, func(t *testing.T, e *testEnv) { 119 e.KeyValueYARPCServer.SetNextError(intyarpcerrors.NewWithNamef(yarpcerrors.CodeUnknown, "bar", "")) 120 err := e.SetValueYARPC(context.Background(), "foo", "bar") 121 assert.Equal(t, intyarpcerrors.NewWithNamef(yarpcerrors.CodeUnknown, "bar", ""), err) 122 }) 123 } 124 125 func TestYARPCErrorWithDetails(t *testing.T) { 126 t.Parallel() 127 te := testEnvOptions{} 128 te.do(t, func(t *testing.T, e *testEnv) { 129 e.KeyValueYARPCServer.SetNextError(protobuf.NewError(yarpcerrors.CodeNotFound, "hello world", protobuf.WithErrorDetails(&examplepb.SetValueResponse{}))) 130 err := e.SetValueYARPC(context.Background(), "foo", "bar") 131 require.Len(t, protobuf.GetErrorDetails(err), 1) 132 assert.Equal(t, protobuf.GetErrorDetails(err)[0], &examplepb.SetValueResponse{}) 133 assert.Equal(t, yarpcerrors.FromError(err).Code(), yarpcerrors.CodeNotFound) 134 assert.Equal(t, yarpcerrors.FromError(err).Message(), "hello world") 135 }) 136 } 137 138 func TestGRPCWellKnownError(t *testing.T) { 139 t.Parallel() 140 te := testEnvOptions{} 141 te.do(t, func(t *testing.T, e *testEnv) { 142 e.KeyValueYARPCServer.SetNextError(status.Error(codes.FailedPrecondition, "bar 1")) 143 err := e.SetValueGRPC(context.Background(), "foo", "bar") 144 assert.Equal(t, status.Error(codes.FailedPrecondition, "bar 1"), err) 145 }) 146 } 147 148 func TestGRPCNamedError(t *testing.T) { 149 t.Parallel() 150 te := testEnvOptions{} 151 te.do(t, func(t *testing.T, e *testEnv) { 152 e.KeyValueYARPCServer.SetNextError(intyarpcerrors.NewWithNamef(yarpcerrors.CodeUnknown, "bar", "baz 1")) 153 err := e.SetValueGRPC(context.Background(), "foo", "bar") 154 assert.Equal(t, status.Error(codes.Unknown, "bar: baz 1"), err) 155 }) 156 } 157 158 func TestGRPCNamedErrorNoMessage(t *testing.T) { 159 t.Parallel() 160 te := testEnvOptions{} 161 te.do(t, func(t *testing.T, e *testEnv) { 162 e.KeyValueYARPCServer.SetNextError(intyarpcerrors.NewWithNamef(yarpcerrors.CodeUnknown, "bar", "")) 163 err := e.SetValueGRPC(context.Background(), "foo", "bar") 164 assert.Equal(t, status.Error(codes.Unknown, "bar"), err) 165 }) 166 } 167 168 func TestGRPCErrorWithDetails(t *testing.T) { 169 t.Parallel() 170 te := testEnvOptions{} 171 te.do(t, func(t *testing.T, e *testEnv) { 172 e.KeyValueYARPCServer.SetNextError(protobuf.NewError(yarpcerrors.CodeNotFound, "hello world", protobuf.WithErrorDetails(&examplepb.SetValueResponse{}))) 173 err := e.SetValueGRPC(context.Background(), "foo", "bar") 174 st := gogostatus.Convert(err) 175 assert.Equal(t, st.Code(), codes.NotFound) 176 assert.Equal(t, st.Message(), "hello world") 177 assert.Equal(t, st.Details(), []interface{}{&examplepb.SetValueResponse{}}) 178 }) 179 } 180 181 func TestYARPCResponseAndError(t *testing.T) { 182 t.Parallel() 183 te := testEnvOptions{} 184 te.do(t, func(t *testing.T, e *testEnv) { 185 err := e.SetValueYARPC(context.Background(), "foo", "bar") 186 assert.NoError(t, err) 187 e.KeyValueYARPCServer.SetNextError(status.Error(codes.FailedPrecondition, "bar 1")) 188 value, err := e.GetValueYARPC(context.Background(), "foo") 189 assert.Equal(t, "bar", value) 190 assert.Equal(t, yarpcerrors.Newf(yarpcerrors.CodeFailedPrecondition, "bar 1"), err) 191 }) 192 } 193 194 func TestGRPCResponseAndError(t *testing.T) { 195 t.Skip("grpc-go clients do not support returning both a response and error as of now") 196 t.Parallel() 197 te := testEnvOptions{} 198 te.do(t, func(t *testing.T, e *testEnv) { 199 err := e.SetValueGRPC(context.Background(), "foo", "bar") 200 assert.NoError(t, err) 201 e.KeyValueYARPCServer.SetNextError(status.Error(codes.FailedPrecondition, "bar 1")) 202 value, err := e.GetValueGRPC(context.Background(), "foo") 203 assert.Equal(t, "bar", value) 204 assert.Equal(t, status.Error(codes.FailedPrecondition, "bar 1"), err) 205 }) 206 } 207 208 func TestYARPCMaxMsgSize(t *testing.T) { 209 t.Parallel() 210 value := strings.Repeat("a", defaultServerMaxRecvMsgSize+1) 211 t.Run("too big", func(t *testing.T) { 212 te := testEnvOptions{} 213 te.do(t, func(t *testing.T, e *testEnv) { 214 ctx, cancel := context.WithTimeout(context.Background(), testtime.Second*5) 215 defer cancel() 216 217 err := e.SetValueYARPC(ctx, "foo", value) 218 219 assert.Equal(t, yarpcerrors.CodeResourceExhausted.String(), yarpcerrors.FromError(err).Code().String()) 220 }) 221 }) 222 t.Run("just right", func(t *testing.T) { 223 te := testEnvOptions{ 224 TransportOptions: []TransportOption{ 225 ClientMaxRecvMsgSize(math.MaxInt32), 226 ClientMaxSendMsgSize(math.MaxInt32), 227 ServerMaxRecvMsgSize(math.MaxInt32), 228 ServerMaxSendMsgSize(math.MaxInt32), 229 }, 230 } 231 te.do(t, func(t *testing.T, e *testEnv) { 232 ctx, cancel := context.WithTimeout(context.Background(), testtime.Second*5) 233 defer cancel() 234 235 if assert.NoError(t, e.SetValueYARPC(ctx, "foo", value)) { 236 getValue, err := e.GetValueYARPC(ctx, "foo") 237 assert.NoError(t, err) 238 assert.Equal(t, value, getValue) 239 } 240 }) 241 }) 242 } 243 244 func TestLargeEcho(t *testing.T) { 245 t.Parallel() 246 value := strings.Repeat("a", 32768) 247 te := testEnvOptions{} 248 te.do(t, func(t *testing.T, e *testEnv) { 249 if assert.NoError(t, e.SetValueYARPC(context.Background(), "foo", value)) { 250 getValue, err := e.GetValueYARPC(context.Background(), "foo") 251 assert.NoError(t, err) 252 assert.Equal(t, value, getValue) 253 } 254 }) 255 } 256 257 func TestApplicationErrorPropagation(t *testing.T) { 258 t.Parallel() 259 te := testEnvOptions{} 260 te.do(t, func(t *testing.T, e *testEnv) { 261 response, err := e.Call( 262 context.Background(), 263 "GetValue", 264 &examplepb.GetValueRequest{Key: "foo"}, 265 protobuf.Encoding, 266 transport.Headers{}, 267 ) 268 require.Equal(t, yarpcerrors.NotFoundErrorf("foo"), err) 269 require.True(t, response.ApplicationError) 270 271 response, err = e.Call( 272 context.Background(), 273 "SetValue", 274 &examplepb.SetValueRequest{Key: "foo", Value: "hello"}, 275 protobuf.Encoding, 276 transport.Headers{}, 277 ) 278 require.NoError(t, err) 279 require.False(t, response.ApplicationError) 280 281 response, err = e.Call( 282 context.Background(), 283 "GetValue", 284 &examplepb.GetValueRequest{Key: "foo"}, 285 "bad_encoding", 286 transport.Headers{}, 287 ) 288 require.True(t, yarpcerrors.IsInvalidArgument(err)) 289 require.False(t, response.ApplicationError) 290 }) 291 } 292 293 func TestCustomContextDial(t *testing.T) { 294 t.Parallel() 295 errMsg := "my custom dialer error" 296 contextDial := func(context.Context, string) (net.Conn, error) { 297 return nil, errors.New(errMsg) 298 } 299 300 te := testEnvOptions{ 301 DialOptions: []DialOption{ContextDialer(contextDial)}, 302 } 303 te.do(t, func(t *testing.T, e *testEnv) { 304 err := e.SetValueYARPC(context.Background(), "foo", "bar") 305 require.Error(t, err) 306 assert.Contains(t, err.Error(), errMsg) 307 }) 308 } 309 310 // TestGRPCCompression aims to test the compression when both, the client and 311 // the server has the same compressors registered and have the same compressor 312 // enabled. 313 func TestGRPCCompression(t *testing.T) { 314 tagsCompression := map[string]string{"stage": "compress"} 315 tagsDecompression := map[string]string{"stage": "decompress"} 316 317 tests := []struct { 318 testEnvOptions 319 320 msg string 321 compressor transport.Compressor 322 wantErr string 323 wantMetrics []metric 324 }{ 325 { 326 msg: "no compression", 327 }, 328 { 329 msg: "fail compression of request", 330 compressor: _badCompressor, 331 wantErr: "code:internal message:grpc: error while compressing: assert.AnError general error for testing", 332 wantMetrics: []metric{ 333 {0, tagsCompression}, 334 }, 335 }, 336 { 337 msg: "fail decompression of request", 338 compressor: _badDecompressor, 339 wantErr: "code:internal message:grpc: failed to decompress the received message assert.AnError general error for testing", 340 wantMetrics: []metric{ 341 {32777, tagsCompression}, 342 {0, tagsDecompression}, 343 }, 344 }, 345 { 346 msg: "ok, dummy compression", 347 compressor: _goodCompressor, 348 wantMetrics: []metric{ 349 {32777, tagsCompression}, 350 {32777, tagsDecompression}, 351 {0, tagsCompression}, 352 {5, tagsCompression}, 353 {5, tagsDecompression}, 354 {32772, tagsCompression}, 355 {32772, tagsDecompression}, 356 }, 357 }, 358 { 359 msg: "ok, gzip compression", 360 compressor: _gzipCompressor, 361 wantMetrics: []metric{ 362 {82, tagsCompression}, 363 {82, tagsDecompression}, 364 {23, tagsCompression}, 365 {23, tagsDecompression}, 366 {29, tagsCompression}, 367 {29, tagsDecompression}, 368 {75, tagsCompression}, 369 {75, tagsDecompression}, 370 }, 371 }, 372 } 373 374 for _, tt := range tests { 375 tt := tt 376 t.Run(tt.msg, func(t *testing.T) { 377 _metrics.reset() 378 379 tt.testEnvOptions.DialOptions = []DialOption{Compressor(tt.compressor)} 380 tt.do(t, func(t *testing.T, e *testEnv) { 381 value := strings.Repeat("a", 32*1024) 382 err := e.SetValueYARPC(context.Background(), "foo", value) 383 if tt.wantErr != "" { 384 assert.Error(t, err) 385 assert.EqualError(t, err, tt.wantErr) 386 } else if assert.NoError(t, err) { 387 getValue, err := e.GetValueYARPC(context.Background(), "foo") 388 require.NoError(t, err) 389 assert.Equal(t, value, getValue) 390 } 391 }) 392 393 compressor := "" 394 if tt.compressor != nil { 395 compressor = tt.compressor.Name() 396 } 397 assert.Equal(t, newMetrics(tt.wantMetrics, map[string]string{ 398 "compressor": compressor, 399 }), _metrics) 400 }) 401 } 402 } 403 404 func TestTLSWithYARPCAndGRPC(t *testing.T) { 405 tests := []struct { 406 name string 407 clientValidity time.Duration 408 serverValidity time.Duration 409 wantErr bool 410 }{ 411 { 412 name: "valid certs both sides", 413 clientValidity: time.Minute, 414 serverValidity: time.Minute, 415 }, 416 { 417 name: "invalid server cert", 418 clientValidity: time.Minute, 419 serverValidity: -1, 420 wantErr: true, 421 }, 422 { 423 name: "invalid client cert", 424 clientValidity: -1, 425 serverValidity: time.Minute, 426 wantErr: true, 427 }, 428 } 429 430 for _, tt := range tests { 431 t.Run(tt.name, func(t *testing.T) { 432 scenario := testscenario.Create(t, tt.clientValidity, tt.serverValidity) 433 te := testEnvOptions{ 434 InboundOptions: []InboundOption{InboundCredentials(credentials.NewTLS(scenario.ServerTLSConfig()))}, 435 DialOptions: []DialOption{DialerCredentials(credentials.NewTLS(scenario.ClientTLSConfig()))}, 436 } 437 te.do(t, func(t *testing.T, e *testEnv) { 438 err := e.SetValueYARPC(context.Background(), "foo", "bar") 439 if tt.wantErr { 440 assert.Error(t, err) 441 } else { 442 assert.NoError(t, err) 443 } 444 445 err = e.SetValueGRPC(context.Background(), "foo", "bar") 446 if tt.wantErr { 447 assert.Error(t, err) 448 } else { 449 assert.NoError(t, err) 450 } 451 }) 452 }) 453 } 454 } 455 456 // TestCompressionWithMultipleOutbounds creates multiple outbound for the 457 // same hostport where one outbound has compression enabled. 458 // Validates compression is applied for the outbound with compression enabled 459 // and rest of the outbounds are still uncompressed. 460 func TestCompressionWithMultipleOutbounds(t *testing.T) { 461 env, err := newTestEnv(t, nil, nil, nil, nil) 462 require.NoError(t, err) 463 defer func() { assert.NoError(t, env.Close()) }() 464 465 chooser := peer.NewSingle(hostport.Identify(env.Inbound.Addr().String()), env.Transport.NewDialer()) 466 compressedOutbound := env.Transport.NewOutbound(chooser, OutboundCompressor(_goodCompressor)) 467 require.NoError(t, compressedOutbound.Start()) 468 defer compressedOutbound.Stop() 469 470 caller := "example-client" 471 service := "example" 472 clientConfig := clientconfig.MultiOutbound( 473 caller, 474 service, 475 transport.Outbounds{ 476 ServiceName: caller, 477 Unary: compressedOutbound, 478 }, 479 ) 480 compressedClient := examplepb.NewKeyValueYARPCClient(clientConfig) 481 482 ctx, cancel := context.WithTimeout(context.Background(), testtime.Second*5) 483 defer cancel() 484 485 // Send request over uncompressed outbound and assert compression metric 486 // is empty. 487 _metrics.reset() 488 require.NoError(t, env.SetValueYARPC(ctx, "foo", strings.Repeat("a", 32*1024))) 489 assert.Equal(t, &metricCollection{metrics: []metric{}}, _metrics) 490 491 // Send request over compressed outbound and assert compression metric 492 // is seen. 493 _metrics.reset() 494 _, err = compressedClient.SetValue(ctx, &examplepb.SetValueRequest{Key: "foo", Value: strings.Repeat("a", 32*1024)}) 495 require.NoError(t, err) 496 wantMetric := []metric{ 497 {32777, map[string]string{"stage": "compress"}}, 498 {32777, map[string]string{"stage": "decompress"}}, 499 {0, map[string]string{"stage": "compress"}}, 500 } 501 assert.Equal(t, newMetrics(wantMetric, map[string]string{ 502 "compressor": _goodCompressor.name, 503 }), _metrics) 504 } 505 506 func TestGRPCHeaderListSize(t *testing.T) { 507 tests := []struct { 508 desc string 509 options []TransportOption 510 headerSize int 511 errorMsg string 512 }{ 513 { 514 desc: "default_setting", 515 headerSize: 1024, 516 }, 517 { 518 desc: "limit_server_header_size", 519 headerSize: 1024, 520 options: []TransportOption{ServerMaxHeaderListSize(1000)}, 521 errorMsg: "header list size to send violates the maximum size (1000 bytes) set by server", 522 }, 523 { 524 desc: "limit_client_header_size", 525 headerSize: 1024, 526 options: []TransportOption{ClientMaxHeaderListSize(1000)}, 527 errorMsg: "stream terminated", 528 }, 529 { 530 desc: "allow_large_header_size", 531 headerSize: 1024 * 1024 * 1, // 1MB 532 options: []TransportOption{ServerMaxHeaderListSize(1024 * 1024 * 2), ClientMaxHeaderListSize(1024 * 1024 * 2)}, 533 }, 534 } 535 536 for _, tt := range tests { 537 t.Run(tt.desc, func(t *testing.T) { 538 headerVal := make([]byte, tt.headerSize) 539 // Set valid ASCII as grpc header cannot be a 0 byte slice. 540 for i := 0; i < tt.headerSize; i++ { 541 headerVal[i] = 'a' 542 } 543 te := testEnvOptions{ 544 TransportOptions: tt.options, 545 } 546 te.do(t, func(t *testing.T, e *testEnv) { 547 var resHeaders map[string]string 548 // Setting longer timeout as CI timesout on large payloads. 549 ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) 550 defer cancel() 551 552 err := e.SetValueYARPC(ctx, "foo", "bar", yarpc.ResponseHeaders(&resHeaders), yarpc.WithHeader("test-header", string(headerVal))) 553 if tt.errorMsg != "" { 554 require.Error(t, err) 555 assert.Contains(t, err.Error(), tt.errorMsg) 556 return 557 } 558 assert.NoError(t, err) 559 assert.Equal(t, resHeaders["test-header"], string(headerVal)) 560 }) 561 }) 562 } 563 } 564 565 func TestMuxTLS(t *testing.T) { 566 defer goleak.VerifyNone(t) 567 tests := []struct { 568 name string 569 isClientTLS bool 570 }{ 571 { 572 name: "plaintext_client", 573 isClientTLS: false, 574 }, 575 { 576 name: "tls_client", 577 isClientTLS: true, 578 }, 579 } 580 for _, tt := range tests { 581 t.Run(tt.name, func(t *testing.T) { 582 scenario := testscenario.Create(t, time.Minute, time.Minute) 583 var dialOptions []DialOption 584 if tt.isClientTLS { 585 dialOptions = append(dialOptions, DialerCredentials(credentials.NewTLS(scenario.ClientTLSConfig()))) 586 } 587 588 te := testEnvOptions{ 589 InboundOptions: []InboundOption{InboundTLSConfiguration(scenario.ServerTLSConfig()), InboundTLSMode(yarpctls.Permissive)}, 590 DialOptions: dialOptions, 591 } 592 te.do(t, func(t *testing.T, e *testEnv) { 593 err := e.SetValueYARPC(context.Background(), "foo", "bar") 594 assert.NoError(t, err) 595 596 err = e.SetValueGRPC(context.Background(), "foo", "bar") 597 assert.NoError(t, err) 598 }) 599 }) 600 } 601 } 602 603 func TestOutboundTLS(t *testing.T) { 604 defer goleak.VerifyNone(t) 605 scenario := testscenario.Create(t, time.Minute, time.Minute) 606 607 tests := []struct { 608 desc string 609 withCustomDialer bool 610 }{ 611 {desc: "without_custom_dialer", withCustomDialer: false}, 612 {desc: "with_custom_dialer", withCustomDialer: true}, 613 } 614 for _, tt := range tests { 615 t.Run(tt.desc, func(t *testing.T) { 616 dialOpts := []DialOption{ 617 DialerTLSConfig(scenario.ClientTLSConfig()), 618 } 619 // This is used for asserting if custom dialer is invoked. 620 var invokedCustomDialer int32 621 if tt.withCustomDialer { 622 dialOpts = append(dialOpts, ContextDialer(func(ctx context.Context, s string) (net.Conn, error) { 623 // Avoid write race warning as concurrent dialers will be 624 // invoked as two gRPC clients are created below. 625 atomic.AddInt32(&invokedCustomDialer, 1) 626 return (&net.Dialer{}).DialContext(ctx, "tcp", s) 627 })) 628 } 629 te := testEnvOptions{ 630 InboundOptions: []InboundOption{InboundTLSConfiguration(scenario.ServerTLSConfig()), InboundTLSMode(yarpctls.Permissive)}, 631 DialOptions: dialOpts, 632 } 633 te.do(t, func(t *testing.T, e *testEnv) { 634 err := e.SetValueYARPC(context.Background(), "foo", "bar") 635 assert.NoError(t, err) 636 637 err = e.SetValueGRPC(context.Background(), "foo", "bar") 638 assert.NoError(t, err) 639 }) 640 if tt.withCustomDialer { 641 assert.True(t, invokedCustomDialer > 0) 642 } 643 }) 644 } 645 } 646 647 type metricCollection struct { 648 metrics []metric 649 } 650 651 func (c *metricCollection) reset() { 652 c.metrics = c.metrics[:0] 653 } 654 655 func newMetrics(metrics []metric, tags map[string]string) *metricCollection { 656 c := metricCollection{ 657 metrics: make([]metric, len(metrics)), 658 } 659 for i, m := range metrics { 660 c.metrics[i] = metric{ 661 bytes: m.bytes, 662 tags: map[string]string{}, 663 } 664 for key, value := range m.tags { 665 c.metrics[i].tags[key] = value 666 } 667 for key, value := range tags { 668 c.metrics[i].tags[key] = value 669 } 670 } 671 return &c 672 } 673 674 type metric struct { 675 bytes int 676 tags map[string]string 677 } 678 679 func (m *metric) Increment(value int) { 680 m.bytes += value 681 } 682 683 // new creates a new metrics data point and passes returns it as one element slice 684 func (c *metricCollection) new(stage, compressor string) *metric { 685 l := len(c.metrics) 686 c.metrics = append(c.metrics, metric{ 687 bytes: 0, 688 tags: map[string]string{ 689 "compressor": compressor, 690 "stage": stage, 691 }, 692 }) 693 return &c.metrics[l] 694 } 695 696 type counter interface { 697 Increment(value int) 698 } 699 700 type testCompressor struct { 701 name string 702 metrics *metricCollection 703 comperr error 704 decomperr error 705 enableGZip bool 706 } 707 708 type testCompressorBehavior int 709 710 const ( 711 testCompressorOk = 1 << iota 712 testCompressorFailToCompress 713 testCompressorFailToDecompress 714 testCompressorGzip 715 ) 716 717 func newCompressor(name string, behavior testCompressorBehavior, metrics *metricCollection) *testCompressor { 718 comp := testCompressor{ 719 name: name, 720 metrics: metrics, 721 } 722 723 if behavior&testCompressorFailToCompress != 0 { 724 comp.comperr = assert.AnError 725 } 726 727 if behavior&testCompressorFailToDecompress != 0 { 728 comp.decomperr = assert.AnError 729 } 730 731 if behavior&testCompressorGzip != 0 { 732 comp.enableGZip = true 733 } 734 735 return &comp 736 } 737 738 func (c *testCompressor) Name() string { return c.name } 739 740 func (c *testCompressor) Compress(w io.Writer) (io.WriteCloser, error) { 741 metered := byteMeter{ 742 Writer: w, 743 counter: c.metrics.new("compress", c.name), 744 } 745 746 if c.enableGZip { 747 return gzip.NewWriter(&metered), nil 748 } 749 return &metered, c.comperr 750 } 751 752 func (c *testCompressor) Decompress(r io.Reader) (io.ReadCloser, error) { 753 metered := byteMeter{ 754 Reader: r, 755 counter: c.metrics.new("decompress", c.name), 756 } 757 758 if c.enableGZip { 759 return gzip.NewReader(&metered) 760 } 761 762 return &metered, c.decomperr 763 } 764 765 // byteMeter is a test type wrapper that counts the number of bytes transferred within the compressors. 766 type byteMeter struct { 767 io.Writer 768 io.Reader 769 counter counter 770 } 771 772 func (m *byteMeter) Write(p []byte) (int, error) { 773 m.counter.Increment(len(p)) 774 return m.Writer.Write(p) 775 } 776 777 func (m *byteMeter) Read(p []byte) (int, error) { 778 l, err := m.Reader.Read(p) 779 m.counter.Increment(l) 780 return l, err 781 } 782 783 func (m *byteMeter) Close() error { return nil } 784 785 type testEnv struct { 786 Caller string 787 Service string 788 Transport *Transport 789 Inbound *Inbound 790 Outbound *Outbound 791 ClientConn *grpc.ClientConn 792 ContextWrapper *grpcctx.ContextWrapper 793 ClientConfig transport.ClientConfig 794 Procedures []transport.Procedure 795 KeyValueGRPCClient examplepb.KeyValueClient 796 KeyValueYARPCClient examplepb.KeyValueYARPCClient 797 KeyValueYARPCServer *example.KeyValueYARPCServer 798 } 799 800 type testEnvOptions struct { 801 TransportOptions []TransportOption 802 InboundOptions []InboundOption 803 OutboundOptions []OutboundOption 804 DialOptions []DialOption 805 } 806 807 func (te *testEnvOptions) do(t *testing.T, f func(*testing.T, *testEnv)) { 808 testEnv, err := newTestEnv( 809 t, 810 te.TransportOptions, 811 te.InboundOptions, 812 te.OutboundOptions, 813 te.DialOptions, 814 ) 815 require.NoError(t, err) 816 defer func() { 817 assert.NoError(t, testEnv.Close()) 818 }() 819 f(t, testEnv) 820 } 821 822 func newTestEnv( 823 t *testing.T, 824 transportOptions []TransportOption, 825 inboundOptions []InboundOption, 826 outboundOptions []OutboundOption, 827 dialOptions []DialOption, 828 ) (_ *testEnv, err error) { 829 keyValueYARPCServer := example.NewKeyValueYARPCServer() 830 procedures := examplepb.BuildKeyValueYARPCProcedures(keyValueYARPCServer) 831 testRouter := newTestRouter(procedures) 832 833 listener, err := net.Listen("tcp", "127.0.0.1:0") 834 if err != nil { 835 return nil, err 836 } 837 838 logger := zaptest.NewLogger(t) 839 transportOptions = append(transportOptions, Logger(logger)) 840 trans := NewTransport(transportOptions...) 841 inbound := trans.NewInbound(listener, inboundOptions...) 842 inbound.SetRouter(testRouter) 843 chooser := peer.NewSingle(hostport.Identify(listener.Addr().String()), trans.NewDialer(dialOptions...)) 844 outbound := trans.NewOutbound(chooser, outboundOptions...) 845 846 if err := trans.Start(); err != nil { 847 return nil, err 848 } 849 defer func() { 850 if err != nil { 851 err = multierr.Append(err, trans.Stop()) 852 } 853 }() 854 855 if err := inbound.Start(); err != nil { 856 return nil, err 857 } 858 defer func() { 859 if err != nil { 860 err = multierr.Append(err, inbound.Stop()) 861 } 862 }() 863 864 if err := outbound.Start(); err != nil { 865 return nil, err 866 } 867 defer func() { 868 if err != nil { 869 err = multierr.Append(err, outbound.Stop()) 870 } 871 }() 872 873 var clientConn *grpc.ClientConn 874 875 clientConn, err = grpc.Dial(listener.Addr().String(), newDialOptions(dialOptions).grpcOptions(trans)...) 876 if err != nil { 877 return nil, err 878 } 879 keyValueClient := examplepb.NewKeyValueClient(clientConn) 880 881 caller := "example-client" 882 service := "example" 883 clientConfig := clientconfig.MultiOutbound( 884 caller, 885 service, 886 transport.Outbounds{ 887 ServiceName: caller, 888 Unary: outbound, 889 }, 890 ) 891 keyValueYARPCClient := examplepb.NewKeyValueYARPCClient(clientConfig) 892 893 contextWrapper := grpcctx.NewContextWrapper(). 894 WithCaller("example-client"). 895 WithService("example"). 896 WithEncoding(string(protobuf.Encoding)) 897 898 return &testEnv{ 899 Caller: caller, 900 Service: service, 901 Transport: trans, 902 Inbound: inbound, 903 Outbound: outbound, 904 ClientConn: clientConn, 905 ContextWrapper: contextWrapper, 906 ClientConfig: clientConfig, 907 Procedures: procedures, 908 KeyValueGRPCClient: keyValueClient, 909 KeyValueYARPCClient: keyValueYARPCClient, 910 KeyValueYARPCServer: keyValueYARPCServer, 911 }, nil 912 } 913 914 func (e *testEnv) Call( 915 ctx context.Context, 916 methodName string, 917 message proto.Message, 918 encoding transport.Encoding, 919 headers transport.Headers, 920 ) (*transport.Response, error) { 921 data, err := proto.Marshal(message) 922 if err != nil { 923 return nil, err 924 } 925 ctx, cancel := context.WithTimeout(ctx, testtime.Second) 926 defer cancel() 927 return e.Outbound.Call( 928 ctx, 929 &transport.Request{ 930 Caller: e.Caller, 931 Service: e.Service, 932 Encoding: encoding, 933 Procedure: procedure.ToName( 934 "uber.yarpc.internal.examples.protobuf.example.KeyValue", 935 methodName, 936 ), 937 Headers: headers, 938 Body: bytes.NewReader(data), 939 }, 940 ) 941 } 942 943 func (e *testEnv) GetValueYARPC(ctx context.Context, key string, options ...yarpc.CallOption) (string, error) { 944 ctx, cancel := context.WithTimeout(ctx, testtime.Second) 945 defer cancel() 946 response, err := e.KeyValueYARPCClient.GetValue(ctx, &examplepb.GetValueRequest{Key: key}, options...) 947 if response != nil { 948 return response.Value, err 949 } 950 return "", err 951 } 952 953 func (e *testEnv) SetValueYARPC(ctx context.Context, key string, value string, options ...yarpc.CallOption) error { 954 if _, ok := ctx.Deadline(); !ok { 955 var cancel func() 956 ctx, cancel = context.WithTimeout(ctx, testtime.Second) 957 defer cancel() 958 } 959 _, err := e.KeyValueYARPCClient.SetValue(ctx, &examplepb.SetValueRequest{Key: key, Value: value}, options...) 960 return err 961 } 962 963 func (e *testEnv) GetValueGRPC(ctx context.Context, key string) (string, error) { 964 ctx, cancel := context.WithTimeout(ctx, testtime.Second) 965 defer cancel() 966 response, err := e.KeyValueGRPCClient.GetValue(e.ContextWrapper.Wrap(ctx), &examplepb.GetValueRequest{Key: key}) 967 if response != nil { 968 return response.Value, err 969 } 970 return "", err 971 } 972 973 func (e *testEnv) SetValueGRPC(ctx context.Context, key string, value string) error { 974 ctx, cancel := context.WithTimeout(ctx, testtime.Second) 975 defer cancel() 976 _, err := e.KeyValueGRPCClient.SetValue(e.ContextWrapper.Wrap(ctx), &examplepb.SetValueRequest{Key: key, Value: value}) 977 return err 978 } 979 980 func (e *testEnv) Close() error { 981 return multierr.Combine( 982 e.ClientConn.Close(), 983 e.Transport.Stop(), 984 e.Outbound.Stop(), 985 e.Inbound.Stop(), 986 ) 987 } 988 989 type testRouter struct { 990 procedures []transport.Procedure 991 } 992 993 func newTestRouter(procedures []transport.Procedure) *testRouter { 994 return &testRouter{procedures} 995 } 996 997 func (r *testRouter) Procedures() []transport.Procedure { 998 return r.procedures 999 } 1000 1001 func (r *testRouter) Choose(_ context.Context, request *transport.Request) (transport.HandlerSpec, error) { 1002 for _, procedure := range r.procedures { 1003 if procedure.Name == request.Procedure { 1004 return procedure.HandlerSpec, nil 1005 } 1006 } 1007 return transport.HandlerSpec{}, yarpcerrors.UnimplementedErrorf("no procedure for name %s", request.Procedure) 1008 } 1009 1010 func TestYARPCErrorsConverted(t *testing.T) { 1011 // Ensures that all returned errors are gRPC errors and not YARPC errors 1012 1013 trans := NewTransport() 1014 1015 listener, err := net.Listen("tcp", "127.0.0.1:0") 1016 require.NoError(t, err) 1017 inbound := trans.NewInbound(listener) 1018 1019 outbound := trans.NewSingleOutbound(listener.Addr().String()) 1020 1021 router := &testRouter{} 1022 inbound.SetRouter(router) 1023 1024 require.NoError(t, trans.Start()) 1025 defer func() { assert.NoError(t, trans.Stop()) }() 1026 1027 require.NoError(t, inbound.Start()) 1028 defer func() { assert.NoError(t, inbound.Stop()) }() 1029 1030 require.NoError(t, outbound.Start()) 1031 defer func() { assert.NoError(t, outbound.Stop()) }() 1032 1033 t.Run("no procedure", func(t *testing.T) { 1034 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 1035 defer cancel() 1036 1037 _, err := outbound.Call(ctx, &transport.Request{ 1038 Caller: "caller", 1039 Service: "service", 1040 Encoding: "encoding", 1041 Procedure: "no procedure", 1042 Body: bytes.NewBufferString("foo-body"), 1043 }) 1044 1045 require.Error(t, err) 1046 assert.True(t, yarpcerrors.IsUnimplemented(err)) 1047 }) 1048 }