github.com/emmansun/gmsm@v0.29.1/smx509/csr_rsp.go (about)

     1  // Marshal & Parse CSRResponse which is defined in GM/T 0092-2020
     2  // Specification of certificate request syntax based on SM2 cryptographic algorithm.
     3  
     4  package smx509
     5  
     6  import (
     7  	"bytes"
     8  	"crypto/ecdsa"
     9  	"crypto/rand"
    10  	"encoding/asn1"
    11  	"errors"
    12  
    13  	"github.com/emmansun/gmsm/sm2"
    14  )
    15  
    16  // CSRResponse represents the response of a certificate signing request.
    17  type CSRResponse struct {
    18  	SignCerts         []*Certificate
    19  	EncryptPrivateKey *sm2.PrivateKey
    20  	EncryptCerts      []*Certificate
    21  }
    22  
    23  // GM/T 0092-2020 Specification of certificate request syntax based on SM2 cryptographic algorithm.
    24  // Section 8 and Appendix A
    25  //
    26  // CSRResponse ::= SEQUENCE {
    27  //	 signCertificate CertificateSet,
    28  //	 encryptedPrivateKey [0] SM2EnvelopedKey OPTIONAL,
    29  //   encryptCertificate  [1] CertificateSet OPTIONAL
    30  // }
    31  type tbsCSRResponse struct {
    32  	SignCerts           []asn1.RawValue `asn1:"set"` // SignCerts ::= SET OF Certificate
    33  	EncryptedPrivateKey asn1.RawValue   `asn1:"optional,tag:0"`
    34  	EncryptCerts        rawCertificates `asn1:"optional,tag:1"`
    35  }
    36  
    37  type rawCertificates struct {
    38  	Raw asn1.RawContent
    39  }
    40  
    41  // ParseCSRResponse parses a CSRResponse from DER format.
    42  // We do NOT verify the cert chain here, it's the caller's responsibility.
    43  func ParseCSRResponse(signPrivateKey *sm2.PrivateKey, der []byte) (CSRResponse, error) {
    44  	result := CSRResponse{}
    45  	resp := &tbsCSRResponse{}
    46  	rest, err := asn1.Unmarshal(der, resp)
    47  	if err != nil || len(rest) > 0 {
    48  		return result, errors.New("smx509: invalid CSRResponse asn1 data")
    49  	}
    50  
    51  	signCerts := make([]*Certificate, len(resp.SignCerts))
    52  	for i, rawCert := range resp.SignCerts {
    53  		signCert, err := ParseCertificate(rawCert.FullBytes)
    54  		if err != nil {
    55  			return result, err
    56  		}
    57  		signCerts[i] = signCert
    58  	}
    59  
    60  	// check sign public key against the private key
    61  	if !signPrivateKey.PublicKey.Equal(signCerts[0].PublicKey) {
    62  		return result, errors.New("smx509: sign cert public key mismatch")
    63  	}
    64  
    65  	var encPrivateKey *sm2.PrivateKey
    66  	if len(resp.EncryptedPrivateKey.Bytes) > 0 {
    67  		encPrivateKey, err = sm2.ParseEnvelopedPrivateKey(signPrivateKey, resp.EncryptedPrivateKey.Bytes)
    68  		if err != nil {
    69  			return result, err
    70  		}
    71  	}
    72  	var encryptCerts []*Certificate
    73  	if len(resp.EncryptCerts.Raw) > 0 {
    74  		encryptCerts, err = resp.EncryptCerts.Parse()
    75  		if err != nil {
    76  			return result, err
    77  		}
    78  	}
    79  
    80  	// check the public key of the encrypt certificate
    81  	if encPrivateKey != nil && len(encryptCerts) == 0 {
    82  		return result, errors.New("smx509: missing encrypt certificate")
    83  	}
    84  
    85  	if encPrivateKey != nil && !encPrivateKey.PublicKey.Equal(encryptCerts[0].PublicKey) {
    86  		return result, errors.New("smx509: encrypt key pair mismatch")
    87  	}
    88  
    89  	result.SignCerts = signCerts
    90  	result.EncryptPrivateKey = encPrivateKey
    91  	result.EncryptCerts = encryptCerts
    92  	return result, nil
    93  }
    94  
    95  // MarshalCSRResponse marshals a CSRResponse to DER format.
    96  func MarshalCSRResponse(signCerts []*Certificate, encryptPrivateKey *sm2.PrivateKey, encryptCerts []*Certificate) ([]byte, error) {
    97  	if len(signCerts) == 0 {
    98  		return nil, errors.New("smx509: no sign certificate")
    99  	}
   100  	signPubKey, ok := signCerts[0].PublicKey.(*ecdsa.PublicKey)
   101  	if !ok || !sm2.IsSM2PublicKey(signPubKey) {
   102  		return nil, errors.New("smx509: invalid sign public key")
   103  	}
   104  
   105  	// check the public key of the encrypt certificate
   106  	if encryptPrivateKey != nil && len(encryptCerts) == 0 {
   107  		return nil, errors.New("smx509: missing encrypt certificate")
   108  	}
   109  	if encryptPrivateKey != nil && !encryptPrivateKey.PublicKey.Equal(encryptCerts[0].PublicKey) {
   110  		return nil, errors.New("smx509: encrypt key pair mismatch")
   111  	}
   112  
   113  	resp := tbsCSRResponse{}
   114  	resp.SignCerts = make([]asn1.RawValue, 0, len(signCerts))
   115  	for _, cert := range signCerts {
   116  		resp.SignCerts = append(resp.SignCerts, asn1.RawValue{FullBytes: cert.Raw})
   117  	}
   118  	if encryptPrivateKey != nil && len(encryptCerts) > 0 {
   119  		privateKeyBytes, err := sm2.MarshalEnvelopedPrivateKey(rand.Reader, signPubKey, encryptPrivateKey)
   120  		if err != nil {
   121  			return nil, err
   122  		}
   123  		resp.EncryptedPrivateKey = asn1.RawValue{Class: asn1.ClassContextSpecific, Tag: 0, IsCompound: true, Bytes: privateKeyBytes}
   124  		resp.EncryptCerts = marshalCertificates(encryptCerts)
   125  	}
   126  	return asn1.Marshal(resp)
   127  }
   128  
   129  // concats and wraps the certificates in the RawValue structure
   130  func marshalCertificates(certs []*Certificate) rawCertificates {
   131  	var buf bytes.Buffer
   132  	for _, cert := range certs {
   133  		buf.Write(cert.Raw)
   134  	}
   135  	rawCerts, _ := marshalCertificateBytes(buf.Bytes())
   136  	return rawCerts
   137  }
   138  
   139  // Even though, the tag & length are stripped out during marshalling the
   140  // RawContent, we have to encode it into the RawContent. If its missing,
   141  // then `asn1.Marshal()` will strip out the certificate wrapper instead.
   142  func marshalCertificateBytes(certs []byte) (rawCertificates, error) {
   143  	var val = asn1.RawValue{Bytes: certs, Class: asn1.ClassContextSpecific, Tag: 0, IsCompound: true}
   144  	b, err := asn1.Marshal(val)
   145  	if err != nil {
   146  		return rawCertificates{}, err
   147  	}
   148  	return rawCertificates{Raw: b}, nil
   149  }
   150  
   151  func (raw rawCertificates) Parse() ([]*Certificate, error) {
   152  	if len(raw.Raw) == 0 {
   153  		return nil, nil
   154  	}
   155  
   156  	var val asn1.RawValue
   157  	if _, err := asn1.Unmarshal(raw.Raw, &val); err != nil {
   158  		return nil, err
   159  	}
   160  
   161  	return ParseCertificates(val.Bytes)
   162  }