github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/sm4/ctr_cipher_asm.go (about)

     1  //go:build amd64 || arm64
     2  // +build amd64 arm64
     3  
     4  package sm4
     5  
     6  import (
     7  	"crypto/cipher"
     8  
     9  	"github.com/hxx258456/ccgo/internal/subtle"
    10  	"github.com/hxx258456/ccgo/internal/xor"
    11  )
    12  
    13  // Assert that sm4CipherAsm implements the ctrAble interface.
    14  var _ ctrAble = (*sm4CipherAsm)(nil)
    15  
    16  type ctr struct {
    17  	b       *sm4CipherAsm
    18  	ctr     []byte
    19  	out     []byte
    20  	outUsed int
    21  }
    22  
    23  const streamBufferSize = 512
    24  
    25  // NewCTR returns a Stream which encrypts/decrypts using the SM4 block
    26  // cipher in counter mode. The length of iv must be the same as BlockSize.
    27  func (sm4c *sm4CipherAsm) NewCTR(iv []byte) cipher.Stream {
    28  	if len(iv) != BlockSize {
    29  		panic("cipher.NewCTR: IV length must equal block size")
    30  	}
    31  	bufSize := streamBufferSize
    32  	if bufSize < BlockSize {
    33  		bufSize = BlockSize
    34  	}
    35  	s := &ctr{
    36  		b:       sm4c,
    37  		ctr:     make([]byte, sm4c.batchBlocks*len(iv)),
    38  		out:     make([]byte, 0, bufSize),
    39  		outUsed: 0,
    40  	}
    41  	copy(s.ctr, iv)
    42  	for i := 1; i < sm4c.batchBlocks; i++ {
    43  		s.genCtr(i * BlockSize)
    44  	}
    45  	return s
    46  
    47  }
    48  
    49  func (x *ctr) genCtr(start int) {
    50  	if start > 0 {
    51  		copy(x.ctr[start:], x.ctr[start-BlockSize:start])
    52  	} else {
    53  		copy(x.ctr[start:], x.ctr[len(x.ctr)-BlockSize:])
    54  	}
    55  	// Increment counter
    56  	end := start + BlockSize
    57  	for i := end - 1; i >= 0; i-- {
    58  		x.ctr[i]++
    59  		if x.ctr[i] != 0 {
    60  			break
    61  		}
    62  	}
    63  }
    64  
    65  func (x *ctr) refill() {
    66  	remain := len(x.out) - x.outUsed
    67  	copy(x.out, x.out[x.outUsed:])
    68  	x.out = x.out[:cap(x.out)]
    69  	for remain <= len(x.out)-x.b.blocksSize {
    70  		x.b.EncryptBlocks(x.out[remain:], x.ctr)
    71  		remain += x.b.blocksSize
    72  
    73  		// Increment counter
    74  		for i := 0; i < x.b.batchBlocks; i++ {
    75  			x.genCtr(i * BlockSize)
    76  		}
    77  	}
    78  	x.out = x.out[:remain]
    79  	x.outUsed = 0
    80  }
    81  
    82  func (x *ctr) XORKeyStream(dst, src []byte) {
    83  	if len(dst) < len(src) {
    84  		panic("cipher: output smaller than input")
    85  	}
    86  	if subtle.InexactOverlap(dst[:len(src)], src) {
    87  		panic("cipher: invalid buffer overlap")
    88  	}
    89  	for len(src) > 0 {
    90  		if x.outUsed >= len(x.out)-BlockSize {
    91  			x.refill()
    92  		}
    93  		n := xor.XorBytes(dst, src, x.out[x.outUsed:])
    94  		dst = dst[n:]
    95  		src = src[n:]
    96  		x.outUsed += n
    97  	}
    98  }