github.com/hechain20/hechain@v0.0.0-20220316014945-b544036ba106/internal/pkg/comm/config_test.go (about)

     1  /*
     2  Copyright hechain. All Rights Reserved.
     3  
     4  SPDX-License-Identifier: Apache-2.0
     5  */
     6  
     7  package comm
     8  
     9  import (
    10  	"crypto/tls"
    11  	"crypto/x509"
    12  	"io/ioutil"
    13  	"path/filepath"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/hechain20/hechain/common/crypto/tlsgen"
    18  	"github.com/stretchr/testify/require"
    19  	"google.golang.org/grpc"
    20  	"google.golang.org/grpc/keepalive"
    21  )
    22  
    23  func TestServerKeepaliveOptions(t *testing.T) {
    24  	t.Parallel()
    25  
    26  	kap := keepalive.ServerParameters{
    27  		Time:    DefaultKeepaliveOptions.ServerInterval,
    28  		Timeout: DefaultKeepaliveOptions.ServerTimeout,
    29  	}
    30  	kep := keepalive.EnforcementPolicy{
    31  		MinTime:             DefaultKeepaliveOptions.ServerMinInterval,
    32  		PermitWithoutStream: true,
    33  	}
    34  	expectedOpts := []grpc.ServerOption{
    35  		grpc.KeepaliveParams(kap),
    36  		grpc.KeepaliveEnforcementPolicy(kep),
    37  	}
    38  	opts := DefaultKeepaliveOptions.ServerKeepaliveOptions()
    39  
    40  	// Unable to test equality of options since the option methods return
    41  	// functions and each instance is a different func.
    42  	// Unable to test the equality of applying the options to the server
    43  	// implementation because the server embeds channels.
    44  	// Fallback to a sanity check.
    45  	require.Len(t, opts, len(expectedOpts))
    46  	for i := range opts {
    47  		require.IsType(t, expectedOpts[i], opts[i])
    48  	}
    49  }
    50  
    51  func TestClientKeepaliveOptions(t *testing.T) {
    52  	t.Parallel()
    53  
    54  	kap := keepalive.ClientParameters{
    55  		Time:                DefaultKeepaliveOptions.ClientInterval,
    56  		Timeout:             DefaultKeepaliveOptions.ClientTimeout,
    57  		PermitWithoutStream: true,
    58  	}
    59  	expectedOpts := []grpc.DialOption{grpc.WithKeepaliveParams(kap)}
    60  	opts := DefaultKeepaliveOptions.ClientKeepaliveOptions()
    61  
    62  	// Unable to test equality of options since the option methods return
    63  	// functions and each instance is a different func.
    64  	// Fallback to a sanity check.
    65  	require.Len(t, opts, len(expectedOpts))
    66  	for i := range opts {
    67  		require.IsType(t, expectedOpts[i], opts[i])
    68  	}
    69  }
    70  
    71  func TestClientConfigClone(t *testing.T) {
    72  	origin := ClientConfig{
    73  		KaOpts: KeepaliveOptions{
    74  			ClientInterval: time.Second,
    75  		},
    76  		SecOpts: SecureOptions{
    77  			Key: []byte{1, 2, 3},
    78  		},
    79  		DialTimeout:  time.Second,
    80  		AsyncConnect: true,
    81  	}
    82  
    83  	clone := origin
    84  
    85  	// Same content, different inner fields references.
    86  	require.Equal(t, origin, clone)
    87  
    88  	// We change the contents of the fields and ensure it doesn't
    89  	// propagate across instances.
    90  	origin.AsyncConnect = false
    91  	origin.KaOpts.ServerInterval = time.Second
    92  	origin.KaOpts.ClientInterval = time.Hour
    93  	origin.SecOpts.Certificate = []byte{1, 2, 3}
    94  	origin.SecOpts.Key = []byte{5, 4, 6}
    95  	origin.DialTimeout = time.Second * 2
    96  
    97  	clone.SecOpts.UseTLS = true
    98  	clone.KaOpts.ServerMinInterval = time.Hour
    99  
   100  	expectedOriginState := ClientConfig{
   101  		KaOpts: KeepaliveOptions{
   102  			ClientInterval: time.Hour,
   103  			ServerInterval: time.Second,
   104  		},
   105  		SecOpts: SecureOptions{
   106  			Key:         []byte{5, 4, 6},
   107  			Certificate: []byte{1, 2, 3},
   108  		},
   109  		DialTimeout: time.Second * 2,
   110  	}
   111  
   112  	expectedCloneState := ClientConfig{
   113  		KaOpts: KeepaliveOptions{
   114  			ClientInterval:    time.Second,
   115  			ServerMinInterval: time.Hour,
   116  		},
   117  		SecOpts: SecureOptions{
   118  			Key:    []byte{1, 2, 3},
   119  			UseTLS: true,
   120  		},
   121  		DialTimeout:  time.Second,
   122  		AsyncConnect: true,
   123  	}
   124  
   125  	require.Equal(t, expectedOriginState, origin)
   126  	require.Equal(t, expectedCloneState, clone)
   127  }
   128  
   129  func TestSecureOptionsTLSConfig(t *testing.T) {
   130  	ca1, err := tlsgen.NewCA()
   131  	require.NoError(t, err, "failed to create CA1")
   132  	ca2, err := tlsgen.NewCA()
   133  	require.NoError(t, err, "failed to create CA2")
   134  	ckp, err := ca1.NewClientCertKeyPair()
   135  	require.NoError(t, err, "failed to create client key pair")
   136  	clientCert, err := tls.X509KeyPair(ckp.Cert, ckp.Key)
   137  	require.NoError(t, err, "failed to create client certificate")
   138  
   139  	newCertPool := func(cas ...tlsgen.CA) *x509.CertPool {
   140  		cp := x509.NewCertPool()
   141  		for _, ca := range cas {
   142  			ok := cp.AppendCertsFromPEM(ca.CertBytes())
   143  			require.True(t, ok, "failed to add cert to pool")
   144  		}
   145  		return cp
   146  	}
   147  
   148  	tests := []struct {
   149  		desc        string
   150  		so          SecureOptions
   151  		tc          *tls.Config
   152  		expectedErr string
   153  	}{
   154  		{desc: "TLSDisabled"},
   155  		{desc: "TLSEnabled", so: SecureOptions{UseTLS: true}, tc: &tls.Config{MinVersion: tls.VersionTLS12}},
   156  		{
   157  			desc: "ServerNameOverride",
   158  			so:   SecureOptions{UseTLS: true, ServerNameOverride: "bob"},
   159  			tc:   &tls.Config{MinVersion: tls.VersionTLS12, ServerName: "bob"},
   160  		},
   161  		{
   162  			desc: "WithServerRootCAs",
   163  			so:   SecureOptions{UseTLS: true, ServerRootCAs: [][]byte{ca1.CertBytes(), ca2.CertBytes()}},
   164  			tc:   &tls.Config{MinVersion: tls.VersionTLS12, RootCAs: newCertPool(ca1, ca2)},
   165  		},
   166  		{
   167  			desc: "BadServerRootCertificate",
   168  			so: SecureOptions{
   169  				UseTLS: true,
   170  				ServerRootCAs: [][]byte{
   171  					[]byte("-----BEGIN CERTIFICATE-----\nYm9ndXM=\n-----END CERTIFICATE-----"),
   172  				},
   173  			},
   174  			expectedErr: "error adding root certificate",
   175  		},
   176  		{
   177  			desc: "WithRequiredClientKeyPair",
   178  			so:   SecureOptions{UseTLS: true, RequireClientCert: true, Key: ckp.Key, Certificate: ckp.Cert},
   179  			tc:   &tls.Config{MinVersion: tls.VersionTLS12, Certificates: []tls.Certificate{clientCert}},
   180  		},
   181  		{
   182  			desc:        "MissingClientKey",
   183  			so:          SecureOptions{UseTLS: true, RequireClientCert: true, Certificate: ckp.Cert},
   184  			expectedErr: "both Key and Certificate are required when using mutual TLS",
   185  		},
   186  		{
   187  			desc:        "MissingClientCert",
   188  			so:          SecureOptions{UseTLS: true, RequireClientCert: true, Key: ckp.Key},
   189  			expectedErr: "both Key and Certificate are required when using mutual TLS",
   190  		},
   191  		{
   192  			desc: "WithTimeShift",
   193  			so:   SecureOptions{UseTLS: true, TimeShift: 2 * time.Hour},
   194  			tc:   &tls.Config{MinVersion: tls.VersionTLS12},
   195  		},
   196  	}
   197  	for _, tt := range tests {
   198  		t.Run(tt.desc, func(t *testing.T) {
   199  			tc, err := tt.so.TLSConfig()
   200  			if tt.expectedErr != "" {
   201  				require.ErrorContainsf(t, err, tt.expectedErr, "got %v, want %s", err, tt.expectedErr)
   202  				return
   203  			}
   204  			require.NoError(t, err)
   205  
   206  			if len(tt.so.ServerRootCAs) != 0 {
   207  				require.NotNil(t, tc.RootCAs)
   208  				require.Len(t, tc.RootCAs.Subjects(), len(tt.so.ServerRootCAs))
   209  				for _, subj := range tt.tc.RootCAs.Subjects() {
   210  					require.Contains(t, tc.RootCAs.Subjects(), subj, "missing subject %x", subj)
   211  				}
   212  				tt.tc.RootCAs, tc.RootCAs = nil, nil
   213  			}
   214  
   215  			if tt.so.TimeShift != 0 {
   216  				require.NotNil(t, tc.Time)
   217  				require.WithinDuration(t, time.Now().Add(-1*tt.so.TimeShift), tc.Time(), 10*time.Second)
   218  				tc.Time = nil
   219  			}
   220  
   221  			require.Equal(t, tt.tc, tc)
   222  		})
   223  	}
   224  }
   225  
   226  func TestClientConfigDialOptions_GoodConfig(t *testing.T) {
   227  	testCerts := LoadTestCerts(t)
   228  
   229  	config := ClientConfig{}
   230  	opts, err := config.DialOptions()
   231  	require.NoError(t, err)
   232  	require.NotEmpty(t, opts)
   233  
   234  	secOpts := SecureOptions{
   235  		UseTLS:            true,
   236  		ServerRootCAs:     [][]byte{testCerts.CAPEM},
   237  		RequireClientCert: false,
   238  	}
   239  	config.SecOpts = secOpts
   240  	opts, err = config.DialOptions()
   241  	require.NoError(t, err)
   242  	require.NotEmpty(t, opts)
   243  
   244  	secOpts = SecureOptions{
   245  		Certificate:       testCerts.CertPEM,
   246  		Key:               testCerts.KeyPEM,
   247  		UseTLS:            true,
   248  		ServerRootCAs:     [][]byte{testCerts.CAPEM},
   249  		RequireClientCert: true,
   250  	}
   251  	clientCert, err := secOpts.ClientCertificate()
   252  	require.NoError(t, err)
   253  	require.Equal(t, testCerts.ClientCert, clientCert)
   254  	config.SecOpts = secOpts
   255  	opts, err = config.DialOptions()
   256  	require.NoError(t, err)
   257  	require.NotEmpty(t, opts)
   258  }
   259  
   260  func TestClientConfigDialOptions_BadConfig(t *testing.T) {
   261  	testCerts := LoadTestCerts(t)
   262  
   263  	// bad root cert
   264  	config := ClientConfig{
   265  		SecOpts: SecureOptions{
   266  			UseTLS:        true,
   267  			ServerRootCAs: [][]byte{[]byte(badPEM)},
   268  		},
   269  	}
   270  	_, err := config.DialOptions()
   271  	require.ErrorContains(t, err, "error adding root certificate")
   272  
   273  	// missing key
   274  	config.SecOpts = SecureOptions{
   275  		Certificate:       []byte("cert"),
   276  		UseTLS:            true,
   277  		RequireClientCert: true,
   278  	}
   279  	_, err = config.DialOptions()
   280  	require.ErrorContains(t, err, "both Key and Certificate are required when using mutual TLS")
   281  
   282  	// missing cert
   283  	config.SecOpts = SecureOptions{
   284  		Key:               []byte("key"),
   285  		UseTLS:            true,
   286  		RequireClientCert: true,
   287  	}
   288  	_, err = config.DialOptions()
   289  	require.ErrorContains(t, err, "both Key and Certificate are required when using mutual TLS")
   290  
   291  	// bad key
   292  	config.SecOpts = SecureOptions{
   293  		Certificate:       testCerts.CertPEM,
   294  		Key:               []byte(badPEM),
   295  		UseTLS:            true,
   296  		RequireClientCert: true,
   297  	}
   298  	_, err = config.DialOptions()
   299  	require.ErrorContains(t, err, "failed to load client certificate")
   300  
   301  	// bad cert
   302  	config.SecOpts = SecureOptions{
   303  		Certificate:       []byte(badPEM),
   304  		Key:               testCerts.KeyPEM,
   305  		UseTLS:            true,
   306  		RequireClientCert: true,
   307  	}
   308  	_, err = config.DialOptions()
   309  	require.ErrorContains(t, err, "failed to load client certificate")
   310  }
   311  
   312  type TestCerts struct {
   313  	CAPEM      []byte
   314  	CertPEM    []byte
   315  	KeyPEM     []byte
   316  	ClientCert tls.Certificate
   317  	ServerCert tls.Certificate
   318  }
   319  
   320  func LoadTestCerts(t *testing.T) TestCerts {
   321  	t.Helper()
   322  
   323  	var certs TestCerts
   324  	var err error
   325  	certs.CAPEM, err = ioutil.ReadFile(filepath.Join("testdata", "certs", "Org1-cert.pem"))
   326  	if err != nil {
   327  		t.Fatalf("unexpected error reading root cert for test: %v", err)
   328  	}
   329  	certs.CertPEM, err = ioutil.ReadFile(filepath.Join("testdata", "certs", "Org1-client1-cert.pem"))
   330  	if err != nil {
   331  		t.Fatalf("unexpected error reading cert for test: %v", err)
   332  	}
   333  	certs.KeyPEM, err = ioutil.ReadFile(filepath.Join("testdata", "certs", "Org1-client1-key.pem"))
   334  	if err != nil {
   335  		t.Fatalf("unexpected error reading key for test: %v", err)
   336  	}
   337  	certs.ClientCert, err = tls.X509KeyPair(certs.CertPEM, certs.KeyPEM)
   338  	if err != nil {
   339  		t.Fatalf("unexpected error loading certificate for test: %v", err)
   340  	}
   341  	certs.ServerCert, err = tls.LoadX509KeyPair(
   342  		filepath.Join("testdata", "certs", "Org1-server1-cert.pem"),
   343  		filepath.Join("testdata", "certs", "Org1-server1-key.pem"),
   344  	)
   345  	require.NoError(t, err)
   346  
   347  	return certs
   348  }