github.com/emmansun/gmsm@v0.29.1/ecdh/ecdh_test.go (about)

     1  package ecdh_test
     2  
     3  import (
     4  	"bytes"
     5  	"crypto"
     6  	"crypto/cipher"
     7  	"crypto/rand"
     8  	"encoding/hex"
     9  	"fmt"
    10  	"io"
    11  	"testing"
    12  
    13  	"github.com/emmansun/gmsm/ecdh"
    14  	"golang.org/x/crypto/chacha20"
    15  )
    16  
    17  // Check that PublicKey and PrivateKey implement the interfaces documented in
    18  // crypto.PublicKey and crypto.PrivateKey.
    19  var _ interface {
    20  	Equal(x crypto.PublicKey) bool
    21  } = &ecdh.PublicKey{}
    22  var _ interface {
    23  	Public() crypto.PublicKey
    24  	Equal(x crypto.PrivateKey) bool
    25  } = &ecdh.PrivateKey{}
    26  
    27  func hexDecode(t *testing.T, s string) []byte {
    28  	b, err := hex.DecodeString(s)
    29  	if err != nil {
    30  		t.Fatal("invalid hex string:", s)
    31  	}
    32  	return b
    33  }
    34  
    35  func TestNewPrivateKeyWithOrderMinus1(t *testing.T) {
    36  	_, err := ecdh.P256().NewPrivateKey([]byte{
    37  		0xff, 0xff, 0xff, 0xfe, 0xff, 0xff, 0xff, 0xff,
    38  		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff,
    39  		0x72, 0x03, 0xdf, 0x6b, 0x21, 0xc6, 0x05, 0x2b,
    40  		0x53, 0xbb, 0xf4, 0x09, 0x39, 0xd5, 0x41, 0x22})
    41  	if err == nil || err.Error() != "ecdh: invalid private key" {
    42  		t.Errorf("expected invalid private key")
    43  	} 
    44  }
    45  
    46  func TestECDH(t *testing.T) {
    47  	aliceKey, err := ecdh.P256().GenerateKey(rand.Reader)
    48  	if err != nil {
    49  		t.Fatal(err)
    50  	}
    51  	bobKey, err := ecdh.P256().GenerateKey(rand.Reader)
    52  	if err != nil {
    53  		t.Fatal(err)
    54  	}
    55  
    56  	alicePubKey, err := ecdh.P256().NewPublicKey(aliceKey.PublicKey().Bytes())
    57  	if err != nil {
    58  		t.Error(err)
    59  	}
    60  	if !bytes.Equal(aliceKey.PublicKey().Bytes(), alicePubKey.Bytes()) {
    61  		t.Error("encoded and decoded public keys are different")
    62  	}
    63  	if !aliceKey.PublicKey().Equal(alicePubKey) {
    64  		t.Error("encoded and decoded public keys are different")
    65  	}
    66  
    67  	alicePrivKey, err := ecdh.P256().NewPrivateKey(aliceKey.Bytes())
    68  	if err != nil {
    69  		t.Error(err)
    70  	}
    71  	if !bytes.Equal(aliceKey.Bytes(), alicePrivKey.Bytes()) {
    72  		t.Error("encoded and decoded private keys are different")
    73  	}
    74  	if !aliceKey.Equal(alicePrivKey) {
    75  		t.Error("encoded and decoded private keys are different")
    76  	}
    77  
    78  	bobSecret, err := bobKey.ECDH(aliceKey.PublicKey())
    79  	if err != nil {
    80  		t.Fatal(err)
    81  	}
    82  	aliceSecret, err := aliceKey.ECDH(bobKey.PublicKey())
    83  	if err != nil {
    84  		t.Fatal(err)
    85  	}
    86  
    87  	if !bytes.Equal(bobSecret, aliceSecret) {
    88  		t.Error("two ECDH computations came out different")
    89  	}
    90  }
    91  
    92  func TestSM2MQV(t *testing.T) {
    93  	aliceSKey, err := ecdh.P256().GenerateKey(rand.Reader)
    94  	if err != nil {
    95  		t.Fatal(err)
    96  	}
    97  	aliceEKey, err := ecdh.P256().GenerateKey(rand.Reader)
    98  	if err != nil {
    99  		t.Fatal(err)
   100  	}
   101  
   102  	bobSKey, err := ecdh.P256().GenerateKey(rand.Reader)
   103  	if err != nil {
   104  		t.Fatal(err)
   105  	}
   106  	bobEKey, err := ecdh.P256().GenerateKey(rand.Reader)
   107  	if err != nil {
   108  		t.Fatal(err)
   109  	}
   110  
   111  	bobSecret, err := bobSKey.SM2MQV(bobEKey, aliceSKey.PublicKey(), aliceEKey.PublicKey())
   112  	if err != nil {
   113  		t.Fatal(err)
   114  	}
   115  
   116  	aliceSecret, err := aliceSKey.SM2MQV(aliceEKey, bobSKey.PublicKey(), bobEKey.PublicKey())
   117  	if err != nil {
   118  		t.Fatal(err)
   119  	}
   120  
   121  	if !aliceSecret.Equal(bobSecret) {
   122  		t.Error("two SM2MQV computations came out different")
   123  	}
   124  }
   125  
   126  func TestSM2SharedKey(t *testing.T) {
   127  	aliceSKey, err := ecdh.P256().GenerateKey(rand.Reader)
   128  	if err != nil {
   129  		t.Fatal(err)
   130  	}
   131  	aliceEKey, err := ecdh.P256().GenerateKey(rand.Reader)
   132  	if err != nil {
   133  		t.Fatal(err)
   134  	}
   135  
   136  	bobSKey, err := ecdh.P256().GenerateKey(rand.Reader)
   137  	if err != nil {
   138  		t.Fatal(err)
   139  	}
   140  	bobEKey, err := ecdh.P256().GenerateKey(rand.Reader)
   141  	if err != nil {
   142  		t.Fatal(err)
   143  	}
   144  
   145  	bobSecret, err := bobSKey.SM2MQV(bobEKey, aliceSKey.PublicKey(), aliceEKey.PublicKey())
   146  	if err != nil {
   147  		t.Fatal(err)
   148  	}
   149  
   150  	aliceSecret, err := aliceSKey.SM2MQV(aliceEKey, bobSKey.PublicKey(), bobEKey.PublicKey())
   151  	if err != nil {
   152  		t.Fatal(err)
   153  	}
   154  
   155  	if !aliceSecret.Equal(bobSecret) {
   156  		t.Error("two SM2MQV computations came out different")
   157  	}
   158  
   159  	bobKey, err := bobSecret.SM2SharedKey(true, 48, bobSKey.PublicKey(), aliceSKey.PublicKey(), []byte("Bob"), []byte("Alice"))
   160  	if err != nil {
   161  		t.Fatal(err)
   162  	}
   163  
   164  	aliceKey, err := aliceSecret.SM2SharedKey(false, 48, aliceSKey.PublicKey(), bobSKey.PublicKey(), []byte("Alice"), []byte("Bob"))
   165  	if err != nil {
   166  		t.Fatal(err)
   167  	}
   168  
   169  	if !bytes.Equal(bobKey, aliceKey) {
   170  		t.Error("two SM2SharedKey computations came out different")
   171  	}
   172  }
   173  
   174  var vectors = []struct {
   175  	LocalStaticPriv, LocalEphemeralPriv   string
   176  	RemoteStaticPriv, RemoteEphemeralPriv string
   177  	SharedSecret, Key                     string
   178  }{
   179  	{
   180  		"e04c3fd77408b56a648ad439f673511a2ae248def3bab26bdfc9cdbd0ae9607e",
   181  		"6fe0bac5b09d3ab10f724638811c34464790520e4604e71e6cb0e5310623b5b1",
   182  		"7a1136f60d2c5531447e5a3093078c2a505abf74f33aefed927ac0a5b27e7dd7",
   183  		"d0233bdbb0b8a7bfe1aab66132ef06fc4efaedd5d5000692bc21185242a31f6f",
   184  		"046ab5c9709277837cedc515730d04751ef81c71e81e0e52357a98cf41796ab560508da6e858b40c6264f17943037434174284a847f32c4f54104a98af5148d89f",
   185  		"1ad809ebc56ddda532020c352e1e60b121ebeb7b4e632db4dd90a362cf844f8bba85140e30984ddb581199bf5a9dda22",
   186  	},
   187  	{
   188  		"cb5ac204b38d0e5c9fc38a467075986754018f7dbb7cbbc5b4c78d56a88a8ad8",
   189  		"1681a66c02b67fdadfc53cba9b417b9499d0159435c86bb8760c3a03ae157539",
   190  		"4f54b10e0d8e9e2fe5cc79893e37fd0fd990762d1372197ed92dde464b2773ef",
   191  		"a2fe43dea141e9acc88226eaba8908ad17e81376c92102cb8186e8fef61a8700",
   192  		"04677d055355a1dcc9de4df00d3a80b6daa76bdf54ff7e0a3a6359fcd0c6f1e4b4697fffc41bbbcc3a28ea3aa1c6c380d1e92f142233afa4b430d02ab4cebc43b2",
   193  		"7a103ae61a30ed9df573a5febb35a9609cbed5681bcb98a8545351bf7d6824cc4635df5203712ea506e2e3c4ec9b12e7",
   194  	},
   195  	{
   196  		"ee690a34a779ab48227a2f68b062a80f92e26d82835608dd01b7452f1e4fb296",
   197  		"2046c6cee085665e9f3abeba41fd38e17a26c08f2f5e8f0e1007afc0bf6a2a5d",
   198  		"8ef49ea427b13cc31151e1c96ae8a48cb7919063f2d342560fb7eaaffb93d8fe",
   199  		"9baf8d602e43fbae83fedb7368f98c969d378b8a647318f8cafb265296ae37de",
   200  		"04f7e9f1447968b284ff43548fcec3752063ea386b48bfabb9baf2f9c1caa05c2fb12c2cca37326ce27e68f8cc6414c2554895519c28da1ca21e61890d0bc525c4",
   201  		"b18e78e5072f301399dc1f4baf2956c0ed2d5f52f19abb1705131b0865b079031259ee6c629b4faed528bcfa1c5d2cbc",
   202  	},
   203  }
   204  
   205  func TestSM2SharedKeyVectors(t *testing.T) {
   206  	initiator := []byte("Alice")
   207  	responder := []byte("Bob")
   208  	kenLen := 48
   209  
   210  	for i, v := range vectors {
   211  		aliceSKey, err := ecdh.P256().NewPrivateKey(hexDecode(t, v.LocalStaticPriv))
   212  		if err != nil {
   213  			t.Fatal(err)
   214  		}
   215  		aliceEKey, err := ecdh.P256().NewPrivateKey(hexDecode(t, v.LocalEphemeralPriv))
   216  		if err != nil {
   217  			t.Fatal(err)
   218  		}
   219  		bobSKey, err := ecdh.P256().NewPrivateKey(hexDecode(t, v.RemoteStaticPriv))
   220  		if err != nil {
   221  			t.Fatal(err)
   222  		}
   223  		bobEKey, err := ecdh.P256().NewPrivateKey(hexDecode(t, v.RemoteEphemeralPriv))
   224  		if err != nil {
   225  			t.Fatal(err)
   226  		}
   227  
   228  		bobSecret, err := bobSKey.SM2MQV(bobEKey, aliceSKey.PublicKey(), aliceEKey.PublicKey())
   229  		if err != nil {
   230  			t.Fatal(err)
   231  		}
   232  
   233  		aliceSecret, err := aliceSKey.SM2MQV(aliceEKey, bobSKey.PublicKey(), bobEKey.PublicKey())
   234  		if err != nil {
   235  			t.Fatal(err)
   236  		}
   237  
   238  		if !aliceSecret.Equal(bobSecret) {
   239  			t.Error("two SM2MQV computations came out different")
   240  		}
   241  
   242  		if !bytes.Equal(aliceSecret.Bytes(), hexDecode(t, v.SharedSecret)) {
   243  			t.Errorf("%v shared secret is not expected.", i)
   244  		}
   245  
   246  		bobKey, err := bobSecret.SM2SharedKey(true, kenLen, bobSKey.PublicKey(), aliceSKey.PublicKey(), responder, initiator)
   247  		if err != nil {
   248  			t.Fatal(err)
   249  		}
   250  
   251  		aliceKey, err := aliceSecret.SM2SharedKey(false, kenLen, aliceSKey.PublicKey(), bobSKey.PublicKey(), initiator, responder)
   252  		if err != nil {
   253  			t.Fatal(err)
   254  		}
   255  
   256  		if !bytes.Equal(bobKey, aliceKey) {
   257  			t.Error("two SM2SharedKey computations came out different")
   258  		}
   259  
   260  		if !bytes.Equal(bobKey, hexDecode(t, v.Key)) {
   261  			t.Errorf("%v keying data is not expected.", i)
   262  		}
   263  	}
   264  }
   265  
   266  type countingReader struct {
   267  	r io.Reader
   268  	n int
   269  }
   270  
   271  func (r *countingReader) Read(p []byte) (int, error) {
   272  	n, err := r.r.Read(p)
   273  	r.n += n
   274  	return n, err
   275  }
   276  
   277  func TestGenerateKey(t *testing.T) {
   278  	r := &countingReader{r: rand.Reader}
   279  	k, err := ecdh.P256().GenerateKey(r)
   280  	if err != nil {
   281  		t.Fatal(err)
   282  	}
   283  
   284  	// GenerateKey does rejection sampling. If the masking works correctly,
   285  	// the probability of a rejection is 1-ord(G)/2^ceil(log2(ord(G))),
   286  	// which for all curves is small enough (at most 2^-32, for P-256) that
   287  	// a bit flip is more likely to make this test fail than bad luck.
   288  	// Account for the extra MaybeReadByte byte, too.
   289  	if got, expected := r.n, len(k.Bytes())+1; got > expected {
   290  		t.Errorf("expected GenerateKey to consume at most %v bytes, got %v", expected, got)
   291  	}
   292  }
   293  
   294  func TestString(t *testing.T) {
   295  	s := fmt.Sprintf("%s", ecdh.P256())
   296  	if s != "sm2p256v1" {
   297  		t.Errorf("unexpected Curve string encoding: %q", s)
   298  	}
   299  }
   300  
   301  func BenchmarkECDH(b *testing.B) {
   302  	benchmarkAllCurves(b, func(b *testing.B, curve ecdh.Curve) {
   303  		c, err := chacha20.NewUnauthenticatedCipher(make([]byte, 32), make([]byte, 12))
   304  		if err != nil {
   305  			b.Fatal(err)
   306  		}
   307  		rand := cipher.StreamReader{
   308  			S: c, R: zeroReader,
   309  		}
   310  
   311  		peerKey, err := curve.GenerateKey(rand)
   312  		if err != nil {
   313  			b.Fatal(err)
   314  		}
   315  		peerShare := peerKey.PublicKey().Bytes()
   316  		b.ResetTimer()
   317  		b.ReportAllocs()
   318  
   319  		var allocationsSink byte
   320  
   321  		for i := 0; i < b.N; i++ {
   322  			key, err := curve.GenerateKey(rand)
   323  			if err != nil {
   324  				b.Fatal(err)
   325  			}
   326  			share := key.PublicKey().Bytes()
   327  			peerPubKey, err := curve.NewPublicKey(peerShare)
   328  			if err != nil {
   329  				b.Fatal(err)
   330  			}
   331  			secret, err := key.ECDH(peerPubKey)
   332  			if err != nil {
   333  				b.Fatal(err)
   334  			}
   335  			allocationsSink ^= secret[0] ^ share[0]
   336  		}
   337  	})
   338  }
   339  
   340  func benchmarkAllCurves(b *testing.B, f func(b *testing.B, curve ecdh.Curve)) {
   341  	b.Run("SM2P256", func(b *testing.B) { f(b, ecdh.P256()) })
   342  }
   343  
   344  type zr struct{}
   345  
   346  // Read replaces the contents of dst with zeros. It is safe for concurrent use.
   347  func (zr) Read(dst []byte) (n int, err error) {
   348  	for i := range dst {
   349  		dst[i] = 0
   350  	}
   351  	return len(dst), nil
   352  }
   353  
   354  var zeroReader = zr{}