github.com/hechain20/hechain@v0.0.0-20220316014945-b544036ba106/internal/pkg/comm/config_test.go (about) 1 /* 2 Copyright hechain. All Rights Reserved. 3 4 SPDX-License-Identifier: Apache-2.0 5 */ 6 7 package comm 8 9 import ( 10 "crypto/tls" 11 "crypto/x509" 12 "io/ioutil" 13 "path/filepath" 14 "testing" 15 "time" 16 17 "github.com/hechain20/hechain/common/crypto/tlsgen" 18 "github.com/stretchr/testify/require" 19 "google.golang.org/grpc" 20 "google.golang.org/grpc/keepalive" 21 ) 22 23 func TestServerKeepaliveOptions(t *testing.T) { 24 t.Parallel() 25 26 kap := keepalive.ServerParameters{ 27 Time: DefaultKeepaliveOptions.ServerInterval, 28 Timeout: DefaultKeepaliveOptions.ServerTimeout, 29 } 30 kep := keepalive.EnforcementPolicy{ 31 MinTime: DefaultKeepaliveOptions.ServerMinInterval, 32 PermitWithoutStream: true, 33 } 34 expectedOpts := []grpc.ServerOption{ 35 grpc.KeepaliveParams(kap), 36 grpc.KeepaliveEnforcementPolicy(kep), 37 } 38 opts := DefaultKeepaliveOptions.ServerKeepaliveOptions() 39 40 // Unable to test equality of options since the option methods return 41 // functions and each instance is a different func. 42 // Unable to test the equality of applying the options to the server 43 // implementation because the server embeds channels. 44 // Fallback to a sanity check. 45 require.Len(t, opts, len(expectedOpts)) 46 for i := range opts { 47 require.IsType(t, expectedOpts[i], opts[i]) 48 } 49 } 50 51 func TestClientKeepaliveOptions(t *testing.T) { 52 t.Parallel() 53 54 kap := keepalive.ClientParameters{ 55 Time: DefaultKeepaliveOptions.ClientInterval, 56 Timeout: DefaultKeepaliveOptions.ClientTimeout, 57 PermitWithoutStream: true, 58 } 59 expectedOpts := []grpc.DialOption{grpc.WithKeepaliveParams(kap)} 60 opts := DefaultKeepaliveOptions.ClientKeepaliveOptions() 61 62 // Unable to test equality of options since the option methods return 63 // functions and each instance is a different func. 64 // Fallback to a sanity check. 65 require.Len(t, opts, len(expectedOpts)) 66 for i := range opts { 67 require.IsType(t, expectedOpts[i], opts[i]) 68 } 69 } 70 71 func TestClientConfigClone(t *testing.T) { 72 origin := ClientConfig{ 73 KaOpts: KeepaliveOptions{ 74 ClientInterval: time.Second, 75 }, 76 SecOpts: SecureOptions{ 77 Key: []byte{1, 2, 3}, 78 }, 79 DialTimeout: time.Second, 80 AsyncConnect: true, 81 } 82 83 clone := origin 84 85 // Same content, different inner fields references. 86 require.Equal(t, origin, clone) 87 88 // We change the contents of the fields and ensure it doesn't 89 // propagate across instances. 90 origin.AsyncConnect = false 91 origin.KaOpts.ServerInterval = time.Second 92 origin.KaOpts.ClientInterval = time.Hour 93 origin.SecOpts.Certificate = []byte{1, 2, 3} 94 origin.SecOpts.Key = []byte{5, 4, 6} 95 origin.DialTimeout = time.Second * 2 96 97 clone.SecOpts.UseTLS = true 98 clone.KaOpts.ServerMinInterval = time.Hour 99 100 expectedOriginState := ClientConfig{ 101 KaOpts: KeepaliveOptions{ 102 ClientInterval: time.Hour, 103 ServerInterval: time.Second, 104 }, 105 SecOpts: SecureOptions{ 106 Key: []byte{5, 4, 6}, 107 Certificate: []byte{1, 2, 3}, 108 }, 109 DialTimeout: time.Second * 2, 110 } 111 112 expectedCloneState := ClientConfig{ 113 KaOpts: KeepaliveOptions{ 114 ClientInterval: time.Second, 115 ServerMinInterval: time.Hour, 116 }, 117 SecOpts: SecureOptions{ 118 Key: []byte{1, 2, 3}, 119 UseTLS: true, 120 }, 121 DialTimeout: time.Second, 122 AsyncConnect: true, 123 } 124 125 require.Equal(t, expectedOriginState, origin) 126 require.Equal(t, expectedCloneState, clone) 127 } 128 129 func TestSecureOptionsTLSConfig(t *testing.T) { 130 ca1, err := tlsgen.NewCA() 131 require.NoError(t, err, "failed to create CA1") 132 ca2, err := tlsgen.NewCA() 133 require.NoError(t, err, "failed to create CA2") 134 ckp, err := ca1.NewClientCertKeyPair() 135 require.NoError(t, err, "failed to create client key pair") 136 clientCert, err := tls.X509KeyPair(ckp.Cert, ckp.Key) 137 require.NoError(t, err, "failed to create client certificate") 138 139 newCertPool := func(cas ...tlsgen.CA) *x509.CertPool { 140 cp := x509.NewCertPool() 141 for _, ca := range cas { 142 ok := cp.AppendCertsFromPEM(ca.CertBytes()) 143 require.True(t, ok, "failed to add cert to pool") 144 } 145 return cp 146 } 147 148 tests := []struct { 149 desc string 150 so SecureOptions 151 tc *tls.Config 152 expectedErr string 153 }{ 154 {desc: "TLSDisabled"}, 155 {desc: "TLSEnabled", so: SecureOptions{UseTLS: true}, tc: &tls.Config{MinVersion: tls.VersionTLS12}}, 156 { 157 desc: "ServerNameOverride", 158 so: SecureOptions{UseTLS: true, ServerNameOverride: "bob"}, 159 tc: &tls.Config{MinVersion: tls.VersionTLS12, ServerName: "bob"}, 160 }, 161 { 162 desc: "WithServerRootCAs", 163 so: SecureOptions{UseTLS: true, ServerRootCAs: [][]byte{ca1.CertBytes(), ca2.CertBytes()}}, 164 tc: &tls.Config{MinVersion: tls.VersionTLS12, RootCAs: newCertPool(ca1, ca2)}, 165 }, 166 { 167 desc: "BadServerRootCertificate", 168 so: SecureOptions{ 169 UseTLS: true, 170 ServerRootCAs: [][]byte{ 171 []byte("-----BEGIN CERTIFICATE-----\nYm9ndXM=\n-----END CERTIFICATE-----"), 172 }, 173 }, 174 expectedErr: "error adding root certificate", 175 }, 176 { 177 desc: "WithRequiredClientKeyPair", 178 so: SecureOptions{UseTLS: true, RequireClientCert: true, Key: ckp.Key, Certificate: ckp.Cert}, 179 tc: &tls.Config{MinVersion: tls.VersionTLS12, Certificates: []tls.Certificate{clientCert}}, 180 }, 181 { 182 desc: "MissingClientKey", 183 so: SecureOptions{UseTLS: true, RequireClientCert: true, Certificate: ckp.Cert}, 184 expectedErr: "both Key and Certificate are required when using mutual TLS", 185 }, 186 { 187 desc: "MissingClientCert", 188 so: SecureOptions{UseTLS: true, RequireClientCert: true, Key: ckp.Key}, 189 expectedErr: "both Key and Certificate are required when using mutual TLS", 190 }, 191 { 192 desc: "WithTimeShift", 193 so: SecureOptions{UseTLS: true, TimeShift: 2 * time.Hour}, 194 tc: &tls.Config{MinVersion: tls.VersionTLS12}, 195 }, 196 } 197 for _, tt := range tests { 198 t.Run(tt.desc, func(t *testing.T) { 199 tc, err := tt.so.TLSConfig() 200 if tt.expectedErr != "" { 201 require.ErrorContainsf(t, err, tt.expectedErr, "got %v, want %s", err, tt.expectedErr) 202 return 203 } 204 require.NoError(t, err) 205 206 if len(tt.so.ServerRootCAs) != 0 { 207 require.NotNil(t, tc.RootCAs) 208 require.Len(t, tc.RootCAs.Subjects(), len(tt.so.ServerRootCAs)) 209 for _, subj := range tt.tc.RootCAs.Subjects() { 210 require.Contains(t, tc.RootCAs.Subjects(), subj, "missing subject %x", subj) 211 } 212 tt.tc.RootCAs, tc.RootCAs = nil, nil 213 } 214 215 if tt.so.TimeShift != 0 { 216 require.NotNil(t, tc.Time) 217 require.WithinDuration(t, time.Now().Add(-1*tt.so.TimeShift), tc.Time(), 10*time.Second) 218 tc.Time = nil 219 } 220 221 require.Equal(t, tt.tc, tc) 222 }) 223 } 224 } 225 226 func TestClientConfigDialOptions_GoodConfig(t *testing.T) { 227 testCerts := LoadTestCerts(t) 228 229 config := ClientConfig{} 230 opts, err := config.DialOptions() 231 require.NoError(t, err) 232 require.NotEmpty(t, opts) 233 234 secOpts := SecureOptions{ 235 UseTLS: true, 236 ServerRootCAs: [][]byte{testCerts.CAPEM}, 237 RequireClientCert: false, 238 } 239 config.SecOpts = secOpts 240 opts, err = config.DialOptions() 241 require.NoError(t, err) 242 require.NotEmpty(t, opts) 243 244 secOpts = SecureOptions{ 245 Certificate: testCerts.CertPEM, 246 Key: testCerts.KeyPEM, 247 UseTLS: true, 248 ServerRootCAs: [][]byte{testCerts.CAPEM}, 249 RequireClientCert: true, 250 } 251 clientCert, err := secOpts.ClientCertificate() 252 require.NoError(t, err) 253 require.Equal(t, testCerts.ClientCert, clientCert) 254 config.SecOpts = secOpts 255 opts, err = config.DialOptions() 256 require.NoError(t, err) 257 require.NotEmpty(t, opts) 258 } 259 260 func TestClientConfigDialOptions_BadConfig(t *testing.T) { 261 testCerts := LoadTestCerts(t) 262 263 // bad root cert 264 config := ClientConfig{ 265 SecOpts: SecureOptions{ 266 UseTLS: true, 267 ServerRootCAs: [][]byte{[]byte(badPEM)}, 268 }, 269 } 270 _, err := config.DialOptions() 271 require.ErrorContains(t, err, "error adding root certificate") 272 273 // missing key 274 config.SecOpts = SecureOptions{ 275 Certificate: []byte("cert"), 276 UseTLS: true, 277 RequireClientCert: true, 278 } 279 _, err = config.DialOptions() 280 require.ErrorContains(t, err, "both Key and Certificate are required when using mutual TLS") 281 282 // missing cert 283 config.SecOpts = SecureOptions{ 284 Key: []byte("key"), 285 UseTLS: true, 286 RequireClientCert: true, 287 } 288 _, err = config.DialOptions() 289 require.ErrorContains(t, err, "both Key and Certificate are required when using mutual TLS") 290 291 // bad key 292 config.SecOpts = SecureOptions{ 293 Certificate: testCerts.CertPEM, 294 Key: []byte(badPEM), 295 UseTLS: true, 296 RequireClientCert: true, 297 } 298 _, err = config.DialOptions() 299 require.ErrorContains(t, err, "failed to load client certificate") 300 301 // bad cert 302 config.SecOpts = SecureOptions{ 303 Certificate: []byte(badPEM), 304 Key: testCerts.KeyPEM, 305 UseTLS: true, 306 RequireClientCert: true, 307 } 308 _, err = config.DialOptions() 309 require.ErrorContains(t, err, "failed to load client certificate") 310 } 311 312 type TestCerts struct { 313 CAPEM []byte 314 CertPEM []byte 315 KeyPEM []byte 316 ClientCert tls.Certificate 317 ServerCert tls.Certificate 318 } 319 320 func LoadTestCerts(t *testing.T) TestCerts { 321 t.Helper() 322 323 var certs TestCerts 324 var err error 325 certs.CAPEM, err = ioutil.ReadFile(filepath.Join("testdata", "certs", "Org1-cert.pem")) 326 if err != nil { 327 t.Fatalf("unexpected error reading root cert for test: %v", err) 328 } 329 certs.CertPEM, err = ioutil.ReadFile(filepath.Join("testdata", "certs", "Org1-client1-cert.pem")) 330 if err != nil { 331 t.Fatalf("unexpected error reading cert for test: %v", err) 332 } 333 certs.KeyPEM, err = ioutil.ReadFile(filepath.Join("testdata", "certs", "Org1-client1-key.pem")) 334 if err != nil { 335 t.Fatalf("unexpected error reading key for test: %v", err) 336 } 337 certs.ClientCert, err = tls.X509KeyPair(certs.CertPEM, certs.KeyPEM) 338 if err != nil { 339 t.Fatalf("unexpected error loading certificate for test: %v", err) 340 } 341 certs.ServerCert, err = tls.LoadX509KeyPair( 342 filepath.Join("testdata", "certs", "Org1-server1-cert.pem"), 343 filepath.Join("testdata", "certs", "Org1-server1-key.pem"), 344 ) 345 require.NoError(t, err) 346 347 return certs 348 }