google.golang.org/grpc@v1.62.1/balancer/rls/control_channel_test.go (about) 1 /* 2 * 3 * Copyright 2021 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19 package rls 20 21 import ( 22 "context" 23 "crypto/tls" 24 "crypto/x509" 25 "errors" 26 "fmt" 27 "os" 28 "regexp" 29 "testing" 30 "time" 31 32 "github.com/google/go-cmp/cmp" 33 "google.golang.org/grpc" 34 "google.golang.org/grpc/balancer" 35 "google.golang.org/grpc/codes" 36 "google.golang.org/grpc/credentials" 37 "google.golang.org/grpc/internal" 38 rlspb "google.golang.org/grpc/internal/proto/grpc_lookup_v1" 39 rlstest "google.golang.org/grpc/internal/testutils/rls" 40 "google.golang.org/grpc/metadata" 41 "google.golang.org/grpc/status" 42 "google.golang.org/grpc/testdata" 43 "google.golang.org/protobuf/proto" 44 ) 45 46 // TestControlChannelThrottled tests the case where the adaptive throttler 47 // indicates that the control channel needs to be throttled. 48 func (s) TestControlChannelThrottled(t *testing.T) { 49 // Start an RLS server and set the throttler to always throttle requests. 50 rlsServer, rlsReqCh := rlstest.SetupFakeRLSServer(t, nil) 51 overrideAdaptiveThrottler(t, alwaysThrottlingThrottler()) 52 53 // Create a control channel to the fake RLS server. 54 ctrlCh, err := newControlChannel(rlsServer.Address, "", defaultTestTimeout, balancer.BuildOptions{}, nil) 55 if err != nil { 56 t.Fatalf("Failed to create control channel to RLS server: %v", err) 57 } 58 defer ctrlCh.close() 59 60 // Perform the lookup and expect the attempt to be throttled. 61 ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, nil) 62 63 select { 64 case <-rlsReqCh: 65 t.Fatal("RouteLookup RPC invoked when control channel is throtlled") 66 case <-time.After(defaultTestShortTimeout): 67 } 68 } 69 70 // TestLookupFailure tests the case where the RLS server responds with an error. 71 func (s) TestLookupFailure(t *testing.T) { 72 // Start an RLS server and set the throttler to never throttle requests. 73 rlsServer, _ := rlstest.SetupFakeRLSServer(t, nil) 74 overrideAdaptiveThrottler(t, neverThrottlingThrottler()) 75 76 // Setup the RLS server to respond with errors. 77 rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse { 78 return &rlstest.RouteLookupResponse{Err: errors.New("rls failure")} 79 }) 80 81 // Create a control channel to the fake RLS server. 82 ctrlCh, err := newControlChannel(rlsServer.Address, "", defaultTestTimeout, balancer.BuildOptions{}, nil) 83 if err != nil { 84 t.Fatalf("Failed to create control channel to RLS server: %v", err) 85 } 86 defer ctrlCh.close() 87 88 // Perform the lookup and expect the callback to be invoked with an error. 89 errCh := make(chan error, 1) 90 ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(_ []string, _ string, err error) { 91 if err == nil { 92 errCh <- errors.New("rlsClient.lookup() succeeded, should have failed") 93 return 94 } 95 errCh <- nil 96 }) 97 98 select { 99 case <-time.After(defaultTestTimeout): 100 t.Fatal("timeout when waiting for lookup callback to be invoked") 101 case err := <-errCh: 102 if err != nil { 103 t.Fatal(err) 104 } 105 } 106 } 107 108 // TestLookupDeadlineExceeded tests the case where the RLS server does not 109 // respond within the configured rpc timeout. 110 func (s) TestLookupDeadlineExceeded(t *testing.T) { 111 // A unary interceptor which returns a status error with DeadlineExceeded. 112 interceptor := func(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { 113 return nil, status.Error(codes.DeadlineExceeded, "deadline exceeded") 114 } 115 116 // Start an RLS server and set the throttler to never throttle. 117 rlsServer, _ := rlstest.SetupFakeRLSServer(t, nil, grpc.UnaryInterceptor(interceptor)) 118 overrideAdaptiveThrottler(t, neverThrottlingThrottler()) 119 120 // Create a control channel with a small deadline. 121 ctrlCh, err := newControlChannel(rlsServer.Address, "", defaultTestShortTimeout, balancer.BuildOptions{}, nil) 122 if err != nil { 123 t.Fatalf("Failed to create control channel to RLS server: %v", err) 124 } 125 defer ctrlCh.close() 126 127 // Perform the lookup and expect the callback to be invoked with an error. 128 errCh := make(chan error) 129 ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(_ []string, _ string, err error) { 130 if st, ok := status.FromError(err); !ok || st.Code() != codes.DeadlineExceeded { 131 errCh <- fmt.Errorf("rlsClient.lookup() returned error: %v, want %v", err, codes.DeadlineExceeded) 132 return 133 } 134 errCh <- nil 135 }) 136 137 select { 138 case <-time.After(defaultTestTimeout): 139 t.Fatal("timeout when waiting for lookup callback to be invoked") 140 case err := <-errCh: 141 if err != nil { 142 t.Fatal(err) 143 } 144 } 145 } 146 147 // testCredsBundle wraps a test call creds and real transport creds. 148 type testCredsBundle struct { 149 transportCreds credentials.TransportCredentials 150 callCreds credentials.PerRPCCredentials 151 } 152 153 func (f *testCredsBundle) TransportCredentials() credentials.TransportCredentials { 154 return f.transportCreds 155 } 156 157 func (f *testCredsBundle) PerRPCCredentials() credentials.PerRPCCredentials { 158 return f.callCreds 159 } 160 161 func (f *testCredsBundle) NewWithMode(mode string) (credentials.Bundle, error) { 162 if mode != internal.CredsBundleModeFallback { 163 return nil, fmt.Errorf("unsupported mode: %v", mode) 164 } 165 return &testCredsBundle{ 166 transportCreds: f.transportCreds, 167 callCreds: f.callCreds, 168 }, nil 169 } 170 171 var ( 172 // Call creds sent by the testPerRPCCredentials on the client, and verified 173 // by an interceptor on the server. 174 perRPCCredsData = map[string]string{ 175 "test-key": "test-value", 176 "test-key-bin": string([]byte{1, 2, 3}), 177 } 178 ) 179 180 type testPerRPCCredentials struct { 181 callCreds map[string]string 182 } 183 184 func (f *testPerRPCCredentials) GetRequestMetadata(context.Context, ...string) (map[string]string, error) { 185 return f.callCreds, nil 186 } 187 188 func (f *testPerRPCCredentials) RequireTransportSecurity() bool { 189 return true 190 } 191 192 // Unary server interceptor which validates if the RPC contains call credentials 193 // which match `perRPCCredsData 194 func callCredsValidatingServerInterceptor(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp any, err error) { 195 md, ok := metadata.FromIncomingContext(ctx) 196 if !ok { 197 return nil, status.Error(codes.PermissionDenied, "didn't find metadata in context") 198 } 199 for k, want := range perRPCCredsData { 200 got, ok := md[k] 201 if !ok { 202 return ctx, status.Errorf(codes.PermissionDenied, "didn't find call creds key %v in context", k) 203 } 204 if got[0] != want { 205 return ctx, status.Errorf(codes.PermissionDenied, "for key %v, got value %v, want %v", k, got, want) 206 } 207 } 208 return handler(ctx, req) 209 } 210 211 // makeTLSCreds is a test helper which creates a TLS based transport credentials 212 // from files specified in the arguments. 213 func makeTLSCreds(t *testing.T, certPath, keyPath, rootsPath string) credentials.TransportCredentials { 214 cert, err := tls.LoadX509KeyPair(testdata.Path(certPath), testdata.Path(keyPath)) 215 if err != nil { 216 t.Fatalf("tls.LoadX509KeyPair(%q, %q) failed: %v", certPath, keyPath, err) 217 } 218 b, err := os.ReadFile(testdata.Path(rootsPath)) 219 if err != nil { 220 t.Fatalf("os.ReadFile(%q) failed: %v", rootsPath, err) 221 } 222 roots := x509.NewCertPool() 223 if !roots.AppendCertsFromPEM(b) { 224 t.Fatal("failed to append certificates") 225 } 226 return credentials.NewTLS(&tls.Config{ 227 Certificates: []tls.Certificate{cert}, 228 RootCAs: roots, 229 }) 230 } 231 232 const ( 233 wantHeaderData = "headerData" 234 staleHeaderData = "staleHeaderData" 235 ) 236 237 var ( 238 keyMap = map[string]string{ 239 "k1": "v1", 240 "k2": "v2", 241 } 242 wantTargets = []string{"us_east_1.firestore.googleapis.com"} 243 lookupRequest = &rlspb.RouteLookupRequest{ 244 TargetType: "grpc", 245 KeyMap: keyMap, 246 Reason: rlspb.RouteLookupRequest_REASON_MISS, 247 StaleHeaderData: staleHeaderData, 248 } 249 lookupResponse = &rlstest.RouteLookupResponse{ 250 Resp: &rlspb.RouteLookupResponse{ 251 Targets: wantTargets, 252 HeaderData: wantHeaderData, 253 }, 254 } 255 ) 256 257 func testControlChannelCredsSuccess(t *testing.T, sopts []grpc.ServerOption, bopts balancer.BuildOptions) { 258 // Start an RLS server and set the throttler to never throttle requests. 259 rlsServer, _ := rlstest.SetupFakeRLSServer(t, nil, sopts...) 260 overrideAdaptiveThrottler(t, neverThrottlingThrottler()) 261 262 // Setup the RLS server to respond with a valid response. 263 rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *rlstest.RouteLookupResponse { 264 return lookupResponse 265 }) 266 267 // Verify that the request received by the RLS matches the expected one. 268 rlsServer.SetRequestCallback(func(got *rlspb.RouteLookupRequest) { 269 if diff := cmp.Diff(lookupRequest, got, cmp.Comparer(proto.Equal)); diff != "" { 270 t.Errorf("RouteLookupRequest diff (-want, +got):\n%s", diff) 271 } 272 }) 273 274 // Create a control channel to the fake server. 275 ctrlCh, err := newControlChannel(rlsServer.Address, "", defaultTestTimeout, bopts, nil) 276 if err != nil { 277 t.Fatalf("Failed to create control channel to RLS server: %v", err) 278 } 279 defer ctrlCh.close() 280 281 // Perform the lookup and expect a successful callback invocation. 282 ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) 283 defer cancel() 284 errCh := make(chan error, 1) 285 ctrlCh.lookup(keyMap, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(targets []string, headerData string, err error) { 286 if err != nil { 287 errCh <- fmt.Errorf("rlsClient.lookup() failed with err: %v", err) 288 return 289 } 290 if !cmp.Equal(targets, wantTargets) || headerData != wantHeaderData { 291 errCh <- fmt.Errorf("rlsClient.lookup() = (%v, %s), want (%v, %s)", targets, headerData, wantTargets, wantHeaderData) 292 return 293 } 294 errCh <- nil 295 }) 296 297 select { 298 case <-ctx.Done(): 299 t.Fatal("timeout when waiting for lookup callback to be invoked") 300 case err := <-errCh: 301 if err != nil { 302 t.Fatal(err) 303 } 304 } 305 } 306 307 // TestControlChannelCredsSuccess tests creation of the control channel with 308 // different credentials, which are expected to succeed. 309 func (s) TestControlChannelCredsSuccess(t *testing.T) { 310 serverCreds := makeTLSCreds(t, "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem") 311 clientCreds := makeTLSCreds(t, "x509/client1_cert.pem", "x509/client1_key.pem", "x509/server_ca_cert.pem") 312 313 tests := []struct { 314 name string 315 sopts []grpc.ServerOption 316 bopts balancer.BuildOptions 317 }{ 318 { 319 name: "insecure", 320 sopts: nil, 321 bopts: balancer.BuildOptions{}, 322 }, 323 { 324 name: "transport creds only", 325 sopts: []grpc.ServerOption{grpc.Creds(serverCreds)}, 326 bopts: balancer.BuildOptions{ 327 DialCreds: clientCreds, 328 Authority: "x.test.example.com", 329 }, 330 }, 331 { 332 name: "creds bundle", 333 sopts: []grpc.ServerOption{ 334 grpc.Creds(serverCreds), 335 grpc.UnaryInterceptor(callCredsValidatingServerInterceptor), 336 }, 337 bopts: balancer.BuildOptions{ 338 CredsBundle: &testCredsBundle{ 339 transportCreds: clientCreds, 340 callCreds: &testPerRPCCredentials{callCreds: perRPCCredsData}, 341 }, 342 Authority: "x.test.example.com", 343 }, 344 }, 345 } 346 for _, test := range tests { 347 t.Run(test.name, func(t *testing.T) { 348 testControlChannelCredsSuccess(t, test.sopts, test.bopts) 349 }) 350 } 351 } 352 353 func testControlChannelCredsFailure(t *testing.T, sopts []grpc.ServerOption, bopts balancer.BuildOptions, wantCode codes.Code, wantErrRegex *regexp.Regexp) { 354 // StartFakeRouteLookupServer a fake server. 355 // 356 // Start an RLS server and set the throttler to never throttle requests. The 357 // creds failures happen before the RPC handler on the server is invoked. 358 // So, there is need to setup the request and responses on the fake server. 359 rlsServer, _ := rlstest.SetupFakeRLSServer(t, nil, sopts...) 360 overrideAdaptiveThrottler(t, neverThrottlingThrottler()) 361 362 // Create the control channel to the fake server. 363 ctrlCh, err := newControlChannel(rlsServer.Address, "", defaultTestTimeout, bopts, nil) 364 if err != nil { 365 t.Fatalf("Failed to create control channel to RLS server: %v", err) 366 } 367 defer ctrlCh.close() 368 369 // Perform the lookup and expect the callback to be invoked with an error. 370 errCh := make(chan error) 371 ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(_ []string, _ string, err error) { 372 if st, ok := status.FromError(err); !ok || st.Code() != wantCode || !wantErrRegex.MatchString(st.String()) { 373 errCh <- fmt.Errorf("rlsClient.lookup() returned error: %v, wantCode: %v, wantErr: %s", err, wantCode, wantErrRegex.String()) 374 return 375 } 376 errCh <- nil 377 }) 378 379 select { 380 case <-time.After(defaultTestTimeout): 381 t.Fatal("timeout when waiting for lookup callback to be invoked") 382 case err := <-errCh: 383 if err != nil { 384 t.Fatal(err) 385 } 386 } 387 } 388 389 // TestControlChannelCredsFailure tests creation of the control channel with 390 // different credentials, which are expected to fail. 391 func (s) TestControlChannelCredsFailure(t *testing.T) { 392 serverCreds := makeTLSCreds(t, "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem") 393 clientCreds := makeTLSCreds(t, "x509/client1_cert.pem", "x509/client1_key.pem", "x509/server_ca_cert.pem") 394 395 tests := []struct { 396 name string 397 sopts []grpc.ServerOption 398 bopts balancer.BuildOptions 399 wantCode codes.Code 400 wantErrRegex *regexp.Regexp 401 }{ 402 { 403 name: "transport creds authority mismatch", 404 sopts: []grpc.ServerOption{grpc.Creds(serverCreds)}, 405 bopts: balancer.BuildOptions{ 406 DialCreds: clientCreds, 407 Authority: "authority-mismatch", 408 }, 409 wantCode: codes.Unavailable, 410 wantErrRegex: regexp.MustCompile(`transport: authentication handshake failed: .* \*\.test\.example\.com.*authority-mismatch`), 411 }, 412 { 413 name: "transport creds handshake failure", 414 sopts: nil, // server expects insecure connection 415 bopts: balancer.BuildOptions{ 416 DialCreds: clientCreds, 417 Authority: "x.test.example.com", 418 }, 419 wantCode: codes.Unavailable, 420 wantErrRegex: regexp.MustCompile("transport: authentication handshake failed: .*"), 421 }, 422 { 423 name: "call creds mismatch", 424 sopts: []grpc.ServerOption{ 425 grpc.Creds(serverCreds), 426 grpc.UnaryInterceptor(callCredsValidatingServerInterceptor), // server expects call creds 427 }, 428 bopts: balancer.BuildOptions{ 429 CredsBundle: &testCredsBundle{ 430 transportCreds: clientCreds, 431 callCreds: &testPerRPCCredentials{}, // sends no call creds 432 }, 433 Authority: "x.test.example.com", 434 }, 435 wantCode: codes.PermissionDenied, 436 wantErrRegex: regexp.MustCompile("didn't find call creds"), 437 }, 438 } 439 for _, test := range tests { 440 t.Run(test.name, func(t *testing.T) { 441 testControlChannelCredsFailure(t, test.sopts, test.bopts, test.wantCode, test.wantErrRegex) 442 }) 443 } 444 } 445 446 type unsupportedCredsBundle struct { 447 credentials.Bundle 448 } 449 450 func (*unsupportedCredsBundle) NewWithMode(mode string) (credentials.Bundle, error) { 451 return nil, fmt.Errorf("unsupported mode: %v", mode) 452 } 453 454 // TestNewControlChannelUnsupportedCredsBundle tests the case where the control 455 // channel is configured with a bundle which does not support the mode we use. 456 func (s) TestNewControlChannelUnsupportedCredsBundle(t *testing.T) { 457 rlsServer, _ := rlstest.SetupFakeRLSServer(t, nil) 458 459 // Create the control channel to the fake server. 460 ctrlCh, err := newControlChannel(rlsServer.Address, "", defaultTestTimeout, balancer.BuildOptions{CredsBundle: &unsupportedCredsBundle{}}, nil) 461 if err == nil { 462 ctrlCh.close() 463 t.Fatal("newControlChannel succeeded when expected to fail") 464 } 465 }