github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/sm2/sm2_test.go (about)

     1  package sm2
     2  
     3  import (
     4  	"crypto"
     5  	"crypto/ecdsa"
     6  	"crypto/elliptic"
     7  	"crypto/rand"
     8  	"encoding/hex"
     9  	"math/big"
    10  	"reflect"
    11  	"testing"
    12  
    13  	"github.com/hxx258456/ccgo/sm3"
    14  )
    15  
    16  func Test_kdf(t *testing.T) {
    17  	x2, _ := new(big.Int).SetString("64D20D27D0632957F8028C1E024F6B02EDF23102A566C932AE8BD613A8E865FE", 16)
    18  	y2, _ := new(big.Int).SetString("58D225ECA784AE300A81A2D48281A828E1CEDF11C4219099840265375077BF78", 16)
    19  
    20  	expected := "006e30dae231b071dfad8aa379e90264491603"
    21  
    22  	result, success := kdf(append(x2.Bytes(), y2.Bytes()...), 19)
    23  	if !success {
    24  		t.Fatalf("failed")
    25  	}
    26  
    27  	resultStr := hex.EncodeToString(result)
    28  
    29  	if expected != resultStr {
    30  		t.Fatalf("expected %s, real value %s", expected, resultStr)
    31  	}
    32  }
    33  
    34  func Test_SplicingOrder(t *testing.T) {
    35  	priv, _ := GenerateKey(rand.Reader)
    36  	tests := []struct {
    37  		name      string
    38  		plainText string
    39  		from      ciphertextSplicingOrder
    40  		to        ciphertextSplicingOrder
    41  	}{
    42  		// TODO: Add test cases.
    43  		{"less than 32 1", "encryption standard", C1C2C3, C1C3C2},
    44  		{"less than 32 2", "encryption standard", C1C3C2, C1C2C3},
    45  		{"equals 32 1", "encryption standard encryption ", C1C2C3, C1C3C2},
    46  		{"equals 32 2", "encryption standard encryption ", C1C3C2, C1C2C3},
    47  		{"long than 32 1", "encryption standard encryption standard", C1C2C3, C1C3C2},
    48  		{"long than 32 2", "encryption standard encryption standard", C1C3C2, C1C2C3},
    49  	}
    50  	for _, tt := range tests {
    51  		t.Run(tt.name, func(t *testing.T) {
    52  			ciphertext, err := Encrypt(&priv.PublicKey, []byte(tt.plainText), rand.Reader, NewPlainEncrypterOpts(MarshalUncompressed, tt.from))
    53  			if err != nil {
    54  				t.Fatalf("encrypt failed %v", err)
    55  			}
    56  			plaintext, err := priv.Decrypt(rand.Reader, ciphertext, NewPlainDecrypterOpts(tt.from))
    57  			if err != nil {
    58  				t.Fatalf("decrypt failed %v", err)
    59  			}
    60  			if !reflect.DeepEqual(string(plaintext), tt.plainText) {
    61  				t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
    62  			}
    63  
    64  			//Adjust splicing order
    65  			ciphertext, err = AdjustCiphertextSplicingOrder(ciphertext, tt.from, tt.to)
    66  			if err != nil {
    67  				t.Fatalf("adjust splicing order failed %v", err)
    68  			}
    69  			plaintext, err = priv.Decrypt(rand.Reader, ciphertext, NewPlainDecrypterOpts(tt.to))
    70  			if err != nil {
    71  				t.Fatalf("decrypt failed after adjust splicing order %v", err)
    72  			}
    73  			if !reflect.DeepEqual(string(plaintext), tt.plainText) {
    74  				t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
    75  			}
    76  		})
    77  	}
    78  }
    79  
    80  func Test_encryptDecrypt_ASN1(t *testing.T) {
    81  	priv, _ := GenerateKey(rand.Reader)
    82  	tests := []struct {
    83  		name      string
    84  		plainText string
    85  	}{
    86  		// TODO: Add test cases.
    87  		{"less than 32", "encryption standard"},
    88  		{"equals 32", "encryption standard encryption "},
    89  		{"long than 32", "encryption standard encryption standard"},
    90  	}
    91  	for _, tt := range tests {
    92  		t.Run(tt.name, func(t *testing.T) {
    93  			encrypterOpts := ASN1EncrypterOpts
    94  			ciphertext, err := Encrypt(&priv.PublicKey, []byte(tt.plainText), rand.Reader, encrypterOpts)
    95  			if err != nil {
    96  				t.Fatalf("encrypt failed %v", err)
    97  			}
    98  			plaintext, err := priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts)
    99  			if err != nil {
   100  				t.Fatalf("decrypt failed %v", err)
   101  			}
   102  			if !reflect.DeepEqual(string(plaintext), tt.plainText) {
   103  				t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
   104  			}
   105  		})
   106  	}
   107  }
   108  
   109  func Test_Ciphertext2ASN1(t *testing.T) {
   110  	priv, _ := GenerateKey(rand.Reader)
   111  	tests := []struct {
   112  		name      string
   113  		plainText string
   114  	}{
   115  		// TODO: Add test cases.
   116  		{"less than 32", "encryption standard"},
   117  		{"equals 32", "encryption standard encryption "},
   118  		{"long than 32", "encryption standard encryption standard"},
   119  	}
   120  	for _, tt := range tests {
   121  		t.Run(tt.name, func(t *testing.T) {
   122  			ciphertext, err := Encrypt(&priv.PublicKey, []byte(tt.plainText), rand.Reader, nil)
   123  			if err != nil {
   124  				t.Fatalf("encrypt failed %v", err)
   125  			}
   126  			ciphertext, err = PlainCiphertext2ASN1(ciphertext, C1C3C2)
   127  			if err != nil {
   128  				t.Fatalf("convert to ASN.1 failed %v", err)
   129  			}
   130  			plaintext, err := priv.Decrypt(rand.Reader, ciphertext, ASN1DecrypterOpts)
   131  			if err != nil {
   132  				t.Fatalf("decrypt failed %v", err)
   133  			}
   134  			if !reflect.DeepEqual(string(plaintext), tt.plainText) {
   135  				t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
   136  			}
   137  		})
   138  	}
   139  }
   140  
   141  func Test_ASN1Ciphertext2Plain(t *testing.T) {
   142  	priv, _ := GenerateKey(rand.Reader)
   143  	tests := []struct {
   144  		name      string
   145  		plainText string
   146  	}{
   147  		// TODO: Add test cases.
   148  		{"less than 32", "encryption standard"},
   149  		{"equals 32", "encryption standard encryption "},
   150  		{"long than 32", "encryption standard encryption standard"},
   151  	}
   152  	for _, tt := range tests {
   153  		t.Run(tt.name, func(t *testing.T) {
   154  			ciphertext, err := EncryptAsn1(&priv.PublicKey, []byte(tt.plainText), rand.Reader)
   155  			if err != nil {
   156  				t.Fatalf("encrypt failed %v", err)
   157  			}
   158  			ciphertext, err = ASN1Ciphertext2Plain(ciphertext, nil)
   159  			if err != nil {
   160  				t.Fatalf("convert to plain failed %v", err)
   161  			}
   162  			plaintext, err := priv.Decrypt(rand.Reader, ciphertext, nil)
   163  			if err != nil {
   164  				t.Fatalf("decrypt failed %v", err)
   165  			}
   166  			if !reflect.DeepEqual(string(plaintext), tt.plainText) {
   167  				t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
   168  			}
   169  		})
   170  	}
   171  }
   172  
   173  func Test_encryptDecrypt(t *testing.T) {
   174  	priv, _ := GenerateKey(rand.Reader)
   175  	tests := []struct {
   176  		name      string
   177  		plainText string
   178  	}{
   179  		// TODO: Add test cases.
   180  		{"less than 32", "encryption standard"},
   181  		{"equals 32", "encryption standard encryption "},
   182  		{"long than 32", "encryption standard encryption standard"},
   183  	}
   184  	for _, tt := range tests {
   185  		t.Run(tt.name, func(t *testing.T) {
   186  			ciphertext, err := Encrypt(&priv.PublicKey, []byte(tt.plainText), rand.Reader, nil)
   187  			if err != nil {
   188  				t.Fatalf("encrypt failed %v", err)
   189  			}
   190  			plaintext, err := Decrypt(priv, ciphertext, nil)
   191  			if err != nil {
   192  				t.Fatalf("decrypt failed %v", err)
   193  			}
   194  			if !reflect.DeepEqual(string(plaintext), tt.plainText) {
   195  				t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
   196  			}
   197  			// compress mode
   198  			encrypterOpts := NewPlainEncrypterOpts(MarshalCompressed, C1C3C2)
   199  			ciphertext, err = Encrypt(&priv.PublicKey, []byte(tt.plainText), rand.Reader, encrypterOpts)
   200  			if err != nil {
   201  				t.Fatalf("encrypt failed %v", err)
   202  			}
   203  			plaintext, err = Decrypt(priv, ciphertext, nil)
   204  			if err != nil {
   205  				t.Fatalf("decrypt failed %v", err)
   206  			}
   207  			if !reflect.DeepEqual(string(plaintext), tt.plainText) {
   208  				t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
   209  			}
   210  
   211  			// mixed mode
   212  			encrypterOpts = NewPlainEncrypterOpts(MarshalMixed, C1C3C2)
   213  			ciphertext, err = Encrypt(&priv.PublicKey, []byte(tt.plainText), rand.Reader, encrypterOpts)
   214  			if err != nil {
   215  				t.Fatalf("encrypt failed %v", err)
   216  			}
   217  			plaintext, err = Decrypt(priv, ciphertext, nil)
   218  			if err != nil {
   219  				t.Fatalf("decrypt failed %v", err)
   220  			}
   221  			if !reflect.DeepEqual(string(plaintext), tt.plainText) {
   222  				t.Errorf("Decrypt() = %v, want %v", string(plaintext), tt.plainText)
   223  			}
   224  		})
   225  	}
   226  }
   227  
   228  func Test_signVerify(t *testing.T) {
   229  	priv, _ := GenerateKey(rand.Reader)
   230  	tests := []struct {
   231  		name      string
   232  		plainText string
   233  	}{
   234  		// TODO: Add test cases.
   235  		{"less than 32", "encryption standard"},
   236  		{"equals 32", "encryption standard encryption "},
   237  		{"long than 32", "encryption standard encryption standard"},
   238  	}
   239  	for _, tt := range tests {
   240  		t.Run(tt.name, func(t *testing.T) {
   241  			hash := sm3.Sm3Sum([]byte(tt.plainText))
   242  			signature, err := priv.Sign(rand.Reader, hash[:], nil)
   243  			if err != nil {
   244  				t.Fatalf("sign failed %v", err)
   245  			}
   246  			result := VerifyASN1(&priv.PublicKey, hash[:], signature)
   247  			if !result {
   248  				t.Fatal("verify failed")
   249  			}
   250  		})
   251  	}
   252  }
   253  
   254  // Check that signatures are safe even with a broken entropy source.
   255  func TestNonceSafety(t *testing.T) {
   256  	priv, _ := GenerateKey(rand.Reader)
   257  
   258  	hashed := []byte("testing")
   259  	r0, s0, err := SignAfterZA(zeroReader, priv, hashed)
   260  	if err != nil {
   261  		t.Errorf("SM2: error signing: %s", err)
   262  		return
   263  	}
   264  
   265  	hashed = []byte("testing...")
   266  	r1, s1, err := SignAfterZA(zeroReader, priv, hashed)
   267  	if err != nil {
   268  		t.Errorf("SM2: error signing: %s", err)
   269  		return
   270  	}
   271  
   272  	if s0.Cmp(s1) == 0 {
   273  		// This should never happen.
   274  		t.Error("SM2: the signatures on two different messages were the same")
   275  	}
   276  
   277  	if r0.Cmp(r1) == 0 {
   278  		t.Error("SM2: the nonce used for two diferent messages was the same")
   279  	}
   280  }
   281  
   282  // Check that signatures remain non-deterministic with a functional entropy source.
   283  func TestINDCCA(t *testing.T) {
   284  	priv, _ := GenerateKey(rand.Reader)
   285  
   286  	hashed := []byte("testing")
   287  	r0, s0, err := SignAfterZA(rand.Reader, priv, hashed)
   288  	if err != nil {
   289  		t.Errorf("SM2: error signing: %s", err)
   290  		return
   291  	}
   292  
   293  	r1, s1, err := SignAfterZA(rand.Reader, priv, hashed)
   294  	if err != nil {
   295  		t.Errorf("SM2: error signing: %s", err)
   296  		return
   297  	}
   298  
   299  	if s0.Cmp(s1) == 0 {
   300  		t.Error("SM2: two signatures of the same message produced the same result")
   301  	}
   302  
   303  	if r0.Cmp(r1) == 0 {
   304  		t.Error("SM2: two signatures of the same message produced the same nonce")
   305  	}
   306  }
   307  
   308  func TestEqual(t *testing.T) {
   309  	private, _ := GenerateKey(rand.Reader)
   310  	public := &private.PublicKey
   311  
   312  	if !public.Equal(public) {
   313  		t.Errorf("public key is not equal to itself: %q", public)
   314  	}
   315  	if !public.Equal(crypto.Signer(private).Public()) {
   316  		t.Errorf("private.Public() is not Equal to public: %q", public)
   317  	}
   318  	if !private.Equal(private) {
   319  		t.Errorf("private key is not equal to itself: %q", private)
   320  	}
   321  
   322  	otherPriv, _ := GenerateKey(rand.Reader)
   323  	otherPub := &otherPriv.PublicKey
   324  	if public.Equal(otherPub) {
   325  		t.Errorf("different public keys are Equal")
   326  	}
   327  	if private.Equal(otherPriv) {
   328  		t.Errorf("different private keys are Equal")
   329  	}
   330  }
   331  
   332  func BenchmarkGenerateKey_SM2(b *testing.B) {
   333  	b.ReportAllocs()
   334  	b.ResetTimer()
   335  	for i := 0; i < b.N; i++ {
   336  		if _, err := GenerateKey(rand.Reader); err != nil {
   337  			b.Fatal(err)
   338  		}
   339  	}
   340  }
   341  
   342  func BenchmarkGenerateKey_P256(b *testing.B) {
   343  	b.ReportAllocs()
   344  	b.ResetTimer()
   345  	for i := 0; i < b.N; i++ {
   346  		if _, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader); err != nil {
   347  			b.Fatal(err)
   348  		}
   349  	}
   350  }
   351  
   352  func BenchmarkSign_SM2(b *testing.B) {
   353  	priv, err := GenerateKey(rand.Reader)
   354  	if err != nil {
   355  		b.Fatal(err)
   356  	}
   357  	hashed := []byte("testing")
   358  
   359  	b.ReportAllocs()
   360  	b.ResetTimer()
   361  	for i := 0; i < b.N; i++ {
   362  		sig, err := SignASN1(rand.Reader, priv, hashed)
   363  		if err != nil {
   364  			b.Fatal(err)
   365  		}
   366  		// Prevent the compiler from optimizing out the operation.
   367  		hashed[0] = sig[0]
   368  	}
   369  }
   370  
   371  func BenchmarkSign_P256(b *testing.B) {
   372  	priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   373  	if err != nil {
   374  		b.Fatal(err)
   375  	}
   376  	hashed := []byte("testing")
   377  
   378  	b.ReportAllocs()
   379  	b.ResetTimer()
   380  	for i := 0; i < b.N; i++ {
   381  		sig, err := ecdsa.SignASN1(rand.Reader, priv, hashed)
   382  		if err != nil {
   383  			b.Fatal(err)
   384  		}
   385  		// Prevent the compiler from optimizing out the operation.
   386  		hashed[0] = sig[0]
   387  	}
   388  }
   389  
   390  func BenchmarkVerify_P256(b *testing.B) {
   391  	priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
   392  	if err != nil {
   393  		b.Fatal(err)
   394  	}
   395  	hashed := []byte("testing")
   396  	r, s, err := ecdsa.Sign(rand.Reader, priv, hashed)
   397  	if err != nil {
   398  		b.Fatal(err)
   399  	}
   400  
   401  	b.ReportAllocs()
   402  	b.ResetTimer()
   403  	for i := 0; i < b.N; i++ {
   404  		if !ecdsa.Verify(&priv.PublicKey, hashed, r, s) {
   405  			b.Fatal("verify failed")
   406  		}
   407  	}
   408  }
   409  
   410  func BenchmarkVerify_SM2(b *testing.B) {
   411  	priv, err := GenerateKey(rand.Reader)
   412  	if err != nil {
   413  		b.Fatal(err)
   414  	}
   415  	hashed := []byte("testing")
   416  	r, s, err := SignAfterZA(rand.Reader, priv, hashed)
   417  	if err != nil {
   418  		b.Fatal(err)
   419  	}
   420  
   421  	b.ReportAllocs()
   422  	b.ResetTimer()
   423  	for i := 0; i < b.N; i++ {
   424  		if !verifyGeneric(&priv.PublicKey, hashed, r, s) {
   425  			b.Fatal("verify failed")
   426  		}
   427  	}
   428  }
   429  
   430  // func benchmarkEncrypt(b *testing.B, curve elliptic.Curve, plaintext string) {
   431  // 	priv, err := ecdsa.GenerateKey(curve, rand.Reader)
   432  // 	if err != nil {
   433  // 		b.Fatal(err)
   434  // 	}
   435  // 	b.ReportAllocs()
   436  // 	b.ResetTimer()
   437  // 	for i := 0; i < b.N; i++ {
   438  // 		Encrypt(&priv.PublicKey, []byte(plaintext), rand.Reader, nil)
   439  // 	}
   440  // }
   441  
   442  // func BenchmarkLessThan32_P256(b *testing.B) {
   443  // 	benchmarkEncrypt(b, elliptic.P256(), "encryption standard")
   444  // }
   445  
   446  // func BenchmarkLessThan32_SM2(b *testing.B) {
   447  // 	benchmarkEncrypt(b, P256(), "encryption standard")
   448  // }
   449  
   450  // func BenchmarkMoreThan32_P256(b *testing.B) {
   451  // 	benchmarkEncrypt(b, elliptic.P256(), "encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard")
   452  // }
   453  
   454  // func BenchmarkMoreThan32_SM2(b *testing.B) {
   455  // 	benchmarkEncrypt(b, P256(), "encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard encryption standard")
   456  // }