github.com/emmansun/gmsm@v0.29.1/smx509/csr_rsp_test.go (about)

     1  package smx509_test
     2  
     3  import (
     4  	"crypto"
     5  	"crypto/rand"
     6  	"crypto/x509"
     7  	"crypto/x509/pkix"
     8  	"encoding/pem"
     9  	"fmt"
    10  	"math/big"
    11  	"os"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/emmansun/gmsm/sm2"
    16  	"github.com/emmansun/gmsm/smx509"
    17  )
    18  
    19  type certKeyPair struct {
    20  	Certificate *smx509.Certificate
    21  	PrivateKey  *crypto.PrivateKey
    22  }
    23  
    24  func createTestCertificate() ([]*certKeyPair, error) {
    25  	signer, err := createTestCertificateByIssuer("Test CA", nil, true)
    26  	if err != nil {
    27  		return nil, err
    28  	}
    29  	pair1, err := createTestCertificateByIssuer("Test Org Sign", signer, false)
    30  	if err != nil {
    31  		return nil, err
    32  	}
    33  	pair2, err := createTestCertificateByIssuer("Test Org Enc", signer, false)
    34  	if err != nil {
    35  		return nil, err
    36  	}
    37  	return []*certKeyPair{pair1, pair2, signer}, nil
    38  }
    39  
    40  func createTestCertificateByIssuer(name string, issuer *certKeyPair, isCA bool) (*certKeyPair, error) {
    41  	var (
    42  		err        error
    43  		priv       crypto.PrivateKey
    44  		derCert    []byte
    45  		issuerCert *smx509.Certificate
    46  		issuerKey  crypto.PrivateKey
    47  	)
    48  	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 32)
    49  	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
    50  	if err != nil {
    51  		return nil, err
    52  	}
    53  
    54  	template := x509.Certificate{
    55  		SerialNumber: serialNumber,
    56  		Subject: pkix.Name{
    57  			CommonName:   name,
    58  			Organization: []string{"Acme Co"},
    59  		},
    60  		NotBefore:   time.Now().Add(-1 * time.Second),
    61  		NotAfter:    time.Now().AddDate(1, 0, 0),
    62  		KeyUsage:    x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
    63  		ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageEmailProtection},
    64  	}
    65  	if issuer != nil {
    66  		issuerCert = issuer.Certificate
    67  		issuerKey = *issuer.PrivateKey
    68  	}
    69  
    70  	priv, err = sm2.GenerateKey(rand.Reader)
    71  	if err != nil {
    72  		return nil, err
    73  	}
    74  
    75  	pkey := priv.(crypto.Signer)
    76  	if isCA {
    77  		template.IsCA = true
    78  		template.KeyUsage |= x509.KeyUsageCertSign
    79  		template.BasicConstraintsValid = true
    80  	}
    81  	if issuer == nil {
    82  		// no issuer given,make this a self-signed root cert
    83  		issuerCert = (*smx509.Certificate)(&template)
    84  		issuerKey = priv
    85  	}
    86  
    87  	derCert, err = smx509.CreateCertificate(rand.Reader, &template, (*x509.Certificate)(issuerCert), pkey.Public(), issuerKey)
    88  	if err != nil {
    89  		return nil, err
    90  	}
    91  	if len(derCert) == 0 {
    92  		return nil, fmt.Errorf("no certificate created, probably due to wrong keys. types were %T and %T", priv, issuerKey)
    93  	}
    94  	cert, err := smx509.ParseCertificate(derCert)
    95  	if err != nil {
    96  		return nil, err
    97  	}
    98  	pem.Encode(os.Stdout, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})
    99  	return &certKeyPair{
   100  		Certificate: cert,
   101  		PrivateKey:  &priv,
   102  	}, nil
   103  }
   104  
   105  func TestMarshalCSRResponse(t *testing.T) {
   106  	pairs, err := createTestCertificate()
   107  	if err != nil {
   108  		t.Fatal(err)
   109  	}
   110  
   111  	signPrivKey, _ := (*pairs[0].PrivateKey).(*sm2.PrivateKey)
   112  	encPrivKey, _ := (*pairs[1].PrivateKey).(*sm2.PrivateKey)
   113  
   114  	// Call the function
   115  	result, err := smx509.MarshalCSRResponse([]*smx509.Certificate{pairs[0].Certificate, pairs[2].Certificate}, encPrivKey, []*smx509.Certificate{pairs[1].Certificate, pairs[2].Certificate})
   116  	// Check the result
   117  	if err != nil {
   118  		t.Errorf("Unexpected error: %v", err)
   119  	}
   120  
   121  	resp, err := smx509.ParseCSRResponse(signPrivKey, result)
   122  	if err != nil {
   123  		t.Errorf("Unexpected error: %v", err)
   124  	}
   125  	if len(resp.SignCerts) != 2 {
   126  		t.Errorf("Unexpected number of sign certs: %d", len(resp.SignCerts))
   127  	}
   128  	if resp.EncryptPrivateKey == nil || !encPrivKey.Equal(resp.EncryptPrivateKey) {
   129  		t.Errorf("Unexpected encrypt private key")
   130  	}
   131  	if len(resp.EncryptCerts) != 2 {
   132  		t.Errorf("Unexpected number of encrypt certs: %d", len(resp.EncryptCerts))
   133  	}
   134  
   135  	// Marshal sign certificate only
   136  	result, err = smx509.MarshalCSRResponse([]*smx509.Certificate{pairs[0].Certificate, pairs[2].Certificate}, nil, nil)
   137  	// Check the result
   138  	if err != nil {
   139  		t.Errorf("Unexpected error: %v", err)
   140  	}
   141  	resp, err = smx509.ParseCSRResponse(signPrivKey, result)
   142  	if err != nil {
   143  		t.Errorf("Unexpected error: %v", err)
   144  	}
   145  	if len(resp.SignCerts) != 2 {
   146  		t.Errorf("Unexpected number of sign certs: %d", len(resp.SignCerts))
   147  	}
   148  	if resp.EncryptPrivateKey != nil {
   149  		t.Errorf("Unexpected encrypt private key")
   150  	}
   151  	if resp.EncryptCerts != nil {
   152  		t.Errorf("Unexpected encrypt certs")
   153  	}
   154  
   155  	_, err = smx509.MarshalCSRResponse(nil, nil, nil)
   156  	if err == nil || err.Error() != "smx509: no sign certificate" {
   157  		t.Errorf("Unexpected error: %v", err)
   158  	}
   159  
   160  	_, err = smx509.MarshalCSRResponse([]*smx509.Certificate{pairs[0].Certificate, pairs[2].Certificate}, encPrivKey, nil)
   161  	if err == nil || err.Error() != "smx509: missing encrypt certificate" {
   162  		t.Errorf("Unexpected error: %v", err)
   163  	}
   164  
   165  	_, err = smx509.MarshalCSRResponse([]*smx509.Certificate{pairs[0].Certificate, pairs[2].Certificate}, encPrivKey, []*smx509.Certificate{pairs[2].Certificate})
   166  	if err == nil || err.Error() != "smx509: encrypt key pair mismatch" {
   167  		t.Errorf("Unexpected error: %v", err)
   168  	}
   169  }