github.com/microsoft/moc@v0.17.1/pkg/certs/certs_test.go (about) 1 // Copyright (c) Microsoft Corporation. All rights reserved. 2 // Licensed under the Apache v2.0 license. 3 package certs 4 5 import ( 6 "bytes" 7 "context" 8 "crypto/rand" 9 "crypto/rsa" 10 "crypto/tls" 11 "crypto/x509" 12 "crypto/x509/pkix" 13 "encoding/asn1" 14 "fmt" 15 "log" 16 "math" 17 "math/big" 18 "net" 19 "testing" 20 "time" 21 22 "github.com/microsoft/moc/pkg/errors" 23 24 gomock "github.com/golang/mock/gomock" 25 mock "github.com/microsoft/moc/pkg/certs/mock" 26 "github.com/microsoft/moc/rpc/testagent" 27 "github.com/stretchr/testify/assert" 28 "google.golang.org/grpc" 29 "google.golang.org/grpc/codes" 30 "google.golang.org/grpc/credentials" 31 "google.golang.org/grpc/status" 32 ) 33 34 func IsTransportUnavailable(err error) bool { 35 if e, ok := status.FromError(err); ok && e.Code() == codes.Unavailable { 36 return true 37 } 38 return false 39 } 40 41 type TestTlsServer struct { 42 } 43 44 func (s *TestTlsServer) PingHello(ctx context.Context, in *testagent.Hello) (*testagent.Hello, error) { 45 return &testagent.Hello{Name: "Hello From the Server!" + in.Name}, nil 46 } 47 48 func startHelloServer(grpcServer *grpc.Server, address string) { 49 lis, err := net.Listen("tcp", address) 50 if err != nil { 51 log.Fatalf("failed to listen: %v", err) 52 } 53 tlsServer := TestTlsServer{} 54 testagent.RegisterHelloAgentServer(grpcServer, &tlsServer) 55 if err := grpcServer.Serve(lis); err != nil { 56 log.Fatalf("failed to serve: %s", err) 57 } 58 } 59 60 type CertAuthority struct { 61 ca *CertificateAuthority 62 } 63 64 func (auth *CertAuthority) VerifyPeerCertificate(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error { 65 return auth.ca.VerifyClientCertificate(rawCerts) 66 } 67 68 func getTlsCreds(t *testing.T, tlsCert tls.Certificate, certAuth *CertAuthority) credentials.TransportCredentials { 69 70 return credentials.NewTLS(&tls.Config{ 71 CipherSuites: []uint16{ 72 tls.TLS_AES_128_GCM_SHA256, 73 tls.TLS_AES_256_GCM_SHA384, 74 tls.TLS_CHACHA20_POLY1305_SHA256, 75 tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, 76 tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305, 77 tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 78 }, 79 MinVersion: tls.VersionTLS12, 80 PreferServerCipherSuites: true, 81 ClientAuth: tls.RequestClientCert, 82 Certificates: []tls.Certificate{tlsCert}, 83 VerifyPeerCertificate: certAuth.VerifyPeerCertificate, 84 }) 85 } 86 87 func getGrpcServer(t *testing.T, creds credentials.TransportCredentials) *grpc.Server { 88 var opts []grpc.ServerOption 89 opts = append(opts, grpc.Creds(creds)) 90 grpcServer := grpc.NewServer(opts...) 91 return grpcServer 92 } 93 94 func makeTlsCall(t *testing.T, address string, provider credentials.TransportCredentials) (*testagent.Hello, error) { 95 var conn *grpc.ClientConn 96 var err error 97 if provider != nil { 98 conn, err = grpc.Dial(address, grpc.WithTransportCredentials(provider)) 99 } else { 100 conn, err = grpc.Dial(address, grpc.WithInsecure()) 101 } 102 assert.NoErrorf(t, err, "Failed to dial", err) 103 defer conn.Close() 104 c := testagent.NewHelloAgentClient(conn) 105 return c.PingHello(context.Background(), &testagent.Hello{Name: "TLSServer"}) 106 } 107 108 func createTestCertificate(before, after time.Time) (string, error) { 109 key, err := rsa.GenerateKey(rand.Reader, 2048) 110 if err != nil { 111 return "", err 112 } 113 114 serial, err := rand.Int(rand.Reader, new(big.Int).SetInt64(math.MaxInt64)) 115 if err != nil { 116 return "", err 117 } 118 119 tmpl := x509.Certificate{ 120 SerialNumber: serial, 121 Subject: pkix.Name{ 122 CommonName: "test", 123 Organization: []string{"microsoft"}, 124 }, 125 NotBefore: before, 126 NotAfter: after, 127 KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, 128 MaxPathLenZero: true, 129 BasicConstraintsValid: true, 130 MaxPathLen: 0, 131 IsCA: true, 132 } 133 134 b, err := x509.CreateCertificate(rand.Reader, &tmpl, &tmpl, key.Public(), key) 135 if err != nil { 136 return "", err 137 } 138 139 x509Cert, err := x509.ParseCertificate(b) 140 if err != nil { 141 return "", err 142 } 143 144 pemCert := EncodeCertPEM(x509Cert) 145 return string(pemCert), nil 146 } 147 148 func NewTransportCredentialFromAuthFromPem(serverName string, tlsCert tls.Certificate, caCertPem []byte) (credentials.TransportCredentials, error) { 149 certPool := x509.NewCertPool() 150 // Append the client certificates from the CA 151 if ok := certPool.AppendCertsFromPEM(caCertPem); !ok { 152 return nil, fmt.Errorf("could not append the server certificate") 153 } 154 creds := &tls.Config{ 155 ServerName: serverName, 156 RootCAs: certPool, 157 Certificates: []tls.Certificate{tlsCert}, 158 } 159 return credentials.NewTLS(creds), nil 160 } 161 162 func Test_TLSServer(t *testing.T) { 163 server := "localhost" 164 port := "9000" 165 address := server + ":" + port 166 ca, key, err := GenerateClientCertificate("test CA") 167 assert.NoErrorf(t, err, "Error creation in CA certificate failed: %v", err) 168 169 rootSigner, err := tls.X509KeyPair(EncodeCertPEM(ca), EncodePrivateKeyPEM(key)) 170 assert.NoErrorf(t, err, "Failed to load root key pair: %v", err) 171 172 caConfig := CAConfig{ 173 RootSigner: &rootSigner, 174 } 175 176 caAuth, err := NewCertificateAuthority(&caConfig) 177 assert.NoErrorf(t, err, "Error creation CA Auth: %v", err) 178 179 certPem := EncodeCertPEM(ca) 180 keyPem := EncodePrivateKeyPEM(key) 181 tlsCert, err := tls.X509KeyPair(certPem, keyPem) 182 assert.NoErrorf(t, err, "Failed to get tls cert", err) 183 184 creds := getTlsCreds(t, tlsCert, &CertAuthority{caAuth}) 185 grpcServer := getGrpcServer(t, creds) 186 go startHelloServer(grpcServer, address) 187 defer grpcServer.Stop() 188 time.Sleep((time.Second * 3)) 189 conf := Config{ 190 CommonName: "Test Cert", 191 Organization: []string{"microsoft"}, 192 } 193 conf.AltNames.DNSNames = []string{"Test Cert"} 194 csr, keyClientPem, err := GenerateCertificateRequest(&conf, nil) 195 assert.NoErrorf(t, err, "Error creation in CSR: %v", err) 196 197 signConf := SignConfig{Offset: time.Second * 5} 198 clientCertPem, err := caAuth.SignRequest(csr, nil, &signConf) 199 assert.NoErrorf(t, err, "Error signing CSR: %v", err) 200 tlsClientCert, err := tls.X509KeyPair(clientCertPem, keyClientPem) 201 assert.NoErrorf(t, err, "Failed to get tls cert", err) 202 203 provider, err := NewTransportCredentialFromAuthFromPem(server, tlsClientCert, EncodeCertPEM(ca)) 204 assert.NoErrorf(t, err, "Failed to create TLS Credentials", err) 205 // Making the certificate invalid 206 time.Sleep((time.Second * 10)) 207 _, err = makeTlsCall(t, address, provider) 208 assert.True(t, IsTransportUnavailable(err)) 209 } 210 211 func Test_CACerts(t *testing.T) { 212 ca, key, err := GenerateClientCertificate("test CA") 213 assert.NoErrorf(t, err, "Error creation in CA certificate failed: %v", err) 214 215 rootSigner, err := tls.X509KeyPair(EncodeCertPEM(ca), EncodePrivateKeyPEM(key)) 216 assert.NoErrorf(t, err, "Failed to load root key pair: %v", err) 217 218 caConfig := CAConfig{ 219 RootSigner: &rootSigner, 220 } 221 caAuth, err := NewCertificateAuthority(&caConfig) 222 assert.NoErrorf(t, err, "Error creation CA Auth: %v", err) 223 224 conf := Config{ 225 CommonName: "Test Cert", 226 Organization: []string{"microsoft"}, 227 } 228 conf.AltNames.DNSNames = []string{"Test Cert"} 229 csr, keyClientPem, err := GenerateCertificateRequest(&conf, nil) 230 assert.NoErrorf(t, err, "Error creation in CSR: %v", err) 231 keyClient, err := DecodePrivateKeyPEM(keyClientPem) 232 assert.NoErrorf(t, err, "Failed Decoding privatekey: %v", err) 233 clientCertPem, err := caAuth.SignRequest(csr, nil, nil) 234 assert.NoErrorf(t, err, "Error signing CSR: %v", err) 235 clientCert, err := DecodeCertPEM(clientCertPem) 236 assert.NoErrorf(t, err, "Failed Decoding cert: %v", err) 237 if (clientCert.NotAfter.Sub(clientCert.NotBefore)) != (time.Hour * 24 * 365) { 238 t.Errorf("Invalid certificate expiry") 239 } 240 241 foundCertDER := false 242 foundRenewCount := false 243 for _, ext := range clientCert.Extensions { 244 if ext.Id.Equal(oidOriginalCertificate) { 245 foundCertDER = true 246 } else if ext.Id.Equal(oidRenewCount) { 247 foundRenewCount = true 248 } 249 } 250 251 if foundRenewCount || foundCertDER { 252 t.Errorf("Found certDER or renewCount Extensions") 253 } 254 255 roots := x509.NewCertPool() 256 roots.AddCert(ca) 257 258 opts := x509.VerifyOptions{ 259 Roots: roots, 260 DNSName: "Test Cert", 261 KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, 262 } 263 264 if _, err := clientCert.Verify(opts); err != nil { 265 panic("failed to verify certificate: " + err.Error()) 266 } 267 if _, err = tls.X509KeyPair(EncodeCertPEM(clientCert), EncodePrivateKeyPEM(keyClient)); err != nil { 268 t.Errorf("Error Verifying key and cert: %s", err.Error()) 269 } 270 } 271 272 func Test_CACertsVerify(t *testing.T) { 273 ca, key, err := GenerateClientCertificate("test CA") 274 assert.NoErrorf(t, err, "Error creation in CA certificate failed: %v", err) 275 276 rootSigner, err := tls.X509KeyPair(EncodeCertPEM(ca), EncodePrivateKeyPEM(key)) 277 assert.NoErrorf(t, err, "Failed to load root key pair: %v", err) 278 279 caConfig := CAConfig{ 280 RootSigner: &rootSigner, 281 } 282 283 caAuth, err := NewCertificateAuthority(&caConfig) 284 assert.NoErrorf(t, err, "Error creation CA Auth: %v", err) 285 286 conf := Config{ 287 CommonName: "Test Cert", 288 Organization: []string{"microsoft"}, 289 } 290 conf.AltNames.DNSNames = []string{"Test Cert"} 291 csr, keyClientPem, err := GenerateCertificateRequest(&conf, nil) 292 assert.NoErrorf(t, err, "Error creation in CSR: %v", err) 293 keyClient, err := DecodePrivateKeyPEM(keyClientPem) 294 assert.NoErrorf(t, err, "Failed Decoding privatekey: %v", err) 295 296 signConf := SignConfig{Offset: time.Second * 5} 297 clientCertPem, err := caAuth.SignRequest(csr, nil, &signConf) 298 assert.NoErrorf(t, err, "Error signing CSR: %v", err) 299 300 clientCert, err := DecodeCertPEM(clientCertPem) 301 assert.NoErrorf(t, err, "Failed Decoding cert: %v", err) 302 303 if (clientCert.NotAfter.Sub(clientCert.NotBefore)) != signConf.Offset { 304 t.Errorf("Invalid certificate expiry") 305 } 306 307 foundCertDER := false 308 foundRenewCount := false 309 for _, ext := range clientCert.Extensions { 310 if ext.Id.Equal(oidOriginalCertificate) { 311 foundCertDER = true 312 } else if ext.Id.Equal(oidRenewCount) { 313 foundRenewCount = true 314 } 315 } 316 317 if foundRenewCount || foundCertDER { 318 t.Errorf("Found certDER or renewCount Extensions") 319 } 320 321 clientCerts := [][]byte{clientCert.Raw} 322 323 err = caAuth.VerifyClientCertificate(clientCerts) 324 assert.NoErrorf(t, err, "failed to verify certificate: %v", err) 325 326 time.Sleep(time.Second * 6) 327 err = caAuth.VerifyClientCertificate(clientCerts) 328 assert.Errorf(t, err, "failed to verify certificate after Expiry") 329 330 _, err = tls.X509KeyPair(EncodeCertPEM(clientCert), EncodePrivateKeyPEM(keyClient)) 331 assert.NoErrorf(t, err, "Error Verifying key and cert: %v", err) 332 } 333 334 func Test_CACertsRenewVerify(t *testing.T) { 335 ca, key, err := GenerateClientCertificate("test CA") 336 assert.NoErrorf(t, err, "Error creation in CA certificate failed: %v", err) 337 338 rootSigner, err := tls.X509KeyPair(EncodeCertPEM(ca), EncodePrivateKeyPEM(key)) 339 assert.NoErrorf(t, err, "Failed to load root key pair: %v", err) 340 341 caConfig := CAConfig{ 342 RootSigner: &rootSigner, 343 } 344 caAuth, err := NewCertificateAuthority(&caConfig) 345 assert.NoErrorf(t, err, "Error creation CA Auth: %v", err) 346 347 conf := Config{ 348 CommonName: "Test Cert", 349 Organization: []string{"microsoft"}, 350 } 351 conf.AltNames.DNSNames = []string{"Test Cert"} 352 csr, keyClientPem, err := GenerateCertificateRequest(&conf, nil) 353 assert.NoErrorf(t, err, "Error creation in CSR: %v", err) 354 355 keyClient, err := DecodePrivateKeyPEM(keyClientPem) 356 assert.NoErrorf(t, err, "Failed Decoding privatekey: %v", err) 357 358 signConf := SignConfig{Offset: time.Second * 5} 359 clientCertPem, err := caAuth.SignRequest(csr, nil, &signConf) 360 assert.NoErrorf(t, err, "Error signing CSR: %v", err) 361 362 clientCert, err := DecodeCertPEM(clientCertPem) 363 assert.NoErrorf(t, err, "Failed Decoding cert: %v", err) 364 365 // Test certificate duration 366 if (clientCert.NotAfter.Sub(clientCert.NotBefore)) != signConf.Offset { 367 t.Errorf("Invalid certificate expiry") 368 } 369 370 clientCerts := [][]byte{clientCert.Raw} 371 372 err = caAuth.VerifyClientCertificate(clientCerts) 373 assert.NoErrorf(t, err, "Failed to verify certificate: %v", err) 374 375 oldcert, err := tls.X509KeyPair(EncodeCertPEM(clientCert), EncodePrivateKeyPEM(keyClient)) 376 assert.NoErrorf(t, err, "Error creating X509 keypair: %v", err) 377 378 // ================= Renew 1 ======================== 379 csr1, keyClient1Pem, err := GenerateCertificateRenewRequest(&oldcert) 380 assert.NoErrorf(t, err, "Error creating renew CSR: %v", err) 381 382 keyClient1, err := DecodePrivateKeyPEM(keyClient1Pem) 383 assert.NoErrorf(t, err, "Failed Decoding privatekey: %v", err) 384 385 signConf = SignConfig{Offset: time.Second * 20} 386 certClient1Pem, err := caAuth.SignRequest(csr1, clientCert.Raw, &signConf) 387 assert.NoErrorf(t, err, "Error signing CSR: %v", err) 388 389 certClient1, err := DecodeCertPEM(certClient1Pem) 390 assert.NoErrorf(t, err, "Failed Decoding cert: %v", err) 391 392 // Test certificate duration 393 if (certClient1.NotAfter.Sub(certClient1.NotBefore)) != (time.Second * 5) { 394 t.Errorf("Invalid certificate expiry") 395 } 396 397 foundCertDER := false 398 foundRenewCount := false 399 var origCertDER []byte 400 var renewCount int64 = 0 401 for _, ext := range certClient1.Extensions { 402 if ext.Id.Equal(oidOriginalCertificate) { 403 origCertDER = ext.Value 404 foundCertDER = true 405 } else if ext.Id.Equal(oidRenewCount) { 406 asn1.Unmarshal(ext.Value, &renewCount) 407 foundRenewCount = true 408 } 409 } 410 411 if !(foundRenewCount && foundCertDER) { 412 t.Errorf("Not found certDER or renewCount Extensions") 413 } 414 415 if !bytes.Equal(origCertDER, clientCert.Raw) { 416 t.Errorf("Extension not Matching old cert") 417 } 418 419 if renewCount != 1 { 420 t.Errorf("Extension renew count is wrong") 421 } 422 423 clientCerts = [][]byte{certClient1.Raw} 424 err = caAuth.VerifyClientCertificate(clientCerts) 425 assert.NoErrorf(t, err, "failed to verify certificate: %v", err) 426 _, err = tls.X509KeyPair(EncodeCertPEM(certClient1), EncodePrivateKeyPEM(keyClient1)) 427 assert.NoErrorf(t, err, "Error Verifying key and cert: %v", err) 428 429 // ================= Renew 2 ======================== 430 oldcert, err = tls.X509KeyPair(EncodeCertPEM(certClient1), EncodePrivateKeyPEM(keyClient1)) 431 assert.NoErrorf(t, err, "Error creating X509 keypair: %v", err) 432 433 csr2, keyClient2Pem, err := GenerateCertificateRenewRequest(&oldcert) 434 assert.NoErrorf(t, err, "Error creating renew CSR: %v", err) 435 436 keyClient2, err := DecodePrivateKeyPEM(keyClient2Pem) 437 assert.NoErrorf(t, err, "Failed Decoding privatekey: %v", err) 438 439 certClient2Pem, err := caAuth.SignRequest(csr2, certClient1.Raw, nil) 440 assert.NoErrorf(t, err, "Error signing CSR: %v", err) 441 certClient2, err := DecodeCertPEM(certClient2Pem) 442 assert.NoErrorf(t, err, "Failed Decoding cert: %v", err) 443 444 // Test certificate duration 445 if (certClient2.NotAfter.Sub(certClient2.NotBefore)) != (time.Second * 5) { 446 t.Errorf("Invalid certificate expiry") 447 } 448 449 foundCertDER = false 450 foundRenewCount = false 451 for _, ext := range certClient2.Extensions { 452 if ext.Id.Equal(oidOriginalCertificate) { 453 origCertDER = ext.Value 454 foundCertDER = true 455 } else if ext.Id.Equal(oidRenewCount) { 456 asn1.Unmarshal(ext.Value, &renewCount) 457 foundRenewCount = true 458 } 459 } 460 461 if !(foundRenewCount && foundCertDER) { 462 t.Errorf("Not found certDER or renewCount Extensions") 463 } 464 465 // The origCertDER should point to the first cert 466 if !bytes.Equal(origCertDER, clientCert.Raw) { 467 t.Errorf("Extension not Matching old cert") 468 } 469 470 if renewCount != 2 { 471 t.Errorf("Extension renew count is wrong") 472 } 473 474 clientCerts = [][]byte{certClient2.Raw} 475 err = caAuth.VerifyClientCertificate(clientCerts) 476 assert.NoErrorf(t, err, "failed to verify certificate: %v", err) 477 _, err = tls.X509KeyPair(EncodeCertPEM(certClient2), EncodePrivateKeyPEM(keyClient2)) 478 assert.NoErrorf(t, err, "Error Verifying key and cert: %v", err) 479 } 480 481 func Test_CACertsRenewVerifySameKey(t *testing.T) { 482 ca, key, err := GenerateClientCertificate("test CA") 483 if err != nil { 484 t.Errorf("Error creation in CA certificate failed: %s", err.Error()) 485 } 486 487 rootSigner, err := tls.X509KeyPair(EncodeCertPEM(ca), EncodePrivateKeyPEM(key)) 488 if err != nil { 489 t.Errorf("Failed to load root key pair: %v", err) 490 return 491 } 492 493 caConfig := CAConfig{ 494 RootSigner: &rootSigner, 495 } 496 497 caAuth, err := NewCertificateAuthority(&caConfig) 498 if err != nil { 499 t.Errorf("Error creation CA Auth: %s", err.Error()) 500 } 501 502 conf := Config{ 503 CommonName: "Test Cert", 504 Organization: []string{"microsoft"}, 505 } 506 conf.AltNames.DNSNames = []string{"Test Cert"} 507 csr, keyClientPem, err := GenerateCertificateRequest(&conf, nil) 508 if err != nil { 509 t.Errorf("Error creation in CSR: %s", err.Error()) 510 } 511 keyClient, err := DecodePrivateKeyPEM(keyClientPem) 512 if err != nil { 513 t.Errorf("Failed Decoding privatekey: %s", err.Error()) 514 } 515 signConf := SignConfig{Offset: time.Second * 5} 516 clientCertPem, err := caAuth.SignRequest(csr, nil, &signConf) 517 if err != nil { 518 t.Errorf("Error signing CSR: %s", err.Error()) 519 } 520 clientCert, err := DecodeCertPEM(clientCertPem) 521 if err != nil { 522 t.Errorf("Failed Decoding cert: %s", err.Error()) 523 } 524 // Test certificate duration 525 if (clientCert.NotAfter.Sub(clientCert.NotBefore)) != signConf.Offset { 526 t.Errorf("Invalid certificate expiry") 527 } 528 529 clientCerts := [][]byte{clientCert.Raw} 530 531 if err := caAuth.VerifyClientCertificate(clientCerts); err != nil { 532 panic("failed to verify certificate: " + err.Error()) 533 } 534 oldcert, err := tls.X509KeyPair(EncodeCertPEM(clientCert), EncodePrivateKeyPEM(keyClient)) 535 if err != nil { 536 t.Errorf("Error creating X509 keypair: %s", err.Error()) 537 } 538 539 // ================= Renew 1 ======================== 540 csr1, err := GenerateCertificateRenewRequestSameKey(&oldcert) 541 if err != nil { 542 t.Errorf("Error creating renew CSR: %s", err.Error()) 543 } 544 certClient1Pem, err := caAuth.SignRequest(csr1, clientCert.Raw, nil) 545 if err != nil { 546 t.Errorf("Error signing CSR: %s", err.Error()) 547 } 548 549 certClient1, err := DecodeCertPEM(certClient1Pem) 550 if err != nil { 551 t.Errorf("Failed Decoding cert: %s", err.Error()) 552 } 553 554 // Test certificate duration 555 if (certClient1.NotAfter.Sub(certClient1.NotBefore)) != signConf.Offset { 556 t.Errorf("Invalid certificate expiry") 557 } 558 559 foundCertDER := false 560 foundRenewCount := false 561 var origCertDER []byte 562 var renewCount int64 = 0 563 for _, ext := range certClient1.Extensions { 564 if ext.Id.Equal(oidOriginalCertificate) { 565 origCertDER = ext.Value 566 foundCertDER = true 567 } else if ext.Id.Equal(oidRenewCount) { 568 asn1.Unmarshal(ext.Value, &renewCount) 569 foundRenewCount = true 570 } 571 } 572 573 if !(foundRenewCount && foundCertDER) { 574 t.Errorf("Not found certDER or renewCount Extensions") 575 } 576 577 if !bytes.Equal(origCertDER, clientCert.Raw) { 578 t.Errorf("Extension not Matching old cert") 579 } 580 581 if renewCount != 1 { 582 t.Errorf("Extension renew count is wrong") 583 } 584 585 clientCerts = [][]byte{certClient1.Raw} 586 if err := caAuth.VerifyClientCertificate(clientCerts); err != nil { 587 t.Errorf("failed to verify certificate: " + err.Error()) 588 } 589 if _, err = tls.X509KeyPair(EncodeCertPEM(certClient1), EncodePrivateKeyPEM(keyClient)); err != nil { 590 t.Errorf("Error Verifying key and cert: %s", err.Error()) 591 } 592 593 // ================= Renew 2 ======================== 594 oldcert, err = tls.X509KeyPair(EncodeCertPEM(certClient1), EncodePrivateKeyPEM(keyClient)) 595 if err != nil { 596 t.Errorf("Error creating X509 keypair: %s", err.Error()) 597 } 598 csr2, err := GenerateCertificateRenewRequestSameKey(&oldcert) 599 if err != nil { 600 t.Errorf("Error creating renew CSR: %s", err.Error()) 601 } 602 certClient2Pem, err := caAuth.SignRequest(csr2, certClient1.Raw, nil) 603 if err != nil { 604 t.Errorf("Error signing CSR: %s", err.Error()) 605 } 606 607 certClient2, err := DecodeCertPEM(certClient2Pem) 608 if err != nil { 609 t.Errorf("Failed Decoding cert: %s", err.Error()) 610 } 611 // Test certificate duration 612 if (certClient2.NotAfter.Sub(certClient2.NotBefore)) != signConf.Offset { 613 t.Errorf("Invalid certificate expiry") 614 } 615 616 foundCertDER = false 617 foundRenewCount = false 618 for _, ext := range certClient2.Extensions { 619 if ext.Id.Equal(oidOriginalCertificate) { 620 origCertDER = ext.Value 621 foundCertDER = true 622 } else if ext.Id.Equal(oidRenewCount) { 623 asn1.Unmarshal(ext.Value, &renewCount) 624 foundRenewCount = true 625 } 626 } 627 628 if !(foundRenewCount && foundCertDER) { 629 t.Errorf("Not found certDER or renewCount Extensions") 630 } 631 632 // The origCertDER should point to the first cert 633 if !bytes.Equal(origCertDER, clientCert.Raw) { 634 t.Errorf("Extension not Matching old cert") 635 } 636 637 if renewCount != 2 { 638 t.Errorf("Extension renew count is wrong") 639 } 640 641 clientCerts = [][]byte{certClient2.Raw} 642 if err := caAuth.VerifyClientCertificate(clientCerts); err != nil { 643 t.Errorf("failed to verify certificate: " + err.Error()) 644 } 645 if _, err = tls.X509KeyPair(EncodeCertPEM(certClient2), EncodePrivateKeyPEM(keyClient)); err != nil { 646 t.Errorf("Error Verifying key and cert: %s", err.Error()) 647 } 648 } 649 650 func Test_BackoffFactor(t *testing.T) { 651 _, err := NewBackOffFactor(-1.0, 5) 652 if err == nil || !errors.IsInvalidInput(err) { 653 t.Errorf("Expected Error InvalidInput") 654 } 655 _, err = NewBackOffFactor(1.0, -5.0) 656 if err == nil || !errors.IsInvalidInput(err) { 657 t.Errorf("Expected Error InvalidInput") 658 } 659 _, err = NewBackOffFactor(-1.0, -5.0) 660 if err == nil || !errors.IsInvalidInput(err) { 661 t.Errorf("Expected Error InvalidInput") 662 } 663 } 664 665 func Test_BackoffFactor1(t *testing.T) { 666 factor, err := NewBackOffFactor(1.0, 5) 667 if err != nil { 668 t.Errorf("Error creating Factor: %s", err.Error()) 669 } 670 if factor.errorBackoffFactor != 5 || factor.renewBackoffFactor != 1 { 671 t.Errorf("renewBackoffFactor Expected:1.0 Actual:%f \n errorBackoffFactor Expected:5.0 Actual:%f", factor.renewBackoffFactor, factor.errorBackoffFactor) 672 } 673 } 674 675 func Test_CalculateTime(t *testing.T) { 676 factor, err := NewBackOffFactor(0.3, 0.02) 677 if err != nil { 678 t.Errorf("Error creating Factor: %s", err.Error()) 679 } 680 now := time.Now() 681 before := now.Add(time.Duration(time.Second * -10)) 682 after := now.Add(time.Duration(time.Second * 10)) 683 duration := calculateTime(before, after, now, factor) 684 if duration.RenewBackoffDuration != time.Duration(time.Second*4) { 685 t.Errorf("Wrong wait time returned Expected %s Actual %s", time.Duration(time.Second*4), duration.RenewBackoffDuration) 686 } 687 if duration.RenewBackoffDuration < time.Duration(0) { 688 t.Errorf("Wrong wait time returned Expected greater than zero %s", duration.RenewBackoffDuration) 689 } 690 if duration.ErrorBackoffDuration != time.Duration(time.Millisecond*400) { 691 t.Errorf("Wrong renewbackoff time returned Expected %s Actual %s", time.Duration(time.Millisecond*400), duration.ErrorBackoffDuration) 692 } 693 } 694 695 func Test_CalculateTime1(t *testing.T) { 696 factor, err := NewBackOffFactor(0.1, 0.002) 697 if err != nil { 698 t.Errorf("Error creating Factor: %s", err.Error()) 699 } 700 now := time.Now() 701 before := now.Add(time.Duration(time.Second * -30)) 702 after := now.Add(time.Duration(time.Second * 10)) 703 duration := calculateTime(before, after, now, factor) 704 if duration.RenewBackoffDuration != time.Duration(time.Second*6) { 705 t.Errorf("Wrong wait time returned Expected %s Actual %s", time.Duration(time.Second*6), duration.RenewBackoffDuration) 706 } 707 if duration.RenewBackoffDuration < time.Duration(0) { 708 t.Errorf("Wrong wait time returned Expected greater than zero %s", duration.RenewBackoffDuration) 709 } 710 if duration.ErrorBackoffDuration != time.Duration(time.Millisecond*80) { 711 t.Errorf("Wrong renewbackoff time returned Expected %s Actual %s", time.Duration(time.Millisecond*400), duration.ErrorBackoffDuration) 712 } 713 } 714 715 func Test_CalculateTime2(t *testing.T) { 716 factor, err := NewBackOffFactor(0.5, 0.002) 717 if err != nil { 718 t.Errorf("Error creating Factor: %s", err.Error()) 719 } 720 now := time.Now() 721 before := now.Add(time.Duration(time.Second * -30)) 722 after := now.Add(time.Duration(time.Second * 10)) 723 duration := calculateTime(before, after, now, factor) 724 if duration.RenewBackoffDuration != time.Duration(time.Second*-10) { 725 t.Errorf("Wrong wait time returned Expected %s Actual %s", time.Duration(time.Second*-10), duration.RenewBackoffDuration) 726 } 727 if duration.RenewBackoffDuration > time.Duration(0) { 728 t.Errorf("Wrong wait time returned Expected greater than zero %s", duration.RenewBackoffDuration) 729 } 730 if duration.ErrorBackoffDuration != time.Duration(time.Millisecond*80) { 731 t.Errorf("Wrong renewbackoff time returned Expected %s Actual %s", time.Duration(time.Millisecond*400), duration.ErrorBackoffDuration) 732 } 733 } 734 735 func Test_CalculateTime3(t *testing.T) { 736 factor, err := NewBackOffFactor(30.0/100.0, 0.02) 737 if err != nil { 738 t.Errorf("Error creating Factor: %s", err.Error()) 739 } 740 now := time.Now() 741 before := now.Add(time.Minute * -5) 742 after := now.Add(time.Duration(time.Minute*10 + time.Second*30)) 743 duration := calculateTime(before, after, now, factor) 744 if duration.RenewBackoffDuration != time.Duration(time.Minute*5+time.Second*51) { 745 t.Errorf("Wrong wait time returned Expected %s Actual %s", time.Duration(time.Minute*5+time.Second*51), duration.RenewBackoffDuration) 746 } 747 if duration.RenewBackoffDuration < time.Duration(0) { 748 t.Errorf("Wrong wait time returned Expected greater than zero %s", duration.RenewBackoffDuration) 749 } 750 if duration.ErrorBackoffDuration != time.Duration(time.Second*18+time.Millisecond*600) { 751 t.Errorf("Wrong renewbackoff time returned Expected %s Actual %s", time.Duration(time.Millisecond*400), duration.ErrorBackoffDuration) 752 } 753 } 754 755 func Test_CalculateTimeNegative(t *testing.T) { 756 factor, err := NewBackOffFactor(0.3, 0.02) 757 if err != nil { 758 t.Errorf("Error creating Factor: %s", err.Error()) 759 } 760 now := time.Now() 761 before := now.Add(time.Duration(time.Second * -20)) 762 after := now.Add(time.Duration(time.Second * -10)) 763 duration := calculateTime(before, after, now, factor) 764 if duration.RenewBackoffDuration != time.Duration(time.Second*-13) { 765 t.Errorf("Wrong wait time returned Expected %s Actual %s", time.Duration(time.Second*-13), duration.RenewBackoffDuration) 766 } 767 if duration.RenewBackoffDuration > time.Duration(0) { 768 t.Errorf("Wrong wait time returned Expected less than zero %s", duration.RenewBackoffDuration) 769 } 770 if duration.ErrorBackoffDuration != time.Duration(time.Millisecond*200) { 771 t.Errorf("Wrong renewbackoff time returned Expected %s Actual %s", time.Duration(time.Millisecond*200), duration.ErrorBackoffDuration) 772 } 773 } 774 775 func Test_CalculateTimeAfter(t *testing.T) { 776 factor, err := NewBackOffFactor(0.3, 0.02) 777 if err != nil { 778 t.Errorf("Error creating Factor: %s", err.Error()) 779 } 780 now := time.Now() 781 before := now.Add(time.Duration(time.Second * 10)) 782 after := now.Add(time.Duration(time.Second * 30)) 783 duration := calculateTime(before, after, now, factor) 784 if duration.RenewBackoffDuration != time.Duration(time.Second*24) { 785 t.Errorf("Wrong wait time returned Expected %s Actual %s", time.Duration(time.Second*24), duration.RenewBackoffDuration) 786 } 787 if duration.ErrorBackoffDuration != time.Duration(time.Millisecond*400) { 788 t.Errorf("Wrong renewbackoff time returned Expected %s Actual %s", time.Duration(time.Millisecond*400), duration.ErrorBackoffDuration) 789 } 790 } 791 792 func Test_CalculateRenewTime(t *testing.T) { 793 factor, err := NewBackOffFactor(0.3, 0.02) 794 if err != nil { 795 t.Errorf("Error creating Factor: %s", err.Error()) 796 } 797 now := time.Now() 798 before := now.Add(time.Duration(time.Second * -10)) 799 after := now.Add(time.Duration(time.Second * 10)) 800 cert, err := createTestCertificate(before, after) 801 if err != nil { 802 t.Errorf("Failed creating certificate: %s", err.Error()) 803 } 804 duration, err := CalculateRenewTime(cert, factor) 805 if err != nil { 806 t.Errorf("Failed calculating Certificate renewal backoff: %s", err.Error()) 807 } 808 if duration.RenewBackoffDuration > time.Duration(time.Second*4) || duration.RenewBackoffDuration < time.Duration(time.Second*1) { 809 t.Errorf("Wrong wait time returned Expected %s Actual %s", time.Duration(time.Second*4), duration.RenewBackoffDuration) 810 } 811 if duration.RenewBackoffDuration < time.Duration(0) { 812 t.Errorf("Wrong wait time returned Expected greater than zero %s", duration.RenewBackoffDuration) 813 } 814 if duration.ErrorBackoffDuration != time.Duration(time.Millisecond*400) { 815 t.Errorf("Wrong renewbackoff time returned Expected %s Actual %s", time.Duration(time.Millisecond*400), duration.ErrorBackoffDuration) 816 } 817 } 818 819 func Test_CalculateRenewTime1(t *testing.T) { 820 factor, err := NewBackOffFactor(0.1, 0.002) 821 if err != nil { 822 t.Errorf("Error creating Factor: %s", err.Error()) 823 } 824 now := time.Now() 825 before := now.Add(time.Duration(time.Second * -30)) 826 after := now.Add(time.Duration(time.Second * 10)) 827 cert, err := createTestCertificate(before, after) 828 if err != nil { 829 t.Errorf("Failed creating certificate: %s", err.Error()) 830 } 831 duration, err := CalculateRenewTime(cert, factor) 832 if err != nil { 833 t.Errorf("Failed calculating Certificate renewal backoff: %s", err.Error()) 834 } 835 if duration.RenewBackoffDuration > time.Duration(time.Second*6) || duration.RenewBackoffDuration < time.Duration(time.Second*3) { 836 t.Errorf("Wrong wait time returned Expected %s Actual %s", time.Duration(time.Second*6), duration.RenewBackoffDuration) 837 } 838 if duration.RenewBackoffDuration < time.Duration(0) { 839 t.Errorf("Wrong wait time returned Expected greater than zero %s", duration.RenewBackoffDuration) 840 } 841 if duration.ErrorBackoffDuration != time.Duration(time.Millisecond*80) { 842 t.Errorf("Wrong renewbackoff time returned Expected %s Actual %s", time.Duration(time.Millisecond*400), duration.ErrorBackoffDuration) 843 } 844 } 845 846 func Test_CalculateRenewTime2(t *testing.T) { 847 factor, err := NewBackOffFactor(0.5, 0.002) 848 if err != nil { 849 t.Errorf("Error creating Factor: %s", err.Error()) 850 } 851 now := time.Now() 852 before := now.Add(time.Duration(time.Second * -30)) 853 after := now.Add(time.Duration(time.Second * 10)) 854 cert, err := createTestCertificate(before, after) 855 if err != nil { 856 t.Errorf("Failed creating certificate: %s", err.Error()) 857 } 858 duration, err := CalculateRenewTime(cert, factor) 859 if err != nil { 860 t.Errorf("Failed calculating Certificate renewal backoff: %s", err.Error()) 861 } 862 if duration.RenewBackoffDuration > time.Duration(time.Second*-10) || duration.RenewBackoffDuration < time.Duration(time.Second*-13) { 863 t.Errorf("Wrong wait time returned Expected %s Actual %s", time.Duration(time.Second*-10), duration.RenewBackoffDuration) 864 } 865 if duration.RenewBackoffDuration > time.Duration(0) { 866 t.Errorf("Wrong wait time returned Expected greater than zero %s", duration.RenewBackoffDuration) 867 } 868 if duration.ErrorBackoffDuration != time.Duration(time.Millisecond*80) { 869 t.Errorf("Wrong renewbackoff time returned Expected %s Actual %s", time.Duration(time.Millisecond*400), duration.ErrorBackoffDuration) 870 } 871 } 872 873 func Test_CalculateRenewTimeNegative(t *testing.T) { 874 factor, err := NewBackOffFactor(0.3, 0.02) 875 if err != nil { 876 t.Errorf("Error creating Factor: %s", err.Error()) 877 } 878 now := time.Now() 879 before := now.Add(time.Duration(time.Second * -20)) 880 after := now.Add(time.Duration(time.Second * -10)) 881 cert, err := createTestCertificate(before, after) 882 if err != nil { 883 t.Errorf("Failed creating certificate: %s", err.Error()) 884 } 885 duration, err := CalculateRenewTime(cert, factor) 886 if err != nil { 887 t.Errorf("Failed calculating Certificate renewal backoff: %s", err.Error()) 888 } 889 if duration.RenewBackoffDuration > time.Duration(time.Second*-13) || duration.RenewBackoffDuration < time.Duration(time.Second*-16) { 890 t.Errorf("Wrong wait time returned Expected %s Actual %s", time.Duration(time.Second*-13), duration.RenewBackoffDuration) 891 } 892 if duration.RenewBackoffDuration > time.Duration(0) { 893 t.Errorf("Wrong wait time returned Expected less than zero %s", duration.RenewBackoffDuration) 894 } 895 if duration.ErrorBackoffDuration != time.Duration(time.Millisecond*200) { 896 t.Errorf("Wrong renewbackoff time returned Expected %s Actual %s", time.Duration(time.Millisecond*200), duration.ErrorBackoffDuration) 897 } 898 } 899 900 func Test_CalculateRenewTimeAfter(t *testing.T) { 901 factor, err := NewBackOffFactor(0.3, 0.02) 902 if err != nil { 903 t.Errorf("Error creating Factor: %s", err.Error()) 904 } 905 now := time.Now() 906 before := now.Add(time.Duration(time.Second * 10)) 907 after := now.Add(time.Duration(time.Second * 30)) 908 cert, err := createTestCertificate(before, after) 909 if err != nil { 910 t.Errorf("Failed creating certificate: %s", err.Error()) 911 } 912 duration, err := CalculateRenewTime(cert, factor) 913 if err != nil { 914 t.Errorf("Failed calculating Certificate renewal backoff: %s", err.Error()) 915 } 916 if duration.RenewBackoffDuration < time.Duration(time.Second*22) || duration.RenewBackoffDuration > time.Duration(time.Second*24) { 917 t.Errorf("Wrong wait time returned Expected %s Actual %s", time.Duration(time.Second*24), duration.RenewBackoffDuration) 918 } 919 if duration.RenewBackoffDuration < time.Duration(0) { 920 t.Errorf("Wrong wait time returned Expected greater than zero %s", duration.RenewBackoffDuration) 921 } 922 if duration.ErrorBackoffDuration != time.Duration(time.Millisecond*400) { 923 t.Errorf("Wrong renewbackoff time returned Expected %s Actual %s", time.Duration(time.Millisecond*400), duration.ErrorBackoffDuration) 924 } 925 } 926 927 func Test_CalculateCertExpiry(t *testing.T) { 928 now := time.Now() 929 before := now.Add(time.Duration(time.Second * -30)) 930 after := now.Add(time.Duration(time.Second * 10)) 931 cert, err := createTestCertificate(before, after) 932 if err != nil { 933 t.Errorf("Failed creating certificate: %s", err.Error()) 934 } 935 expired, err := IsCertificateExpired(cert) 936 if err != nil { 937 t.Errorf("Failed finding certificate expired: %s", err.Error()) 938 } 939 940 if expired { 941 t.Errorf("Certificate expired") 942 } 943 } 944 945 func Test_CalculateCertExpiry1(t *testing.T) { 946 now := time.Now() 947 before := now.Add(time.Duration(time.Second * -20)) 948 after := now.Add(time.Duration(time.Second * -10)) 949 cert, err := createTestCertificate(before, after) 950 if err != nil { 951 t.Errorf("Failed creating certificate: %s", err.Error()) 952 } 953 expired, err := IsCertificateExpired(cert) 954 if err != nil { 955 t.Errorf("Failed finding certificate expired: %s", err.Error()) 956 } 957 958 if !expired { 959 t.Errorf("Certificate not expired") 960 } 961 } 962 963 func Test_Revocation_IsRevoked(t *testing.T) { 964 ctrl := gomock.NewController(t) 965 defer ctrl.Finish() 966 967 ca, _, _ := GenerateClientCertificate("test CA") 968 m := mock.NewMockRevocation(ctrl) 969 m.EXPECT().IsRevoked(ca) 970 m.IsRevoked(ca) 971 }