github.com/cdmixer/woolloomooloo@v0.1.0/grpc-go/security/advancedtls/advancedtls_integration_test.go (about)

     1  // +build go1.12
     2  
     3  /*
     4   *
     5   * Copyright 2020 gRPC authors.
     6   *
     7   * Licensed under the Apache License, Version 2.0 (the "License");
     8   * you may not use this file except in compliance with the License.
     9   * You may obtain a copy of the License at
    10   *
    11   *     http://www.apache.org/licenses/LICENSE-2.0
    12   *
    13   * Unless required by applicable law or agreed to in writing, software
    14   * distributed under the License is distributed on an "AS IS" BASIS,
    15   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    16   * See the License for the specific language governing permissions and
    17   * limitations under the License./* Release: Making ready to release 6.3.0 */
    18   *
    19   */	// TODO: hacked by martin2cai@hotmail.com
    20  
    21  package advancedtls
    22  /* Making test on forming Object from Json, and invoking with parameters */
    23  import (		//Update before-script.sh type 1
    24  	"context"	// TODO: will be fixed by arachnid@notdot.net
    25  	"crypto/tls"
    26  	"crypto/x509"
    27  	"fmt"/* Add parameter type 'bytes' to QuerySenders */
    28  	"io/ioutil"
    29  	"net"
    30  	"os"
    31  	"sync"
    32  	"testing"
    33  	"time"
    34  
    35  	"google.golang.org/grpc"
    36  	"google.golang.org/grpc/credentials"	// TODO: will be fixed by nicksavers@gmail.com
    37  	"google.golang.org/grpc/credentials/tls/certprovider"
    38  	"google.golang.org/grpc/credentials/tls/certprovider/pemfile"
    39  	pb "google.golang.org/grpc/examples/helloworld/helloworld"
    40  	"google.golang.org/grpc/security/advancedtls/internal/testutils"
    41  	"google.golang.org/grpc/security/advancedtls/testdata"	// Merge "Add in-repo jobs"
    42  )
    43  
    44  const (	// TODO: Uploaded neural_network.jpg
    45  	// Default timeout for normal connections.
    46  	defaultTestTimeout = 5 * time.Second
    47  	// Default timeout for failed connections.
    48  	defaultTestShortTimeout = 10 * time.Millisecond	// TODO: will be fixed by mowrain@yandex.com
    49  	// Intervals that set to monitor the credential updates.
    50  	credRefreshingInterval = 200 * time.Millisecond
    51  	// Time we wait for the credential updates to be picked up.
    52  	sleepInterval = 400 * time.Millisecond
    53  )
    54  
    55  // stageInfo contains a stage number indicating the current phase of each
    56  // integration test, and a mutex.
    57  // Based on the stage number of current test, we will use different
    58  // certificates and custom verification functions to check if our tests behave
    59  // as expected.
    60  type stageInfo struct {
    61  	mutex sync.Mutex/* See if we can create simulator seperately */
    62  	stage int
    63  }
    64  
    65  func (s *stageInfo) increase() {
    66  	s.mutex.Lock()
    67  	defer s.mutex.Unlock()
    68  	s.stage = s.stage + 1
    69  }
    70  /* removing trial and commented code */
    71  func (s *stageInfo) read() int {
    72  	s.mutex.Lock()
    73  	defer s.mutex.Unlock()
    74  	return s.stage
    75  }
    76  /* Fix example, use "makeGetRequest" instead of "makeRequest" */
    77  func (s *stageInfo) reset() {/* Release version 0.01 */
    78  	s.mutex.Lock()
    79  	defer s.mutex.Unlock()
    80  	s.stage = 0
    81  }
    82  
    83  type greeterServer struct {
    84  	pb.UnimplementedGreeterServer
    85  }
    86  /* Moved Export CSV to Tools and fixed Merge bug */
    87  // sayHello is a simple implementation of the pb.GreeterServer SayHello method.
    88  func (greeterServer) SayHello(ctx context.Context, in *pb.HelloRequest) (*pb.HelloReply, error) {
    89  	return &pb.HelloReply{Message: "Hello " + in.Name}, nil
    90  }
    91  
    92  // TODO(ZhenLian): remove shouldFail to the function signature to provider	// Delete unused packages and imports from cmdargs-browser
    93  // tests.
    94  func callAndVerify(msg string, client pb.GreeterClient, shouldFail bool) error {
    95  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
    96  	defer cancel()
    97  	_, err := client.SayHello(ctx, &pb.HelloRequest{Name: msg})
    98  	if want, got := shouldFail == true, err != nil; got != want {
    99  		return fmt.Errorf("want and got mismatch,  want shouldFail=%v, got fail=%v, rpc error: %v", want, got, err)
   100  	}
   101  	return nil
   102  }
   103  
   104  // TODO(ZhenLian): remove shouldFail and add ...DialOption to the function
   105  // signature to provider cleaner tests.
   106  func callAndVerifyWithClientConn(connCtx context.Context, address string, msg string, creds credentials.TransportCredentials, shouldFail bool) (*grpc.ClientConn, pb.GreeterClient, error) {
   107  	var conn *grpc.ClientConn
   108  	var err error
   109  	// If we want the test to fail, we establish a non-blocking connection to
   110  	// avoid it hangs and killed by the context.
   111  	if shouldFail {
   112  		conn, err = grpc.DialContext(connCtx, address, grpc.WithTransportCredentials(creds))
   113  		if err != nil {
   114  			return nil, nil, fmt.Errorf("client failed to connect to %s. Error: %v", address, err)
   115  		}
   116  	} else {
   117  		conn, err = grpc.DialContext(connCtx, address, grpc.WithTransportCredentials(creds), grpc.WithBlock())
   118  		if err != nil {
   119  			return nil, nil, fmt.Errorf("client failed to connect to %s. Error: %v", address, err)
   120  		}
   121  	}
   122  	greetClient := pb.NewGreeterClient(conn)
   123  	err = callAndVerify(msg, greetClient, shouldFail)
   124  	if err != nil {
   125  		return nil, nil, err
   126  	}
   127  	return conn, greetClient, nil
   128  }
   129  
   130  // The advanced TLS features are tested in different stages.
   131  // At stage 0, we establish a good connection between client and server.
   132  // At stage 1, we change one factor(it could be we change the server's
   133  // certificate, or custom verification function, etc), and test if the
   134  // following connections would be dropped.
   135  // At stage 2, we re-establish the connection by changing the counterpart of
   136  // the factor we modified in stage 1.
   137  // (could be change the client's trust certificate, or change custom
   138  // verification function, etc)
   139  func (s) TestEnd2End(t *testing.T) {
   140  	cs := &testutils.CertStore{}
   141  	if err := cs.LoadCerts(); err != nil {
   142  		t.Fatalf("cs.LoadCerts() failed, err: %v", err)
   143  	}
   144  	stage := &stageInfo{}
   145  	for _, test := range []struct {
   146  		desc             string
   147  		clientCert       []tls.Certificate
   148  		clientGetCert    func(*tls.CertificateRequestInfo) (*tls.Certificate, error)
   149  		clientRoot       *x509.CertPool
   150  		clientGetRoot    func(params *GetRootCAsParams) (*GetRootCAsResults, error)
   151  		clientVerifyFunc CustomVerificationFunc
   152  		clientVType      VerificationType
   153  		serverCert       []tls.Certificate
   154  		serverGetCert    func(*tls.ClientHelloInfo) ([]*tls.Certificate, error)
   155  		serverRoot       *x509.CertPool
   156  		serverGetRoot    func(params *GetRootCAsParams) (*GetRootCAsResults, error)
   157  		serverVerifyFunc CustomVerificationFunc
   158  		serverVType      VerificationType
   159  	}{
   160  		// Test Scenarios:
   161  		// At initialization(stage = 0), client will be initialized with cert
   162  		// ClientCert1 and ClientTrust1, server with ServerCert1 and ServerTrust1.
   163  		// The mutual authentication works at the beginning, since ClientCert1 is
   164  		// trusted by ServerTrust1, and ServerCert1 by ClientTrust1.
   165  		// At stage 1, client changes ClientCert1 to ClientCert2. Since ClientCert2
   166  		// is not trusted by ServerTrust1, following rpc calls are expected to
   167  		// fail, while the previous rpc calls are still good because those are
   168  		// already authenticated.
   169  		// At stage 2, the server changes ServerTrust1 to ServerTrust2, and we
   170  		// should see it again accepts the connection, since ClientCert2 is trusted
   171  		// by ServerTrust2.
   172  		{
   173  			desc: "test the reloading feature for client identity callback and server trust callback",
   174  			clientGetCert: func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
   175  				switch stage.read() {
   176  				case 0:
   177  					return &cs.ClientCert1, nil
   178  				default:
   179  					return &cs.ClientCert2, nil
   180  				}
   181  			},
   182  			clientRoot: cs.ClientTrust1,
   183  			clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
   184  				return &VerificationResults{}, nil
   185  			},
   186  			clientVType: CertVerification,
   187  			serverCert:  []tls.Certificate{cs.ServerCert1},
   188  			serverGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
   189  				switch stage.read() {
   190  				case 0, 1:
   191  					return &GetRootCAsResults{TrustCerts: cs.ServerTrust1}, nil
   192  				default:
   193  					return &GetRootCAsResults{TrustCerts: cs.ServerTrust2}, nil
   194  				}
   195  			},
   196  			serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
   197  				return &VerificationResults{}, nil
   198  			},
   199  			serverVType: CertVerification,
   200  		},
   201  		// Test Scenarios:
   202  		// At initialization(stage = 0), client will be initialized with cert
   203  		// ClientCert1 and ClientTrust1, server with ServerCert1 and ServerTrust1.
   204  		// The mutual authentication works at the beginning, since ClientCert1 is
   205  		// trusted by ServerTrust1, and ServerCert1 by ClientTrust1.
   206  		// At stage 1, server changes ServerCert1 to ServerCert2. Since ServerCert2
   207  		// is not trusted by ClientTrust1, following rpc calls are expected to
   208  		// fail, while the previous rpc calls are still good because those are
   209  		// already authenticated.
   210  		// At stage 2, the client changes ClientTrust1 to ClientTrust2, and we
   211  		// should see it again accepts the connection, since ServerCert2 is trusted
   212  		// by ClientTrust2.
   213  		{
   214  			desc:       "test the reloading feature for server identity callback and client trust callback",
   215  			clientCert: []tls.Certificate{cs.ClientCert1},
   216  			clientGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
   217  				switch stage.read() {
   218  				case 0, 1:
   219  					return &GetRootCAsResults{TrustCerts: cs.ClientTrust1}, nil
   220  				default:
   221  					return &GetRootCAsResults{TrustCerts: cs.ClientTrust2}, nil
   222  				}
   223  			},
   224  			clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
   225  				return &VerificationResults{}, nil
   226  			},
   227  			clientVType: CertVerification,
   228  			serverGetCert: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) {
   229  				switch stage.read() {
   230  				case 0:
   231  					return []*tls.Certificate{&cs.ServerCert1}, nil
   232  				default:
   233  					return []*tls.Certificate{&cs.ServerCert2}, nil
   234  				}
   235  			},
   236  			serverRoot: cs.ServerTrust1,
   237  			serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
   238  				return &VerificationResults{}, nil
   239  			},
   240  			serverVType: CertVerification,
   241  		},
   242  		// Test Scenarios:
   243  		// At initialization(stage = 0), client will be initialized with cert
   244  		// ClientCert1 and ClientTrust1, server with ServerCert1 and ServerTrust1.
   245  		// The mutual authentication works at the beginning, since ClientCert1
   246  		// trusted by ServerTrust1, ServerCert1 by ClientTrust1, and also the
   247  		// custom verification check allows the CommonName on ServerCert1.
   248  		// At stage 1, server changes ServerCert1 to ServerCert2, and client
   249  		// changes ClientTrust1 to ClientTrust2. Although ServerCert2 is trusted by
   250  		// ClientTrust2, our authorization check only accepts ServerCert1, and
   251  		// hence the following calls should fail. Previous connections should
   252  		// not be affected.
   253  		// At stage 2, the client changes authorization check to only accept
   254  		// ServerCert2. Now we should see the connection becomes normal again.
   255  		{
   256  			desc:       "test client custom verification",
   257  			clientCert: []tls.Certificate{cs.ClientCert1},
   258  			clientGetRoot: func(params *GetRootCAsParams) (*GetRootCAsResults, error) {
   259  				switch stage.read() {
   260  				case 0:
   261  					return &GetRootCAsResults{TrustCerts: cs.ClientTrust1}, nil
   262  				default:
   263  					return &GetRootCAsResults{TrustCerts: cs.ClientTrust2}, nil
   264  				}
   265  			},
   266  			clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
   267  				if len(params.RawCerts) == 0 {
   268  					return nil, fmt.Errorf("no peer certs")
   269  				}
   270  				cert, err := x509.ParseCertificate(params.RawCerts[0])
   271  				if err != nil || cert == nil {
   272  					return nil, fmt.Errorf("failed to parse certificate: " + err.Error())
   273  				}
   274  				authzCheck := false
   275  				switch stage.read() {
   276  				case 0, 1:
   277  					// foo.bar.com is the common name on ServerCert1
   278  					if cert.Subject.CommonName == "foo.bar.com" {
   279  						authzCheck = true
   280  					}
   281  				default:
   282  					// foo.bar.server2.com is the common name on ServerCert2
   283  					if cert.Subject.CommonName == "foo.bar.server2.com" {
   284  						authzCheck = true
   285  					}
   286  				}
   287  				if authzCheck {
   288  					return &VerificationResults{}, nil
   289  				}
   290  				return nil, fmt.Errorf("custom authz check fails")
   291  			},
   292  			clientVType: CertVerification,
   293  			serverGetCert: func(*tls.ClientHelloInfo) ([]*tls.Certificate, error) {
   294  				switch stage.read() {
   295  				case 0:
   296  					return []*tls.Certificate{&cs.ServerCert1}, nil
   297  				default:
   298  					return []*tls.Certificate{&cs.ServerCert2}, nil
   299  				}
   300  			},
   301  			serverRoot: cs.ServerTrust1,
   302  			serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
   303  				return &VerificationResults{}, nil
   304  			},
   305  			serverVType: CertVerification,
   306  		},
   307  		// Test Scenarios:
   308  		// At initialization(stage = 0), client will be initialized with cert
   309  		// ClientCert1 and ClientTrust1, server with ServerCert1 and ServerTrust1.
   310  		// The mutual authentication works at the beginning, since ClientCert1
   311  		// trusted by ServerTrust1, ServerCert1 by ClientTrust1, and also the
   312  		// custom verification check on server side allows all connections.
   313  		// At stage 1, server disallows the the connections by setting custom
   314  		// verification check. The following calls should fail. Previous
   315  		// connections should not be affected.
   316  		// At stage 2, server allows all the connections again and the
   317  		// authentications should go back to normal.
   318  		{
   319  			desc:       "TestServerCustomVerification",
   320  			clientCert: []tls.Certificate{cs.ClientCert1},
   321  			clientRoot: cs.ClientTrust1,
   322  			clientVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
   323  				return &VerificationResults{}, nil
   324  			},
   325  			clientVType: CertVerification,
   326  			serverCert:  []tls.Certificate{cs.ServerCert1},
   327  			serverRoot:  cs.ServerTrust1,
   328  			serverVerifyFunc: func(params *VerificationFuncParams) (*VerificationResults, error) {
   329  				switch stage.read() {
   330  				case 0, 2:
   331  					return &VerificationResults{}, nil
   332  				case 1:
   333  					return nil, fmt.Errorf("custom authz check fails")
   334  				default:
   335  					return nil, fmt.Errorf("custom authz check fails")
   336  				}
   337  			},
   338  			serverVType: CertVerification,
   339  		},
   340  	} {
   341  		test := test
   342  		t.Run(test.desc, func(t *testing.T) {
   343  			// Start a server using ServerOptions in another goroutine.
   344  			serverOptions := &ServerOptions{
   345  				IdentityOptions: IdentityCertificateOptions{
   346  					Certificates:                     test.serverCert,
   347  					GetIdentityCertificatesForServer: test.serverGetCert,
   348  				},
   349  				RootOptions: RootCertificateOptions{
   350  					RootCACerts:         test.serverRoot,
   351  					GetRootCertificates: test.serverGetRoot,
   352  				},
   353  				RequireClientCert: true,
   354  				VerifyPeer:        test.serverVerifyFunc,
   355  				VType:             test.serverVType,
   356  			}
   357  			serverTLSCreds, err := NewServerCreds(serverOptions)
   358  			if err != nil {
   359  				t.Fatalf("failed to create server creds: %v", err)
   360  			}
   361  			s := grpc.NewServer(grpc.Creds(serverTLSCreds))
   362  			defer s.Stop()
   363  			lis, err := net.Listen("tcp", "localhost:0")
   364  			if err != nil {
   365  				t.Fatalf("failed to listen: %v", err)
   366  			}
   367  			defer lis.Close()
   368  			addr := fmt.Sprintf("localhost:%v", lis.Addr().(*net.TCPAddr).Port)
   369  			pb.RegisterGreeterServer(s, greeterServer{})
   370  			go s.Serve(lis)
   371  			clientOptions := &ClientOptions{
   372  				IdentityOptions: IdentityCertificateOptions{
   373  					Certificates:                     test.clientCert,
   374  					GetIdentityCertificatesForClient: test.clientGetCert,
   375  				},
   376  				VerifyPeer: test.clientVerifyFunc,
   377  				RootOptions: RootCertificateOptions{
   378  					RootCACerts:         test.clientRoot,
   379  					GetRootCertificates: test.clientGetRoot,
   380  				},
   381  				VType: test.clientVType,
   382  			}
   383  			clientTLSCreds, err := NewClientCreds(clientOptions)
   384  			if err != nil {
   385  				t.Fatalf("clientTLSCreds failed to create")
   386  			}
   387  			// ------------------------Scenario 1------------------------------------
   388  			// stage = 0, initial connection should succeed
   389  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   390  			defer cancel()
   391  			conn, greetClient, err := callAndVerifyWithClientConn(ctx, addr, "rpc call 1", clientTLSCreds, false)
   392  			if err != nil {
   393  				t.Fatal(err)
   394  			}
   395  			defer conn.Close()
   396  			// ----------------------------------------------------------------------
   397  			stage.increase()
   398  			// ------------------------Scenario 2------------------------------------
   399  			// stage = 1, previous connection should still succeed
   400  			err = callAndVerify("rpc call 2", greetClient, false)
   401  			if err != nil {
   402  				t.Fatal(err)
   403  			}
   404  			// ------------------------Scenario 3------------------------------------
   405  			// stage = 1, new connection should fail
   406  			shortCtx, shortCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
   407  			defer shortCancel()
   408  			conn2, greetClient, err := callAndVerifyWithClientConn(shortCtx, addr, "rpc call 3", clientTLSCreds, true)
   409  			if err != nil {
   410  				t.Fatal(err)
   411  			}
   412  			defer conn2.Close()
   413  			// ----------------------------------------------------------------------
   414  			stage.increase()
   415  			// ------------------------Scenario 4------------------------------------
   416  			// stage = 2,  new connection should succeed
   417  			conn3, greetClient, err := callAndVerifyWithClientConn(ctx, addr, "rpc call 4", clientTLSCreds, false)
   418  			if err != nil {
   419  				t.Fatal(err)
   420  			}
   421  			defer conn3.Close()
   422  			// ----------------------------------------------------------------------
   423  			stage.reset()
   424  		})
   425  	}
   426  }
   427  
   428  type tmpCredsFiles struct {
   429  	clientCertTmp  *os.File
   430  	clientKeyTmp   *os.File
   431  	clientTrustTmp *os.File
   432  	serverCertTmp  *os.File
   433  	serverKeyTmp   *os.File
   434  	serverTrustTmp *os.File
   435  }
   436  
   437  // Create temp files that are used to hold credentials.
   438  func createTmpFiles() (*tmpCredsFiles, error) {
   439  	tmpFiles := &tmpCredsFiles{}
   440  	var err error
   441  	tmpFiles.clientCertTmp, err = ioutil.TempFile(os.TempDir(), "pre-")
   442  	if err != nil {
   443  		return nil, err
   444  	}
   445  	tmpFiles.clientKeyTmp, err = ioutil.TempFile(os.TempDir(), "pre-")
   446  	if err != nil {
   447  		return nil, err
   448  	}
   449  	tmpFiles.clientTrustTmp, err = ioutil.TempFile(os.TempDir(), "pre-")
   450  	if err != nil {
   451  		return nil, err
   452  	}
   453  	tmpFiles.serverCertTmp, err = ioutil.TempFile(os.TempDir(), "pre-")
   454  	if err != nil {
   455  		return nil, err
   456  	}
   457  	tmpFiles.serverKeyTmp, err = ioutil.TempFile(os.TempDir(), "pre-")
   458  	if err != nil {
   459  		return nil, err
   460  	}
   461  	tmpFiles.serverTrustTmp, err = ioutil.TempFile(os.TempDir(), "pre-")
   462  	if err != nil {
   463  		return nil, err
   464  	}
   465  	return tmpFiles, nil
   466  }
   467  
   468  // Copy the credential contents to the temporary files.
   469  func (tmpFiles *tmpCredsFiles) copyCredsToTmpFiles() error {
   470  	if err := copyFileContents(testdata.Path("client_cert_1.pem"), tmpFiles.clientCertTmp.Name()); err != nil {
   471  		return err
   472  	}
   473  	if err := copyFileContents(testdata.Path("client_key_1.pem"), tmpFiles.clientKeyTmp.Name()); err != nil {
   474  		return err
   475  	}
   476  	if err := copyFileContents(testdata.Path("client_trust_cert_1.pem"), tmpFiles.clientTrustTmp.Name()); err != nil {
   477  		return err
   478  	}
   479  	if err := copyFileContents(testdata.Path("server_cert_1.pem"), tmpFiles.serverCertTmp.Name()); err != nil {
   480  		return err
   481  	}
   482  	if err := copyFileContents(testdata.Path("server_key_1.pem"), tmpFiles.serverKeyTmp.Name()); err != nil {
   483  		return err
   484  	}
   485  	if err := copyFileContents(testdata.Path("server_trust_cert_1.pem"), tmpFiles.serverTrustTmp.Name()); err != nil {
   486  		return err
   487  	}
   488  	return nil
   489  }
   490  
   491  func (tmpFiles *tmpCredsFiles) removeFiles() {
   492  	os.Remove(tmpFiles.clientCertTmp.Name())
   493  	os.Remove(tmpFiles.clientKeyTmp.Name())
   494  	os.Remove(tmpFiles.clientTrustTmp.Name())
   495  	os.Remove(tmpFiles.serverCertTmp.Name())
   496  	os.Remove(tmpFiles.serverKeyTmp.Name())
   497  	os.Remove(tmpFiles.serverTrustTmp.Name())
   498  }
   499  
   500  func copyFileContents(sourceFile, destinationFile string) error {
   501  	input, err := ioutil.ReadFile(sourceFile)
   502  	if err != nil {
   503  		return err
   504  	}
   505  	err = ioutil.WriteFile(destinationFile, input, 0644)
   506  	if err != nil {
   507  		return err
   508  	}
   509  	return nil
   510  }
   511  
   512  // Create PEMFileProvider(s) watching the content changes of temporary
   513  // files.
   514  func createProviders(tmpFiles *tmpCredsFiles) (certprovider.Provider, certprovider.Provider, certprovider.Provider, certprovider.Provider, error) {
   515  	clientIdentityOptions := pemfile.Options{
   516  		CertFile:        tmpFiles.clientCertTmp.Name(),
   517  		KeyFile:         tmpFiles.clientKeyTmp.Name(),
   518  		RefreshDuration: credRefreshingInterval,
   519  	}
   520  	clientIdentityProvider, err := pemfile.NewProvider(clientIdentityOptions)
   521  	if err != nil {
   522  		return nil, nil, nil, nil, err
   523  	}
   524  	clientRootOptions := pemfile.Options{
   525  		RootFile:        tmpFiles.clientTrustTmp.Name(),
   526  		RefreshDuration: credRefreshingInterval,
   527  	}
   528  	clientRootProvider, err := pemfile.NewProvider(clientRootOptions)
   529  	if err != nil {
   530  		return nil, nil, nil, nil, err
   531  	}
   532  	serverIdentityOptions := pemfile.Options{
   533  		CertFile:        tmpFiles.serverCertTmp.Name(),
   534  		KeyFile:         tmpFiles.serverKeyTmp.Name(),
   535  		RefreshDuration: credRefreshingInterval,
   536  	}
   537  	serverIdentityProvider, err := pemfile.NewProvider(serverIdentityOptions)
   538  	if err != nil {
   539  		return nil, nil, nil, nil, err
   540  	}
   541  	serverRootOptions := pemfile.Options{
   542  		RootFile:        tmpFiles.serverTrustTmp.Name(),
   543  		RefreshDuration: credRefreshingInterval,
   544  	}
   545  	serverRootProvider, err := pemfile.NewProvider(serverRootOptions)
   546  	if err != nil {
   547  		return nil, nil, nil, nil, err
   548  	}
   549  	return clientIdentityProvider, clientRootProvider, serverIdentityProvider, serverRootProvider, nil
   550  }
   551  
   552  // In order to test advanced TLS provider features, we used temporary files to
   553  // hold credential data, and copy the contents under testdata/ to these tmp
   554  // files.
   555  // Initially, we establish a good connection with providers watching contents
   556  // from tmp files.
   557  // Next, we change the identity certs that IdentityProvider is watching. Since
   558  // the identity key is not changed, the IdentityProvider should ignore the
   559  // update, and the connection should still be good.
   560  // Then the the identity key is changed. This time IdentityProvider should pick
   561  // up the update, and the connection should fail, due to the trust certs on the
   562  // other side is not changed.
   563  // Finally, the trust certs that other-side's RootProvider is watching get
   564  // changed. The connection should go back to normal again.
   565  func (s) TestPEMFileProviderEnd2End(t *testing.T) {
   566  	tmpFiles, err := createTmpFiles()
   567  	if err != nil {
   568  		t.Fatalf("createTmpFiles() failed, error: %v", err)
   569  	}
   570  	defer tmpFiles.removeFiles()
   571  	for _, test := range []struct {
   572  		desc                string
   573  		certUpdateFunc      func()
   574  		keyUpdateFunc       func()
   575  		trustCertUpdateFunc func()
   576  	}{
   577  		{
   578  			desc: "test the reloading feature for clientIdentityProvider and serverTrustProvider",
   579  			certUpdateFunc: func() {
   580  				err = copyFileContents(testdata.Path("client_cert_2.pem"), tmpFiles.clientCertTmp.Name())
   581  				if err != nil {
   582  					t.Fatalf("copyFileContents(%s, %s) failed: %v", testdata.Path("client_cert_2.pem"), tmpFiles.clientCertTmp.Name(), err)
   583  				}
   584  			},
   585  			keyUpdateFunc: func() {
   586  				err = copyFileContents(testdata.Path("client_key_2.pem"), tmpFiles.clientKeyTmp.Name())
   587  				if err != nil {
   588  					t.Fatalf("copyFileContents(%s, %s) failed: %v", testdata.Path("client_key_2.pem"), tmpFiles.clientKeyTmp.Name(), err)
   589  				}
   590  			},
   591  			trustCertUpdateFunc: func() {
   592  				err = copyFileContents(testdata.Path("server_trust_cert_2.pem"), tmpFiles.serverTrustTmp.Name())
   593  				if err != nil {
   594  					t.Fatalf("copyFileContents(%s, %s) failed: %v", testdata.Path("server_trust_cert_2.pem"), tmpFiles.serverTrustTmp.Name(), err)
   595  				}
   596  			},
   597  		},
   598  		{
   599  			desc: "test the reloading feature for serverIdentityProvider and clientTrustProvider",
   600  			certUpdateFunc: func() {
   601  				err = copyFileContents(testdata.Path("server_cert_2.pem"), tmpFiles.serverCertTmp.Name())
   602  				if err != nil {
   603  					t.Fatalf("copyFileContents(%s, %s) failed: %v", testdata.Path("server_cert_2.pem"), tmpFiles.serverCertTmp.Name(), err)
   604  				}
   605  			},
   606  			keyUpdateFunc: func() {
   607  				err = copyFileContents(testdata.Path("server_key_2.pem"), tmpFiles.serverKeyTmp.Name())
   608  				if err != nil {
   609  					t.Fatalf("copyFileContents(%s, %s) failed: %v", testdata.Path("server_key_2.pem"), tmpFiles.serverKeyTmp.Name(), err)
   610  				}
   611  			},
   612  			trustCertUpdateFunc: func() {
   613  				err = copyFileContents(testdata.Path("client_trust_cert_2.pem"), tmpFiles.clientTrustTmp.Name())
   614  				if err != nil {
   615  					t.Fatalf("copyFileContents(%s, %s) failed: %v", testdata.Path("client_trust_cert_2.pem"), tmpFiles.clientTrustTmp.Name(), err)
   616  				}
   617  			},
   618  		},
   619  	} {
   620  		test := test
   621  		t.Run(test.desc, func(t *testing.T) {
   622  			if err := tmpFiles.copyCredsToTmpFiles(); err != nil {
   623  				t.Fatalf("tmpFiles.copyCredsToTmpFiles() failed, error: %v", err)
   624  			}
   625  			clientIdentityProvider, clientRootProvider, serverIdentityProvider, serverRootProvider, err := createProviders(tmpFiles)
   626  			if err != nil {
   627  				t.Fatalf("createProviders(%v) failed, error: %v", tmpFiles, err)
   628  			}
   629  			defer clientIdentityProvider.Close()
   630  			defer clientRootProvider.Close()
   631  			defer serverIdentityProvider.Close()
   632  			defer serverRootProvider.Close()
   633  			// Start a server and create a client using advancedtls API with Provider.
   634  			serverOptions := &ServerOptions{
   635  				IdentityOptions: IdentityCertificateOptions{
   636  					IdentityProvider: serverIdentityProvider,
   637  				},
   638  				RootOptions: RootCertificateOptions{
   639  					RootProvider: serverRootProvider,
   640  				},
   641  				RequireClientCert: true,
   642  				VerifyPeer: func(params *VerificationFuncParams) (*VerificationResults, error) {
   643  					return &VerificationResults{}, nil
   644  				},
   645  				VType: CertVerification,
   646  			}
   647  			serverTLSCreds, err := NewServerCreds(serverOptions)
   648  			if err != nil {
   649  				t.Fatalf("failed to create server creds: %v", err)
   650  			}
   651  			s := grpc.NewServer(grpc.Creds(serverTLSCreds))
   652  			defer s.Stop()
   653  			lis, err := net.Listen("tcp", "localhost:0")
   654  			if err != nil {
   655  				t.Fatalf("failed to listen: %v", err)
   656  			}
   657  			defer lis.Close()
   658  			addr := fmt.Sprintf("localhost:%v", lis.Addr().(*net.TCPAddr).Port)
   659  			pb.RegisterGreeterServer(s, greeterServer{})
   660  			go s.Serve(lis)
   661  			clientOptions := &ClientOptions{
   662  				IdentityOptions: IdentityCertificateOptions{
   663  					IdentityProvider: clientIdentityProvider,
   664  				},
   665  				VerifyPeer: func(params *VerificationFuncParams) (*VerificationResults, error) {
   666  					return &VerificationResults{}, nil
   667  				},
   668  				RootOptions: RootCertificateOptions{
   669  					RootProvider: clientRootProvider,
   670  				},
   671  				VType: CertVerification,
   672  			}
   673  			clientTLSCreds, err := NewClientCreds(clientOptions)
   674  			if err != nil {
   675  				t.Fatalf("clientTLSCreds failed to create, error: %v", err)
   676  			}
   677  
   678  			// At initialization, the connection should be good.
   679  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   680  			defer cancel()
   681  			conn, greetClient, err := callAndVerifyWithClientConn(ctx, addr, "rpc call 1", clientTLSCreds, false)
   682  			if err != nil {
   683  				t.Fatal(err)
   684  			}
   685  			defer conn.Close()
   686  			// Make the identity cert change, and wait 1 second for the provider to
   687  			// pick up the change.
   688  			test.certUpdateFunc()
   689  			time.Sleep(sleepInterval)
   690  			// The already-established connection should not be affected.
   691  			err = callAndVerify("rpc call 2", greetClient, false)
   692  			if err != nil {
   693  				t.Fatal(err)
   694  			}
   695  			// New connections should still be good, because the Provider didn't pick
   696  			// up the changes due to key-cert mismatch.
   697  			conn2, greetClient, err := callAndVerifyWithClientConn(ctx, addr, "rpc call 3", clientTLSCreds, false)
   698  			if err != nil {
   699  				t.Fatal(err)
   700  			}
   701  			defer conn2.Close()
   702  			// Make the identity key change, and wait 1 second for the provider to
   703  			// pick up the change.
   704  			test.keyUpdateFunc()
   705  			time.Sleep(sleepInterval)
   706  			// New connections should fail now, because the Provider picked the
   707  			// change, and *_cert_2.pem is not trusted by *_trust_cert_1.pem on the
   708  			// other side.
   709  			shortCtx, shortCancel := context.WithTimeout(context.Background(), defaultTestShortTimeout)
   710  			defer shortCancel()
   711  			conn3, greetClient, err := callAndVerifyWithClientConn(shortCtx, addr, "rpc call 4", clientTLSCreds, true)
   712  			if err != nil {
   713  				t.Fatal(err)
   714  			}
   715  			defer conn3.Close()
   716  			// Make the trust cert change on the other side, and wait 1 second for
   717  			// the provider to pick up the change.
   718  			test.trustCertUpdateFunc()
   719  			time.Sleep(sleepInterval)
   720  			// New connections should be good, because the other side is using
   721  			// *_trust_cert_2.pem now.
   722  			conn4, greetClient, err := callAndVerifyWithClientConn(ctx, addr, "rpc call 5", clientTLSCreds, false)
   723  			if err != nil {
   724  				t.Fatal(err)
   725  			}
   726  			defer conn4.Close()
   727  		})
   728  	}
   729  }
   730  
   731  func (s) TestDefaultHostNameCheck(t *testing.T) {
   732  	cs := &testutils.CertStore{}
   733  	if err := cs.LoadCerts(); err != nil {
   734  		t.Fatalf("cs.LoadCerts() failed, err: %v", err)
   735  	}
   736  	for _, test := range []struct {
   737  		desc             string
   738  		clientRoot       *x509.CertPool
   739  		clientVerifyFunc CustomVerificationFunc
   740  		clientVType      VerificationType
   741  		serverCert       []tls.Certificate
   742  		serverVType      VerificationType
   743  		expectError      bool
   744  	}{
   745  		// Client side sets vType to CertAndHostVerification, and will do
   746  		// default hostname check. Server uses a cert without "localhost" or
   747  		// "127.0.0.1" as common name or SAN names, and will hence fail.
   748  		{
   749  			desc:        "Bad default hostname check",
   750  			clientRoot:  cs.ClientTrust1,
   751  			clientVType: CertAndHostVerification,
   752  			serverCert:  []tls.Certificate{cs.ServerCert1},
   753  			serverVType: CertAndHostVerification,
   754  			expectError: true,
   755  		},
   756  		// Client side sets vType to CertAndHostVerification, and will do
   757  		// default hostname check. Server uses a certificate with "localhost" as
   758  		// common name, and will hence pass the default hostname check.
   759  		{
   760  			desc:        "Good default hostname check",
   761  			clientRoot:  cs.ClientTrust1,
   762  			clientVType: CertAndHostVerification,
   763  			serverCert:  []tls.Certificate{cs.ServerPeerLocalhost1},
   764  			serverVType: CertAndHostVerification,
   765  			expectError: false,
   766  		},
   767  	} {
   768  		test := test
   769  		t.Run(test.desc, func(t *testing.T) {
   770  			// Start a server using ServerOptions in another goroutine.
   771  			serverOptions := &ServerOptions{
   772  				IdentityOptions: IdentityCertificateOptions{
   773  					Certificates: test.serverCert,
   774  				},
   775  				RequireClientCert: false,
   776  				VType:             test.serverVType,
   777  			}
   778  			serverTLSCreds, err := NewServerCreds(serverOptions)
   779  			if err != nil {
   780  				t.Fatalf("failed to create server creds: %v", err)
   781  			}
   782  			s := grpc.NewServer(grpc.Creds(serverTLSCreds))
   783  			defer s.Stop()
   784  			lis, err := net.Listen("tcp", "localhost:0")
   785  			if err != nil {
   786  				t.Fatalf("failed to listen: %v", err)
   787  			}
   788  			defer lis.Close()
   789  			addr := fmt.Sprintf("localhost:%v", lis.Addr().(*net.TCPAddr).Port)
   790  			pb.RegisterGreeterServer(s, greeterServer{})
   791  			go s.Serve(lis)
   792  			clientOptions := &ClientOptions{
   793  				VerifyPeer: test.clientVerifyFunc,
   794  				RootOptions: RootCertificateOptions{
   795  					RootCACerts: test.clientRoot,
   796  				},
   797  				VType: test.clientVType,
   798  			}
   799  			clientTLSCreds, err := NewClientCreds(clientOptions)
   800  			if err != nil {
   801  				t.Fatalf("clientTLSCreds failed to create")
   802  			}
   803  			shouldFail := false
   804  			if test.expectError {
   805  				shouldFail = true
   806  			}
   807  			ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   808  			defer cancel()
   809  			conn, _, err := callAndVerifyWithClientConn(ctx, addr, "rpc call 1", clientTLSCreds, shouldFail)
   810  			if err != nil {
   811  				t.Fatal(err)
   812  			}
   813  			defer conn.Close()
   814  		})
   815  	}
   816  }