github.com/emmansun/gmsm@v0.29.1/pkcs/cipher.go (about)

     1  // Package pkcs implements ciphers used by PKCS#7 & PKCS#8.
     2  package pkcs
     3  
     4  import (
     5  	"crypto/cipher"
     6  	"crypto/x509/pkix"
     7  	"encoding/asn1"
     8  	"errors"
     9  	"fmt"
    10  	"io"
    11  
    12  	smcipher "github.com/emmansun/gmsm/cipher"
    13  	"github.com/emmansun/gmsm/padding"
    14  )
    15  
    16  // Cipher represents a cipher for encrypting the key material
    17  // which is used in PBES2.
    18  type Cipher interface {
    19  	// KeySize returns the key size of the cipher, in bytes.
    20  	KeySize() int
    21  	// Encrypt encrypts the key material. The returned AlgorithmIdentifier is
    22  	// the algorithm identifier used for encryption including parameters.
    23  	Encrypt(rand io.Reader, key, plaintext []byte) (*pkix.AlgorithmIdentifier, []byte, error)
    24  	// Decrypt decrypts the key material. The parameters are the parameters from the
    25  	// DER-encoded AlgorithmIdentifier's.
    26  	Decrypt(key []byte, parameters *asn1.RawValue, ciphertext []byte) ([]byte, error)
    27  	// OID returns the OID of the cipher specified.
    28  	OID() asn1.ObjectIdentifier
    29  }
    30  
    31  var ciphers = make(map[string]func() Cipher)
    32  
    33  // RegisterCipher registers a function that returns a new instance of the given
    34  // cipher. This allows the library to support client-provided ciphers.
    35  func RegisterCipher(oid asn1.ObjectIdentifier, cipher func() Cipher) {
    36  	ciphers[oid.String()] = cipher
    37  }
    38  
    39  // GetCipher returns an instance of the cipher specified by the given algorithm identifier.
    40  func GetCipher(alg pkix.AlgorithmIdentifier) (Cipher, error) {
    41  	oid := alg.Algorithm.String()
    42  	if oid == oidSM4.String() {
    43  		if len(alg.Parameters.Bytes) != 0 || len(alg.Parameters.FullBytes) != 0 {
    44  			return SM4CBC, nil
    45  		} else {
    46  			return SM4ECB, nil
    47  		}
    48  	}
    49  	newCipher, ok := ciphers[oid]
    50  	if !ok {
    51  		return nil, fmt.Errorf("pbes: unsupported cipher (OID: %s)", oid)
    52  	}
    53  	return newCipher(), nil
    54  }
    55  
    56  type baseBlockCipher struct {
    57  	oid      asn1.ObjectIdentifier
    58  	keySize  int
    59  	newBlock func(key []byte) (cipher.Block, error)
    60  }
    61  
    62  func (b *baseBlockCipher) KeySize() int {
    63  	return b.keySize
    64  }
    65  
    66  func (b *baseBlockCipher) OID() asn1.ObjectIdentifier {
    67  	return b.oid
    68  }
    69  
    70  type ecbBlockCipher struct {
    71  	baseBlockCipher
    72  }
    73  
    74  func (ecb *ecbBlockCipher) Encrypt(rand io.Reader, key, plaintext []byte) (*pkix.AlgorithmIdentifier, []byte, error) {
    75  	block, err := ecb.newBlock(key)
    76  	if err != nil {
    77  		return nil, nil, err
    78  	}
    79  	mode := smcipher.NewECBEncrypter(block)
    80  	pkcs7 := padding.NewPKCS7Padding(uint(block.BlockSize()))
    81  	plaintext = pkcs7.Pad(plaintext)
    82  	ciphertext := make([]byte, len(plaintext))
    83  	mode.CryptBlocks(ciphertext, plaintext)
    84  
    85  	encryptionScheme := pkix.AlgorithmIdentifier{
    86  		Algorithm: ecb.oid,
    87  	}
    88  
    89  	return &encryptionScheme, ciphertext, nil
    90  }
    91  
    92  func (ecb *ecbBlockCipher) Decrypt(key []byte, parameters *asn1.RawValue, ciphertext []byte) ([]byte, error) {
    93  	block, err := ecb.newBlock(key)
    94  	if err != nil {
    95  		return nil, err
    96  	}
    97  	mode := smcipher.NewECBDecrypter(block)
    98  	plaintext := make([]byte, len(ciphertext))
    99  	mode.CryptBlocks(plaintext, ciphertext)
   100  	pkcs7 := padding.NewPKCS7Padding(uint(block.BlockSize()))
   101  	unpadded, err := pkcs7.Unpad(plaintext)
   102  	if err != nil { // In order to be compatible with some implementations without padding
   103  		return plaintext, nil
   104  	}
   105  	return unpadded, nil
   106  }
   107  
   108  type cbcBlockCipher struct {
   109  	baseBlockCipher
   110  	ivSize int
   111  }
   112  
   113  func (c *cbcBlockCipher) Encrypt(rand io.Reader, key, plaintext []byte) (*pkix.AlgorithmIdentifier, []byte, error) {
   114  	block, err := c.newBlock(key)
   115  	if err != nil {
   116  		return nil, nil, err
   117  	}
   118  
   119  	iv := make([]byte, c.ivSize)
   120  	if _, err := rand.Read(iv); err != nil {
   121  		return nil, nil, err
   122  	}
   123  
   124  	ciphertext, err := cbcEncrypt(block, iv, plaintext)
   125  	if err != nil {
   126  		return nil, nil, err
   127  	}
   128  
   129  	marshalledIV, err := asn1.Marshal(iv)
   130  	if err != nil {
   131  		return nil, nil, err
   132  	}
   133  
   134  	encryptionScheme := pkix.AlgorithmIdentifier{
   135  		Algorithm:  c.oid,
   136  		Parameters: asn1.RawValue{FullBytes: marshalledIV},
   137  	}
   138  
   139  	return &encryptionScheme, ciphertext, nil
   140  }
   141  
   142  func (c *cbcBlockCipher) Decrypt(key []byte, parameters *asn1.RawValue, ciphertext []byte) ([]byte, error) {
   143  	block, err := c.newBlock(key)
   144  	if err != nil {
   145  		return nil, err
   146  	}
   147  
   148  	var iv []byte
   149  	if _, err := asn1.Unmarshal(parameters.FullBytes, &iv); err != nil {
   150  		return nil, errors.New("pbes: invalid cipher parameters")
   151  	}
   152  
   153  	return cbcDecrypt(block, iv, ciphertext)
   154  }
   155  
   156  func cbcEncrypt(block cipher.Block, iv, plaintext []byte) ([]byte, error) {
   157  	mode := cipher.NewCBCEncrypter(block, iv)
   158  	pkcs7 := padding.NewPKCS7Padding(uint(block.BlockSize()))
   159  	plainText := pkcs7.Pad(plaintext)
   160  	ciphertext := make([]byte, len(plainText))
   161  	mode.CryptBlocks(ciphertext, plainText)
   162  	return ciphertext, nil
   163  }
   164  
   165  func cbcDecrypt(block cipher.Block, iv, ciphertext []byte) ([]byte, error) {
   166  	mode := cipher.NewCBCDecrypter(block, iv)
   167  	pkcs7 := padding.NewPKCS7Padding(uint(block.BlockSize()))
   168  	plaintext := make([]byte, len(ciphertext))
   169  	mode.CryptBlocks(plaintext, ciphertext)
   170  	return pkcs7.Unpad(plaintext)
   171  }
   172  
   173  type gcmBlockCipher struct {
   174  	baseBlockCipher
   175  	nonceSize int
   176  }
   177  
   178  // https://datatracker.ietf.org/doc/rfc5084/
   179  //
   180  //	GCMParameters ::= SEQUENCE {
   181  //		aes-nonce        OCTET STRING, -- recommended size is 12 octets
   182  //		aes-ICVlen       AES-GCM-ICVlen DEFAULT 12 }
   183  type gcmParameters struct {
   184  	Nonce  []byte
   185  	ICVLen int `asn1:"default:12,optional"`
   186  }
   187  
   188  func (c *gcmBlockCipher) Encrypt(rand io.Reader, key, plaintext []byte) (*pkix.AlgorithmIdentifier, []byte, error) {
   189  	block, err := c.newBlock(key)
   190  	if err != nil {
   191  		return nil, nil, err
   192  	}
   193  
   194  	nonce := make([]byte, c.nonceSize)
   195  	if _, err := rand.Read(nonce); err != nil {
   196  		return nil, nil, err
   197  	}
   198  
   199  	aead, err := cipher.NewGCMWithNonceSize(block, c.nonceSize)
   200  	if err != nil {
   201  		return nil, nil, err
   202  	}
   203  	ciphertext := aead.Seal(nil, nonce, plaintext, nil)
   204  	paramSeq := gcmParameters{
   205  		Nonce:  nonce,
   206  		ICVLen: aead.Overhead(),
   207  	}
   208  	paramBytes, err := asn1.Marshal(paramSeq)
   209  	if err != nil {
   210  		return nil, nil, err
   211  	}
   212  	encryptionAlgorithm := pkix.AlgorithmIdentifier{
   213  		Algorithm: c.oid,
   214  		Parameters: asn1.RawValue{
   215  			FullBytes: paramBytes,
   216  		},
   217  	}
   218  	return &encryptionAlgorithm, ciphertext, nil
   219  }
   220  
   221  func (c *gcmBlockCipher) Decrypt(key []byte, parameters *asn1.RawValue, ciphertext []byte) ([]byte, error) {
   222  	block, err := c.newBlock(key)
   223  	if err != nil {
   224  		return nil, err
   225  	}
   226  	params := gcmParameters{}
   227  	_, err = asn1.Unmarshal(parameters.FullBytes, &params)
   228  	if err != nil {
   229  		return nil, err
   230  	}
   231  	aead, err := cipher.NewGCMWithNonceSize(block, len(params.Nonce))
   232  	if err != nil {
   233  		return nil, err
   234  	}
   235  	if params.ICVLen != aead.Overhead() {
   236  		return nil, errors.New("pbes: we do not support non-standard tag size")
   237  	}
   238  
   239  	return aead.Open(nil, params.Nonce, ciphertext, nil)
   240  }