github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/sm4/ctr_cipher_asm.go (about) 1 //go:build amd64 || arm64 2 // +build amd64 arm64 3 4 package sm4 5 6 import ( 7 "crypto/cipher" 8 9 "github.com/hxx258456/ccgo/internal/subtle" 10 "github.com/hxx258456/ccgo/internal/xor" 11 ) 12 13 // Assert that sm4CipherAsm implements the ctrAble interface. 14 var _ ctrAble = (*sm4CipherAsm)(nil) 15 16 type ctr struct { 17 b *sm4CipherAsm 18 ctr []byte 19 out []byte 20 outUsed int 21 } 22 23 const streamBufferSize = 512 24 25 // NewCTR returns a Stream which encrypts/decrypts using the SM4 block 26 // cipher in counter mode. The length of iv must be the same as BlockSize. 27 func (sm4c *sm4CipherAsm) NewCTR(iv []byte) cipher.Stream { 28 if len(iv) != BlockSize { 29 panic("cipher.NewCTR: IV length must equal block size") 30 } 31 bufSize := streamBufferSize 32 if bufSize < BlockSize { 33 bufSize = BlockSize 34 } 35 s := &ctr{ 36 b: sm4c, 37 ctr: make([]byte, sm4c.batchBlocks*len(iv)), 38 out: make([]byte, 0, bufSize), 39 outUsed: 0, 40 } 41 copy(s.ctr, iv) 42 for i := 1; i < sm4c.batchBlocks; i++ { 43 s.genCtr(i * BlockSize) 44 } 45 return s 46 47 } 48 49 func (x *ctr) genCtr(start int) { 50 if start > 0 { 51 copy(x.ctr[start:], x.ctr[start-BlockSize:start]) 52 } else { 53 copy(x.ctr[start:], x.ctr[len(x.ctr)-BlockSize:]) 54 } 55 // Increment counter 56 end := start + BlockSize 57 for i := end - 1; i >= 0; i-- { 58 x.ctr[i]++ 59 if x.ctr[i] != 0 { 60 break 61 } 62 } 63 } 64 65 func (x *ctr) refill() { 66 remain := len(x.out) - x.outUsed 67 copy(x.out, x.out[x.outUsed:]) 68 x.out = x.out[:cap(x.out)] 69 for remain <= len(x.out)-x.b.blocksSize { 70 x.b.EncryptBlocks(x.out[remain:], x.ctr) 71 remain += x.b.blocksSize 72 73 // Increment counter 74 for i := 0; i < x.b.batchBlocks; i++ { 75 x.genCtr(i * BlockSize) 76 } 77 } 78 x.out = x.out[:remain] 79 x.outUsed = 0 80 } 81 82 func (x *ctr) XORKeyStream(dst, src []byte) { 83 if len(dst) < len(src) { 84 panic("cipher: output smaller than input") 85 } 86 if subtle.InexactOverlap(dst[:len(src)], src) { 87 panic("cipher: invalid buffer overlap") 88 } 89 for len(src) > 0 { 90 if x.outUsed >= len(x.out)-BlockSize { 91 x.refill() 92 } 93 n := xor.XorBytes(dst, src, x.out[x.outUsed:]) 94 dst = dst[n:] 95 src = src[n:] 96 x.outUsed += n 97 } 98 }