go.uber.org/yarpc@v1.72.1/transport/http/integration_test.go (about) 1 // Copyright (c) 2022 Uber Technologies, Inc. 2 // 3 // Permission is hereby granted, free of charge, to any person obtaining a copy 4 // of this software and associated documentation files (the "Software"), to deal 5 // in the Software without restriction, including without limitation the rights 6 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 // copies of the Software, and to permit persons to whom the Software is 8 // furnished to do so, subject to the following conditions: 9 // 10 // The above copyright notice and this permission notice shall be included in 11 // all copies or substantial portions of the Software. 12 // 13 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 19 // THE SOFTWARE. 20 21 package http 22 23 import ( 24 "context" 25 "crypto/tls" 26 "errors" 27 "fmt" 28 "net" 29 "testing" 30 "time" 31 32 "github.com/stretchr/testify/assert" 33 "github.com/stretchr/testify/require" 34 "go.uber.org/goleak" 35 "go.uber.org/multierr" 36 "go.uber.org/yarpc/api/transport" 37 yarpctls "go.uber.org/yarpc/api/transport/tls" 38 "go.uber.org/yarpc/encoding/json" 39 "go.uber.org/yarpc/internal/clientconfig" 40 pkgerrors "go.uber.org/yarpc/pkg/errors" 41 "go.uber.org/yarpc/transport/internal/tls/testscenario" 42 ) 43 44 func TestInboundTLS(t *testing.T) { 45 defer goleak.VerifyNone(t) 46 47 scenario := testscenario.Create(t, time.Minute, time.Minute) 48 tests := []struct { 49 desc string 50 inboundOptions []InboundOption 51 transportOptions []TransportOption 52 isTLSClient bool 53 }{ 54 { 55 desc: "plaintext_client_permissive_tls_server", 56 inboundOptions: []InboundOption{ 57 InboundTLSConfiguration(scenario.ServerTLSConfig()), 58 InboundTLSMode(yarpctls.Permissive), 59 }, 60 }, 61 { 62 desc: "tls_client_enforced_tls_server", 63 inboundOptions: []InboundOption{ 64 InboundTLSConfiguration(scenario.ServerTLSConfig()), 65 InboundTLSMode(yarpctls.Enforced), 66 }, 67 transportOptions: []TransportOption{ 68 DialContext(func(ctx context.Context, network, addr string) (net.Conn, error) { 69 return tls.Dial(network, addr, scenario.ClientTLSConfig()) 70 }), 71 }, 72 isTLSClient: true, 73 }, 74 } 75 for _, tt := range tests { 76 t.Run(tt.desc, func(t *testing.T) { 77 doWithTestEnv(t, testEnvOptions{ 78 Procedures: json.Procedure("testFoo", testFooHandler), 79 InboundOptions: tt.inboundOptions, 80 TransportOptions: tt.transportOptions, 81 }, func(t *testing.T, testEnv *testEnv) { 82 client := json.New(testEnv.ClientConfig) 83 var response testFooResponse 84 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 85 defer cancel() 86 87 err := client.Call(ctx, "testFoo", &testFooRequest{One: "one"}, &response) 88 require.Nil(t, err) 89 assert.Equal(t, testFooResponse{One: "one"}, response) 90 }) 91 }) 92 } 93 } 94 95 func TestOutboundTLS(t *testing.T) { 96 defer goleak.VerifyNone(t) 97 98 scenario := testscenario.Create(t, time.Minute, time.Minute) 99 tests := []struct { 100 desc string 101 withCustomDialer bool 102 }{ 103 {desc: "without_custom_dialer", withCustomDialer: false}, 104 {desc: "with_custom_dialer", withCustomDialer: true}, 105 } 106 for _, tt := range tests { 107 t.Run(tt.desc, func(t *testing.T) { 108 var opts []TransportOption 109 // This is used for asserting if custom dialer is invoked. 110 var invokedCustomDialer bool 111 if tt.withCustomDialer { 112 opts = []TransportOption{ 113 DialContext(func(ctx context.Context, network, addr string) (net.Conn, error) { 114 invokedCustomDialer = true 115 return (&net.Dialer{}).DialContext(ctx, network, addr) 116 }), 117 } 118 } 119 doWithTestEnv(t, testEnvOptions{ 120 Procedures: json.Procedure("testFoo", testFooHandler), 121 InboundOptions: []InboundOption{ 122 InboundTLSConfiguration(scenario.ServerTLSConfig()), 123 InboundTLSMode(yarpctls.Enforced), 124 }, 125 OutboundOptions: []OutboundOption{OutboundTLSConfiguration(scenario.ClientTLSConfig())}, 126 TransportOptions: opts, 127 }, func(t *testing.T, testEnv *testEnv) { 128 client := json.New(testEnv.ClientConfig) 129 var response testFooResponse 130 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 131 defer cancel() 132 133 err := client.Call(ctx, "testFoo", &testFooRequest{One: "one"}, &response) 134 require.Nil(t, err) 135 assert.Equal(t, testFooResponse{One: "one"}, response) 136 assert.Equal(t, tt.withCustomDialer, invokedCustomDialer) 137 }) 138 }) 139 } 140 } 141 142 func TestBothResponseError(t *testing.T) { 143 tests := []struct { 144 inboundBothResponseError bool 145 outboundBothResponseError bool 146 }{ 147 { 148 inboundBothResponseError: false, 149 outboundBothResponseError: false, 150 }, 151 { 152 inboundBothResponseError: false, 153 outboundBothResponseError: true, 154 }, 155 { 156 inboundBothResponseError: true, 157 outboundBothResponseError: false, 158 }, 159 { 160 inboundBothResponseError: true, 161 outboundBothResponseError: true, 162 }, 163 } 164 165 for _, tt := range tests { 166 t.Run(fmt.Sprintf("inbound(%v)-outbound(%v)", tt.inboundBothResponseError, tt.outboundBothResponseError), func(t *testing.T) { 167 doWithTestEnv(t, testEnvOptions{ 168 Procedures: json.Procedure("testFoo", testFooHandler), 169 InboundOptions: []InboundOption{ 170 func(i *Inbound) { 171 i.bothResponseError = tt.inboundBothResponseError 172 }, 173 }, 174 OutboundOptions: []OutboundOption{ 175 func(o *Outbound) { 176 o.bothResponseError = tt.outboundBothResponseError 177 }, 178 }, 179 }, func(t *testing.T, testEnv *testEnv) { 180 client := json.New(testEnv.ClientConfig) 181 var response testFooResponse 182 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 183 defer cancel() 184 err := client.Call(ctx, "testFoo", &testFooRequest{One: "one", Error: "bar"}, &response) 185 186 assert.Equal(t, pkgerrors.WrapHandlerError(errors.New("bar"), "example", "testFoo"), err) 187 if tt.inboundBothResponseError && tt.outboundBothResponseError { 188 assert.Equal(t, "one", response.One) 189 } else { 190 assert.Empty(t, response.One) 191 } 192 }) 193 }) 194 } 195 } 196 197 type testFooRequest struct { 198 One string 199 Error string 200 } 201 202 type testFooResponse struct { 203 One string 204 } 205 206 func testFooHandler(_ context.Context, request *testFooRequest) (*testFooResponse, error) { 207 var err error 208 if request.Error != "" { 209 err = errors.New(request.Error) 210 } 211 return &testFooResponse{ 212 One: request.One, 213 }, err 214 } 215 216 func doWithTestEnv(t *testing.T, options testEnvOptions, f func(*testing.T, *testEnv)) { 217 testEnv, err := newTestEnv(options) 218 require.NoError(t, err) 219 defer func() { 220 assert.NoError(t, testEnv.Close()) 221 }() 222 f(t, testEnv) 223 } 224 225 type testEnv struct { 226 Inbound *Inbound 227 Outbound *Outbound 228 ClientConfig transport.ClientConfig 229 } 230 231 type testEnvOptions struct { 232 Procedures []transport.Procedure 233 TransportOptions []TransportOption 234 InboundOptions []InboundOption 235 OutboundOptions []OutboundOption 236 } 237 238 func newTestEnv(options testEnvOptions) (_ *testEnv, err error) { 239 t := NewTransport(options.TransportOptions...) 240 if err := t.Start(); err != nil { 241 return nil, err 242 } 243 defer func() { 244 if err != nil { 245 err = multierr.Append(err, t.Stop()) 246 } 247 }() 248 249 inbound := t.NewInbound("127.0.0.1:0", options.InboundOptions...) 250 inbound.SetRouter(newTestRouter(options.Procedures)) 251 if err := inbound.Start(); err != nil { 252 return nil, err 253 } 254 defer func() { 255 if err != nil { 256 err = multierr.Append(err, inbound.Stop()) 257 } 258 }() 259 260 outbound := t.NewSingleOutbound(fmt.Sprintf("http://%s", inbound.Addr().String()), options.OutboundOptions...) 261 if err := outbound.Start(); err != nil { 262 return nil, err 263 } 264 defer func() { 265 if err != nil { 266 err = multierr.Append(err, outbound.Stop()) 267 } 268 }() 269 270 caller := "example-client" 271 service := "example" 272 clientConfig := clientconfig.MultiOutbound( 273 caller, 274 service, 275 transport.Outbounds{ 276 ServiceName: caller, 277 Unary: outbound, 278 }, 279 ) 280 281 return &testEnv{ 282 inbound, 283 outbound, 284 clientConfig, 285 }, nil 286 } 287 288 func (e *testEnv) Close() error { 289 return multierr.Combine( 290 e.Outbound.Stop(), 291 e.Inbound.Stop(), 292 ) 293 } 294 295 type testRouter struct { 296 procedures []transport.Procedure 297 } 298 299 func newTestRouter(procedures []transport.Procedure) *testRouter { 300 return &testRouter{procedures} 301 } 302 303 func (r *testRouter) Procedures() []transport.Procedure { 304 return r.procedures 305 } 306 307 func (r *testRouter) Choose(_ context.Context, request *transport.Request) (transport.HandlerSpec, error) { 308 for _, procedure := range r.procedures { 309 if procedure.Name == request.Procedure { 310 return procedure.HandlerSpec, nil 311 } 312 } 313 return transport.HandlerSpec{}, fmt.Errorf("no procedure for name %s", request.Procedure) 314 }