github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/sm4soft/padding/pkcs7_padding_io.go (about)

     1  // Copyright 2022 s1ren@github.com/hxx258456.
     2  
     3  /*
     4  sm4soft 是sm4的纯软实现,基于tjfoc国密算法库`tjfoc/gmsm`做了少量修改。
     5  对应版权声明: thrid_licenses/github.com/tjfoc/gmsm/版权声明
     6  */
     7  
     8  package padding
     9  
    10  import (
    11  	"bytes"
    12  	"errors"
    13  	"io"
    14  )
    15  
    16  // PKCS7PaddingReader 符合PKCS#7填充的输入流
    17  type PKCS7PaddingReader struct {
    18  	fIn       io.Reader
    19  	padding   io.Reader
    20  	blockSize int
    21  	readed    int64
    22  	eof       bool
    23  	eop       bool
    24  }
    25  
    26  // NewPKCS7PaddingReader 创建PKCS7填充Reader
    27  // in: 输入流
    28  // blockSize: 分块大小
    29  func NewPKCS7PaddingReader(in io.Reader, blockSize int) *PKCS7PaddingReader {
    30  	return &PKCS7PaddingReader{
    31  		fIn:       in,
    32  		padding:   nil,
    33  		eof:       false,
    34  		eop:       false,
    35  		blockSize: blockSize,
    36  	}
    37  }
    38  
    39  func (p *PKCS7PaddingReader) Read(buf []byte) (int, error) {
    40  	/*
    41  		- 读取文件
    42  			- 文件长度充足, 直接返还
    43  			- 不充足
    44  		- 读取到 n 字节, 剩余需要 m 字节
    45  		- 从 padding 中读取然后追加到 buff
    46  			- EOF  直接返回, 整个Reader end
    47  	*/
    48  	// 都读取完了
    49  	if p.eof && p.eop {
    50  		return 0, io.EOF
    51  	}
    52  
    53  	var n, off = 0, 0
    54  	var err error
    55  	if !p.eof {
    56  		// 读取文件
    57  		n, err = p.fIn.Read(buf)
    58  		if err != nil && !errors.Is(err, io.EOF) {
    59  			// 错误返回
    60  			return 0, err
    61  		}
    62  		p.readed += int64(n)
    63  		if errors.Is(err, io.EOF) {
    64  			// 标志文件结束
    65  			p.eof = true
    66  		}
    67  		if n == len(buf) {
    68  			// 长度足够直接返回
    69  			return n, nil
    70  		}
    71  		// 文件长度已经不足,根据已经已经读取的长度创建Padding
    72  		p.newPadding()
    73  		// 长度不足向Padding中索要
    74  		off = n
    75  	}
    76  
    77  	if !p.eop {
    78  		// 读取流
    79  		var n2 = 0
    80  		n2, err = p.padding.Read(buf[off:])
    81  		n += n2
    82  		if errors.Is(err, io.EOF) {
    83  			p.eop = true
    84  		}
    85  	}
    86  	return n, err
    87  }
    88  
    89  // 新建Padding
    90  func (p *PKCS7PaddingReader) newPadding() {
    91  	if p.padding != nil {
    92  		return
    93  	}
    94  	size := p.blockSize - int(p.readed%int64(p.blockSize))
    95  	padding := bytes.Repeat([]byte{byte(size)}, size)
    96  	p.padding = bytes.NewReader(padding)
    97  }
    98  
    99  // PKCS7PaddingWriter 符合PKCS#7去除的输入流,最后一个 分组根据会根据填充情况去除填充。
   100  type PKCS7PaddingWriter struct {
   101  	cache     *bytes.Buffer // 缓存区
   102  	swap      []byte        // 临时交换区
   103  	out       io.Writer     // 输出位置
   104  	blockSize int           // 分块大小
   105  }
   106  
   107  // NewPKCS7PaddingWriter PKCS#7 填充Writer 可以去除填充
   108  func NewPKCS7PaddingWriter(out io.Writer, blockSize int) *PKCS7PaddingWriter {
   109  	cache := bytes.NewBuffer(make([]byte, 0, 1024))
   110  	swap := make([]byte, 1024)
   111  	return &PKCS7PaddingWriter{out: out, blockSize: blockSize, cache: cache, swap: swap}
   112  }
   113  
   114  // Write 保留一个填充大小的数据,其余全部写入输出中
   115  func (p *PKCS7PaddingWriter) Write(buff []byte) (n int, err error) {
   116  	// 写入缓存
   117  	n, err = p.cache.Write(buff)
   118  	if err != nil {
   119  		return 0, err
   120  	}
   121  	if p.cache.Len() > p.blockSize {
   122  		// 把超过一个分组长度的部分读取出来,写入到实际的out中
   123  		size := p.cache.Len() - p.blockSize
   124  		_, _ = p.cache.Read(p.swap[:size])
   125  		_, err = p.out.Write(p.swap[:size])
   126  		if err != nil {
   127  			return 0, err
   128  		}
   129  	}
   130  	return n, err
   131  
   132  }
   133  
   134  // Final 去除填充写入最后一个分块
   135  func (p *PKCS7PaddingWriter) Final() error {
   136  	// 在Write 之后 cache 只会保留一个Block长度数据
   137  	b := p.cache.Bytes()
   138  	length := len(b)
   139  	if length != p.blockSize {
   140  		return errors.New("非法的PKCS7填充")
   141  	}
   142  	if length == 0 {
   143  		return nil
   144  	}
   145  	unpadding := int(b[length-1])
   146  	if unpadding > p.blockSize || unpadding == 0 {
   147  		return errors.New("非法的PKCS7填充")
   148  	}
   149  	_, err := p.out.Write(b[:(length - unpadding)])
   150  	return err
   151  }