github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/sm4soft/padding/pkcs7_padding_io.go (about) 1 // Copyright 2022 s1ren@github.com/hxx258456. 2 3 /* 4 sm4soft 是sm4的纯软实现,基于tjfoc国密算法库`tjfoc/gmsm`做了少量修改。 5 对应版权声明: thrid_licenses/github.com/tjfoc/gmsm/版权声明 6 */ 7 8 package padding 9 10 import ( 11 "bytes" 12 "errors" 13 "io" 14 ) 15 16 // PKCS7PaddingReader 符合PKCS#7填充的输入流 17 type PKCS7PaddingReader struct { 18 fIn io.Reader 19 padding io.Reader 20 blockSize int 21 readed int64 22 eof bool 23 eop bool 24 } 25 26 // NewPKCS7PaddingReader 创建PKCS7填充Reader 27 // in: 输入流 28 // blockSize: 分块大小 29 func NewPKCS7PaddingReader(in io.Reader, blockSize int) *PKCS7PaddingReader { 30 return &PKCS7PaddingReader{ 31 fIn: in, 32 padding: nil, 33 eof: false, 34 eop: false, 35 blockSize: blockSize, 36 } 37 } 38 39 func (p *PKCS7PaddingReader) Read(buf []byte) (int, error) { 40 /* 41 - 读取文件 42 - 文件长度充足, 直接返还 43 - 不充足 44 - 读取到 n 字节, 剩余需要 m 字节 45 - 从 padding 中读取然后追加到 buff 46 - EOF 直接返回, 整个Reader end 47 */ 48 // 都读取完了 49 if p.eof && p.eop { 50 return 0, io.EOF 51 } 52 53 var n, off = 0, 0 54 var err error 55 if !p.eof { 56 // 读取文件 57 n, err = p.fIn.Read(buf) 58 if err != nil && !errors.Is(err, io.EOF) { 59 // 错误返回 60 return 0, err 61 } 62 p.readed += int64(n) 63 if errors.Is(err, io.EOF) { 64 // 标志文件结束 65 p.eof = true 66 } 67 if n == len(buf) { 68 // 长度足够直接返回 69 return n, nil 70 } 71 // 文件长度已经不足,根据已经已经读取的长度创建Padding 72 p.newPadding() 73 // 长度不足向Padding中索要 74 off = n 75 } 76 77 if !p.eop { 78 // 读取流 79 var n2 = 0 80 n2, err = p.padding.Read(buf[off:]) 81 n += n2 82 if errors.Is(err, io.EOF) { 83 p.eop = true 84 } 85 } 86 return n, err 87 } 88 89 // 新建Padding 90 func (p *PKCS7PaddingReader) newPadding() { 91 if p.padding != nil { 92 return 93 } 94 size := p.blockSize - int(p.readed%int64(p.blockSize)) 95 padding := bytes.Repeat([]byte{byte(size)}, size) 96 p.padding = bytes.NewReader(padding) 97 } 98 99 // PKCS7PaddingWriter 符合PKCS#7去除的输入流,最后一个 分组根据会根据填充情况去除填充。 100 type PKCS7PaddingWriter struct { 101 cache *bytes.Buffer // 缓存区 102 swap []byte // 临时交换区 103 out io.Writer // 输出位置 104 blockSize int // 分块大小 105 } 106 107 // NewPKCS7PaddingWriter PKCS#7 填充Writer 可以去除填充 108 func NewPKCS7PaddingWriter(out io.Writer, blockSize int) *PKCS7PaddingWriter { 109 cache := bytes.NewBuffer(make([]byte, 0, 1024)) 110 swap := make([]byte, 1024) 111 return &PKCS7PaddingWriter{out: out, blockSize: blockSize, cache: cache, swap: swap} 112 } 113 114 // Write 保留一个填充大小的数据,其余全部写入输出中 115 func (p *PKCS7PaddingWriter) Write(buff []byte) (n int, err error) { 116 // 写入缓存 117 n, err = p.cache.Write(buff) 118 if err != nil { 119 return 0, err 120 } 121 if p.cache.Len() > p.blockSize { 122 // 把超过一个分组长度的部分读取出来,写入到实际的out中 123 size := p.cache.Len() - p.blockSize 124 _, _ = p.cache.Read(p.swap[:size]) 125 _, err = p.out.Write(p.swap[:size]) 126 if err != nil { 127 return 0, err 128 } 129 } 130 return n, err 131 132 } 133 134 // Final 去除填充写入最后一个分块 135 func (p *PKCS7PaddingWriter) Final() error { 136 // 在Write 之后 cache 只会保留一个Block长度数据 137 b := p.cache.Bytes() 138 length := len(b) 139 if length != p.blockSize { 140 return errors.New("非法的PKCS7填充") 141 } 142 if length == 0 { 143 return nil 144 } 145 unpadding := int(b[length-1]) 146 if unpadding > p.blockSize || unpadding == 0 { 147 return errors.New("非法的PKCS7填充") 148 } 149 _, err := p.out.Write(b[:(length - unpadding)]) 150 return err 151 }