github.com/emmansun/gmsm@v0.29.1/sm2/sm2_test.go (about)

     1  package sm2
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"crypto"
     7  	"crypto/ecdsa"
     8  	"crypto/elliptic"
     9  	"crypto/rand"
    10  	"encoding/hex"
    11  	"io"
    12  	"math/big"
    13  	"reflect"
    14  	"testing"
    15  
    16  	"github.com/emmansun/gmsm/sm3"
    17  )
    18  
    19  func TestNewPrivateKey(t *testing.T) {
    20  	c := p256()
    21  	// test nil
    22  	_, err := NewPrivateKey(nil)
    23  	if err == nil || err.Error() != "sm2: invalid private key size" {
    24  		t.Errorf("should throw sm2: invalid private key size")
    25  	}
    26  	// test all zero
    27  	key := make([]byte, c.N.Size())
    28  	_, err = NewPrivateKey(key)
    29  	if err == nil || err != errInvalidPrivateKey {
    30  		t.Errorf("should throw errInvalidPrivateKey")
    31  	}
    32  	// test N-1
    33  	_, err = NewPrivateKey(c.nMinus1.Bytes(c.N))
    34  	if err == nil || err != errInvalidPrivateKey {
    35  		t.Errorf("should throw errInvalidPrivateKey")
    36  	}
    37  	// test N
    38  	_, err = NewPrivateKey(P256().Params().N.Bytes())
    39  	if err == nil || err != errInvalidPrivateKey {
    40  		t.Errorf("should throw errInvalidPrivateKey")
    41  	}
    42  	// test 1
    43  	key[31] = 1
    44  	_, err = NewPrivateKey(key)
    45  	if err != nil {
    46  		t.Fatal(err)
    47  	}
    48  	// test N-2
    49  	_, err = NewPrivateKey(c.nMinus2)
    50  	if err != nil {
    51  		t.Error(err)
    52  	}
    53  }
    54  
    55  func TestNewPrivateKeyFromInt(t *testing.T) {
    56  	// test nil
    57  	_, err := NewPrivateKeyFromInt(nil)
    58  	if err == nil || err.Error() != "sm2: invalid private key size" {
    59  		t.Errorf("should throw sm2: invalid private key size")
    60  	}
    61  	// test 1
    62  	_, err = NewPrivateKeyFromInt(big.NewInt(1))
    63  	if err != nil {
    64  		t.Fatal(err)
    65  	}
    66  	// test N
    67  	_, err = NewPrivateKeyFromInt(P256().Params().N)
    68  	if err == nil || err != errInvalidPrivateKey {
    69  		t.Errorf("should throw errInvalidPrivateKey")
    70  	}
    71  
    72  	// test N + 1
    73  	_, err = NewPrivateKeyFromInt(new(big.Int).Add(P256().Params().N, big.NewInt(1)))
    74  	if err == nil || err != errInvalidPrivateKey {
    75  		t.Errorf("should throw errInvalidPrivateKey")
    76  	}
    77  
    78  	c := p256()
    79  	// test N - 1
    80  	_, err = NewPrivateKeyFromInt(new(big.Int).SetBytes(c.nMinus1.Bytes(c.N)))
    81  	if err == nil || err != errInvalidPrivateKey {
    82  		t.Errorf("should throw errInvalidPrivateKey")
    83  	}
    84  }
    85  
    86  func TestNewPublicKey(t *testing.T) {
    87  	// test nil
    88  	_, err := NewPublicKey(nil)
    89  	if err == nil || err.Error() != "sm2: invalid public key" {
    90  		t.Errorf("should throw sm2: invalid public key")
    91  	}
    92  	// test without point format prefix byte
    93  	keypoints, _ := hex.DecodeString("8356e642a40ebd18d29ba3532fbd9f3bbee8f027c3f6f39a5ba2f870369f9988981f5efe55d1c5cdf6c0ef2b070847a14f7fdf4272a8df09c442f3058af94ba1")
    94  	_, err = NewPublicKey(keypoints)
    95  	if err == nil || err.Error() != "sm2: invalid public key" {
    96  		t.Errorf("should throw sm2: invalid public key")
    97  	}
    98  	// test correct point
    99  	keypoints, _ = hex.DecodeString("048356e642a40ebd18d29ba3532fbd9f3bbee8f027c3f6f39a5ba2f870369f9988981f5efe55d1c5cdf6c0ef2b070847a14f7fdf4272a8df09c442f3058af94ba1")
   100  	_, err = NewPublicKey(keypoints)
   101  	if err != nil {
   102  		t.Fatal(err)
   103  	}
   104  	// test point not on curve
   105  	keypoints, _ = hex.DecodeString("048356e642a40ebd18d29ba3532fbd9f3bbee8f027c3f6f39a5ba2f870369f9988981f5efe55d1c5cdf6c0ef2b070847a14f7fdf4272a8df09c442f3058af94ba2")
   106  	_, err = NewPublicKey(keypoints)
   107  	if err == nil || err.Error() != "point not on SM2 P256 curve" {
   108  		t.Errorf("should throw point not on SM2 P256 curve, got %v", err)
   109  	}
   110  }
   111  
   112  func TestSplicingOrder(t *testing.T) {
   113  	priv, _ := GenerateKey(rand.Reader)
   114  	tests := []struct {
   115  		name      string
   116  		plainText string
   117  		from      ciphertextSplicingOrder
   118  		to        ciphertextSplicingOrder
   119  	}{
   120  		// TODO: Add test cases.
   121  		{"less than 32 1", "encryption standard", C1C2C3, C1C3C2},
   122  		{"less than 32 2", "encryption standard", C1C3C2, C1C2C3},
   123  		{"equals 32 1", "encryption standard encryption ", C1C2C3, C1C3C2},
   124  		{"equals 32 2", "encryption standard encryption ", C1C3C2, C1C2C3},
   125  		{"long than 32 1", "encryption standard encryption standard", C1C2C3, C1C3C2},
   126  		{"long than 32 2", "encryption standard encryption standard", C1C3C2, C1C2C3},
   127  	}
   128  	for _, tt := range tests {
   129  		t.Run(tt.name, func(t *testing.T) {
   130  			ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), NewPlainEncrypterOpts(MarshalUncompressed, tt.from))
   131  			if err != nil {
   132  				t.Fatalf("encrypt failed %v", err)
   133  			}
   134  			plaintext, err := priv.Decrypt(rand.Reader, ciphertext, NewPlainDecrypterOpts(tt.from))
   135  			if err != nil {
   136  				t.Fatalf("decrypt failed %v", err)
   137  			}
   138  			if !reflect.DeepEqual(string(plaintext), tt.plainText) {
   139  				t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
   140  			}
   141  
   142  			//Adjust splicing order
   143  			ciphertext, err = AdjustCiphertextSplicingOrder(ciphertext, tt.from, tt.to)
   144  			if err != nil {
   145  				t.Fatalf("adjust splicing order failed %v", err)
   146  			}
   147  			plaintext, err = priv.Decrypt(rand.Reader, ciphertext, NewPlainDecrypterOpts(tt.to))
   148  			if err != nil {
   149  				t.Fatalf("decrypt failed after adjust splicing order %v", err)
   150  			}
   151  			if !reflect.DeepEqual(string(plaintext), tt.plainText) {
   152  				t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
   153  			}
   154  		})
   155  	}
   156  }
   157  
   158  func TestEncryptDecryptASN1(t *testing.T) {
   159  	priv, _ := GenerateKey(rand.Reader)
   160  	priv2, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   161  	key2 := new(PrivateKey)
   162  	key2.PrivateKey = *priv2
   163  	tests := []struct {
   164  		name      string
   165  		plainText string
   166  		priv      *PrivateKey
   167  	}{
   168  		// TODO: Add test cases.
   169  		{"less than 32", "encryption standard", priv},
   170  		{"equals 32", "encryption standard encryption ", priv},
   171  		{"long than 32", "encryption standard encryption standard", priv},
   172  		{"less than 32", "encryption standard", key2},
   173  		{"equals 32", "encryption standard encryption ", key2},
   174  		{"long than 32", "encryption standard encryption standard", key2},
   175  	}
   176  	for _, tt := range tests {
   177  		t.Run(tt.name, func(t *testing.T) {
   178  			encrypterOpts := ASN1EncrypterOpts
   179  			ciphertext, err := Encrypt(rand.Reader, &tt.priv.PublicKey, []byte(tt.plainText), encrypterOpts)
   180  			if err != nil {
   181  				t.Fatalf("%v encrypt failed %v", tt.priv.Curve.Params().Name, err)
   182  			}
   183  			plaintext, err := tt.priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts)
   184  			if err != nil {
   185  				t.Fatalf("%v decrypt 1 failed %v", tt.priv.Curve.Params().Name, err)
   186  			}
   187  			if !reflect.DeepEqual(string(plaintext), tt.plainText) {
   188  				t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
   189  			}
   190  			plaintext, err = tt.priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts)
   191  			if err != nil {
   192  				t.Fatalf("%v decrypt 2 failed %v", tt.priv.Curve.Params().Name, err)
   193  			}
   194  			if !reflect.DeepEqual(string(plaintext), tt.plainText) {
   195  				t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
   196  			}
   197  		})
   198  	}
   199  }
   200  
   201  func TestPlainCiphertext2ASN1(t *testing.T) {
   202  	ciphertext, _ := hex.DecodeString("047928e22045eec8dc00e95639dd0c1c8dfb75cf8cedcf496731a6a6f423baa54c5014c60b73495886d8d7bc996a4a716cb58e6bfc8e03078b24e7b0f5cba0efd5b9272c27fc263bb59eaca6eabc97c0323bf1de953aeabaf59700b3bf49c9a1056decc08dd18544960541a2239afa7b1512df05")
   203  	_, err := PlainCiphertext2ASN1(append([]byte{0x30}, ciphertext...), C1C3C2)
   204  	if err == nil {
   205  		t.Fatalf("expected error")
   206  	}
   207  	_, err = PlainCiphertext2ASN1(ciphertext[:65], C1C3C2)
   208  	if err == nil {
   209  		t.Fatalf("expected error")
   210  	}
   211  	ciphertext[0] = 0x10
   212  	_, err = PlainCiphertext2ASN1(ciphertext, C1C3C2)
   213  	if err == nil {
   214  		t.Fatalf("expected error")
   215  	}
   216  }
   217  
   218  func TestAdjustCiphertextSplicingOrder(t *testing.T) {
   219  	ciphertext, _ := hex.DecodeString("047928e22045eec8dc00e95639dd0c1c8dfb75cf8cedcf496731a6a6f423baa54c5014c60b73495886d8d7bc996a4a716cb58e6bfc8e03078b24e7b0f5cba0efd5b9272c27fc263bb59eaca6eabc97c0323bf1de953aeabaf59700b3bf49c9a1056decc08dd18544960541a2239afa7b1512df05")
   220  	res, err := AdjustCiphertextSplicingOrder(ciphertext, C1C3C2, C1C3C2)
   221  	if err != nil || &res[0] != &ciphertext[0] {
   222  		t.Fatalf("should be same one")
   223  	}
   224  	_, err = AdjustCiphertextSplicingOrder(ciphertext[:65], C1C3C2, C1C2C3)
   225  	if err == nil {
   226  		t.Fatalf("expected error")
   227  	}
   228  	ciphertext[0] = 0x10
   229  	_, err = AdjustCiphertextSplicingOrder(ciphertext, C1C3C2, C1C2C3)
   230  	if err == nil {
   231  		t.Fatalf("expected error")
   232  	}
   233  }
   234  
   235  func TestCiphertext2ASN1(t *testing.T) {
   236  	priv, _ := GenerateKey(rand.Reader)
   237  	tests := []struct {
   238  		name      string
   239  		plainText string
   240  	}{
   241  		// TODO: Add test cases.
   242  		{"less than 32", "encryption standard"},
   243  		{"equals 32", "encryption standard encryption "},
   244  		{"long than 32", "encryption standard encryption standard"},
   245  	}
   246  	for _, tt := range tests {
   247  		t.Run(tt.name, func(t *testing.T) {
   248  			ciphertext1, err := Encrypt(rand.Reader, &priv.PublicKey, []byte(tt.plainText), nil)
   249  			if err != nil {
   250  				t.Fatalf("encrypt failed %v", err)
   251  			}
   252  
   253  			ciphertext, err := PlainCiphertext2ASN1(ciphertext1, C1C3C2)
   254  			if err != nil {
   255  				t.Fatalf("convert to ASN.1 failed %v", err)
   256  			}
   257  			plaintext, err := priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts)
   258  			if err != nil {
   259  				t.Fatalf("decrypt failed %v", err)
   260  			}
   261  			if !reflect.DeepEqual(string(plaintext), tt.plainText) {
   262  				t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
   263  			}
   264  
   265  			ciphertext2, err := AdjustCiphertextSplicingOrder(ciphertext1, C1C3C2, C1C2C3)
   266  			if err != nil {
   267  				t.Fatalf("adjust order failed %v", err)
   268  			}
   269  			ciphertext, err = PlainCiphertext2ASN1(ciphertext2, C1C2C3)
   270  			if err != nil {
   271  				t.Fatalf("convert to ASN.1 failed %v", err)
   272  			}
   273  			plaintext, err = priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts)
   274  			if err != nil {
   275  				t.Fatalf("decrypt failed %v", err)
   276  			}
   277  			if !reflect.DeepEqual(string(plaintext), tt.plainText) {
   278  				t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
   279  			}
   280  		})
   281  	}
   282  }
   283  
   284  func TestCiphertextASN12Plain(t *testing.T) {
   285  	priv, _ := GenerateKey(rand.Reader)
   286  	tests := []struct {
   287  		name      string
   288  		plainText string
   289  	}{
   290  		// TODO: Add test cases.
   291  		{"less than 32", "encryption standard"},
   292  		{"equals 32", "encryption standard encryption "},
   293  		{"long than 32", "encryption standard encryption standard"},
   294  	}
   295  	for _, tt := range tests {
   296  		t.Run(tt.name, func(t *testing.T) {
   297  			ciphertext, err := EncryptASN1(rand.Reader, &priv.PublicKey, []byte(tt.plainText))
   298  			if err != nil {
   299  				t.Fatalf("encrypt failed %v", err)
   300  			}
   301  			ciphertext, err = ASN1Ciphertext2Plain(ciphertext, nil)
   302  			if err != nil {
   303  				t.Fatalf("convert to plain failed %v", err)
   304  			}
   305  			plaintext, err := priv.Decrypt(rand.Reader, ciphertext, nil)
   306  			if err != nil {
   307  				t.Fatalf("decrypt failed %v", err)
   308  			}
   309  			if !reflect.DeepEqual(string(plaintext), tt.plainText) {
   310  				t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
   311  			}
   312  		})
   313  	}
   314  }
   315  
   316  func TestEncryptWithInfinitePublicKey(t *testing.T) {
   317  	pub := new(ecdsa.PublicKey)
   318  	pub.Curve = P256()
   319  	pub.X = big.NewInt(0)
   320  	pub.Y = big.NewInt(0)
   321  
   322  	_, err := Encrypt(rand.Reader, pub, []byte("sm2 encryption standard"), nil)
   323  	if err == nil {
   324  		t.Fatalf("should be failed")
   325  	}
   326  }
   327  
   328  func TestEncryptEmptyPlaintext(t *testing.T) {
   329  	priv, _ := GenerateKey(rand.Reader)
   330  	ciphertext, err := Encrypt(rand.Reader, &priv.PublicKey, nil, nil)
   331  	if err != nil || ciphertext != nil {
   332  		t.Fatalf("nil plaintext should return nil")
   333  	}
   334  	ciphertext, err = Encrypt(rand.Reader, &priv.PublicKey, []byte{}, nil)
   335  	if err != nil || ciphertext != nil {
   336  		t.Fatalf("empty plaintext should return nil")
   337  	}
   338  }
   339  
   340  func TestEncryptDecrypt(t *testing.T) {
   341  	priv, _ := GenerateKey(rand.Reader)
   342  	priv2, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   343  	key2 := new(PrivateKey)
   344  	key2.PrivateKey = *priv2
   345  	tests := []struct {
   346  		name      string
   347  		plainText string
   348  		priv      *PrivateKey
   349  	}{
   350  		// TODO: Add test cases.
   351  		{"less than 32", "encryption standard", priv},
   352  		{"equals 32", "encryption standard encryption ", priv},
   353  		{"long than 32", "encryption standard encryption standard", priv},
   354  		{"less than 32", "encryption standard", key2},
   355  		{"equals 32", "encryption standard encryption ", key2},
   356  		{"long than 32", "encryption standard encryption standard", key2},
   357  	}
   358  	for _, tt := range tests {
   359  		t.Run(tt.name, func(t *testing.T) {
   360  			ciphertext, err := Encrypt(rand.Reader, &tt.priv.PublicKey, []byte(tt.plainText), nil)
   361  			if err != nil {
   362  				t.Fatalf("encrypt failed %v", err)
   363  			}
   364  			plaintext, err := Decrypt(tt.priv, ciphertext)
   365  			if err != nil {
   366  				t.Fatalf("decrypt failed %v", err)
   367  			}
   368  			if !reflect.DeepEqual(string(plaintext), tt.plainText) {
   369  				t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
   370  			}
   371  			// compress mode
   372  			encrypterOpts := NewPlainEncrypterOpts(MarshalCompressed, C1C3C2)
   373  			ciphertext, err = Encrypt(rand.Reader, &tt.priv.PublicKey, []byte(tt.plainText), encrypterOpts)
   374  			if err != nil {
   375  				t.Fatalf("encrypt failed %v", err)
   376  			}
   377  			plaintext, err = Decrypt(tt.priv, ciphertext)
   378  			if err != nil {
   379  				t.Fatalf("decrypt failed %v", err)
   380  			}
   381  			if !reflect.DeepEqual(string(plaintext), tt.plainText) {
   382  				t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
   383  			}
   384  
   385  			// hybrid mode
   386  			encrypterOpts = NewPlainEncrypterOpts(MarshalHybrid, C1C3C2)
   387  			ciphertext, err = Encrypt(rand.Reader, &tt.priv.PublicKey, []byte(tt.plainText), encrypterOpts)
   388  			if err != nil {
   389  				t.Fatalf("encrypt failed %v", err)
   390  			}
   391  			plaintext, err = Decrypt(tt.priv, ciphertext)
   392  			if err != nil {
   393  				t.Fatalf("decrypt failed %v", err)
   394  			}
   395  			if !reflect.DeepEqual(string(plaintext), tt.plainText) {
   396  				t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
   397  			}
   398  			plaintext, err = Decrypt(tt.priv, ciphertext)
   399  			if err != nil {
   400  				t.Fatalf("decrypt failed %v", err)
   401  			}
   402  			if !reflect.DeepEqual(string(plaintext), tt.plainText) {
   403  				t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
   404  			}
   405  		})
   406  	}
   407  }
   408  
   409  func TestInvalidCiphertext(t *testing.T) {
   410  	priv, _ := GenerateKey(rand.Reader)
   411  	tests := []struct {
   412  		name       string
   413  		ciphertext []byte
   414  	}{
   415  		// TODO: Add test cases.
   416  		{errCiphertextTooShort.Error(), make([]byte, 65)},
   417  		{ErrDecryption.Error(), append([]byte{0x04}, make([]byte, 96)...)},
   418  		{ErrDecryption.Error(), append([]byte{0x04}, make([]byte, 97)...)},
   419  		{ErrDecryption.Error(), append([]byte{0x02}, make([]byte, 65)...)},
   420  		{ErrDecryption.Error(), append([]byte{0x30}, make([]byte, 97)...)},
   421  		{ErrDecryption.Error(), make([]byte, 97)},
   422  	}
   423  	for i, tt := range tests {
   424  		_, err := Decrypt(priv, tt.ciphertext)
   425  		if err.Error() != tt.name {
   426  			t.Fatalf("case %v, expected %v, got %v\n", i, tt.name, err.Error())
   427  		}
   428  	}
   429  }
   430  
   431  func TestPrivateKeyPlus1WithOrderMinus1(t *testing.T) {
   432  	priv := new(PrivateKey)
   433  	priv.D = new(big.Int).Sub(P256().Params().N, big.NewInt(1))
   434  	priv.Curve = P256()
   435  	priv.PublicKey.X, priv.PublicKey.Y = P256().ScalarBaseMult(priv.D.Bytes())
   436  
   437  	_, err := priv.inverseOfPrivateKeyPlus1(p256())
   438  	if err == nil || err != errInvalidPrivateKey {
   439  		t.Errorf("expected invalid private key error")
   440  	}
   441  }
   442  
   443  func TestSignVerify(t *testing.T) {
   444  	priv, _ := GenerateKey(rand.Reader)
   445  	tests := []struct {
   446  		name      string
   447  		plainText string
   448  	}{
   449  		// TODO: Add test cases.
   450  		{"less than 32", "encryption standard"},
   451  		{"equals 32", "encryption standard encryption "},
   452  		{"long than 32", "encryption standard encryption standard"},
   453  	}
   454  	for _, tt := range tests {
   455  		t.Run(tt.name, func(t *testing.T) {
   456  			hashed := sm3.Sum([]byte(tt.plainText))
   457  			signature, err := priv.Sign(rand.Reader, hashed[:], nil)
   458  			if err != nil {
   459  				t.Fatalf("sign failed %v", err)
   460  			}
   461  			result := VerifyASN1(&priv.PublicKey, hashed[:], signature)
   462  			if !result {
   463  				t.Fatal("verify failed")
   464  			}
   465  			hashed[0] ^= 0xff
   466  			if VerifyASN1(&priv.PublicKey, hashed[:], signature) {
   467  				t.Errorf("VerifyASN1 always works!")
   468  			}
   469  		})
   470  	}
   471  }
   472  
   473  func testRecoverPublicKeysFromSM2Signature(t *testing.T, priv *PrivateKey) {
   474  	tests := []struct {
   475  		name      string
   476  		plainText string
   477  	}{
   478  		{"less than 32", "encryption standard"},
   479  		{"equals 32", "encryption standard encryption "},
   480  		{"long than 32", "encryption standard encryption standard"},
   481  	}
   482  	for _, tt := range tests {
   483  		t.Run(tt.name, func(t *testing.T) {
   484  			hashValue, err := CalculateSM2Hash(&priv.PublicKey, []byte(tt.plainText), nil)
   485  			if err != nil {
   486  				t.Fatalf("hash failed %v", err)
   487  			}
   488  			sig, err := priv.Sign(rand.Reader, hashValue, nil)
   489  			if err != nil {
   490  				t.Fatalf("sign failed %v", err)
   491  			}
   492  
   493  			pubs, err := RecoverPublicKeysFromSM2Signature(hashValue, sig)
   494  			if err != nil {
   495  				t.Fatalf("recover sig=%x, priv=%x, failed %v", sig, priv.D.Bytes(), err)
   496  			}
   497  			found := false
   498  			for _, pub := range pubs {
   499  				if !VerifyASN1(pub, hashValue, sig) {
   500  					t.Errorf("failed to verify hash for sig=%x, priv=%x", sig, priv.D.Bytes())
   501  				}
   502  				if pub.Equal(&priv.PublicKey) {
   503  					found = true
   504  				}
   505  			}
   506  			if !found {
   507  				t.Errorf("recover failed, not found public key for sig=%x, priv=%x", sig, priv.D.Bytes())
   508  			}
   509  		})
   510  	}
   511  }
   512  
   513  func TestRecoverPublicKeysFromSM2Signature(t *testing.T) {
   514  	priv, _ := GenerateKey(rand.Reader)
   515  	testRecoverPublicKeysFromSM2Signature(t, priv)
   516  	keyInt := bigFromHex("d6833540d019e0438a5dd73b414f26ab43d8064b99671206944e284dbd969093")
   517  	priv, _ = NewPrivateKeyFromInt(keyInt)
   518  	testRecoverPublicKeysFromSM2Signature(t, priv)
   519  
   520  	// failed case
   521  	hashValue, _ := CalculateSM2Hash(&priv.PublicKey, []byte("encryption standard encryption "), nil)
   522  	signature, _ := hex.DecodeString("3045022000cd0b56bf6be810032d28ff27d6f3468f1f1a09bcf8581f30a5de6692c85ea602210096ba29c086134af1be139dd572f2f2908f30e01fd0c28e06a687cbb0ff6e33ce")
   523  	// verify signature with public key
   524  	if !VerifyASN1(&priv.PublicKey, hashValue, signature) {
   525  		t.Errorf("failed to verify hash for sig=%x, priv=%x", signature, priv.D.Bytes())
   526  	}
   527  	pubs, err := RecoverPublicKeysFromSM2Signature(hashValue, signature)
   528  	if err != nil {
   529  		t.Fatalf("recover failed %v", err)
   530  	}
   531  	found := false
   532  	for _, pub := range pubs {
   533  		if !VerifyASN1(pub, hashValue, signature) {
   534  			t.Errorf("failed to verify hash for sig=%x, priv=%x", signature, priv.D.Bytes())
   535  		}
   536  		if pub.Equal(&priv.PublicKey) {
   537  			found = true
   538  		}
   539  	}
   540  	if !found {
   541  		t.Errorf("recover failed, not found public key for sig=%x, priv=%x", signature, priv.D.Bytes())
   542  	}
   543  }
   544  
   545  func TestSignVerifyLegacy(t *testing.T) {
   546  	priv, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   547  	tests := []struct {
   548  		name      string
   549  		plainText string
   550  	}{
   551  		// TODO: Add test cases.
   552  		{"less than 32", "encryption standard"},
   553  		{"equals 32", "encryption standard encryption "},
   554  		{"long than 32", "encryption standard encryption standard"},
   555  	}
   556  	for _, tt := range tests {
   557  		t.Run(tt.name, func(t *testing.T) {
   558  			hashed := sm3.Sum([]byte(tt.plainText))
   559  			r, s, err := Sign(rand.Reader, priv, hashed[:])
   560  			if err != nil {
   561  				t.Fatalf("sign failed %v", err)
   562  			}
   563  			result := Verify(&priv.PublicKey, hashed[:], r, s)
   564  			if !result {
   565  				t.Fatal("verify failed")
   566  			}
   567  			hashed[0] ^= 0xff
   568  			if Verify(&priv.PublicKey, hashed[:], r, s) {
   569  				t.Errorf("VerifyASN1 always works!")
   570  			}
   571  		})
   572  	}
   573  }
   574  
   575  // Check that signatures remain non-deterministic with a functional entropy source.
   576  func TestINDCCA(t *testing.T) {
   577  	priv, err := GenerateKey(rand.Reader)
   578  	if err != nil {
   579  		t.Errorf("failed to generate key")
   580  	}
   581  
   582  	hashed := []byte("testing")
   583  	r0, s0, err := Sign(rand.Reader, &priv.PrivateKey, hashed)
   584  	if err != nil {
   585  		t.Errorf("SM2: error signing: %s", err)
   586  		return
   587  	}
   588  
   589  	r1, s1, err := Sign(rand.Reader, &priv.PrivateKey, hashed)
   590  	if err != nil {
   591  		t.Errorf("SM2: error signing: %s", err)
   592  		return
   593  	}
   594  
   595  	if s0.Cmp(s1) == 0 {
   596  		t.Error("SM2: two signatures of the same message produced the same result")
   597  	}
   598  
   599  	if r0.Cmp(r1) == 0 {
   600  		t.Error("SM2: two signatures of the same message produced the same nonce")
   601  	}
   602  }
   603  
   604  func TestNegativeInputs(t *testing.T) {
   605  	key, err := GenerateKey(rand.Reader)
   606  	if err != nil {
   607  		t.Errorf("failed to generate key")
   608  	}
   609  
   610  	var hash [32]byte
   611  	r := new(big.Int).SetInt64(1)
   612  	r.Lsh(r, 550 /* larger than any supported curve */)
   613  	r.Neg(r)
   614  
   615  	if Verify(&key.PublicKey, hash[:], r, r) {
   616  		t.Errorf("bogus signature accepted")
   617  	}
   618  }
   619  
   620  func TestZeroHashSignature(t *testing.T) {
   621  	zeroHash := make([]byte, 64)
   622  
   623  	privKey, err := GenerateKey(rand.Reader)
   624  	if err != nil {
   625  		panic(err)
   626  	}
   627  
   628  	// Sign a hash consisting of all zeros.
   629  	r, s, err := Sign(rand.Reader, &privKey.PrivateKey, zeroHash)
   630  	if err != nil {
   631  		panic(err)
   632  	}
   633  
   634  	// Confirm that it can be verified.
   635  	if !Verify(&privKey.PublicKey, zeroHash, r, s) {
   636  		t.Errorf("zero hash signature verify failed")
   637  	}
   638  }
   639  
   640  func TestZeroSignature(t *testing.T) {
   641  	privKey, err := GenerateKey(rand.Reader)
   642  	if err != nil {
   643  		panic(err)
   644  	}
   645  	if Verify(&privKey.PublicKey, make([]byte, 64), big.NewInt(0), big.NewInt(0)) {
   646  		t.Error("Verify with r,s=0 succeeded")
   647  	}
   648  }
   649  
   650  func TestNegtativeSignature(t *testing.T) {
   651  	zeroHash := make([]byte, 64)
   652  
   653  	privKey, err := GenerateKey(rand.Reader)
   654  	if err != nil {
   655  		panic(err)
   656  	}
   657  	r, s, err := Sign(rand.Reader, &privKey.PrivateKey, zeroHash)
   658  	if err != nil {
   659  		panic(err)
   660  	}
   661  
   662  	r = r.Neg(r)
   663  	if Verify(&privKey.PublicKey, zeroHash, r, s) {
   664  		t.Error("Verify with r=-r succeeded")
   665  	}
   666  }
   667  
   668  func TestRPlusNSignature(t *testing.T) {
   669  	zeroHash := make([]byte, 64)
   670  
   671  	privKey, err := GenerateKey(rand.Reader)
   672  	if err != nil {
   673  		panic(err)
   674  	}
   675  	r, s, err := Sign(rand.Reader, &privKey.PrivateKey, zeroHash)
   676  	if err != nil {
   677  		panic(err)
   678  	}
   679  
   680  	r = r.Add(r, P256().Params().N)
   681  	if Verify(&privKey.PublicKey, zeroHash, r, s) {
   682  		t.Error("Verify with r=r+n succeeded")
   683  	}
   684  }
   685  
   686  func TestRMinusNSignature(t *testing.T) {
   687  	zeroHash := make([]byte, 64)
   688  
   689  	privKey, err := GenerateKey(rand.Reader)
   690  	if err != nil {
   691  		panic(err)
   692  	}
   693  	r, s, err := Sign(rand.Reader, &privKey.PrivateKey, zeroHash)
   694  	if err != nil {
   695  		panic(err)
   696  	}
   697  
   698  	r = r.Sub(r, P256().Params().N)
   699  	if Verify(&privKey.PublicKey, zeroHash, r, s) {
   700  		t.Error("Verify with r=r-n succeeded")
   701  	}
   702  }
   703  
   704  func TestEqual(t *testing.T) {
   705  	private, _ := GenerateKey(rand.Reader)
   706  	public := &private.PublicKey
   707  
   708  	if !public.Equal(public) {
   709  		t.Errorf("public key is not equal to itself: %q", public)
   710  	}
   711  	if !public.Equal(crypto.Signer(private).Public()) {
   712  		t.Errorf("private.Public() is not Equal to public: %q", public)
   713  	}
   714  	if !private.Equal(private) {
   715  		t.Errorf("private key is not equal to itself: %q", private.PrivateKey)
   716  	}
   717  
   718  	otherPriv, _ := GenerateKey(rand.Reader)
   719  	otherPub := &otherPriv.PublicKey
   720  	if public.Equal(otherPub) {
   721  		t.Errorf("different public keys are Equal")
   722  	}
   723  	if private.Equal(otherPriv) {
   724  		t.Errorf("different private keys are Equal")
   725  	}
   726  }
   727  
   728  func TestPublicKeyToECDH(t *testing.T) {
   729  	priv, _ := GenerateKey(rand.Reader)
   730  	_, err := PublicKeyToECDH(&priv.PublicKey)
   731  	if err != nil {
   732  		t.Fatal(err)
   733  	}
   734  
   735  	p256, _ := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   736  	_, err = PublicKeyToECDH(&p256.PublicKey)
   737  	if err == nil {
   738  		t.Fatal("should be error")
   739  	}
   740  }
   741  
   742  func TestRandomPoint(t *testing.T) {
   743  	c := p256()
   744  	t.Cleanup(func() { testingOnlyRejectionSamplingLooped = nil })
   745  	var loopCount int
   746  	testingOnlyRejectionSamplingLooped = func() { loopCount++ }
   747  
   748  	// A sequence of all ones will generate 2^N-1, which should be rejected.
   749  	// (Unless, for example, we are masking too many bits.)
   750  	r := io.MultiReader(bytes.NewReader(bytes.Repeat([]byte{0xff}, 100)), rand.Reader)
   751  	if k, p, err := randomPoint(c, r, false); err != nil {
   752  		t.Fatal(err)
   753  	} else if k.IsZero() == 1 {
   754  		t.Error("k is zero")
   755  	} else if p.Bytes()[0] != 4 {
   756  		t.Error("p is infinity")
   757  	}
   758  	if loopCount == 0 {
   759  		t.Error("overflow was not rejected")
   760  	}
   761  	loopCount = 0
   762  
   763  	// A sequence of all zeroes will generate zero, which should be rejected.
   764  	r = io.MultiReader(bytes.NewReader(bytes.Repeat([]byte{0}, 100)), rand.Reader)
   765  	if k, p, err := randomPoint(c, r, false); err != nil {
   766  		t.Fatal(err)
   767  	} else if k.IsZero() == 1 {
   768  		t.Error("k is zero")
   769  	} else if p.Bytes()[0] != 4 {
   770  		t.Error("p is infinity")
   771  	}
   772  	if loopCount == 0 {
   773  		t.Error("zero was not rejected")
   774  	}
   775  }
   776  
   777  func BenchmarkGenerateKey_SM2(b *testing.B) {
   778  	r := bufio.NewReaderSize(rand.Reader, 1<<15)
   779  	b.ReportAllocs()
   780  	b.ResetTimer()
   781  	for i := 0; i < b.N; i++ {
   782  		if _, err := GenerateKey(r); err != nil {
   783  			b.Fatal(err)
   784  		}
   785  	}
   786  }
   787  
   788  func BenchmarkGenerateKey_P256(b *testing.B) {
   789  	r := bufio.NewReaderSize(rand.Reader, 1<<15)
   790  	b.ReportAllocs()
   791  	b.ResetTimer()
   792  	for i := 0; i < b.N; i++ {
   793  		if _, err := ecdsa.GenerateKey(elliptic.P256(), r); err != nil {
   794  			b.Fatal(err)
   795  		}
   796  	}
   797  }
   798  
   799  func BenchmarkSign_SM2(b *testing.B) {
   800  	r := bufio.NewReaderSize(rand.Reader, 1<<15)
   801  	priv, err := GenerateKey(r)
   802  	if err != nil {
   803  		b.Fatal(err)
   804  	}
   805  	hashed := sm3.Sum([]byte("testing"))
   806  
   807  	b.ReportAllocs()
   808  	b.ResetTimer()
   809  	for i := 0; i < b.N; i++ {
   810  		sig, err := SignASN1(rand.Reader, priv, hashed[:], nil)
   811  		if err != nil {
   812  			b.Fatal(err)
   813  		}
   814  		// Prevent the compiler from optimizing out the operation.
   815  		hashed[0] = sig[0]
   816  	}
   817  }
   818  
   819  func BenchmarkSign_SM2Specific(b *testing.B) {
   820  	r := bufio.NewReaderSize(rand.Reader, 1<<15)
   821  	priv, err := GenerateKey(r)
   822  	if err != nil {
   823  		b.Fatal(err)
   824  	}
   825  	hashed := []byte("testingtestingtestingtestingtestingtestingtestingtestingtestingtestingtestingtestingtestingtestingtestingtestingtestingtestingtestingtesting")
   826  	b.RunParallel(func(p *testing.PB) {
   827  		for p.Next() {
   828  			_, err := priv.SignWithSM2(rand.Reader, nil, hashed)
   829  			if err != nil {
   830  				b.Fatal(err)
   831  			}
   832  		}
   833  	})
   834  }
   835  
   836  func BenchmarkSign_P256(b *testing.B) {
   837  	r := bufio.NewReaderSize(rand.Reader, 1<<15)
   838  	priv, err := ecdsa.GenerateKey(elliptic.P256(), r)
   839  	if err != nil {
   840  		b.Fatal(err)
   841  	}
   842  	hashed := []byte("testing")
   843  
   844  	b.ReportAllocs()
   845  	b.ResetTimer()
   846  	for i := 0; i < b.N; i++ {
   847  		sig, err := ecdsa.SignASN1(rand.Reader, priv, hashed)
   848  		if err != nil {
   849  			b.Fatal(err)
   850  		}
   851  		// Prevent the compiler from optimizing out the operation.
   852  		hashed[0] = sig[0]
   853  	}
   854  }
   855  
   856  func BenchmarkVerify_P256(b *testing.B) {
   857  	rd := bufio.NewReaderSize(rand.Reader, 1<<15)
   858  	priv, err := ecdsa.GenerateKey(elliptic.P256(), rd)
   859  	if err != nil {
   860  		b.Fatal(err)
   861  	}
   862  	hashed := []byte("testing")
   863  	r, s, err := ecdsa.Sign(rand.Reader, priv, hashed)
   864  	if err != nil {
   865  		b.Fatal(err)
   866  	}
   867  
   868  	b.ReportAllocs()
   869  	b.ResetTimer()
   870  	for i := 0; i < b.N; i++ {
   871  		if !ecdsa.Verify(&priv.PublicKey, hashed, r, s) {
   872  			b.Fatal("verify failed")
   873  		}
   874  	}
   875  }
   876  
   877  func BenchmarkVerify_SM2(b *testing.B) {
   878  	rd := bufio.NewReaderSize(rand.Reader, 1<<15)
   879  	priv, err := GenerateKey(rd)
   880  	if err != nil {
   881  		b.Fatal(err)
   882  	}
   883  	hashed := []byte("testing")
   884  	r, s, err := Sign(rand.Reader, &priv.PrivateKey, hashed)
   885  	if err != nil {
   886  		b.Fatal(err)
   887  	}
   888  
   889  	b.ReportAllocs()
   890  	b.ResetTimer()
   891  	for i := 0; i < b.N; i++ {
   892  		if !Verify(&priv.PublicKey, hashed, r, s) {
   893  			b.Fatal("verify failed")
   894  		}
   895  	}
   896  }
   897  
   898  func benchmarkEncrypt(b *testing.B, curve elliptic.Curve, plaintext []byte) {
   899  	r := bufio.NewReaderSize(rand.Reader, 1<<15)
   900  	priv, err := ecdsa.GenerateKey(curve, r)
   901  	if err != nil {
   902  		b.Fatal(err)
   903  	}
   904  	b.SetBytes(int64(len(plaintext)))
   905  	b.ReportAllocs()
   906  	b.ResetTimer()
   907  	for i := 0; i < b.N; i++ {
   908  		Encrypt(rand.Reader, &priv.PublicKey, []byte(plaintext), nil)
   909  	}
   910  }
   911  
   912  func BenchmarkEncryptNoMoreThan32_P256(b *testing.B) {
   913  	benchmarkEncrypt(b, elliptic.P256(), make([]byte, 31))
   914  }
   915  
   916  func BenchmarkEncryptNoMoreThan32_SM2(b *testing.B) {
   917  	benchmarkEncrypt(b, P256(), make([]byte, 31))
   918  }
   919  
   920  func BenchmarkEncrypt128_P256(b *testing.B) {
   921  	benchmarkEncrypt(b, elliptic.P256(), make([]byte, 128))
   922  }
   923  
   924  func BenchmarkEncrypt128_SM2(b *testing.B) {
   925  	benchmarkEncrypt(b, P256(), make([]byte, 128))
   926  }
   927  
   928  func BenchmarkEncrypt512_SM2(b *testing.B) {
   929  	benchmarkEncrypt(b, P256(), make([]byte, 512))
   930  }
   931  
   932  func BenchmarkEncrypt1K_SM2(b *testing.B) {
   933  	benchmarkEncrypt(b, P256(), make([]byte, 1024))
   934  }
   935  
   936  func BenchmarkEncrypt8K_SM2(b *testing.B) {
   937  	benchmarkEncrypt(b, P256(), make([]byte, 8*1024))
   938  }