github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/grpc/balancer/rls/internal/control_channel_test.go (about)

     1  /*
     2   *
     3   * Copyright 2021 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package rls
    20  
    21  import (
    22  	"context"
    23  	"errors"
    24  	"fmt"
    25  	"io/ioutil"
    26  	"strings"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/hxx258456/ccgo/x509"
    31  
    32  	tls "github.com/hxx258456/ccgo/gmtls"
    33  
    34  	"github.com/golang/protobuf/proto"
    35  	"github.com/google/go-cmp/cmp"
    36  	grpc "github.com/hxx258456/ccgo/grpc"
    37  	"github.com/hxx258456/ccgo/grpc/balancer"
    38  	"github.com/hxx258456/ccgo/grpc/balancer/rls/internal/test/e2e"
    39  	"github.com/hxx258456/ccgo/grpc/codes"
    40  	"github.com/hxx258456/ccgo/grpc/credentials"
    41  	"github.com/hxx258456/ccgo/grpc/internal"
    42  	rlspb "github.com/hxx258456/ccgo/grpc/internal/proto/grpc_lookup_v1"
    43  	"github.com/hxx258456/ccgo/grpc/metadata"
    44  	"github.com/hxx258456/ccgo/grpc/status"
    45  	"github.com/hxx258456/ccgo/grpc/testdata"
    46  )
    47  
    48  // TestControlChannelThrottled tests the case where the adaptive throttler
    49  // indicates that the control channel needs to be throttled.
    50  func (s) TestControlChannelThrottled(t *testing.T) {
    51  	// Start an RLS server and set the throttler to always throttle requests.
    52  	rlsServer, rlsReqCh := setupFakeRLSServer(t, nil)
    53  	overrideAdaptiveThrottler(t, alwaysThrottlingThrottler())
    54  
    55  	// Create a control channel to the fake RLS server.
    56  	ctrlCh, err := newControlChannel(rlsServer.Address, defaultTestTimeout, balancer.BuildOptions{}, nil)
    57  	if err != nil {
    58  		t.Fatalf("Failed to create control channel to RLS server: %v", err)
    59  	}
    60  	defer ctrlCh.close()
    61  
    62  	// Perform the lookup and expect the attempt to be throttled.
    63  	ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, nil)
    64  
    65  	select {
    66  	case <-rlsReqCh:
    67  		t.Fatal("RouteLookup RPC invoked when control channel is throtlled")
    68  	case <-time.After(defaultTestShortTimeout):
    69  	}
    70  }
    71  
    72  // TestLookupFailure tests the case where the RLS server responds with an error.
    73  func (s) TestLookupFailure(t *testing.T) {
    74  	// Start an RLS server and set the throttler to never throttle requests.
    75  	rlsServer, _ := setupFakeRLSServer(t, nil)
    76  	overrideAdaptiveThrottler(t, neverThrottlingThrottler())
    77  
    78  	// Setup the RLS server to respond with errors.
    79  	rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *e2e.RouteLookupResponse {
    80  		return &e2e.RouteLookupResponse{Err: errors.New("rls failure")}
    81  	})
    82  
    83  	// Create a control channel to the fake RLS server.
    84  	ctrlCh, err := newControlChannel(rlsServer.Address, defaultTestTimeout, balancer.BuildOptions{}, nil)
    85  	if err != nil {
    86  		t.Fatalf("Failed to create control channel to RLS server: %v", err)
    87  	}
    88  	defer ctrlCh.close()
    89  
    90  	// Perform the lookup and expect the callback to be invoked with an error.
    91  	errCh := make(chan error, 1)
    92  	ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(_ []string, _ string, err error) {
    93  		if err == nil {
    94  			errCh <- errors.New("rlsClient.lookup() succeeded, should have failed")
    95  			return
    96  		}
    97  		errCh <- nil
    98  	})
    99  
   100  	select {
   101  	case <-time.After(defaultTestTimeout):
   102  		t.Fatal("timeout when waiting for lookup callback to be invoked")
   103  	case err := <-errCh:
   104  		if err != nil {
   105  			t.Fatal(err)
   106  		}
   107  	}
   108  }
   109  
   110  // TestLookupDeadlineExceeded tests the case where the RLS server does not
   111  // respond within the configured rpc timeout.
   112  func (s) TestLookupDeadlineExceeded(t *testing.T) {
   113  	// A unary interceptor which sleeps for long enough to cause lookup RPCs to
   114  	// exceed their deadline.
   115  	rlsReqCh := make(chan struct{}, 1)
   116  	interceptor := func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
   117  		rlsReqCh <- struct{}{}
   118  		time.Sleep(2 * defaultTestShortTimeout)
   119  		return handler(ctx, req)
   120  	}
   121  
   122  	// Start an RLS server and set the throttler to never throttle.
   123  	rlsServer, _ := setupFakeRLSServer(t, nil, grpc.UnaryInterceptor(interceptor))
   124  	overrideAdaptiveThrottler(t, neverThrottlingThrottler())
   125  
   126  	// Create a control channel with a small deadline.
   127  	ctrlCh, err := newControlChannel(rlsServer.Address, defaultTestShortTimeout, balancer.BuildOptions{}, nil)
   128  	if err != nil {
   129  		t.Fatalf("Failed to create control channel to RLS server: %v", err)
   130  	}
   131  	defer ctrlCh.close()
   132  
   133  	// Perform the lookup and expect the callback to be invoked with an error.
   134  	errCh := make(chan error)
   135  	ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(_ []string, _ string, err error) {
   136  		if st, ok := status.FromError(err); !ok || st.Code() != codes.DeadlineExceeded {
   137  			errCh <- fmt.Errorf("rlsClient.lookup() returned error: %v, want %v", err, codes.DeadlineExceeded)
   138  			return
   139  		}
   140  		errCh <- nil
   141  	})
   142  
   143  	select {
   144  	case <-time.After(defaultTestTimeout):
   145  		t.Fatal("timeout when waiting for lookup callback to be invoked")
   146  	case err := <-errCh:
   147  		if err != nil {
   148  			t.Fatal(err)
   149  		}
   150  	}
   151  }
   152  
   153  // testCredsBundle wraps a test call creds and real transport creds.
   154  type testCredsBundle struct {
   155  	transportCreds credentials.TransportCredentials
   156  	callCreds      credentials.PerRPCCredentials
   157  }
   158  
   159  func (f *testCredsBundle) TransportCredentials() credentials.TransportCredentials {
   160  	return f.transportCreds
   161  }
   162  
   163  func (f *testCredsBundle) PerRPCCredentials() credentials.PerRPCCredentials {
   164  	return f.callCreds
   165  }
   166  
   167  func (f *testCredsBundle) NewWithMode(mode string) (credentials.Bundle, error) {
   168  	if mode != internal.CredsBundleModeFallback {
   169  		return nil, fmt.Errorf("unsupported mode: %v", mode)
   170  	}
   171  	return &testCredsBundle{
   172  		transportCreds: f.transportCreds,
   173  		callCreds:      f.callCreds,
   174  	}, nil
   175  }
   176  
   177  var (
   178  	// Call creds sent by the testPerRPCCredentials on the client, and verified
   179  	// by an interceptor on the server.
   180  	perRPCCredsData = map[string]string{
   181  		"test-key":     "test-value",
   182  		"test-key-bin": string([]byte{1, 2, 3}),
   183  	}
   184  )
   185  
   186  type testPerRPCCredentials struct {
   187  	callCreds map[string]string
   188  }
   189  
   190  func (f *testPerRPCCredentials) GetRequestMetadata(context.Context, ...string) (map[string]string, error) {
   191  	return f.callCreds, nil
   192  }
   193  
   194  func (f *testPerRPCCredentials) RequireTransportSecurity() bool {
   195  	return true
   196  }
   197  
   198  // Unary server interceptor which validates if the RPC contains call credentials
   199  // which match `perRPCCredsData
   200  func callCredsValidatingServerInterceptor(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (resp interface{}, err error) {
   201  	md, ok := metadata.FromIncomingContext(ctx)
   202  	if !ok {
   203  		return nil, status.Error(codes.PermissionDenied, "didn't find metadata in context")
   204  	}
   205  	for k, want := range perRPCCredsData {
   206  		got, ok := md[k]
   207  		if !ok {
   208  			return ctx, status.Errorf(codes.PermissionDenied, "didn't find call creds key %v in context", k)
   209  		}
   210  		if got[0] != want {
   211  			return ctx, status.Errorf(codes.PermissionDenied, "for key %v, got value %v, want %v", k, got, want)
   212  		}
   213  	}
   214  	return handler(ctx, req)
   215  }
   216  
   217  // makeTLSCreds is a test helper which creates a TLS based transport credentials
   218  // from files specified in the arguments.
   219  func makeTLSCreds(t *testing.T, certPath, keyPath, rootsPath string) credentials.TransportCredentials {
   220  	cert, err := tls.LoadX509KeyPair(testdata.Path(certPath), testdata.Path(keyPath))
   221  	if err != nil {
   222  		t.Fatalf("tls.LoadX509KeyPair(%q, %q) failed: %v", certPath, keyPath, err)
   223  	}
   224  	b, err := ioutil.ReadFile(testdata.Path(rootsPath))
   225  	if err != nil {
   226  		t.Fatalf("ioutil.ReadFile(%q) failed: %v", rootsPath, err)
   227  	}
   228  	roots := x509.NewCertPool()
   229  	if !roots.AppendCertsFromPEM(b) {
   230  		t.Fatal("failed to append certificates")
   231  	}
   232  	return credentials.NewTLS(&tls.Config{
   233  		Certificates: []tls.Certificate{cert},
   234  		RootCAs:      roots,
   235  	})
   236  }
   237  
   238  const (
   239  	wantHeaderData  = "headerData"
   240  	staleHeaderData = "staleHeaderData"
   241  )
   242  
   243  var (
   244  	keyMap = map[string]string{
   245  		"k1": "v1",
   246  		"k2": "v2",
   247  	}
   248  	wantTargets   = []string{"us_east_1.firestore.googleapis.com"}
   249  	lookupRequest = &rlspb.RouteLookupRequest{
   250  		TargetType:      "grpc",
   251  		KeyMap:          keyMap,
   252  		Reason:          rlspb.RouteLookupRequest_REASON_MISS,
   253  		StaleHeaderData: staleHeaderData,
   254  	}
   255  	lookupResponse = &e2e.RouteLookupResponse{
   256  		Resp: &rlspb.RouteLookupResponse{
   257  			Targets:    wantTargets,
   258  			HeaderData: wantHeaderData,
   259  		},
   260  	}
   261  )
   262  
   263  func testControlChannelCredsSuccess(t *testing.T, sopts []grpc.ServerOption, bopts balancer.BuildOptions) {
   264  	// Start an RLS server and set the throttler to never throttle requests.
   265  	rlsServer, _ := setupFakeRLSServer(t, nil, sopts...)
   266  	overrideAdaptiveThrottler(t, neverThrottlingThrottler())
   267  
   268  	// Setup the RLS server to respond with a valid response.
   269  	rlsServer.SetResponseCallback(func(_ context.Context, req *rlspb.RouteLookupRequest) *e2e.RouteLookupResponse {
   270  		return lookupResponse
   271  	})
   272  
   273  	// Verify that the request received by the RLS matches the expected one.
   274  	rlsServer.SetRequestCallback(func(got *rlspb.RouteLookupRequest) {
   275  		if diff := cmp.Diff(lookupRequest, got, cmp.Comparer(proto.Equal)); diff != "" {
   276  			t.Errorf("RouteLookupRequest diff (-want, +got):\n%s", diff)
   277  		}
   278  	})
   279  
   280  	// Create a control channel to the fake server.
   281  	ctrlCh, err := newControlChannel(rlsServer.Address, defaultTestTimeout, bopts, nil)
   282  	if err != nil {
   283  		t.Fatalf("Failed to create control channel to RLS server: %v", err)
   284  	}
   285  	defer ctrlCh.close()
   286  
   287  	// Perform the lookup and expect a successful callback invocation.
   288  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   289  	defer cancel()
   290  	errCh := make(chan error, 1)
   291  	ctrlCh.lookup(keyMap, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(targets []string, headerData string, err error) {
   292  		if err != nil {
   293  			errCh <- fmt.Errorf("rlsClient.lookup() failed with err: %v", err)
   294  			return
   295  		}
   296  		if !cmp.Equal(targets, wantTargets) || headerData != wantHeaderData {
   297  			errCh <- fmt.Errorf("rlsClient.lookup() = (%v, %s), want (%v, %s)", targets, headerData, wantTargets, wantHeaderData)
   298  			return
   299  		}
   300  		errCh <- nil
   301  	})
   302  
   303  	select {
   304  	case <-ctx.Done():
   305  		t.Fatal("timeout when waiting for lookup callback to be invoked")
   306  	case err := <-errCh:
   307  		if err != nil {
   308  			t.Fatal(err)
   309  		}
   310  	}
   311  }
   312  
   313  // TestControlChannelCredsSuccess tests creation of the control channel with
   314  // different credentials, which are expected to succeed.
   315  func (s) TestControlChannelCredsSuccess(t *testing.T) {
   316  	serverCreds := makeTLSCreds(t, "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem")
   317  	clientCreds := makeTLSCreds(t, "x509/client1_cert.pem", "x509/client1_key.pem", "x509/server_ca_cert.pem")
   318  
   319  	tests := []struct {
   320  		name  string
   321  		sopts []grpc.ServerOption
   322  		bopts balancer.BuildOptions
   323  	}{
   324  		{
   325  			name:  "insecure",
   326  			sopts: nil,
   327  			bopts: balancer.BuildOptions{},
   328  		},
   329  		{
   330  			name:  "transport creds only",
   331  			sopts: []grpc.ServerOption{grpc.Creds(serverCreds)},
   332  			bopts: balancer.BuildOptions{
   333  				DialCreds: clientCreds,
   334  				Authority: "x.test.example.com",
   335  			},
   336  		},
   337  		{
   338  			name: "creds bundle",
   339  			sopts: []grpc.ServerOption{
   340  				grpc.Creds(serverCreds),
   341  				grpc.UnaryInterceptor(callCredsValidatingServerInterceptor),
   342  			},
   343  			bopts: balancer.BuildOptions{
   344  				CredsBundle: &testCredsBundle{
   345  					transportCreds: clientCreds,
   346  					callCreds:      &testPerRPCCredentials{callCreds: perRPCCredsData},
   347  				},
   348  				Authority: "x.test.example.com",
   349  			},
   350  		},
   351  	}
   352  	for _, test := range tests {
   353  		t.Run(test.name, func(t *testing.T) {
   354  			testControlChannelCredsSuccess(t, test.sopts, test.bopts)
   355  		})
   356  	}
   357  }
   358  
   359  func testControlChannelCredsFailure(t *testing.T, sopts []grpc.ServerOption, bopts balancer.BuildOptions, wantCode codes.Code, wantErr string) {
   360  	// StartFakeRouteLookupServer a fake server.
   361  	//
   362  	// Start an RLS server and set the throttler to never throttle requests. The
   363  	// creds failures happen before the RPC handler on the server is invoked.
   364  	// So, there is need to setup the request and responses on the fake server.
   365  	rlsServer, _ := setupFakeRLSServer(t, nil, sopts...)
   366  	overrideAdaptiveThrottler(t, neverThrottlingThrottler())
   367  
   368  	// Create the control channel to the fake server.
   369  	ctrlCh, err := newControlChannel(rlsServer.Address, defaultTestTimeout, bopts, nil)
   370  	if err != nil {
   371  		t.Fatalf("Failed to create control channel to RLS server: %v", err)
   372  	}
   373  	defer ctrlCh.close()
   374  
   375  	// Perform the lookup and expect the callback to be invoked with an error.
   376  	errCh := make(chan error)
   377  	ctrlCh.lookup(nil, rlspb.RouteLookupRequest_REASON_MISS, staleHeaderData, func(_ []string, _ string, err error) {
   378  		if st, ok := status.FromError(err); !ok || st.Code() != wantCode || !strings.Contains(st.String(), wantErr) {
   379  			errCh <- fmt.Errorf("rlsClient.lookup() returned error: %v, wantCode: %v, wantErr: %s", err, wantCode, wantErr)
   380  			return
   381  		}
   382  		errCh <- nil
   383  	})
   384  
   385  	select {
   386  	case <-time.After(defaultTestTimeout):
   387  		t.Fatal("timeout when waiting for lookup callback to be invoked")
   388  	case err := <-errCh:
   389  		if err != nil {
   390  			t.Fatal(err)
   391  		}
   392  	}
   393  }
   394  
   395  // TestControlChannelCredsFailure tests creation of the control channel with
   396  // different credentials, which are expected to fail.
   397  func (s) TestControlChannelCredsFailure(t *testing.T) {
   398  	serverCreds := makeTLSCreds(t, "x509/server1_cert.pem", "x509/server1_key.pem", "x509/client_ca_cert.pem")
   399  	clientCreds := makeTLSCreds(t, "x509/client1_cert.pem", "x509/client1_key.pem", "x509/server_ca_cert.pem")
   400  
   401  	tests := []struct {
   402  		name     string
   403  		sopts    []grpc.ServerOption
   404  		bopts    balancer.BuildOptions
   405  		wantCode codes.Code
   406  		wantErr  string
   407  	}{
   408  		{
   409  			name:  "transport creds authority mismatch",
   410  			sopts: []grpc.ServerOption{grpc.Creds(serverCreds)},
   411  			bopts: balancer.BuildOptions{
   412  				DialCreds: clientCreds,
   413  				Authority: "authority-mismatch",
   414  			},
   415  			wantCode: codes.Unavailable,
   416  			wantErr:  "transport: authentication handshake failed: x509: certificate is valid for *.test.example.com, not authority-mismatch",
   417  		},
   418  		{
   419  			name:  "transport creds handshake failure",
   420  			sopts: nil, // server expects insecure connection
   421  			bopts: balancer.BuildOptions{
   422  				DialCreds: clientCreds,
   423  				Authority: "x.test.example.com",
   424  			},
   425  			wantCode: codes.Unavailable,
   426  			wantErr:  "transport: authentication handshake failed: tls: first record does not look like a TLS handshake",
   427  		},
   428  		{
   429  			name: "call creds mismatch",
   430  			sopts: []grpc.ServerOption{
   431  				grpc.Creds(serverCreds),
   432  				grpc.UnaryInterceptor(callCredsValidatingServerInterceptor), // server expects call creds
   433  			},
   434  			bopts: balancer.BuildOptions{
   435  				CredsBundle: &testCredsBundle{
   436  					transportCreds: clientCreds,
   437  					callCreds:      &testPerRPCCredentials{}, // sends no call creds
   438  				},
   439  				Authority: "x.test.example.com",
   440  			},
   441  			wantCode: codes.PermissionDenied,
   442  			wantErr:  "didn't find call creds",
   443  		},
   444  	}
   445  	for _, test := range tests {
   446  		t.Run(test.name, func(t *testing.T) {
   447  			testControlChannelCredsFailure(t, test.sopts, test.bopts, test.wantCode, test.wantErr)
   448  		})
   449  	}
   450  }
   451  
   452  type unsupportedCredsBundle struct {
   453  	credentials.Bundle
   454  }
   455  
   456  func (*unsupportedCredsBundle) NewWithMode(mode string) (credentials.Bundle, error) {
   457  	return nil, fmt.Errorf("unsupported mode: %v", mode)
   458  }
   459  
   460  // TestNewControlChannelUnsupportedCredsBundle tests the case where the control
   461  // channel is configured with a bundle which does not support the mode we use.
   462  func (s) TestNewControlChannelUnsupportedCredsBundle(t *testing.T) {
   463  	rlsServer, _ := setupFakeRLSServer(t, nil)
   464  
   465  	// Create the control channel to the fake server.
   466  	ctrlCh, err := newControlChannel(rlsServer.Address, defaultTestTimeout, balancer.BuildOptions{CredsBundle: &unsupportedCredsBundle{}}, nil)
   467  	if err == nil {
   468  		ctrlCh.close()
   469  		t.Fatal("newControlChannel succeeded when expected to fail")
   470  	}
   471  }