gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/sm4soft/sm4_test.go (about)

     1  // Copyright (c) 2022 zhaochun
     2  // core-gm is licensed under Mulan PSL v2.
     3  // You can use this software according to the terms and conditions of the Mulan PSL v2.
     4  // You may obtain a copy of Mulan PSL v2 at:
     5  //          http://license.coscl.org.cn/MulanPSL2
     6  // THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
     7  // See the Mulan PSL v2 for more details.
     8  
     9  /*
    10  sm4soft 是sm4的纯软实现,基于tjfoc国密算法库`tjfoc/gmsm`做了少量修改。
    11  对应版权声明: thrid_licenses/github.com/tjfoc/gmsm/版权声明
    12  */
    13  
    14  package sm4soft
    15  
    16  import (
    17  	"fmt"
    18  	"reflect"
    19  	"testing"
    20  )
    21  
    22  func TestSM4(t *testing.T) {
    23  	// 定义密钥,16字节
    24  	key := []byte("abcdef1234567890")
    25  	fmt.Printf("key字节数组 : %v\n", key)
    26  	fmt.Printf("key字符串 : %s\n", key)
    27  
    28  	// 将key写入key.pem
    29  	err := WriteKeyToPemFile("testdata/key.pem", key, nil)
    30  	if err != nil {
    31  		t.Fatalf("WriteKeyToPem error")
    32  	}
    33  	// 读取key.pem
    34  	key, err = ReadKeyFromPemFile("testdata/key.pem", nil)
    35  	fmt.Printf("读取到的key字节数组 : %v\n", key)
    36  	fmt.Printf("读取到的key字符串 : %s\n", key)
    37  	if err != nil {
    38  		t.Fatal(err)
    39  	}
    40  
    41  	// 定义数据
    42  	// data := []byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10}
    43  	data := []byte("天行健君子以自强不息")
    44  	fmt.Printf("data字节数组 : %v\n", data)
    45  	fmt.Printf("data十六进制 : %x\n", data)
    46  	fmt.Printf("data字符串 : %s\n", data)
    47  
    48  	// ECB模式加密
    49  	ecbMsg, err := Sm4Ecb(key, data, true)
    50  	if err != nil {
    51  		t.Errorf("sm4 enc error:%s", err)
    52  		return
    53  	}
    54  	fmt.Printf("ecbMsg 16进制 : %x\n", ecbMsg)
    55  	// ECB模式解密
    56  	ecbDec, err := Sm4Ecb(key, ecbMsg, false)
    57  	if err != nil {
    58  		t.Errorf("sm4 dec error:%s", err)
    59  		return
    60  	}
    61  	fmt.Printf("ecbDec : %s\n", ecbDec)
    62  	if !testCompare(data, ecbDec) {
    63  		t.Errorf("sm4 self enc and dec failed")
    64  	}
    65  
    66  	// 定义初始化向量,16字节
    67  	// iv := []byte("0000000000000000")
    68  	iv := []byte("1234def567890abc")
    69  	// err = SetIVDefault(iv)
    70  	fmt.Printf("err = %v\n", err)
    71  	fmt.Printf("iv字节数组 : %v\n", iv)
    72  	fmt.Printf("iv16进制 : %x\n", iv)
    73  	fmt.Printf("iv字符串 : %s\n", iv)
    74  
    75  	// CBC模式加密
    76  	cbcMsg, err := Sm4Cbc(key, iv, data, true)
    77  	if err != nil {
    78  		t.Errorf("sm4 enc error:%s", err)
    79  	}
    80  	fmt.Printf("cbcMsg 16进制 : %x\n", cbcMsg)
    81  	// CBC模式解密
    82  	cbcDec, err := Sm4Cbc(key, iv, cbcMsg, false)
    83  	if err != nil {
    84  		t.Errorf("sm4 dec error:%s", err)
    85  		return
    86  	}
    87  	fmt.Printf("cbcDec : %s\n", cbcDec)
    88  	if !testCompare(data, cbcDec) {
    89  		t.Errorf("sm4 self enc and dec failed")
    90  	}
    91  
    92  	// CFB模式加密
    93  	cfbMsg, err := Sm4CFB(key, iv, data, true)
    94  	if err != nil {
    95  		t.Errorf("sm4 enc error:%s", err)
    96  	}
    97  	fmt.Printf("cfbMsg 16进制 : %x\n", cfbMsg)
    98  	// CFB模式解密
    99  	cfbDec, err := Sm4CFB(key, iv, cfbMsg, false)
   100  	if err != nil {
   101  		t.Errorf("sm4 dec error:%s", err)
   102  		return
   103  	}
   104  	fmt.Printf("cfbDec : %s\n", cfbDec)
   105  
   106  	// OFB模式加密
   107  	ofbMsg, err := Sm4OFB(key, iv, data, true)
   108  	if err != nil {
   109  		t.Errorf("sm4 enc error:%s", err)
   110  	}
   111  	fmt.Printf("ofbMsg 16进制 : %x\n", ofbMsg)
   112  	// OFB模式解密
   113  	ofbDec, err := Sm4OFB(key, iv, ofbMsg, false)
   114  	if err != nil {
   115  		t.Errorf("sm4 dec error:%s", err)
   116  		return
   117  	}
   118  	fmt.Printf("ofbDec : %s\n", ofbDec)
   119  }
   120  
   121  func TestNewCipher(t *testing.T) {
   122  	key := []byte("1234567890abcdef")
   123  	// 直接用NewCipher只能对16字节的数据加密
   124  	data := []byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10}
   125  	// data := []byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32}
   126  	c, err := NewCipher(key)
   127  	if err != nil {
   128  		t.Fatal(err)
   129  	}
   130  	d0 := make([]byte, 16)
   131  	c.Encrypt(d0, data)
   132  	d1 := make([]byte, 16)
   133  	c.Decrypt(d1, d0)
   134  }
   135  
   136  func BenchmarkSM4(t *testing.B) {
   137  	t.ReportAllocs()
   138  	key := []byte("1234567890abcdef")
   139  	data := []byte{0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0xfe, 0xdc, 0xba, 0x98, 0x76, 0x54, 0x32, 0x10}
   140  	err := WriteKeyToPemFile("key.pem", key, nil)
   141  	if err != nil {
   142  		t.Fatalf("WriteKeyToPem error")
   143  	}
   144  	key, err = ReadKeyFromPemFile("key.pem", nil)
   145  	if err != nil {
   146  		t.Fatal(err)
   147  	}
   148  	c, err := NewCipher(key)
   149  	if err != nil {
   150  		t.Fatal(err)
   151  	}
   152  
   153  	for i := 0; i < t.N; i++ {
   154  		d0 := make([]byte, 16)
   155  		c.Encrypt(d0, data)
   156  		d1 := make([]byte, 16)
   157  		c.Decrypt(d1, d0)
   158  	}
   159  }
   160  
   161  func TestErrKeyLen(t *testing.T) {
   162  	fmt.Printf("\n--------------test key len------------------")
   163  	key := []byte("1234567890abcdefg")
   164  	_, err := NewCipher(key)
   165  	if err != nil {
   166  		fmt.Println("\nError key len !")
   167  	}
   168  	key = []byte("1234")
   169  	_, err = NewCipher(key)
   170  	if err != nil {
   171  		fmt.Println("Error key len !")
   172  	}
   173  	fmt.Println("------------------end----------------------")
   174  }
   175  
   176  func testCompare(key1, key2 []byte) bool {
   177  	if len(key1) != len(key2) {
   178  		return false
   179  	}
   180  	for i, v := range key1 {
   181  		if i == 1 {
   182  			fmt.Println("type of v", reflect.TypeOf(v))
   183  		}
   184  		a := key2[i]
   185  		if a != v {
   186  			return false
   187  		}
   188  	}
   189  	return true
   190  }