go.uber.org/yarpc@v1.72.1/transport/tchannel/outbound_test.go (about) 1 // Copyright (c) 2022 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 21 package tchannel 22 23 import ( 24 "bytes" 25 "io/ioutil" 26 "sync" 27 "testing" 28 "time" 29 30 "github.com/stretchr/testify/assert" 31 "github.com/stretchr/testify/require" 32 "github.com/uber/tchannel-go" 33 "github.com/uber/tchannel-go/testutils" 34 "go.uber.org/yarpc/api/transport" 35 "go.uber.org/yarpc/encoding/raw" 36 "go.uber.org/yarpc/internal/testtime" 37 "go.uber.org/yarpc/yarpcerrors" 38 "golang.org/x/net/context" 39 ) 40 41 func TestTransportNamer(t *testing.T) { 42 trans, err := NewTransport() 43 require.NoError(t, err) 44 assert.Equal(t, TransportName, trans.NewOutbound(nil).TransportName()) 45 } 46 47 func TestOutboundHeaders(t *testing.T) { 48 tests := []struct { 49 name string 50 originalHeaders bool 51 giveHeaders map[string]string 52 wantHeaders map[string]string 53 }{ 54 { 55 name: "exactCaseHeader options on", 56 giveHeaders: map[string]string{ 57 "foo-BAR-BaZ": "PiE", 58 "foo-bar": "LEMON", 59 "BAR-BAZ": "orange", 60 }, 61 wantHeaders: map[string]string{ 62 "foo-BAR-BaZ": "PiE", 63 "foo-bar": "LEMON", 64 "BAR-BAZ": "orange", 65 }, 66 originalHeaders: true, 67 }, 68 { 69 name: "exactCaseHeader options off", 70 giveHeaders: map[string]string{ 71 "foo-BAR-BaZ": "PiE", 72 "foo-bar": "LEMON", 73 "BAR-BAZ": "orange", 74 }, 75 wantHeaders: map[string]string{ 76 "foo-bar-baz": "PiE", 77 "foo-bar": "LEMON", 78 "bar-baz": "orange", 79 }, 80 }, 81 } 82 83 for _, tt := range tests { 84 t.Run(tt.name, func(t *testing.T) { 85 var handlerInvoked bool 86 server := testutils.NewServer(t, nil) 87 defer server.Close() 88 serverHostPort := server.PeerInfo().HostPort 89 90 server.GetSubChannel("service").SetHandler(tchannel.HandlerFunc( 91 func(ctx context.Context, call *tchannel.InboundCall) { 92 handlerInvoked = true 93 headers, err := readHeaders(tchannel.Raw, call.Arg2Reader) 94 if !assert.NoError(t, err, "failed to read request") { 95 return 96 } 97 98 deleteReservedHeaders(headers) 99 assert.Equal(t, tt.wantHeaders, headers.OriginalItems(), "headers did not match") 100 101 // write a response 102 err = writeArgs(call.Response(), []byte{0x00, 0x00}, []byte("")) 103 assert.NoError(t, err, "failed to write response") 104 })) 105 106 opts := []TransportOption{ServiceName("caller")} 107 if tt.originalHeaders { 108 opts = append(opts, OriginalHeaders()) 109 } 110 111 trans, err := NewTransport(opts...) 112 require.NoError(t, err) 113 require.NoError(t, trans.Start(), "failed to start transport") 114 defer trans.Stop() 115 116 out := trans.NewSingleOutbound(serverHostPort) 117 require.NoError(t, out.Start(), "failed to start outbound") 118 defer out.Stop() 119 120 ctx, cancel := context.WithTimeout(context.Background(), 200*testtime.Millisecond) 121 defer cancel() 122 _, err = out.Call( 123 ctx, 124 &transport.Request{ 125 Caller: "caller", 126 Service: "service", 127 Encoding: raw.Encoding, 128 Procedure: "hello", 129 Headers: transport.HeadersFromMap(tt.giveHeaders), 130 Body: bytes.NewBufferString("body"), 131 }, 132 ) 133 134 require.NoError(t, err, "failed to make call") 135 assert.True(t, handlerInvoked, "handler was never called by client") 136 }) 137 } 138 } 139 140 func TestCallSuccess(t *testing.T) { 141 var handlerInvoked bool 142 server := testutils.NewServer(t, nil) 143 defer server.Close() 144 serverHostPort := server.PeerInfo().HostPort 145 146 server.GetSubChannel("service").SetHandler(tchannel.HandlerFunc( 147 func(ctx context.Context, call *tchannel.InboundCall) { 148 handlerInvoked = true 149 150 assert.Equal(t, "caller", call.CallerName()) 151 assert.Equal(t, "service", call.ServiceName()) 152 assert.Equal(t, tchannel.Raw, call.Format()) 153 assert.Equal(t, "hello", call.MethodString()) 154 _, body, err := readArgs(call) 155 if assert.NoError(t, err, "failed to read request") { 156 assert.Equal(t, []byte("world"), body) 157 } 158 159 dl, ok := ctx.Deadline() 160 assert.True(t, ok, "deadline expected") 161 assert.WithinDuration(t, time.Now(), dl, 200*testtime.Millisecond) 162 163 err = writeArgs(call.Response(), 164 []byte{ 165 0x00, 0x01, 166 0x00, 0x03, 'f', 'o', 'o', 167 0x00, 0x03, 'b', 'a', 'r', 168 }, []byte("great success")) 169 assert.NoError(t, err, "failed to write response") 170 })) 171 172 out, trans := newSingleOutbound(t, serverHostPort) 173 defer out.Stop() 174 defer trans.Stop() 175 require.NoError(t, out.Start(), "failed to start outbound") 176 177 ctx, cancel := context.WithTimeout(context.Background(), 200*testtime.Millisecond) 178 defer cancel() 179 res, err := out.Call( 180 ctx, 181 &transport.Request{ 182 Caller: "caller", 183 Service: "service", 184 Encoding: raw.Encoding, 185 Procedure: "hello", 186 Body: bytes.NewBufferString("world"), 187 }, 188 ) 189 190 require.NoError(t, err, "failed to make call") 191 require.False(t, res.ApplicationError, "unexpected application error") 192 193 foo, ok := res.Headers.Get("foo") 194 assert.True(t, ok, "value for foo expected") 195 assert.Equal(t, "bar", foo, "foo value mismatch") 196 197 body, err := ioutil.ReadAll(res.Body) 198 if assert.NoError(t, err, "failed to read response body") { 199 assert.Equal(t, []byte("great success"), body) 200 } 201 202 assert.NoError(t, res.Body.Close(), "failed to close response body") 203 assert.True(t, handlerInvoked, "handler was never called by client") 204 } 205 206 func TestCallWithModifiedCallerName(t *testing.T) { 207 const ( 208 destService = "server" 209 alternateCallerName = "alternate-caller" 210 ) 211 212 server := testutils.NewServer(t, nil) 213 defer server.Close() 214 215 server.GetSubChannel(destService).SetHandler(tchannel.HandlerFunc( 216 func(ctx context.Context, call *tchannel.InboundCall) { 217 assert.Equal(t, alternateCallerName, call.CallerName()) 218 _, _, err := readArgs(call) 219 assert.NoError(t, err, "failed to read request") 220 221 err = writeArgs(call.Response(), []byte{0x00, 0x00} /*headers*/, nil /*body*/) 222 assert.NoError(t, err, "failed to write response") 223 })) 224 225 out, trans := newSingleOutbound(t, server.PeerInfo().HostPort) 226 require.NoError(t, out.Start(), "failed to start outbound") 227 defer out.Stop() 228 defer trans.Stop() 229 230 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 231 defer cancel() 232 res, err := out.Call( 233 ctx, 234 &transport.Request{ 235 Caller: alternateCallerName, // newSingleOutbound uses "caller", this should override it 236 Service: destService, 237 Encoding: "bar", 238 Procedure: "baz", 239 Body: bytes.NewBuffer(nil), 240 }, 241 ) 242 243 require.NoError(t, err, "failed to make call") 244 assert.NoError(t, res.Body.Close(), "failed to close response body") 245 } 246 247 func TestCallFailures(t *testing.T) { 248 const ( 249 unexpectedMethod = "unexpected" 250 unknownMethod = "unknown" 251 ) 252 253 server := testutils.NewServer(t, nil) 254 defer server.Close() 255 serverHostPort := server.PeerInfo().HostPort 256 257 server.GetSubChannel("service").SetHandler(tchannel.HandlerFunc( 258 func(ctx context.Context, call *tchannel.InboundCall) { 259 var err error 260 if call.MethodString() == unexpectedMethod { 261 err = tchannel.NewSystemError( 262 tchannel.ErrCodeUnexpected, "great sadness") 263 call.Response().SendSystemError(err) 264 } else if call.MethodString() == unknownMethod { 265 err = tchannel.NewSystemError( 266 tchannel.ErrCodeBadRequest, "unknown method") 267 call.Response().SendSystemError(err) 268 } else { 269 err = writeArgs(call.Response(), 270 []byte{ 271 0x00, 0x01, 272 0x00, 0x0d, '$', 'r', 'p', 'c', '$', '-', 's', 'e', 'r', 'v', 'i', 'c', 'e', 273 0x00, 0x05, 'w', 'r', 'o', 'n', 'g', 274 }, []byte("bad sadness")) 275 assert.NoError(t, err, "o write response") 276 } 277 })) 278 279 type testCase struct { 280 desc string 281 procedure string 282 message string 283 } 284 285 tests := []testCase{ 286 { 287 desc: "unexpected error", 288 procedure: unexpectedMethod, 289 message: "great sadness", 290 }, 291 { 292 desc: "missing procedure error", 293 procedure: unknownMethod, 294 message: "unknown method", 295 }, 296 { 297 desc: "service name mismatch error", 298 procedure: "wrong service name", 299 message: "does not match", 300 }, 301 } 302 303 for _, tt := range tests { 304 t.Run(tt.desc, func(t *testing.T) { 305 306 out, trans := newSingleOutbound(t, serverHostPort) 307 require.NoError(t, out.Start(), "failed to start outbound") 308 defer out.Stop() 309 defer trans.Stop() 310 311 ctx, cancel := context.WithTimeout(context.Background(), 200*testtime.Millisecond) 312 defer cancel() 313 _, err := out.Call( 314 ctx, 315 &transport.Request{ 316 Caller: "caller", 317 Service: "service", 318 Encoding: raw.Encoding, 319 Procedure: tt.procedure, 320 Body: bytes.NewReader([]byte("sup")), 321 }, 322 ) 323 324 require.Error(t, err, "expected failure") 325 assert.Contains(t, err.Error(), tt.message) 326 }) 327 } 328 } 329 330 func TestApplicationError(t *testing.T) { 331 server := testutils.NewServer(t, nil) 332 defer server.Close() 333 serverHostPort := server.PeerInfo().HostPort 334 335 server.GetSubChannel("service").SetHandler(tchannel.HandlerFunc( 336 func(ctx context.Context, call *tchannel.InboundCall) { 337 call.Response().SetApplicationError() 338 339 err := writeArgs( 340 call.Response(), 341 []byte{ 342 0x00, 0x03, 343 0x00, 0x1c, '$', 'r', 'p', 'c', '$', '-', 'a', 'p', 'p', 'l', 'i', 'c', 'a', 't', 'i', 'o', 'n', 344 '-', 'e', 'r', 'r', 'o', 'r', '-', 'c', 'o', 'd', 'e', 345 0x00, 0x02, '1', '0', 346 0x00, 0x1c, '$', 'r', 'p', 'c', '$', '-', 'a', 'p', 'p', 'l', 'i', 'c', 'a', 't', 'i', 'o', 'n', 347 '-', 'e', 'r', 'r', 'o', 'r', '-', 'n', 'a', 'm', 'e', 348 0x00, 0x03, 'b', 'A', 'z', 349 0x00, 0x1f, '$', 'r', 'p', 'c', '$', '-', 'a', 'p', 'p', 'l', 'i', 'c', 'a', 't', 'i', 'o', 'n', 350 '-', 'e', 'r', 'r', 'o', 'r', '-', 'd', 'e', 't', 'a', 'i', 'l', 's', 351 0x00, 0x03, 'F', 'o', 'O', 352 }, 353 []byte("foo"), 354 ) 355 assert.NoError(t, err, "failed to write response") 356 })) 357 358 out, trans := newSingleOutbound(t, serverHostPort) 359 defer out.Stop() 360 defer trans.Stop() 361 require.NoError(t, out.Start(), "failed to start outbound") 362 363 ctx, cancel := context.WithTimeout(context.Background(), 200*testtime.Millisecond) 364 defer cancel() 365 res, err := out.Call( 366 ctx, 367 &transport.Request{ 368 Caller: "caller", 369 Service: "service", 370 Encoding: raw.Encoding, 371 Procedure: "hello", 372 Body: &bytes.Buffer{}, 373 }, 374 ) 375 require.NoError(t, err, "failed to make call") 376 require.True(t, res.ApplicationError, "application error was not set") 377 require.NotNil(t, res.ApplicationErrorMeta.Code, "application error code was not set") 378 assert.Equal(t, "FoO", res.ApplicationErrorMeta.Details, "unexpected error message") 379 assert.Equal( 380 t, 381 yarpcerrors.CodeAborted, 382 *res.ApplicationErrorMeta.Code, 383 "application error code does not match the expected one", 384 ) 385 assert.Equal( 386 t, 387 "bAz", 388 res.ApplicationErrorMeta.Name, 389 "application error name does not match the expected one", 390 ) 391 392 } 393 394 func TestStartMultiple(t *testing.T) { 395 out, trans := newSingleOutbound(t, "localhost:4040") 396 defer out.Stop() 397 defer trans.Stop() 398 var wg sync.WaitGroup 399 signal := make(chan struct{}) 400 401 for i := 0; i < 10; i++ { 402 wg.Add(1) 403 go func() { 404 defer wg.Done() 405 <-signal 406 407 err := out.Start() 408 assert.NoError(t, err) 409 }() 410 } 411 close(signal) 412 wg.Wait() 413 } 414 415 func TestStopMultiple(t *testing.T) { 416 out, trans := newSingleOutbound(t, "localhost:4040") 417 defer out.Stop() 418 defer trans.Stop() 419 require.NoError(t, out.Start()) 420 421 var wg sync.WaitGroup 422 signal := make(chan struct{}) 423 424 for i := 0; i < 10; i++ { 425 wg.Add(1) 426 go func() { 427 defer wg.Done() 428 <-signal 429 430 err := out.Stop() 431 assert.NoError(t, err) 432 }() 433 } 434 close(signal) 435 wg.Wait() 436 } 437 438 func TestCallWithoutStarting(t *testing.T) { 439 out, trans := newSingleOutbound(t, "localhost:4040") 440 defer out.Stop() 441 defer trans.Stop() 442 ctx, cancel := context.WithTimeout(context.Background(), 200*testtime.Millisecond) 443 defer cancel() 444 _, err := out.Call( 445 ctx, 446 &transport.Request{ 447 Caller: "caller", 448 Service: "service", 449 Encoding: raw.Encoding, 450 Procedure: "foo", 451 Body: bytes.NewReader([]byte("sup")), 452 }, 453 ) 454 455 wantErr := yarpcerrors.FailedPreconditionErrorf("error waiting for tchannel outbound to start for service: service: context finished while waiting for instance to start: context deadline exceeded") 456 assert.EqualError(t, err, wantErr.Error()) 457 458 } 459 460 func TestOutboundNoRequest(t *testing.T) { 461 out, trans := newSingleOutbound(t, "localhost:4040") 462 defer out.Stop() 463 defer trans.Stop() 464 _, err := out.Call(context.Background(), nil) 465 wantErr := yarpcerrors.InvalidArgumentErrorf("request for tchannel outbound was nil") 466 assert.EqualError(t, err, wantErr.Error()) 467 } 468 469 func newSingleOutbound(t *testing.T, serverAddr string) (transport.UnaryOutbound, transport.Transport) { 470 trans, err := NewTransport(ServiceName("caller")) 471 require.NoError(t, err) 472 require.NoError(t, trans.Start()) 473 return trans.NewSingleOutbound(serverAddr), trans 474 }