github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/grpc/credentials/tls/certprovider/pemfile/watcher_test.go (about)

     1  /*
     2   *
     3   * Copyright 2020 gRPC authors.
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     http://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  package pemfile
    20  
    21  import (
    22  	"context"
    23  	"fmt"
    24  	"io/ioutil"
    25  	"os"
    26  	"path"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/google/go-cmp/cmp"
    31  	"github.com/google/go-cmp/cmp/cmpopts"
    32  	"github.com/hxx258456/ccgo/grpc/credentials/tls/certprovider"
    33  	"github.com/hxx258456/ccgo/grpc/internal/grpctest"
    34  	"github.com/hxx258456/ccgo/grpc/internal/testutils"
    35  	"github.com/hxx258456/ccgo/grpc/testdata"
    36  )
    37  
    38  const (
    39  	// These are the names of files inside temporary directories, which the
    40  	// plugin is asked to watch.
    41  	certFile = "cert.pem"
    42  	keyFile  = "key.pem"
    43  	rootFile = "ca.pem"
    44  
    45  	defaultTestRefreshDuration = 100 * time.Millisecond
    46  	defaultTestTimeout         = 5 * time.Second
    47  )
    48  
    49  type s struct {
    50  	grpctest.Tester
    51  }
    52  
    53  func Test(t *testing.T) {
    54  	grpctest.RunSubTests(t, s{})
    55  }
    56  
    57  func compareKeyMaterial(got, want *certprovider.KeyMaterial) error {
    58  	if len(got.Certs) != len(want.Certs) {
    59  		return fmt.Errorf("keyMaterial certs = %+v, want %+v", got, want)
    60  	}
    61  	for i := 0; i < len(got.Certs); i++ {
    62  		if !got.Certs[i].Leaf.Equal(want.Certs[i].Leaf) {
    63  			return fmt.Errorf("keyMaterial certs = %+v, want %+v", got, want)
    64  		}
    65  	}
    66  
    67  	// x509.CertPool contains only unexported fields some of which contain other
    68  	// unexported fields. So usage of cmp.AllowUnexported() or
    69  	// cmpopts.IgnoreUnexported() does not help us much here. Also, the standard
    70  	// library does not provide a way to compare CertPool values. Comparing the
    71  	// subjects field of the certs in the CertPool seems like a reasonable
    72  	// approach.
    73  	if gotR, wantR := got.Roots.Subjects(), want.Roots.Subjects(); !cmp.Equal(gotR, wantR, cmpopts.EquateEmpty()) {
    74  		return fmt.Errorf("keyMaterial roots = %v, want %v", gotR, wantR)
    75  	}
    76  	return nil
    77  }
    78  
    79  // TestNewProvider tests the NewProvider() function with different inputs.
    80  func (s) TestNewProvider(t *testing.T) {
    81  	tests := []struct {
    82  		desc      string
    83  		options   Options
    84  		wantError bool
    85  	}{
    86  		{
    87  			desc:      "No credential files specified",
    88  			options:   Options{},
    89  			wantError: true,
    90  		},
    91  		{
    92  			desc: "Only identity cert is specified",
    93  			options: Options{
    94  				CertFile: testdata.Path("x509/client1_cert.pem"),
    95  			},
    96  			wantError: true,
    97  		},
    98  		{
    99  			desc: "Only identity key is specified",
   100  			options: Options{
   101  				KeyFile: testdata.Path("x509/client1_key.pem"),
   102  			},
   103  			wantError: true,
   104  		},
   105  		{
   106  			desc: "Identity cert/key pair is specified",
   107  			options: Options{
   108  				KeyFile:  testdata.Path("x509/client1_key.pem"),
   109  				CertFile: testdata.Path("x509/client1_cert.pem"),
   110  			},
   111  		},
   112  		{
   113  			desc: "Only root certs are specified",
   114  			options: Options{
   115  				RootFile: testdata.Path("x509/client_ca_cert.pem"),
   116  			},
   117  		},
   118  		{
   119  			desc: "Everything is specified",
   120  			options: Options{
   121  				KeyFile:  testdata.Path("x509/client1_key.pem"),
   122  				CertFile: testdata.Path("x509/client1_cert.pem"),
   123  				RootFile: testdata.Path("x509/client_ca_cert.pem"),
   124  			},
   125  			wantError: false,
   126  		},
   127  	}
   128  	for _, test := range tests {
   129  		t.Run(test.desc, func(t *testing.T) {
   130  			provider, err := NewProvider(test.options)
   131  			if (err != nil) != test.wantError {
   132  				t.Fatalf("NewProvider(%v) = %v, want %v", test.options, err, test.wantError)
   133  			}
   134  			if err != nil {
   135  				return
   136  			}
   137  			provider.Close()
   138  		})
   139  	}
   140  }
   141  
   142  // wrappedDistributor wraps a distributor and pushes on a channel whenever new
   143  // key material is pushed to the distributor.
   144  type wrappedDistributor struct {
   145  	*certprovider.Distributor
   146  	distCh *testutils.Channel
   147  }
   148  
   149  func newWrappedDistributor(distCh *testutils.Channel) *wrappedDistributor {
   150  	return &wrappedDistributor{
   151  		distCh:      distCh,
   152  		Distributor: certprovider.NewDistributor(),
   153  	}
   154  }
   155  
   156  func (wd *wrappedDistributor) Set(km *certprovider.KeyMaterial, err error) {
   157  	wd.Distributor.Set(km, err)
   158  	wd.distCh.Send(nil)
   159  }
   160  
   161  func createTmpFile(t *testing.T, src, dst string) {
   162  	t.Helper()
   163  
   164  	data, err := ioutil.ReadFile(src)
   165  	if err != nil {
   166  		t.Fatalf("ioutil.ReadFile(%q) failed: %v", src, err)
   167  	}
   168  	if err := ioutil.WriteFile(dst, data, os.ModePerm); err != nil {
   169  		t.Fatalf("ioutil.WriteFile(%q) failed: %v", dst, err)
   170  	}
   171  	t.Logf("Wrote file at: %s", dst)
   172  	t.Logf("%s", string(data))
   173  }
   174  
   175  // createTempDirWithFiles creates a temporary directory under the system default
   176  // tempDir with the given dirSuffix. It also reads from certSrc, keySrc and
   177  // rootSrc files are creates appropriate files under the newly create tempDir.
   178  // Returns the name of the created tempDir.
   179  func createTmpDirWithFiles(t *testing.T, dirSuffix, certSrc, keySrc, rootSrc string) string {
   180  	t.Helper()
   181  
   182  	// Create a temp directory. Passing an empty string for the first argument
   183  	// uses the system temp directory.
   184  	dir, err := ioutil.TempDir("", dirSuffix)
   185  	if err != nil {
   186  		t.Fatalf("ioutil.TempDir() failed: %v", err)
   187  	}
   188  	t.Logf("Using tmpdir: %s", dir)
   189  
   190  	createTmpFile(t, testdata.Path(certSrc), path.Join(dir, certFile))
   191  	createTmpFile(t, testdata.Path(keySrc), path.Join(dir, keyFile))
   192  	createTmpFile(t, testdata.Path(rootSrc), path.Join(dir, rootFile))
   193  	return dir
   194  }
   195  
   196  // initializeProvider performs setup steps common to all tests (except the one
   197  // which uses symlinks).
   198  func initializeProvider(t *testing.T, testName string) (string, certprovider.Provider, *testutils.Channel, func()) {
   199  	t.Helper()
   200  
   201  	// Override the newDistributor to one which pushes on a channel that we
   202  	// can block on.
   203  	origDistributorFunc := newDistributor
   204  	distCh := testutils.NewChannel()
   205  	d := newWrappedDistributor(distCh)
   206  	newDistributor = func() distributor { return d }
   207  
   208  	// Create a new provider to watch the files in tmpdir.
   209  	dir := createTmpDirWithFiles(t, testName+"*", "x509/client1_cert.pem", "x509/client1_key.pem", "x509/client_ca_cert.pem")
   210  	opts := Options{
   211  		CertFile:        path.Join(dir, certFile),
   212  		KeyFile:         path.Join(dir, keyFile),
   213  		RootFile:        path.Join(dir, rootFile),
   214  		RefreshDuration: defaultTestRefreshDuration,
   215  	}
   216  	prov, err := NewProvider(opts)
   217  	if err != nil {
   218  		t.Fatalf("NewProvider(%+v) failed: %v", opts, err)
   219  	}
   220  
   221  	// Make sure the provider picks up the files and pushes the key material on
   222  	// to the distributors.
   223  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   224  	defer cancel()
   225  	for i := 0; i < 2; i++ {
   226  		// Since we have root and identity certs, we need to make sure the
   227  		// update is pushed on both of them.
   228  		if _, err := distCh.Receive(ctx); err != nil {
   229  			t.Fatalf("timeout waiting for provider to read files and push key material to distributor: %v", err)
   230  		}
   231  	}
   232  
   233  	return dir, prov, distCh, func() {
   234  		newDistributor = origDistributorFunc
   235  		prov.Close()
   236  	}
   237  }
   238  
   239  // TestProvider_NoUpdate tests the case where a file watcher plugin is created
   240  // successfully, and the underlying files do not change. Verifies that the
   241  // plugin does not push new updates to the distributor in this case.
   242  func (s) TestProvider_NoUpdate(t *testing.T) {
   243  	_, prov, distCh, cancel := initializeProvider(t, "no_update")
   244  	defer cancel()
   245  
   246  	// Make sure the provider is healthy and returns key material.
   247  	ctx, cc := context.WithTimeout(context.Background(), defaultTestTimeout)
   248  	defer cc()
   249  	if _, err := prov.KeyMaterial(ctx); err != nil {
   250  		t.Fatalf("provider.KeyMaterial() failed: %v", err)
   251  	}
   252  
   253  	// Files haven't change. Make sure no updates are pushed by the provider.
   254  	sCtx, sc := context.WithTimeout(context.Background(), 2*defaultTestRefreshDuration)
   255  	defer sc()
   256  	if _, err := distCh.Receive(sCtx); err == nil {
   257  		t.Fatal("new key material pushed to distributor when underlying files did not change")
   258  	}
   259  }
   260  
   261  // TestProvider_UpdateSuccess tests the case where a file watcher plugin is
   262  // created successfully and the underlying files change. Verifies that the
   263  // changes are picked up by the provider.
   264  func (s) TestProvider_UpdateSuccess(t *testing.T) {
   265  	dir, prov, distCh, cancel := initializeProvider(t, "update_success")
   266  	defer cancel()
   267  
   268  	// Make sure the provider is healthy and returns key material.
   269  	ctx, cc := context.WithTimeout(context.Background(), defaultTestTimeout)
   270  	defer cc()
   271  	km1, err := prov.KeyMaterial(ctx)
   272  	if err != nil {
   273  		t.Fatalf("provider.KeyMaterial() failed: %v", err)
   274  	}
   275  
   276  	// Change only the root file.
   277  	createTmpFile(t, testdata.Path("x509/server_ca_cert.pem"), path.Join(dir, rootFile))
   278  	if _, err := distCh.Receive(ctx); err != nil {
   279  		t.Fatal("timeout waiting for new key material to be pushed to the distributor")
   280  	}
   281  
   282  	// Make sure update is picked up.
   283  	km2, err := prov.KeyMaterial(ctx)
   284  	if err != nil {
   285  		t.Fatalf("provider.KeyMaterial() failed: %v", err)
   286  	}
   287  	if err := compareKeyMaterial(km1, km2); err == nil {
   288  		t.Fatal("expected provider to return new key material after update to underlying file")
   289  	}
   290  
   291  	// Change only cert/key files.
   292  	createTmpFile(t, testdata.Path("x509/client2_cert.pem"), path.Join(dir, certFile))
   293  	createTmpFile(t, testdata.Path("x509/client2_key.pem"), path.Join(dir, keyFile))
   294  	if _, err := distCh.Receive(ctx); err != nil {
   295  		t.Fatal("timeout waiting for new key material to be pushed to the distributor")
   296  	}
   297  
   298  	// Make sure update is picked up.
   299  	km3, err := prov.KeyMaterial(ctx)
   300  	if err != nil {
   301  		t.Fatalf("provider.KeyMaterial() failed: %v", err)
   302  	}
   303  	if err := compareKeyMaterial(km2, km3); err == nil {
   304  		t.Fatal("expected provider to return new key material after update to underlying file")
   305  	}
   306  }
   307  
   308  // TestProvider_UpdateSuccessWithSymlink tests the case where a file watcher
   309  // plugin is created successfully to watch files through a symlink and the
   310  // symlink is updates to point to new files. Verifies that the changes are
   311  // picked up by the provider.
   312  func (s) TestProvider_UpdateSuccessWithSymlink(t *testing.T) {
   313  	// Override the newDistributor to one which pushes on a channel that we
   314  	// can block on.
   315  	origDistributorFunc := newDistributor
   316  	distCh := testutils.NewChannel()
   317  	d := newWrappedDistributor(distCh)
   318  	newDistributor = func() distributor { return d }
   319  	defer func() { newDistributor = origDistributorFunc }()
   320  
   321  	// Create two tempDirs with different files.
   322  	dir1 := createTmpDirWithFiles(t, "update_with_symlink1_*", "x509/client1_cert.pem", "x509/client1_key.pem", "x509/client_ca_cert.pem")
   323  	dir2 := createTmpDirWithFiles(t, "update_with_symlink2_*", "x509/server1_cert.pem", "x509/server1_key.pem", "x509/server_ca_cert.pem")
   324  
   325  	// Create a symlink under a new tempdir, and make it point to dir1.
   326  	tmpdir, err := ioutil.TempDir("", "test_symlink_*")
   327  	if err != nil {
   328  		t.Fatalf("ioutil.TempDir() failed: %v", err)
   329  	}
   330  	symLinkName := path.Join(tmpdir, "test_symlink")
   331  	if err := os.Symlink(dir1, symLinkName); err != nil {
   332  		t.Fatalf("failed to create symlink to %q: %v", dir1, err)
   333  	}
   334  
   335  	// Create a provider which watches the files pointed to by the symlink.
   336  	opts := Options{
   337  		CertFile:        path.Join(symLinkName, certFile),
   338  		KeyFile:         path.Join(symLinkName, keyFile),
   339  		RootFile:        path.Join(symLinkName, rootFile),
   340  		RefreshDuration: defaultTestRefreshDuration,
   341  	}
   342  	prov, err := NewProvider(opts)
   343  	if err != nil {
   344  		t.Fatalf("NewProvider(%+v) failed: %v", opts, err)
   345  	}
   346  	defer prov.Close()
   347  
   348  	// Make sure the provider picks up the files and pushes the key material on
   349  	// to the distributors.
   350  	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
   351  	defer cancel()
   352  	for i := 0; i < 2; i++ {
   353  		// Since we have root and identity certs, we need to make sure the
   354  		// update is pushed on both of them.
   355  		if _, err := distCh.Receive(ctx); err != nil {
   356  			t.Fatalf("timeout waiting for provider to read files and push key material to distributor: %v", err)
   357  		}
   358  	}
   359  	km1, err := prov.KeyMaterial(ctx)
   360  	if err != nil {
   361  		t.Fatalf("provider.KeyMaterial() failed: %v", err)
   362  	}
   363  
   364  	// Update the symlink to point to dir2.
   365  	symLinkTmpName := path.Join(tmpdir, "test_symlink.tmp")
   366  	if err := os.Symlink(dir2, symLinkTmpName); err != nil {
   367  		t.Fatalf("failed to create symlink to %q: %v", dir2, err)
   368  	}
   369  	if err := os.Rename(symLinkTmpName, symLinkName); err != nil {
   370  		t.Fatalf("failed to update symlink: %v", err)
   371  	}
   372  
   373  	// Make sure the provider picks up the new files and pushes the key material
   374  	// on to the distributors.
   375  	for i := 0; i < 2; i++ {
   376  		// Since we have root and identity certs, we need to make sure the
   377  		// update is pushed on both of them.
   378  		if _, err := distCh.Receive(ctx); err != nil {
   379  			t.Fatalf("timeout waiting for provider to read files and push key material to distributor: %v", err)
   380  		}
   381  	}
   382  	km2, err := prov.KeyMaterial(ctx)
   383  	if err != nil {
   384  		t.Fatalf("provider.KeyMaterial() failed: %v", err)
   385  	}
   386  
   387  	if err := compareKeyMaterial(km1, km2); err == nil {
   388  		t.Fatal("expected provider to return new key material after symlink update")
   389  	}
   390  }
   391  
   392  // TestProvider_UpdateFailure_ThenSuccess tests the case where updating cert/key
   393  // files fail. Verifies that the failed update does not push anything on the
   394  // distributor. Then the update succeeds, and the test verifies that the key
   395  // material is updated.
   396  func (s) TestProvider_UpdateFailure_ThenSuccess(t *testing.T) {
   397  	dir, prov, distCh, cancel := initializeProvider(t, "update_failure")
   398  	defer cancel()
   399  
   400  	// Make sure the provider is healthy and returns key material.
   401  	ctx, cc := context.WithTimeout(context.Background(), defaultTestTimeout)
   402  	defer cc()
   403  	km1, err := prov.KeyMaterial(ctx)
   404  	if err != nil {
   405  		t.Fatalf("provider.KeyMaterial() failed: %v", err)
   406  	}
   407  
   408  	// Update only the cert file. The key file is left unchanged. This should
   409  	// lead to these two files being not compatible with each other. This
   410  	// simulates the case where the watching goroutine might catch the files in
   411  	// the midst of an update.
   412  	createTmpFile(t, testdata.Path("x509/server1_cert.pem"), path.Join(dir, certFile))
   413  
   414  	// Since the last update left the files in an incompatible state, the update
   415  	// should not be picked up by our provider.
   416  	sCtx, sc := context.WithTimeout(context.Background(), 2*defaultTestRefreshDuration)
   417  	defer sc()
   418  	if _, err := distCh.Receive(sCtx); err == nil {
   419  		t.Fatal("new key material pushed to distributor when underlying files did not change")
   420  	}
   421  
   422  	// The provider should return key material corresponding to the old state.
   423  	km2, err := prov.KeyMaterial(ctx)
   424  	if err != nil {
   425  		t.Fatalf("provider.KeyMaterial() failed: %v", err)
   426  	}
   427  	if err := compareKeyMaterial(km1, km2); err != nil {
   428  		t.Fatalf("expected provider to not update key material: %v", err)
   429  	}
   430  
   431  	// Update the key file to match the cert file.
   432  	createTmpFile(t, testdata.Path("x509/server1_key.pem"), path.Join(dir, keyFile))
   433  
   434  	// Make sure update is picked up.
   435  	if _, err := distCh.Receive(ctx); err != nil {
   436  		t.Fatal("timeout waiting for new key material to be pushed to the distributor")
   437  	}
   438  	km3, err := prov.KeyMaterial(ctx)
   439  	if err != nil {
   440  		t.Fatalf("provider.KeyMaterial() failed: %v", err)
   441  	}
   442  	if err := compareKeyMaterial(km2, km3); err == nil {
   443  		t.Fatal("expected provider to return new key material after update to underlying file")
   444  	}
   445  }