github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/sm2soft/sm2_test.go (about)

     1  // Copyright 2022 s1ren@github.com/hxx258456.
     2  
     3  /*
     4  sm2soft 是sm2的纯软实现,基于tjfoc国密算法库`tjfoc/gmsm`做了少量修改。
     5  对应版权声明: thrid_licenses/github.com/tjfoc/gmsm/版权声明
     6  */
     7  
     8  package sm2soft
     9  
    10  import (
    11  	"bytes"
    12  	"crypto/rand"
    13  	"fmt"
    14  	"io/ioutil"
    15  	"math/big"
    16  	"os"
    17  	"testing"
    18  
    19  	"github.com/hxx258456/ccgo/sm3"
    20  )
    21  
    22  func TestSm2(t *testing.T) {
    23  	// 生成sm2密钥对
    24  	priv, err := GenerateKey(rand.Reader)
    25  	fmt.Println("私钥: ", priv.D)
    26  	if err != nil {
    27  		t.Fatal(err)
    28  	}
    29  	// 验证生成的公钥是否在sm2的椭圆曲线上
    30  	fmt.Printf("公钥是否在sm2的椭圆曲线上: %v\n", priv.Curve.IsOnCurve(priv.X, priv.Y))
    31  	// 公钥
    32  	pub := &priv.PublicKey
    33  	fmt.Println("公钥: ", pub.X, pub.Y)
    34  
    35  	// 定义明文
    36  	msg := []byte("12345,上山打老虎")
    37  	fmt.Printf("明文: %s\n", msg)
    38  
    39  	// 公钥加密,C1C3C2模式,结果asn1编码
    40  	d0, err := pub.EncryptAsn1(msg, rand.Reader)
    41  	if err != nil {
    42  		fmt.Printf("Error: failed to encrypt %s: %v\n", msg, err)
    43  		return
    44  	}
    45  	fmt.Printf("公钥加密结果(C1C3C2) : %v\n", d0)
    46  
    47  	// 私钥解密,C1C3C2模式,先asn1解码
    48  	d1, err := priv.DecryptAsn1(d0)
    49  	if err != nil {
    50  		fmt.Printf("Error: failed to decrypt: %v\n", err)
    51  	}
    52  	fmt.Printf("私钥解密结果(C1C3C2) : %s\n", d1)
    53  
    54  	// 公钥加密 C1C2C3
    55  	d2, err := Encrypt(pub, msg, rand.Reader, C1C2C3)
    56  	if err != nil {
    57  		fmt.Printf("Error: failed to encrypt %s: %v\n", msg, err)
    58  		return
    59  	}
    60  	fmt.Printf("公钥加密结果(C1C2C3) : %v\n", d2)
    61  	// 私钥解密,C1C2C3
    62  	d3, err := Decrypt(priv, d2, C1C2C3)
    63  	if err != nil {
    64  		fmt.Printf("Error: failed to decrypt: %v\n", err)
    65  	}
    66  	fmt.Printf("私钥解密结果(C1C2C3) : %s\n", d3)
    67  
    68  	// 从文件读取消息
    69  	msg, _ = ioutil.ReadFile("testdata/msg")
    70  	hashFunc := sm3.New()
    71  	hashFunc.Write(msg)
    72  	digest := hashFunc.Sum(nil)
    73  
    74  	// 私钥签名
    75  	sign, err := priv.Sign(rand.Reader, digest, nil)
    76  	if err != nil {
    77  		t.Fatal(err)
    78  	}
    79  	// 签名写入文件
    80  	err = ioutil.WriteFile("testdata/signdata", sign, os.FileMode(0644))
    81  	if err != nil {
    82  		t.Fatal(err)
    83  	}
    84  	// 读取签名文件
    85  	signdata, _ := ioutil.ReadFile("testdata/signdata")
    86  	// 公钥验签
    87  	ok := pub.Verify(digest, signdata)
    88  	if ok != true {
    89  		fmt.Printf("公钥验签失败\n")
    90  	} else {
    91  		fmt.Printf("公钥验签成功\n")
    92  	}
    93  }
    94  
    95  func BenchmarkSM2(t *testing.B) {
    96  	t.ReportAllocs()
    97  	msg := []byte("test")
    98  	priv, err := GenerateKey(nil) // 生成密钥对
    99  	if err != nil {
   100  		t.Fatal(err)
   101  	}
   102  	t.ResetTimer()
   103  	for i := 0; i < t.N; i++ {
   104  		sign, err := priv.Sign(nil, msg, nil) // 签名
   105  		if err != nil {
   106  			t.Fatal(err)
   107  		}
   108  		priv.PublicKey.Verify(msg, sign) // 密钥验证
   109  	}
   110  }
   111  
   112  func BenchmarkVerify_SM2(b *testing.B) {
   113  	priv, err := GenerateKey(rand.Reader)
   114  	if err != nil {
   115  		b.Fatal(err)
   116  	}
   117  	hashed := []byte("testing")
   118  	r, s, err := Sm2Sign(priv, hashed, nil, rand.Reader)
   119  	if err != nil {
   120  		b.Fatal(err)
   121  	}
   122  
   123  	b.ReportAllocs()
   124  	b.ResetTimer()
   125  	for i := 0; i < b.N; i++ {
   126  		if !Sm2Verify(&priv.PublicKey, hashed, nil, r, s) {
   127  			b.Fatal("verify failed")
   128  		}
   129  	}
   130  }
   131  
   132  func TestKEB2(t *testing.T) {
   133  	ida := []byte{'1', '2', '3', '4', '5', '6', '7', '8',
   134  		'1', '2', '3', '4', '5', '6', '7', '8'}
   135  	idb := []byte{'1', '2', '3', '4', '5', '6', '7', '8',
   136  		'1', '2', '3', '4', '5', '6', '7', '8'}
   137  	daBuf := []byte{0x81, 0xEB, 0x26, 0xE9, 0x41, 0xBB, 0x5A, 0xF1,
   138  		0x6D, 0xF1, 0x16, 0x49, 0x5F, 0x90, 0x69, 0x52,
   139  		0x72, 0xAE, 0x2C, 0xD6, 0x3D, 0x6C, 0x4A, 0xE1,
   140  		0x67, 0x84, 0x18, 0xBE, 0x48, 0x23, 0x00, 0x29}
   141  	dbBuf := []byte{0x78, 0x51, 0x29, 0x91, 0x7D, 0x45, 0xA9, 0xEA,
   142  		0x54, 0x37, 0xA5, 0x93, 0x56, 0xB8, 0x23, 0x38,
   143  		0xEA, 0xAD, 0xDA, 0x6C, 0xEB, 0x19, 0x90, 0x88,
   144  		0xF1, 0x4A, 0xE1, 0x0D, 0xEF, 0xA2, 0x29, 0xB5}
   145  	raBuf := []byte{0xD4, 0xDE, 0x15, 0x47, 0x4D, 0xB7, 0x4D, 0x06,
   146  		0x49, 0x1C, 0x44, 0x0D, 0x30, 0x5E, 0x01, 0x24,
   147  		0x00, 0x99, 0x0F, 0x3E, 0x39, 0x0C, 0x7E, 0x87,
   148  		0x15, 0x3C, 0x12, 0xDB, 0x2E, 0xA6, 0x0B, 0xB3}
   149  
   150  	rbBuf := []byte{0x7E, 0x07, 0x12, 0x48, 0x14, 0xB3, 0x09, 0x48,
   151  		0x91, 0x25, 0xEA, 0xED, 0x10, 0x11, 0x13, 0x16,
   152  		0x4E, 0xBF, 0x0F, 0x34, 0x58, 0xC5, 0xBD, 0x88,
   153  		0x33, 0x5C, 0x1F, 0x9D, 0x59, 0x62, 0x43, 0xD6}
   154  
   155  	expk := []byte{0x6C, 0x89, 0x34, 0x73, 0x54, 0xDE, 0x24, 0x84,
   156  		0xC6, 0x0B, 0x4A, 0xB1, 0xFD, 0xE4, 0xC6, 0xE5}
   157  
   158  	curve := P256Sm2()
   159  	curve.ScalarBaseMult(daBuf)
   160  	da := new(PrivateKey)
   161  	da.PublicKey.Curve = curve
   162  	da.D = new(big.Int).SetBytes(daBuf)
   163  	da.PublicKey.X, da.PublicKey.Y = curve.ScalarBaseMult(daBuf)
   164  
   165  	db := new(PrivateKey)
   166  	db.PublicKey.Curve = curve
   167  	db.D = new(big.Int).SetBytes(dbBuf)
   168  	db.PublicKey.X, db.PublicKey.Y = curve.ScalarBaseMult(dbBuf)
   169  
   170  	ra := new(PrivateKey)
   171  	ra.PublicKey.Curve = curve
   172  	ra.D = new(big.Int).SetBytes(raBuf)
   173  	ra.PublicKey.X, ra.PublicKey.Y = curve.ScalarBaseMult(raBuf)
   174  
   175  	rb := new(PrivateKey)
   176  	rb.PublicKey.Curve = curve
   177  	rb.D = new(big.Int).SetBytes(rbBuf)
   178  	rb.PublicKey.X, rb.PublicKey.Y = curve.ScalarBaseMult(rbBuf)
   179  
   180  	k1, Sb, S2, err := KeyExchangeB(16, ida, idb, db, &da.PublicKey, rb, &ra.PublicKey)
   181  	if err != nil {
   182  		t.Error(err)
   183  	}
   184  	k2, S1, Sa, err := KeyExchangeA(16, ida, idb, da, &db.PublicKey, ra, &rb.PublicKey)
   185  	if err != nil {
   186  		t.Error(err)
   187  	}
   188  	if !bytes.Equal(k1, k2) {
   189  		t.Error("key exchange differ")
   190  	}
   191  	if !bytes.Equal(k1, expk) {
   192  		t.Errorf("expected %x, found %x", expk, k1)
   193  	}
   194  	if !bytes.Equal(S1, Sb) {
   195  		t.Error("hash verfication failed")
   196  	}
   197  	if !bytes.Equal(Sa, S2) {
   198  		t.Error("hash verfication failed")
   199  	}
   200  }