github.com/Hyperledger-TWGC/tjfoc-gm@v1.4.0/sm2/sm2_test.go (about)

     1  /*
     2  Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved.
     3  Licensed under the Apache License, Version 2.0 (the "License");
     4  you may not use this file except in compliance with the License.
     5  You may obtain a copy of the License at
     6  
     7                   http://www.apache.org/licenses/LICENSE-2.0
     8  
     9  Unless required by applicable law or agreed to in writing, software
    10  distributed under the License is distributed on an "AS IS" BASIS,
    11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  See the License for the specific language governing permissions and
    13  limitations under the License.
    14  */
    15  
    16  package sm2
    17  
    18  import (
    19  	"bytes"
    20  	"crypto/ecdsa"
    21  	"crypto/elliptic"
    22  	"crypto/rand"
    23  	"fmt"
    24  	"io/ioutil"
    25  	"math/big"
    26  	"os"
    27  	"testing"
    28  )
    29  
    30  func TestSm2(t *testing.T) {
    31  	priv, err := GenerateKey(rand.Reader) // 生成密钥对
    32  	fmt.Println(priv)
    33  	if err != nil {
    34  		t.Fatal(err)
    35  	}
    36  	fmt.Printf("%v\n", priv.Curve.IsOnCurve(priv.X, priv.Y)) // 验证是否为sm2的曲线
    37  	pub := &priv.PublicKey
    38  	msg := []byte("123456")
    39  	d0, err := pub.EncryptAsn1(msg, rand.Reader)
    40  	if err != nil {
    41  		fmt.Printf("Error: failed to encrypt %s: %v\n", msg, err)
    42  		return
    43  	}
    44  	// fmt.Printf("Cipher text = %v\n", d0)
    45  	d1, err := priv.DecryptAsn1(d0)
    46  	if err != nil {
    47  		fmt.Printf("Error: failed to decrypt: %v\n", err)
    48  	}
    49  	fmt.Printf("clear text = %s\n", d1)
    50  
    51  	msg, _ = ioutil.ReadFile("ifile")             // 从文件读取数据
    52  	sign, err := priv.Sign(rand.Reader, msg, nil) // 签名
    53  	if err != nil {
    54  		t.Fatal(err)
    55  	}
    56  
    57  	err = ioutil.WriteFile("TestResult", sign, os.FileMode(0644))
    58  	if err != nil {
    59  		t.Fatal(err)
    60  	}
    61  	signdata, _ := ioutil.ReadFile("TestResult")
    62  	ok := priv.Verify(msg, signdata) // 密钥验证
    63  	if ok != true {
    64  		fmt.Printf("Verify error\n")
    65  	} else {
    66  		fmt.Printf("Verify ok\n")
    67  	}
    68  	pubKey := priv.PublicKey
    69  	ok = pubKey.Verify(msg, signdata) // 公钥验证
    70  	if ok != true {
    71  		fmt.Printf("Verify error\n")
    72  	} else {
    73  		fmt.Printf("Verify ok\n")
    74  	}
    75  
    76  }
    77  
    78  func BenchmarkSM2Sign(t *testing.B) {
    79  	t.ReportAllocs()
    80  	msg := []byte("test")
    81  	priv, err := GenerateKey(nil) // 生成密钥对
    82  	if err != nil {
    83  		t.Fatal(err)
    84  	}
    85  	t.ResetTimer()
    86  	for i := 0; i < t.N; i++ {
    87  		_, err := priv.Sign(nil, msg, nil) // 签名
    88  		if err != nil {
    89  			t.Fatal(err)
    90  		}
    91  	}
    92  }
    93  
    94  func BenchmarkSM2Verify(t *testing.B) {
    95  	t.ReportAllocs()
    96  	msg := []byte("test")
    97  	priv, err := GenerateKey(nil) // 生成密钥对
    98  	if err != nil {
    99  		t.Fatal(err)
   100  	}
   101  	sign, err := priv.Sign(nil, msg, nil) // 签名
   102  	if err != nil {
   103  		t.Fatal(err)
   104  	}
   105  	t.ResetTimer()
   106  	for i := 0; i < t.N; i++ {
   107  		priv.Verify(msg, sign) // 密钥验证
   108  	}
   109  }
   110  
   111  func BenchmarkEcdsaSign(t *testing.B) {
   112  	t.ReportAllocs()
   113  	msg := []byte("test")
   114  	priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   115  	if err != nil {
   116  		t.Fatal(err)
   117  	}
   118  	t.ResetTimer()
   119  	for i := 0; i < t.N; i++ {
   120  		_, _, err := ecdsa.Sign(rand.Reader, priv, msg)
   121  		if err != nil {
   122  			t.Fatal(err)
   123  		}
   124  	}
   125  }
   126  
   127  func BenchmarkEcdsaVerify(t *testing.B) {
   128  	t.ReportAllocs()
   129  	msg := []byte("test")
   130  	priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   131  	if err != nil {
   132  		t.Fatal(err)
   133  	}
   134  	r, s, err := ecdsa.Sign(rand.Reader, priv, msg)
   135  	if err != nil {
   136  		t.Fatal(err)
   137  	}
   138  	t.ResetTimer()
   139  	for i := 0; i < t.N; i++ {
   140  		ecdsa.Verify(&priv.PublicKey, msg, r, s)
   141  	}
   142  }
   143  
   144  func TestKEB2(t *testing.T) {
   145  	ida := []byte{'1', '2', '3', '4', '5', '6', '7', '8',
   146  		'1', '2', '3', '4', '5', '6', '7', '8'}
   147  	idb := []byte{'1', '2', '3', '4', '5', '6', '7', '8',
   148  		'1', '2', '3', '4', '5', '6', '7', '8'}
   149  	daBuf := []byte{0x81, 0xEB, 0x26, 0xE9, 0x41, 0xBB, 0x5A, 0xF1,
   150  		0x6D, 0xF1, 0x16, 0x49, 0x5F, 0x90, 0x69, 0x52,
   151  		0x72, 0xAE, 0x2C, 0xD6, 0x3D, 0x6C, 0x4A, 0xE1,
   152  		0x67, 0x84, 0x18, 0xBE, 0x48, 0x23, 0x00, 0x29}
   153  	dbBuf := []byte{0x78, 0x51, 0x29, 0x91, 0x7D, 0x45, 0xA9, 0xEA,
   154  		0x54, 0x37, 0xA5, 0x93, 0x56, 0xB8, 0x23, 0x38,
   155  		0xEA, 0xAD, 0xDA, 0x6C, 0xEB, 0x19, 0x90, 0x88,
   156  		0xF1, 0x4A, 0xE1, 0x0D, 0xEF, 0xA2, 0x29, 0xB5}
   157  	raBuf := []byte{0xD4, 0xDE, 0x15, 0x47, 0x4D, 0xB7, 0x4D, 0x06,
   158  		0x49, 0x1C, 0x44, 0x0D, 0x30, 0x5E, 0x01, 0x24,
   159  		0x00, 0x99, 0x0F, 0x3E, 0x39, 0x0C, 0x7E, 0x87,
   160  		0x15, 0x3C, 0x12, 0xDB, 0x2E, 0xA6, 0x0B, 0xB3}
   161  
   162  	rbBuf := []byte{0x7E, 0x07, 0x12, 0x48, 0x14, 0xB3, 0x09, 0x48,
   163  		0x91, 0x25, 0xEA, 0xED, 0x10, 0x11, 0x13, 0x16,
   164  		0x4E, 0xBF, 0x0F, 0x34, 0x58, 0xC5, 0xBD, 0x88,
   165  		0x33, 0x5C, 0x1F, 0x9D, 0x59, 0x62, 0x43, 0xD6}
   166  
   167  	expk := []byte{0x6C, 0x89, 0x34, 0x73, 0x54, 0xDE, 0x24, 0x84,
   168  		0xC6, 0x0B, 0x4A, 0xB1, 0xFD, 0xE4, 0xC6, 0xE5}
   169  
   170  	curve := P256Sm2()
   171  	curve.ScalarBaseMult(daBuf)
   172  	da := new(PrivateKey)
   173  	da.PublicKey.Curve = curve
   174  	da.D = new(big.Int).SetBytes(daBuf)
   175  	da.PublicKey.X, da.PublicKey.Y = curve.ScalarBaseMult(daBuf)
   176  
   177  	db := new(PrivateKey)
   178  	db.PublicKey.Curve = curve
   179  	db.D = new(big.Int).SetBytes(dbBuf)
   180  	db.PublicKey.X, db.PublicKey.Y = curve.ScalarBaseMult(dbBuf)
   181  
   182  	ra := new(PrivateKey)
   183  	ra.PublicKey.Curve = curve
   184  	ra.D = new(big.Int).SetBytes(raBuf)
   185  	ra.PublicKey.X, ra.PublicKey.Y = curve.ScalarBaseMult(raBuf)
   186  
   187  	rb := new(PrivateKey)
   188  	rb.PublicKey.Curve = curve
   189  	rb.D = new(big.Int).SetBytes(rbBuf)
   190  	rb.PublicKey.X, rb.PublicKey.Y = curve.ScalarBaseMult(rbBuf)
   191  
   192  	k1, Sb, S2, err := KeyExchangeB(16, ida, idb, db, &da.PublicKey, rb, &ra.PublicKey)
   193  	if err != nil {
   194  		t.Error(err)
   195  	}
   196  	k2, S1, Sa, err := KeyExchangeA(16, ida, idb, da, &db.PublicKey, ra, &rb.PublicKey)
   197  	if err != nil {
   198  		t.Error(err)
   199  	}
   200  	if bytes.Compare(k1, k2) != 0 {
   201  		t.Error("key exchange differ")
   202  	}
   203  	if bytes.Compare(k1, expk) != 0 {
   204  		t.Errorf("expected %x, found %x", expk, k1)
   205  	}
   206  	if bytes.Compare(S1, Sb) != 0 {
   207  		t.Error("hash verfication failed")
   208  	}
   209  	if bytes.Compare(Sa, S2) != 0 {
   210  		t.Error("hash verfication failed")
   211  	}
   212  }