google.golang.org/grpc@v1.72.2/credentials/xds/xds_client_test.go (about)

     1  /*
     2   *
     3   * Copyright 2020 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package xds
    20  
    21  import (
    22  	"context"
    23  	"crypto/tls"
    24  	"crypto/x509"
    25  	"errors"
    26  	"fmt"
    27  	"net"
    28  	"os"
    29  	"strings"
    30  	"sync/atomic"
    31  	"testing"
    32  	"time"
    33  	"unsafe"
    34  
    35  	"google.golang.org/grpc/credentials"
    36  	"google.golang.org/grpc/credentials/tls/certprovider"
    37  	icredentials "google.golang.org/grpc/internal/credentials"
    38  	xdsinternal "google.golang.org/grpc/internal/credentials/xds"
    39  	"google.golang.org/grpc/internal/grpctest"
    40  	"google.golang.org/grpc/internal/testutils"
    41  	"google.golang.org/grpc/internal/xds/matcher"
    42  	"google.golang.org/grpc/resolver"
    43  	"google.golang.org/grpc/testdata"
    44  )
    45  
    46  const (
    47  	defaultTestTimeout      = 1 * time.Second
    48  	defaultTestShortTimeout = 10 * time.Millisecond
    49  	defaultTestCertSAN      = "abc.test.example.com"
    50  	authority               = "authority"
    51  )
    52  
    53  type s struct {
    54  	grpctest.Tester
    55  }
    56  
    57  func Test(t *testing.T) {
    58  	grpctest.RunSubTests(t, s{})
    59  }
    60  
    61  // Helper function to create a real TLS client credentials which is used as
    62  // fallback credentials from multiple tests.
    63  func makeFallbackClientCreds(t *testing.T) credentials.TransportCredentials {
    64  	creds, err := credentials.NewClientTLSFromFile(testdata.Path("x509/server_ca_cert.pem"), "x.test.example.com")
    65  	if err != nil {
    66  		t.Fatal(err)
    67  	}
    68  	return creds
    69  }
    70  
    71  // testServer is a no-op server which listens on a local TCP port for incoming
    72  // connections, and performs a manual TLS handshake on the received raw
    73  // connection using a user specified handshake function. It then makes the
    74  // result of the handshake operation available through a channel for tests to
    75  // inspect. Tests should stop the testServer as part of their cleanup.
    76  type testServer struct {
    77  	lis           net.Listener
    78  	address       string             // Listening address of the test server.
    79  	handshakeFunc testHandshakeFunc  // Test specified handshake function.
    80  	hsResult      *testutils.Channel // Channel to deliver handshake results.
    81  }
    82  
    83  // handshakeResult wraps the result of the handshake operation on the test
    84  // server. It consists of TLS connection state and an error, if the handshake
    85  // failed. This result is delivered on the `hsResult` channel on the testServer.
    86  type handshakeResult struct {
    87  	connState tls.ConnectionState
    88  	err       error
    89  }
    90  
    91  // Configurable handshake function for the testServer. Tests can set this to
    92  // simulate different conditions like handshake success, failure, timeout etc.
    93  type testHandshakeFunc func(net.Conn) handshakeResult
    94  
    95  // newTestServerWithHandshakeFunc starts a new testServer which listens for
    96  // connections on a local TCP port, and uses the provided custom handshake
    97  // function to perform TLS handshake.
    98  func newTestServerWithHandshakeFunc(f testHandshakeFunc) *testServer {
    99  	ts := &testServer{
   100  		handshakeFunc: f,
   101  		hsResult:      testutils.NewChannel(),
   102  	}
   103  	ts.start()
   104  	return ts
   105  }
   106  
   107  // starts actually starts listening on a local TCP port, and spawns a goroutine
   108  // to handle new connections.
   109  func (ts *testServer) start() error {
   110  	lis, err := net.Listen("tcp", "localhost:0")
   111  	if err != nil {
   112  		return err
   113  	}
   114  	ts.lis = lis
   115  	ts.address = lis.Addr().String()
   116  	go ts.handleConn()
   117  	return nil
   118  }
   119  
   120  // handleConn accepts a new raw connection, and invokes the test provided
   121  // handshake function to perform TLS handshake, and returns the result on the
   122  // `hsResult` channel.
   123  func (ts *testServer) handleConn() {
   124  	for {
   125  		rawConn, err := ts.lis.Accept()
   126  		if err != nil {
   127  			// Once the listeners closed, Accept() will return with an error.
   128  			return
   129  		}
   130  		hsr := ts.handshakeFunc(rawConn)
   131  		ts.hsResult.Send(hsr)
   132  	}
   133  }
   134  
   135  // stop closes the associated listener which causes the connection handling
   136  // goroutine to exit.
   137  func (ts *testServer) stop() {
   138  	ts.lis.Close()
   139  }
   140  
   141  // A handshake function which simulates a successful handshake without client
   142  // authentication (server does not request for client certificate during the
   143  // handshake here).
   144  func testServerTLSHandshake(rawConn net.Conn) handshakeResult {
   145  	cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
   146  	if err != nil {
   147  		return handshakeResult{err: err}
   148  	}
   149  	cfg := &tls.Config{
   150  		Certificates: []tls.Certificate{cert},
   151  		NextProtos:   []string{"h2"},
   152  	}
   153  	conn := tls.Server(rawConn, cfg)
   154  	if err := conn.Handshake(); err != nil {
   155  		return handshakeResult{err: err}
   156  	}
   157  	return handshakeResult{connState: conn.ConnectionState()}
   158  }
   159  
   160  // A handshake function which simulates a successful handshake with mutual
   161  // authentication.
   162  func testServerMutualTLSHandshake(rawConn net.Conn) handshakeResult {
   163  	cert, err := tls.LoadX509KeyPair(testdata.Path("x509/server1_cert.pem"), testdata.Path("x509/server1_key.pem"))
   164  	if err != nil {
   165  		return handshakeResult{err: err}
   166  	}
   167  	pemData, err := os.ReadFile(testdata.Path("x509/client_ca_cert.pem"))
   168  	if err != nil {
   169  		return handshakeResult{err: err}
   170  	}
   171  	roots := x509.NewCertPool()
   172  	roots.AppendCertsFromPEM(pemData)
   173  	cfg := &tls.Config{
   174  		Certificates: []tls.Certificate{cert},
   175  		ClientCAs:    roots,
   176  	}
   177  	conn := tls.Server(rawConn, cfg)
   178  	if err := conn.Handshake(); err != nil {
   179  		return handshakeResult{err: err}
   180  	}
   181  	return handshakeResult{connState: conn.ConnectionState()}
   182  }
   183  
   184  // fakeProvider is an implementation of the certprovider.Provider interface
   185  // which returns the configured key material and error in calls to
   186  // KeyMaterial().
   187  type fakeProvider struct {
   188  	km  *certprovider.KeyMaterial
   189  	err error
   190  }
   191  
   192  func (f *fakeProvider) KeyMaterial(context.Context) (*certprovider.KeyMaterial, error) {
   193  	return f.km, f.err
   194  }
   195  
   196  func (f *fakeProvider) Close() {}
   197  
   198  // makeIdentityProvider creates a new instance of the fakeProvider returning the
   199  // identity key material specified in the provider file paths.
   200  func makeIdentityProvider(t *testing.T, certPath, keyPath string) certprovider.Provider {
   201  	t.Helper()
   202  	cert, err := tls.LoadX509KeyPair(testdata.Path(certPath), testdata.Path(keyPath))
   203  	if err != nil {
   204  		t.Fatal(err)
   205  	}
   206  	return &fakeProvider{km: &certprovider.KeyMaterial{Certs: []tls.Certificate{cert}}}
   207  }
   208  
   209  // makeRootProvider creates a new instance of the fakeProvider returning the
   210  // root key material specified in the provider file paths.
   211  func makeRootProvider(t *testing.T, caPath string) *fakeProvider {
   212  	pemData, err := os.ReadFile(testdata.Path(caPath))
   213  	if err != nil {
   214  		t.Fatal(err)
   215  	}
   216  	roots := x509.NewCertPool()
   217  	roots.AppendCertsFromPEM(pemData)
   218  	return &fakeProvider{km: &certprovider.KeyMaterial{Roots: roots}}
   219  }
   220  
   221  // newTestContextWithHandshakeInfo returns a copy of parent with HandshakeInfo
   222  // context value added to it.
   223  func newTestContextWithHandshakeInfo(parent context.Context, root, identity certprovider.Provider, sanExactMatch string) context.Context {
   224  	// Creating the HandshakeInfo and adding it to the attributes is very
   225  	// similar to what the CDS balancer would do when it intercepts calls to
   226  	// NewSubConn().
   227  	var sms []matcher.StringMatcher
   228  	if sanExactMatch != "" {
   229  		sms = []matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(sanExactMatch), nil, nil, nil, nil, false)}
   230  	}
   231  	info := xdsinternal.NewHandshakeInfo(root, identity, sms, false)
   232  	uPtr := unsafe.Pointer(info)
   233  	addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, &uPtr)
   234  
   235  	// Moving the attributes from the resolver.Address to the context passed to
   236  	// the handshaker is done in the transport layer. Since we directly call the
   237  	// handshaker in these tests, we need to do the same here.
   238  	return icredentials.NewClientHandshakeInfoContext(parent, credentials.ClientHandshakeInfo{Attributes: addr.Attributes})
   239  }
   240  
   241  // compareAuthInfo compares the AuthInfo received on the client side after a
   242  // successful handshake with the authInfo available on the testServer.
   243  func compareAuthInfo(ctx context.Context, ts *testServer, ai credentials.AuthInfo) error {
   244  	if ai.AuthType() != "tls" {
   245  		return fmt.Errorf("ClientHandshake returned authType %q, want %q", ai.AuthType(), "tls")
   246  	}
   247  	info, ok := ai.(credentials.TLSInfo)
   248  	if !ok {
   249  		return fmt.Errorf("ClientHandshake returned authInfo of type %T, want %T", ai, credentials.TLSInfo{})
   250  	}
   251  	gotState := info.State
   252  
   253  	// Read the handshake result from the testServer which contains the TLS
   254  	// connection state and compare it with the one received on the client-side.
   255  	val, err := ts.hsResult.Receive(ctx)
   256  	if err != nil {
   257  		return fmt.Errorf("testServer failed to return handshake result: %v", err)
   258  	}
   259  	hsr := val.(handshakeResult)
   260  	if hsr.err != nil {
   261  		return fmt.Errorf("testServer handshake failure: %v", hsr.err)
   262  	}
   263  	// AuthInfo contains a variety of information. We only verify a subset here.
   264  	// This is the same subset which is verified in TLS credentials tests.
   265  	if err := compareConnState(gotState, hsr.connState); err != nil {
   266  		return err
   267  	}
   268  	return nil
   269  }
   270  
   271  func compareConnState(got, want tls.ConnectionState) error {
   272  	switch {
   273  	case got.Version != want.Version:
   274  		return fmt.Errorf("TLS.ConnectionState got Version: %v, want: %v", got.Version, want.Version)
   275  	case got.HandshakeComplete != want.HandshakeComplete:
   276  		return fmt.Errorf("TLS.ConnectionState got HandshakeComplete: %v, want: %v", got.HandshakeComplete, want.HandshakeComplete)
   277  	case got.CipherSuite != want.CipherSuite:
   278  		return fmt.Errorf("TLS.ConnectionState got CipherSuite: %v, want: %v", got.CipherSuite, want.CipherSuite)
   279  	case got.NegotiatedProtocol != want.NegotiatedProtocol:
   280  		return fmt.Errorf("TLS.ConnectionState got NegotiatedProtocol: %v, want: %v", got.NegotiatedProtocol, want.NegotiatedProtocol)
   281  	}
   282  	return nil
   283  }
   284  
   285  // TestClientCredsWithoutFallback verifies that the call to
   286  // NewClientCredentials() fails when no fallback is specified.
   287  func (s) TestClientCredsWithoutFallback(t *testing.T) {
   288  	if _, err := NewClientCredentials(ClientOptions{}); err == nil {
   289  		t.Fatal("NewClientCredentials() succeeded without specifying fallback")
   290  	}
   291  }
   292  
   293  // TestClientCredsInvalidHandshakeInfo verifies scenarios where the passed in
   294  // HandshakeInfo is invalid because it does not contain the expected certificate
   295  // providers.
   296  func (s) TestClientCredsInvalidHandshakeInfo(t *testing.T) {
   297  	opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
   298  	creds, err := NewClientCredentials(opts)
   299  	if err != nil {
   300  		t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
   301  	}
   302  
   303  	pCtx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   304  	defer cancel()
   305  	ctx := newTestContextWithHandshakeInfo(pCtx, nil, &fakeProvider{}, "")
   306  	if _, _, err := creds.ClientHandshake(ctx, authority, nil); err == nil {
   307  		t.Fatal("ClientHandshake succeeded without root certificate provider in HandshakeInfo")
   308  	}
   309  }
   310  
   311  // TestClientCredsProviderFailure verifies the cases where an expected
   312  // certificate provider is missing in the HandshakeInfo value in the context.
   313  func (s) TestClientCredsProviderFailure(t *testing.T) {
   314  	opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
   315  	creds, err := NewClientCredentials(opts)
   316  	if err != nil {
   317  		t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
   318  	}
   319  
   320  	tests := []struct {
   321  		desc             string
   322  		rootProvider     certprovider.Provider
   323  		identityProvider certprovider.Provider
   324  		wantErr          string
   325  	}{
   326  		{
   327  			desc:         "erroring root provider",
   328  			rootProvider: &fakeProvider{err: errors.New("root provider error")},
   329  			wantErr:      "root provider error",
   330  		},
   331  		{
   332  			desc:             "erroring identity provider",
   333  			rootProvider:     &fakeProvider{km: &certprovider.KeyMaterial{}},
   334  			identityProvider: &fakeProvider{err: errors.New("identity provider error")},
   335  			wantErr:          "identity provider error",
   336  		},
   337  	}
   338  	for _, test := range tests {
   339  		t.Run(test.desc, func(t *testing.T) {
   340  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   341  			defer cancel()
   342  			ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, test.identityProvider, "")
   343  			if _, _, err := creds.ClientHandshake(ctx, authority, nil); err == nil || !strings.Contains(err.Error(), test.wantErr) {
   344  				t.Fatalf("ClientHandshake() returned error: %q, wantErr: %q", err, test.wantErr)
   345  			}
   346  		})
   347  	}
   348  }
   349  
   350  // TestClientCredsSuccess verifies successful client handshake cases.
   351  func (s) TestClientCredsSuccess(t *testing.T) {
   352  	tests := []struct {
   353  		desc             string
   354  		handshakeFunc    testHandshakeFunc
   355  		handshakeInfoCtx func(ctx context.Context) context.Context
   356  	}{
   357  		{
   358  			desc:          "fallback",
   359  			handshakeFunc: testServerTLSHandshake,
   360  			handshakeInfoCtx: func(ctx context.Context) context.Context {
   361  				// Since we don't add a HandshakeInfo to the context, the
   362  				// ClientHandshake() method will delegate to the fallback.
   363  				return ctx
   364  			},
   365  		},
   366  		{
   367  			desc:          "TLS",
   368  			handshakeFunc: testServerTLSHandshake,
   369  			handshakeInfoCtx: func(ctx context.Context) context.Context {
   370  				return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), nil, defaultTestCertSAN)
   371  			},
   372  		},
   373  		{
   374  			desc:          "mTLS",
   375  			handshakeFunc: testServerMutualTLSHandshake,
   376  			handshakeInfoCtx: func(ctx context.Context) context.Context {
   377  				return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"), defaultTestCertSAN)
   378  			},
   379  		},
   380  		{
   381  			desc:          "mTLS with no acceptedSANs specified",
   382  			handshakeFunc: testServerMutualTLSHandshake,
   383  			handshakeInfoCtx: func(ctx context.Context) context.Context {
   384  				return newTestContextWithHandshakeInfo(ctx, makeRootProvider(t, "x509/server_ca_cert.pem"), makeIdentityProvider(t, "x509/server1_cert.pem", "x509/server1_key.pem"), "")
   385  			},
   386  		},
   387  	}
   388  
   389  	for _, test := range tests {
   390  		t.Run(test.desc, func(t *testing.T) {
   391  			ts := newTestServerWithHandshakeFunc(test.handshakeFunc)
   392  			defer ts.stop()
   393  
   394  			opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
   395  			creds, err := NewClientCredentials(opts)
   396  			if err != nil {
   397  				t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
   398  			}
   399  
   400  			conn, err := net.Dial("tcp", ts.address)
   401  			if err != nil {
   402  				t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
   403  			}
   404  			defer conn.Close()
   405  
   406  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   407  			defer cancel()
   408  			_, ai, err := creds.ClientHandshake(test.handshakeInfoCtx(ctx), authority, conn)
   409  			if err != nil {
   410  				t.Fatalf("ClientHandshake() returned failed: %q", err)
   411  			}
   412  			if err := compareAuthInfo(ctx, ts, ai); err != nil {
   413  				t.Fatal(err)
   414  			}
   415  		})
   416  	}
   417  }
   418  
   419  func (s) TestClientCredsHandshakeTimeout(t *testing.T) {
   420  	clientDone := make(chan struct{})
   421  	// A handshake function which simulates a handshake timeout from the
   422  	// server-side by simply blocking on the client-side handshake to timeout
   423  	// and not writing any handshake data.
   424  	hErr := errors.New("server handshake error")
   425  	ts := newTestServerWithHandshakeFunc(func(net.Conn) handshakeResult {
   426  		<-clientDone
   427  		return handshakeResult{err: hErr}
   428  	})
   429  	defer ts.stop()
   430  
   431  	opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
   432  	creds, err := NewClientCredentials(opts)
   433  	if err != nil {
   434  		t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
   435  	}
   436  
   437  	conn, err := net.Dial("tcp", ts.address)
   438  	if err != nil {
   439  		t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
   440  	}
   441  	defer conn.Close()
   442  
   443  	sCtx, sCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
   444  	defer sCancel()
   445  	ctx := newTestContextWithHandshakeInfo(sCtx, makeRootProvider(t, "x509/server_ca_cert.pem"), nil, defaultTestCertSAN)
   446  	if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil {
   447  		t.Fatal("ClientHandshake() succeeded when expected to timeout")
   448  	}
   449  	close(clientDone)
   450  
   451  	// Read the handshake result from the testServer and make sure the expected
   452  	// error is returned.
   453  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   454  	defer cancel()
   455  	val, err := ts.hsResult.Receive(ctx)
   456  	if err != nil {
   457  		t.Fatalf("testServer failed to return handshake result: %v", err)
   458  	}
   459  	hsr := val.(handshakeResult)
   460  	if hsr.err != hErr {
   461  		t.Fatalf("testServer handshake returned error: %v, want: %v", hsr.err, hErr)
   462  	}
   463  }
   464  
   465  // TestClientCredsHandshakeFailure verifies different handshake failure cases.
   466  func (s) TestClientCredsHandshakeFailure(t *testing.T) {
   467  	tests := []struct {
   468  		desc          string
   469  		handshakeFunc testHandshakeFunc
   470  		rootProvider  certprovider.Provider
   471  		san           string
   472  		wantErr       string
   473  	}{
   474  		{
   475  			desc:          "cert validation failure",
   476  			handshakeFunc: testServerTLSHandshake,
   477  			rootProvider:  makeRootProvider(t, "x509/client_ca_cert.pem"),
   478  			san:           defaultTestCertSAN,
   479  			wantErr:       "x509: certificate signed by unknown authority",
   480  		},
   481  		{
   482  			desc:          "SAN mismatch",
   483  			handshakeFunc: testServerTLSHandshake,
   484  			rootProvider:  makeRootProvider(t, "x509/server_ca_cert.pem"),
   485  			san:           "bad-san",
   486  			wantErr:       "do not match any of the accepted SANs",
   487  		},
   488  	}
   489  
   490  	for _, test := range tests {
   491  		t.Run(test.desc, func(t *testing.T) {
   492  			ts := newTestServerWithHandshakeFunc(test.handshakeFunc)
   493  			defer ts.stop()
   494  
   495  			opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
   496  			creds, err := NewClientCredentials(opts)
   497  			if err != nil {
   498  				t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
   499  			}
   500  
   501  			conn, err := net.Dial("tcp", ts.address)
   502  			if err != nil {
   503  				t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
   504  			}
   505  			defer conn.Close()
   506  
   507  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   508  			defer cancel()
   509  			ctx = newTestContextWithHandshakeInfo(ctx, test.rootProvider, nil, test.san)
   510  			if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil || !strings.Contains(err.Error(), test.wantErr) {
   511  				t.Fatalf("ClientHandshake() returned %q, wantErr %q", err, test.wantErr)
   512  			}
   513  		})
   514  	}
   515  }
   516  
   517  // TestClientCredsProviderSwitch verifies the case where the first attempt of
   518  // ClientHandshake fails because of a handshake failure. Then we update the
   519  // certificate provider and the second attempt succeeds. This is an
   520  // approximation of the flow of events when the control plane specifies new
   521  // security config which results in new certificate providers being used.
   522  func (s) TestClientCredsProviderSwitch(t *testing.T) {
   523  	ts := newTestServerWithHandshakeFunc(testServerTLSHandshake)
   524  	defer ts.stop()
   525  
   526  	opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
   527  	creds, err := NewClientCredentials(opts)
   528  	if err != nil {
   529  		t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
   530  	}
   531  
   532  	conn, err := net.Dial("tcp", ts.address)
   533  	if err != nil {
   534  		t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
   535  	}
   536  	defer conn.Close()
   537  
   538  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   539  	defer cancel()
   540  	// Create a root provider which will fail the handshake because it does not
   541  	// use the correct trust roots.
   542  	root1 := makeRootProvider(t, "x509/client_ca_cert.pem")
   543  	handshakeInfo := xdsinternal.NewHandshakeInfo(root1, nil, []matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(defaultTestCertSAN), nil, nil, nil, nil, false)}, false)
   544  	// We need to repeat most of what newTestContextWithHandshakeInfo() does
   545  	// here because we need access to the underlying HandshakeInfo so that we
   546  	// can update it before the next call to ClientHandshake().
   547  	uPtr := unsafe.Pointer(handshakeInfo)
   548  	addr := xdsinternal.SetHandshakeInfo(resolver.Address{}, &uPtr)
   549  	ctx = icredentials.NewClientHandshakeInfoContext(ctx, credentials.ClientHandshakeInfo{Attributes: addr.Attributes})
   550  	if _, _, err := creds.ClientHandshake(ctx, authority, conn); err == nil {
   551  		t.Fatal("ClientHandshake() succeeded when expected to fail")
   552  	}
   553  	// Drain the result channel on the test server so that we can inspect the
   554  	// result for the next handshake.
   555  	_, err = ts.hsResult.Receive(ctx)
   556  	if err != nil {
   557  		t.Errorf("testServer failed to return handshake result: %v", err)
   558  	}
   559  
   560  	conn, err = net.Dial("tcp", ts.address)
   561  	if err != nil {
   562  		t.Fatalf("net.Dial(%s) failed: %v", ts.address, err)
   563  	}
   564  	defer conn.Close()
   565  
   566  	// Create a new root provider which uses the correct trust roots. And update
   567  	// the HandshakeInfo with the new provider.
   568  	root2 := makeRootProvider(t, "x509/server_ca_cert.pem")
   569  	handshakeInfo = xdsinternal.NewHandshakeInfo(root2, nil, []matcher.StringMatcher{matcher.StringMatcherForTesting(newStringP(defaultTestCertSAN), nil, nil, nil, nil, false)}, false)
   570  	// Update the existing pointer, which address attribute will continue to
   571  	// point to.
   572  	atomic.StorePointer(&uPtr, unsafe.Pointer(handshakeInfo))
   573  	_, ai, err := creds.ClientHandshake(ctx, authority, conn)
   574  	if err != nil {
   575  		t.Fatalf("ClientHandshake() returned failed: %q", err)
   576  	}
   577  	if err := compareAuthInfo(ctx, ts, ai); err != nil {
   578  		t.Fatal(err)
   579  	}
   580  }
   581  
   582  // TestClientClone verifies the Clone() method on client credentials.
   583  func (s) TestClientClone(t *testing.T) {
   584  	opts := ClientOptions{FallbackCreds: makeFallbackClientCreds(t)}
   585  	orig, err := NewClientCredentials(opts)
   586  	if err != nil {
   587  		t.Fatalf("NewClientCredentials(%v) failed: %v", opts, err)
   588  	}
   589  
   590  	// The credsImpl does not have any exported fields, and it does not make
   591  	// sense to use any cmp options to look deep into. So, all we make sure here
   592  	// is that the cloned object points to a different location in memory.
   593  	if clone := orig.Clone(); clone == orig {
   594  		t.Fatal("return value from Clone() doesn't point to new credentials instance")
   595  	}
   596  }
   597  
   598  func newStringP(s string) *string {
   599  	return &s
   600  }