go.uber.org/yarpc@v1.72.1/transport/roundtrip_test.go (about) 1 // Copyright (c) 2022 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 21 package transport_test 22 23 import ( 24 "bytes" 25 "context" 26 "fmt" 27 "io/ioutil" 28 "net" 29 "testing" 30 "time" 31 32 "github.com/stretchr/testify/assert" 33 "github.com/stretchr/testify/require" 34 "github.com/uber/tchannel-go" 35 "github.com/uber/tchannel-go/testutils" 36 "go.uber.org/yarpc/api/transport" 37 "go.uber.org/yarpc/api/transport/transporttest" 38 "go.uber.org/yarpc/encoding/raw" 39 "go.uber.org/yarpc/internal/testtime" 40 "go.uber.org/yarpc/transport/grpc" 41 "go.uber.org/yarpc/transport/http" 42 tch "go.uber.org/yarpc/transport/tchannel" 43 "go.uber.org/yarpc/yarpcerrors" 44 ) 45 46 // all tests in this file should use these names for callers and services. 47 const ( 48 testCaller = "testService-client" 49 testService = "testService" 50 51 testProcedure = "hello" 52 testProcedureOneway = "hello-oneway" 53 ) 54 55 // roundTripTransport provides a function that sets up and tears down an 56 // Inbound, and provides an Outbound which knows how to call that Inbound. 57 type roundTripTransport interface { 58 // Name is the string representation of the transport. eg http, grpc, tchannel 59 Name() string 60 // Set up an Inbound serving Router r, and call f with an Outbound that 61 // knows how to talk to that Inbound. 62 WithRouter(r transport.Router, f func(transport.UnaryOutbound)) 63 WithRouterOneway(r transport.Router, f func(transport.OnewayOutbound)) 64 } 65 66 type staticRouter struct { 67 Handler transport.UnaryHandler 68 OnewayHandler transport.OnewayHandler 69 } 70 71 func (r staticRouter) Register([]transport.Procedure) { 72 panic("cannot register methods on a static router") 73 } 74 75 func (r staticRouter) Procedures() []transport.Procedure { 76 return []transport.Procedure{{Name: testProcedure, Service: testService}} 77 } 78 79 func (r staticRouter) Choose(ctx context.Context, req *transport.Request) (transport.HandlerSpec, error) { 80 if req.Procedure == testProcedure { 81 return transport.NewUnaryHandlerSpec(r.Handler), nil 82 } 83 return transport.NewOnewayHandlerSpec(r.OnewayHandler), nil 84 } 85 86 // handlerFunc wraps a function into a transport.Router 87 type unaryHandlerFunc func(context.Context, *transport.Request, transport.ResponseWriter) error 88 89 func (f unaryHandlerFunc) Handle(ctx context.Context, r *transport.Request, w transport.ResponseWriter) error { 90 return f(ctx, r, w) 91 } 92 93 // onewayHandlerFunc wraps a function into a transport.Router 94 type onewayHandlerFunc func(context.Context, *transport.Request) error 95 96 func (f onewayHandlerFunc) HandleOneway(ctx context.Context, r *transport.Request) error { 97 return f(ctx, r) 98 } 99 100 // httpTransport implements a roundTripTransport for HTTP. 101 type httpTransport struct{ t *testing.T } 102 103 func (ht httpTransport) Name() string { 104 return "http" 105 } 106 107 func (ht httpTransport) WithRouter(r transport.Router, f func(transport.UnaryOutbound)) { 108 httpTransport := http.NewTransport() 109 110 i := httpTransport.NewInbound("127.0.0.1:0") 111 i.SetRouter(r) 112 require.NoError(ht.t, i.Start(), "failed to start") 113 defer i.Stop() 114 115 o := httpTransport.NewSingleOutbound(fmt.Sprintf("http://%s", i.Addr().String())) 116 require.NoError(ht.t, o.Start(), "failed to start outbound") 117 defer o.Stop() 118 f(o) 119 } 120 121 func (ht httpTransport) WithRouterOneway(r transport.Router, f func(transport.OnewayOutbound)) { 122 httpTransport := http.NewTransport() 123 124 i := httpTransport.NewInbound("127.0.0.1:0") 125 i.SetRouter(r) 126 require.NoError(ht.t, i.Start(), "failed to start") 127 defer i.Stop() 128 129 o := httpTransport.NewSingleOutbound(fmt.Sprintf("http://%s", i.Addr().String())) 130 require.NoError(ht.t, o.Start(), "failed to start outbound") 131 defer o.Stop() 132 f(o) 133 } 134 135 // tchannelTransport implements a roundTripTransport for TChannel. 136 type tchannelTransport struct{ t *testing.T } 137 138 func (tt tchannelTransport) Name() string { 139 return "tchannel" 140 } 141 142 func (tt tchannelTransport) WithRouter(r transport.Router, f func(transport.UnaryOutbound)) { 143 serverOpts := testutils.NewOpts().SetServiceName(testService) 144 clientOpts := testutils.NewOpts().SetServiceName(testCaller) 145 testutils.WithServer(tt.t, serverOpts, func(ch *tchannel.Channel, hostPort string) { 146 ix, err := tch.NewChannelTransport(tch.WithChannel(ch)) 147 require.NoError(tt.t, err) 148 149 i := ix.NewInbound() 150 i.SetRouter(r) 151 require.NoError(tt.t, ix.Start(), "failed to start inbound transport") 152 require.NoError(tt.t, i.Start(), "failed to start inbound") 153 154 defer i.Stop() 155 // ^ the server is already listening so this will just set up the 156 // handler. 157 158 client := testutils.NewClient(tt.t, clientOpts) 159 ox, err := tch.NewChannelTransport(tch.WithChannel(client)) 160 require.NoError(tt.t, err) 161 162 o := ox.NewSingleOutbound(hostPort) 163 require.NoError(tt.t, ox.Start(), "failed to start outbound transport") 164 require.NoError(tt.t, o.Start(), "failed to start outbound") 165 defer o.Stop() 166 167 f(o) 168 }) 169 } 170 171 func (tt tchannelTransport) WithRouterOneway(r transport.Router, f func(transport.OnewayOutbound)) { 172 panic("tchannel does not support oneway calls") 173 } 174 175 // grpcTransport implements a roundTripTransport for gRPC. 176 type grpcTransport struct{ t *testing.T } 177 178 func (gt grpcTransport) Name() string { 179 return "grpc" 180 } 181 182 func (gt grpcTransport) WithRouter(r transport.Router, f func(transport.UnaryOutbound)) { 183 grpcTransport := grpc.NewTransport() 184 require.NoError(gt.t, grpcTransport.Start(), "failed to start transport") 185 defer grpcTransport.Stop() 186 187 listener, err := net.Listen("tcp", "127.0.0.1:0") 188 require.NoError(gt.t, err) 189 i := grpcTransport.NewInbound(listener) 190 i.SetRouter(r) 191 require.NoError(gt.t, i.Start(), "failed to start inbound") 192 defer i.Stop() 193 194 o := grpcTransport.NewSingleOutbound(listener.Addr().String()) 195 require.NoError(gt.t, o.Start(), "failed to start outbound") 196 defer o.Stop() 197 f(o) 198 } 199 200 func (gt grpcTransport) WithRouterOneway(r transport.Router, f func(transport.OnewayOutbound)) { 201 panic("grpc does not support oneway calls") 202 } 203 204 func TestSimpleRoundTrip(t *testing.T) { 205 transports := []roundTripTransport{ 206 httpTransport{t}, 207 tchannelTransport{t}, 208 grpcTransport{t}, 209 } 210 211 tests := []struct { 212 name string 213 214 requestHeaders transport.Headers 215 requestBody string 216 responseHeaders transport.Headers 217 responseBody string 218 responseError error 219 220 wantError func(error) 221 }{ 222 { 223 name: "headers", 224 requestHeaders: transport.NewHeaders().With("token", "1234"), 225 requestBody: "world", 226 responseHeaders: transport.NewHeaders().With("status", "ok"), 227 responseBody: "hello, world", 228 }, 229 { 230 name: "internal err", 231 requestBody: "foo", 232 responseError: yarpcerrors.Newf(yarpcerrors.CodeInternal, "great sadness"), 233 wantError: func(err error) { 234 assert.True(t, yarpcerrors.FromError(err).Code() == yarpcerrors.CodeInternal, err.Error()) 235 }, 236 }, 237 { 238 name: "invalid arg", 239 requestBody: "bar", 240 responseError: yarpcerrors.Newf(yarpcerrors.CodeInvalidArgument, "missing service name"), 241 wantError: func(err error) { 242 assert.True(t, yarpcerrors.FromError(err).Code() == yarpcerrors.CodeInvalidArgument, err.Error()) 243 }, 244 }, 245 } 246 247 for _, tt := range tests { 248 for _, trans := range transports { 249 t.Run(tt.name+"/"+trans.Name(), func(t *testing.T) { 250 requestMatcher := transporttest.NewRequestMatcher(t, &transport.Request{ 251 Caller: testCaller, 252 Service: testService, 253 Transport: trans.Name(), 254 Procedure: testProcedure, 255 Encoding: raw.Encoding, 256 Headers: tt.requestHeaders, 257 Body: bytes.NewBufferString(tt.requestBody), 258 }) 259 260 handler := unaryHandlerFunc(func(_ context.Context, r *transport.Request, w transport.ResponseWriter) error { 261 r.Headers.Del("user-agent") // for gRPC 262 r.Headers.Del(":authority") // for gRPC 263 assert.True(t, requestMatcher.Matches(r), "request mismatch: received %v", r) 264 265 if tt.responseError != nil { 266 return tt.responseError 267 } 268 269 if tt.responseHeaders.Len() > 0 { 270 w.AddHeaders(tt.responseHeaders) 271 } 272 273 _, err := w.Write([]byte(tt.responseBody)) 274 assert.NoError(t, err, "failed to write response for %v", r) 275 return err 276 }) 277 278 ctx, cancel := context.WithTimeout(context.Background(), 200*testtime.Millisecond) 279 defer cancel() 280 281 router := staticRouter{Handler: handler} 282 trans.WithRouter(router, func(o transport.UnaryOutbound) { 283 res, err := o.Call(ctx, &transport.Request{ 284 Caller: testCaller, 285 Service: testService, 286 Procedure: testProcedure, 287 Encoding: raw.Encoding, 288 Headers: tt.requestHeaders, 289 Body: bytes.NewBufferString(tt.requestBody), 290 }) 291 292 if tt.wantError != nil { 293 if assert.Error(t, err, "%T: expected error, got %v", trans, res) { 294 tt.wantError(err) 295 } 296 297 } else { 298 responseMatcher := transporttest.NewResponseMatcher(t, &transport.Response{ 299 Headers: tt.responseHeaders, 300 Body: ioutil.NopCloser(bytes.NewReader([]byte(tt.responseBody))), 301 }) 302 303 if assert.NoError(t, err, "%T: call failed", trans) { 304 assert.True(t, responseMatcher.Matches(res), "%T: response mismatch", trans) 305 } 306 } 307 }) 308 }) 309 } 310 } 311 } 312 313 func TestSimpleRoundTripOneway(t *testing.T) { 314 trans := httpTransport{t} 315 316 tests := []struct { 317 name string 318 requestHeaders transport.Headers 319 requestBody string 320 }{ 321 { 322 name: "hello world", 323 requestHeaders: transport.NewHeaders().With("foo", "bar"), 324 requestBody: "hello world", 325 }, 326 { 327 name: "empty", 328 requestHeaders: transport.NewHeaders(), 329 requestBody: "", 330 }, 331 } 332 333 rootCtx := context.Background() 334 335 for _, tt := range tests { 336 t.Run(tt.name, func(t *testing.T) { 337 338 requestMatcher := transporttest.NewRequestMatcher(t, &transport.Request{ 339 Caller: testCaller, 340 Service: testService, 341 Transport: trans.Name(), 342 Procedure: testProcedureOneway, 343 Encoding: raw.Encoding, 344 Headers: tt.requestHeaders, 345 Body: bytes.NewReader([]byte(tt.requestBody)), 346 }) 347 348 handlerDone := make(chan struct{}) 349 350 onewayHandler := onewayHandlerFunc(func(_ context.Context, r *transport.Request) error { 351 assert.True(t, requestMatcher.Matches(r), "request mismatch: received %v", r) 352 353 // Pretend to work: this delay should not slow down tests since it is a 354 // server-side operation 355 testtime.Sleep(5 * time.Second) 356 357 // close the channel, telling the client (which should not be waiting for 358 // a response) that the handler finished executing 359 close(handlerDone) 360 361 return nil 362 }) 363 364 router := staticRouter{OnewayHandler: onewayHandler} 365 366 trans.WithRouterOneway(router, func(o transport.OnewayOutbound) { 367 ctx, cancel := context.WithTimeout(rootCtx, time.Second) 368 defer cancel() 369 ack, err := o.CallOneway(ctx, &transport.Request{ 370 Caller: testCaller, 371 Service: testService, 372 Procedure: testProcedureOneway, 373 Encoding: raw.Encoding, 374 Headers: tt.requestHeaders, 375 Body: bytes.NewReader([]byte(tt.requestBody)), 376 }) 377 378 select { 379 case <-handlerDone: 380 // if the server filled the channel, it means we waited for the server 381 // to complete the request 382 assert.Fail(t, "client waited for server handler to finish executing") 383 default: 384 } 385 386 if assert.NoError(t, err, "%T: oneway call failed for test '%v'", trans, tt.name) { 387 assert.NotNil(t, ack) 388 } 389 }) 390 }) 391 } 392 }