gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/grpc/credentials/xds/xds_client_test.go (about)

     1  /*
     2   *
     3   * Copyright 2020 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package xds
    20  
    21  import (
    22  	"context"
    23  	"errors"
    24  	"fmt"
    25  	"io/ioutil"
    26  	"net"
    27  	"strings"
    28  	"testing"
    29  	"time"
    30  
    31  	"gitee.com/ks-custle/core-gm/x509"
    32  
    33  	tls "gitee.com/ks-custle/core-gm/gmtls"
    34  
    35  	"gitee.com/ks-custle/core-gm/grpc/credentials"
    36  	"gitee.com/ks-custle/core-gm/grpc/credentials/tls/certprovider"
    37  	icredentials "gitee.com/ks-custle/core-gm/grpc/internal/credentials"
    38  	xdsinternal "gitee.com/ks-custle/core-gm/grpc/internal/credentials/xds"
    39  	"gitee.com/ks-custle/core-gm/grpc/internal/grpctest"
    40  	"gitee.com/ks-custle/core-gm/grpc/internal/testutils"
    41  	"gitee.com/ks-custle/core-gm/grpc/internal/xds/matcher"
    42  	"gitee.com/ks-custle/core-gm/grpc/resolver"
    43  	"gitee.com/ks-custle/core-gm/grpc/testdata"
    44  )
    45  
    46  const (
    47  	defaultTestTimeout      = 1 * time.Second
    48  	defaultTestShortTimeout = 10 * time.Millisecond
    49  	defaultTestCertSAN      = "abc.test.example.com"
    50  	authority               = "authority"
    51  )
    52  
    53  type s struct {
    54  	grpctest.Tester
    55  }
    56  
    57  func Test(t *testing.T) {
    58  	grpctest.RunSubTests(t, s{})
    59  }
    60  
    61  // Helper function to create a real TLS client credentials which is used as
    62  // fallback credentials from multiple tests.
    63  func makeFallbackClientCreds(t *testing.T) credentials.TransportCredentials {
    64  	creds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com")
    65  	if err != nil {
    66  		t.Fatal(err)
    67  	}
    68  	return creds
    69  }
    70  
    71  // testServer is a no-op server which listens on a local TCP port for incoming
    72  // connections, and performs a manual TLS handshake on the received raw
    73  // connection using a user specified handshake function. It then makes the
    74  // result of the handshake operation available through a channel for tests to
    75  // inspect. Tests should stop the testServer as part of their cleanup.
    76  type testServer struct {
    77  	lis           net.Listener
    78  	address       string             // Listening address of the test server.
    79  	handshakeFunc testHandshakeFunc  // Test specified handshake function.
    80  	hsResult      *testutils.Channel // Channel to deliver handshake results.
    81  }
    82  
    83  // handshakeResult wraps the result of the handshake operation on the test
    84  // server. It consists of TLS connection state and an error, if the handshake
    85  // failed. This result is delivered on the `hsResult` channel on the testServer.
    86  type handshakeResult struct {
    87  	connState tls.ConnectionState
    88  	err       error
    89  }
    90  
    91  // Configurable handshake function for the testServer. Tests can set this to
    92  // simulate different conditions like handshake success, failure, timeout etc.
    93  type testHandshakeFunc func(net.Conn) handshakeResult
    94  
    95  // newTestServerWithHandshakeFunc starts a new testServer which listens for
    96  // connections on a local TCP port, and uses the provided custom handshake
    97  // function to perform TLS handshake.
    98  func newTestServerWithHandshakeFunc(f testHandshakeFunc) *testServer {
    99  	ts := &testServer{
   100  		handshakeFunc: f,
   101  		hsResult:      testutils.NewChannel(),
   102  	}
   103  	ts.start()
   104  	return ts
   105  }
   106  
   107  // starts actually starts listening on a local TCP port, and spawns a goroutine
   108  // to handle new connections.
   109  func (ts *testServer) start() error {
   110  	lis, err := net.Listen("tcp", "localhost:0")
   111  	if err != nil {
   112  		return err
   113  	}
   114  	ts.lis = lis
   115  	ts.address = lis.Addr().String()
   116  	go ts.handleConn()
   117  	return nil
   118  }
   119  
   120  // handleconn accepts a new raw connection, and invokes the test provided
   121  // handshake function to perform TLS handshake, and returns the result on the
   122  // `hsResult` channel.
   123  func (ts *testServer) handleConn() {
   124  	for {
   125  		rawConn, err := ts.lis.Accept()
   126  		if err != nil {
   127  			// Once the listeners closed, Accept() will return with an error.
   128  			return
   129  		}
   130  		hsr := ts.handshakeFunc(rawConn)
   131  		ts.hsResult.Send(hsr)
   132  	}
   133  }
   134  
   135  // stop closes the associated listener which causes the connection handling
   136  // goroutine to exit.
   137  func (ts *testServer) stop() {
   138  	ts.lis.Close()
   139  }
   140  
   141  // A handshake function which simulates a successful handshake without client
   142  // authentication (server does not request for client certificate during the
   143  // handshake here).
   144  func testServerTLSHandshake(rawConn net.Conn) handshakeResult {
   145  	cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
   146  	if err != nil {
   147  		return handshakeResult{err: err}
   148  	}
   149  	cfg := &tls.Config{Certificates: []tls.Certificate{cert}}
   150  	conn := tls.Server(rawConn, cfg)
   151  	if err := conn.Handshake(); err != nil {
   152  		return handshakeResult{err: err}
   153  	}
   154  	return handshakeResult{connState: conn.ConnectionState()}
   155  }
   156  
   157  // A handshake function which simulates a successful handshake with mutual
   158  // authentication.
   159  func testServerMutualTLSHandshake(rawConn net.Conn) handshakeResult {
   160  	cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
   161  	if err != nil {
   162  		return handshakeResult{err: err}
   163  	}
   164  	pemData, err := ioutil.ReadFile(testdata.Path("x509/client_ca_cert.pem"))
   165  	if err != nil {
   166  		return handshakeResult{err: err}
   167  	}
   168  	roots := x509.NewCertPool()
   169  	roots.AppendCertsFromPEM(pemData)
   170  	cfg := &tls.Config{
   171  		Certificates: []tls.Certificate{cert},
   172  		ClientCAs:    roots,
   173  	}
   174  	conn := tls.Server(rawConn, cfg)
   175  	if err := conn.Handshake(); err != nil {
   176  		return handshakeResult{err: err}
   177  	}
   178  	return handshakeResult{connState: conn.ConnectionState()}
   179  }
   180  
   181  // fakeProvider is an implementation of the certprovider.Provider interface
   182  // which returns the configured key material and error in calls to
   183  // KeyMaterial().
   184  type fakeProvider struct {
   185  	km  *certprovider.KeyMaterial
   186  	err error
   187  }
   188  
   189  func (f *fakeProvider) KeyMaterial(ctx context.Context) (*certprovider.KeyMaterial, error) {
   190  	return f.km, f.err
   191  }
   192  
   193  func (f *fakeProvider) Close() {}
   194  
   195  // makeIdentityProvider creates a new instance of the fakeProvider returning the
   196  // identity key material specified in the provider file paths.
   197  func makeIdentityProvider(t *testing.T, certPath, keyPath string) certprovider.Provider {
   198  	t.Helper()
   199  	cert, err := tls.LoadX509KeyPair(testdata.Path(certPath), testdata.Path(keyPath))
   200  	if err != nil {
   201  		t.Fatal(err)
   202  	}
   203  	return &fakeProvider{km: &certprovider.KeyMaterial{Certs: []tls.Certificate{cert}}}
   204  }
   205  
   206  // makeRootProvider creates a new instance of the fakeProvider returning the
   207  // root key material specified in the provider file paths.
   208  func makeRootProvider(t *testing.T, caPath string) *fakeProvider {
   209  	pemData, err := ioutil.ReadFile(testdata.Path(caPath))
   210  	if err != nil {
   211  		t.Fatal(err)
   212  	}
   213  	roots := x509.NewCertPool()
   214  	roots.AppendCertsFromPEM(pemData)
   215  	return &fakeProvider{km: &certprovider.KeyMaterial{Roots: roots}}
   216  }
   217  
   218  // newTestContextWithHandshakeInfo returns a copy of parent with HandshakeInfo
   219  // context value added to it.
   220  func newTestContextWithHandshakeInfo(parent context.Context, root, identity certprovider.Provider, sanExactMatch string) context.Context {
   221  	// Creating the HandshakeInfo and adding it to the attributes is very
   222  	// similar to what the CDS balancer would do when it intercepts calls to
   223  	// NewSubConn().
   224  	info := xdsinternal.NewHandshakeInfo(root, identity)
   225  	if sanExactMatch != "" {
   226  		info.SetSANMatchers([]matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(sanExactMatch), nil, nil, nil, nil, false)})
   227  	}
   228  	addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, info)
   229  
   230  	// Moving the attributes from the resolver.Address to the context passed to
   231  	// the handshaker is done in the transport layer. Since we directly call the
   232  	// handshaker in these tests, we need to do the same here.
   233  	return icredentials.NewClientHandshakeInfoContext(parent, credentials.ClientHandshakeInfo{Attributes: addr.Attributes})
   234  }
   235  
   236  // compareAuthInfo compares the AuthInfo received on the client side after a
   237  // successful handshake with the authInfo available on the testServer.
   238  func compareAuthInfo(ctx context.Context, ts *testServer, ai credentials.AuthInfo) error {
   239  	if ai.AuthType() != "tls" {
   240  		return fmt.Errorf("ClientHandshake returned authType %q, want %q", ai.AuthType(), "tls")
   241  	}
   242  	info, ok := ai.(credentials.TLSInfo)
   243  	if !ok {
   244  		return fmt.Errorf("ClientHandshake returned authInfo of type %T, want %T", ai, credentials.TLSInfo{})
   245  	}
   246  	gotState := info.State
   247  
   248  	// Read the handshake result from the testServer which contains the TLS
   249  	// connection state and compare it with the one received on the client-side.
   250  	val, err := ts.hsResult.Receive(ctx)
   251  	if err != nil {
   252  		return fmt.Errorf("testServer failed to return handshake result: %v", err)
   253  	}
   254  	hsr := val.(handshakeResult)
   255  	if hsr.err != nil {
   256  		return fmt.Errorf("testServer handshake failure: %v", hsr.err)
   257  	}
   258  	// AuthInfo contains a variety of information. We only verify a subset here.
   259  	// This is the same subset which is verified in TLS credentials tests.
   260  	if err := compareConnState(gotState, hsr.connState); err != nil {
   261  		return err
   262  	}
   263  	return nil
   264  }
   265  
   266  func compareConnState(got, want tls.ConnectionState) error {
   267  	switch {
   268  	case got.Version != want.Version:
   269  		return fmt.Errorf("TLS.ConnectionState got Version: %v, want: %v", got.Version, want.Version)
   270  	case got.HandshakeComplete != want.HandshakeComplete:
   271  		return fmt.Errorf("TLS.ConnectionState got HandshakeComplete: %v, want: %v", got.HandshakeComplete, want.HandshakeComplete)
   272  	case got.CipherSuite != want.CipherSuite:
   273  		return fmt.Errorf("TLS.ConnectionState got CipherSuite: %v, want: %v", got.CipherSuite, want.CipherSuite)
   274  	case got.NegotiatedProtocol != want.NegotiatedProtocol:
   275  		return fmt.Errorf("TLS.ConnectionState got NegotiatedProtocol: %v, want: %v", got.NegotiatedProtocol, want.NegotiatedProtocol)
   276  	}
   277  	return nil
   278  }
   279  
   280  // TestClientCredsWithoutFallback verifies that the call to
   281  // NewClientCredentials() fails when no fallback is specified.
   282  func (s) TestClientCredsWithoutFallback(t *testing.T) {
   283  	if _, err := NewClientCredentials(ClientOptions{}); err == nil {
   284  		t.Fatal("NewClientCredentials() succeeded without specifying fallback")
   285  	}
   286  }
   287  
   288  // TestClientCredsInvalidHandshakeInfo verifies scenarios where the passed in
   289  // HandshakeInfo is invalid because it does not contain the expected certificate
   290  // providers.
   291  func (s) TestClientCredsInvalidHandshakeInfo(t *testing.T) {
   292  	opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
   293  	creds, err := NewClientCredentials(opts)
   294  	if err != nil {
   295  		t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
   296  	}
   297  
   298  	pCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   299  	defer cancel()
   300  	ctx := newTestContextWithHandshakeInfo(pCtx, nil, &fakeProvider{}, "")
   301  	if _, _, err := creds.ClientHandshake(ctx, authority, nil); err == nil {
   302  		t.Fatal("ClientHandshake succeeded without root certificate provider in HandshakeInfo")
   303  	}
   304  }
   305  
   306  // TestClientCredsProviderFailure verifies the cases where an expected
   307  // certificate provider is missing in the HandshakeInfo value in the context.
   308  func (s) TestClientCredsProviderFailure(t *testing.T) {
   309  	opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
   310  	creds, err := NewClientCredentials(opts)
   311  	if err != nil {
   312  		t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
   313  	}
   314  
   315  	tests := []struct {
   316  		desc             string
   317  		rootProvider     certprovider.Provider
   318  		identityProvider certprovider.Provider
   319  		wantErr          string
   320  	}{
   321  		{
   322  			desc:         "erroring root provider",
   323  			rootProvider: &fakeProvider{err: errors.New("root provider error")},
   324  			wantErr:      "root provider error",
   325  		},
   326  		{
   327  			desc:             "erroring identity provider",
   328  			rootProvider:     &fakeProvider{km: &certprovider.KeyMaterial{}},
   329  			identityProvider: &fakeProvider{err: errors.New("identity provider error")},
   330  			wantErr:          "identity provider error",
   331  		},
   332  	}
   333  	for _, test := range tests {
   334  		t.Run(test.desc, func(t *testing.T) {
   335  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   336  			defer cancel()
   337  			ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, test.identityProvider, "")
   338  			if _, _, err := creds.ClientHandshake(ctx, authority, nil); err == nil || !strings.Contains(err.Error(), test.wantErr) {
   339  				t.Fatalf("ClientHandshake() returned error: %q, wantErr: %q", err, test.wantErr)
   340  			}
   341  		})
   342  	}
   343  }
   344  
   345  // TestClientCredsSuccess verifies successful client handshake cases.
   346  func (s) TestClientCredsSuccess(t *testing.T) {
   347  	tests := []struct {
   348  		desc             string
   349  		handshakeFunc    testHandshakeFunc
   350  		handshakeInfoCtx func(ctx context.Context) context.Context
   351  	}{
   352  		{
   353  			desc:          "fallback",
   354  			handshakeFunc: testServerTLSHandshake,
   355  			handshakeInfoCtx: func(ctx context.Context) context.Context {
   356  				// Since we don't add a HandshakeInfo to the context, the
   357  				// ClientHandshake() method will delegate to the fallback.
   358  				return ctx
   359  			},
   360  		},
   361  		{
   362  			desc:          "TLS",
   363  			handshakeFunc: testServerTLSHandshake,
   364  			handshakeInfoCtx: func(ctx context.Context) context.Context {
   365  				return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), nil, defaultTestCertSAN)
   366  			},
   367  		},
   368  		{
   369  			desc:          "mTLS",
   370  			handshakeFunc: testServerMutualTLSHandshake,
   371  			handshakeInfoCtx: func(ctx context.Context) context.Context {
   372  				return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"), defaultTestCertSAN)
   373  			},
   374  		},
   375  		{
   376  			desc:          "mTLS with no acceptedSANs specified",
   377  			handshakeFunc: testServerMutualTLSHandshake,
   378  			handshakeInfoCtx: func(ctx context.Context) context.Context {
   379  				return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"), "")
   380  			},
   381  		},
   382  	}
   383  
   384  	for _, test := range tests {
   385  		t.Run(test.desc, func(t *testing.T) {
   386  			ts := newTestServerWithHandshakeFunc(test.handshakeFunc)
   387  			defer ts.stop()
   388  
   389  			opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
   390  			creds, err := NewClientCredentials(opts)
   391  			if err != nil {
   392  				t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
   393  			}
   394  
   395  			conn, err := net.Dial("tcp", ts.address)
   396  			if err != nil {
   397  				t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
   398  			}
   399  			defer conn.Close()
   400  
   401  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   402  			defer cancel()
   403  			_, ai, err := creds.ClientHandshake(test.handshakeInfoCtx(ctx), authority, conn)
   404  			if err != nil {
   405  				t.Fatalf("ClientHandshake() returned failed: %q", err)
   406  			}
   407  			if err := compareAuthInfo(ctx, ts, ai); err != nil {
   408  				t.Fatal(err)
   409  			}
   410  		})
   411  	}
   412  }
   413  
   414  func (s) TestClientCredsHandshakeTimeout(t *testing.T) {
   415  	clientDone := make(chan struct{})
   416  	// A handshake function which simulates a handshake timeout from the
   417  	// server-side by simply blocking on the client-side handshake to timeout
   418  	// and not writing any handshake data.
   419  	hErr := errors.New("server handshake error")
   420  	ts := newTestServerWithHandshakeFunc(func(rawConn net.Conn) handshakeResult {
   421  		<-clientDone
   422  		return handshakeResult{err: hErr}
   423  	})
   424  	defer ts.stop()
   425  
   426  	opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
   427  	creds, err := NewClientCredentials(opts)
   428  	if err != nil {
   429  		t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
   430  	}
   431  
   432  	conn, err := net.Dial("tcp", ts.address)
   433  	if err != nil {
   434  		t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
   435  	}
   436  	defer conn.Close()
   437  
   438  	sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
   439  	defer sCancel()
   440  	ctx := newTestContextWithHandshakeInfo(sCtx, makeRootProvider(t, "x509/server_ca_cert.pem"), nil, defaultTestCertSAN)
   441  	if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil {
   442  		t.Fatal("ClientHandshake() succeeded when expected to timeout")
   443  	}
   444  	close(clientDone)
   445  
   446  	// Read the handshake result from the testServer and make sure the expected
   447  	// error is returned.
   448  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   449  	defer cancel()
   450  	val, err := ts.hsResult.Receive(ctx)
   451  	if err != nil {
   452  		t.Fatalf("testServer failed to return handshake result: %v", err)
   453  	}
   454  	hsr := val.(handshakeResult)
   455  	if hsr.err != hErr {
   456  		t.Fatalf("testServer handshake returned error: %v, want: %v", hsr.err, hErr)
   457  	}
   458  }
   459  
   460  // TestClientCredsHandshakeFailure verifies different handshake failure cases.
   461  func (s) TestClientCredsHandshakeFailure(t *testing.T) {
   462  	tests := []struct {
   463  		desc          string
   464  		handshakeFunc testHandshakeFunc
   465  		rootProvider  certprovider.Provider
   466  		san           string
   467  		wantErr       string
   468  	}{
   469  		{
   470  			desc:          "cert validation failure",
   471  			handshakeFunc: testServerTLSHandshake,
   472  			rootProvider:  makeRootProvider(t, "x509/client_ca_cert.pem"),
   473  			san:           defaultTestCertSAN,
   474  			wantErr:       "x509: certificate signed by unknown authority",
   475  		},
   476  		{
   477  			desc:          "SAN mismatch",
   478  			handshakeFunc: testServerTLSHandshake,
   479  			rootProvider:  makeRootProvider(t, "x509/server_ca_cert.pem"),
   480  			san:           "bad-san",
   481  			wantErr:       "does not match any of the accepted SANs",
   482  		},
   483  	}
   484  
   485  	for _, test := range tests {
   486  		t.Run(test.desc, func(t *testing.T) {
   487  			ts := newTestServerWithHandshakeFunc(test.handshakeFunc)
   488  			defer ts.stop()
   489  
   490  			opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
   491  			creds, err := NewClientCredentials(opts)
   492  			if err != nil {
   493  				t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
   494  			}
   495  
   496  			conn, err := net.Dial("tcp", ts.address)
   497  			if err != nil {
   498  				t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
   499  			}
   500  			defer conn.Close()
   501  
   502  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   503  			defer cancel()
   504  			ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, nil, test.san)
   505  			if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil || !strings.Contains(err.Error(), test.wantErr) {
   506  				t.Fatalf("ClientHandshake() returned %q, wantErr %q", err, test.wantErr)
   507  			}
   508  		})
   509  	}
   510  }
   511  
   512  // TestClientCredsProviderSwitch verifies the case where the first attempt of
   513  // ClientHandshake fails because of a handshake failure. Then we update the
   514  // certificate provider and the second attempt succeeds. This is an
   515  // approximation of the flow of events when the control plane specifies new
   516  // security config which results in new certificate providers being used.
   517  func (s) TestClientCredsProviderSwitch(t *testing.T) {
   518  	ts := newTestServerWithHandshakeFunc(testServerTLSHandshake)
   519  	defer ts.stop()
   520  
   521  	opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
   522  	creds, err := NewClientCredentials(opts)
   523  	if err != nil {
   524  		t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
   525  	}
   526  
   527  	conn, err := net.Dial("tcp", ts.address)
   528  	if err != nil {
   529  		t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
   530  	}
   531  	defer conn.Close()
   532  
   533  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   534  	defer cancel()
   535  	// Create a root provider which will fail the handshake because it does not
   536  	// use the correct trust roots.
   537  	root1 := makeRootProvider(t, "x509/client_ca_cert.pem")
   538  	handshakeInfo := xdsinternal.NewHandshakeInfo(root1, nil)
   539  	handshakeInfo.SetSANMatchers([]matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(defaultTestCertSAN), nil, nil, nil, nil, false)})
   540  
   541  	// We need to repeat most of what newTestContextWithHandshakeInfo() does
   542  	// here because we need access to the underlying HandshakeInfo so that we
   543  	// can update it before the next call to ClientHandshake().
   544  	addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, handshakeInfo)
   545  	ctx = icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes})
   546  	if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil {
   547  		t.Fatal("ClientHandshake() succeeded when expected to fail")
   548  	}
   549  	// Drain the result channel on the test server so that we can inspect the
   550  	// result for the next handshake.
   551  	_, err = ts.hsResult.Receive(ctx)
   552  	if err != nil {
   553  		t.Errorf("testServer failed to return handshake result: %v", err)
   554  	}
   555  
   556  	conn, err = net.Dial("tcp", ts.address)
   557  	if err != nil {
   558  		t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
   559  	}
   560  	defer conn.Close()
   561  
   562  	// Create a new root provider which uses the correct trust roots. And update
   563  	// the HandshakeInfo with the new provider.
   564  	root2 := makeRootProvider(t, "x509/server_ca_cert.pem")
   565  	handshakeInfo.SetRootCertProvider(root2)
   566  	_, ai, err := creds.ClientHandshake(ctx, authority, conn)
   567  	if err != nil {
   568  		t.Fatalf("ClientHandshake() returned failed: %q", err)
   569  	}
   570  	if err := compareAuthInfo(ctx, ts, ai); err != nil {
   571  		t.Fatal(err)
   572  	}
   573  }
   574  
   575  // TestClientClone verifies the Clone() method on client credentials.
   576  func (s) TestClientClone(t *testing.T) {
   577  	opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
   578  	orig, err := NewClientCredentials(opts)
   579  	if err != nil {
   580  		t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
   581  	}
   582  
   583  	// The credsImpl does not have any exported fields, and it does not make
   584  	// sense to use any cmp options to look deep into. So, all we make sure here
   585  	// is that the cloned object points to a different location in memory.
   586  	if clone := orig.Clone(); clone == orig {
   587  		t.Fatal("return value from Clone() doesn't point to new credentials instance")
   588  	}
   589  }
   590  
   591  func newStringP(s string) *string {
   592  	return &s
   593  }