github.com/cloudflare/circl@v1.5.0/dh/x25519/curve_test.go (about)

     1  package x25519
     2  
     3  import (
     4  	"crypto/rand"
     5  	"math/big"
     6  	"testing"
     7  
     8  	"github.com/cloudflare/circl/internal/conv"
     9  	"github.com/cloudflare/circl/internal/test"
    10  	fp "github.com/cloudflare/circl/math/fp25519"
    11  )
    12  
    13  func getModulus() *big.Int {
    14  	p := big.NewInt(1)
    15  	p.Lsh(p, 255).Sub(p, big.NewInt(19))
    16  	return p
    17  }
    18  
    19  func doubleBig(x1, z1, p *big.Int) {
    20  	// Montgomery point doubling in projective (X:Z) coordinates.
    21  	A24 := big.NewInt(121666)
    22  	A, B, C := big.NewInt(0), big.NewInt(0), big.NewInt(0)
    23  	A.Add(x1, z1).Mod(A, p)
    24  	B.Sub(x1, z1).Mod(B, p)
    25  	A.Mul(A, A)
    26  	B.Mul(B, B)
    27  	C.Sub(A, B)
    28  	x1.Mul(A, B).Mod(x1, p)
    29  	z1.Mul(C, A24).Add(z1, B).Mul(z1, C).Mod(z1, p)
    30  }
    31  
    32  func diffAddBig(work [5]*big.Int, p *big.Int, b uint) {
    33  	// Equation 7 at https://eprint.iacr.org/2017/264
    34  	mu, x1, z1, x2, z2 := work[0], work[1], work[2], work[3], work[4]
    35  	A, B := big.NewInt(0), big.NewInt(0)
    36  	if b != 0 {
    37  		t := new(big.Int)
    38  		t.Set(x1)
    39  		x1.Set(x2)
    40  		x2.Set(t)
    41  		t.Set(z1)
    42  		z1.Set(z2)
    43  		z2.Set(t)
    44  	}
    45  	A.Add(x1, z1)
    46  	B.Sub(x1, z1)
    47  	B.Mul(B, mu).Mod(B, p)
    48  	x1.Add(A, B).Mod(x1, p)
    49  	z1.Sub(A, B).Mod(z1, p)
    50  	x1.Mul(x1, x1).Mul(x1, z2).Mod(x1, p)
    51  	z1.Mul(z1, z1).Mul(z1, x2).Mod(z1, p)
    52  	mu.Mod(mu, p)
    53  	x2.Mod(x2, p)
    54  	z2.Mod(z2, p)
    55  }
    56  
    57  func ladderStepBig(work [5]*big.Int, p *big.Int, b uint) {
    58  	A24 := big.NewInt(121666)
    59  	x1 := work[0]
    60  	x2, z2 := work[1], work[2]
    61  	x3, z3 := work[3], work[4]
    62  	A, B, C, D := big.NewInt(0), big.NewInt(0), big.NewInt(0), big.NewInt(0)
    63  	DA, CB, E := big.NewInt(0), big.NewInt(0), big.NewInt(0)
    64  	A.Add(x2, z2).Mod(A, p)
    65  	B.Sub(x2, z2).Mod(B, p)
    66  	C.Add(x3, z3).Mod(C, p)
    67  	D.Sub(x3, z3).Mod(D, p)
    68  	DA.Mul(D, A).Mod(DA, p)
    69  	CB.Mul(C, B).Mod(CB, p)
    70  	if b != 0 {
    71  		t := new(big.Int)
    72  		t.Set(A)
    73  		A.Set(C)
    74  		C.Set(t)
    75  		t.Set(B)
    76  		B.Set(D)
    77  		D.Set(t)
    78  	}
    79  	AA := A.Mul(A, A).Mod(A, p)
    80  	BB := B.Mul(B, B).Mod(B, p)
    81  	E.Sub(AA, BB)
    82  	x1.Mod(x1, p)
    83  	x2.Mul(AA, BB).Mod(x2, p)
    84  	z2.Mul(E, A24).Add(z2, BB).Mul(z2, E).Mod(z2, p)
    85  	x3.Add(DA, CB)
    86  	z3.Sub(DA, CB)
    87  	x3.Mul(x3, x3).Mod(x3, p)
    88  	z3.Mul(z3, z3).Mul(z3, x1).Mod(z3, p)
    89  }
    90  
    91  func testMulA24(t *testing.T, f func(z, x *fp.Elt)) {
    92  	const numTests = 1 << 9
    93  	var x, z fp.Elt
    94  	p := getModulus()
    95  	A24 := big.NewInt(121666)
    96  	for i := 0; i < numTests; i++ {
    97  		_, _ = rand.Read(x[:])
    98  		bigX := conv.BytesLe2BigInt(x[:])
    99  		f(&z, &x)
   100  		got := conv.BytesLe2BigInt(z[:])
   101  		got.Mod(got, p)
   102  
   103  		want := bigX.Mul(bigX, A24).Mod(bigX, p)
   104  
   105  		if got.Cmp(want) != 0 {
   106  			test.ReportError(t, got, want, x)
   107  		}
   108  	}
   109  }
   110  
   111  func testDouble(t *testing.T, f func(x, z *fp.Elt)) {
   112  	const numTests = 1 << 9
   113  	var x, z fp.Elt
   114  	p := getModulus()
   115  	for i := 0; i < numTests; i++ {
   116  		_, _ = rand.Read(x[:])
   117  		_, _ = rand.Read(z[:])
   118  
   119  		bigX := conv.BytesLe2BigInt(x[:])
   120  		bigZ := conv.BytesLe2BigInt(z[:])
   121  		f(&x, &z)
   122  		got0 := conv.BytesLe2BigInt(x[:])
   123  		got1 := conv.BytesLe2BigInt(z[:])
   124  		got0.Mod(got0, p)
   125  		got1.Mod(got1, p)
   126  
   127  		doubleBig(bigX, bigZ, p)
   128  		want0 := bigX
   129  		want1 := bigZ
   130  
   131  		if got0.Cmp(want0) != 0 {
   132  			test.ReportError(t, got0, want0, x, z)
   133  		}
   134  		if got1.Cmp(want1) != 0 {
   135  			test.ReportError(t, got1, want1, x, z)
   136  		}
   137  	}
   138  }
   139  
   140  func testDiffAdd(t *testing.T, f func(w *[5]fp.Elt, b uint)) {
   141  	const numTests = 1 << 9
   142  	p := getModulus()
   143  	var w [5]fp.Elt
   144  	bigWork := [5]*big.Int{}
   145  	for i := 0; i < numTests; i++ {
   146  		for j := range w {
   147  			_, _ = rand.Read(w[j][:])
   148  			bigWork[j] = conv.BytesLe2BigInt(w[j][:])
   149  		}
   150  		b := uint(w[0][0] & 1)
   151  
   152  		f(&w, b)
   153  
   154  		diffAddBig(bigWork, p, b)
   155  
   156  		for j := range w {
   157  			got := conv.BytesLe2BigInt(w[j][:])
   158  			got.Mod(got, p)
   159  			want := bigWork[j]
   160  			if got.Cmp(want) != 0 {
   161  				test.ReportError(t, got, want, w, b)
   162  			}
   163  		}
   164  	}
   165  }
   166  
   167  func testLadderStep(t *testing.T, f func(w *[5]fp.Elt, b uint)) {
   168  	const numTests = 1 << 9
   169  	var w [5]fp.Elt
   170  	bigWork := [5]*big.Int{}
   171  	p := getModulus()
   172  	for i := 0; i < numTests; i++ {
   173  		for j := range w {
   174  			_, _ = rand.Read(w[j][:])
   175  			bigWork[j] = conv.BytesLe2BigInt(w[j][:])
   176  		}
   177  		b := uint(w[0][0] & 1)
   178  
   179  		f(&w, b)
   180  
   181  		ladderStepBig(bigWork, p, b)
   182  
   183  		for j := range bigWork {
   184  			got := conv.BytesLe2BigInt(w[j][:])
   185  			got.Mod(got, p)
   186  			want := bigWork[j]
   187  			if got.Cmp(want) != 0 {
   188  				test.ReportError(t, got, want, w, b)
   189  			}
   190  		}
   191  	}
   192  }
   193  
   194  func TestGeneric(t *testing.T) {
   195  	t.Run("Double", func(t *testing.T) { testDouble(t, doubleGeneric) })
   196  	t.Run("DiffAdd", func(t *testing.T) { testDiffAdd(t, diffAddGeneric) })
   197  	t.Run("LadderStep", func(t *testing.T) { testLadderStep(t, ladderStepGeneric) })
   198  	t.Run("MulA24", func(t *testing.T) { testMulA24(t, mulA24Generic) })
   199  }
   200  
   201  func TestNative(t *testing.T) {
   202  	t.Run("Double", func(t *testing.T) { testDouble(t, double) })
   203  	t.Run("DiffAdd", func(t *testing.T) { testDiffAdd(t, diffAdd) })
   204  	t.Run("LadderStep", func(t *testing.T) { testLadderStep(t, ladderStep) })
   205  	t.Run("MulA24", func(t *testing.T) { testMulA24(t, mulA24) })
   206  }